feat: add handshake with query for n2c (#266)

This commit is contained in:
Andrew Westberg 2023-06-27 12:46:24 -04:00 committed by GitHub
parent 31a87032ca
commit 554fa1578e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 138 additions and 28 deletions

View file

@ -55,7 +55,15 @@ async fn main() {
// we connect to the unix socket of the local node. Make sure you have the right
// path for your environment
let mut client = NodeClient::connect("/tmp/node.socket", MAINNET_MAGIC)
let socket_path = "/tmp/node.socket";
// we connect to the unix socket of the local node and perform a handshake query
let version_table = NodeClient::handshake_query(socket_path, MAINNET_MAGIC)
.await
.unwrap();
info!("handshake query result: {:?}", version_table);
let mut client = NodeClient::connect(socket_path, MAINNET_MAGIC)
.await
.unwrap();

View file

@ -11,6 +11,7 @@ use crate::{
},
multiplexer::{self, Bearer},
};
use crate::miniprotocols::handshake::Confirmation;
#[derive(Debug, Error)]
pub enum Error {
@ -126,6 +127,44 @@ impl NodeClient {
})
}
#[cfg(not(target_os = "windows"))]
pub async fn handshake_query(path: impl AsRef<Path>, magic: u64) -> Result<handshake::n2c::VersionTable, Error> {
debug!("connecting");
let bearer = Bearer::connect_unix(path)
.await
.map_err(Error::ConnectFailure)?;
let mut plexer = multiplexer::Plexer::new(bearer);
let hs_channel = plexer.subscribe_client(PROTOCOL_N2C_HANDSHAKE);
let plexer_handle = tokio::spawn(async move { plexer.run().await });
let versions = handshake::n2c::VersionTable::v15_with_query(magic);
let mut client = handshake::Client::new(hs_channel);
let handshake = client
.handshake(versions)
.await
.map_err(Error::HandshakeProtocol)?;
match handshake {
Confirmation::Accepted(_, _) => {
error!("handshake accepted when we expected query reply");
Err(Error::HandshakeProtocol(handshake::Error::InvalidInbound))
}
Confirmation::Rejected(reason) => {
error!(?reason, "handshake refused");
Err(Error::IncompatibleVersion)
}
Confirmation::QueryReply(version_table) => {
plexer_handle.abort();
Ok(version_table)
}
}
}
pub fn chainsync(&mut self) -> &mut chainsync::N2CClient {
&mut self.chainsync
}

View file

@ -1,3 +1,4 @@
use std::fmt::Debug;
use pallas_codec::Fragment;
use std::marker::PhantomData;
use tracing::debug;
@ -6,16 +7,17 @@ use super::{Error, Message, RefuseReason, State, VersionNumber, VersionTable};
use crate::multiplexer;
#[derive(Debug)]
pub enum Confirmation<D> {
pub enum Confirmation<D: Debug + Clone> {
Accepted(VersionNumber, D),
Rejected(RefuseReason),
QueryReply(VersionTable<D>),
}
pub struct Client<D>(State, multiplexer::ChannelBuffer, PhantomData<D>);
impl<D> Client<D>
where
D: std::fmt::Debug + Clone,
D: Debug + Clone,
Message<D>: Fragment,
{
pub fn new(channel: multiplexer::AgentChannel) -> Self {
@ -69,6 +71,7 @@ where
match (&self.0, msg) {
(State::Confirm, Message::Accept(..)) => Ok(()),
(State::Confirm, Message::Refuse(..)) => Ok(()),
(State::Confirm, Message::QueryReply(..)) => Ok(()),
_ => Err(Error::InvalidInbound),
}
}
@ -113,6 +116,11 @@ where
Ok(Confirmation::Rejected(r))
}
Message::QueryReply(version_table) => {
debug!("handshake query reply");
Ok(Confirmation::QueryReply(version_table))
}
_ => Err(Error::InvalidInbound),
}
}

View file

@ -1,6 +1,7 @@
use std::collections::HashMap;
use pallas_codec::minicbor::{decode, encode, Decode, Decoder, Encode, Encoder};
use pallas_codec::minicbor::{decode, Decode, Decoder, encode, Encode, Encoder};
use pallas_codec::minicbor::data::Type;
use super::protocol::NetworkMagic;
@ -18,22 +19,28 @@ const PROTOCOL_V9: u64 = 32777;
const PROTOCOL_V10: u64 = 32778;
const PROTOCOL_V11: u64 = 32779;
const PROTOCOL_V12: u64 = 32780;
const PROTOCOL_V13: u64 = 32781;
const PROTOCOL_V14: u64 = 32782;
const PROTOCOL_V15: u64 = 32783;
impl VersionTable {
pub fn v1_and_above(network_magic: u64) -> VersionTable {
let values = vec![
(PROTOCOL_V1, VersionData(network_magic)),
(PROTOCOL_V2, VersionData(network_magic)),
(PROTOCOL_V3, VersionData(network_magic)),
(PROTOCOL_V4, VersionData(network_magic)),
(PROTOCOL_V5, VersionData(network_magic)),
(PROTOCOL_V6, VersionData(network_magic)),
(PROTOCOL_V7, VersionData(network_magic)),
(PROTOCOL_V8, VersionData(network_magic)),
(PROTOCOL_V9, VersionData(network_magic)),
(PROTOCOL_V10, VersionData(network_magic)),
(PROTOCOL_V11, VersionData(network_magic)),
(PROTOCOL_V12, VersionData(network_magic)),
(PROTOCOL_V1, VersionData(network_magic, None)),
(PROTOCOL_V2, VersionData(network_magic, None)),
(PROTOCOL_V3, VersionData(network_magic, None)),
(PROTOCOL_V4, VersionData(network_magic, None)),
(PROTOCOL_V5, VersionData(network_magic, None)),
(PROTOCOL_V6, VersionData(network_magic, None)),
(PROTOCOL_V7, VersionData(network_magic, None)),
(PROTOCOL_V8, VersionData(network_magic, None)),
(PROTOCOL_V9, VersionData(network_magic, None)),
(PROTOCOL_V10, VersionData(network_magic, None)),
(PROTOCOL_V11, VersionData(network_magic, None)),
(PROTOCOL_V12, VersionData(network_magic, None)),
(PROTOCOL_V13, VersionData(network_magic, None)),
(PROTOCOL_V14, VersionData(network_magic, None)),
(PROTOCOL_V15, VersionData(network_magic, Some(false))),
]
.into_iter()
.collect::<HashMap<u64, VersionData>>();
@ -42,7 +49,7 @@ impl VersionTable {
}
pub fn only_v10(network_magic: u64) -> VersionTable {
let values = vec![(PROTOCOL_V10, VersionData(network_magic))]
let values = vec![(PROTOCOL_V10, VersionData(network_magic, None))]
.into_iter()
.collect::<HashMap<u64, VersionData>>();
@ -51,9 +58,22 @@ impl VersionTable {
pub fn v10_and_above(network_magic: u64) -> VersionTable {
let values = vec![
(PROTOCOL_V10, VersionData(network_magic)),
(PROTOCOL_V11, VersionData(network_magic)),
(PROTOCOL_V12, VersionData(network_magic)),
(PROTOCOL_V10, VersionData(network_magic, None)),
(PROTOCOL_V11, VersionData(network_magic, None)),
(PROTOCOL_V12, VersionData(network_magic, None)),
(PROTOCOL_V13, VersionData(network_magic, None)),
(PROTOCOL_V14, VersionData(network_magic, None)),
(PROTOCOL_V15, VersionData(network_magic, Some(false))),
]
.into_iter()
.collect::<HashMap<u64, VersionData>>();
VersionTable { values }
}
pub fn v15_with_query(network_magic: u64) -> VersionTable {
let values = vec![
(PROTOCOL_V15, VersionData(network_magic, Some(true))),
]
.into_iter()
.collect::<HashMap<u64, VersionData>>();
@ -63,7 +83,7 @@ impl VersionTable {
}
#[derive(Debug, Clone)]
pub struct VersionData(NetworkMagic);
pub struct VersionData(NetworkMagic, Option<bool>);
impl Encode<()> for VersionData {
fn encode<W: encode::Write>(
@ -71,7 +91,14 @@ impl Encode<()> for VersionData {
e: &mut Encoder<W>,
_ctx: &mut (),
) -> Result<(), encode::Error<W::Error>> {
e.u64(self.0)?;
match self.1 {
None => { e.u64(self.0)?; }
Some(is_query) => {
e.array(2)?;
e.u64(self.0)?;
e.bool(is_query)?;
}
}
Ok(())
}
@ -79,8 +106,20 @@ impl Encode<()> for VersionData {
impl<'b> Decode<'b, ()> for VersionData {
fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> {
let network_magic = d.u64()?;
Ok(Self(network_magic))
match d.datatype()? {
Type::U8 | Type::U16 | Type::U32 | Type::U64 => {
let network_magic = d.u64()?;
Ok(Self(network_magic, None))
}
Type::Array => {
d.array()?;
let network_magic = d.u64()?;
let is_query = d.bool()?;
Ok(Self(network_magic, Some(is_query)))
}
_ => Err(decode::Error::message(
"unknown type for VersionData",
)),
}
}
}

View file

@ -55,8 +55,15 @@ impl<'b, T> Decode<'b, ()> for VersionTable<T>
where
T: Debug + Clone + Decode<'b, ()>,
{
fn decode(d: &mut Decoder<'b>, ctx: &mut ()) -> Result<Self, decode::Error> {
let values = d.map_iter_with(ctx)?.collect::<Result<_, _>>()?;
fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> {
let len = d.map()?.ok_or(decode::Error::message("expected def-length map for versiontable"))?;
let mut values = HashMap::new();
for _ in 0..len {
let key = d.u64()?;
let value = d.decode()?;
values.insert(key, value);
}
Ok(VersionTable { values })
}
}
@ -73,6 +80,7 @@ where
Propose(VersionTable<D>),
Accept(VersionNumber, D),
Refuse(RefuseReason),
QueryReply(VersionTable<D>),
}
impl<D> Encode<()> for Message<D>
@ -100,6 +108,10 @@ where
e.array(2)?.u16(2)?;
e.encode(reason)?;
}
Message::QueryReply(version_table) => {
e.array(2)?.u16(3)?;
e.encode(version_table)?;
}
};
Ok(())
@ -128,6 +140,10 @@ where
let reason: RefuseReason = d.decode()?;
Ok(Message::Refuse(reason))
}
3 => {
let version_table = d.decode()?;
Ok(Message::QueryReply(version_table))
}
_ => Err(decode::Error::message(
"unknown variant for handshake message",
)),

View file

@ -13,7 +13,7 @@ use tokio::time::Instant;
use tracing::{debug, error, trace};
#[cfg(not(target_os = "windows"))]
use UnixStream;
use tokio::net::UnixStream;
const HEADER_LEN: usize = 8;