feat: add handshake with query for n2c (#266)

This commit is contained in:
Andrew Westberg 2023-06-27 12:46:24 -04:00 committed by GitHub
parent 31a87032ca
commit 554fa1578e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 138 additions and 28 deletions

View file

@ -55,7 +55,15 @@ async fn main() {
// we connect to the unix socket of the local node. Make sure you have the right // we connect to the unix socket of the local node. Make sure you have the right
// path for your environment // path for your environment
let mut client = NodeClient::connect("/tmp/node.socket", MAINNET_MAGIC) let socket_path = "/tmp/node.socket";
// we connect to the unix socket of the local node and perform a handshake query
let version_table = NodeClient::handshake_query(socket_path, MAINNET_MAGIC)
.await
.unwrap();
info!("handshake query result: {:?}", version_table);
let mut client = NodeClient::connect(socket_path, MAINNET_MAGIC)
.await .await
.unwrap(); .unwrap();

View file

@ -11,6 +11,7 @@ use crate::{
}, },
multiplexer::{self, Bearer}, multiplexer::{self, Bearer},
}; };
use crate::miniprotocols::handshake::Confirmation;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum Error { pub enum Error {
@ -126,6 +127,44 @@ impl NodeClient {
}) })
} }
#[cfg(not(target_os = "windows"))]
pub async fn handshake_query(path: impl AsRef<Path>, magic: u64) -> Result<handshake::n2c::VersionTable, Error> {
debug!("connecting");
let bearer = Bearer::connect_unix(path)
.await
.map_err(Error::ConnectFailure)?;
let mut plexer = multiplexer::Plexer::new(bearer);
let hs_channel = plexer.subscribe_client(PROTOCOL_N2C_HANDSHAKE);
let plexer_handle = tokio::spawn(async move { plexer.run().await });
let versions = handshake::n2c::VersionTable::v15_with_query(magic);
let mut client = handshake::Client::new(hs_channel);
let handshake = client
.handshake(versions)
.await
.map_err(Error::HandshakeProtocol)?;
match handshake {
Confirmation::Accepted(_, _) => {
error!("handshake accepted when we expected query reply");
Err(Error::HandshakeProtocol(handshake::Error::InvalidInbound))
}
Confirmation::Rejected(reason) => {
error!(?reason, "handshake refused");
Err(Error::IncompatibleVersion)
}
Confirmation::QueryReply(version_table) => {
plexer_handle.abort();
Ok(version_table)
}
}
}
pub fn chainsync(&mut self) -> &mut chainsync::N2CClient { pub fn chainsync(&mut self) -> &mut chainsync::N2CClient {
&mut self.chainsync &mut self.chainsync
} }

View file

@ -1,3 +1,4 @@
use std::fmt::Debug;
use pallas_codec::Fragment; use pallas_codec::Fragment;
use std::marker::PhantomData; use std::marker::PhantomData;
use tracing::debug; use tracing::debug;
@ -6,16 +7,17 @@ use super::{Error, Message, RefuseReason, State, VersionNumber, VersionTable};
use crate::multiplexer; use crate::multiplexer;
#[derive(Debug)] #[derive(Debug)]
pub enum Confirmation<D> { pub enum Confirmation<D: Debug + Clone> {
Accepted(VersionNumber, D), Accepted(VersionNumber, D),
Rejected(RefuseReason), Rejected(RefuseReason),
QueryReply(VersionTable<D>),
} }
pub struct Client<D>(State, multiplexer::ChannelBuffer, PhantomData<D>); pub struct Client<D>(State, multiplexer::ChannelBuffer, PhantomData<D>);
impl<D> Client<D> impl<D> Client<D>
where where
D: std::fmt::Debug + Clone, D: Debug + Clone,
Message<D>: Fragment, Message<D>: Fragment,
{ {
pub fn new(channel: multiplexer::AgentChannel) -> Self { pub fn new(channel: multiplexer::AgentChannel) -> Self {
@ -69,6 +71,7 @@ where
match (&self.0, msg) { match (&self.0, msg) {
(State::Confirm, Message::Accept(..)) => Ok(()), (State::Confirm, Message::Accept(..)) => Ok(()),
(State::Confirm, Message::Refuse(..)) => Ok(()), (State::Confirm, Message::Refuse(..)) => Ok(()),
(State::Confirm, Message::QueryReply(..)) => Ok(()),
_ => Err(Error::InvalidInbound), _ => Err(Error::InvalidInbound),
} }
} }
@ -113,6 +116,11 @@ where
Ok(Confirmation::Rejected(r)) Ok(Confirmation::Rejected(r))
} }
Message::QueryReply(version_table) => {
debug!("handshake query reply");
Ok(Confirmation::QueryReply(version_table))
}
_ => Err(Error::InvalidInbound), _ => Err(Error::InvalidInbound),
} }
} }

View file

@ -1,6 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use pallas_codec::minicbor::{decode, encode, Decode, Decoder, Encode, Encoder}; use pallas_codec::minicbor::{decode, Decode, Decoder, encode, Encode, Encoder};
use pallas_codec::minicbor::data::Type;
use super::protocol::NetworkMagic; use super::protocol::NetworkMagic;
@ -18,22 +19,28 @@ const PROTOCOL_V9: u64 = 32777;
const PROTOCOL_V10: u64 = 32778; const PROTOCOL_V10: u64 = 32778;
const PROTOCOL_V11: u64 = 32779; const PROTOCOL_V11: u64 = 32779;
const PROTOCOL_V12: u64 = 32780; const PROTOCOL_V12: u64 = 32780;
const PROTOCOL_V13: u64 = 32781;
const PROTOCOL_V14: u64 = 32782;
const PROTOCOL_V15: u64 = 32783;
impl VersionTable { impl VersionTable {
pub fn v1_and_above(network_magic: u64) -> VersionTable { pub fn v1_and_above(network_magic: u64) -> VersionTable {
let values = vec![ let values = vec![
(PROTOCOL_V1, VersionData(network_magic)), (PROTOCOL_V1, VersionData(network_magic, None)),
(PROTOCOL_V2, VersionData(network_magic)), (PROTOCOL_V2, VersionData(network_magic, None)),
(PROTOCOL_V3, VersionData(network_magic)), (PROTOCOL_V3, VersionData(network_magic, None)),
(PROTOCOL_V4, VersionData(network_magic)), (PROTOCOL_V4, VersionData(network_magic, None)),
(PROTOCOL_V5, VersionData(network_magic)), (PROTOCOL_V5, VersionData(network_magic, None)),
(PROTOCOL_V6, VersionData(network_magic)), (PROTOCOL_V6, VersionData(network_magic, None)),
(PROTOCOL_V7, VersionData(network_magic)), (PROTOCOL_V7, VersionData(network_magic, None)),
(PROTOCOL_V8, VersionData(network_magic)), (PROTOCOL_V8, VersionData(network_magic, None)),
(PROTOCOL_V9, VersionData(network_magic)), (PROTOCOL_V9, VersionData(network_magic, None)),
(PROTOCOL_V10, VersionData(network_magic)), (PROTOCOL_V10, VersionData(network_magic, None)),
(PROTOCOL_V11, VersionData(network_magic)), (PROTOCOL_V11, VersionData(network_magic, None)),
(PROTOCOL_V12, VersionData(network_magic)), (PROTOCOL_V12, VersionData(network_magic, None)),
(PROTOCOL_V13, VersionData(network_magic, None)),
(PROTOCOL_V14, VersionData(network_magic, None)),
(PROTOCOL_V15, VersionData(network_magic, Some(false))),
] ]
.into_iter() .into_iter()
.collect::<HashMap<u64, VersionData>>(); .collect::<HashMap<u64, VersionData>>();
@ -42,7 +49,7 @@ impl VersionTable {
} }
pub fn only_v10(network_magic: u64) -> VersionTable { pub fn only_v10(network_magic: u64) -> VersionTable {
let values = vec![(PROTOCOL_V10, VersionData(network_magic))] let values = vec![(PROTOCOL_V10, VersionData(network_magic, None))]
.into_iter() .into_iter()
.collect::<HashMap<u64, VersionData>>(); .collect::<HashMap<u64, VersionData>>();
@ -51,9 +58,22 @@ impl VersionTable {
pub fn v10_and_above(network_magic: u64) -> VersionTable { pub fn v10_and_above(network_magic: u64) -> VersionTable {
let values = vec![ let values = vec![
(PROTOCOL_V10, VersionData(network_magic)), (PROTOCOL_V10, VersionData(network_magic, None)),
(PROTOCOL_V11, VersionData(network_magic)), (PROTOCOL_V11, VersionData(network_magic, None)),
(PROTOCOL_V12, VersionData(network_magic)), (PROTOCOL_V12, VersionData(network_magic, None)),
(PROTOCOL_V13, VersionData(network_magic, None)),
(PROTOCOL_V14, VersionData(network_magic, None)),
(PROTOCOL_V15, VersionData(network_magic, Some(false))),
]
.into_iter()
.collect::<HashMap<u64, VersionData>>();
VersionTable { values }
}
pub fn v15_with_query(network_magic: u64) -> VersionTable {
let values = vec![
(PROTOCOL_V15, VersionData(network_magic, Some(true))),
] ]
.into_iter() .into_iter()
.collect::<HashMap<u64, VersionData>>(); .collect::<HashMap<u64, VersionData>>();
@ -63,7 +83,7 @@ impl VersionTable {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct VersionData(NetworkMagic); pub struct VersionData(NetworkMagic, Option<bool>);
impl Encode<()> for VersionData { impl Encode<()> for VersionData {
fn encode<W: encode::Write>( fn encode<W: encode::Write>(
@ -71,7 +91,14 @@ impl Encode<()> for VersionData {
e: &mut Encoder<W>, e: &mut Encoder<W>,
_ctx: &mut (), _ctx: &mut (),
) -> Result<(), encode::Error<W::Error>> { ) -> Result<(), encode::Error<W::Error>> {
e.u64(self.0)?; match self.1 {
None => { e.u64(self.0)?; }
Some(is_query) => {
e.array(2)?;
e.u64(self.0)?;
e.bool(is_query)?;
}
}
Ok(()) Ok(())
} }
@ -79,8 +106,20 @@ impl Encode<()> for VersionData {
impl<'b> Decode<'b, ()> for VersionData { impl<'b> Decode<'b, ()> for VersionData {
fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> { fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> {
let network_magic = d.u64()?; match d.datatype()? {
Type::U8 | Type::U16 | Type::U32 | Type::U64 => {
Ok(Self(network_magic)) let network_magic = d.u64()?;
Ok(Self(network_magic, None))
}
Type::Array => {
d.array()?;
let network_magic = d.u64()?;
let is_query = d.bool()?;
Ok(Self(network_magic, Some(is_query)))
}
_ => Err(decode::Error::message(
"unknown type for VersionData",
)),
}
} }
} }

View file

@ -55,8 +55,15 @@ impl<'b, T> Decode<'b, ()> for VersionTable<T>
where where
T: Debug + Clone + Decode<'b, ()>, T: Debug + Clone + Decode<'b, ()>,
{ {
fn decode(d: &mut Decoder<'b>, ctx: &mut ()) -> Result<Self, decode::Error> { fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> {
let values = d.map_iter_with(ctx)?.collect::<Result<_, _>>()?; let len = d.map()?.ok_or(decode::Error::message("expected def-length map for versiontable"))?;
let mut values = HashMap::new();
for _ in 0..len {
let key = d.u64()?;
let value = d.decode()?;
values.insert(key, value);
}
Ok(VersionTable { values }) Ok(VersionTable { values })
} }
} }
@ -73,6 +80,7 @@ where
Propose(VersionTable<D>), Propose(VersionTable<D>),
Accept(VersionNumber, D), Accept(VersionNumber, D),
Refuse(RefuseReason), Refuse(RefuseReason),
QueryReply(VersionTable<D>),
} }
impl<D> Encode<()> for Message<D> impl<D> Encode<()> for Message<D>
@ -100,6 +108,10 @@ where
e.array(2)?.u16(2)?; e.array(2)?.u16(2)?;
e.encode(reason)?; e.encode(reason)?;
} }
Message::QueryReply(version_table) => {
e.array(2)?.u16(3)?;
e.encode(version_table)?;
}
}; };
Ok(()) Ok(())
@ -128,6 +140,10 @@ where
let reason: RefuseReason = d.decode()?; let reason: RefuseReason = d.decode()?;
Ok(Message::Refuse(reason)) Ok(Message::Refuse(reason))
} }
3 => {
let version_table = d.decode()?;
Ok(Message::QueryReply(version_table))
}
_ => Err(decode::Error::message( _ => Err(decode::Error::message(
"unknown variant for handshake message", "unknown variant for handshake message",
)), )),

View file

@ -13,7 +13,7 @@ use tokio::time::Instant;
use tracing::{debug, error, trace}; use tracing::{debug, error, trace};
#[cfg(not(target_os = "windows"))] #[cfg(not(target_os = "windows"))]
use UnixStream; use tokio::net::UnixStream;
const HEADER_LEN: usize = 8; const HEADER_LEN: usize = 8;