diff --git a/examples/n2n-miniprotocols/src/main.rs b/examples/n2n-miniprotocols/src/main.rs index dfb950f..2adf4d9 100644 --- a/examples/n2n-miniprotocols/src/main.rs +++ b/examples/n2n-miniprotocols/src/main.rs @@ -5,7 +5,7 @@ use pallas::{ miniprotocols::{blockfetch, chainsync, keepalive, Point, MAINNET_MAGIC}, }, }; -use std::time::Duration; + use thiserror::Error; use tokio::time::Instant; @@ -117,14 +117,6 @@ async fn do_chainsync( } } -async fn do_keepalive(mut keepalive_client: keepalive::Client) -> Result<(), Error> { - loop { - tokio::time::sleep(Duration::from_secs(20)).await; - keepalive_client.send_keepalive().await?; - tracing::info!("keepalive sent"); - } -} - #[tokio::main] async fn main() { tracing::subscriber::set_global_default( @@ -145,25 +137,10 @@ async fn main() { plexer, chainsync, blockfetch, - keepalive, .. } = peer; - let chainsync_handle = tokio::spawn(do_chainsync(chainsync, blockfetch)); - let keepalive_handle = tokio::spawn(do_keepalive(keepalive)); - - // If any of these concurrent tasks exit or fail, the others are canceled. - let (chainsync_result, keepalive_result) = - tokio::try_join!(chainsync_handle, keepalive_handle) - .expect("error joining tokio threads"); - - if let Err(err) = chainsync_result { - tracing::error!("chainsync error: {:?}", err); - } - - if let Err(err) = keepalive_result { - tracing::error!("keepalive error: {:?}", err); - } + do_chainsync(chainsync, blockfetch).await.unwrap(); plexer.abort().await; diff --git a/pallas-network/src/facades.rs b/pallas-network/src/facades.rs index 06df821..6cdf37a 100644 --- a/pallas-network/src/facades.rs +++ b/pallas-network/src/facades.rs @@ -1,7 +1,8 @@ use std::net::SocketAddr; use std::path::Path; +use std::time::Duration; use thiserror::Error; -use tracing::error; +use tracing::{debug, error, warn}; use tokio::net::{TcpListener, ToSocketAddrs}; @@ -30,53 +31,104 @@ pub enum Error { #[error("handshake protocol error")] HandshakeProtocol(handshake::Error), + #[error("keepalive client loop error")] + KeepAliveClientLoop(keepalive::ClientError), + + #[error("keepalive server loop error")] + KeepAliveServerLoop(keepalive::ServerError), + #[error("handshake version not accepted")] IncompatibleVersion, } +pub const DEFAULT_KEEP_ALIVE_INTERVAL_SEC: u64 = 20; + +pub type KeepAliveHandle = tokio::task::JoinHandle>; + +pub enum KeepAliveLoop { + Client(keepalive::Client, Duration), + Server(keepalive::Server), +} + +impl KeepAliveLoop { + pub fn client(client: keepalive::Client, interval: Duration) -> Self { + Self::Client(client, interval) + } + + pub fn server(server: keepalive::Server) -> Self { + Self::Server(server) + } + + pub async fn run_client( + mut client: keepalive::Client, + interval: Duration, + ) -> Result<(), Error> { + let mut interval = tokio::time::interval(interval); + + loop { + interval.tick().await; + warn!("sending keepalive request"); + + client + .keepalive_roundtrip() + .await + .map_err(Error::KeepAliveClientLoop)?; + } + } + + pub async fn run_server(mut server: keepalive::Server) -> Result<(), Error> { + loop { + debug!("waiting keepalive request"); + + server + .keepalive_roundtrip() + .await + .map_err(Error::KeepAliveServerLoop)?; + } + } + + pub fn spawn(self) -> KeepAliveHandle { + match self { + KeepAliveLoop::Client(client, interval) => { + tokio::spawn(Self::run_client(client, interval)) + } + KeepAliveLoop::Server(server) => tokio::spawn(Self::run_server(server)), + } + } +} + /// Client of N2N Ouroboros pub struct PeerClient { pub plexer: RunningPlexer, - pub handshake: handshake::N2NClient, + pub keepalive: KeepAliveHandle, pub chainsync: chainsync::N2NClient, pub blockfetch: blockfetch::Client, pub txsubmission: txsubmission::Client, - pub keepalive: keepalive::Client, } impl PeerClient { - pub fn new(bearer: Bearer) -> Self { - let mut plexer = multiplexer::Plexer::new(bearer); - - let hs_channel = plexer.subscribe_client(PROTOCOL_N2N_HANDSHAKE); - let cs_channel = plexer.subscribe_client(PROTOCOL_N2N_CHAIN_SYNC); - let bf_channel = plexer.subscribe_client(PROTOCOL_N2N_BLOCK_FETCH); - let txsub_channel = plexer.subscribe_client(PROTOCOL_N2N_TX_SUBMISSION); - let keepalive_channel = plexer.subscribe_client(PROTOCOL_N2N_KEEP_ALIVE); - - let plexer = plexer.spawn(); - - Self { - plexer, - handshake: handshake::Client::new(hs_channel), - chainsync: chainsync::Client::new(cs_channel), - blockfetch: blockfetch::Client::new(bf_channel), - txsubmission: txsubmission::Client::new(txsub_channel), - keepalive: keepalive::Client::new(keepalive_channel), - } - } - pub async fn connect(addr: impl ToSocketAddrs, magic: u64) -> Result { let bearer = Bearer::connect_tcp(addr) .await .map_err(Error::ConnectFailure)?; - let mut client = Self::new(bearer); + let mut plexer = multiplexer::Plexer::new(bearer); + + let channel = plexer.subscribe_client(PROTOCOL_N2N_HANDSHAKE); + let mut handshake = handshake::Client::new(channel); + + let cs_channel = plexer.subscribe_client(PROTOCOL_N2N_CHAIN_SYNC); + let bf_channel = plexer.subscribe_client(PROTOCOL_N2N_BLOCK_FETCH); + let txsub_channel = plexer.subscribe_client(PROTOCOL_N2N_TX_SUBMISSION); + + let channel = plexer.subscribe_client(PROTOCOL_N2N_KEEP_ALIVE); + let keepalive = keepalive::Client::new(channel); + + let plexer = plexer.spawn(); let versions = handshake::n2n::VersionTable::v7_and_above(magic); - let handshake = client - .handshake() + let handshake = handshake .handshake(versions) .await .map_err(Error::HandshakeProtocol)?; @@ -86,11 +138,21 @@ impl PeerClient { return Err(Error::IncompatibleVersion); } - Ok(client) - } + let keepalive = KeepAliveLoop::client( + keepalive, + Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_SEC), + ) + .spawn(); - pub fn handshake(&mut self) -> &mut handshake::N2NClient { - &mut self.handshake + let client = Self { + plexer, + keepalive, + chainsync: chainsync::Client::new(cs_channel), + blockfetch: blockfetch::Client::new(bf_channel), + txsubmission: txsubmission::Client::new(txsub_channel), + }; + + Ok(client) } pub fn chainsync(&mut self) -> &mut chainsync::N2NClient { @@ -114,10 +176,6 @@ impl PeerClient { &mut self.txsubmission } - pub fn keepalive(&mut self) -> &mut keepalive::Client { - &mut self.keepalive - } - pub async fn abort(self) { self.plexer.abort().await } diff --git a/pallas-network/src/miniprotocols/keepalive/client.rs b/pallas-network/src/miniprotocols/keepalive/client.rs index 8c4e35d..be315b6 100644 --- a/pallas-network/src/miniprotocols/keepalive/client.rs +++ b/pallas-network/src/miniprotocols/keepalive/client.rs @@ -27,19 +27,11 @@ pub enum ClientError { Plexer(multiplexer::Error), } -pub struct KeepAliveSharedState { - saved_cookie: u16, -} - -pub struct Client(State, multiplexer::ChannelBuffer, KeepAliveSharedState); +pub struct Client(State, multiplexer::ChannelBuffer); impl Client { pub fn new(channel: multiplexer::AgentChannel) -> Self { - Self( - State::Client, - multiplexer::ChannelBuffer::new(channel), - KeepAliveSharedState { saved_cookie: 0 }, - ) + Self(State::Client, multiplexer::ChannelBuffer::new(channel)) } pub fn state(&self) -> &State { @@ -53,7 +45,7 @@ impl Client { fn has_agency(&self) -> bool { match &self.0 { State::Client => true, - State::Server => false, + State::Server(..) => false, State::Done => false, } } @@ -84,7 +76,7 @@ impl Client { fn assert_inbound_state(&self, msg: &Message) -> Result<(), ClientError> { match (&self.0, msg) { - (State::Server, Message::ResponseKeepAlive(..)) => Ok(()), + (State::Server(..), Message::ResponseKeepAlive(..)) => Ok(()), _ => Err(ClientError::InvalidInbound), } } @@ -108,32 +100,38 @@ impl Client { Ok(msg) } - pub async fn send_keepalive(&mut self) -> Result<(), ClientError> { + pub async fn send_keepalive_request(&mut self) -> Result<(), ClientError> { // generate random cookie value - let cookie = rand::thread_rng().gen::(); + let cookie = rand::thread_rng().gen::(); let msg = Message::KeepAlive(cookie); self.send_message(&msg).await?; - self.2.saved_cookie = cookie; - self.0 = State::Server; + self.0 = State::Server(cookie); debug!("sent keepalive message with cookie {}", cookie); - self.recv_while_sending_keepalive().await?; - Ok(()) } - async fn recv_while_sending_keepalive(&mut self) -> Result<(), ClientError> { + pub async fn recv_keepalive_response(&mut self) -> Result<(), ClientError> { match self.recv_message().await? { Message::ResponseKeepAlive(cookie) => { debug!("received keepalive response with cookie {}", cookie); - if cookie == self.2.saved_cookie { - self.0 = State::Client; - Ok(()) - } else { - Err(ClientError::KeepAliveCookieMismatch) + match self.state() { + State::Server(expected) if *expected == cookie => { + self.0 = State::Client; + Ok(()) + } + State::Server(..) => Err(ClientError::KeepAliveCookieMismatch), + _ => unreachable!(), } } _ => Err(ClientError::InvalidInbound), } } + + pub async fn keepalive_roundtrip(&mut self) -> Result<(), ClientError> { + self.send_keepalive_request().await?; + self.recv_keepalive_response().await?; + + Ok(()) + } } diff --git a/pallas-network/src/miniprotocols/keepalive/protocol.rs b/pallas-network/src/miniprotocols/keepalive/protocol.rs index 121228c..95e8e76 100644 --- a/pallas-network/src/miniprotocols/keepalive/protocol.rs +++ b/pallas-network/src/miniprotocols/keepalive/protocol.rs @@ -1,15 +1,15 @@ -pub type KeepAliveCookie = u16; +pub type Cookie = u16; #[derive(Debug, PartialEq, Eq, Clone)] pub enum State { Client, - Server, + Server(Cookie), Done, } #[derive(Debug, Clone)] pub enum Message { - KeepAlive(KeepAliveCookie), - ResponseKeepAlive(KeepAliveCookie), + KeepAlive(Cookie), + ResponseKeepAlive(Cookie), Done, } diff --git a/pallas-network/src/miniprotocols/keepalive/server.rs b/pallas-network/src/miniprotocols/keepalive/server.rs index eb354ea..af83883 100644 --- a/pallas-network/src/miniprotocols/keepalive/server.rs +++ b/pallas-network/src/miniprotocols/keepalive/server.rs @@ -41,7 +41,7 @@ impl Server { fn has_agency(&self) -> bool { match &self.0 { State::Client => false, - State::Server => true, + State::Server(..) => true, State::Done => false, } } @@ -64,8 +64,7 @@ impl Server { fn assert_outbound_state(&self, msg: &Message) -> Result<(), ServerError> { match (&self.0, msg) { - (State::Server, Message::ResponseKeepAlive(..)) => Ok(()), - + (State::Server(..), Message::ResponseKeepAlive(..)) => Ok(()), _ => Err(ServerError::InvalidOutbound), } } @@ -97,33 +96,37 @@ impl Server { Ok(msg) } - pub async fn send_keepalive_response( - &mut self, - cookie: KeepAliveCookie, - ) -> Result<(), ServerError> { - let msg = Message::ResponseKeepAlive(cookie); - self.send_message(&msg).await?; - self.0 = State::Client; - debug!("sent keepalive response message with cookie {}", cookie); - - Ok(()) - } - - pub async fn keepalive_receive_and_respond(&mut self) -> Result, ServerError> { + pub async fn recv_keepalive_request(&mut self) -> Result<(), ServerError> { match self.recv_message().await? { Message::KeepAlive(cookie) => { debug!("received keepalive message with cookie {}", cookie); - - self.0 = State::Server; - Some(self.send_keepalive_response(cookie).await).transpose() + self.0 = State::Server(cookie); + Ok(()) } Message::Done => { debug!("client sent done message in keepalive protocol"); - self.0 = State::Done; - Ok(None) + Ok(()) } _ => Err(ServerError::InvalidInbound), } } + + pub async fn send_keepalive_response(&mut self) -> Result<(), ServerError> { + if let State::Server(cookie) = self.state().clone() { + let msg = Message::ResponseKeepAlive(cookie); + self.send_message(&msg).await?; + self.0 = State::Client; + debug!("sent keepalive response message with cookie {}", cookie); + } + + Ok(()) + } + + pub async fn keepalive_roundtrip(&mut self) -> Result<(), ServerError> { + self.recv_keepalive_request().await?; + self.send_keepalive_response().await?; + + Ok(()) + } }