From b8ff4e9418aaf619f9ee9243a1accd14c1abe874 Mon Sep 17 00:00:00 2001 From: Santiago Carmuega Date: Sun, 9 Apr 2023 13:50:56 +0200 Subject: [PATCH] feat: Migrate to asynchronous I/O (#241) This commit updates the networking stack to use asynchronous I/O for improved performance and concurrency. We have replaced synchronous I/O calls with their asynchronous counterparts and refactored the code to use async/await and Tokio runtime. --- examples/block-decode/Cargo.toml | 1 - examples/block-download/Cargo.toml | 2 - examples/n2c-miniprotocols/Cargo.toml | 3 - examples/n2n-miniprotocols/Cargo.toml | 2 - pallas-miniprotocols/Cargo.toml | 3 + pallas-miniprotocols/src/blockfetch/client.rs | 52 +- pallas-miniprotocols/src/chainsync/client.rs | 57 +- pallas-miniprotocols/src/handshake/client.rs | 32 +- pallas-miniprotocols/src/handshake/server.rs | 27 +- pallas-miniprotocols/src/localstate/client.rs | 39 +- pallas-miniprotocols/src/machines.rs | 20 +- .../src/txsubmission/client.rs | 34 +- .../src/txsubmission/server.rs | 30 +- pallas-miniprotocols/tests/integration.rs | 88 +-- pallas-multiplexer/Cargo.toml | 3 + pallas-multiplexer/src/agents.rs | 82 +-- pallas-multiplexer/src/lib.rs | 2 + pallas-multiplexer/src/std.rs | 4 +- pallas-multiplexer/src/sync.rs | 4 +- pallas-multiplexer/tests/integration.rs | 8 +- pallas-upstream/Cargo.toml | 8 + pallas-upstream/src/api.rs | 29 +- pallas-upstream/src/blockfetch.rs | 37 +- pallas-upstream/src/chainsync.rs | 159 ++++-- pallas-upstream/src/cursor.rs | 56 -- pallas-upstream/src/framework.rs | 47 +- pallas-upstream/src/lib.rs | 4 +- pallas-upstream/src/plexer.rs | 527 +++++++++++++----- pallas-upstream/tests/integration.rs | 87 +++ 29 files changed, 924 insertions(+), 523 deletions(-) delete mode 100644 pallas-upstream/src/cursor.rs create mode 100644 pallas-upstream/tests/integration.rs diff --git a/examples/block-decode/Cargo.toml b/examples/block-decode/Cargo.toml index c35d02d..8eec03e 100644 --- a/examples/block-decode/Cargo.toml +++ b/examples/block-decode/Cargo.toml @@ -9,5 +9,4 @@ publish = false [dependencies] pallas = { path = "../../pallas" } net2 = "0.2.37" -env_logger = "0.10.0" hex = "0.4.3" diff --git a/examples/block-download/Cargo.toml b/examples/block-download/Cargo.toml index b8878bf..a313e3d 100644 --- a/examples/block-download/Cargo.toml +++ b/examples/block-download/Cargo.toml @@ -9,6 +9,4 @@ publish = false [dependencies] pallas = { path = "../../pallas" } net2 = "0.2.37" -env_logger = "0.10.0" hex = "0.4.3" - diff --git a/examples/n2c-miniprotocols/Cargo.toml b/examples/n2c-miniprotocols/Cargo.toml index 35dd294..bfd3dcd 100644 --- a/examples/n2c-miniprotocols/Cargo.toml +++ b/examples/n2c-miniprotocols/Cargo.toml @@ -9,8 +9,5 @@ publish = false [dependencies] pallas = { path = "../../pallas" } net2 = "0.2.37" -env_logger = "0.10.0" hex = "0.4.3" log = "0.4.16" - - diff --git a/examples/n2n-miniprotocols/Cargo.toml b/examples/n2n-miniprotocols/Cargo.toml index cd84003..ac04c27 100644 --- a/examples/n2n-miniprotocols/Cargo.toml +++ b/examples/n2n-miniprotocols/Cargo.toml @@ -9,7 +9,5 @@ publish = false [dependencies] pallas = { path = "../../pallas" } net2 = "0.2.37" -env_logger = "0.10.0" hex = "0.4.3" log = "0.4.16" - diff --git a/pallas-miniprotocols/Cargo.toml b/pallas-miniprotocols/Cargo.toml index 9264749..7b1815a 100644 --- a/pallas-miniprotocols/Cargo.toml +++ b/pallas-miniprotocols/Cargo.toml @@ -20,3 +20,6 @@ hex = "0.4.3" itertools = "0.10.3" thiserror = "1.0.31" tracing = "0.1.37" + +[dev-dependencies] +tokio = { version = "1.27.0", features = ["macros", "rt"] } diff --git a/pallas-miniprotocols/src/blockfetch/client.rs b/pallas-miniprotocols/src/blockfetch/client.rs index 0ba207e..17b03e3 100644 --- a/pallas-miniprotocols/src/blockfetch/client.rs +++ b/pallas-miniprotocols/src/blockfetch/client.rs @@ -99,32 +99,35 @@ where } } - pub fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + pub async fn send_message(&mut self, msg: &Message) -> Result<(), Error> { self.assert_agency_is_ours()?; self.assert_outbound_state(msg)?; - self.1.send_msg_chunks(msg).map_err(Error::ChannelError)?; + self.1 + .send_msg_chunks(msg) + .await + .map_err(Error::ChannelError)?; Ok(()) } - pub fn recv_message(&mut self) -> Result { + pub async fn recv_message(&mut self) -> Result { self.assert_agency_is_theirs()?; - let msg = self.1.recv_full_msg().map_err(Error::ChannelError)?; + let msg = self.1.recv_full_msg().await.map_err(Error::ChannelError)?; self.assert_inbound_state(&msg)?; Ok(msg) } - pub fn send_request_range(&mut self, range: (Point, Point)) -> Result<(), Error> { + pub async fn send_request_range(&mut self, range: (Point, Point)) -> Result<(), Error> { let msg = Message::RequestRange { range }; - self.send_message(&msg)?; + self.send_message(&msg).await?; self.0 = State::Busy; Ok(()) } - pub fn recv_while_busy(&mut self) -> Result { - match self.recv_message()? { + pub async fn recv_while_busy(&mut self) -> Result { + match self.recv_message().await? { Message::StartBatch => { info!("batch start"); self.0 = State::Streaming; @@ -139,14 +142,14 @@ where } } - pub fn request_range(&mut self, range: Range) -> Result { - self.send_request_range(range)?; + pub async fn request_range(&mut self, range: Range) -> Result { + self.send_request_range(range).await?; debug!("range requested"); - self.recv_while_busy() + self.recv_while_busy().await } - pub fn recv_while_streaming(&mut self) -> Result, Error> { - match self.recv_message()? { + pub async fn recv_while_streaming(&mut self) -> Result, Error> { + match self.recv_message().await? { Message::Block { body } => Ok(Some(body)), Message::BatchDone => { self.0 = State::Idle; @@ -156,25 +159,30 @@ where } } - pub fn fetch_single(&mut self, point: Point) -> Result { - self.request_range((point.clone(), point))? + pub async fn fetch_single(&mut self, point: Point) -> Result { + self.request_range((point.clone(), point)) + .await? .ok_or(Error::NoBlocks)?; - let body = self.recv_while_streaming()?.ok_or(Error::InvalidInbound)?; + let body = self + .recv_while_streaming() + .await? + .ok_or(Error::InvalidInbound)?; + debug!("body received"); - match self.recv_while_streaming()? { + match self.recv_while_streaming().await? { Some(_) => Err(Error::InvalidInbound), None => Ok(body), } } - pub fn fetch_range(&mut self, range: Range) -> Result, Error> { - self.request_range(range)?.ok_or(Error::NoBlocks)?; + pub async fn fetch_range(&mut self, range: Range) -> Result, Error> { + self.request_range(range).await?.ok_or(Error::NoBlocks)?; let mut all = vec![]; - while let Some(block) = self.recv_while_streaming()? { + while let Some(block) = self.recv_while_streaming().await? { debug!("body received"); all.push(block); } @@ -182,9 +190,9 @@ where Ok(all) } - pub fn send_done(&mut self) -> Result<(), Error> { + pub async fn send_done(&mut self) -> Result<(), Error> { let msg = Message::ClientDone; - self.send_message(&msg)?; + self.send_message(&msg).await?; self.0 = State::Done; Ok(()) diff --git a/pallas-miniprotocols/src/chainsync/client.rs b/pallas-miniprotocols/src/chainsync/client.rs index ec2f01f..747d682 100644 --- a/pallas-miniprotocols/src/chainsync/client.rs +++ b/pallas-miniprotocols/src/chainsync/client.rs @@ -108,35 +108,38 @@ where } } - pub fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + pub async fn send_message(&mut self, msg: &Message) -> Result<(), Error> { self.assert_agency_is_ours()?; self.assert_outbound_state(msg)?; - self.1.send_msg_chunks(msg).map_err(Error::ChannelError)?; + self.1 + .send_msg_chunks(msg) + .await + .map_err(Error::ChannelError)?; Ok(()) } - pub fn recv_message(&mut self) -> Result, Error> { + pub async fn recv_message(&mut self) -> Result, Error> { self.assert_agency_is_theirs()?; - let msg = self.1.recv_full_msg().map_err(Error::ChannelError)?; + let msg = self.1.recv_full_msg().await.map_err(Error::ChannelError)?; self.assert_inbound_state(&msg)?; Ok(msg) } - pub fn send_find_intersect(&mut self, points: Vec) -> Result<(), Error> { + pub async fn send_find_intersect(&mut self, points: Vec) -> Result<(), Error> { let msg = Message::FindIntersect(points); - self.send_message(&msg)?; + self.send_message(&msg).await?; self.0 = State::Intersect; Ok(()) } - pub fn recv_intersect_response(&mut self) -> Result { - match self.recv_message()? { + pub async fn recv_intersect_response(&mut self) -> Result { + match self.recv_message().await? { Message::IntersectFound(point, tip) => { self.0 = State::Idle; Ok((Some(point), tip)) @@ -149,21 +152,21 @@ where } } - pub fn find_intersect(&mut self, points: Vec) -> Result { - self.send_find_intersect(points)?; - self.recv_intersect_response() + pub async fn find_intersect(&mut self, points: Vec) -> Result { + self.send_find_intersect(points).await?; + self.recv_intersect_response().await } - pub fn send_request_next(&mut self) -> Result<(), Error> { + pub async fn send_request_next(&mut self) -> Result<(), Error> { let msg = Message::RequestNext; - self.send_message(&msg)?; + self.send_message(&msg).await?; self.0 = State::CanAwait; Ok(()) } - pub fn recv_while_can_await(&mut self) -> Result, Error> { - match self.recv_message()? { + pub async fn recv_while_can_await(&mut self) -> Result, Error> { + match self.recv_message().await? { Message::AwaitReply => { self.0 = State::MustReply; Ok(NextResponse::Await) @@ -180,8 +183,8 @@ where } } - pub fn recv_while_must_reply(&mut self) -> Result, Error> { - match self.recv_message()? { + pub async fn recv_while_must_reply(&mut self) -> Result, Error> { + match self.recv_message().await? { Message::RollForward(a, b) => { self.0 = State::Idle; Ok(NextResponse::RollForward(a, b)) @@ -194,35 +197,35 @@ where } } - pub fn request_next(&mut self) -> Result, Error> { + pub async fn request_next(&mut self) -> Result, Error> { debug!("requesting next block"); - self.send_request_next()?; + self.send_request_next().await?; - self.recv_while_can_await() + self.recv_while_can_await().await } - pub fn intersect_origin(&mut self) -> Result { + pub async fn intersect_origin(&mut self) -> Result { debug!("intersecting origin"); - let (point, _) = self.find_intersect(vec![Point::Origin])?; + let (point, _) = self.find_intersect(vec![Point::Origin]).await?; point.ok_or(Error::IntersectionNotFound) } - pub fn intersect_tip(&mut self) -> Result { - let (_, Tip(point, _)) = self.find_intersect(vec![Point::Origin])?; + pub async fn intersect_tip(&mut self) -> Result { + let (_, Tip(point, _)) = self.find_intersect(vec![Point::Origin]).await?; debug!(?point, "found tip value"); - let (point, _) = self.find_intersect(vec![point])?; + let (point, _) = self.find_intersect(vec![point]).await?; point.ok_or(Error::IntersectionNotFound) } - pub fn send_done(&mut self) -> Result<(), Error> { + pub async fn send_done(&mut self) -> Result<(), Error> { let msg = Message::Done; - self.send_message(&msg)?; + self.send_message(&msg).await?; self.0 = State::Done; Ok(()) diff --git a/pallas-miniprotocols/src/handshake/client.rs b/pallas-miniprotocols/src/handshake/client.rs index e307706..eb1a0c8 100644 --- a/pallas-miniprotocols/src/handshake/client.rs +++ b/pallas-miniprotocols/src/handshake/client.rs @@ -1,6 +1,7 @@ use pallas_codec::Fragment; use pallas_multiplexer::agents::{Channel, ChannelBuffer}; use std::marker::PhantomData; +use tracing::debug; use super::{Error, Message, RefuseReason, State, VersionNumber, VersionTable}; @@ -71,47 +72,56 @@ where } } - pub fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + pub async fn send_message(&mut self, msg: &Message) -> Result<(), Error> { self.assert_agency_is_ours()?; self.assert_outbound_state(msg)?; - self.1.send_msg_chunks(msg).map_err(Error::ChannelError)?; + self.1 + .send_msg_chunks(msg) + .await + .map_err(Error::ChannelError)?; Ok(()) } - pub fn recv_message(&mut self) -> Result, Error> { + pub async fn recv_message(&mut self) -> Result, Error> { self.assert_agency_is_theirs()?; - let msg = self.1.recv_full_msg().map_err(Error::ChannelError)?; + let msg = self.1.recv_full_msg().await.map_err(Error::ChannelError)?; self.assert_inbound_state(&msg)?; Ok(msg) } - pub fn send_propose(&mut self, versions: VersionTable) -> Result<(), Error> { + pub async fn send_propose(&mut self, versions: VersionTable) -> Result<(), Error> { let msg = Message::Propose(versions); - self.send_message(&msg)?; + self.send_message(&msg).await?; self.0 = State::Confirm; + debug!("version proposed"); + Ok(()) } - pub fn recv_while_confirm(&mut self) -> Result, Error> { - match self.recv_message()? { + pub async fn recv_while_confirm(&mut self) -> Result, Error> { + match self.recv_message().await? { Message::Accept(v, m) => { self.0 = State::Done; + debug!("handshake accepted"); + Ok(Confirmation::Accepted(v, m)) } Message::Refuse(r) => { self.0 = State::Done; + debug!("handshake refused"); + Ok(Confirmation::Rejected(r)) } _ => Err(Error::InvalidInbound), } } - pub fn handshake(&mut self, versions: VersionTable) -> Result, Error> { - self.send_propose(versions)?; - self.recv_while_confirm() + pub async fn handshake(&mut self, versions: VersionTable) -> Result, Error> { + self.send_propose(versions).await?; + self.recv_while_confirm().await } pub fn unwrap(self) -> H { diff --git a/pallas-miniprotocols/src/handshake/server.rs b/pallas-miniprotocols/src/handshake/server.rs index fddd1cb..5c619f2 100644 --- a/pallas-miniprotocols/src/handshake/server.rs +++ b/pallas-miniprotocols/src/handshake/server.rs @@ -62,24 +62,27 @@ where } } - pub fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + pub async fn send_message(&mut self, msg: &Message) -> Result<(), Error> { self.assert_agency_is_ours()?; self.assert_outbound_state(msg)?; - self.1.send_msg_chunks(msg).map_err(Error::ChannelError)?; + self.1 + .send_msg_chunks(msg) + .await + .map_err(Error::ChannelError)?; Ok(()) } - pub fn recv_message(&mut self) -> Result, Error> { + pub async fn recv_message(&mut self) -> Result, Error> { self.assert_agency_is_theirs()?; - let msg = self.1.recv_full_msg().map_err(Error::ChannelError)?; + let msg = self.1.recv_full_msg().await.map_err(Error::ChannelError)?; self.assert_inbound_state(&msg)?; Ok(msg) } - pub fn receive_proposed_versions(&mut self) -> Result, Error> { - match self.recv_message()? { + pub async fn receive_proposed_versions(&mut self) -> Result, Error> { + match self.recv_message().await? { Message::Propose(v) => { self.0 = State::Confirm; Ok(v) @@ -88,17 +91,21 @@ where } } - pub fn accept_version(&mut self, version: VersionNumber, extra_params: D) -> Result<(), Error> { + pub async fn accept_version( + &mut self, + version: VersionNumber, + extra_params: D, + ) -> Result<(), Error> { let message = Message::Accept(version, extra_params); - self.send_message(&message)?; + self.send_message(&message).await?; self.0 = State::Done; Ok(()) } - pub fn refuse(&mut self, reason: RefuseReason) -> Result<(), Error> { + pub async fn refuse(&mut self, reason: RefuseReason) -> Result<(), Error> { let message = Message::Refuse(reason); - self.send_message(&message)?; + self.send_message(&message).await?; self.0 = State::Done; Ok(()) diff --git a/pallas-miniprotocols/src/localstate/client.rs b/pallas-miniprotocols/src/localstate/client.rs index 58cab05..f84d724 100644 --- a/pallas-miniprotocols/src/localstate/client.rs +++ b/pallas-miniprotocols/src/localstate/client.rs @@ -105,32 +105,35 @@ where } } - pub fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + pub async fn send_message(&mut self, msg: &Message) -> Result<(), Error> { self.assert_agency_is_ours()?; self.assert_outbound_state(msg)?; - self.1.send_msg_chunks(msg).map_err(Error::ChannelError)?; + self.1 + .send_msg_chunks(msg) + .await + .map_err(Error::ChannelError)?; Ok(()) } - pub fn recv_message(&mut self) -> Result, Error> { + pub async fn recv_message(&mut self) -> Result, Error> { self.assert_agency_is_theirs()?; - let msg = self.1.recv_full_msg().map_err(Error::ChannelError)?; + let msg = self.1.recv_full_msg().await.map_err(Error::ChannelError)?; self.assert_inbound_state(&msg)?; Ok(msg) } - pub fn send_acquire(&mut self, point: Option) -> Result<(), Error> { + pub async fn send_acquire(&mut self, point: Option) -> Result<(), Error> { let msg = Message::::Acquire(point); - self.send_message(&msg)?; + self.send_message(&msg).await?; self.0 = State::Acquiring; Ok(()) } - pub fn recv_while_acquiring(&mut self) -> Result<(), Error> { - match self.recv_message()? { + pub async fn recv_while_acquiring(&mut self) -> Result<(), Error> { + match self.recv_message().await? { Message::Acquired => { self.0 = State::Acquired; Ok(()) @@ -143,21 +146,21 @@ where } } - pub fn acquire(&mut self, point: Option) -> Result<(), Error> { - self.send_acquire(point)?; - self.recv_while_acquiring() + pub async fn acquire(&mut self, point: Option) -> Result<(), Error> { + self.send_acquire(point).await?; + self.recv_while_acquiring().await } - pub fn send_query(&mut self, request: Q::Request) -> Result<(), Error> { + pub async fn send_query(&mut self, request: Q::Request) -> Result<(), Error> { let msg = Message::::Query(request); - self.send_message(&msg)?; + self.send_message(&msg).await?; self.0 = State::Querying; Ok(()) } - pub fn recv_while_querying(&mut self) -> Result { - match self.recv_message()? { + pub async fn recv_while_querying(&mut self) -> Result { + match self.recv_message().await? { Message::Result(x) => { self.0 = State::Acquired; Ok(x) @@ -166,9 +169,9 @@ where } } - pub fn query(&mut self, request: Q::Request) -> Result { - self.send_query(request)?; - self.recv_while_querying() + pub async fn query(&mut self, request: Q::Request) -> Result { + self.send_query(request).await?; + self.recv_while_querying().await } } diff --git a/pallas-miniprotocols/src/machines.rs b/pallas-miniprotocols/src/machines.rs index 18c6976..ce3f9e8 100644 --- a/pallas-miniprotocols/src/machines.rs +++ b/pallas-miniprotocols/src/machines.rs @@ -74,9 +74,9 @@ where Ok(()) } - pub fn run_step(&mut self) -> Result { + pub async fn run_step(&mut self) -> Result { let prev = self.agent.take().unwrap(); - let next = run_agent_step(prev, &mut self.buffer)?; + let next = run_agent_step(prev, &mut self.buffer).await?; let is_done = next.is_done(); self.agent.set(Some(next)); @@ -84,16 +84,16 @@ where Ok(is_done) } - pub fn fulfill(mut self) -> Result<(), MachineError> { + pub async fn fulfill(mut self) -> Result<(), MachineError> { self.start()?; - while self.run_step()? {} + while self.run_step().await? {} Ok(()) } } -pub fn run_agent_step(agent: A, channel: &mut ChannelBuffer) -> Transition +pub async fn run_agent_step(agent: A, channel: &mut ChannelBuffer) -> Transition where A: Agent, A::Message: Fragment + std::fmt::Debug, @@ -106,12 +106,16 @@ where channel .send_msg_chunks(&msg) + .await .map_err(MachineError::channel)?; agent.apply_outbound(msg) } false => { - let msg = channel.recv_full_msg().map_err(MachineError::channel)?; + let msg = channel + .recv_full_msg() + .await + .map_err(MachineError::channel)?; trace!(?msg, "processing inbound msg"); @@ -120,7 +124,7 @@ where } } -pub fn run_agent(agent: A, buffer: &mut ChannelBuffer) -> Transition +pub async fn run_agent(agent: A, buffer: &mut ChannelBuffer) -> Transition where A: Agent, A::Message: Fragment + std::fmt::Debug, @@ -129,7 +133,7 @@ where let mut agent = agent.apply_start()?; while !agent.is_done() { - agent = run_agent_step(agent, buffer)?; + agent = run_agent_step(agent, buffer).await?; } Ok(agent) diff --git a/pallas-miniprotocols/src/txsubmission/client.rs b/pallas-miniprotocols/src/txsubmission/client.rs index 79d0a2e..4b44c16 100644 --- a/pallas-miniprotocols/src/txsubmission/client.rs +++ b/pallas-miniprotocols/src/txsubmission/client.rs @@ -14,7 +14,8 @@ pub enum Request { Txs(Vec), } -/// A generic Ouroboros client for submitting a generic notion of "transactions" to another server +/// A generic Ouroboros client for submitting a generic notion of "transactions" +/// to another server pub struct GenericClient( State, ChannelBuffer, @@ -91,48 +92,51 @@ where } } - pub fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + pub async fn send_message(&mut self, msg: &Message) -> Result<(), Error> { self.assert_agency_is_ours()?; self.assert_outbound_state(msg)?; - self.1.send_msg_chunks(msg).map_err(Error::ChannelError)?; + self.1 + .send_msg_chunks(msg) + .await + .map_err(Error::ChannelError)?; Ok(()) } - pub fn recv_message(&mut self) -> Result, Error> { + pub async fn recv_message(&mut self) -> Result, Error> { self.assert_agency_is_theirs()?; - let msg = self.1.recv_full_msg().map_err(Error::ChannelError)?; + let msg = self.1.recv_full_msg().await.map_err(Error::ChannelError)?; self.assert_inbound_state(&msg)?; Ok(msg) } - pub fn send_init(&mut self) -> Result<(), Error> { + pub async fn send_init(&mut self) -> Result<(), Error> { let msg = Message::Init; - self.send_message(&msg)?; + self.send_message(&msg).await?; self.0 = State::Idle; Ok(()) } - pub fn reply_tx_ids(&mut self, ids: Vec>) -> Result<(), Error> { + pub async fn reply_tx_ids(&mut self, ids: Vec>) -> Result<(), Error> { let msg = Message::ReplyTxIds(ids); - self.send_message(&msg)?; + self.send_message(&msg).await?; self.0 = State::Idle; Ok(()) } - pub fn reply_txs(&mut self, txs: Vec) -> Result<(), Error> { + pub async fn reply_txs(&mut self, txs: Vec) -> Result<(), Error> { let msg = Message::ReplyTxs(txs); - self.send_message(&msg)?; + self.send_message(&msg).await?; self.0 = State::Idle; Ok(()) } - pub fn next_request(&mut self) -> Result, Error> { - match self.recv_message()? { + pub async fn next_request(&mut self) -> Result, Error> { + match self.recv_message().await? { Message::RequestTxIds(blocking, ack, req) => { self.0 = State::TxIdsBlocking; @@ -149,9 +153,9 @@ where } } - pub fn send_done(&mut self) -> Result<(), Error> { + pub async fn send_done(&mut self) -> Result<(), Error> { let msg = Message::Done; - self.send_message(&msg)?; + self.send_message(&msg).await?; self.0 = State::Done; Ok(()) diff --git a/pallas-miniprotocols/src/txsubmission/server.rs b/pallas-miniprotocols/src/txsubmission/server.rs index 431f73c..4cb3f8b 100644 --- a/pallas-miniprotocols/src/txsubmission/server.rs +++ b/pallas-miniprotocols/src/txsubmission/server.rs @@ -14,7 +14,8 @@ pub enum Reply { Done, } -/// A generic implementation of an ouroboros server protocol ready to request and receive transactions from a client +/// A generic implementation of an ouroboros server protocol ready to request +/// and receive transactions from a client pub struct GenericServer( State, ChannelBuffer, @@ -91,42 +92,45 @@ where } } - pub fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + pub async fn send_message(&mut self, msg: &Message) -> Result<(), Error> { self.assert_agency_is_ours()?; self.assert_outbound_state(msg)?; - self.1.send_msg_chunks(msg).map_err(Error::ChannelError)?; + self.1 + .send_msg_chunks(msg) + .await + .map_err(Error::ChannelError)?; Ok(()) } - pub fn recv_message(&mut self) -> Result, Error> { + pub async fn recv_message(&mut self) -> Result, Error> { self.assert_agency_is_theirs()?; - let msg = self.1.recv_full_msg().map_err(Error::ChannelError)?; + let msg = self.1.recv_full_msg().await.map_err(Error::ChannelError)?; self.assert_inbound_state(&msg)?; Ok(msg) } - pub fn wait_for_init(&mut self) -> Result<(), Error> { + pub async fn wait_for_init(&mut self) -> Result<(), Error> { if self.0 != State::Init { return Err(Error::AlreadyInitialized); } // recv_message calls assert_inbound_state, which ensures we get an init message - self.recv_message()?; + self.recv_message().await?; self.0 = State::Idle; Ok(()) } - pub fn acknowledge_and_request_tx_ids( + pub async fn acknowledge_and_request_tx_ids( &mut self, blocking: Blocking, acknowledge: TxCount, count: TxCount, ) -> Result<(), Error> { let msg = Message::RequestTxIds(blocking, acknowledge, count); - self.send_message(&msg)?; + self.send_message(&msg).await?; match blocking { true => self.0 = State::TxIdsBlocking, false => self.0 = State::TxIdsNonBlocking, @@ -135,16 +139,16 @@ where Ok(()) } - pub fn request_txs(&mut self, ids: Vec) -> Result<(), Error> { + pub async fn request_txs(&mut self, ids: Vec) -> Result<(), Error> { let msg = Message::RequestTxs(ids); - self.send_message(&msg)?; + self.send_message(&msg).await?; self.0 = State::Txs; Ok(()) } - pub fn receive_next_reply(&mut self) -> Result, Error> { - match self.recv_message()? { + pub async fn receive_next_reply(&mut self) -> Result, Error> { + match self.recv_message().await? { Message::ReplyTxIds(ids_and_sizes) => { self.0 = State::Idle; diff --git a/pallas-miniprotocols/tests/integration.rs b/pallas-miniprotocols/tests/integration.rs index fcbd567..bd7800c 100644 --- a/pallas-miniprotocols/tests/integration.rs +++ b/pallas-miniprotocols/tests/integration.rs @@ -14,7 +14,7 @@ struct N2NChannels { txsubmission: StdChannel, } -fn setup_n2n_client_connection() -> N2NChannels { +async fn setup_n2n_client_connection() -> N2NChannels { let bearer = Bearer::connect_tcp("preview-node.world.dev.cardano.org:30002").unwrap(); let mut plexer = StdPlexer::new(bearer); @@ -30,6 +30,7 @@ fn setup_n2n_client_connection() -> N2NChannels { let confirmation = client .handshake(handshake::n2n::VersionTable::v7_and_above(2)) + .await .unwrap(); assert!(matches!(confirmation, Confirmation::Accepted(..))); @@ -45,10 +46,10 @@ fn setup_n2n_client_connection() -> N2NChannels { } } -#[test] +#[tokio::test] #[ignore] -pub fn chainsync_history_happy_path() { - let N2NChannels { chainsync, .. } = setup_n2n_client_connection(); +pub async fn chainsync_history_happy_path() { + let N2NChannels { chainsync, .. } = setup_n2n_client_connection().await; let known_point = Point::Specific( 1654413, @@ -57,7 +58,10 @@ pub fn chainsync_history_happy_path() { let mut client = chainsync::N2NClient::new(chainsync); - let (point, _) = client.find_intersect(vec![known_point.clone()]).unwrap(); + let (point, _) = client + .find_intersect(vec![known_point.clone()]) + .await + .unwrap(); assert!(matches!(client.state(), chainsync::State::Idle)); @@ -66,7 +70,7 @@ pub fn chainsync_history_happy_path() { None => panic!("expected point"), } - let next = client.request_next().unwrap(); + let next = client.request_next().await.unwrap(); match next { NextResponse::RollBackward(point, _) => assert_eq!(point, known_point), @@ -76,7 +80,7 @@ pub fn chainsync_history_happy_path() { assert!(matches!(client.state(), chainsync::State::Idle)); for _ in 0..10 { - let next = client.request_next().unwrap(); + let next = client.request_next().await.unwrap(); match next { NextResponse::RollForward(_, _) => (), @@ -86,23 +90,23 @@ pub fn chainsync_history_happy_path() { assert!(matches!(client.state(), chainsync::State::Idle)); } - client.send_done().unwrap(); + client.send_done().await.unwrap(); assert!(matches!(client.state(), chainsync::State::Done)); } -#[test] +#[tokio::test] #[ignore] -pub fn chainsync_tip_happy_path() { - let N2NChannels { chainsync, .. } = setup_n2n_client_connection(); +pub async fn chainsync_tip_happy_path() { + let N2NChannels { chainsync, .. } = setup_n2n_client_connection().await; let mut client = chainsync::N2NClient::new(chainsync); - client.intersect_tip().unwrap(); + client.intersect_tip().await.unwrap(); assert!(matches!(client.state(), chainsync::State::Idle)); - let next = client.request_next().unwrap(); + let next = client.request_next().await.unwrap(); assert!(matches!(next, NextResponse::RollBackward(..))); @@ -110,10 +114,10 @@ pub fn chainsync_tip_happy_path() { for _ in 0..4 { let next = if client.has_agency() { - client.request_next().unwrap() + client.request_next().await.unwrap() } else { await_count += 1; - client.recv_while_must_reply().unwrap() + client.recv_while_must_reply().await.unwrap() }; match next { @@ -125,15 +129,15 @@ pub fn chainsync_tip_happy_path() { assert!(await_count > 0, "tip was never reached"); - client.send_done().unwrap(); + client.send_done().await.unwrap(); assert!(matches!(client.state(), chainsync::State::Done)); } -#[test] +#[tokio::test] #[ignore] -pub fn blockfetch_happy_path() { - let N2NChannels { blockfetch, .. } = setup_n2n_client_connection(); +pub async fn blockfetch_happy_path() { + let N2NChannels { blockfetch, .. } = setup_n2n_client_connection().await; let known_point = Point::Specific( 1654413, @@ -142,14 +146,16 @@ pub fn blockfetch_happy_path() { let mut client = blockfetch::Client::new(blockfetch); - let range_ok = client.request_range((known_point.clone(), known_point)); + let range_ok = client + .request_range((known_point.clone(), known_point)) + .await; assert!(matches!(client.state(), blockfetch::State::Streaming)); assert!(matches!(range_ok, Ok(_))); for _ in 0..1 { - let next = client.recv_while_streaming().unwrap(); + let next = client.recv_while_streaming().await.unwrap(); match next { Some(body) => assert_eq!(body.len(), 3251), @@ -159,60 +165,62 @@ pub fn blockfetch_happy_path() { assert!(matches!(client.state(), blockfetch::State::Streaming)); } - let next = client.recv_while_streaming().unwrap(); + let next = client.recv_while_streaming().await.unwrap(); assert!(matches!(next, None)); - client.send_done().unwrap(); + client.send_done().await.unwrap(); assert!(matches!(client.state(), blockfetch::State::Done)); } -#[test] +#[tokio::test] #[ignore] -pub fn txsubmission_server_happy_path() { +pub async fn txsubmission_server_happy_path() { // TODO(pi): Note that the below doesn't work; we need a node to connect *to us* // during the integration test which seems awkward; // Alternatively, we can just set up both a client and server connecting to // themselves for testing! - let N2NChannels { txsubmission, .. } = setup_n2n_client_connection(); + let N2NChannels { txsubmission, .. } = setup_n2n_client_connection().await; let mut server = txsubmission::Server::new(txsubmission); - assert!(matches!(server.wait_for_init(), Ok(_))); + assert!(matches!(server.wait_for_init().await, Ok(_))); assert!(matches!( - server.acknowledge_and_request_tx_ids(false, 0, 3), + server.acknowledge_and_request_tx_ids(false, 0, 3).await, Ok(_) )); - let reply: Result<_, _> = server.receive_next_reply(); + let reply: Result<_, _> = server.receive_next_reply().await; assert!(matches!(reply, Ok(Reply::TxIds(_)))); let Ok(Reply::TxIds(tx_ids)) = reply else { unreachable!() }; assert!(tx_ids.len() <= 3); assert!(matches!( - server.request_txs( - tx_ids - .into_iter() - .map(|txid: TxIdAndSize| txid.0) - .collect() - ), + server + .request_txs( + tx_ids + .into_iter() + .map(|txid: TxIdAndSize| txid.0) + .collect() + ) + .await, Ok(_) )); - let reply = server.receive_next_reply(); + let reply = server.receive_next_reply().await; assert!(matches!(reply, Ok(Reply::Txs(_)))); let Ok(Reply::Txs(first_txs)) = reply else { unreachable!() }; assert!(matches!( - server.acknowledge_and_request_tx_ids(false, 1, 3), + server.acknowledge_and_request_tx_ids(false, 1, 3).await, Ok(_) )); - let reply = server.receive_next_reply(); + let reply = server.receive_next_reply().await; assert!(matches!(reply, Ok(Reply::Txs(_)))); let Ok(Reply::Txs(second_txs)) = reply else { unreachable!() }; @@ -222,11 +230,11 @@ pub fn txsubmission_server_happy_path() { assert_eq!(second_txs[1], first_txs[2]); assert!(matches!( - server.acknowledge_and_request_tx_ids(true, 3, 3), + server.acknowledge_and_request_tx_ids(true, 3, 3).await, Ok(_) )); - match server.receive_next_reply() { + match server.receive_next_reply().await { Ok(Reply::Done) => (), // Server aint havin none of our sh*t Ok(Reply::TxIds(tx_ids)) => assert_eq!(tx_ids.len(), 3), Ok(_) | Err(_) => unreachable!(), diff --git a/pallas-multiplexer/Cargo.toml b/pallas-multiplexer/Cargo.toml index 0274d3b..e3772c7 100644 --- a/pallas-multiplexer/Cargo.toml +++ b/pallas-multiplexer/Cargo.toml @@ -23,3 +23,6 @@ tracing = "0.1.37" std = [] sync = [] default = ["std", "sync"] + +[dev-dependencies] +tokio = { version = "1.27.0", features = ["macros", "rt"] } diff --git a/pallas-multiplexer/src/agents.rs b/pallas-multiplexer/src/agents.rs index fee7432..64121d0 100644 --- a/pallas-multiplexer/src/agents.rs +++ b/pallas-multiplexer/src/agents.rs @@ -3,6 +3,7 @@ use crate::Payload; use pallas_codec::{minicbor, Fragment}; use thiserror::Error; +use tracing::{debug, error, trace}; #[derive(Debug, Error)] pub enum ChannelError { @@ -18,20 +19,14 @@ pub enum ChannelError { /// A raw link to the ingress / egress of the multiplexer pub trait Channel { - fn enqueue_chunk(&mut self, chunk: Payload) -> Result<(), ChannelError>; - fn dequeue_chunk(&mut self) -> Result; + async fn enqueue_chunk(&mut self, chunk: Payload) -> Result<(), ChannelError>; + async fn dequeue_chunk(&mut self) -> Result; } /// Protocol value that defines max segment length pub const MAX_SEGMENT_PAYLOAD_LENGTH: usize = 65535; -enum Decoding { - Done(M, usize), - NotEnoughData, - UnexpectedError(Box), -} - -fn try_decode_message(buffer: &[u8]) -> Decoding +fn try_decode_message(buffer: &mut Vec) -> Result, ChannelError> where M: Fragment, { @@ -39,9 +34,17 @@ where let maybe_msg = decoder.decode(); match maybe_msg { - Ok(msg) => Decoding::Done(msg, decoder.position()), - Err(err) if err.is_end_of_input() => Decoding::NotEnoughData, - Err(err) => Decoding::UnexpectedError(Box::new(err)), + Ok(msg) => { + let pos = decoder.position(); + buffer.drain(0..pos); + Ok(Some(msg)) + } + Err(err) if err.is_end_of_input() => Ok(None), + Err(err) => { + error!(?err); + error!("{}", hex::encode(buffer)); + Err(ChannelError::Decoding(err.to_string())) + } } } @@ -60,7 +63,7 @@ impl ChannelBuffer { } /// Enqueues a msg as a sequence payload chunks - pub fn send_msg_chunks(&mut self, msg: &M) -> Result<(), ChannelError> + pub async fn send_msg_chunks(&mut self, msg: &M) -> Result<(), ChannelError> where M: Fragment, { @@ -71,38 +74,34 @@ impl ChannelBuffer { let chunks = payload.chunks(MAX_SEGMENT_PAYLOAD_LENGTH); for chunk in chunks { - self.channel.enqueue_chunk(Vec::from(chunk))?; + self.channel.enqueue_chunk(Vec::from(chunk)).await?; } Ok(()) } /// Reads from the channel until a complete message is found - pub fn recv_full_msg(&mut self) -> Result + pub async fn recv_full_msg(&mut self) -> Result where M: Fragment, { - // do an eager reading if buffer is empty, no point in going through the error - // handling - if self.temp.is_empty() { - let chunk = self.channel.dequeue_chunk()?; - self.temp.extend(chunk); + if !self.temp.is_empty() { + if let Some(msg) = try_decode_message::(&mut self.temp)? { + debug!("decoding done"); + return Ok(msg); + } } - let decoding = try_decode_message::(&self.temp); + loop { + let chunk = self.channel.dequeue_chunk().await?; + self.temp.extend(chunk); - match decoding { - Decoding::Done(msg, pos) => { - self.temp.drain(0..pos); - Ok(msg) + if let Some(msg) = try_decode_message::(&mut self.temp)? { + debug!("decoding done"); + return Ok(msg); } - Decoding::UnexpectedError(err) => Err(ChannelError::Decoding(err.to_string())), - Decoding::NotEnoughData => { - let chunk = self.channel.dequeue_chunk()?; - self.temp.extend(chunk); - self.recv_full_msg() - } + trace!("not enough data"); } } @@ -124,19 +123,19 @@ mod tests { use super::*; impl Channel for VecDeque { - fn enqueue_chunk(&mut self, chunk: Payload) -> Result<(), ChannelError> { + async fn enqueue_chunk(&mut self, chunk: Payload) -> Result<(), ChannelError> { self.push_back(chunk); Ok(()) } - fn dequeue_chunk(&mut self) -> Result { + async fn dequeue_chunk(&mut self) -> Result { let chunk = self.pop_front().ok_or(ChannelError::NotConnected(None))?; Ok(chunk) } } - #[test] - fn multiple_messages_in_same_payload() { + #[tokio::test] + async fn multiple_messages_in_same_payload() { let mut input = Vec::new(); let in_part1 = (1u8, 2u8, 3u8); let in_part2 = (6u8, 5u8, 4u8); @@ -149,15 +148,15 @@ mod tests { let mut buf = ChannelBuffer::new(channel); - let out_part1 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap(); - let out_part2 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap(); + let out_part1 = buf.recv_full_msg::<(u8, u8, u8)>().await.unwrap(); + let out_part2 = buf.recv_full_msg::<(u8, u8, u8)>().await.unwrap(); assert_eq!(in_part1, out_part1); assert_eq!(in_part2, out_part2); } - #[test] - fn fragmented_message_in_multiple_payloads() { + #[tokio::test] + async fn fragmented_message_in_multiple_payloads() { let mut input = Vec::new(); let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8); minicbor::encode(msg, &mut input).unwrap(); @@ -171,7 +170,10 @@ mod tests { let mut buf = ChannelBuffer::new(channel); - let out_msg = buf.recv_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>().unwrap(); + let out_msg = buf + .recv_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>() + .await + .unwrap(); assert_eq!(msg, out_msg); } diff --git a/pallas-multiplexer/src/lib.rs b/pallas-multiplexer/src/lib.rs index d6922c5..dbad307 100644 --- a/pallas-multiplexer/src/lib.rs +++ b/pallas-multiplexer/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(async_fn_in_trait)] + pub mod agents; pub mod bearers; pub mod demux; diff --git a/pallas-multiplexer/src/std.rs b/pallas-multiplexer/src/std.rs index 0862fe7..7806683 100644 --- a/pallas-multiplexer/src/std.rs +++ b/pallas-multiplexer/src/std.rs @@ -138,14 +138,14 @@ pub type StdChannel = (u16, Sender, Receiver); pub type StdChannelBuffer = ChannelBuffer; impl agents::Channel for StdChannel { - fn enqueue_chunk(&mut self, payload: Payload) -> Result<(), agents::ChannelError> { + async fn enqueue_chunk(&mut self, payload: Payload) -> Result<(), agents::ChannelError> { match self.1.send((self.0, payload)) { Ok(_) => Ok(()), Err(SendError((_, payload))) => Err(agents::ChannelError::NotConnected(Some(payload))), } } - fn dequeue_chunk(&mut self) -> Result { + async fn dequeue_chunk(&mut self) -> Result { match self.2.recv() { Ok(payload) => Ok(payload), Err(_) => Err(agents::ChannelError::NotConnected(None)), diff --git a/pallas-multiplexer/src/sync.rs b/pallas-multiplexer/src/sync.rs index 96f068d..3d4ff94 100644 --- a/pallas-multiplexer/src/sync.rs +++ b/pallas-multiplexer/src/sync.rs @@ -29,7 +29,7 @@ impl SyncPlexer { pub type SyncChannel = ChannelBuffer; impl agents::Channel for SyncPlexer { - fn enqueue_chunk(&mut self, payload: Payload) -> Result<(), agents::ChannelError> { + async fn enqueue_chunk(&mut self, payload: Payload) -> Result<(), agents::ChannelError> { let segment = Segment::new(self.clock, self.protocol, payload); self.bearer @@ -37,7 +37,7 @@ impl agents::Channel for SyncPlexer { .map_err(|_| agents::ChannelError::NotConnected(None)) } - fn dequeue_chunk(&mut self) -> Result { + async fn dequeue_chunk(&mut self) -> Result { match self.bearer.read_segment() { Ok(segment) => match segment { Some(x) => { diff --git a/pallas-multiplexer/tests/integration.rs b/pallas-multiplexer/tests/integration.rs index d64d210..0d46fee 100644 --- a/pallas-multiplexer/tests/integration.rs +++ b/pallas-multiplexer/tests/integration.rs @@ -31,8 +31,8 @@ fn random_payload(size: usize) -> Vec { rand::thread_rng().sample_iter(&range).take(size).collect() } -#[test] -fn one_way_small_sequence_of_payloads() { +#[tokio::test] +async fn one_way_small_sequence_of_payloads() { let passive = setup_passive_muxer::<50301>(); // HACK: a small sleep seems to be required for Github actions runner to @@ -52,8 +52,8 @@ fn one_way_small_sequence_of_payloads() { for _ in 0..100 { let payload = random_payload(50); - sender_channel.enqueue_chunk(payload.clone()).unwrap(); - let received_payload = receiver_channel.dequeue_chunk().unwrap(); + sender_channel.enqueue_chunk(payload.clone()).await.unwrap(); + let received_payload = receiver_channel.dequeue_chunk().await.unwrap(); assert_eq!(payload, received_payload); } } diff --git a/pallas-upstream/Cargo.toml b/pallas-upstream/Cargo.toml index 9fefe42..18be522 100644 --- a/pallas-upstream/Cargo.toml +++ b/pallas-upstream/Cargo.toml @@ -11,7 +11,11 @@ readme = "README.md" authors = ["Santiago Carmuega "] [dependencies] +byteorder = "1.4.3" gasket = { git = "https://github.com/construkts/gasket-rs" } +# gasket = { path = "../../../construkts/gasket-rs" } +hex = "0.4.3" +mio = { version = "0.8.6", features = ["net", "os-poll"] } # gasket = { version = "0.1.0", path = "../../../construkts/gasket-rs" } pallas-codec = { version = "0.18.0", path = "../pallas-codec" } pallas-crypto = { version = "0.18.0", path = "../pallas-crypto" } @@ -21,4 +25,8 @@ pallas-traverse = { version = "0.18.0", path = "../pallas-traverse" } rayon = "1.7.0" serde = { version = "1.0.154", features = ["derive"] } thiserror = "1.0.31" +tokio = { version = "1", features = ["net", "macros", "io-util"] } tracing = "0.1.37" + +[dev-dependencies] +tracing-subscriber = "0.3.16" diff --git a/pallas-upstream/src/api.rs b/pallas-upstream/src/api.rs index e68d2a4..1193875 100644 --- a/pallas-upstream/src/api.rs +++ b/pallas-upstream/src/api.rs @@ -1,11 +1,8 @@ -pub use crate::cursor; - -pub use crate::framework::BlockFetchEvent; - -pub use crate::framework::DownstreamPort; +pub use crate::framework::{BlockFetchEvent, Cursor, DownstreamPort, Intersection}; pub mod n2n { - use crate::{blockfetch, chainsync, cursor::Cursor, framework::*, plexer}; + use crate::{blockfetch, chainsync, framework::*, plexer}; + use gasket::{ messaging::{SendAdapter, SendPort}, runtime::Tether, @@ -17,21 +14,23 @@ pub mod n2n { pub blockfetch_tether: Tether, } - pub struct Bootstrapper + pub struct Bootstrapper where A: SendAdapter, + C: Cursor, { - cursor: Cursor, + cursor: C, peer_address: String, network_magic: u64, output: super::DownstreamPort, } - impl Bootstrapper + impl Bootstrapper where A: SendAdapter + 'static, + C: Cursor + 'static, { - pub fn new(cursor: Cursor, peer_address: String, network_magic: u64) -> Self { + pub fn new(cursor: C, peer_address: String, network_magic: u64) -> Self { Bootstrapper { cursor, peer_address, @@ -67,15 +66,15 @@ pub mod n2n { let mut demux2_out = DemuxOutputPort::default(); let mut demux2_in = DemuxInputPort::default(); - gasket::messaging::crossbeam::connect_ports(&mut demux2_out, &mut demux2_in, 1000); + gasket::messaging::tokio::connect_ports(&mut demux2_out, &mut demux2_in, 1000); let mut demux3_out = DemuxOutputPort::default(); let mut demux3_in = DemuxInputPort::default(); - gasket::messaging::crossbeam::connect_ports(&mut demux3_out, &mut demux3_in, 1000); + gasket::messaging::tokio::connect_ports(&mut demux3_out, &mut demux3_in, 1000); let mut mux2_out = MuxOutputPort::default(); let mut mux3_out = MuxOutputPort::default(); - gasket::messaging::crossbeam::funnel_ports( + gasket::messaging::tokio::funnel_ports( vec![&mut mux2_out, &mut mux3_out], &mut mux_input, 1000, @@ -83,10 +82,10 @@ pub mod n2n { let mut chainsync_downstream = chainsync::DownstreamPort::default(); let mut blockfetch_upstream = blockfetch::UpstreamPort::default(); - gasket::messaging::crossbeam::connect_ports( + gasket::messaging::tokio::connect_ports( &mut chainsync_downstream, &mut blockfetch_upstream, - 20, + 100, ); let plexer_tether = gasket::runtime::spawn_stage( diff --git a/pallas-upstream/src/blockfetch.rs b/pallas-upstream/src/blockfetch.rs index 107bf52..3df0b27 100644 --- a/pallas-upstream/src/blockfetch.rs +++ b/pallas-upstream/src/blockfetch.rs @@ -1,4 +1,5 @@ use gasket::messaging::SendAdapter; +use gasket::runtime::WorkSchedule; use tracing::{error, info, instrument}; use pallas_crypto::hash::Hash; @@ -7,7 +8,7 @@ use pallas_miniprotocols::Point; use crate::framework::*; -pub type UpstreamPort = gasket::messaging::crossbeam::TwoPhaseInputPort; +pub type UpstreamPort = gasket::messaging::tokio::InputPort; pub type OuroborosClient = blockfetch::Client; pub struct Worker @@ -40,12 +41,17 @@ where } #[instrument(skip(self), fields(slot, %hash))] - fn fetch_block(&mut self, slot: u64, hash: Hash<32>) -> Result, gasket::error::Error> { + async fn fetch_block( + &mut self, + slot: u64, + hash: &Hash<32>, + ) -> Result, gasket::error::Error> { info!("fetching block"); match self .client .fetch_single(Point::Specific(slot, hash.to_vec())) + .await { Ok(x) => { info!("block fetch succeeded"); @@ -73,23 +79,28 @@ where .build() } - fn work(&mut self) -> gasket::runtime::WorkResult { - let msg = self.upstream.recv_or_idle()?; + type WorkUnit = ChainSyncEvent; - let msg = match msg.payload { + async fn schedule(&mut self) -> gasket::runtime::ScheduleResult { + let msg = self.upstream.recv().await?; + info!("scheduling block betch"); + Ok(WorkSchedule::Unit(msg.payload)) + } + + async fn execute(&mut self, unit: &Self::WorkUnit) -> Result<(), gasket::error::Error> { + let output = match unit { ChainSyncEvent::RollForward(s, h) => { - let body = self.fetch_block(s, h)?; + let body = self.fetch_block(*s, h).await?; + self.block_count.inc(1); - BlockFetchEvent::RollForward(s, h, body) + + BlockFetchEvent::RollForward(*s, h.clone(), body) } - ChainSyncEvent::Rollback(x) => BlockFetchEvent::Rollback(x), + ChainSyncEvent::Rollback(x) => BlockFetchEvent::Rollback(x.clone()), }; - self.downstream.send(msg.into())?; + self.downstream.send(output.into()).await?; - // remove the processed event from the queue - self.upstream.commit(); - - Ok(gasket::runtime::WorkOutcome::Partial) + Ok(()) } } diff --git a/pallas-upstream/src/chainsync.rs b/pallas-upstream/src/chainsync.rs index 91e23be..4db09ba 100644 --- a/pallas-upstream/src/chainsync.rs +++ b/pallas-upstream/src/chainsync.rs @@ -1,11 +1,10 @@ use gasket::error::AsWorkError; use tracing::{debug, info}; -use pallas_miniprotocols::chainsync::{HeaderContent, NextResponse}; +use pallas_miniprotocols::chainsync::{HeaderContent, NextResponse, Tip}; use pallas_miniprotocols::{chainsync, Point}; use pallas_traverse::MultiEraHeader; -use crate::cursor::{Cursor, Intersection}; use crate::framework::*; fn to_traverse(header: &chainsync::HeaderContent) -> Result, Error> { @@ -17,20 +16,26 @@ fn to_traverse(header: &chainsync::HeaderContent) -> Result, out.map_err(Error::parse) } -pub type DownstreamPort = gasket::messaging::crossbeam::OutputPort; +pub type DownstreamPort = gasket::messaging::tokio::OutputPort; pub type OuroborosClient = chainsync::N2NClient; -pub struct Worker { - chain_cursor: Cursor, +pub struct Worker +where + C: Cursor, +{ + chain_cursor: C, client: OuroborosClient, downstream: DownstreamPort, block_count: gasket::metrics::Counter, chain_tip: gasket::metrics::Gauge, } -impl Worker { - pub fn new(chain_cursor: Cursor, plexer: ProtocolChannel, downstream: DownstreamPort) -> Self { +impl Worker +where + C: Cursor, +{ + pub fn new(chain_cursor: C, plexer: ProtocolChannel, downstream: DownstreamPort) -> Self { let client = OuroborosClient::new(plexer); Self { @@ -42,45 +47,71 @@ impl Worker { } } - fn intersect(&mut self) -> Result, gasket::error::Error> { - let value = self.chain_cursor.read(); - - match value { - Intersection::Origin => { - let point = self.client.intersect_origin().or_restart()?; - - Ok(Some(point)) - } - Intersection::Tip => { - let point = self.client.intersect_tip().or_restart()?; - - Ok(Some(point)) - } - Intersection::Breadcrumbs(points) => { - let (point, _) = self.client.find_intersect(Vec::from(points)).or_restart()?; - - Ok(point) - } - } + fn notify_tip(&self, tip: Tip) { + self.chain_tip.set(tip.0.slot_or_default() as i64); } - fn process_next( + async fn intersect(&mut self) -> Result<(), gasket::error::Error> { + let value = self.chain_cursor.intersection(); + + let intersect = match value { + Intersection::Origin => { + info!("intersecting origin"); + self.client.intersect_origin().await.or_restart()?.into() + } + Intersection::Tip => { + info!("intersecting tip"); + self.client.intersect_tip().await.or_restart()?.into() + } + Intersection::Breadcrumbs(points) => { + info!("intersecting breadcrumbs"); + let (point, tip) = self + .client + .find_intersect(Vec::from(points)) + .await + .or_restart()?; + + self.notify_tip(tip); + + point + } + }; + + info!(?intersect, "intersected"); + + Ok(()) + } + + async fn process_next( &mut self, next: NextResponse, ) -> Result<(), gasket::error::Error> { match next { - chainsync::NextResponse::RollForward(h, t) => { - let h = to_traverse(&h).or_panic()?; - self.downstream - .send(ChainSyncEvent::RollForward(h.slot(), h.hash()).into())?; + chainsync::NextResponse::RollForward(header, tip) => { + let header = to_traverse(&header).or_panic()?; + + debug!(slot = header.slot(), hash = %header.hash(), "chain sync roll forward"); + + self.downstream + .send(ChainSyncEvent::RollForward(header.slot(), header.hash()).into()) + .await?; + + self.notify_tip(tip); - debug!(slot = h.slot(), hash = %h.hash(), "chain sync roll forward"); - self.chain_tip.set(t.1 as i64); Ok(()) } - chainsync::NextResponse::RollBackward(p, t) => { - self.downstream.send(ChainSyncEvent::Rollback(p).into())?; - self.chain_tip.set(t.1 as i64); + chainsync::NextResponse::RollBackward(point, tip) => { + match &point { + Point::Origin => debug!("rollback to origin"), + Point::Specific(slot, _) => debug!(slot, "rollback"), + }; + + self.downstream + .send(ChainSyncEvent::Rollback(point).into()) + .await?; + + self.notify_tip(tip); + Ok(()) } chainsync::NextResponse::Await => { @@ -90,20 +121,31 @@ impl Worker { } } - fn request_next(&mut self) -> Result<(), gasket::error::Error> { + async fn request_next(&mut self) -> Result<(), gasket::error::Error> { info!("requesting next block"); - let next = self.client.request_next().or_restart()?; - self.process_next(next) + let next = self.client.request_next().await.or_restart()?; + self.process_next(next).await } - fn await_next(&mut self) -> Result<(), gasket::error::Error> { + async fn await_next(&mut self) -> Result<(), gasket::error::Error> { info!("awaiting next block (blocking)"); - let next = self.client.recv_while_must_reply().or_restart()?; - self.process_next(next) + let next = self.client.recv_while_must_reply().await.or_restart()?; + self.process_next(next).await } } -impl gasket::runtime::Worker for Worker { +pub enum WorkUnit { + Intersect, + RequestNext, + AwaitNext, +} + +impl gasket::runtime::Worker for Worker +where + C: Cursor + Sync + Send, +{ + type WorkUnit = WorkUnit; + fn metrics(&self) -> gasket::metrics::Registry { gasket::metrics::Builder::new() .with_counter("received_blocks", &self.block_count) @@ -111,19 +153,24 @@ impl gasket::runtime::Worker for Worker { .build() } - fn bootstrap(&mut self) -> Result<(), gasket::error::Error> { - let intersect = self.intersect()?; - info!(?intersect, "chain-sync intersected"); + async fn bootstrap(&mut self) -> gasket::runtime::ScheduleResult { + Ok(gasket::runtime::WorkSchedule::Unit(WorkUnit::Intersect)) + } + + async fn schedule(&mut self) -> gasket::runtime::ScheduleResult { + match self.client.has_agency() { + true => Ok(gasket::runtime::WorkSchedule::Unit(WorkUnit::RequestNext)), + false => Ok(gasket::runtime::WorkSchedule::Unit(WorkUnit::AwaitNext)), + } + } + + async fn execute(&mut self, unit: &Self::WorkUnit) -> Result<(), gasket::error::Error> { + match unit { + WorkUnit::Intersect => self.intersect().await?, + WorkUnit::RequestNext => self.request_next().await?, + WorkUnit::AwaitNext => self.await_next().await?, + }; Ok(()) } - - fn work(&mut self) -> gasket::runtime::WorkResult { - match self.client.has_agency() { - true => self.request_next()?, - false => self.await_next()?, - }; - - Ok(gasket::runtime::WorkOutcome::Partial) - } } diff --git a/pallas-upstream/src/cursor.rs b/pallas-upstream/src/cursor.rs deleted file mode 100644 index d77c21e..0000000 --- a/pallas-upstream/src/cursor.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::{ - collections::VecDeque, - sync::{Arc, RwLock}, -}; - -use pallas_miniprotocols::Point; - -#[derive(Clone)] -pub enum Intersection { - Tip, - Origin, - Breadcrumbs(VecDeque), -} - -const HARDCODED_BREADCRUMBS: usize = 20; - -// TODO: include exponential breadcrumbs logic here -#[derive(Clone)] -pub struct Cursor(Arc>); - -impl Cursor { - pub fn new(value: Intersection) -> Self { - Self(Arc::new(RwLock::new(value))) - } - - pub fn read(&self) -> Intersection { - let v = self.0.read().unwrap(); - v.clone() - } - - pub fn latest_known_point(&self) -> Option { - let guard = self.0.read().unwrap(); - - match &*guard { - Intersection::Breadcrumbs(v) => v.front().cloned(), - _ => None, - } - } - - pub fn add_breadcrumb(&self, value: Point) { - let mut guard = self.0.write().unwrap(); - - match &mut *guard { - Intersection::Tip | Intersection::Origin => { - *guard = Intersection::Breadcrumbs(VecDeque::from(vec![value])); - } - Intersection::Breadcrumbs(crumbs) => { - crumbs.push_front(value); - - if crumbs.len() > HARDCODED_BREADCRUMBS { - crumbs.pop_back(); - } - } - } - } -} diff --git a/pallas-upstream/src/framework.rs b/pallas-upstream/src/framework.rs index 17bc5c3..9c02397 100644 --- a/pallas-upstream/src/framework.rs +++ b/pallas-upstream/src/framework.rs @@ -2,12 +2,23 @@ use pallas_crypto::hash::Hash; use pallas_miniprotocols::Point; use pallas_multiplexer as multiplexer; use thiserror::Error; -use tracing::error; +use tracing::{error, trace}; pub type BlockSlot = u64; pub type BlockHash = Hash<32>; pub type RawBlock = Vec; +#[derive(Clone)] +pub enum Intersection { + Tip, + Origin, + Breadcrumbs(Vec), +} + +pub trait Cursor: Send + Sync { + fn intersection(&self) -> Intersection; +} + #[derive(Debug, Clone)] pub enum ChainSyncEvent { RollForward(BlockSlot, BlockHash), @@ -21,12 +32,12 @@ pub enum BlockFetchEvent { } // ports used by plexer -pub type MuxOutputPort = gasket::messaging::crossbeam::OutputPort<(u16, multiplexer::Payload)>; -pub type DemuxInputPort = gasket::messaging::crossbeam::InputPort; +pub type MuxOutputPort = gasket::messaging::tokio::OutputPort<(u16, multiplexer::Payload)>; +pub type DemuxInputPort = gasket::messaging::tokio::InputPort; // ports used by mini-protocols -pub type MuxInputPort = gasket::messaging::crossbeam::InputPort<(u16, multiplexer::Payload)>; -pub type DemuxOutputPort = gasket::messaging::crossbeam::OutputPort; +pub type MuxInputPort = gasket::messaging::tokio::InputPort<(u16, multiplexer::Payload)>; +pub type DemuxOutputPort = gasket::messaging::tokio::OutputPort; // final output port pub type DownstreamPort = gasket::messaging::OutputPort; @@ -34,14 +45,22 @@ pub type DownstreamPort = gasket::messaging::OutputPort; pub struct ProtocolChannel(pub u16, pub MuxOutputPort, pub DemuxInputPort); impl multiplexer::agents::Channel for ProtocolChannel { - fn enqueue_chunk( + async fn enqueue_chunk( &mut self, payload: multiplexer::Payload, ) -> Result<(), multiplexer::agents::ChannelError> { - match self + trace!( + protocol = self.0, + payload = hex::encode(&payload), + "enqueing" + ); + + let res = self .1 .send(gasket::messaging::Message::from((self.0, payload))) - { + .await; + + match res { Ok(_) => Ok(()), Err(error) => { error!(?error, "enqueue chunk failed"); @@ -50,8 +69,12 @@ impl multiplexer::agents::Channel for ProtocolChannel { } } - fn dequeue_chunk(&mut self) -> Result { - match self.2.recv() { + async fn dequeue_chunk( + &mut self, + ) -> Result { + let res = self.2.recv().await; + + match res { Ok(msg) => Ok(msg.payload), Err(error) => { error!(?error, "dequeue chunk failed"); @@ -96,8 +119,8 @@ impl Error { Error::Message(error.to_string()) } - pub fn custom(error: Box) -> Error { - Error::Custom(format!("{error}")) + pub fn custom(error: impl Into>) -> Error { + Error::Custom(format!("{}", error.into())) } } diff --git a/pallas-upstream/src/lib.rs b/pallas-upstream/src/lib.rs index b630243..f0a274c 100644 --- a/pallas-upstream/src/lib.rs +++ b/pallas-upstream/src/lib.rs @@ -1,10 +1,10 @@ +#![feature(async_fn_in_trait)] + pub(crate) mod blockfetch; pub(crate) mod chainsync; pub(crate) mod framework; pub(crate) mod plexer; -pub mod cursor; - mod api; pub use api::*; diff --git a/pallas-upstream/src/plexer.rs b/pallas-upstream/src/plexer.rs index dc47f01..b0f51b9 100644 --- a/pallas-upstream/src/plexer.rs +++ b/pallas-upstream/src/plexer.rs @@ -1,84 +1,319 @@ +use std::future::ready; + +use byteorder::{ByteOrder, NetworkEndian}; use gasket::error::AsWorkError; -use tracing::{debug, error, info, warn}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, WriteHalf}; +use tokio::net::{TcpStream, ToSocketAddrs}; +use tokio::select; +use tokio::time::Instant; +use tracing::{debug, error, info, trace, warn}; use pallas_miniprotocols::handshake; -use pallas_multiplexer as multiplexer; -use pallas_multiplexer::bearers::Bearer; -use pallas_multiplexer::demux::{Demuxer, Egress}; -use pallas_multiplexer::mux::{Ingress, Muxer}; -use pallas_multiplexer::sync::SyncPlexer; use crate::framework::*; -struct GasketEgress(DemuxOutputPort); +const HEADER_LEN: usize = 8; -impl Egress for GasketEgress { - fn send( - &mut self, - payload: multiplexer::Payload, - ) -> Result<(), multiplexer::demux::EgressError> { - self.0 - .send(gasket::messaging::Message::from(payload)) - .map_err(|_| multiplexer::demux::EgressError(vec![])) - } +pub type Timestamp = u32; + +pub type Payload = Vec; + +pub type Protocol = u16; + +/// A `Header` struct represents an Ouroboros segment header. +/// +/// # Examples +/// +/// Converting a `Header` to bytes: +/// +/// ``` +/// use byteorder::{BigEndian, ByteOrder}; +/// use pallas_upstream::plexer::Header; +/// +/// let header = Header { +/// protocol: 0x01, +/// timestamp: 1619804871, +/// payload_len: 42, +/// }; +/// +/// let header_bytes: [u8; 8] = header.into(); +/// assert_eq!(header_bytes, [97, 75, 168, 15, 128, 1, 0, 42]); +/// ``` +/// +/// Converting bytes to a `Header`: +/// +/// ``` +/// use byteorder::{BigEndian, ByteOrder}; +/// use pallas_upstream::plexer::Header; +/// +/// let bytes = [97, 75, 168, 15, 128, 1, 0, 42]; +/// let header: Header = (&bytes[..]).into(); +/// +/// assert_eq!(header.protocol, 0x01); +/// assert_eq!(header.timestamp, 1619804871); +/// assert_eq!(header.payload_len, 42); +/// ``` +#[derive(Debug)] +pub struct Header { + pub protocol: Protocol, + pub timestamp: Timestamp, + pub payload_len: u16, } -struct GasketIngress(MuxInputPort); +impl From<&[u8]> for Header { + fn from(value: &[u8]) -> Self { + let timestamp = NetworkEndian::read_u32(&value[0..4]); + let protocol = NetworkEndian::read_u16(&value[4..6]) ^ 0x8000; + let payload_len = NetworkEndian::read_u16(&value[6..8]); -impl Ingress for GasketIngress { - fn recv_timeout( - &mut self, - duration: std::time::Duration, - ) -> Result { - self.0 - .recv_timeout(duration) - .map(|msg| msg.payload) - .map_err(|err| match err { - gasket::error::Error::RecvIdle => multiplexer::mux::IngressError::Empty, - _ => multiplexer::mux::IngressError::Disconnected, - }) - } -} - -type IsBusy = bool; - -fn handle_demux_outcome( - outcome: Result, -) -> Result { - match outcome { - Ok(x) => match x { - multiplexer::demux::TickOutcome::Busy => Ok(true), - multiplexer::demux::TickOutcome::Idle => Ok(false), - }, - Err(err) => match err { - multiplexer::demux::DemuxError::BearerError(err) => { - error!("{}", err.kind()); - Err(gasket::error::Error::ShouldRestart) - } - multiplexer::demux::DemuxError::EgressDisconnected(x, _) => { - error!(protocol = x, "egress disconnected"); - Err(gasket::error::Error::WorkPanic) - } - multiplexer::demux::DemuxError::EgressUnknown(x, _) => { - error!(protocol = x, "unknown egress"); - Err(gasket::error::Error::WorkPanic) - } - }, - } -} - -fn handle_mux_outcome( - outcome: multiplexer::mux::TickOutcome, -) -> Result { - match outcome { - multiplexer::mux::TickOutcome::Busy => Ok(true), - multiplexer::mux::TickOutcome::Idle => Ok(false), - multiplexer::mux::TickOutcome::BearerError(err) => { - warn!(%err); - Err(gasket::error::Error::ShouldRestart) + Self { + timestamp, + protocol, + payload_len, } - multiplexer::mux::TickOutcome::IngressDisconnected => { - error!("ingress disconnected"); + } +} + +impl From
for [u8; 8] { + fn from(value: Header) -> Self { + let mut out = [0u8; 8]; + NetworkEndian::write_u32(&mut out[0..4], value.timestamp); + NetworkEndian::write_u16(&mut out[4..6], value.protocol); + NetworkEndian::write_u16(&mut out[6..8], value.payload_len); + + out + } +} + +pub struct Segment { + pub header: Header, + pub payload: Payload, +} + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +struct AsyncBearer(OwnedReadHalf, OwnedWriteHalf, Instant); + +impl AsyncBearer { + async fn connect_tcp(addr: impl ToSocketAddrs) -> Result { + let stream = TcpStream::connect(addr).await?; + let (read, write) = stream.into_split(); + + Ok(Self(read, write, Instant::now())) + } +} + +impl AsyncBearer { + async fn readable(&self) -> tokio::io::Result<()> { + self.0.readable().await + } + + /// Peek the available data in search for a frame header + async fn peek_header(&mut self) -> tokio::io::Result> { + let mut buf = [0u8; HEADER_LEN]; + let len = self.0.peek(&mut buf).await?; + + if len < HEADER_LEN { + return Ok(None); + } + + Ok(Some(Header::from(buf.as_slice()))) + } + + async fn has_payload(&mut self, payload_len: usize) -> tokio::io::Result { + let segment_size = HEADER_LEN + payload_len; + let mut buf = vec![0u8; segment_size]; + + let available = self.0.peek(&mut buf).await?; + + return Ok(available >= segment_size); + } + + /// Peeks the bearer to see if a full segment is available to be read + async fn has_segment(&mut self) -> std::io::Result { + let header = match self.peek_header().await? { + Some(x) => x, + None => return Ok(false), + }; + + self.has_payload(header.payload_len as usize).await + } + + /// Reads a full segment from the bearer while consuming the bytes + /// + /// This function is NOT "cancel safe", meaning that it shouldn't be used + /// inside the context of a select!. Only call this function once you're + /// sure that you can await until all the required bytes are available. + async fn read_segment(&mut self) -> tokio::io::Result<(Protocol, Payload)> { + let mut buf = [0u8; HEADER_LEN]; + self.0.read_exact(&mut buf).await?; + let header = Header::from(buf.as_slice()); + + // TODO: assert any business invariants regarding timestamp from the other party + + let mut payload = vec![0u8; header.payload_len as usize]; + self.0.read_exact(&mut payload).await?; + + Ok((header.protocol, payload)) + } + + async fn write_segment(&mut self, protocol: u16, payload: &[u8]) -> Result<(), std::io::Error> { + let header = Header { + protocol, + timestamp: self.2.elapsed().as_micros() as u32, + payload_len: payload.len() as u16, + }; + + let buf: [u8; 8] = header.into(); + self.1.write_all(&buf).await?; + + self.1.write_all(&payload).await?; + + Ok(()) + } +} + +pub struct AsyncAgentChannel( + Protocol, + tokio::sync::mpsc::Sender<(Protocol, Payload)>, + tokio::sync::broadcast::Receiver<(Protocol, Payload)>, +); + +impl pallas_multiplexer::agents::Channel for AsyncAgentChannel { + async fn enqueue_chunk( + &mut self, + chunk: pallas_multiplexer::Payload, + ) -> Result<(), pallas_multiplexer::agents::ChannelError> { + let res = self.1.send((self.0, chunk)).await; + + res.map_err(|err| pallas_multiplexer::agents::ChannelError::NotConnected(Some(err.0 .1))) + } + + async fn dequeue_chunk( + &mut self, + ) -> Result { + loop { + let (protocol, payload) = self + .2 + .recv() + .await + .map_err(|err| pallas_multiplexer::agents::ChannelError::NotConnected(None))?; + + if protocol == self.0 { + break Ok(payload); + } + } + } +} + +pub type AsyncIngress = ( + tokio::sync::mpsc::Sender<(Protocol, Payload)>, + tokio::sync::mpsc::Receiver<(Protocol, Payload)>, +); +pub type AsyncEgress = ( + tokio::sync::broadcast::Sender<(Protocol, Payload)>, + tokio::sync::broadcast::Receiver<(Protocol, Payload)>, +); + +struct AsyncPlexer { + bearer: AsyncBearer, + ingress: AsyncIngress, + egress: AsyncEgress, +} + +impl AsyncPlexer { + pub fn new(bearer: AsyncBearer) -> Self { + Self { + bearer, + ingress: tokio::sync::mpsc::channel(100), // TODO: define buffer + egress: tokio::sync::broadcast::channel(100), + } + } + + async fn mux(&mut self, msg: (Protocol, Payload)) -> tokio::io::Result<()> { + self.bearer.write_segment(msg.0, &msg.1).await?; + + Ok(()) + } + + async fn demux(&mut self) -> tokio::io::Result<()> { + let (protocol, payload) = self.bearer.read_segment().await?; + + self.egress.0.send((protocol, payload)).unwrap(); + + Ok(()) + } + + pub fn subscribe(&mut self, protocol: Protocol) -> AsyncAgentChannel { + let agent_tx = self.ingress.0.clone(); + let agent_rx = self.egress.0.subscribe(); + + AsyncAgentChannel(protocol, agent_tx, agent_rx) + } + + pub async fn run(&mut self) -> tokio::io::Result<()> { + loop { + select! { + Ok(_) = self.bearer.readable() => { + if let Ok(true) = self.bearer.has_segment().await { + trace!("demux selected"); + self.demux().await? + } + }, + Some(x) = self.ingress.1.recv() => { + trace!("mux selected"); + self.mux(x).await? + }, + } + } + } +} + +impl From for AsyncPlexer { + fn from(value: AsyncBearer) -> Self { + Self::new(value) + } +} + +impl From for AsyncBearer { + fn from(value: AsyncPlexer) -> Self { + value.bearer + } +} + +async fn handshake( + plexer: &mut AsyncPlexer, + network_magic: u64, +) -> Result<(), gasket::error::Error> { + info!("executing handshake"); + + let channel0 = plexer.subscribe(0); + let versions = handshake::n2n::VersionTable::v7_and_above(network_magic); + let mut client = handshake::Client::new(channel0); + + //let p = tokio::spawn(plexer.run()); + //let output = client.handshake(versions).or_restart()?; + + let output = select! { + x = client.handshake(versions) => x.or_restart()?, + x = plexer.run() => { + match x.or_restart() { + Err(x) => return Err(x), + _ => unreachable!(), + }; + }, + }; + + debug!("handshake output: {:?}", output); + //p.abort(); + + match output { + handshake::Confirmation::Accepted(version, _) => { + info!(version, "connected to upstream peer"); + Ok(()) + } + _ => { + error!("couldn't agree on handshake version"); Err(gasket::error::Error::WorkPanic) } } @@ -87,11 +322,10 @@ fn handle_mux_outcome( pub struct Worker { peer_address: String, network_magic: u64, - input: MuxInputPort, + bearer: Option, + mux_input: MuxInputPort, channel2_out: Option, channel3_out: Option, - demuxer: Option>, - muxer: Option>, ops_count: gasket::metrics::Counter, } @@ -99,48 +333,31 @@ impl Worker { pub fn new( peer_address: String, network_magic: u64, - input: MuxInputPort, + mux_input: MuxInputPort, channel2_out: Option, channel3_out: Option, ) -> Self { Self { peer_address, network_magic, - input, channel2_out, channel3_out, - demuxer: None, - muxer: None, + mux_input, + bearer: None, ops_count: Default::default(), } } - - fn handshake(&self, bearer: Bearer) -> Result { - info!("excuting handshake"); - - let plexer = SyncPlexer::new(bearer, 0); - let versions = handshake::n2n::VersionTable::v7_and_above(self.network_magic); - let mut client = handshake::Client::new(plexer); - - let output = client.handshake(versions).or_panic()?; - debug!("handshake output: {:?}", output); - - let bearer = client.unwrap().unwrap(); - - match output { - handshake::Confirmation::Accepted(version, _) => { - info!(version, "connected to upstream peer"); - Ok(bearer) - } - _ => { - error!("couldn't agree on handshake version"); - Err(gasket::error::Error::WorkPanic) - } - } - } +} + +pub enum WorkUnit { + Connect, + Mux((u16, Vec)), + Demux, } impl gasket::runtime::Worker for Worker { + type WorkUnit = WorkUnit; + fn metrics(&self) -> gasket::metrics::Registry { // TODO: define networking metrics (bytes in / out, etc) gasket::metrics::Builder::new() @@ -148,60 +365,72 @@ impl gasket::runtime::Worker for Worker { .build() } - fn bootstrap(&mut self) -> Result<(), gasket::error::Error> { - debug!("connecting muxer"); + async fn bootstrap(&mut self) -> gasket::runtime::ScheduleResult { + Ok(gasket::runtime::WorkSchedule::Unit(WorkUnit::Connect)) + } - let bearer = multiplexer::bearers::Bearer::connect_tcp(&self.peer_address).or_restart()?; - - let bearer = self.handshake(bearer)?; - - let mut demuxer = Demuxer::new(bearer.clone()); - - if let Some(c2) = &self.channel2_out { - demuxer.register(2, GasketEgress(c2.clone())); + async fn schedule(&mut self) -> gasket::runtime::ScheduleResult { + let bearer = self.bearer.as_mut().unwrap(); + trace!("selecting"); + select! { + Ok(msg) = self.mux_input.recv() => { Ok(gasket::runtime::WorkSchedule::Unit(WorkUnit::Mux(msg.payload))) } + Ok(true) = bearer.has_segment() => Ok(gasket::runtime::WorkSchedule::Unit(WorkUnit::Demux)), + _ = tokio::time::sleep(tokio::time::Duration::from_secs(5)) => Ok(gasket::runtime::WorkSchedule::Idle), } + } - if let Some(c3) = &self.channel3_out { - demuxer.register(3, GasketEgress(c3.clone())); - } + async fn execute(&mut self, unit: &Self::WorkUnit) -> Result<(), gasket::error::Error> { + match unit { + WorkUnit::Connect => { + debug!("connecting"); + let bearer = AsyncBearer::connect_tcp(&self.peer_address) + .await + .or_retry()?; - self.demuxer = Some(demuxer); + let mut plexer = bearer.into(); - let muxer = Muxer::new(bearer, GasketIngress(self.input.clone())); - self.muxer = Some(muxer); + handshake(&mut plexer, self.network_magic).await?; + + self.bearer = Some(plexer.into()); + } + WorkUnit::Mux(x) => { + trace!("muxing"); + self.bearer + .as_mut() + .unwrap() + .write_segment(x.0, &x.1) + .await + .or_restart()?; + } + WorkUnit::Demux => { + trace!("demuxing"); + + let (protocol, payload) = self + .bearer + .as_mut() + .unwrap() + .read_segment() + .await + .or_restart()?; + + match protocol { + 2 => { + if let Some(channel) = &mut self.channel2_out { + channel.send(payload.into()).await?; + trace!("sent protocol 2 msg"); + } + } + 3 => { + if let Some(channel) = &mut self.channel3_out { + channel.send(payload.into()).await?; + trace!("sent protocol 3 msg"); + } + } + x => warn!("trying to demux unexpected protocol {x}"), + } + } + }; Ok(()) } - - fn work(&mut self) -> gasket::runtime::WorkResult { - let muxer = self.muxer.as_mut().unwrap(); - let demuxer = self.demuxer.as_mut().unwrap(); - - let span = tracing::span::Span::current(); - - let mut mux_res = None; - let mut demux_res = None; - - rayon::scope(|s| { - s.spawn(|_| { - let _guard = span.enter(); - info!("mux ticking"); - let outcome = muxer.tick(); - mux_res = Some(handle_mux_outcome(outcome)); - }); - s.spawn(|_| { - let _guard = span.enter(); - info!("demux ticking"); - let outcome = demuxer.tick(); - demux_res = Some(handle_demux_outcome(outcome)); - }); - }); - - mux_res.unwrap()?; - demux_res.unwrap()?; - - self.ops_count.inc(1); - - Ok(gasket::runtime::WorkOutcome::Partial) - } } diff --git a/pallas-upstream/tests/integration.rs b/pallas-upstream/tests/integration.rs new file mode 100644 index 0000000..6999fa5 --- /dev/null +++ b/pallas-upstream/tests/integration.rs @@ -0,0 +1,87 @@ +#![feature(async_fn_in_trait)] + +use std::time::Duration; + +use gasket::{ + messaging::{ + tokio::{InputPort, OutputPort}, + RecvPort, SendPort, + }, + runtime::{ScheduleResult, WorkSchedule, Worker}, +}; +use pallas_miniprotocols::Point; +use pallas_upstream::{BlockFetchEvent, Cursor}; +use tracing::{error, info}; + +struct Witness { + input: InputPort, +} + +impl Worker for Witness { + type WorkUnit = BlockFetchEvent; + + fn metrics(&self) -> gasket::metrics::Registry { + gasket::metrics::Registry::new() + } + + async fn schedule(&mut self) -> gasket::runtime::ScheduleResult { + error!("dequeing form witness"); + let msg = self.input.recv().await?; + Ok(WorkSchedule::Unit(msg.payload)) + } + + async fn execute(&mut self, unit: &Self::WorkUnit) -> Result<(), gasket::error::Error> { + error!("witnessing block event"); + + Ok(()) + } +} + +struct StaticCursor; + +impl Cursor for StaticCursor { + fn intersection(&self) -> pallas_upstream::Intersection { + pallas_upstream::Intersection::Origin + } +} + +#[test] +fn test_mainnet_upstream() { + tracing::subscriber::set_global_default( + tracing_subscriber::FmtSubscriber::builder() + .with_max_level(tracing::Level::TRACE) + .finish(), + ) + .unwrap(); + + let mut b = pallas_upstream::n2n::Bootstrapper::new( + StaticCursor, + "relays-new.cardano-mainnet.iohk.io:3001".into(), + 764824073, + ); + + let (send, receive) = gasket::messaging::tokio::channel(200); + + // let mut f = Faker { + // output: Default::default(), + // }; + + //f.output.connect(send); + + b.connect_output(send); + + let b = b.spawn().unwrap(); + + let mut w = Witness { + input: Default::default(), + }; + + w.input.connect(receive); + + //let f = gasket::runtime::spawn_stage(f, Default::default(), Some("faker")); + let w = gasket::runtime::spawn_stage(w, Default::default(), Some("witness")); + + let d = gasket::daemon::Daemon(vec![w]); + + d.block(); +}