diff --git a/examples/block-download/src/main.rs b/examples/block-download/src/main.rs index 9dec596..a7deda0 100644 --- a/examples/block-download/src/main.rs +++ b/examples/block-download/src/main.rs @@ -5,7 +5,7 @@ use pallas::network::{ handshake::{n2n::VersionTable, Initiator}, run_agent, Point, MAINNET_MAGIC, }, - multiplexer::Multiplexer, + multiplexer::{spawn_demuxer, spawn_muxer, use_channel, StdPlexer}, }; use pallas::network::miniprotocols::blockfetch::{BatchClient, Observer}; @@ -30,11 +30,15 @@ fn main() { bearer.set_nodelay(true).unwrap(); bearer.set_keepalive_ms(Some(30_000u32)).unwrap(); - let mut muxer = Multiplexer::setup(bearer, &[0, 3]).unwrap(); + let mut plexer = StdPlexer::new(bearer); + let mut channel0 = use_channel(&mut plexer, 0); + let mut channel3 = use_channel(&mut plexer, 3); + + spawn_muxer(plexer.muxer); + spawn_demuxer(plexer.demuxer); - let mut hs_channel = muxer.use_channel(0); let versions = VersionTable::v4_and_above(MAINNET_MAGIC); - let _last = run_agent(Initiator::initial(versions), &mut hs_channel).unwrap(); + let _last = run_agent(Initiator::initial(versions), &mut channel0).unwrap(); let range = ( Point::Specific( @@ -49,8 +53,7 @@ fn main() { ), ); - let mut bf_channel = muxer.use_channel(3); let bf = BatchClient::initial(range, BlockPrinter {}); - let bf_last = run_agent(bf, &mut bf_channel); + let bf_last = run_agent(bf, &mut channel3); println!("{:?}", bf_last); } diff --git a/examples/n2c-miniprotocols/src/main.rs b/examples/n2c-miniprotocols/src/main.rs index c24cba1..2f79782 100644 --- a/examples/n2c-miniprotocols/src/main.rs +++ b/examples/n2c-miniprotocols/src/main.rs @@ -1,6 +1,6 @@ use pallas::network::{ miniprotocols::{chainsync, handshake, localstate, run_agent, Point, MAINNET_MAGIC}, - multiplexer::Multiplexer, + multiplexer, }; use std::os::unix::net::UnixStream; @@ -45,15 +45,12 @@ impl chainsync::Observer for LoggingObserver { } } -fn do_handshake(muxer: &mut Multiplexer) { - let mut channel = muxer.use_channel(0); +fn do_handshake(mut channel: multiplexer::StdChannel) { let versions = handshake::n2c::VersionTable::v1_and_above(MAINNET_MAGIC); let _last = run_agent(handshake::Initiator::initial(versions), &mut channel).unwrap(); } -fn do_localstate_query(muxer: &mut Multiplexer) { - let mut channel = muxer.use_channel(7); - +fn do_localstate_query(mut channel: multiplexer::StdChannel) { let agent = run_agent( localstate::OneShotClient::::initial( None, @@ -65,9 +62,7 @@ fn do_localstate_query(muxer: &mut Multiplexer) { log::info!("state query result: {:?}", agent); } -fn do_chainsync(muxer: &mut Multiplexer) { - let mut channel = muxer.use_channel(5); - +fn do_chainsync(mut channel: multiplexer::StdChannel) { let known_points = vec![Point::Specific( 43847831u64, hex::decode("15b9eeee849dd6386d3770b0745e0450190f7560e5159b1b3ab13b14b2684a45").unwrap(), @@ -95,14 +90,20 @@ fn main() { // setup the multiplexer by specifying the bearer and the IDs of the // miniprotocols to use - let mut muxer = Multiplexer::setup(bearer, &[0, 4, 5]).unwrap(); + let mut plexer = multiplexer::StdPlexer::new(bearer); + let channel0 = multiplexer::use_channel(&mut plexer, 0); + let channel7 = multiplexer::use_channel(&mut plexer, 7); + let channel5 = multiplexer::use_channel(&mut plexer, 5); + + multiplexer::spawn_muxer(plexer.muxer); + multiplexer::spawn_demuxer(plexer.demuxer); // execute the required handshake against the relay - do_handshake(&mut muxer); + do_handshake(channel0); // execute an arbitrary "Local State" query against the node - do_localstate_query(&mut muxer); + do_localstate_query(channel7); // execute the chainsync flow from an arbitrary point in the chain - do_chainsync(&mut muxer); + do_chainsync(channel5); } diff --git a/examples/n2n-miniprotocols/src/main.rs b/examples/n2n-miniprotocols/src/main.rs index 9f73dff..3ff4227 100644 --- a/examples/n2n-miniprotocols/src/main.rs +++ b/examples/n2n-miniprotocols/src/main.rs @@ -2,7 +2,7 @@ use net2::TcpStreamExt; use pallas::network::{ miniprotocols::{blockfetch, chainsync, handshake, run_agent, Point, MAINNET_MAGIC}, - multiplexer::Multiplexer, + multiplexer::{spawn_demuxer, spawn_muxer, use_channel, StdChannel, StdPlexer}, }; use std::net::TcpStream; @@ -54,15 +54,12 @@ impl chainsync::Observer for LoggingObserver { } } -fn do_handshake(muxer: &mut Multiplexer) { - let mut channel = muxer.use_channel(0); +fn do_handshake(mut channel: StdChannel) { let versions = handshake::n2n::VersionTable::v4_and_above(MAINNET_MAGIC); let _last = run_agent(handshake::Initiator::initial(versions), &mut channel).unwrap(); } -fn do_blockfetch(muxer: &mut Multiplexer) { - let mut channel = muxer.use_channel(3); - +fn do_blockfetch(mut channel: StdChannel) { let range = ( Point::Specific( 43847831, @@ -84,9 +81,7 @@ fn do_blockfetch(muxer: &mut Multiplexer) { println!("{:?}", agent); } -fn do_chainsync(muxer: &mut Multiplexer) { - let mut channel = muxer.use_channel(2); - +fn do_chainsync(mut channel: StdChannel) { let known_points = vec![Point::Specific( 43847831u64, hex::decode("15b9eeee849dd6386d3770b0745e0450190f7560e5159b1b3ab13b14b2684a45").unwrap(), @@ -116,14 +111,20 @@ fn main() { // setup the multiplexer by specifying the bearer and the IDs of the // miniprotocols to use - let mut muxer = Multiplexer::setup(bearer, &[0, 2, 3, 4]).unwrap(); + let mut plexer = StdPlexer::new(bearer); + let channel0 = use_channel(&mut plexer, 0); + let channel3 = use_channel(&mut plexer, 3); + let channel2 = use_channel(&mut plexer, 2); + + spawn_muxer(plexer.muxer); + spawn_demuxer(plexer.demuxer); // execute the required handshake against the relay - do_handshake(&mut muxer); + do_handshake(channel0); // fetch an arbitrary batch of block - do_blockfetch(&mut muxer); + do_blockfetch(channel3); // execute the chainsync flow from an arbitrary point in the chain - do_chainsync(&mut muxer); + do_chainsync(channel2); } diff --git a/pallas-addresses/Cargo.toml b/pallas-addresses/Cargo.toml new file mode 100644 index 0000000..fc4e231 --- /dev/null +++ b/pallas-addresses/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "pallas-primitives" +description = "Ledger primitives and cbor codec for the different Cardano eras" +version = "0.9.1" +edition = "2021" +repository = "https://github.com/txpipe/pallas" +homepage = "https://github.com/txpipe/pallas" +documentation = "https://docs.rs/pallas-byron" +license = "Apache-2.0" +readme = "README.md" +authors = [ + "Santiago Carmuega ", +] + +[dependencies] +hex = "0.4.3" +log = "0.4.14" +pallas-crypto = { version = "0.9.0", path = "../pallas-crypto" } +pallas-codec = { version = "0.9.0", path = "../pallas-codec" } +base58 = "0.2.0" +bech32 = "0.8.1" +serde = { version ="1.0.136", optional = true } +serde_json = { version ="1.0.79", optional = true } + +[features] +json = ["serde", "serde_json"] +default = ["json"] diff --git a/pallas-miniprotocols/Cargo.toml b/pallas-miniprotocols/Cargo.toml index 4ecf57e..9cdb24f 100644 --- a/pallas-miniprotocols/Cargo.toml +++ b/pallas-miniprotocols/Cargo.toml @@ -13,8 +13,8 @@ authors = [ ] [dependencies] -pallas-multiplexer = { version = "0.9.0", path = "../pallas-multiplexer/" } pallas-codec = { version = "0.9.0", path = "../pallas-codec/" } +pallas-multiplexer = { version = "0.9.0", path = "../pallas-multiplexer/" } log = "0.4.14" hex = "0.4.3" itertools = "0.10.3" diff --git a/pallas-miniprotocols/src/blockfetch/mod.rs b/pallas-miniprotocols/src/blockfetch/mod.rs index 980f8d6..6f8c2e7 100644 --- a/pallas-miniprotocols/src/blockfetch/mod.rs +++ b/pallas-miniprotocols/src/blockfetch/mod.rs @@ -1,4 +1,5 @@ use crate::machines::{Agent, Transition}; +use crate::MachineError; use crate::common::Point; @@ -155,7 +156,9 @@ where fn on_block(mut self, body: Vec) -> Transition { log::debug!("received block body, size {}", body.len()); - self.observer.on_block_received(body)?; + self.observer + .on_block_received(body) + .map_err(MachineError::downstream)?; Ok(self) } @@ -180,6 +183,11 @@ where O: Observer, { type Message = Message; + type State = State; + + fn state(&self) -> &Self::State { + &self.state + } fn is_done(&self) -> bool { self.state == State::Done @@ -294,7 +302,9 @@ where fn on_block(mut self, body: Vec) -> Transition { log::debug!("received block body, size {}", body.len()); - self.observer.on_block_received(body)?; + self.observer + .on_block_received(body) + .map_err(MachineError::downstream)?; Ok(self) } @@ -317,6 +327,11 @@ where O: Observer, { type Message = Message; + type State = State; + + fn state(&self) -> &Self::State { + &self.state + } fn is_done(&self) -> bool { self.state == State::Done diff --git a/pallas-miniprotocols/src/chainsync/agents.rs b/pallas-miniprotocols/src/chainsync/agents.rs index 00b9e6b..f6a1fdd 100644 --- a/pallas-miniprotocols/src/chainsync/agents.rs +++ b/pallas-miniprotocols/src/chainsync/agents.rs @@ -44,6 +44,7 @@ pub trait Observer { Ok(Continuation::Proceed) } + fn on_tip_reached(&mut self) -> Result> { log::debug!("tip was reached"); @@ -59,6 +60,7 @@ impl Observer for NoopObserver {} #[derive(Debug)] pub struct Consumer where + Self: Agent, O: Observer, { pub state: State, @@ -77,6 +79,7 @@ impl Consumer where O: Observer, Message: Fragment, + C: std::fmt::Debug + 'static, { pub fn initial(known_points: Option>, observer: O) -> Self { Self { @@ -93,7 +96,10 @@ where fn on_intersect_found(mut self, point: Point, tip: Tip) -> Transition { log::debug!("intersect found: {:?} (tip: {:?})", point, tip); - let continuation = self.observer.on_intersect_found(&point, &tip)?; + let continuation = self + .observer + .on_intersect_found(&point, &tip) + .map_err(MachineError::downstream)?; Ok(Self { tip: Some(tip), @@ -118,7 +124,10 @@ where fn on_roll_forward(mut self, content: C, tip: Tip) -> Transition { log::debug!("rolling forward"); - let continuation = self.observer.on_roll_forward(content, &tip)?; + let continuation = self + .observer + .on_roll_forward(content, &tip) + .map_err(MachineError::downstream)?; Ok(Self { tip: Some(tip), @@ -131,7 +140,10 @@ where fn on_roll_backward(mut self, point: Point, tip: Tip) -> Transition { log::debug!("rolling backward to point: {:?}", point); - let continuation = self.observer.on_rollback(&point)?; + let continuation = self + .observer + .on_rollback(&point) + .map_err(MachineError::downstream)?; Ok(Self { tip: Some(tip), @@ -145,7 +157,10 @@ where fn on_await_reply(mut self) -> Transition { log::debug!("reached tip, await reply"); - let continuation = self.observer.on_tip_reached()?; + let continuation = self + .observer + .on_tip_reached() + .map_err(MachineError::downstream)?; Ok(Self { state: State::MustReply, @@ -162,6 +177,11 @@ where Message: Fragment, { type Message = Message; + type State = State; + + fn state(&self) -> &Self::State { + &self.state + } fn is_done(&self) -> bool { self.state == State::Done || self.continuation == Continuation::DropOut @@ -230,7 +250,7 @@ where self.on_intersect_found(point, tip) } (State::Intersect, Message::IntersectNotFound(tip)) => self.on_intersect_not_found(tip), - (_, msg) => Err(MachineError::InvalidMsgForState(self.state, msg).into()), + (_, msg) => Err(MachineError::InvalidMsgForState(self.state, msg)), } } } @@ -278,6 +298,11 @@ pub type BlockConsumer = Consumer; impl Agent for TipFinder { type Message = Message; + type State = State; + + fn state(&self) -> &Self::State { + &self.state + } fn is_done(&self) -> bool { self.state == State::Done @@ -322,7 +347,7 @@ impl Agent for TipFinder { self.on_intersect_found(tip) } (State::Intersect, Message::IntersectNotFound(tip)) => self.on_intersect_not_found(tip), - (state, msg) => Err(MachineError::InvalidMsgForState(state.clone(), msg).into()), + (state, msg) => Err(MachineError::InvalidMsgForState(state.clone(), msg)), } } } diff --git a/pallas-miniprotocols/src/handshake/agents.rs b/pallas-miniprotocols/src/handshake/agents.rs index 0087aa6..9a4d583 100644 --- a/pallas-miniprotocols/src/handshake/agents.rs +++ b/pallas-miniprotocols/src/handshake/agents.rs @@ -39,6 +39,11 @@ where D: Debug + Clone, { type Message = Message; + type State = State; + + fn state(&self) -> &Self::State { + &self.state + } fn is_done(&self) -> bool { self.state == State::Done diff --git a/pallas-miniprotocols/src/lib.rs b/pallas-miniprotocols/src/lib.rs index 00ea91f..67706c3 100644 --- a/pallas-miniprotocols/src/lib.rs +++ b/pallas-miniprotocols/src/lib.rs @@ -1,6 +1,5 @@ mod common; mod machines; -mod payloads; pub mod blockfetch; pub mod chainsync; @@ -10,4 +9,3 @@ pub mod txsubmission; pub use common::*; pub use machines::*; -pub use payloads::*; diff --git a/pallas-miniprotocols/src/localstate/mod.rs b/pallas-miniprotocols/src/localstate/mod.rs index ad278f1..cdea7ed 100644 --- a/pallas-miniprotocols/src/localstate/mod.rs +++ b/pallas-miniprotocols/src/localstate/mod.rs @@ -52,7 +52,7 @@ pub struct OneShotClient { impl OneShotClient where - Q: Query, + Q: Query + 'static, Message: Fragment, { pub fn initial(check_point: Option, request: Q::Request) -> Self { @@ -101,6 +101,11 @@ where Message: Fragment, { type Message = Message; + type State = State; + + fn state(&self) -> &Self::State { + &self.state + } fn is_done(&self) -> bool { self.state == State::Done @@ -158,7 +163,7 @@ where (State::Acquiring, Message::Acquired) => self.on_acquired(), (State::Acquiring, Message::Failure(failure)) => self.on_failure(failure), (State::Querying, Message::Result(result)) => self.on_result(result), - (_, msg) => Err(MachineError::InvalidMsgForState(self.state, msg).into()), + (_, msg) => Err(MachineError::InvalidMsgForState(self.state, msg)), } } } diff --git a/pallas-miniprotocols/src/machines.rs b/pallas-miniprotocols/src/machines.rs index 11d4a3f..f151a56 100644 --- a/pallas-miniprotocols/src/machines.rs +++ b/pallas-miniprotocols/src/machines.rs @@ -1,84 +1,31 @@ -pub use crate::payloads::*; -use pallas_codec::{minicbor, Fragment}; -use pallas_multiplexer::{Channel, Payload}; +use pallas_codec::Fragment; +use pallas_multiplexer::agents::{Channel, ChannelBuffer, ChannelError}; use std::cell::Cell; -use std::fmt::{Debug, Display}; -use std::sync::mpsc::Sender; #[derive(Debug)] -pub enum MachineError -where - State: Debug, - Msg: Debug, -{ - InvalidMsgForState(State, Msg), +pub enum MachineError { + InvalidMsgForState(A::State, A::Message), + ChannelError(ChannelError), + DownstreamError(Box), } -impl Display for MachineError -where - S: Debug, - M: Debug, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MachineError::InvalidMsgForState(msg, state) => { - write!( - f, - "received invalid message ({:?}) for current state ({:?})", - msg, state - ) - } - } +impl MachineError { + pub fn channel(err: ChannelError) -> Self { + Self::ChannelError(err) + } + + pub fn downstream(err: Box) -> Self { + Self::DownstreamError(err) } } -impl std::error::Error for MachineError -where - S: Debug, - M: Debug, -{ -} - -#[derive(Debug)] -pub enum CodecError { - BadLabel(u16), - UnexpectedCbor(&'static str), -} - -impl std::error::Error for CodecError {} - -impl Display for CodecError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - CodecError::BadLabel(label) => { - write!(f, "unknown message label: {}", label) - } - CodecError::UnexpectedCbor(msg) => { - write!(f, "unexpected cbor: {}", msg) - } - } - } -} - -pub trait MachineOutput { - fn send_msg(&self, data: &impl Fragment) -> Result<(), Box>; -} - -impl MachineOutput for Sender { - fn send_msg(&self, data: &impl Fragment) -> Result<(), Box> { - let mut payload = Vec::new(); - minicbor::encode(data, &mut payload)?; - self.send(payload)?; - - Ok(()) - } -} - -pub type Transition = Result>; +pub type Transition = Result>; pub trait Agent: Sized { type Message; + type State; + fn state(&self) -> &Self::State; fn is_done(&self) -> bool; fn has_agency(&self) -> bool; fn build_next(&self) -> Self::Message; @@ -87,36 +34,38 @@ pub trait Agent: Sized { fn apply_inbound(self, msg: Self::Message) -> Transition; } -pub struct Runner +pub struct Runner<'c, A, C> where A: Agent, + C: Channel, { agent: Cell>, - buffer: Vec, + buffer: ChannelBuffer<'c, C>, } -impl<'a, A> Runner +impl<'c, A, C> Runner<'c, A, C> where A: Agent, - A::Message: Fragment + Debug, + A::Message: Fragment + std::fmt::Debug, + C: Channel, { - pub fn new(agent: A) -> Self { + pub fn new(agent: A, channel: &'c mut C) -> Self { Self { agent: Cell::new(Some(agent)), - buffer: Vec::new(), + buffer: ChannelBuffer::new(channel), } } - pub fn start(&mut self) -> Result<(), Error> { + pub fn start(&mut self) -> Result<(), MachineError> { let prev = self.agent.take().unwrap(); let next = prev.apply_start()?; self.agent.set(Some(next)); Ok(()) } - pub fn run_step(&mut self, channel: &mut Channel) -> Result { + pub fn run_step(&mut self) -> Result> { let prev = self.agent.take().unwrap(); - let next = run_agent_step(prev, channel, &mut self.buffer)?; + let next = run_agent_step(prev, &mut self.buffer)?; let is_done = next.is_done(); self.agent.set(Some(next)); @@ -124,35 +73,35 @@ where Ok(is_done) } - pub fn fulfill(mut self, channel: &mut Channel) -> Result<(), Error> { + pub fn fulfill(mut self) -> Result<(), MachineError> { self.start()?; - while self.run_step(channel)? {} + while self.run_step()? {} Ok(()) } } -pub fn run_agent_step(agent: T, channel: &mut Channel, buffer: &mut Vec) -> Transition +pub fn run_agent_step(agent: A, channel: &mut ChannelBuffer) -> Transition where - T: Agent, - T::Message: Fragment + Debug, + A: Agent, + A::Message: Fragment + std::fmt::Debug, + C: Channel, { - let Channel(tx, rx) = channel; - match agent.has_agency() { true => { let msg = agent.build_next(); log::trace!("processing outbound msg: {:?}", msg); - let mut payload = Vec::new(); - minicbor::encode(&msg, &mut payload)?; - tx.send(payload)?; + channel + .send_msg_chunks(&msg) + .map_err(MachineError::channel)?; agent.apply_outbound(msg) } false => { - let msg = read_until_full_msg::(buffer, rx).unwrap(); + let msg = channel.recv_full_msg().map_err(MachineError::channel)?; + log::trace!("procesing inbound msg: {:?}", msg); agent.apply_inbound(msg) @@ -160,17 +109,18 @@ where } } -pub fn run_agent(agent: T, channel: &mut Channel) -> Result> +pub fn run_agent(agent: A, channel: &mut C) -> Transition where - T: Agent, - T::Message: Fragment + Debug, + A: Agent, + A::Message: Fragment + std::fmt::Debug, + C: Channel, { - let mut buffer = Vec::new(); + let mut buffer = ChannelBuffer::new(channel); let mut agent = agent.apply_start()?; while !agent.is_done() { - agent = run_agent_step(agent, channel, &mut buffer)?; + agent = run_agent_step(agent, &mut buffer)?; } Ok(agent) diff --git a/pallas-miniprotocols/src/payloads.rs b/pallas-miniprotocols/src/payloads.rs deleted file mode 100644 index c053e20..0000000 --- a/pallas-miniprotocols/src/payloads.rs +++ /dev/null @@ -1,104 +0,0 @@ -use pallas_codec::{minicbor, Fragment}; -use pallas_multiplexer::Payload; -use std::sync::mpsc::Receiver; - -pub type Error = Box; - -enum Decoding { - Done(M, usize), - NotEnoughData, - UnexpectedError(Error), -} - -fn try_decode_message(buffer: &[u8]) -> Decoding -where - M: Fragment, -{ - let mut decoder = minicbor::Decoder::new(buffer); - let maybe_msg = decoder.decode(); - - match maybe_msg { - Ok(msg) => Decoding::Done(msg, decoder.position()), - Err(err) if err.is_end_of_input() => Decoding::NotEnoughData, - Err(err) => Decoding::UnexpectedError(Box::new(err)), - } -} - -/// Reads from the receiver until a complete message is found -pub fn read_until_full_msg( - buffer: &mut Vec, - receiver: &mut Receiver, -) -> Result -where - M: Fragment, -{ - // do an eager reading if buffer is empty, no point in going through the error - // handling - if buffer.is_empty() { - let chunk = receiver.recv()?; - buffer.extend(chunk); - } - - let decoding = try_decode_message::(buffer); - - match decoding { - Decoding::Done(msg, pos) => { - buffer.drain(0..pos); - Ok(msg) - } - Decoding::UnexpectedError(err) => Err(err), - Decoding::NotEnoughData => { - let chunk = receiver.recv()?; - buffer.extend(chunk); - - read_until_full_msg::(buffer, receiver) - } - } -} - -#[cfg(test)] -mod tests { - use crate::read_until_full_msg; - use pallas_codec::minicbor; - use std::sync::mpsc::channel; - - #[test] - fn multiple_messages_in_same_payload() { - let mut input = Vec::new(); - let in_part1 = (1u8, 2u8, 3u8); - let in_part2 = (6u8, 5u8, 4u8); - - minicbor::encode(in_part1, &mut input).unwrap(); - minicbor::encode(in_part2, &mut input).unwrap(); - - let (tx, mut rx) = channel(); - tx.send(input).unwrap(); - - let mut output = Vec::new(); - let out_part1 = read_until_full_msg::<(u8, u8, u8)>(&mut output, &mut rx).unwrap(); - let out_part2 = read_until_full_msg::<(u8, u8, u8)>(&mut output, &mut rx).unwrap(); - - assert_eq!(in_part1, out_part1); - assert_eq!(in_part2, out_part2); - } - - #[test] - fn fragmented_message_in_multiple_payload() { - let mut input = Vec::new(); - let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8); - minicbor::encode(msg, &mut input).unwrap(); - - let (tx, mut rx) = channel(); - - while !input.is_empty() { - let chunk = Vec::from(input.drain(0..2).as_slice()); - tx.send(chunk).unwrap(); - } - - let mut output = Vec::new(); - let out_msg = - read_until_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>(&mut output, &mut rx).unwrap(); - - assert_eq!(msg, out_msg); - } -} diff --git a/pallas-miniprotocols/src/txsubmission/mod.rs b/pallas-miniprotocols/src/txsubmission/mod.rs index 9a08c01..e703203 100644 --- a/pallas-miniprotocols/src/txsubmission/mod.rs +++ b/pallas-miniprotocols/src/txsubmission/mod.rs @@ -239,6 +239,11 @@ impl NaiveProvider { impl Agent for NaiveProvider { type Message = Message; + type State = State; + + fn state(&self) -> &Self::State { + &self.state + } fn is_done(&self) -> bool { self.state == State::Done @@ -295,7 +300,7 @@ impl Agent for NaiveProvider { ..self }), (State::Idle, Message::RequestTxs(ids)) => self.on_txs_request(ids), - (_, msg) => Err(MachineError::InvalidMsgForState(self.state, msg).into()), + (_, msg) => Err(MachineError::InvalidMsgForState(self.state, msg)), } } } diff --git a/pallas-multiplexer/Cargo.toml b/pallas-multiplexer/Cargo.toml index fabb7ce..18a131e 100644 --- a/pallas-multiplexer/Cargo.toml +++ b/pallas-multiplexer/Cargo.toml @@ -13,10 +13,15 @@ authors = [ ] [dependencies] +pallas-codec = { version = "0.9.0", path = "../pallas-codec/" } log = "0.4.14" byteorder = "1.4.3" hex = "0.4.3" +rand = "0.8.4" [dev-dependencies] -rand = "0.8.4" env_logger = "0.9.0" + +[features] +std = [] +default = ["std"] \ No newline at end of file diff --git a/pallas-multiplexer/examples/listener.rs b/pallas-multiplexer/examples/listener.rs deleted file mode 100644 index b89ba27..0000000 --- a/pallas-multiplexer/examples/listener.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::{net::TcpListener, thread, time::Duration}; - -use log::info; -use pallas_multiplexer::{Channel, Multiplexer}; - -const PROTOCOLS: [u16; 2] = [0x8002u16, 0x8003u16]; - -fn main() { - env_logger::init(); - - let server = TcpListener::bind("0.0.0.0:3001").unwrap(); - - info!("listening for connections on port 3001"); - let (bearer, _) = server.accept().unwrap(); - - let mut muxer = Multiplexer::setup(bearer, &PROTOCOLS).unwrap(); - - for protocol in PROTOCOLS { - let handle = muxer.use_channel(protocol); - - thread::spawn(move || { - info!("starting thread for protocol: {}", protocol); - - let Channel(_, rx) = handle; - - loop { - let payload = rx.recv().unwrap(); - info!( - "got message within thread, id:{}, length:{}", - protocol, - payload.len() - ); - } - }); - } - - loop { - thread::sleep(Duration::from_secs(6000)); - } -} diff --git a/pallas-multiplexer/examples/sender.rs b/pallas-multiplexer/examples/sender.rs deleted file mode 100644 index 3627342..0000000 --- a/pallas-multiplexer/examples/sender.rs +++ /dev/null @@ -1,33 +0,0 @@ -use std::{net::TcpStream, thread, time::Duration}; - -use log::info; -use pallas_multiplexer::{Channel, Multiplexer}; - -const PROTOCOLS: [u16; 2] = [0x0002u16, 0x0003u16]; - -fn main() { - env_logger::init(); - - info!("connecting to tcp socket on 127.0.0.1:3001"); - let bearer = TcpStream::connect("127.0.0.1:3001").unwrap(); - let mut muxer = Multiplexer::setup(bearer, &PROTOCOLS).unwrap(); - - for protocol in PROTOCOLS { - let handle = muxer.use_channel(protocol); - - thread::spawn(move || { - let Channel(tx, _) = handle; - - loop { - let payload = vec![1; 65545]; - info!("sending dumb payload for protocol: {}", protocol); - tx.send(payload).unwrap(); - thread::sleep(Duration::from_millis(500u64 + (protocol as u64 * 10u64))); - } - }); - } - - loop { - thread::sleep(Duration::from_secs(6000)); - } -} diff --git a/pallas-multiplexer/src/agents.rs b/pallas-multiplexer/src/agents.rs new file mode 100644 index 0000000..ee31e76 --- /dev/null +++ b/pallas-multiplexer/src/agents.rs @@ -0,0 +1,103 @@ +//! Interface to interact with the multiplexer as an agent + +use crate::Payload; + +use pallas_codec::{minicbor, Fragment}; + +#[derive(Debug)] +pub enum ChannelError { + NotConnected(Option), + Encoding(String), + Decoding(String), +} + +/// A raw link to the ingress / egress of the multiplexer +pub trait Channel { + fn enqueue_chunk(&mut self, chunk: Payload) -> Result<(), ChannelError>; + fn dequeue_chunk(&mut self) -> Result; +} + +/// Protocol value that defines max segment length +pub const MAX_SEGMENT_PAYLOAD_LENGTH: usize = 65535; + +enum Decoding { + Done(M, usize), + NotEnoughData, + UnexpectedError(Box), +} + +fn try_decode_message(buffer: &[u8]) -> Decoding +where + M: Fragment, +{ + let mut decoder = minicbor::Decoder::new(buffer); + let maybe_msg = decoder.decode(); + + match maybe_msg { + Ok(msg) => Decoding::Done(msg, decoder.position()), + Err(err) if err.is_end_of_input() => Decoding::NotEnoughData, + Err(err) => Decoding::UnexpectedError(Box::new(err)), + } +} + +/// A channel abstraction to hide the complexity of partial payloads +pub struct ChannelBuffer<'c, C: Channel> { + channel: &'c mut C, + temp: Vec, +} + +impl<'c, C: Channel> ChannelBuffer<'c, C> { + pub fn new(channel: &'c mut C) -> Self { + Self { + channel, + temp: Vec::new(), + } + } + + /// Enqueues a msg as a sequence payload chunks + pub fn send_msg_chunks(&mut self, msg: &M) -> Result<(), ChannelError> + where + M: Fragment, + { + let mut payload = Vec::new(); + minicbor::encode(&msg, &mut payload) + .map_err(|err| ChannelError::Encoding(err.to_string()))?; + + let chunks = payload.chunks(MAX_SEGMENT_PAYLOAD_LENGTH); + + for chunk in chunks { + self.channel.enqueue_chunk(Vec::from(chunk))?; + } + + Ok(()) + } + + /// Reads from the channel until a complete message is found + pub fn recv_full_msg(&mut self) -> Result + where + M: Fragment, + { + // do an eager reading if buffer is empty, no point in going through the error + // handling + if self.temp.is_empty() { + let chunk = self.channel.dequeue_chunk()?; + self.temp.extend(chunk); + } + + let decoding = try_decode_message::(&self.temp); + + match decoding { + Decoding::Done(msg, pos) => { + self.temp.drain(0..pos); + Ok(msg) + } + Decoding::UnexpectedError(err) => Err(ChannelError::Decoding(err.to_string())), + Decoding::NotEnoughData => { + let chunk = self.channel.dequeue_chunk()?; + self.temp.extend(chunk); + + self.recv_full_msg() + } + } + } +} diff --git a/pallas-multiplexer/src/bearers.rs b/pallas-multiplexer/src/bearers.rs index bf980ba..2520fd7 100644 --- a/pallas-multiplexer/src/bearers.rs +++ b/pallas-multiplexer/src/bearers.rs @@ -5,35 +5,62 @@ use std::io::{Read, Write}; use std::os::unix::net::UnixStream; use std::{net::TcpStream, time::Instant}; -use crate::{Bearer, Payload}; +use crate::Payload; + +pub struct Segment { + pub protocol: u16, + pub timestamp: u32, + pub payload: Payload, +} + +pub trait Bearer: Read + Write + Send + Sync + Sized { + type Error: std::error::Error; + + fn read_segment(&mut self) -> Result, Self::Error>; + + fn write_segment(&mut self, segment: Segment) -> Result<(), Self::Error>; + + fn clone(&self) -> Self; +} + +impl Segment { + pub fn new(clock: Instant, protocol: u16, payload: Payload) -> Self { + Segment { + timestamp: clock.elapsed().as_micros() as u32, + protocol, + payload, + } + } +} + +fn write_segment(writer: &mut impl Write, segment: Segment) -> Result<(), std::io::Error> { + let Segment { + timestamp, + protocol, + payload, + } = segment; -fn write_segment( - writer: &mut impl Write, - clock: Instant, - protocol_id: u16, - payload: &[u8], -) -> Result<(), std::io::Error> { let mut msg = Vec::new(); - msg.write_u32::(clock.elapsed().as_micros() as u32)?; - msg.write_u16::(protocol_id)?; + msg.write_u32::(timestamp)?; + msg.write_u16::(protocol)?; msg.write_u16::(payload.len() as u16)?; if log_enabled!(log::Level::Trace) { trace!( "sending segment, header {:?}, protocol id: {}, payload length: {}", hex::encode(&msg), - protocol_id, + protocol, payload.len() ); } - msg.write_all(payload)?; + msg.write_all(&payload)?; writer.write_all(&msg)?; writer.flush() } -fn read_segment(reader: &mut impl Read) -> Result<(u16, u32, Payload), std::io::Error> { +fn read_segment(reader: &mut impl Read) -> Result { let mut header = [0u8; 8]; reader.read_exact(&mut header)?; @@ -43,12 +70,12 @@ fn read_segment(reader: &mut impl Read) -> Result<(u16, u32, Payload), std::io:: } let length = NetworkEndian::read_u16(&header[6..]) as usize; - let id = NetworkEndian::read_u16(&header[4..6]) as usize ^ 0x8000; - let ts = NetworkEndian::read_u32(&header[0..4]); + let protocol = NetworkEndian::read_u16(&header[4..6]) as usize ^ 0x8000; + let timestamp = NetworkEndian::read_u32(&header[0..4]); debug!( "parsed inbound msg, protocol id: {}, ts: {}, payload length: {}", - id, ts, length + protocol, timestamp, length ); let mut payload = vec![0u8; length]; @@ -58,44 +85,54 @@ fn read_segment(reader: &mut impl Read) -> Result<(u16, u32, Payload), std::io:: trace!("read segment payload: {:?}", hex::encode(&payload)); } - Ok((id as u16, ts, payload)) + Ok(Segment { + protocol: protocol as u16, + timestamp, + payload, + }) +} + +fn read_segment_with_timeout(reader: &mut impl Read) -> Result, std::io::Error> { + match read_segment(reader) { + Ok(s) => Ok(Some(s)), + Err(err) => match err.kind() { + std::io::ErrorKind::WouldBlock => Ok(None), + std::io::ErrorKind::TimedOut => Ok(None), + std::io::ErrorKind::Interrupted => Ok(None), + _ => todo!(), + }, + } } impl Bearer for TcpStream { + type Error = std::io::Error; + fn clone(&self) -> Self { self.try_clone().expect("error cloning tcp stream") } - fn read_segment(&mut self) -> Result<(u16, u32, Payload), std::io::Error> { - read_segment(self) + fn read_segment(&mut self) -> Result, std::io::Error> { + read_segment_with_timeout(self) } - fn write_segment( - &mut self, - clock: Instant, - protocol_id: u16, - partial_payload: &[u8], - ) -> Result<(), std::io::Error> { - write_segment(self, clock, protocol_id, partial_payload) + fn write_segment(&mut self, segment: Segment) -> Result<(), std::io::Error> { + write_segment(self, segment) } } #[cfg(target_family = "unix")] impl Bearer for UnixStream { + type Error = std::io::Error; + fn clone(&self) -> Self { self.try_clone().expect("error cloning unix stream") } - fn read_segment(&mut self) -> Result<(u16, u32, Payload), std::io::Error> { - read_segment(self) + fn read_segment(&mut self) -> Result, std::io::Error> { + read_segment_with_timeout(self) } - fn write_segment( - &mut self, - clock: Instant, - protocol_id: u16, - partial_payload: &[u8], - ) -> Result<(), std::io::Error> { - write_segment(self, clock, protocol_id, partial_payload) + fn write_segment(&mut self, segment: Segment) -> Result<(), std::io::Error> { + write_segment(self, segment) } } diff --git a/pallas-multiplexer/src/demux.rs b/pallas-multiplexer/src/demux.rs new file mode 100644 index 0000000..4342085 --- /dev/null +++ b/pallas-multiplexer/src/demux.rs @@ -0,0 +1,83 @@ +use std::collections::HashMap; + +use crate::{bearers::Bearer, std::Cancel, Payload}; + +pub struct EgressError(pub Payload); + +pub trait Egress { + fn send(&self, payload: Payload) -> Result<(), EgressError>; +} + +pub enum DemuxError { + BearerError(B::Error), + EgressDisconnected(u16, Payload), + EgressUnknown(u16, Payload), +} + +pub enum TickOutcome { + Busy, + Idle, +} + +/// A demuxer that reads from a bearer into the corresponding egress +pub struct Demuxer { + bearer: B, + egress: HashMap, +} + +impl Demuxer +where + B: Bearer, + E: Egress, +{ + pub fn new(bearer: B) -> Self { + Demuxer { + bearer, + egress: Default::default(), + } + } + + pub fn register(&mut self, id: u16, tx: E) { + self.egress.insert(id, tx); + } + + fn dispatch(&self, protocol: u16, payload: Payload) -> Result<(), DemuxError> { + match self.egress.get(&protocol) { + Some(tx) => match tx.send(payload) { + Err(EgressError(p)) => Err(DemuxError::EgressDisconnected(protocol, p)), + Ok(_) => Ok(()), + }, + None => Err(DemuxError::EgressUnknown(protocol, payload)), + } + } + + pub fn tick(&mut self) -> Result> { + match self.bearer.read_segment() { + Err(err) => Err(DemuxError::BearerError(err)), + Ok(None) => Ok(TickOutcome::Idle), + Ok(Some(segment)) => match self.dispatch(segment.protocol, segment.payload) { + Err(err) => Err(err), + Ok(()) => Ok(TickOutcome::Busy), + }, + } + } + + pub fn block(&mut self, cancel: Cancel) -> Result<(), B::Error> { + loop { + match self.tick() { + Ok(TickOutcome::Busy) => (), + Ok(TickOutcome::Idle) => match cancel.is_set() { + true => break Ok(()), + false => (), + }, + Err(DemuxError::BearerError(err)) => return Err(err), + Err(DemuxError::EgressDisconnected(id, _)) => { + log::warn!("disconnected protocol {}", id) + } + Err(DemuxError::EgressUnknown(id, _)) => { + log::warn!("unknown protocol {}", id) + } + } + } + } +} diff --git a/pallas-multiplexer/src/lib.rs b/pallas-multiplexer/src/lib.rs index 0a77822..bb4b852 100644 --- a/pallas-multiplexer/src/lib.rs +++ b/pallas-multiplexer/src/lib.rs @@ -1,184 +1,41 @@ -mod bearers; - -use std::{ - collections::HashMap, - io::{Read, Write}, - sync::mpsc::{self, Receiver, Sender, TryRecvError}, - thread::{self, JoinHandle}, - time::{Duration, Instant}, -}; - -use log::{debug, error, warn}; - -pub trait Bearer: Read + Write + Send + Sync + Sized { - fn read_segment(&mut self) -> Result<(u16, u32, Payload), std::io::Error>; - - fn write_segment( - &mut self, - clock: Instant, - protocol_id: u16, - partial_payload: &[u8], - ) -> Result<(), std::io::Error>; - - fn clone(&self) -> Self; -} - -const MAX_SEGMENT_PAYLOAD_LENGTH: usize = 65535; +pub mod agents; +pub mod bearers; +pub mod demux; +pub mod mux; pub type Payload = Vec; -enum TxStepError { - BearerError(std::io::Error), - IngressDisconnected, - IngressEmpty, +pub struct Multiplexer +where + B: bearers::Bearer, + I: mux::Ingress, + E: demux::Egress, +{ + pub muxer: mux::Muxer, + pub demuxer: demux::Demuxer, } -fn tx_step( - bearer: &mut TBearer, - ingress_id: u16, - ingress_rx: &mut Receiver, - clock: Instant, -) -> Result<(), TxStepError> +impl Multiplexer where - TBearer: Bearer, + B: bearers::Bearer, + I: mux::Ingress, + E: demux::Egress, { - match ingress_rx.try_recv() { - Ok(payload) => { - let chunks = payload.chunks(MAX_SEGMENT_PAYLOAD_LENGTH); - - for chunk in chunks { - bearer - .write_segment(clock, ingress_id, chunk) - .map_err(TxStepError::BearerError)?; - } - - Ok(()) - } - Err(TryRecvError::Disconnected) => Err(TxStepError::IngressDisconnected), - Err(TryRecvError::Empty) => Err(TxStepError::IngressEmpty), - } -} - -fn tx_loop(bearer: &mut TBearer, ingress: MuxIngress) -where - TBearer: Bearer, -{ - let mut rx_map: HashMap<_, _> = ingress.into_iter().collect(); - - loop { - let clock = Instant::now(); - - rx_map.retain(|id, rx| match tx_step(bearer, *id, rx, clock) { - Err(TxStepError::BearerError(err)) => { - error!("{:?}", err); - panic!(); - } - Err(TxStepError::IngressDisconnected) => { - warn!("protocol handle {} disconnected", id); - false - } - Err(TxStepError::IngressEmpty) => { - thread::sleep(Duration::from_millis(10)); - true - } - Ok(_) => true, - }); - } -} - -fn rx_loop(bearer: &mut TBearer, egress: DemuxerEgress) -where - TBearer: Bearer, -{ - let mut tx_map: HashMap<_, _> = egress.into_iter().collect(); - - loop { - match bearer.read_segment() { - Err(err) => { - error!("{:?}", err); - panic!(); - } - Ok(segment) => { - let (id, _ts, payload) = segment; - match tx_map.get(&id) { - Some(tx) => match tx.send(payload) { - Err(err) => { - error!("error sending egress tx to protocol, removing protocol from egress output. {:?}", err); - tx_map.remove(&id); - } - Ok(_) => { - debug!("successful tx to egress protocol"); - } - }, - None => warn!("received segment for protocol id not being demuxed {}", id), - } - } + pub fn new(bearer: B) -> Self { + Multiplexer { + muxer: mux::Muxer::new(bearer.clone()), + demuxer: demux::Demuxer::new(bearer.clone()), } } -} -pub struct Channel(pub Sender, pub Receiver); - -type ChannelProtocolHandle = (u16, Channel); -type ChannelIngressHandle = (u16, Receiver); -type ChannelEgressHandle = (u16, Sender); -type MuxIngress = Vec; -type DemuxerEgress = Vec; - -pub struct Multiplexer { - tx_thread: JoinHandle<()>, - rx_thread: JoinHandle<()>, - io_handles: HashMap, -} - -impl Multiplexer { - pub fn setup( - bearer: TBearer, - protocols: &[u16], - ) -> Result> - where - TBearer: Bearer + 'static, - { - let handles = protocols.iter().map(|id| { - let (demux_tx, demux_rx) = mpsc::channel::(); - let (mux_tx, mux_rx) = mpsc::channel::(); - - let channel = Channel(mux_tx, demux_rx); - - let protocol_handle: ChannelProtocolHandle = (*id, channel); - let ingress_handle: ChannelIngressHandle = (*id, mux_rx); - let egress_handle: ChannelEgressHandle = (*id, demux_tx); - - (protocol_handle, (ingress_handle, egress_handle)) - }); - - let (protocol_handles, multiplex_handles): (Vec<_>, Vec<_>) = handles.into_iter().unzip(); - - let (ingress, egress): (Vec<_>, Vec<_>) = multiplex_handles.into_iter().unzip(); - - let mut tx_bearer = bearer.clone(); - let tx_thread = thread::spawn(move || tx_loop(&mut tx_bearer, ingress)); - - let mut rx_bearer = bearer.clone(); - let rx_thread = thread::spawn(move || rx_loop(&mut rx_bearer, egress)); - - let io_handles: HashMap = protocol_handles.into_iter().collect(); - - Ok(Multiplexer { - io_handles, - tx_thread, - rx_thread, - }) - } - - pub fn use_channel(&mut self, protocol_id: u16) -> Channel { - self.io_handles - .remove(&protocol_id) - .expect("requested channel not found in multiplexer") - } - - pub fn join(self) { - self.tx_thread.join().expect("error joining tx loop thread"); - self.rx_thread.join().expect("error joining rx loop thread"); + pub fn register_channel(&mut self, protocol: u16, ingress: I, egress: E) { + self.muxer.register(protocol, ingress); + self.demuxer.register(protocol, egress); } } + +#[cfg(feature = "std")] +mod std; + +#[cfg(feature = "std")] +pub use crate::std::*; diff --git a/pallas-multiplexer/src/mux.rs b/pallas-multiplexer/src/mux.rs new file mode 100644 index 0000000..fdd4461 --- /dev/null +++ b/pallas-multiplexer/src/mux.rs @@ -0,0 +1,122 @@ +use std::{collections::HashMap, time::Instant}; + +use rand::seq::SliceRandom; +use rand::thread_rng; + +use crate::{ + bearers::{Bearer, Segment}, + std::Cancel, + Payload, +}; + +pub enum IngressError { + Disconnected, + Empty, +} + +/// Source of payloads for a particular protocol +/// +/// To be implemented by any mechanism that allows to submit a payloads from a +/// particular protocol that need to be muxed by the multiplexer. +pub trait Ingress { + fn try_recv(&mut self) -> Result; +} + +type Message = (u16, Payload); + +pub enum TickOutcome +where + TBearer: Bearer, +{ + BearerError(TBearer::Error), + Idle, + Busy, +} + +pub struct Muxer { + bearer: B, + ingress: HashMap, + clock: Instant, +} + +impl Muxer +where + B: Bearer, + I: Ingress, +{ + pub fn new(bearer: B) -> Self { + Self { + bearer, + ingress: Default::default(), + clock: Instant::now(), + } + } + + /// Register the receiver end of an ingress channel + pub fn register(&mut self, id: u16, rx: I) { + self.ingress.insert(id, rx); + } + + /// Remove a protocol from the ingress + /// + /// Meant to be used after a receive error in a previous tick + pub fn deregister(&mut self, id: u16) { + self.ingress.remove(&id); + } + + #[inline] + fn randomize_ids(&self) -> Vec { + let mut rng = thread_rng(); + let mut keys: Vec<_> = self.ingress.keys().cloned().collect(); + keys.shuffle(&mut rng); + keys + } + + /// Select the next segment to be muxed + /// + /// This method iterates over the existing receivers checking for the first + /// available message. The order of the checks is random to ensure a fair + /// use of the multiplexer amongst all protocols. + pub fn select(&mut self) -> Option { + for id in self.randomize_ids() { + let rx = self.ingress.get_mut(&id).unwrap(); + + match rx.try_recv() { + Ok(payload) => return Some((id, payload)), + Err(IngressError::Disconnected) => { + self.deregister(id); + } + _ => (), + }; + } + + None + } + + pub fn tick(&mut self) -> TickOutcome { + match self.select() { + Some((id, payload)) => { + let segment = Segment::new(self.clock, id, payload); + + match self.bearer.write_segment(segment) { + Err(err) => TickOutcome::BearerError(err), + _ => TickOutcome::Busy, + } + } + None => TickOutcome::Idle, + } + } + + pub fn block(&mut self, cancel: Cancel) -> Result<(), B::Error> { + loop { + match self.tick() { + TickOutcome::BearerError(err) => return Err(err), + TickOutcome::Idle => match cancel.is_set() { + true => break Ok(()), + false => std::thread::yield_now(), + }, + TickOutcome::Busy => (), + } + } + } +} diff --git a/pallas-multiplexer/src/std.rs b/pallas-multiplexer/src/std.rs new file mode 100644 index 0000000..bd9865e --- /dev/null +++ b/pallas-multiplexer/src/std.rs @@ -0,0 +1,123 @@ +use crate::{agents, bearers::Bearer, demux, mux, Payload}; + +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + mpsc::{channel, Receiver, SendError, Sender, TryRecvError}, + Arc, + }, + thread::{spawn, JoinHandle}, +}; + +pub type StdIngress = Receiver; + +impl mux::Ingress for StdIngress { + fn try_recv(&mut self) -> Result { + match Receiver::try_recv(self) { + Ok(x) => Ok(x), + Err(TryRecvError::Disconnected) => Err(mux::IngressError::Disconnected), + Err(TryRecvError::Empty) => Err(mux::IngressError::Empty), + } + } +} + +pub type StdEgress = Sender; + +impl demux::Egress for StdEgress { + fn send(&self, payload: Payload) -> Result<(), demux::EgressError> { + match Sender::send(self, payload) { + Ok(_) => Ok(()), + Err(SendError(p)) => Err(demux::EgressError(p)), + } + } +} + +pub type StdPlexer = crate::Multiplexer; + +pub type StdChannel = (Sender, Receiver); + +impl agents::Channel for StdChannel { + fn enqueue_chunk(&mut self, payload: Payload) -> Result<(), agents::ChannelError> { + match self.0.send(payload) { + Ok(_) => Ok(()), + Err(SendError(payload)) => Err(agents::ChannelError::NotConnected(Some(payload))), + } + } + + fn dequeue_chunk(&mut self) -> Result { + match self.1.recv() { + Ok(payload) => Ok(payload), + Err(_) => Err(agents::ChannelError::NotConnected(None)), + } + } +} + +pub fn use_channel(plexer: &mut StdPlexer, protocol: u16) -> StdChannel { + let (demux_tx, demux_rx) = channel::(); + let (mux_tx, mux_rx) = channel::(); + + plexer.register_channel(protocol, mux_rx, demux_tx); + + (mux_tx, demux_rx) +} + +#[derive(Clone, Debug, Default)] +pub struct Cancel(Arc); + +impl Cancel { + pub fn set(&self) { + self.0.store(true, Ordering::SeqCst); + } + + pub fn is_set(&self) -> bool { + self.0.load(Ordering::SeqCst) + } +} + +#[derive(Debug)] +pub struct Loop +where + B: Bearer, +{ + cancel: Cancel, + thread: JoinHandle>, +} + +impl Loop +where + B: Bearer, +{ + pub fn cancel(&self) { + self.cancel.set(); + } + + pub fn join(self) -> Result<(), B::Error> { + self.thread.join().unwrap() + } +} + +pub fn spawn_muxer(mut muxer: mux::Muxer) -> Loop +where + B: Bearer + 'static, + B::Error: Send, + I: mux::Ingress + Send + 'static, +{ + let cancel = Cancel::default(); + let cancel2 = cancel.clone(); + let thread = spawn(move || muxer.block(cancel2)); + + Loop { cancel, thread } +} + +pub fn spawn_demuxer(mut demuxer: demux::Demuxer) -> Loop +where + B: Bearer + 'static, + B::Error: Send, + E: demux::Egress + Send + 'static, +{ + let cancel = Cancel::default(); + let cancel2 = cancel.clone(); + let thread = spawn(move || demuxer.block(cancel2)); + + Loop { cancel, thread } +} diff --git a/pallas-multiplexer/tests/integration.rs b/pallas-multiplexer/tests/integration.rs index 337a137..6c6c0c9 100644 --- a/pallas-multiplexer/tests/integration.rs +++ b/pallas-multiplexer/tests/integration.rs @@ -1,25 +1,37 @@ use std::{ net::{Ipv4Addr, SocketAddrV4, TcpListener, TcpStream}, thread::{self, JoinHandle}, + time::Duration, }; use log::info; -use pallas_multiplexer::{Channel, Multiplexer}; +use pallas_codec::minicbor; +use pallas_multiplexer::{ + agents::{Channel, ChannelBuffer}, + spawn_demuxer, spawn_muxer, use_channel, StdPlexer, +}; use rand::{distributions::Uniform, Rng}; -fn setup_passive_muxer() -> JoinHandle { +fn setup_passive_muxer() -> JoinHandle> { thread::spawn(|| { let server = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, P)).unwrap(); info!("listening for connections on port {}", P); let (bearer, _) = server.accept().unwrap(); - Multiplexer::setup(bearer, &[0x8003u16]).unwrap() + + bearer.set_nonblocking(true).unwrap(); + + bearer + .set_read_timeout(Some(Duration::from_secs(3))) + .unwrap(); + + StdPlexer::new(bearer) }) } -fn setup_active_muxer() -> JoinHandle { +fn setup_active_muxer() -> JoinHandle> { thread::spawn(|| { let bearer = TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, P)).unwrap(); - Multiplexer::setup(bearer, &[0x0003u16]).unwrap() + StdPlexer::new(bearer) }) } @@ -29,29 +41,7 @@ fn random_payload(size: usize) -> Vec { } #[test] -fn one_way_small_payload_is_consistent() { - let passive = setup_passive_muxer::<50201>(); - - // HACK: a small sleep seems to be required for Github actions runner to - // formally expose the port - thread::sleep(std::time::Duration::from_secs(1)); - - let active = setup_active_muxer::<50201>(); - - let mut active_muxer = active.join().unwrap(); - let mut passive_muxer = passive.join().unwrap(); - - let Channel(tx, _) = active_muxer.use_channel(0x0003u16); - let Channel(_, rx) = passive_muxer.use_channel(0x8003u16); - - let payload = random_payload(50); - tx.send(payload.clone()).unwrap(); - let received_payload = rx.recv().unwrap(); - assert_eq!(payload, received_payload) -} - -#[test] -fn one_way_small_sequence_of_payloads_are_consistent() { +fn one_way_small_sequence_of_payloads() { let passive = setup_passive_muxer::<50301>(); // HACK: a small sleep seems to be required for Github actions runner to @@ -60,16 +50,101 @@ fn one_way_small_sequence_of_payloads_are_consistent() { let active = setup_active_muxer::<50301>(); - let mut active_muxer = active.join().unwrap(); - let mut passive_muxer = passive.join().unwrap(); + let mut active_plexer = active.join().unwrap(); + let mut passive_plexer = passive.join().unwrap(); - let Channel(tx, _) = active_muxer.use_channel(0x0003u16); - let Channel(_, rx) = passive_muxer.use_channel(0x8003u16); + let mut sender_channel = use_channel(&mut active_plexer, 0x0003u16); + let mut receiver_channel = use_channel(&mut passive_plexer, 0x8003u16); + + let loop1 = spawn_muxer(active_plexer.muxer); + let loop2 = spawn_demuxer(passive_plexer.demuxer); for _ in [0..100] { let payload = random_payload(50); - tx.send(payload.clone()).unwrap(); - let received_payload = rx.recv().unwrap(); - assert_eq!(payload, received_payload) + sender_channel.enqueue_chunk(payload.clone()).unwrap(); + let received_payload = receiver_channel.dequeue_chunk().unwrap(); + assert_eq!(payload, received_payload); } + + loop1.cancel(); + loop1.join().unwrap(); + + loop2.cancel(); + loop2.join().unwrap(); +} + +#[test] +fn threads_cancel_while_still_sending() { + let passive = setup_passive_muxer::<50401>(); + + // HACK: a small sleep seems to be required for Github actions runner to + // formally expose the port + thread::sleep(std::time::Duration::from_secs(1)); + + let active = setup_active_muxer::<50401>(); + + let mut active_plexer = active.join().unwrap(); + let mut passive_plexer = passive.join().unwrap(); + + let mut sender_channel = use_channel(&mut active_plexer, 0x0003u16); + let mut receiver_channel = use_channel(&mut passive_plexer, 0x8003u16); + + let loop1 = spawn_muxer(active_plexer.muxer); + let loop2 = spawn_demuxer(passive_plexer.demuxer); + + thread::spawn(move || loop { + let payload = random_payload(50); + sender_channel.enqueue_chunk(payload.clone()).unwrap(); + let received_payload = receiver_channel.dequeue_chunk().unwrap(); + assert_eq!(payload, received_payload); + }); + + thread::sleep(Duration::from_secs(5)); + + loop1.cancel(); + loop1.join().unwrap(); + + loop2.cancel(); + loop2.join().unwrap(); +} + +#[test] +fn multiple_messages_in_same_payload() { + let mut input = Vec::new(); + let in_part1 = (1u8, 2u8, 3u8); + let in_part2 = (6u8, 5u8, 4u8); + + minicbor::encode(in_part1, &mut input).unwrap(); + minicbor::encode(in_part2, &mut input).unwrap(); + + let mut channel = std::sync::mpsc::channel(); + channel.0.send(input).unwrap(); + + let mut buf = ChannelBuffer::new(&mut channel); + + let out_part1 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap(); + let out_part2 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap(); + + assert_eq!(in_part1, out_part1); + assert_eq!(in_part2, out_part2); +} + +#[test] +fn fragmented_message_in_multiple_payload() { + let mut input = Vec::new(); + let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8); + minicbor::encode(msg, &mut input).unwrap(); + + let mut channel = std::sync::mpsc::channel(); + + while !input.is_empty() { + let chunk = Vec::from(input.drain(0..2).as_slice()); + channel.0.send(chunk).unwrap(); + } + + let mut buf = ChannelBuffer::new(&mut channel); + + let out_msg = buf.recv_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>().unwrap(); + + assert_eq!(msg, out_msg); }