feat: add handshake with query for n2c (#266)
This commit is contained in:
parent
31a87032ca
commit
554fa1578e
6 changed files with 138 additions and 28 deletions
|
|
@ -55,7 +55,15 @@ async fn main() {
|
|||
|
||||
// we connect to the unix socket of the local node. Make sure you have the right
|
||||
// path for your environment
|
||||
let mut client = NodeClient::connect("/tmp/node.socket", MAINNET_MAGIC)
|
||||
let socket_path = "/tmp/node.socket";
|
||||
|
||||
// we connect to the unix socket of the local node and perform a handshake query
|
||||
let version_table = NodeClient::handshake_query(socket_path, MAINNET_MAGIC)
|
||||
.await
|
||||
.unwrap();
|
||||
info!("handshake query result: {:?}", version_table);
|
||||
|
||||
let mut client = NodeClient::connect(socket_path, MAINNET_MAGIC)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ use crate::{
|
|||
},
|
||||
multiplexer::{self, Bearer},
|
||||
};
|
||||
use crate::miniprotocols::handshake::Confirmation;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
|
|
@ -126,6 +127,44 @@ impl NodeClient {
|
|||
})
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
pub async fn handshake_query(path: impl AsRef<Path>, magic: u64) -> Result<handshake::n2c::VersionTable, Error> {
|
||||
debug!("connecting");
|
||||
|
||||
let bearer = Bearer::connect_unix(path)
|
||||
.await
|
||||
.map_err(Error::ConnectFailure)?;
|
||||
|
||||
let mut plexer = multiplexer::Plexer::new(bearer);
|
||||
|
||||
let hs_channel = plexer.subscribe_client(PROTOCOL_N2C_HANDSHAKE);
|
||||
|
||||
let plexer_handle = tokio::spawn(async move { plexer.run().await });
|
||||
|
||||
let versions = handshake::n2c::VersionTable::v15_with_query(magic);
|
||||
let mut client = handshake::Client::new(hs_channel);
|
||||
|
||||
let handshake = client
|
||||
.handshake(versions)
|
||||
.await
|
||||
.map_err(Error::HandshakeProtocol)?;
|
||||
|
||||
match handshake {
|
||||
Confirmation::Accepted(_, _) => {
|
||||
error!("handshake accepted when we expected query reply");
|
||||
Err(Error::HandshakeProtocol(handshake::Error::InvalidInbound))
|
||||
}
|
||||
Confirmation::Rejected(reason) => {
|
||||
error!(?reason, "handshake refused");
|
||||
Err(Error::IncompatibleVersion)
|
||||
}
|
||||
Confirmation::QueryReply(version_table) => {
|
||||
plexer_handle.abort();
|
||||
Ok(version_table)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chainsync(&mut self) -> &mut chainsync::N2CClient {
|
||||
&mut self.chainsync
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use std::fmt::Debug;
|
||||
use pallas_codec::Fragment;
|
||||
use std::marker::PhantomData;
|
||||
use tracing::debug;
|
||||
|
|
@ -6,16 +7,17 @@ use super::{Error, Message, RefuseReason, State, VersionNumber, VersionTable};
|
|||
use crate::multiplexer;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Confirmation<D> {
|
||||
pub enum Confirmation<D: Debug + Clone> {
|
||||
Accepted(VersionNumber, D),
|
||||
Rejected(RefuseReason),
|
||||
QueryReply(VersionTable<D>),
|
||||
}
|
||||
|
||||
pub struct Client<D>(State, multiplexer::ChannelBuffer, PhantomData<D>);
|
||||
|
||||
impl<D> Client<D>
|
||||
where
|
||||
D: std::fmt::Debug + Clone,
|
||||
D: Debug + Clone,
|
||||
Message<D>: Fragment,
|
||||
{
|
||||
pub fn new(channel: multiplexer::AgentChannel) -> Self {
|
||||
|
|
@ -69,6 +71,7 @@ where
|
|||
match (&self.0, msg) {
|
||||
(State::Confirm, Message::Accept(..)) => Ok(()),
|
||||
(State::Confirm, Message::Refuse(..)) => Ok(()),
|
||||
(State::Confirm, Message::QueryReply(..)) => Ok(()),
|
||||
_ => Err(Error::InvalidInbound),
|
||||
}
|
||||
}
|
||||
|
|
@ -113,6 +116,11 @@ where
|
|||
|
||||
Ok(Confirmation::Rejected(r))
|
||||
}
|
||||
Message::QueryReply(version_table) => {
|
||||
debug!("handshake query reply");
|
||||
|
||||
Ok(Confirmation::QueryReply(version_table))
|
||||
}
|
||||
_ => Err(Error::InvalidInbound),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use pallas_codec::minicbor::{decode, encode, Decode, Decoder, Encode, Encoder};
|
||||
use pallas_codec::minicbor::{decode, Decode, Decoder, encode, Encode, Encoder};
|
||||
use pallas_codec::minicbor::data::Type;
|
||||
|
||||
use super::protocol::NetworkMagic;
|
||||
|
||||
|
|
@ -18,22 +19,28 @@ const PROTOCOL_V9: u64 = 32777;
|
|||
const PROTOCOL_V10: u64 = 32778;
|
||||
const PROTOCOL_V11: u64 = 32779;
|
||||
const PROTOCOL_V12: u64 = 32780;
|
||||
const PROTOCOL_V13: u64 = 32781;
|
||||
const PROTOCOL_V14: u64 = 32782;
|
||||
const PROTOCOL_V15: u64 = 32783;
|
||||
|
||||
impl VersionTable {
|
||||
pub fn v1_and_above(network_magic: u64) -> VersionTable {
|
||||
let values = vec![
|
||||
(PROTOCOL_V1, VersionData(network_magic)),
|
||||
(PROTOCOL_V2, VersionData(network_magic)),
|
||||
(PROTOCOL_V3, VersionData(network_magic)),
|
||||
(PROTOCOL_V4, VersionData(network_magic)),
|
||||
(PROTOCOL_V5, VersionData(network_magic)),
|
||||
(PROTOCOL_V6, VersionData(network_magic)),
|
||||
(PROTOCOL_V7, VersionData(network_magic)),
|
||||
(PROTOCOL_V8, VersionData(network_magic)),
|
||||
(PROTOCOL_V9, VersionData(network_magic)),
|
||||
(PROTOCOL_V10, VersionData(network_magic)),
|
||||
(PROTOCOL_V11, VersionData(network_magic)),
|
||||
(PROTOCOL_V12, VersionData(network_magic)),
|
||||
(PROTOCOL_V1, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V2, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V3, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V4, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V5, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V6, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V7, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V8, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V9, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V10, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V11, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V12, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V13, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V14, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V15, VersionData(network_magic, Some(false))),
|
||||
]
|
||||
.into_iter()
|
||||
.collect::<HashMap<u64, VersionData>>();
|
||||
|
|
@ -42,7 +49,7 @@ impl VersionTable {
|
|||
}
|
||||
|
||||
pub fn only_v10(network_magic: u64) -> VersionTable {
|
||||
let values = vec![(PROTOCOL_V10, VersionData(network_magic))]
|
||||
let values = vec![(PROTOCOL_V10, VersionData(network_magic, None))]
|
||||
.into_iter()
|
||||
.collect::<HashMap<u64, VersionData>>();
|
||||
|
||||
|
|
@ -51,9 +58,22 @@ impl VersionTable {
|
|||
|
||||
pub fn v10_and_above(network_magic: u64) -> VersionTable {
|
||||
let values = vec![
|
||||
(PROTOCOL_V10, VersionData(network_magic)),
|
||||
(PROTOCOL_V11, VersionData(network_magic)),
|
||||
(PROTOCOL_V12, VersionData(network_magic)),
|
||||
(PROTOCOL_V10, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V11, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V12, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V13, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V14, VersionData(network_magic, None)),
|
||||
(PROTOCOL_V15, VersionData(network_magic, Some(false))),
|
||||
]
|
||||
.into_iter()
|
||||
.collect::<HashMap<u64, VersionData>>();
|
||||
|
||||
VersionTable { values }
|
||||
}
|
||||
|
||||
pub fn v15_with_query(network_magic: u64) -> VersionTable {
|
||||
let values = vec![
|
||||
(PROTOCOL_V15, VersionData(network_magic, Some(true))),
|
||||
]
|
||||
.into_iter()
|
||||
.collect::<HashMap<u64, VersionData>>();
|
||||
|
|
@ -63,7 +83,7 @@ impl VersionTable {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VersionData(NetworkMagic);
|
||||
pub struct VersionData(NetworkMagic, Option<bool>);
|
||||
|
||||
impl Encode<()> for VersionData {
|
||||
fn encode<W: encode::Write>(
|
||||
|
|
@ -71,7 +91,14 @@ impl Encode<()> for VersionData {
|
|||
e: &mut Encoder<W>,
|
||||
_ctx: &mut (),
|
||||
) -> Result<(), encode::Error<W::Error>> {
|
||||
e.u64(self.0)?;
|
||||
match self.1 {
|
||||
None => { e.u64(self.0)?; }
|
||||
Some(is_query) => {
|
||||
e.array(2)?;
|
||||
e.u64(self.0)?;
|
||||
e.bool(is_query)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -79,8 +106,20 @@ impl Encode<()> for VersionData {
|
|||
|
||||
impl<'b> Decode<'b, ()> for VersionData {
|
||||
fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> {
|
||||
let network_magic = d.u64()?;
|
||||
|
||||
Ok(Self(network_magic))
|
||||
match d.datatype()? {
|
||||
Type::U8 | Type::U16 | Type::U32 | Type::U64 => {
|
||||
let network_magic = d.u64()?;
|
||||
Ok(Self(network_magic, None))
|
||||
}
|
||||
Type::Array => {
|
||||
d.array()?;
|
||||
let network_magic = d.u64()?;
|
||||
let is_query = d.bool()?;
|
||||
Ok(Self(network_magic, Some(is_query)))
|
||||
}
|
||||
_ => Err(decode::Error::message(
|
||||
"unknown type for VersionData",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,8 +55,15 @@ impl<'b, T> Decode<'b, ()> for VersionTable<T>
|
|||
where
|
||||
T: Debug + Clone + Decode<'b, ()>,
|
||||
{
|
||||
fn decode(d: &mut Decoder<'b>, ctx: &mut ()) -> Result<Self, decode::Error> {
|
||||
let values = d.map_iter_with(ctx)?.collect::<Result<_, _>>()?;
|
||||
fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> {
|
||||
let len = d.map()?.ok_or(decode::Error::message("expected def-length map for versiontable"))?;
|
||||
let mut values = HashMap::new();
|
||||
|
||||
for _ in 0..len {
|
||||
let key = d.u64()?;
|
||||
let value = d.decode()?;
|
||||
values.insert(key, value);
|
||||
}
|
||||
Ok(VersionTable { values })
|
||||
}
|
||||
}
|
||||
|
|
@ -73,6 +80,7 @@ where
|
|||
Propose(VersionTable<D>),
|
||||
Accept(VersionNumber, D),
|
||||
Refuse(RefuseReason),
|
||||
QueryReply(VersionTable<D>),
|
||||
}
|
||||
|
||||
impl<D> Encode<()> for Message<D>
|
||||
|
|
@ -100,6 +108,10 @@ where
|
|||
e.array(2)?.u16(2)?;
|
||||
e.encode(reason)?;
|
||||
}
|
||||
Message::QueryReply(version_table) => {
|
||||
e.array(2)?.u16(3)?;
|
||||
e.encode(version_table)?;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
|
|
@ -128,6 +140,10 @@ where
|
|||
let reason: RefuseReason = d.decode()?;
|
||||
Ok(Message::Refuse(reason))
|
||||
}
|
||||
3 => {
|
||||
let version_table = d.decode()?;
|
||||
Ok(Message::QueryReply(version_table))
|
||||
}
|
||||
_ => Err(decode::Error::message(
|
||||
"unknown variant for handshake message",
|
||||
)),
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ use tokio::time::Instant;
|
|||
use tracing::{debug, error, trace};
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
use UnixStream;
|
||||
use tokio::net::UnixStream;
|
||||
|
||||
const HEADER_LEN: usize = 8;
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue