From fdf41b5b3259612245edc92ce50fb2495dd51d5a Mon Sep 17 00:00:00 2001 From: Harper Date: Thu, 24 Aug 2023 00:17:17 +0100 Subject: [PATCH] feat(network): implement chain sync server side (#277) --- .../src/miniprotocols/chainsync/client.rs | 4 +- .../src/miniprotocols/chainsync/codec.rs | 40 ++- .../src/miniprotocols/chainsync/mod.rs | 2 + .../src/miniprotocols/chainsync/protocol.rs | 2 + .../src/miniprotocols/chainsync/server.rs | 285 ++++++++++++++++++ pallas-network/tests/protocols.rs | 200 ++++++++++++ 6 files changed, 524 insertions(+), 9 deletions(-) create mode 100644 pallas-network/src/miniprotocols/chainsync/server.rs diff --git a/pallas-network/src/miniprotocols/chainsync/client.rs b/pallas-network/src/miniprotocols/chainsync/client.rs index 593859b..e79c870 100644 --- a/pallas-network/src/miniprotocols/chainsync/client.rs +++ b/pallas-network/src/miniprotocols/chainsync/client.rs @@ -6,7 +6,7 @@ use tracing::debug; use crate::miniprotocols::Point; use crate::multiplexer; -use super::{BlockContent, HeaderContent, Message, State, Tip}; +use super::{BlockContent, HeaderContent, IntersectResponse, Message, State, Tip}; #[derive(Error, Debug)] pub enum Error { @@ -29,8 +29,6 @@ pub enum Error { Plexer(multiplexer::Error), } -pub type IntersectResponse = (Option, Tip); - #[derive(Debug)] pub enum NextResponse { RollForward(CONTENT, Tip), diff --git a/pallas-network/src/miniprotocols/chainsync/codec.rs b/pallas-network/src/miniprotocols/chainsync/codec.rs index ba23923..9a123d5 100644 --- a/pallas-network/src/miniprotocols/chainsync/codec.rs +++ b/pallas-network/src/miniprotocols/chainsync/codec.rs @@ -1,4 +1,5 @@ use pallas_codec::minicbor; +use pallas_codec::minicbor::encode::Error; use pallas_codec::minicbor::{decode, encode, Decode, Decoder, Encode, Encoder}; use super::{BlockContent, HeaderContent, Message, SkippedContent, Tip}; @@ -167,10 +168,32 @@ impl<'b> Decode<'b, ()> for HeaderContent { impl Encode<()> for HeaderContent { fn encode( &self, - _e: &mut Encoder, + e: &mut Encoder, _ctx: &mut (), ) -> Result<(), encode::Error> { - todo!() + e.array(2)?; + e.u8(self.variant)?; + + // variant 0 is byron + if self.variant == 0 { + e.array(2)?; + + if let Some((a, b)) = self.byron_prefix { + e.array(2)?; + e.u8(a)?; + e.u64(b)?; + } else { + return Err(Error::message("header variant 0 but no byron prefix")); + } + + e.tag(minicbor::data::Tag::Cbor)?; + e.bytes(&self.cbor)?; + } else { + e.tag(minicbor::data::Tag::Cbor)?; + e.bytes(&self.cbor)?; + } + + Ok(()) } } @@ -185,10 +208,13 @@ impl<'b> Decode<'b, ()> for BlockContent { impl Encode<()> for BlockContent { fn encode( &self, - _e: &mut Encoder, + e: &mut Encoder, _ctx: &mut (), ) -> Result<(), encode::Error> { - todo!() + e.tag(minicbor::data::Tag::Cbor)?; + e.bytes(&self.0)?; + + Ok(()) } } @@ -202,9 +228,11 @@ impl<'b> Decode<'b, ()> for SkippedContent { impl Encode<()> for SkippedContent { fn encode( &self, - _e: &mut Encoder, + e: &mut Encoder, _ctx: &mut (), ) -> Result<(), encode::Error> { - todo!() + e.null()?; + + Ok(()) } } diff --git a/pallas-network/src/miniprotocols/chainsync/mod.rs b/pallas-network/src/miniprotocols/chainsync/mod.rs index 2ad863e..6b732fe 100644 --- a/pallas-network/src/miniprotocols/chainsync/mod.rs +++ b/pallas-network/src/miniprotocols/chainsync/mod.rs @@ -2,8 +2,10 @@ mod buffer; mod client; mod codec; mod protocol; +mod server; pub use buffer::*; pub use client::*; pub use codec::*; pub use protocol::*; +pub use server::*; diff --git a/pallas-network/src/miniprotocols/chainsync/protocol.rs b/pallas-network/src/miniprotocols/chainsync/protocol.rs index ee1eef5..1a9015c 100644 --- a/pallas-network/src/miniprotocols/chainsync/protocol.rs +++ b/pallas-network/src/miniprotocols/chainsync/protocol.rs @@ -5,6 +5,8 @@ use crate::miniprotocols::Point; #[derive(Debug, Clone)] pub struct Tip(pub Point, pub u64); +pub type IntersectResponse = (Option, Tip); + #[derive(Debug, PartialEq, Eq, Clone)] pub enum State { Idle, diff --git a/pallas-network/src/miniprotocols/chainsync/server.rs b/pallas-network/src/miniprotocols/chainsync/server.rs new file mode 100644 index 0000000..ae58e47 --- /dev/null +++ b/pallas-network/src/miniprotocols/chainsync/server.rs @@ -0,0 +1,285 @@ +use pallas_codec::Fragment; +use std::marker::PhantomData; +use thiserror::Error; +use tracing::debug; + +use crate::miniprotocols::Point; +use crate::multiplexer; + +use super::{BlockContent, HeaderContent, Message, State, Tip}; + +#[derive(Error, Debug)] +pub enum Error { + #[error("attempted to receive message while agency is ours")] + AgencyIsOurs, + + #[error("attempted to send message while agency is theirs")] + AgencyIsTheirs, + + #[error("inbound message is not valid for current state")] + InvalidInbound, + + #[error("outbound message is not valid for current state")] + InvalidOutbound, + + #[error("error while sending or receiving data through the channel")] + Plexer(multiplexer::Error), +} + +#[derive(Debug)] +pub enum ClientRequest { + Intersect(Vec), + RequestNext, +} + +pub struct Server(State, multiplexer::ChannelBuffer, PhantomData) +where + Message: Fragment; + +impl Server +where + Message: Fragment, +{ + /// Constructs a new ChainSync `Server` instance. + /// + /// # Arguments + /// + /// * `channel` - An instance of `multiplexer::AgentChannel` to be used for + /// communication. + pub fn new(channel: multiplexer::AgentChannel) -> Self { + Self( + State::Idle, + multiplexer::ChannelBuffer::new(channel), + PhantomData {}, + ) + } + + /// Returns the current state of the server. + pub fn state(&self) -> &State { + &self.0 + } + + /// Checks if the server state is done. + pub fn is_done(&self) -> bool { + self.0 == State::Done + } + + /// Checks if the server has agency. + pub fn has_agency(&self) -> bool { + match self.state() { + State::Idle => false, + State::CanAwait => true, + State::MustReply => true, + State::Intersect => true, + State::Done => false, + } + } + + fn assert_agency_is_ours(&self) -> Result<(), Error> { + if !self.has_agency() { + Err(Error::AgencyIsTheirs) + } else { + Ok(()) + } + } + + fn assert_agency_is_theirs(&self) -> Result<(), Error> { + if self.has_agency() { + Err(Error::AgencyIsOurs) + } else { + Ok(()) + } + } + + fn assert_outbound_state(&self, msg: &Message) -> Result<(), Error> { + match (&self.0, msg) { + (State::CanAwait, Message::RollForward(_, _)) => Ok(()), + (State::CanAwait, Message::RollBackward(_, _)) => Ok(()), + (State::CanAwait, Message::AwaitReply) => Ok(()), + (State::MustReply, Message::RollForward(_, _)) => Ok(()), + (State::MustReply, Message::RollBackward(_, _)) => Ok(()), + (State::Intersect, Message::IntersectFound(_, _)) => Ok(()), + (State::Intersect, Message::IntersectNotFound(_)) => Ok(()), + _ => Err(Error::InvalidOutbound), + } + } + + fn assert_inbound_state(&self, msg: &Message) -> Result<(), Error> { + match (&self.0, msg) { + (State::Idle, Message::RequestNext) => Ok(()), + (State::Idle, Message::FindIntersect(_)) => Ok(()), + (State::Idle, Message::Done) => Ok(()), + _ => Err(Error::InvalidInbound), + } + } + + /// Sends a message to the client + /// + /// # Arguments + /// + /// * `msg` - A reference to the `Message` to be sent. + /// + /// # Errors + /// + /// Returns an error if the agency is not ours or if the outbound state is + /// invalid. + pub async fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + self.assert_agency_is_ours()?; + self.assert_outbound_state(msg)?; + + self.1.send_msg_chunks(msg).await.map_err(Error::Plexer)?; + + Ok(()) + } + + /// Receives the next message from the client. + /// + /// # Errors + /// + /// Returns an error if the agency is not theirs or if the inbound state is + /// invalid. + async fn recv_message(&mut self) -> Result, Error> { + self.assert_agency_is_theirs()?; + + let msg = self.1.recv_full_msg().await.map_err(Error::Plexer)?; + + self.assert_inbound_state(&msg)?; + + Ok(msg) + } + + /// Receive a message from the client when the protocol state is Idle. + /// + /// # Errors + /// + /// Returns an error if the agency is not theirs or if the inbound message + /// is invalid for Idle protocol state. + pub async fn recv_while_idle(&mut self) -> Result, Error> { + match self.recv_message().await? { + Message::FindIntersect(points) => { + self.0 = State::Intersect; + Ok(Some(ClientRequest::Intersect(points))) + } + Message::RequestNext => { + self.0 = State::CanAwait; + Ok(Some(ClientRequest::RequestNext)) + } + Message::Done => { + self.0 = State::Done; + + Ok(None) + } + _ => Err(Error::InvalidInbound), + } + } + + /// Sends an IntersectNotFound message to the client. + /// + /// # Arguments + /// + /// * `tip` - the most recent point of the server's chain. + /// + /// # Errors + /// + /// Returns an error if the message cannot be sent or if it's not valid for + /// the current state of the server. + pub async fn send_intersect_not_found(&mut self, tip: Tip) -> Result<(), Error> { + debug!("send intersect not found"); + + let msg = Message::IntersectNotFound(tip); + self.send_message(&msg).await?; + self.0 = State::Idle; + + Ok(()) + } + + /// Sends an IntersectFound message to the client. + /// + /// # Arguments + /// + /// * `point` - the first point in the client's provided list of intersect + /// points that was found in the servers's current chain. + /// * `tip` - the most recent point of the server's chain. + /// + /// # Errors + /// + /// Returns an error if the message cannot be sent or if it's not valid for + /// the current state of the server. + pub async fn send_intersect_found(&mut self, point: Point, tip: Tip) -> Result<(), Error> { + debug!("send intersect found ({point:?}"); + + let msg = Message::IntersectFound(point, tip); + self.send_message(&msg).await?; + self.0 = State::Idle; + + Ok(()) + } + + /// Sends a RollForward message to the client. + /// + /// # Arguments + /// + /// * `content` - the data to send to the client: for example block headers + /// for N2N or full blocks for N2C. + /// * `tip` - the most recent point of the server's chain. + /// + /// # Errors + /// + /// Returns an error if the message cannot be sent or if it's not valid for + /// the current state of the server. + pub async fn send_roll_forward(&mut self, content: O, tip: Tip) -> Result<(), Error> { + debug!("send roll forward"); + + let msg = Message::RollForward(content, tip); + self.send_message(&msg).await?; + self.0 = State::Idle; + + Ok(()) + } + + /// Sends a RollBackward message to the client. + /// + /// # Arguments + /// + /// * `point` - point at which the client should rollback their chain to. + /// * `tip` - the most recent point of the server's chain. + /// + /// # Errors + /// + /// Returns an error if the message cannot be sent or if it's not valid for + /// the current state of the server. + pub async fn send_roll_backward(&mut self, point: Point, tip: Tip) -> Result<(), Error> { + debug!("send roll backward {point:?}"); + + let msg = Message::RollBackward(point, tip); + self.send_message(&msg).await?; + self.0 = State::Idle; + + Ok(()) + } + + /// Sends an AwaitReply message to the client. + /// + /// # Arguments + /// + /// * `point` - point at which the client should rollback their chain to. + /// * `tip` - the most recent point of the server's chain. + /// + /// # Errors + /// + /// Returns an error if the message cannot be sent or if it's not valid for + /// the current state of the server. + pub async fn send_await_reply(&mut self) -> Result<(), Error> { + debug!("send await reply"); + + let msg = Message::AwaitReply; + self.send_message(&msg).await?; + self.0 = State::MustReply; + + Ok(()) + } +} + +pub type N2NServer = Server; + +pub type N2CServer = Server; diff --git a/pallas-network/tests/protocols.rs b/pallas-network/tests/protocols.rs index 2c52658..eb92f80 100644 --- a/pallas-network/tests/protocols.rs +++ b/pallas-network/tests/protocols.rs @@ -3,6 +3,7 @@ use std::time::Duration; use pallas_network::facades::PeerClient; use pallas_network::miniprotocols::blockfetch::BlockRequest; +use pallas_network::miniprotocols::chainsync::{ClientRequest, HeaderContent, Tip}; use pallas_network::miniprotocols::handshake; use pallas_network::miniprotocols::handshake::n2n::VersionData; use pallas_network::miniprotocols::{ @@ -261,4 +262,203 @@ pub async fn blockfetch_server_and_client_happy_path() { _ = tokio::join!(client, server); } +#[tokio::test] +#[ignore] +pub async fn chainsync_server_and_client_happy_path_n2n() { + let point1 = Point::Specific(1, vec![0x01]); + let point2 = Point::Specific(2, vec![0x02]); + + let server = tokio::spawn({ + let point1 = point1.clone(); + let point2 = point2.clone(); + async move { + // server setup + + let server_listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 30001)) + .await + .unwrap(); + + let (bearer, _) = Bearer::accept_tcp(server_listener).await.unwrap(); + + let mut server_plexer = Plexer::new(bearer); + + let mut server_hs: handshake::Server = + handshake::Server::new(server_plexer.subscribe_server(0)); + let mut server_cs = chainsync::N2NServer::new(server_plexer.subscribe_server(2)); + + tokio::spawn(async move { server_plexer.run().await }); + + server_hs.receive_proposed_versions().await.unwrap(); + server_hs + .accept_version(10, VersionData::new(0, false)) + .await + .unwrap(); + + // server receives find intersect from client, sends intersect point + + assert_eq!(*server_cs.state(), chainsync::State::Idle); + + let intersect_points = match server_cs.recv_while_idle().await.unwrap().unwrap() { + ClientRequest::Intersect(points) => points, + ClientRequest::RequestNext => panic!("unexpected message"), + }; + + assert_eq!(*server_cs.state(), chainsync::State::Intersect); + assert_eq!(intersect_points, vec![point2.clone(), point1.clone()]); + + server_cs + .send_intersect_found(point2.clone(), Tip(point2.clone(), 1337)) + .await + .unwrap(); + + assert_eq!(*server_cs.state(), chainsync::State::Idle); + + // server receives request next from client, sends rollbackwards + + match server_cs.recv_while_idle().await.unwrap().unwrap() { + ClientRequest::RequestNext => (), + ClientRequest::Intersect(_) => panic!("unexpected message"), + }; + + assert_eq!(*server_cs.state(), chainsync::State::CanAwait); + + server_cs + .send_roll_backward(point2.clone(), Tip(point2.clone(), 1337)) + .await + .unwrap(); + + assert_eq!(*server_cs.state(), chainsync::State::Idle); + + // server receives request next from client, sends rollforwards + + match server_cs.recv_while_idle().await.unwrap().unwrap() { + ClientRequest::RequestNext => (), + ClientRequest::Intersect(_) => panic!("unexpected message"), + }; + + assert_eq!(*server_cs.state(), chainsync::State::CanAwait); + + let header2 = HeaderContent { + variant: 1, + byron_prefix: None, + cbor: hex::decode("c0ffeec0ffeec0ffee").unwrap(), + }; + + server_cs + .send_roll_forward(header2, Tip(point2.clone(), 1337)) + .await + .unwrap(); + + assert_eq!(*server_cs.state(), chainsync::State::Idle); + + // server receives request next from client, sends await reply + // then rollforwards + + match server_cs.recv_while_idle().await.unwrap().unwrap() { + ClientRequest::RequestNext => (), + ClientRequest::Intersect(_) => panic!("unexpected message"), + }; + + assert_eq!(*server_cs.state(), chainsync::State::CanAwait); + + server_cs.send_await_reply().await.unwrap(); + + assert_eq!(*server_cs.state(), chainsync::State::MustReply); + + let header1 = HeaderContent { + variant: 1, + byron_prefix: None, + cbor: hex::decode("deadbeefdeadbeef").unwrap(), + }; + + server_cs + .send_roll_forward(header1, Tip(point1.clone(), 123)) + .await + .unwrap(); + + assert_eq!(*server_cs.state(), chainsync::State::Idle); + + // server receives client done + + assert!(server_cs.recv_while_idle().await.unwrap().is_none()); + assert_eq!(*server_cs.state(), chainsync::State::Done); + } + }); + + let client = tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(2)).await; + + // client setup + + let mut client_to_server_conn = PeerClient::connect("localhost:30001", 0).await.unwrap(); + + let client_cs = client_to_server_conn.chainsync(); + + // client sends find intersect + + let intersect_response = client_cs + .find_intersect(vec![point2.clone(), point1.clone()]) + .await + .unwrap(); + + assert_eq!(intersect_response.0, Some(point2.clone())); + assert_eq!(intersect_response.1 .0, point2.clone()); + assert_eq!(intersect_response.1 .1, 1337); + + // client sends msg request next + + client_cs.send_request_next().await.unwrap(); + + // client receives rollback + + match client_cs.recv_while_can_await().await.unwrap() { + NextResponse::RollBackward(point, tip) => { + assert_eq!(point, point2.clone()); + assert_eq!(tip.0, point2.clone()); + assert_eq!(tip.1, 1337); + } + _ => panic!("unexpected response"), + } + + client_cs.send_request_next().await.unwrap(); + + // client receives roll forward + + match client_cs.recv_while_can_await().await.unwrap() { + NextResponse::RollForward(content, tip) => { + assert_eq!(content.cbor, hex::decode("c0ffeec0ffeec0ffee").unwrap()); + assert_eq!(tip.0, point2.clone()); + assert_eq!(tip.1, 1337); + } + _ => panic!("unexpected response"), + } + + // client sends msg request next + + client_cs.send_request_next().await.unwrap(); + + // client receives await + + match client_cs.recv_while_can_await().await.unwrap() { + NextResponse::Await => (), + _ => panic!("unexpected response"), + } + + match client_cs.recv_while_must_reply().await.unwrap() { + NextResponse::RollForward(content, tip) => { + assert_eq!(content.cbor, hex::decode("deadbeefdeadbeef").unwrap()); + assert_eq!(tip.0, point1.clone()); + assert_eq!(tip.1, 123); + } + _ => panic!("unexpected response"), + } + + // client sends done + + client_cs.send_done().await.unwrap(); + }); + + _ = tokio::join!(client, server); +} + // TODO: redo txsubmission client test