diff --git a/examples/n2n-miniprotocols/src/main.rs b/examples/n2n-miniprotocols/src/main.rs index 7c5f187..6ff26cb 100644 --- a/examples/n2n-miniprotocols/src/main.rs +++ b/examples/n2n-miniprotocols/src/main.rs @@ -1,10 +1,13 @@ -use pallas::{network::{ - facades::PeerClient, - miniprotocols::{chainsync, Point, MAINNET_MAGIC, blockfetch, keepalive}, -}, ledger::traverse::MultiEraHeader}; -use tokio::{time::Instant, select}; +use pallas::{ + ledger::traverse::MultiEraHeader, + network::{ + facades::PeerClient, + miniprotocols::{blockfetch, chainsync, keepalive, Point, MAINNET_MAGIC}, + }, +}; +use std::time::Duration; use thiserror::Error; -use futures::{future::FutureExt, pin_mut}; +use tokio::time::Instant; #[derive(Error, Debug)] pub enum Error { @@ -24,17 +27,27 @@ pub enum Error { PallasTraverseError(#[from] pallas::ledger::traverse::Error), } -async fn do_blockfetch(blockfetch_client: &mut blockfetch::Client, range: (Point, Point)) -> Result<(), Error> { +async fn do_blockfetch( + blockfetch_client: &mut blockfetch::Client, + range: (Point, Point), +) -> Result<(), Error> { let blocks = blockfetch_client.fetch_range(range.clone()).await?; for block in &blocks { tracing::trace!("received block of size: {}", block.len()); } - tracing::info!("received {} blocks. last slot: {}", blocks.len(), range.1.slot_or_default()); + tracing::info!( + "received {} blocks. last slot: {}", + blocks.len(), + range.1.slot_or_default() + ); Ok(()) } -async fn do_chainsync(chainsync_client: &mut chainsync::N2NClient, blockfetch_client: &mut blockfetch::Client) -> Result<(), Error> { +async fn do_chainsync( + mut chainsync_client: chainsync::N2NClient, + mut blockfetch_client: blockfetch::Client, +) -> Result<(), Error> { let known_points = vec![Point::Specific( 43847831u64, hex::decode("15b9eeee849dd6386d3770b0745e0450190f7560e5159b1b3ab13b14b2684a45")?, @@ -64,18 +77,18 @@ async fn do_chainsync(chainsync_client: &mut chainsync::N2NClient, blockfetch_cl MultiEraHeader::EpochBoundary(_) => { tracing::info!("epoch boundary"); None - }, + } MultiEraHeader::AlonzoCompatible(_) | MultiEraHeader::Babbage(_) => { if next_log.elapsed().as_secs() > 1 { tracing::info!("chainsync block header: {}", number); next_log = Instant::now(); } Some(Point::Specific(slot, hash)) - }, + } MultiEraHeader::Byron(_) => { tracing::info!("ignoring byron header"); None - }, + } } } Some(_) => { @@ -83,19 +96,19 @@ async fn do_chainsync(chainsync_client: &mut chainsync::N2NClient, blockfetch_cl None } }; - match point { - Some(p) => { - block_count += 1; - if block_count == 1 { - start_point = p; - } - else if block_count == 10 { - end_point = p; - do_blockfetch(blockfetch_client, (start_point.clone(), end_point.clone())).await?; - block_count = 0; - } - }, - None => {}, + if let Some(p) = point { + block_count += 1; + if block_count == 1 { + start_point = p; + } else if block_count == 10 { + end_point = p; + do_blockfetch( + &mut blockfetch_client, + (start_point.clone(), end_point.clone()), + ) + .await?; + block_count = 0; + } }; } chainsync::NextResponse::RollBackward(x, _) => log::info!("rollback to {:?}", x), @@ -104,15 +117,11 @@ async fn do_chainsync(chainsync_client: &mut chainsync::N2NClient, blockfetch_cl } } -async fn do_keepalive(keepalive_client: &mut keepalive::Client) -> Result<(), Error> { - let mut keepalive_timer = Instant::now(); +async fn do_keepalive(mut keepalive_client: keepalive::Client) -> Result<(), Error> { loop { - if keepalive_timer.elapsed().as_secs() > 20 { - tracing::info!("sending keepalive..."); - keepalive_client.send_keepalive().await?; - tracing::info!("keepalive sent"); - keepalive_timer = Instant::now(); - } + tokio::time::sleep(Duration::from_secs(20)).await; + keepalive_client.send_keepalive().await?; + tracing::info!("keepalive sent"); } } @@ -130,55 +139,33 @@ async fn main() { // relay. let server = "backbone.cardano-mainnet.iohk.io:3001"; // let server = "localhost:6000"; - let mut peer = PeerClient::connect(server, MAINNET_MAGIC) - .await - .unwrap(); + let peer = PeerClient::connect(server, MAINNET_MAGIC).await.unwrap(); - let chainsync_handle = tokio::spawn(async move { - do_chainsync(&mut peer.chainsync, &mut peer.blockfetch).await?; - Ok::<(), Error>(()) - }).fuse(); - let keepalive_handle = tokio::spawn(async move { - do_keepalive(&mut peer.keepalive).await?; - Ok::<(), Error>(()) - }).fuse(); + let PeerClient { + plexer, + chainsync, + blockfetch, + keepalive, + .. + } = peer; - pin_mut!(chainsync_handle, keepalive_handle); + let chainsync_handle = tokio::spawn(do_chainsync(chainsync, blockfetch)); + let keepalive_handle = tokio::spawn(do_keepalive(keepalive)); // If any of these concurrent tasks exit or fail, the others are canceled. - select! { - chainsync_result = chainsync_handle => { - match chainsync_result { - Ok(result) => { - match result { - Ok(_) => {} - Err(error) => { - tracing::error!("chainsync error: {:?}", error); - } - } - } - Err(error) => { - tracing::error!("chainsync error: {:?}", error); - } - } - } - keepalive_result = keepalive_handle => { - match keepalive_result { - Ok(result) => { - match result { - Ok(_) => {} - Err(error) => { - tracing::error!("keepalive error: {:?}", error); - } - } - } - Err(error) => { - tracing::error!("keepalive error: {:?}", error); - } - } - } + let (chainsync_result, keepalive_result) = + tokio::try_join!(chainsync_handle, keepalive_handle) + .expect("error joining tokio threads"); + + if let Err(err) = chainsync_result { + tracing::error!("chainsync error: {:?}", err); } - peer.plexer_handle.abort(); + + if let Err(err) = keepalive_result { + tracing::error!("keepalive error: {:?}", err); + } + + plexer.abort().await; tracing::info!("waiting 10 seconds before reconnecting..."); tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; diff --git a/pallas-network/src/facades.rs b/pallas-network/src/facades.rs index 31d3412..cc834d9 100644 --- a/pallas-network/src/facades.rs +++ b/pallas-network/src/facades.rs @@ -1,25 +1,29 @@ +use std::net::SocketAddr; use std::path::Path; - use thiserror::Error; +use tracing::error; + use tokio::net::TcpListener; -use tokio::task::JoinHandle; -use tracing::{debug, error}; #[cfg(unix)] -use tokio::net::UnixListener; +use tokio::net::{unix::SocketAddr as UnixSocketAddr, UnixListener}; -use crate::miniprotocols::handshake::{n2c, n2n, Confirmation, VersionNumber, VersionTable}; +use crate::miniprotocols::handshake::{n2c, n2n, Confirmation, VersionNumber}; use crate::miniprotocols::{ - txsubmission, keepalive, blockfetch, chainsync, handshake, localstate, - PROTOCOL_N2N_HANDSHAKE, PROTOCOL_N2N_TX_SUBMISSION, PROTOCOL_N2N_KEEP_ALIVE, - PROTOCOL_N2C_CHAIN_SYNC, PROTOCOL_N2C_HANDSHAKE, PROTOCOL_N2C_STATE_QUERY, - PROTOCOL_N2N_BLOCK_FETCH, PROTOCOL_N2N_CHAIN_SYNC, + blockfetch, chainsync, handshake, keepalive, localstate, txsubmission, PROTOCOL_N2C_CHAIN_SYNC, + PROTOCOL_N2C_HANDSHAKE, PROTOCOL_N2C_STATE_QUERY, PROTOCOL_N2N_BLOCK_FETCH, + PROTOCOL_N2N_CHAIN_SYNC, PROTOCOL_N2N_HANDSHAKE, PROTOCOL_N2N_KEEP_ALIVE, + PROTOCOL_N2N_TX_SUBMISSION, }; -use crate::multiplexer::{self, Bearer}; + +use crate::multiplexer::{self, Bearer, RunningPlexer}; #[derive(Debug, Error)] pub enum Error { + #[error("error in multiplexer")] + PlexerFailure(#[source] multiplexer::Error), + #[error("error connecting bearer")] ConnectFailure(#[source] tokio::io::Error), @@ -32,8 +36,8 @@ pub enum Error { /// Client of N2N Ouroboros pub struct PeerClient { - pub plexer_handle: JoinHandle>, - pub handshake: handshake::Confirmation, + pub plexer: RunningPlexer, + pub handshake: handshake::N2NClient, pub chainsync: chainsync::N2NClient, pub blockfetch: blockfetch::Client, pub txsubmission: txsubmission::Client, @@ -41,12 +45,7 @@ pub struct PeerClient { } impl PeerClient { - pub async fn connect(address: &str, magic: u64) -> Result { - debug!("connecting"); - let bearer = Bearer::connect_tcp(address) - .await - .map_err(Error::ConnectFailure)?; - + pub fn new(bearer: Bearer) -> Self { let mut plexer = multiplexer::Plexer::new(bearer); let hs_channel = plexer.subscribe_client(PROTOCOL_N2N_HANDSHAKE); @@ -55,12 +54,29 @@ impl PeerClient { let txsub_channel = plexer.subscribe_client(PROTOCOL_N2N_TX_SUBMISSION); let keepalive_channel = plexer.subscribe_client(PROTOCOL_N2N_KEEP_ALIVE); - let plexer_handle = tokio::spawn(async move { plexer.run().await }); + let plexer = plexer.spawn(); + + Self { + plexer, + handshake: handshake::Client::new(hs_channel), + chainsync: chainsync::Client::new(cs_channel), + blockfetch: blockfetch::Client::new(bf_channel), + txsubmission: txsubmission::Client::new(txsub_channel), + keepalive: keepalive::Client::new(keepalive_channel), + } + } + + pub async fn connect(addr: &'static str, magic: u64) -> Result { + let bearer = Bearer::connect_tcp(addr) + .await + .map_err(Error::ConnectFailure)?; + + let mut client = Self::new(bearer); let versions = handshake::n2n::VersionTable::v7_and_above(magic); - let mut client = handshake::Client::new(hs_channel); let handshake = client + .handshake() .handshake(versions) .await .map_err(Error::HandshakeProtocol)?; @@ -70,20 +86,26 @@ impl PeerClient { return Err(Error::IncompatibleVersion); } - Ok(Self { - plexer_handle, - handshake, - chainsync: chainsync::Client::new(cs_channel), - blockfetch: blockfetch::Client::new(bf_channel), - txsubmission: txsubmission::Client::new(txsub_channel), - keepalive: keepalive::Client::new(keepalive_channel), - }) + Ok(client) + } + + pub fn handshake(&mut self) -> &mut handshake::N2NClient { + &mut self.handshake } pub fn chainsync(&mut self) -> &mut chainsync::N2NClient { &mut self.chainsync } + pub async fn with_chainsync(&mut self, op: T) -> tokio::task::JoinHandle + where + T: FnOnce(&mut chainsync::N2NClient) -> Fut, + Fut: std::future::Future + Send + 'static, + O: Send + 'static, + { + tokio::spawn(op(&mut self.chainsync)) + } + pub fn blockfetch(&mut self) -> &mut blockfetch::Client { &mut self.blockfetch } @@ -96,59 +118,76 @@ impl PeerClient { &mut self.keepalive } - pub fn abort(&mut self) { - self.plexer_handle.abort(); + pub async fn abort(self) { + self.plexer.abort().await } } /// Server of N2N Ouroboros pub struct PeerServer { - pub plexer_handle: JoinHandle>, - pub version: (VersionNumber, n2n::VersionData), - pub chainsync: chainsync::N2NServer, - pub blockfetch: blockfetch::Server, - pub txsubmission: txsubmission::Server, + plexer: RunningPlexer, + handshake: handshake::N2NServer, + chainsync: chainsync::N2NServer, + blockfetch: blockfetch::Server, + txsubmission: txsubmission::Server, + accepted_address: Option, + accepted_version: Option, } impl PeerServer { + pub fn new(bearer: Bearer) -> Self { + let mut plexer = multiplexer::Plexer::new(bearer); + + let hs_channel = plexer.subscribe_server(PROTOCOL_N2N_HANDSHAKE); + let cs_channel = plexer.subscribe_server(PROTOCOL_N2N_CHAIN_SYNC); + let bf_channel = plexer.subscribe_server(PROTOCOL_N2N_BLOCK_FETCH); + let txsub_channel = plexer.subscribe_server(PROTOCOL_N2N_TX_SUBMISSION); + + let hs = handshake::N2NServer::new(hs_channel); + let cs = chainsync::N2NServer::new(cs_channel); + let bf = blockfetch::Server::new(bf_channel); + let txsub = txsubmission::Server::new(txsub_channel); + + let plexer = plexer.spawn(); + + Self { + plexer, + handshake: hs, + chainsync: cs, + blockfetch: bf, + txsubmission: txsub, + accepted_address: None, + accepted_version: None, + } + } + pub async fn accept(listener: &TcpListener, magic: u64) -> Result { - let (bearer, _) = Bearer::accept_tcp(listener) + let (bearer, address) = Bearer::accept_tcp(listener) .await .map_err(Error::ConnectFailure)?; - let mut server_plexer = multiplexer::Plexer::new(bearer); + let mut client = Self::new(bearer); - let hs_channel = server_plexer.subscribe_server(PROTOCOL_N2N_HANDSHAKE); - let cs_channel = server_plexer.subscribe_server(PROTOCOL_N2N_CHAIN_SYNC); - let bf_channel = server_plexer.subscribe_server(PROTOCOL_N2N_BLOCK_FETCH); - let txsub_channel = server_plexer.subscribe_server(PROTOCOL_N2N_TX_SUBMISSION); - - let mut server_hs: handshake::Server = handshake::Server::new(hs_channel); - let server_cs = chainsync::N2NServer::new(cs_channel); - let server_bf = blockfetch::Server::new(bf_channel); - let server_txsub = txsubmission::Server::new(txsub_channel); - - let plexer_handle = tokio::spawn(async move { server_plexer.run().await }); - - let accepted_version = server_hs + let accepted_version = client + .handshake() .handshake(n2n::VersionTable::v7_and_above(magic)) .await .map_err(Error::HandshakeProtocol)?; - if let Some(ver) = accepted_version { - Ok(Self { - plexer_handle, - version: ver, - chainsync: server_cs, - blockfetch: server_bf, - txsubmission: server_txsub, - }) + if let Some((version, _)) = accepted_version { + client.accepted_address = Some(address); + client.accepted_version = Some(version); + Ok(client) } else { - plexer_handle.abort(); + client.abort().await; Err(Error::IncompatibleVersion) } } + pub fn handshake(&mut self) -> &mut handshake::N2NServer { + &mut self.handshake + } + pub fn chainsync(&mut self) -> &mut chainsync::N2NServer { &mut self.chainsync } @@ -161,35 +200,52 @@ impl PeerServer { &mut self.txsubmission } - pub fn abort(&mut self) { - self.plexer_handle.abort(); + pub async fn abort(self) { + self.plexer.abort().await } } /// Client of N2C Ouroboros pub struct NodeClient { - pub plexer_handle: JoinHandle>, - pub handshake: handshake::Confirmation, - pub chainsync: chainsync::N2CClient, - pub statequery: localstate::Client, + plexer: RunningPlexer, + handshake: handshake::N2CClient, + chainsync: chainsync::N2CClient, + statequery: localstate::Client, } impl NodeClient { - async fn connect_bearer( - bearer: Bearer, - versions: VersionTable, - ) -> Result { + pub fn new(bearer: Bearer) -> Self { let mut plexer = multiplexer::Plexer::new(bearer); let hs_channel = plexer.subscribe_client(PROTOCOL_N2C_HANDSHAKE); let cs_channel = plexer.subscribe_client(PROTOCOL_N2C_CHAIN_SYNC); let sq_channel = plexer.subscribe_client(PROTOCOL_N2C_STATE_QUERY); - let plexer_handle = tokio::spawn(async move { plexer.run().await }); + let plexer = plexer.spawn(); - let mut client = handshake::Client::new(hs_channel); + Self { + plexer, + handshake: handshake::Client::new(hs_channel), + chainsync: chainsync::Client::new(cs_channel), + statequery: localstate::Client::new(sq_channel), + } + } + + #[cfg(unix)] + pub async fn connect( + path: impl AsRef + Send + 'static, + magic: u64, + ) -> Result { + let bearer = Bearer::connect_unix(path) + .await + .map_err(Error::ConnectFailure)?; + + let mut client = Self::new(bearer); + + let versions = handshake::n2c::VersionTable::v10_and_above(magic); let handshake = client + .handshake() .handshake(versions) .await .map_err(Error::HandshakeProtocol)?; @@ -199,59 +255,47 @@ impl NodeClient { return Err(Error::IncompatibleVersion); } - Ok(Self { - plexer_handle, - handshake, - chainsync: chainsync::Client::new(cs_channel), - statequery: localstate::Client::new(sq_channel), - }) + Ok(client) } - #[cfg(unix)] - pub async fn connect(path: impl AsRef, magic: u64) -> Result { - debug!("connecting"); + // #[cfg(windows)] + // pub async fn connect( + // pipe_name: impl AsRef, + // magic: u64, + // ) -> Result { + // let bearer = tokio::task::spawn_blocking(move || + // Bearer::connect_named_pipe(pipe_name)) .await + // .expect("can't join tokio thread") + // .map_err(Error::ConnectFailure)?; - let bearer = Bearer::connect_unix(path) - .await - .map_err(Error::ConnectFailure)?; + // let mut client = Self::new(bearer); - let versions = handshake::n2c::VersionTable::v10_and_above(magic); + // let versions = handshake::n2c::VersionTable::v10_and_above(magic); - Self::connect_bearer(bearer, versions).await - } + // let handshake = client + // .handshake() + // .handshake(versions) + // .await + // .map_err(Error::HandshakeProtocol)?; - #[cfg(windows)] - pub async fn connect( - pipe_name: impl AsRef, - magic: u64, - ) -> Result { - debug!("connecting"); + // if let handshake::Confirmation::Rejected(reason) = handshake { + // error!(?reason, "handshake refused"); + // return Err(Error::IncompatibleVersion); + // } - let bearer = Bearer::connect_named_pipe(pipe_name) - .await - .map_err(Error::ConnectFailure)?; - - let versions = handshake::n2c::VersionTable::v10_and_above(magic); - - Self::connect_bearer(bearer, versions).await - } + // Ok(client) + // } #[cfg(unix)] pub async fn handshake_query( - path: impl AsRef, + bearer: Bearer, magic: u64, ) -> Result { - debug!("connecting"); - - let bearer = Bearer::connect_unix(path) - .await - .map_err(Error::ConnectFailure)?; - let mut plexer = multiplexer::Plexer::new(bearer); let hs_channel = plexer.subscribe_client(PROTOCOL_N2C_HANDSHAKE); - let plexer_handle = tokio::spawn(async move { plexer.run().await }); + let plexer = plexer.spawn(); let versions = handshake::n2c::VersionTable::v15_with_query(magic); let mut client = handshake::Client::new(hs_channel); @@ -271,12 +315,16 @@ impl NodeClient { Err(Error::IncompatibleVersion) } Confirmation::QueryReply(version_table) => { - plexer_handle.abort(); + plexer.abort().await; Ok(version_table) } } } + pub fn handshake(&mut self) -> &mut handshake::N2CClient { + &mut self.handshake + } + pub fn chainsync(&mut self) -> &mut chainsync::N2CClient { &mut self.chainsync } @@ -285,57 +333,74 @@ impl NodeClient { &mut self.statequery } - pub fn abort(&mut self) { - self.plexer_handle.abort(); + pub async fn abort(self) { + self.plexer.abort().await } } /// Server of N2C Ouroboros. #[cfg(unix)] pub struct NodeServer { - pub plexer_handle: JoinHandle>, - pub version: (VersionNumber, n2c::VersionData), - pub chainsync: chainsync::N2CServer, - pub statequery: localstate::Server, + plexer: RunningPlexer, + handshake: handshake::N2CServer, + chainsync: chainsync::N2CServer, + statequery: localstate::Server, + accepted_address: Option, + accpeted_version: Option<(VersionNumber, n2c::VersionData)>, } #[cfg(unix)] impl NodeServer { - pub async fn accept(listener: &UnixListener, magic: u64) -> Result { - let (bearer, _) = Bearer::accept_unix(listener) - .await - .map_err(Error::ConnectFailure)?; + pub async fn new(bearer: Bearer) -> Self { + let mut plexer = multiplexer::Plexer::new(bearer); - let mut server_plexer = multiplexer::Plexer::new(bearer); + let hs_channel = plexer.subscribe_server(PROTOCOL_N2C_HANDSHAKE); + let cs_channel = plexer.subscribe_server(PROTOCOL_N2C_CHAIN_SYNC); + let sq_channel = plexer.subscribe_server(PROTOCOL_N2C_STATE_QUERY); - let hs_channel = server_plexer.subscribe_server(PROTOCOL_N2C_HANDSHAKE); - let cs_channel = server_plexer.subscribe_server(PROTOCOL_N2C_CHAIN_SYNC); - let sq_channel = server_plexer.subscribe_server(PROTOCOL_N2C_STATE_QUERY); - - let mut server_hs: handshake::Server = handshake::Server::new(hs_channel); + let server_hs = handshake::Server::::new(hs_channel); let server_cs = chainsync::N2CServer::new(cs_channel); let server_sq = localstate::Server::new(sq_channel); - let plexer_handle = tokio::spawn(async move { server_plexer.run().await }); + let plexer = plexer.spawn(); - let accepted_version = server_hs + Self { + plexer, + handshake: server_hs, + chainsync: server_cs, + statequery: server_sq, + accepted_address: None, + accpeted_version: None, + } + } + + pub async fn accept(listener: &UnixListener, magic: u64) -> Result { + let (bearer, address) = Bearer::accept_unix(listener) + .await + .map_err(Error::ConnectFailure)?; + + let mut client = Self::new(bearer).await; + + let accepted_version = client + .handshake() .handshake(n2c::VersionTable::v10_and_above(magic)) .await .map_err(Error::HandshakeProtocol)?; - if let Some(ver) = accepted_version { - Ok(Self { - plexer_handle, - version: ver, - chainsync: server_cs, - statequery: server_sq, - }) + if let Some(version) = accepted_version { + client.accepted_address = Some(address); + client.accpeted_version = Some(version); + Ok(client) } else { - plexer_handle.abort(); + client.abort().await; Err(Error::IncompatibleVersion) } } + pub fn handshake(&mut self) -> &mut handshake::N2CServer { + &mut self.handshake + } + pub fn chainsync(&mut self) -> &mut chainsync::N2CServer { &mut self.chainsync } @@ -344,7 +409,7 @@ impl NodeServer { &mut self.statequery } - pub fn abort(&mut self) { - self.plexer_handle.abort(); + pub async fn abort(self) { + self.plexer.abort().await } } diff --git a/pallas-network/src/miniprotocols/handshake/n2n.rs b/pallas-network/src/miniprotocols/handshake/n2n.rs index ad7bc17..eb1b273 100644 --- a/pallas-network/src/miniprotocols/handshake/n2n.rs +++ b/pallas-network/src/miniprotocols/handshake/n2n.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use pallas_codec::minicbor::{decode, Decode, Decoder, encode, Encode, Encoder}; +use pallas_codec::minicbor::{decode, encode, Decode, Decoder, Encode, Encoder}; pub type VersionTable = super::protocol::VersionTable; @@ -15,24 +15,38 @@ const PROTOCOL_V13: u64 = 13; const PEER_SHARING_DISABLED: u8 = 0; impl VersionTable { + #[deprecated(note = "no longer supported by spec")] pub fn v4_and_above(network_magic: u64) -> VersionTable { // Older versions are not supported anymore (removed from network-spec.pdf). // Try not to break compatibility with older pallas users. - return Self::v7_and_above(network_magic); + Self::v7_and_above(network_magic) } + #[deprecated(note = "no longer supported by spec")] pub fn v6_and_above(network_magic: u64) -> VersionTable { // Older versions are not supported anymore (removed from network-spec.pdf). // Try not to break compatibility with older pallas users. - return Self::v7_and_above(network_magic); + Self::v7_and_above(network_magic) } pub fn v7_to_v10(network_magic: u64) -> VersionTable { let values = vec![ - (PROTOCOL_V7, VersionData::new(network_magic, false, None, None)), - (PROTOCOL_V8, VersionData::new(network_magic, false, None, None)), - (PROTOCOL_V9, VersionData::new(network_magic, false, None, None)), - (PROTOCOL_V10, VersionData::new(network_magic, false, None, None)), + ( + PROTOCOL_V7, + VersionData::new(network_magic, false, None, None), + ), + ( + PROTOCOL_V8, + VersionData::new(network_magic, false, None, None), + ), + ( + PROTOCOL_V9, + VersionData::new(network_magic, false, None, None), + ), + ( + PROTOCOL_V10, + VersionData::new(network_magic, false, None, None), + ), ] .into_iter() .collect::>(); @@ -42,13 +56,49 @@ impl VersionTable { pub fn v7_and_above(network_magic: u64) -> VersionTable { let values = vec![ - (PROTOCOL_V7, VersionData::new(network_magic, false, None, None)), - (PROTOCOL_V8, VersionData::new(network_magic, false, None, None)), - (PROTOCOL_V9, VersionData::new(network_magic, false, None, None)), - (PROTOCOL_V10, VersionData::new(network_magic, false, None, None)), - (PROTOCOL_V11, VersionData::new(network_magic, false, Some(PEER_SHARING_DISABLED), Some(false))), - (PROTOCOL_V12, VersionData::new(network_magic, false, Some(PEER_SHARING_DISABLED), Some(false))), - (PROTOCOL_V13, VersionData::new(network_magic, false, Some(PEER_SHARING_DISABLED), Some(false))), + ( + PROTOCOL_V7, + VersionData::new(network_magic, false, None, None), + ), + ( + PROTOCOL_V8, + VersionData::new(network_magic, false, None, None), + ), + ( + PROTOCOL_V9, + VersionData::new(network_magic, false, None, None), + ), + ( + PROTOCOL_V10, + VersionData::new(network_magic, false, None, None), + ), + ( + PROTOCOL_V11, + VersionData::new( + network_magic, + false, + Some(PEER_SHARING_DISABLED), + Some(false), + ), + ), + ( + PROTOCOL_V12, + VersionData::new( + network_magic, + false, + Some(PEER_SHARING_DISABLED), + Some(false), + ), + ), + ( + PROTOCOL_V13, + VersionData::new( + network_magic, + false, + Some(PEER_SHARING_DISABLED), + Some(false), + ), + ), ] .into_iter() .collect::>(); @@ -58,9 +108,33 @@ impl VersionTable { pub fn v11_and_above(network_magic: u64) -> VersionTable { let values = vec![ - (PROTOCOL_V11, VersionData::new(network_magic, false, Some(PEER_SHARING_DISABLED), Some(false))), - (PROTOCOL_V12, VersionData::new(network_magic, false, Some(PEER_SHARING_DISABLED), Some(false))), - (PROTOCOL_V13, VersionData::new(network_magic, false, Some(PEER_SHARING_DISABLED), Some(false))), + ( + PROTOCOL_V11, + VersionData::new( + network_magic, + false, + Some(PEER_SHARING_DISABLED), + Some(false), + ), + ), + ( + PROTOCOL_V12, + VersionData::new( + network_magic, + false, + Some(PEER_SHARING_DISABLED), + Some(false), + ), + ), + ( + PROTOCOL_V13, + VersionData::new( + network_magic, + false, + Some(PEER_SHARING_DISABLED), + Some(false), + ), + ), ] .into_iter() .collect::>(); @@ -78,7 +152,12 @@ pub struct VersionData { } impl VersionData { - pub fn new(network_magic: u64, initiator_and_responder_diffusion_mode: bool, peer_sharing: Option, query: Option) -> Self { + pub fn new( + network_magic: u64, + initiator_and_responder_diffusion_mode: bool, + peer_sharing: Option, + query: Option, + ) -> Self { VersionData { network_magic, initiator_and_responder_diffusion_mode, @@ -101,12 +180,12 @@ impl Encode<()> for VersionData { .bool(self.initiator_and_responder_diffusion_mode)? .u8(peer_sharing)? .bool(query)?; - }, + } _ => { e.array(2)? .u64(self.network_magic)? .bool(self.initiator_and_responder_diffusion_mode)?; - }, + } }; Ok(()) diff --git a/pallas-network/src/miniprotocols/keepalive/client.rs b/pallas-network/src/miniprotocols/keepalive/client.rs index 45aa48b..0143507 100644 --- a/pallas-network/src/miniprotocols/keepalive/client.rs +++ b/pallas-network/src/miniprotocols/keepalive/client.rs @@ -1,5 +1,5 @@ -use std::fmt::Debug; use rand::Rng; +use std::fmt::Debug; use thiserror::*; use tracing::debug; @@ -35,7 +35,11 @@ pub struct Client(State, multiplexer::ChannelBuffer, KeepAliveSharedState); impl Client { pub fn new(channel: multiplexer::AgentChannel) -> Self { - Self(State::Client, multiplexer::ChannelBuffer::new(channel), KeepAliveSharedState{ saved_cookie: 0 }) + Self( + State::Client, + multiplexer::ChannelBuffer::new(channel), + KeepAliveSharedState { saved_cookie: 0 }, + ) } pub fn state(&self) -> &State { diff --git a/pallas-network/src/miniprotocols/keepalive/codec.rs b/pallas-network/src/miniprotocols/keepalive/codec.rs index e00ed9b..a9e5c0e 100644 --- a/pallas-network/src/miniprotocols/keepalive/codec.rs +++ b/pallas-network/src/miniprotocols/keepalive/codec.rs @@ -11,14 +11,14 @@ impl Encode<()> for Message { Message::KeepAlive(cookie) => { e.array(2)?.u16(0)?; e.encode(cookie)?; - }, + } Message::ResponseKeepAlive(cookie) => { e.array(2)?.u16(1)?; e.encode(cookie)?; - }, + } Message::Done => { e.array(1)?.u16(2)?; - }, + } } Ok(()) diff --git a/pallas-network/src/multiplexer.rs b/pallas-network/src/multiplexer.rs index e69802b..3bd5286 100644 --- a/pallas-network/src/multiplexer.rs +++ b/pallas-network/src/multiplexer.rs @@ -2,17 +2,19 @@ use byteorder::{ByteOrder, NetworkEndian}; use pallas_codec::{minicbor, Fragment}; -use std::net::SocketAddr; use thiserror::Error; -use tokio::io::AsyncWriteExt; -use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; -use tokio::select; -use tokio::sync::mpsc::error::SendError; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::task::JoinHandle; use tokio::time::Instant; +use tokio::{select, sync::mpsc::error::SendError}; use tracing::{debug, error, trace}; +type IOResult = tokio::io::Result; + +use tokio::net as tcp; + #[cfg(unix)] -use tokio::net::{UnixListener, UnixStream}; +use tokio::net as unix; #[cfg(windows)] use tokio::net::windows::named_pipe::NamedPipeClient; @@ -63,112 +65,128 @@ pub struct Segment { } pub enum Bearer { - Tcp(TcpStream), + Tcp(tcp::TcpStream), #[cfg(unix)] - Unix(UnixStream), - - #[cfg(windows)] - NamedPipe(NamedPipeClient), + Unix(unix::UnixStream), + // #[cfg(windows)] + // NamedPipe(NamedPipeClient), } -const BUFFER_LEN: usize = 1024 * 10; - impl Bearer { - pub async fn connect_tcp(addr: impl ToSocketAddrs) -> Result { - let stream = TcpStream::connect(addr).await?; - // add tcp_keepalive + fn configure_tcp(stream: &tcp::TcpStream) -> IOResult<()> { let sock_ref = socket2::SockRef::from(&stream); let mut tcp_keepalive = socket2::TcpKeepalive::new(); tcp_keepalive = tcp_keepalive.with_time(tokio::time::Duration::from_secs(20)); tcp_keepalive = tcp_keepalive.with_interval(tokio::time::Duration::from_secs(20)); - let _ = sock_ref.set_tcp_keepalive(&tcp_keepalive); - // add tcp_nodelay - let _ = sock_ref.set_nodelay(true); + sock_ref.set_tcp_keepalive(&tcp_keepalive)?; + sock_ref.set_nodelay(true)?; + Ok(()) + } + + pub async fn connect_tcp(addr: impl tcp::ToSocketAddrs) -> Result { + let stream = tcp::TcpStream::connect(addr).await?; + Self::configure_tcp(&stream)?; Ok(Self::Tcp(stream)) } - pub async fn connect_tcp_timeout(addr: impl ToSocketAddrs, timeout: std::time::Duration) -> Result { - match tokio::time::timeout(timeout, Self::connect_tcp(addr)).await { - Ok(Ok(stream)) => Ok(stream), - Ok(Err(err)) => Err(err), - Err(_) => Err(tokio::io::Error::new(tokio::io::ErrorKind::TimedOut, "connection timed out")), + pub async fn connect_tcp_timeout( + addr: impl tcp::ToSocketAddrs, + timeout: std::time::Duration, + ) -> IOResult { + select! { + result = Self::connect_tcp(addr) => result, + _ = tokio::time::sleep(timeout) => Err(tokio::io::Error::new(tokio::io::ErrorKind::TimedOut, "connect timeout")), } } - pub async fn accept_tcp(listener: &TcpListener) -> tokio::io::Result<(Self, SocketAddr)> { + pub async fn accept_tcp(listener: &tcp::TcpListener) -> IOResult<(Self, std::net::SocketAddr)> { let (stream, addr) = listener.accept().await?; + Self::configure_tcp(&stream)?; Ok((Self::Tcp(stream), addr)) } #[cfg(unix)] - pub async fn connect_unix(path: impl AsRef) -> Result { - let stream = UnixStream::connect(path).await?; + pub async fn connect_unix(path: impl AsRef) -> IOResult { + let stream = unix::UnixStream::connect(path).await?; Ok(Self::Unix(stream)) } #[cfg(unix)] pub async fn accept_unix( - listener: &UnixListener, - ) -> tokio::io::Result<(Self, tokio::net::unix::SocketAddr)> { + listener: &unix::UnixListener, + ) -> IOResult<(Self, unix::unix::SocketAddr)> { let (stream, addr) = listener.accept().await?; Ok((Self::Unix(stream), addr)) } - #[cfg(windows)] - pub async fn connect_named_pipe( - pipe_name: impl AsRef, - ) -> Result { - let client = tokio::net::windows::named_pipe::ClientOptions::new().open(&pipe_name)?; - Ok(Self::NamedPipe(client)) - } + // #[cfg(windows)] + // pub fn connect_named_pipe(pipe_name: impl AsRef) -> + // IOResult { let client = + // tokio::net::windows::named_pipe::ClientOptions::new().open(&pipe_name)?; + // Ok(Self::NamedPipe(client)) + // } - pub async fn readable(&mut self) -> tokio::io::Result<()> { + pub fn into_split(self) -> (BearerReadHalf, BearerWriteHalf) { match self { - Bearer::Tcp(x) => x.readable().await, + Bearer::Tcp(x) => { + let (r, w) = x.into_split(); + (BearerReadHalf::Tcp(r), BearerWriteHalf::Tcp(w)) + } #[cfg(unix)] - Bearer::Unix(x) => x.readable().await, + Bearer::Unix(x) => { + let (r, w) = x.into_split(); + (BearerReadHalf::Unix(r), BearerWriteHalf::Unix(w)) + } + } + } +} - #[cfg(windows)] - Bearer::NamedPipe(x) => x.readable().await, +pub enum BearerReadHalf { + Tcp(tcp::tcp::OwnedReadHalf), + + #[cfg(unix)] + Unix(unix::unix::OwnedReadHalf), +} + +impl BearerReadHalf { + async fn read_exact(&mut self, buf: &mut [u8]) -> IOResult { + match self { + BearerReadHalf::Tcp(x) => x.read_exact(buf).await, + + #[cfg(unix)] + BearerReadHalf::Unix(x) => x.read_exact(buf).await, + } + } +} + +pub enum BearerWriteHalf { + Tcp(tcp::tcp::OwnedWriteHalf), + + #[cfg(unix)] + Unix(unix::unix::OwnedWriteHalf), +} + +impl BearerWriteHalf { + async fn write_all(&mut self, buf: &[u8]) -> IOResult<()> { + match self { + Self::Tcp(x) => x.write_all(buf).await, + + #[cfg(unix)] + Self::Unix(x) => x.write_all(buf).await, } } - fn try_read(&mut self, buf: &mut [u8]) -> tokio::io::Result { + async fn flush(&mut self) -> IOResult<()> { match self { - Bearer::Tcp(x) => x.try_read(buf), + Self::Tcp(x) => x.flush().await, #[cfg(unix)] - Bearer::Unix(x) => x.try_read(buf), - - #[cfg(windows)] - Bearer::NamedPipe(x) => x.try_read(buf), - } - } - - async fn write_all(&mut self, buf: &[u8]) -> tokio::io::Result<()> { - match self { - Bearer::Tcp(x) => x.write_all(buf).await, - - #[cfg(unix)] - Bearer::Unix(x) => x.write_all(buf).await, - - #[cfg(windows)] - Bearer::NamedPipe(x) => x.write_all(buf).await, - } - } - - async fn flush(&mut self) -> tokio::io::Result<()> { - match self { - Bearer::Tcp(x) => x.flush().await, - - #[cfg(unix)] - Bearer::Unix(x) => x.flush().await, - - #[cfg(windows)] - Bearer::NamedPipe(x) => x.flush().await, + Self::Unix(x) => x.flush().await, + //#[cfg(windows)] + //Bearer::NamedPipe(x) => x.flush().await, } } } @@ -198,85 +216,94 @@ pub enum Error { #[error("plexer failed to mux chunk")] PlexerMux, + + #[error("failure to abort the plexer threads")] + AbortFailure, } -pub struct SegmentBuffer(Bearer, Vec); +type Egress = ( + tokio::sync::broadcast::Sender<(Protocol, Payload)>, + tokio::sync::broadcast::Receiver<(Protocol, Payload)>, +); -impl SegmentBuffer { - pub fn new(bearer: Bearer) -> Self { - Self(bearer, Vec::with_capacity(BUFFER_LEN)) +const EGRESS_MSG_QUEUE_BUFFER: usize = 100_000; + +pub struct Demuxer(BearerReadHalf, Egress); + +impl Demuxer { + pub fn new(bearer: BearerReadHalf) -> Self { + let egress = tokio::sync::broadcast::channel(EGRESS_MSG_QUEUE_BUFFER); + Self(bearer, egress) } - /// Cancel-safe loop that reads from bearer until certain len - async fn cancellable_read(&mut self, required: usize) -> Result<(), Error> { + pub async fn read_segment(&mut self) -> Result<(Protocol, Payload), Error> { + trace!("waiting for segment header"); + let mut buf = vec![0u8; HEADER_LEN]; + self.0.read_exact(&mut buf).await.map_err(Error::BearerIo)?; + let header = Header::from(buf.as_slice()); + + trace!("waiting for full segment"); + let segment_size = header.payload_len as usize; + let mut buf = vec![0u8; segment_size]; + self.0.read_exact(&mut buf).await.map_err(Error::BearerIo)?; + + Ok((header.protocol, buf)) + } + + fn demux(&mut self, protocol: Protocol, payload: Payload) -> Result<(), Error> { + if tracing::event_enabled!(tracing::Level::TRACE) { + trace!(protocol, data = hex::encode(&payload), "read from bearer"); + } + + self.1 + .0 + .send((protocol, payload)) + .map_err(|err| Error::PlexerDemux(err.0 .0, err.0 .1))?; + + Ok(()) + } + + pub fn subscribe_recv(&self) -> tokio::sync::broadcast::Receiver<(Protocol, Payload)> { + self.1 .0.subscribe() + } + + pub async fn tick(&mut self) -> Result<(), Error> { + let (protocol, payload) = self.read_segment().await?; + trace!(protocol, "demux happening"); + self.demux(protocol, payload) + } + + pub async fn run(&mut self) -> Result<(), Error> { loop { - self.0.readable().await.map_err(Error::BearerIo)?; - trace!("bearer is readable"); - - let remaining = required - self.1.len(); - let mut buf = vec![0u8; remaining]; - - match self.0.try_read(&mut buf) { - Ok(0) => { - error!("empty bearer"); - break Err(Error::EmptyBearer); - } - Ok(n) => { - trace!(n, "found data on bearer"); - self.1.extend_from_slice(&buf[0..n]); - - if self.1.len() >= required { - break Ok(()); - } - } - Err(ref e) if e.kind() == tokio::io::ErrorKind::WouldBlock => { - trace!("reading from bearer would block"); - continue; - } - Err(err) => { - error!(?err, "beaerer IO error"); - break Err(Error::BearerIo(err)); - } + if let Err(err) = self.tick().await { + break Err(err); } } } +} - /// Peek the available data in search for a frame header - async fn peek_header(&mut self) -> Result { - trace!("waiting for header buf"); - self.cancellable_read(HEADER_LEN).await?; +type Ingress = ( + tokio::sync::mpsc::Sender<(Protocol, Payload)>, + tokio::sync::mpsc::Receiver<(Protocol, Payload)>, +); - trace!("found enough data for header"); - let header = &self.1[..HEADER_LEN]; +type Clock = Instant; - Ok(Header::from(header)) +const INGRESS_MSG_QUEUE_BUFFER: usize = 100; + +pub struct Muxer(BearerWriteHalf, Clock, Ingress); + +impl Muxer { + pub fn new(bearer: BearerWriteHalf) -> Self { + let ingress = tokio::sync::mpsc::channel(INGRESS_MSG_QUEUE_BUFFER); + let clock = Instant::now(); + Self(bearer, clock, ingress) } - // Cancel-safe read of a full segment from the bearer - pub async fn read_segment(&mut self) -> Result<(Protocol, Payload), Error> { - let header = self.peek_header().await?; - - trace!("waiting for full segment buf"); - let segment_size = HEADER_LEN + header.payload_len as usize; - - self.cancellable_read(segment_size).await?; - - trace!("draining segment buffer"); - let segment = self.1.drain(..segment_size); - let payload = segment.skip(HEADER_LEN).collect(); - - Ok((header.protocol, payload)) - } - - pub async fn write_segment( - &mut self, - protocol: u16, - clock: &Instant, - payload: &[u8], - ) -> Result<(), std::io::Error> { + async fn write_segment(&mut self, protocol: u16, payload: &[u8]) -> Result<(), std::io::Error> { let header = Header { protocol, - timestamp: clock.elapsed().as_micros() as u32, + timestamp: self.1.elapsed().as_micros() as u32, payload_len: payload.len() as u16, }; @@ -288,31 +315,81 @@ impl SegmentBuffer { Ok(()) } + + pub async fn mux(&mut self, msg: (Protocol, Payload)) -> Result<(), Error> { + self.write_segment(msg.0, &msg.1) + .await + .map_err(|_| Error::PlexerMux)?; + + if tracing::event_enabled!(tracing::Level::TRACE) { + trace!( + protocol = msg.0, + data = hex::encode(&msg.1), + "write to bearer" + ); + } + + Ok(()) + } + + pub fn clone_sender(&self) -> tokio::sync::mpsc::Sender<(Protocol, Payload)> { + self.2 .0.clone() + } + + pub async fn tick(&mut self) -> Result<(), Error> { + let msg = self.2 .1.recv().await; + + if let Some(x) = msg { + trace!(protocol = x.0, "mux happening"); + self.mux(x).await? + } + + Ok(()) + } + + pub async fn run(&mut self) -> Result<(), Error> { + loop { + if let Err(err) = self.tick().await { + break Err(err); + } + } + } } +type ToPlexerPort = tokio::sync::mpsc::Sender<(Protocol, Payload)>; +type FromPlexerPort = tokio::sync::broadcast::Receiver<(Protocol, Payload)>; + pub struct AgentChannel { enqueue_protocol: Protocol, dequeue_protocol: Protocol, - to_plexer: tokio::sync::mpsc::Sender<(Protocol, Payload)>, - from_plexer: tokio::sync::broadcast::Receiver<(Protocol, Payload)>, + to_plexer: ToPlexerPort, + from_plexer: FromPlexerPort, } impl AgentChannel { - fn for_client(protocol: Protocol, ingress: &Ingress, egress: &Egress) -> Self { + fn for_client( + protocol: Protocol, + to_plexer: ToPlexerPort, + from_plexer: FromPlexerPort, + ) -> Self { Self { enqueue_protocol: protocol, dequeue_protocol: protocol ^ 0x8000, - to_plexer: ingress.0.clone(), - from_plexer: egress.0.subscribe(), + from_plexer, + to_plexer, } } - fn for_server(protocol: Protocol, ingress: &Ingress, egress: &Egress) -> Self { + fn for_server( + protocol: Protocol, + to_plexer: ToPlexerPort, + from_plexer: FromPlexerPort, + ) -> Self { Self { enqueue_protocol: protocol ^ 0x8000, dequeue_protocol: protocol, - to_plexer: ingress.0.clone(), - from_plexer: egress.0.subscribe(), + from_plexer, + to_plexer, } } @@ -339,92 +416,53 @@ impl AgentChannel { } } -type Ingress = ( - tokio::sync::mpsc::Sender<(Protocol, Payload)>, - tokio::sync::mpsc::Receiver<(Protocol, Payload)>, -); +pub struct RunningPlexer { + demuxer: JoinHandle>, + muxer: JoinHandle>, +} -type Egress = ( - tokio::sync::broadcast::Sender<(Protocol, Payload)>, - tokio::sync::broadcast::Receiver<(Protocol, Payload)>, -); +impl RunningPlexer { + pub async fn abort(self) { + self.demuxer.abort(); + self.muxer.abort(); + } +} pub struct Plexer { - clock: Instant, - bearer: SegmentBuffer, - ingress: Ingress, - egress: Egress, + demuxer: Demuxer, + muxer: Muxer, } impl Plexer { pub fn new(bearer: Bearer) -> Self { + let (r, w) = bearer.into_split(); + Self { - clock: Instant::now(), - bearer: SegmentBuffer::new(bearer), - ingress: tokio::sync::mpsc::channel(100), // TODO: define buffer - egress: tokio::sync::broadcast::channel(100000), + demuxer: Demuxer::new(r), + muxer: Muxer::new(w), } } - async fn mux(&mut self, msg: (Protocol, Payload)) -> Result<(), Error> { - self.bearer - .write_segment(msg.0, &self.clock, &msg.1) - .await - .map_err(|_| Error::PlexerMux)?; - - if tracing::event_enabled!(tracing::Level::TRACE) { - trace!( - protocol = msg.0, - data = hex::encode(&msg.1), - "write to bearer" - ); - } - - Ok(()) - } - - async fn demux(&mut self, protocol: Protocol, payload: Payload) -> Result<(), Error> { - if tracing::event_enabled!(tracing::Level::TRACE) { - trace!(protocol, data = hex::encode(&payload), "read from bearer"); - } - - self.egress - .0 - .send((protocol, payload)) - .map_err(|err| Error::PlexerDemux(err.0 .0, err.0 .1))?; - - Ok(()) - } - pub fn subscribe_client(&mut self, protocol: Protocol) -> AgentChannel { - AgentChannel::for_client(protocol, &self.ingress, &self.egress) + let to_plexer = self.muxer.clone_sender(); + let from_plexer = self.demuxer.subscribe_recv(); + AgentChannel::for_client(protocol, to_plexer, from_plexer) } pub fn subscribe_server(&mut self, protocol: Protocol) -> AgentChannel { - AgentChannel::for_server(protocol, &self.ingress, &self.egress) + let to_plexer = self.muxer.clone_sender(); + let from_plexer = self.demuxer.subscribe_recv(); + AgentChannel::for_server(protocol, to_plexer, from_plexer) } - pub async fn run(&mut self) -> Result<(), Error> { - loop { - trace!("selecting"); - select! { - res = self.bearer.read_segment() => { - let x = res?; - trace!("demux selected"); - self.demux(x.0, x.1).await? - }, - Some(x) = self.ingress.1.recv() => { - trace!("mux selected"); - self.mux(x).await? - }, - _ = tokio::time::sleep(tokio::time::Duration::from_secs(5)) => { - trace!("idle plexer"); - } - else => { - error!("something else happened"); - } - } - } + pub fn spawn(self) -> RunningPlexer { + let mut demuxer = self.demuxer; + let mut muxer = self.muxer; + + let demuxer = tokio::spawn(async move { demuxer.run().await }); + let muxer = tokio::spawn(async move { muxer.run().await }); + + RunningPlexer { demuxer, muxer } } } @@ -538,12 +576,12 @@ mod tests { minicbor::encode(in_part1, &mut input).unwrap(); minicbor::encode(in_part2, &mut input).unwrap(); - let ingress = tokio::sync::mpsc::channel(100); - let egress = tokio::sync::broadcast::channel(100); + let (to_plexer, _) = tokio::sync::mpsc::channel(100); + let (into_plexer, from_plexer) = tokio::sync::broadcast::channel(100); - let channel = AgentChannel::for_client(0, &ingress, &egress); + let channel = AgentChannel::for_client(0, to_plexer, from_plexer); - egress.0.send((0x8000, input)).unwrap(); + into_plexer.send((0x8000, input)).unwrap(); let mut buf = ChannelBuffer::new(channel); @@ -560,14 +598,14 @@ mod tests { let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8); minicbor::encode(msg, &mut input).unwrap(); - let ingress = tokio::sync::mpsc::channel(100); - let egress = tokio::sync::broadcast::channel(100); + let (to_plexer, _) = tokio::sync::mpsc::channel(100); + let (into_plexer, from_plexer) = tokio::sync::broadcast::channel(100); - let channel = AgentChannel::for_client(0, &ingress, &egress); + let channel = AgentChannel::for_client(0, to_plexer, from_plexer); while !input.is_empty() { let chunk = Vec::from(input.drain(0..2).as_slice()); - egress.0.send((0x8000, chunk)).unwrap(); + into_plexer.send((0x8000, chunk)).unwrap(); } let mut buf = ChannelBuffer::new(channel); diff --git a/pallas-network/tests/plexer.rs b/pallas-network/tests/plexer.rs index 8ae0720..13de2c7 100644 --- a/pallas-network/tests/plexer.rs +++ b/pallas-network/tests/plexer.rs @@ -33,7 +33,7 @@ fn random_payload(size: usize) -> Vec { #[tokio::test] async fn one_way_small_sequence_of_payloads() { - let passive = tokio::spawn(setup_passive_muxer::<50301>()); + let passive = tokio::task::spawn(setup_passive_muxer::<50301>()); // HACK: a small sleep seems to be required for Github actions runner to // formally expose the port @@ -46,8 +46,8 @@ async fn one_way_small_sequence_of_payloads() { let mut sender_channel = active.subscribe_client(3); let mut receiver_channel = passive.subscribe_server(3); - let passive_run = tokio::spawn(async move { passive.run().await }); - let active_run = tokio::spawn(async move { active.run().await }); + let passive = passive.spawn(); + let active = active.spawn(); for _ in 0..100 { let payload = random_payload(50); @@ -57,6 +57,6 @@ async fn one_way_small_sequence_of_payloads() { assert_eq!(payload, received_payload); } - passive_run.abort(); - active_run.abort(); + passive.abort().await; + active.abort().await; } diff --git a/pallas-network/tests/protocols.rs b/pallas-network/tests/protocols.rs index b6b401a..436f07f 100644 --- a/pallas-network/tests/protocols.rs +++ b/pallas-network/tests/protocols.rs @@ -19,7 +19,11 @@ use pallas_network::miniprotocols::{ use pallas_network::miniprotocols::{handshake, localstate, txsubmission, MAINNET_MAGIC}; use pallas_network::multiplexer::{Bearer, Plexer}; use std::path::Path; -use tokio::net::{TcpListener, UnixListener}; + +use tokio::net::TcpListener; + +#[cfg(unix)] +use tokio::net::UnixListener; #[tokio::test] #[ignore] @@ -172,17 +176,17 @@ pub async fn blockfetch_server_and_client_happy_path() { hex::decode("deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef").unwrap(), ); + let listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 30003)) + .await + .unwrap(); + let server = tokio::spawn({ let bodies = block_bodies.clone(); let point = point.clone(); async move { // server setup - let server_listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 30001)) - .await - .unwrap(); - - let mut peer_server = PeerServer::accept(&server_listener, 0).await.unwrap(); + let mut peer_server = PeerServer::accept(&listener, 0).await.unwrap(); let server_bf = peer_server.blockfetch(); @@ -214,9 +218,7 @@ pub async fn blockfetch_server_and_client_happy_path() { let client = tokio::spawn(async move { tokio::time::sleep(Duration::from_secs(1)).await; - // client setup - - let mut client_to_server_conn = PeerClient::connect("localhost:30001", 0).await.unwrap(); + let mut client_to_server_conn = PeerClient::connect("localhost:30003", 0).await.unwrap(); let client_bf = client_to_server_conn.blockfetch(); @@ -269,7 +271,7 @@ pub async fn chainsync_server_and_client_happy_path_n2n() { async move { // server setup - let server_listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 30001)) + let server_listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 30002)) .await .unwrap(); @@ -281,7 +283,7 @@ pub async fn chainsync_server_and_client_happy_path_n2n() { handshake::Server::new(server_plexer.subscribe_server(0)); let mut server_cs = chainsync::N2NServer::new(server_plexer.subscribe_server(2)); - tokio::spawn(async move { server_plexer.run().await }); + let server_plexer = server_plexer.spawn(); server_hs.receive_proposed_versions().await.unwrap(); server_hs @@ -377,15 +379,13 @@ pub async fn chainsync_server_and_client_happy_path_n2n() { assert!(server_cs.recv_while_idle().await.unwrap().is_none()); assert_eq!(*server_cs.state(), chainsync::State::Done); + + server_plexer.abort().await; } }); let client = tokio::spawn(async move { - tokio::time::sleep(Duration::from_secs(2)).await; - - // client setup - - let mut client_to_server_conn = PeerClient::connect("localhost:30001", 0).await.unwrap(); + let mut client_to_server_conn = PeerClient::connect("localhost:30002", 0).await.unwrap(); let client_cs = client_to_server_conn.chainsync(); @@ -461,15 +461,15 @@ pub async fn local_state_query_server_and_client_happy_path() { let server = tokio::spawn({ async move { // server setup - let socket_path = Path::new("node.socket"); + let socket_path = Path::new("node1.socket"); if socket_path.exists() { fs::remove_file(socket_path).unwrap(); } - let unix_listener = UnixListener::bind(socket_path).unwrap(); + let listener = UnixListener::bind(socket_path).unwrap(); - let mut server = pallas_network::facades::NodeServer::accept(&unix_listener, 0) + let mut server = pallas_network::facades::NodeServer::accept(&listener, 0) .await .unwrap(); @@ -552,7 +552,9 @@ pub async fn local_state_query_server_and_client_happy_path() { x => panic!("unexpected message from client: {x:?}"), }; - let addr_hex = "981D186018CE18F718FB185F188918A918C7186A186518AC18DD1874186D189E188410184D186F1882184D187D18C4184F1842187F18CA18A118DD"; + let addr_hex = +"981D186018CE18F718FB185F188918A918C7186A186518AC18DD1874186D189E188410184D186F1882184D187D18C4184F1842187F18CA18A118DD" +; let addr = hex::decode(addr_hex).unwrap(); let addr: Addr = addr.to_vec().into(); let addrs: Addrs = Vec::from([addr]); @@ -629,8 +631,7 @@ pub async fn local_state_query_server_and_client_happy_path() { tokio::time::sleep(Duration::from_secs(1)).await; // client setup - - let socket_path = "node.socket"; + let socket_path = "node1.socket"; let mut client = NodeClient::connect(socket_path, 0).await.unwrap(); @@ -705,7 +706,9 @@ pub async fn local_state_query_server_and_client_happy_path() { assert_eq!(result, localstate::queries_v16::StakeDistribution { pools }); - let addr_hex = "981D186018CE18F718FB185F188918A918C7186A186518AC18DD1874186D189E188410184D186F1882184D187D18C4184F1842187F18CA18A118DD"; + let addr_hex = +"981D186018CE18F718FB185F188918A918C7186A186518AC18DD1874186D189E188410184D186F1882184D187D18C4184F1842187F18CA18A118DD" +; let addr = hex::decode(addr_hex).unwrap(); let addr: Addr = addr.to_vec().into(); let addrs: Addrs = Vec::from([addr]); @@ -776,13 +779,13 @@ pub async fn local_state_query_server_and_client_happy_path() { pub async fn txsubmission_server_and_client_happy_path_n2n() { let test_txs = vec![(vec![0], vec![0, 0, 0]), (vec![1], vec![1, 1, 1])]; + let server_listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 30001)) + .await + .unwrap(); + let server = tokio::spawn({ let test_txs = test_txs.clone(); async move { - let server_listener = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 30001)) - .await - .unwrap(); - let mut peer_server = PeerServer::accept(&server_listener, 0).await.unwrap(); let server_txsub = peer_server.txsubmission(); @@ -857,7 +860,6 @@ pub async fn txsubmission_server_and_client_happy_path_n2n() { let mut mempool = test_txs.clone(); // client setup - let mut client_to_server_conn = PeerClient::connect("localhost:30001", 0).await.unwrap(); let client_txsub = client_to_server_conn.txsubmission(); @@ -941,7 +943,6 @@ pub async fn txsubmission_submit_to_mainnet_peer_n2n() { let mempool = vec![(tx_hash, tx_bytes)]; // client setup - let mut client_to_server_conn = PeerClient::connect("relays-new.cardano-mainnet.iohk.io:3001", MAINNET_MAGIC) .await