feat(multiplexer): Allow fine-grained control of concurrency strategy (#106)

This commit is contained in:
Santiago Carmuega 2022-06-03 21:37:38 -03:00 committed by GitHub
parent f5b7c13c86
commit a39682a38d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 822 additions and 559 deletions

View file

@ -5,7 +5,7 @@ use pallas::network::{
handshake::{n2n::VersionTable, Initiator},
run_agent, Point, MAINNET_MAGIC,
},
multiplexer::Multiplexer,
multiplexer::{spawn_demuxer, spawn_muxer, use_channel, StdPlexer},
};
use pallas::network::miniprotocols::blockfetch::{BatchClient, Observer};
@ -30,11 +30,15 @@ fn main() {
bearer.set_nodelay(true).unwrap();
bearer.set_keepalive_ms(Some(30_000u32)).unwrap();
let mut muxer = Multiplexer::setup(bearer, &[0, 3]).unwrap();
let mut plexer = StdPlexer::new(bearer);
let mut channel0 = use_channel(&mut plexer, 0);
let mut channel3 = use_channel(&mut plexer, 3);
spawn_muxer(plexer.muxer);
spawn_demuxer(plexer.demuxer);
let mut hs_channel = muxer.use_channel(0);
let versions = VersionTable::v4_and_above(MAINNET_MAGIC);
let _last = run_agent(Initiator::initial(versions), &mut hs_channel).unwrap();
let _last = run_agent(Initiator::initial(versions), &mut channel0).unwrap();
let range = (
Point::Specific(
@ -49,8 +53,7 @@ fn main() {
),
);
let mut bf_channel = muxer.use_channel(3);
let bf = BatchClient::initial(range, BlockPrinter {});
let bf_last = run_agent(bf, &mut bf_channel);
let bf_last = run_agent(bf, &mut channel3);
println!("{:?}", bf_last);
}

View file

@ -1,6 +1,6 @@
use pallas::network::{
miniprotocols::{chainsync, handshake, localstate, run_agent, Point, MAINNET_MAGIC},
multiplexer::Multiplexer,
multiplexer,
};
use std::os::unix::net::UnixStream;
@ -45,15 +45,12 @@ impl chainsync::Observer<chainsync::HeaderContent> for LoggingObserver {
}
}
fn do_handshake(muxer: &mut Multiplexer) {
let mut channel = muxer.use_channel(0);
fn do_handshake(mut channel: multiplexer::StdChannel) {
let versions = handshake::n2c::VersionTable::v1_and_above(MAINNET_MAGIC);
let _last = run_agent(handshake::Initiator::initial(versions), &mut channel).unwrap();
}
fn do_localstate_query(muxer: &mut Multiplexer) {
let mut channel = muxer.use_channel(7);
fn do_localstate_query(mut channel: multiplexer::StdChannel) {
let agent = run_agent(
localstate::OneShotClient::<localstate::queries::QueryV10>::initial(
None,
@ -65,9 +62,7 @@ fn do_localstate_query(muxer: &mut Multiplexer) {
log::info!("state query result: {:?}", agent);
}
fn do_chainsync(muxer: &mut Multiplexer) {
let mut channel = muxer.use_channel(5);
fn do_chainsync(mut channel: multiplexer::StdChannel) {
let known_points = vec![Point::Specific(
43847831u64,
hex::decode("15b9eeee849dd6386d3770b0745e0450190f7560e5159b1b3ab13b14b2684a45").unwrap(),
@ -95,14 +90,20 @@ fn main() {
// setup the multiplexer by specifying the bearer and the IDs of the
// miniprotocols to use
let mut muxer = Multiplexer::setup(bearer, &[0, 4, 5]).unwrap();
let mut plexer = multiplexer::StdPlexer::new(bearer);
let channel0 = multiplexer::use_channel(&mut plexer, 0);
let channel7 = multiplexer::use_channel(&mut plexer, 7);
let channel5 = multiplexer::use_channel(&mut plexer, 5);
multiplexer::spawn_muxer(plexer.muxer);
multiplexer::spawn_demuxer(plexer.demuxer);
// execute the required handshake against the relay
do_handshake(&mut muxer);
do_handshake(channel0);
// execute an arbitrary "Local State" query against the node
do_localstate_query(&mut muxer);
do_localstate_query(channel7);
// execute the chainsync flow from an arbitrary point in the chain
do_chainsync(&mut muxer);
do_chainsync(channel5);
}

View file

@ -2,7 +2,7 @@ use net2::TcpStreamExt;
use pallas::network::{
miniprotocols::{blockfetch, chainsync, handshake, run_agent, Point, MAINNET_MAGIC},
multiplexer::Multiplexer,
multiplexer::{spawn_demuxer, spawn_muxer, use_channel, StdChannel, StdPlexer},
};
use std::net::TcpStream;
@ -54,15 +54,12 @@ impl chainsync::Observer<chainsync::HeaderContent> for LoggingObserver {
}
}
fn do_handshake(muxer: &mut Multiplexer) {
let mut channel = muxer.use_channel(0);
fn do_handshake(mut channel: StdChannel) {
let versions = handshake::n2n::VersionTable::v4_and_above(MAINNET_MAGIC);
let _last = run_agent(handshake::Initiator::initial(versions), &mut channel).unwrap();
}
fn do_blockfetch(muxer: &mut Multiplexer) {
let mut channel = muxer.use_channel(3);
fn do_blockfetch(mut channel: StdChannel) {
let range = (
Point::Specific(
43847831,
@ -84,9 +81,7 @@ fn do_blockfetch(muxer: &mut Multiplexer) {
println!("{:?}", agent);
}
fn do_chainsync(muxer: &mut Multiplexer) {
let mut channel = muxer.use_channel(2);
fn do_chainsync(mut channel: StdChannel) {
let known_points = vec![Point::Specific(
43847831u64,
hex::decode("15b9eeee849dd6386d3770b0745e0450190f7560e5159b1b3ab13b14b2684a45").unwrap(),
@ -116,14 +111,20 @@ fn main() {
// setup the multiplexer by specifying the bearer and the IDs of the
// miniprotocols to use
let mut muxer = Multiplexer::setup(bearer, &[0, 2, 3, 4]).unwrap();
let mut plexer = StdPlexer::new(bearer);
let channel0 = use_channel(&mut plexer, 0);
let channel3 = use_channel(&mut plexer, 3);
let channel2 = use_channel(&mut plexer, 2);
spawn_muxer(plexer.muxer);
spawn_demuxer(plexer.demuxer);
// execute the required handshake against the relay
do_handshake(&mut muxer);
do_handshake(channel0);
// fetch an arbitrary batch of block
do_blockfetch(&mut muxer);
do_blockfetch(channel3);
// execute the chainsync flow from an arbitrary point in the chain
do_chainsync(&mut muxer);
do_chainsync(channel2);
}

View file

@ -0,0 +1,27 @@
[package]
name = "pallas-primitives"
description = "Ledger primitives and cbor codec for the different Cardano eras"
version = "0.9.1"
edition = "2021"
repository = "https://github.com/txpipe/pallas"
homepage = "https://github.com/txpipe/pallas"
documentation = "https://docs.rs/pallas-byron"
license = "Apache-2.0"
readme = "README.md"
authors = [
"Santiago Carmuega <santiago@carmuega.me>",
]
[dependencies]
hex = "0.4.3"
log = "0.4.14"
pallas-crypto = { version = "0.9.0", path = "../pallas-crypto" }
pallas-codec = { version = "0.9.0", path = "../pallas-codec" }
base58 = "0.2.0"
bech32 = "0.8.1"
serde = { version ="1.0.136", optional = true }
serde_json = { version ="1.0.79", optional = true }
[features]
json = ["serde", "serde_json"]
default = ["json"]

View file

@ -13,8 +13,8 @@ authors = [
]
[dependencies]
pallas-multiplexer = { version = "0.9.0", path = "../pallas-multiplexer/" }
pallas-codec = { version = "0.9.0", path = "../pallas-codec/" }
pallas-multiplexer = { version = "0.9.0", path = "../pallas-multiplexer/" }
log = "0.4.14"
hex = "0.4.3"
itertools = "0.10.3"

View file

@ -1,4 +1,5 @@
use crate::machines::{Agent, Transition};
use crate::MachineError;
use crate::common::Point;
@ -155,7 +156,9 @@ where
fn on_block(mut self, body: Vec<u8>) -> Transition<Self> {
log::debug!("received block body, size {}", body.len());
self.observer.on_block_received(body)?;
self.observer
.on_block_received(body)
.map_err(MachineError::downstream)?;
Ok(self)
}
@ -180,6 +183,11 @@ where
O: Observer,
{
type Message = Message;
type State = State;
fn state(&self) -> &Self::State {
&self.state
}
fn is_done(&self) -> bool {
self.state == State::Done
@ -294,7 +302,9 @@ where
fn on_block(mut self, body: Vec<u8>) -> Transition<Self> {
log::debug!("received block body, size {}", body.len());
self.observer.on_block_received(body)?;
self.observer
.on_block_received(body)
.map_err(MachineError::downstream)?;
Ok(self)
}
@ -317,6 +327,11 @@ where
O: Observer,
{
type Message = Message;
type State = State;
fn state(&self) -> &Self::State {
&self.state
}
fn is_done(&self) -> bool {
self.state == State::Done

View file

@ -44,6 +44,7 @@ pub trait Observer<C> {
Ok(Continuation::Proceed)
}
fn on_tip_reached(&mut self) -> Result<Continuation, Box<dyn std::error::Error>> {
log::debug!("tip was reached");
@ -59,6 +60,7 @@ impl<C> Observer<C> for NoopObserver {}
#[derive(Debug)]
pub struct Consumer<C, O>
where
Self: Agent,
O: Observer<C>,
{
pub state: State,
@ -77,6 +79,7 @@ impl<C, O> Consumer<C, O>
where
O: Observer<C>,
Message<C>: Fragment,
C: std::fmt::Debug + 'static,
{
pub fn initial(known_points: Option<Vec<Point>>, observer: O) -> Self {
Self {
@ -93,7 +96,10 @@ where
fn on_intersect_found(mut self, point: Point, tip: Tip) -> Transition<Self> {
log::debug!("intersect found: {:?} (tip: {:?})", point, tip);
let continuation = self.observer.on_intersect_found(&point, &tip)?;
let continuation = self
.observer
.on_intersect_found(&point, &tip)
.map_err(MachineError::downstream)?;
Ok(Self {
tip: Some(tip),
@ -118,7 +124,10 @@ where
fn on_roll_forward(mut self, content: C, tip: Tip) -> Transition<Self> {
log::debug!("rolling forward");
let continuation = self.observer.on_roll_forward(content, &tip)?;
let continuation = self
.observer
.on_roll_forward(content, &tip)
.map_err(MachineError::downstream)?;
Ok(Self {
tip: Some(tip),
@ -131,7 +140,10 @@ where
fn on_roll_backward(mut self, point: Point, tip: Tip) -> Transition<Self> {
log::debug!("rolling backward to point: {:?}", point);
let continuation = self.observer.on_rollback(&point)?;
let continuation = self
.observer
.on_rollback(&point)
.map_err(MachineError::downstream)?;
Ok(Self {
tip: Some(tip),
@ -145,7 +157,10 @@ where
fn on_await_reply(mut self) -> Transition<Self> {
log::debug!("reached tip, await reply");
let continuation = self.observer.on_tip_reached()?;
let continuation = self
.observer
.on_tip_reached()
.map_err(MachineError::downstream)?;
Ok(Self {
state: State::MustReply,
@ -162,6 +177,11 @@ where
Message<C>: Fragment,
{
type Message = Message<C>;
type State = State;
fn state(&self) -> &Self::State {
&self.state
}
fn is_done(&self) -> bool {
self.state == State::Done || self.continuation == Continuation::DropOut
@ -230,7 +250,7 @@ where
self.on_intersect_found(point, tip)
}
(State::Intersect, Message::IntersectNotFound(tip)) => self.on_intersect_not_found(tip),
(_, msg) => Err(MachineError::InvalidMsgForState(self.state, msg).into()),
(_, msg) => Err(MachineError::InvalidMsgForState(self.state, msg)),
}
}
}
@ -278,6 +298,11 @@ pub type BlockConsumer<O> = Consumer<BlockContent, O>;
impl Agent for TipFinder {
type Message = Message<SkippedContent>;
type State = State;
fn state(&self) -> &Self::State {
&self.state
}
fn is_done(&self) -> bool {
self.state == State::Done
@ -322,7 +347,7 @@ impl Agent for TipFinder {
self.on_intersect_found(tip)
}
(State::Intersect, Message::IntersectNotFound(tip)) => self.on_intersect_not_found(tip),
(state, msg) => Err(MachineError::InvalidMsgForState(state.clone(), msg).into()),
(state, msg) => Err(MachineError::InvalidMsgForState(state.clone(), msg)),
}
}
}

View file

@ -39,6 +39,11 @@ where
D: Debug + Clone,
{
type Message = Message<D>;
type State = State;
fn state(&self) -> &Self::State {
&self.state
}
fn is_done(&self) -> bool {
self.state == State::Done

View file

@ -1,6 +1,5 @@
mod common;
mod machines;
mod payloads;
pub mod blockfetch;
pub mod chainsync;
@ -10,4 +9,3 @@ pub mod txsubmission;
pub use common::*;
pub use machines::*;
pub use payloads::*;

View file

@ -52,7 +52,7 @@ pub struct OneShotClient<Q: Query> {
impl<Q> OneShotClient<Q>
where
Q: Query,
Q: Query + 'static,
Message<Q>: Fragment,
{
pub fn initial(check_point: Option<Point>, request: Q::Request) -> Self {
@ -101,6 +101,11 @@ where
Message<Q>: Fragment,
{
type Message = Message<Q>;
type State = State;
fn state(&self) -> &Self::State {
&self.state
}
fn is_done(&self) -> bool {
self.state == State::Done
@ -158,7 +163,7 @@ where
(State::Acquiring, Message::Acquired) => self.on_acquired(),
(State::Acquiring, Message::Failure(failure)) => self.on_failure(failure),
(State::Querying, Message::Result(result)) => self.on_result(result),
(_, msg) => Err(MachineError::InvalidMsgForState(self.state, msg).into()),
(_, msg) => Err(MachineError::InvalidMsgForState(self.state, msg)),
}
}
}

View file

@ -1,84 +1,31 @@
pub use crate::payloads::*;
use pallas_codec::{minicbor, Fragment};
use pallas_multiplexer::{Channel, Payload};
use pallas_codec::Fragment;
use pallas_multiplexer::agents::{Channel, ChannelBuffer, ChannelError};
use std::cell::Cell;
use std::fmt::{Debug, Display};
use std::sync::mpsc::Sender;
#[derive(Debug)]
pub enum MachineError<State, Msg>
where
State: Debug,
Msg: Debug,
{
InvalidMsgForState(State, Msg),
pub enum MachineError<A: Agent> {
InvalidMsgForState(A::State, A::Message),
ChannelError(ChannelError),
DownstreamError(Box<dyn std::error::Error>),
}
impl<S, M> Display for MachineError<S, M>
where
S: Debug,
M: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MachineError::InvalidMsgForState(msg, state) => {
write!(
f,
"received invalid message ({:?}) for current state ({:?})",
msg, state
)
}
}
impl<A: Agent> MachineError<A> {
pub fn channel(err: ChannelError) -> Self {
Self::ChannelError(err)
}
pub fn downstream(err: Box<dyn std::error::Error>) -> Self {
Self::DownstreamError(err)
}
}
impl<S, M> std::error::Error for MachineError<S, M>
where
S: Debug,
M: Debug,
{
}
#[derive(Debug)]
pub enum CodecError {
BadLabel(u16),
UnexpectedCbor(&'static str),
}
impl std::error::Error for CodecError {}
impl Display for CodecError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CodecError::BadLabel(label) => {
write!(f, "unknown message label: {}", label)
}
CodecError::UnexpectedCbor(msg) => {
write!(f, "unexpected cbor: {}", msg)
}
}
}
}
pub trait MachineOutput {
fn send_msg(&self, data: &impl Fragment) -> Result<(), Box<dyn std::error::Error>>;
}
impl MachineOutput for Sender<Payload> {
fn send_msg(&self, data: &impl Fragment) -> Result<(), Box<dyn std::error::Error>> {
let mut payload = Vec::new();
minicbor::encode(data, &mut payload)?;
self.send(payload)?;
Ok(())
}
}
pub type Transition<T> = Result<T, Box<dyn std::error::Error>>;
pub type Transition<A> = Result<A, MachineError<A>>;
pub trait Agent: Sized {
type Message;
type State;
fn state(&self) -> &Self::State;
fn is_done(&self) -> bool;
fn has_agency(&self) -> bool;
fn build_next(&self) -> Self::Message;
@ -87,36 +34,38 @@ pub trait Agent: Sized {
fn apply_inbound(self, msg: Self::Message) -> Transition<Self>;
}
pub struct Runner<A>
pub struct Runner<'c, A, C>
where
A: Agent,
C: Channel,
{
agent: Cell<Option<A>>,
buffer: Vec<u8>,
buffer: ChannelBuffer<'c, C>,
}
impl<'a, A> Runner<A>
impl<'c, A, C> Runner<'c, A, C>
where
A: Agent,
A::Message: Fragment + Debug,
A::Message: Fragment + std::fmt::Debug,
C: Channel,
{
pub fn new(agent: A) -> Self {
pub fn new(agent: A, channel: &'c mut C) -> Self {
Self {
agent: Cell::new(Some(agent)),
buffer: Vec::new(),
buffer: ChannelBuffer::new(channel),
}
}
pub fn start(&mut self) -> Result<(), Error> {
pub fn start(&mut self) -> Result<(), MachineError<A>> {
let prev = self.agent.take().unwrap();
let next = prev.apply_start()?;
self.agent.set(Some(next));
Ok(())
}
pub fn run_step(&mut self, channel: &mut Channel) -> Result<bool, Error> {
pub fn run_step(&mut self) -> Result<bool, MachineError<A>> {
let prev = self.agent.take().unwrap();
let next = run_agent_step(prev, channel, &mut self.buffer)?;
let next = run_agent_step(prev, &mut self.buffer)?;
let is_done = next.is_done();
self.agent.set(Some(next));
@ -124,35 +73,35 @@ where
Ok(is_done)
}
pub fn fulfill(mut self, channel: &mut Channel) -> Result<(), Error> {
pub fn fulfill(mut self) -> Result<(), MachineError<A>> {
self.start()?;
while self.run_step(channel)? {}
while self.run_step()? {}
Ok(())
}
}
pub fn run_agent_step<T>(agent: T, channel: &mut Channel, buffer: &mut Vec<u8>) -> Transition<T>
pub fn run_agent_step<A, C>(agent: A, channel: &mut ChannelBuffer<C>) -> Transition<A>
where
T: Agent,
T::Message: Fragment + Debug,
A: Agent,
A::Message: Fragment + std::fmt::Debug,
C: Channel,
{
let Channel(tx, rx) = channel;
match agent.has_agency() {
true => {
let msg = agent.build_next();
log::trace!("processing outbound msg: {:?}", msg);
let mut payload = Vec::new();
minicbor::encode(&msg, &mut payload)?;
tx.send(payload)?;
channel
.send_msg_chunks(&msg)
.map_err(MachineError::channel)?;
agent.apply_outbound(msg)
}
false => {
let msg = read_until_full_msg::<T::Message>(buffer, rx).unwrap();
let msg = channel.recv_full_msg().map_err(MachineError::channel)?;
log::trace!("procesing inbound msg: {:?}", msg);
agent.apply_inbound(msg)
@ -160,17 +109,18 @@ where
}
}
pub fn run_agent<T>(agent: T, channel: &mut Channel) -> Result<T, Box<dyn std::error::Error>>
pub fn run_agent<A, C>(agent: A, channel: &mut C) -> Transition<A>
where
T: Agent,
T::Message: Fragment + Debug,
A: Agent,
A::Message: Fragment + std::fmt::Debug,
C: Channel,
{
let mut buffer = Vec::new();
let mut buffer = ChannelBuffer::new(channel);
let mut agent = agent.apply_start()?;
while !agent.is_done() {
agent = run_agent_step(agent, channel, &mut buffer)?;
agent = run_agent_step(agent, &mut buffer)?;
}
Ok(agent)

View file

@ -1,104 +0,0 @@
use pallas_codec::{minicbor, Fragment};
use pallas_multiplexer::Payload;
use std::sync::mpsc::Receiver;
pub type Error = Box<dyn std::error::Error>;
enum Decoding<M> {
Done(M, usize),
NotEnoughData,
UnexpectedError(Error),
}
fn try_decode_message<M>(buffer: &[u8]) -> Decoding<M>
where
M: Fragment,
{
let mut decoder = minicbor::Decoder::new(buffer);
let maybe_msg = decoder.decode();
match maybe_msg {
Ok(msg) => Decoding::Done(msg, decoder.position()),
Err(err) if err.is_end_of_input() => Decoding::NotEnoughData,
Err(err) => Decoding::UnexpectedError(Box::new(err)),
}
}
/// Reads from the receiver until a complete message is found
pub fn read_until_full_msg<M>(
buffer: &mut Vec<u8>,
receiver: &mut Receiver<Payload>,
) -> Result<M, Error>
where
M: Fragment,
{
// do an eager reading if buffer is empty, no point in going through the error
// handling
if buffer.is_empty() {
let chunk = receiver.recv()?;
buffer.extend(chunk);
}
let decoding = try_decode_message::<M>(buffer);
match decoding {
Decoding::Done(msg, pos) => {
buffer.drain(0..pos);
Ok(msg)
}
Decoding::UnexpectedError(err) => Err(err),
Decoding::NotEnoughData => {
let chunk = receiver.recv()?;
buffer.extend(chunk);
read_until_full_msg::<M>(buffer, receiver)
}
}
}
#[cfg(test)]
mod tests {
use crate::read_until_full_msg;
use pallas_codec::minicbor;
use std::sync::mpsc::channel;
#[test]
fn multiple_messages_in_same_payload() {
let mut input = Vec::new();
let in_part1 = (1u8, 2u8, 3u8);
let in_part2 = (6u8, 5u8, 4u8);
minicbor::encode(in_part1, &mut input).unwrap();
minicbor::encode(in_part2, &mut input).unwrap();
let (tx, mut rx) = channel();
tx.send(input).unwrap();
let mut output = Vec::new();
let out_part1 = read_until_full_msg::<(u8, u8, u8)>(&mut output, &mut rx).unwrap();
let out_part2 = read_until_full_msg::<(u8, u8, u8)>(&mut output, &mut rx).unwrap();
assert_eq!(in_part1, out_part1);
assert_eq!(in_part2, out_part2);
}
#[test]
fn fragmented_message_in_multiple_payload() {
let mut input = Vec::new();
let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8);
minicbor::encode(msg, &mut input).unwrap();
let (tx, mut rx) = channel();
while !input.is_empty() {
let chunk = Vec::from(input.drain(0..2).as_slice());
tx.send(chunk).unwrap();
}
let mut output = Vec::new();
let out_msg =
read_until_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>(&mut output, &mut rx).unwrap();
assert_eq!(msg, out_msg);
}
}

View file

@ -239,6 +239,11 @@ impl NaiveProvider {
impl Agent for NaiveProvider {
type Message = Message;
type State = State;
fn state(&self) -> &Self::State {
&self.state
}
fn is_done(&self) -> bool {
self.state == State::Done
@ -295,7 +300,7 @@ impl Agent for NaiveProvider {
..self
}),
(State::Idle, Message::RequestTxs(ids)) => self.on_txs_request(ids),
(_, msg) => Err(MachineError::InvalidMsgForState(self.state, msg).into()),
(_, msg) => Err(MachineError::InvalidMsgForState(self.state, msg)),
}
}
}

View file

@ -13,10 +13,15 @@ authors = [
]
[dependencies]
pallas-codec = { version = "0.9.0", path = "../pallas-codec/" }
log = "0.4.14"
byteorder = "1.4.3"
hex = "0.4.3"
rand = "0.8.4"
[dev-dependencies]
rand = "0.8.4"
env_logger = "0.9.0"
[features]
std = []
default = ["std"]

View file

@ -1,40 +0,0 @@
use std::{net::TcpListener, thread, time::Duration};
use log::info;
use pallas_multiplexer::{Channel, Multiplexer};
const PROTOCOLS: [u16; 2] = [0x8002u16, 0x8003u16];
fn main() {
env_logger::init();
let server = TcpListener::bind("0.0.0.0:3001").unwrap();
info!("listening for connections on port 3001");
let (bearer, _) = server.accept().unwrap();
let mut muxer = Multiplexer::setup(bearer, &PROTOCOLS).unwrap();
for protocol in PROTOCOLS {
let handle = muxer.use_channel(protocol);
thread::spawn(move || {
info!("starting thread for protocol: {}", protocol);
let Channel(_, rx) = handle;
loop {
let payload = rx.recv().unwrap();
info!(
"got message within thread, id:{}, length:{}",
protocol,
payload.len()
);
}
});
}
loop {
thread::sleep(Duration::from_secs(6000));
}
}

View file

@ -1,33 +0,0 @@
use std::{net::TcpStream, thread, time::Duration};
use log::info;
use pallas_multiplexer::{Channel, Multiplexer};
const PROTOCOLS: [u16; 2] = [0x0002u16, 0x0003u16];
fn main() {
env_logger::init();
info!("connecting to tcp socket on 127.0.0.1:3001");
let bearer = TcpStream::connect("127.0.0.1:3001").unwrap();
let mut muxer = Multiplexer::setup(bearer, &PROTOCOLS).unwrap();
for protocol in PROTOCOLS {
let handle = muxer.use_channel(protocol);
thread::spawn(move || {
let Channel(tx, _) = handle;
loop {
let payload = vec![1; 65545];
info!("sending dumb payload for protocol: {}", protocol);
tx.send(payload).unwrap();
thread::sleep(Duration::from_millis(500u64 + (protocol as u64 * 10u64)));
}
});
}
loop {
thread::sleep(Duration::from_secs(6000));
}
}

View file

@ -0,0 +1,103 @@
//! Interface to interact with the multiplexer as an agent
use crate::Payload;
use pallas_codec::{minicbor, Fragment};
#[derive(Debug)]
pub enum ChannelError {
NotConnected(Option<Payload>),
Encoding(String),
Decoding(String),
}
/// A raw link to the ingress / egress of the multiplexer
pub trait Channel {
fn enqueue_chunk(&mut self, chunk: Payload) -> Result<(), ChannelError>;
fn dequeue_chunk(&mut self) -> Result<Payload, ChannelError>;
}
/// Protocol value that defines max segment length
pub const MAX_SEGMENT_PAYLOAD_LENGTH: usize = 65535;
enum Decoding<M> {
Done(M, usize),
NotEnoughData,
UnexpectedError(Box<dyn std::error::Error>),
}
fn try_decode_message<M>(buffer: &[u8]) -> Decoding<M>
where
M: Fragment,
{
let mut decoder = minicbor::Decoder::new(buffer);
let maybe_msg = decoder.decode();
match maybe_msg {
Ok(msg) => Decoding::Done(msg, decoder.position()),
Err(err) if err.is_end_of_input() => Decoding::NotEnoughData,
Err(err) => Decoding::UnexpectedError(Box::new(err)),
}
}
/// A channel abstraction to hide the complexity of partial payloads
pub struct ChannelBuffer<'c, C: Channel> {
channel: &'c mut C,
temp: Vec<u8>,
}
impl<'c, C: Channel> ChannelBuffer<'c, C> {
pub fn new(channel: &'c mut C) -> Self {
Self {
channel,
temp: Vec::new(),
}
}
/// Enqueues a msg as a sequence payload chunks
pub fn send_msg_chunks<M>(&mut self, msg: &M) -> Result<(), ChannelError>
where
M: Fragment,
{
let mut payload = Vec::new();
minicbor::encode(&msg, &mut payload)
.map_err(|err| ChannelError::Encoding(err.to_string()))?;
let chunks = payload.chunks(MAX_SEGMENT_PAYLOAD_LENGTH);
for chunk in chunks {
self.channel.enqueue_chunk(Vec::from(chunk))?;
}
Ok(())
}
/// Reads from the channel until a complete message is found
pub fn recv_full_msg<M>(&mut self) -> Result<M, ChannelError>
where
M: Fragment,
{
// do an eager reading if buffer is empty, no point in going through the error
// handling
if self.temp.is_empty() {
let chunk = self.channel.dequeue_chunk()?;
self.temp.extend(chunk);
}
let decoding = try_decode_message::<M>(&self.temp);
match decoding {
Decoding::Done(msg, pos) => {
self.temp.drain(0..pos);
Ok(msg)
}
Decoding::UnexpectedError(err) => Err(ChannelError::Decoding(err.to_string())),
Decoding::NotEnoughData => {
let chunk = self.channel.dequeue_chunk()?;
self.temp.extend(chunk);
self.recv_full_msg()
}
}
}
}

View file

@ -5,35 +5,62 @@ use std::io::{Read, Write};
use std::os::unix::net::UnixStream;
use std::{net::TcpStream, time::Instant};
use crate::{Bearer, Payload};
use crate::Payload;
pub struct Segment {
pub protocol: u16,
pub timestamp: u32,
pub payload: Payload,
}
pub trait Bearer: Read + Write + Send + Sync + Sized {
type Error: std::error::Error;
fn read_segment(&mut self) -> Result<Option<Segment>, Self::Error>;
fn write_segment(&mut self, segment: Segment) -> Result<(), Self::Error>;
fn clone(&self) -> Self;
}
impl Segment {
pub fn new(clock: Instant, protocol: u16, payload: Payload) -> Self {
Segment {
timestamp: clock.elapsed().as_micros() as u32,
protocol,
payload,
}
}
}
fn write_segment(writer: &mut impl Write, segment: Segment) -> Result<(), std::io::Error> {
let Segment {
timestamp,
protocol,
payload,
} = segment;
fn write_segment(
writer: &mut impl Write,
clock: Instant,
protocol_id: u16,
payload: &[u8],
) -> Result<(), std::io::Error> {
let mut msg = Vec::new();
msg.write_u32::<NetworkEndian>(clock.elapsed().as_micros() as u32)?;
msg.write_u16::<NetworkEndian>(protocol_id)?;
msg.write_u32::<NetworkEndian>(timestamp)?;
msg.write_u16::<NetworkEndian>(protocol)?;
msg.write_u16::<NetworkEndian>(payload.len() as u16)?;
if log_enabled!(log::Level::Trace) {
trace!(
"sending segment, header {:?}, protocol id: {}, payload length: {}",
hex::encode(&msg),
protocol_id,
protocol,
payload.len()
);
}
msg.write_all(payload)?;
msg.write_all(&payload)?;
writer.write_all(&msg)?;
writer.flush()
}
fn read_segment(reader: &mut impl Read) -> Result<(u16, u32, Payload), std::io::Error> {
fn read_segment(reader: &mut impl Read) -> Result<Segment, std::io::Error> {
let mut header = [0u8; 8];
reader.read_exact(&mut header)?;
@ -43,12 +70,12 @@ fn read_segment(reader: &mut impl Read) -> Result<(u16, u32, Payload), std::io::
}
let length = NetworkEndian::read_u16(&header[6..]) as usize;
let id = NetworkEndian::read_u16(&header[4..6]) as usize ^ 0x8000;
let ts = NetworkEndian::read_u32(&header[0..4]);
let protocol = NetworkEndian::read_u16(&header[4..6]) as usize ^ 0x8000;
let timestamp = NetworkEndian::read_u32(&header[0..4]);
debug!(
"parsed inbound msg, protocol id: {}, ts: {}, payload length: {}",
id, ts, length
protocol, timestamp, length
);
let mut payload = vec![0u8; length];
@ -58,44 +85,54 @@ fn read_segment(reader: &mut impl Read) -> Result<(u16, u32, Payload), std::io::
trace!("read segment payload: {:?}", hex::encode(&payload));
}
Ok((id as u16, ts, payload))
Ok(Segment {
protocol: protocol as u16,
timestamp,
payload,
})
}
fn read_segment_with_timeout(reader: &mut impl Read) -> Result<Option<Segment>, std::io::Error> {
match read_segment(reader) {
Ok(s) => Ok(Some(s)),
Err(err) => match err.kind() {
std::io::ErrorKind::WouldBlock => Ok(None),
std::io::ErrorKind::TimedOut => Ok(None),
std::io::ErrorKind::Interrupted => Ok(None),
_ => todo!(),
},
}
}
impl Bearer for TcpStream {
type Error = std::io::Error;
fn clone(&self) -> Self {
self.try_clone().expect("error cloning tcp stream")
}
fn read_segment(&mut self) -> Result<(u16, u32, Payload), std::io::Error> {
read_segment(self)
fn read_segment(&mut self) -> Result<Option<Segment>, std::io::Error> {
read_segment_with_timeout(self)
}
fn write_segment(
&mut self,
clock: Instant,
protocol_id: u16,
partial_payload: &[u8],
) -> Result<(), std::io::Error> {
write_segment(self, clock, protocol_id, partial_payload)
fn write_segment(&mut self, segment: Segment) -> Result<(), std::io::Error> {
write_segment(self, segment)
}
}
#[cfg(target_family = "unix")]
impl Bearer for UnixStream {
type Error = std::io::Error;
fn clone(&self) -> Self {
self.try_clone().expect("error cloning unix stream")
}
fn read_segment(&mut self) -> Result<(u16, u32, Payload), std::io::Error> {
read_segment(self)
fn read_segment(&mut self) -> Result<Option<Segment>, std::io::Error> {
read_segment_with_timeout(self)
}
fn write_segment(
&mut self,
clock: Instant,
protocol_id: u16,
partial_payload: &[u8],
) -> Result<(), std::io::Error> {
write_segment(self, clock, protocol_id, partial_payload)
fn write_segment(&mut self, segment: Segment) -> Result<(), std::io::Error> {
write_segment(self, segment)
}
}

View file

@ -0,0 +1,83 @@
use std::collections::HashMap;
use crate::{bearers::Bearer, std::Cancel, Payload};
pub struct EgressError(pub Payload);
pub trait Egress {
fn send(&self, payload: Payload) -> Result<(), EgressError>;
}
pub enum DemuxError<B: Bearer> {
BearerError(B::Error),
EgressDisconnected(u16, Payload),
EgressUnknown(u16, Payload),
}
pub enum TickOutcome {
Busy,
Idle,
}
/// A demuxer that reads from a bearer into the corresponding egress
pub struct Demuxer<B, E> {
bearer: B,
egress: HashMap<u16, E>,
}
impl<B, E> Demuxer<B, E>
where
B: Bearer,
E: Egress,
{
pub fn new(bearer: B) -> Self {
Demuxer {
bearer,
egress: Default::default(),
}
}
pub fn register(&mut self, id: u16, tx: E) {
self.egress.insert(id, tx);
}
fn dispatch(&self, protocol: u16, payload: Payload) -> Result<(), DemuxError<B>> {
match self.egress.get(&protocol) {
Some(tx) => match tx.send(payload) {
Err(EgressError(p)) => Err(DemuxError::EgressDisconnected(protocol, p)),
Ok(_) => Ok(()),
},
None => Err(DemuxError::EgressUnknown(protocol, payload)),
}
}
pub fn tick(&mut self) -> Result<TickOutcome, DemuxError<B>> {
match self.bearer.read_segment() {
Err(err) => Err(DemuxError::BearerError(err)),
Ok(None) => Ok(TickOutcome::Idle),
Ok(Some(segment)) => match self.dispatch(segment.protocol, segment.payload) {
Err(err) => Err(err),
Ok(()) => Ok(TickOutcome::Busy),
},
}
}
pub fn block(&mut self, cancel: Cancel) -> Result<(), B::Error> {
loop {
match self.tick() {
Ok(TickOutcome::Busy) => (),
Ok(TickOutcome::Idle) => match cancel.is_set() {
true => break Ok(()),
false => (),
},
Err(DemuxError::BearerError(err)) => return Err(err),
Err(DemuxError::EgressDisconnected(id, _)) => {
log::warn!("disconnected protocol {}", id)
}
Err(DemuxError::EgressUnknown(id, _)) => {
log::warn!("unknown protocol {}", id)
}
}
}
}
}

View file

@ -1,184 +1,41 @@
mod bearers;
use std::{
collections::HashMap,
io::{Read, Write},
sync::mpsc::{self, Receiver, Sender, TryRecvError},
thread::{self, JoinHandle},
time::{Duration, Instant},
};
use log::{debug, error, warn};
pub trait Bearer: Read + Write + Send + Sync + Sized {
fn read_segment(&mut self) -> Result<(u16, u32, Payload), std::io::Error>;
fn write_segment(
&mut self,
clock: Instant,
protocol_id: u16,
partial_payload: &[u8],
) -> Result<(), std::io::Error>;
fn clone(&self) -> Self;
}
const MAX_SEGMENT_PAYLOAD_LENGTH: usize = 65535;
pub mod agents;
pub mod bearers;
pub mod demux;
pub mod mux;
pub type Payload = Vec<u8>;
enum TxStepError {
BearerError(std::io::Error),
IngressDisconnected,
IngressEmpty,
pub struct Multiplexer<B, I, E>
where
B: bearers::Bearer,
I: mux::Ingress,
E: demux::Egress,
{
pub muxer: mux::Muxer<B, I>,
pub demuxer: demux::Demuxer<B, E>,
}
fn tx_step<TBearer>(
bearer: &mut TBearer,
ingress_id: u16,
ingress_rx: &mut Receiver<Payload>,
clock: Instant,
) -> Result<(), TxStepError>
impl<B, I, E> Multiplexer<B, I, E>
where
TBearer: Bearer,
B: bearers::Bearer,
I: mux::Ingress,
E: demux::Egress,
{
match ingress_rx.try_recv() {
Ok(payload) => {
let chunks = payload.chunks(MAX_SEGMENT_PAYLOAD_LENGTH);
for chunk in chunks {
bearer
.write_segment(clock, ingress_id, chunk)
.map_err(TxStepError::BearerError)?;
}
Ok(())
}
Err(TryRecvError::Disconnected) => Err(TxStepError::IngressDisconnected),
Err(TryRecvError::Empty) => Err(TxStepError::IngressEmpty),
}
}
fn tx_loop<TBearer>(bearer: &mut TBearer, ingress: MuxIngress)
where
TBearer: Bearer,
{
let mut rx_map: HashMap<_, _> = ingress.into_iter().collect();
loop {
let clock = Instant::now();
rx_map.retain(|id, rx| match tx_step(bearer, *id, rx, clock) {
Err(TxStepError::BearerError(err)) => {
error!("{:?}", err);
panic!();
}
Err(TxStepError::IngressDisconnected) => {
warn!("protocol handle {} disconnected", id);
false
}
Err(TxStepError::IngressEmpty) => {
thread::sleep(Duration::from_millis(10));
true
}
Ok(_) => true,
});
}
}
fn rx_loop<TBearer>(bearer: &mut TBearer, egress: DemuxerEgress)
where
TBearer: Bearer,
{
let mut tx_map: HashMap<_, _> = egress.into_iter().collect();
loop {
match bearer.read_segment() {
Err(err) => {
error!("{:?}", err);
panic!();
}
Ok(segment) => {
let (id, _ts, payload) = segment;
match tx_map.get(&id) {
Some(tx) => match tx.send(payload) {
Err(err) => {
error!("error sending egress tx to protocol, removing protocol from egress output. {:?}", err);
tx_map.remove(&id);
}
Ok(_) => {
debug!("successful tx to egress protocol");
}
},
None => warn!("received segment for protocol id not being demuxed {}", id),
}
}
pub fn new(bearer: B) -> Self {
Multiplexer {
muxer: mux::Muxer::new(bearer.clone()),
demuxer: demux::Demuxer::new(bearer.clone()),
}
}
}
pub struct Channel(pub Sender<Payload>, pub Receiver<Payload>);
type ChannelProtocolHandle = (u16, Channel);
type ChannelIngressHandle = (u16, Receiver<Payload>);
type ChannelEgressHandle = (u16, Sender<Payload>);
type MuxIngress = Vec<ChannelIngressHandle>;
type DemuxerEgress = Vec<ChannelEgressHandle>;
pub struct Multiplexer {
tx_thread: JoinHandle<()>,
rx_thread: JoinHandle<()>,
io_handles: HashMap<u16, Channel>,
}
impl Multiplexer {
pub fn setup<TBearer>(
bearer: TBearer,
protocols: &[u16],
) -> Result<Multiplexer, Box<dyn std::error::Error>>
where
TBearer: Bearer + 'static,
{
let handles = protocols.iter().map(|id| {
let (demux_tx, demux_rx) = mpsc::channel::<Payload>();
let (mux_tx, mux_rx) = mpsc::channel::<Payload>();
let channel = Channel(mux_tx, demux_rx);
let protocol_handle: ChannelProtocolHandle = (*id, channel);
let ingress_handle: ChannelIngressHandle = (*id, mux_rx);
let egress_handle: ChannelEgressHandle = (*id, demux_tx);
(protocol_handle, (ingress_handle, egress_handle))
});
let (protocol_handles, multiplex_handles): (Vec<_>, Vec<_>) = handles.into_iter().unzip();
let (ingress, egress): (Vec<_>, Vec<_>) = multiplex_handles.into_iter().unzip();
let mut tx_bearer = bearer.clone();
let tx_thread = thread::spawn(move || tx_loop(&mut tx_bearer, ingress));
let mut rx_bearer = bearer.clone();
let rx_thread = thread::spawn(move || rx_loop(&mut rx_bearer, egress));
let io_handles: HashMap<u16, Channel> = protocol_handles.into_iter().collect();
Ok(Multiplexer {
io_handles,
tx_thread,
rx_thread,
})
}
pub fn use_channel(&mut self, protocol_id: u16) -> Channel {
self.io_handles
.remove(&protocol_id)
.expect("requested channel not found in multiplexer")
}
pub fn join(self) {
self.tx_thread.join().expect("error joining tx loop thread");
self.rx_thread.join().expect("error joining rx loop thread");
pub fn register_channel(&mut self, protocol: u16, ingress: I, egress: E) {
self.muxer.register(protocol, ingress);
self.demuxer.register(protocol, egress);
}
}
#[cfg(feature = "std")]
mod std;
#[cfg(feature = "std")]
pub use crate::std::*;

View file

@ -0,0 +1,122 @@
use std::{collections::HashMap, time::Instant};
use rand::seq::SliceRandom;
use rand::thread_rng;
use crate::{
bearers::{Bearer, Segment},
std::Cancel,
Payload,
};
pub enum IngressError {
Disconnected,
Empty,
}
/// Source of payloads for a particular protocol
///
/// To be implemented by any mechanism that allows to submit a payloads from a
/// particular protocol that need to be muxed by the multiplexer.
pub trait Ingress {
fn try_recv(&mut self) -> Result<Payload, IngressError>;
}
type Message = (u16, Payload);
pub enum TickOutcome<TBearer>
where
TBearer: Bearer,
{
BearerError(TBearer::Error),
Idle,
Busy,
}
pub struct Muxer<B, I> {
bearer: B,
ingress: HashMap<u16, I>,
clock: Instant,
}
impl<B, I> Muxer<B, I>
where
B: Bearer,
I: Ingress,
{
pub fn new(bearer: B) -> Self {
Self {
bearer,
ingress: Default::default(),
clock: Instant::now(),
}
}
/// Register the receiver end of an ingress channel
pub fn register(&mut self, id: u16, rx: I) {
self.ingress.insert(id, rx);
}
/// Remove a protocol from the ingress
///
/// Meant to be used after a receive error in a previous tick
pub fn deregister(&mut self, id: u16) {
self.ingress.remove(&id);
}
#[inline]
fn randomize_ids(&self) -> Vec<u16> {
let mut rng = thread_rng();
let mut keys: Vec<_> = self.ingress.keys().cloned().collect();
keys.shuffle(&mut rng);
keys
}
/// Select the next segment to be muxed
///
/// This method iterates over the existing receivers checking for the first
/// available message. The order of the checks is random to ensure a fair
/// use of the multiplexer amongst all protocols.
pub fn select(&mut self) -> Option<Message> {
for id in self.randomize_ids() {
let rx = self.ingress.get_mut(&id).unwrap();
match rx.try_recv() {
Ok(payload) => return Some((id, payload)),
Err(IngressError::Disconnected) => {
self.deregister(id);
}
_ => (),
};
}
None
}
pub fn tick(&mut self) -> TickOutcome<B> {
match self.select() {
Some((id, payload)) => {
let segment = Segment::new(self.clock, id, payload);
match self.bearer.write_segment(segment) {
Err(err) => TickOutcome::BearerError(err),
_ => TickOutcome::Busy,
}
}
None => TickOutcome::Idle,
}
}
pub fn block(&mut self, cancel: Cancel) -> Result<(), B::Error> {
loop {
match self.tick() {
TickOutcome::BearerError(err) => return Err(err),
TickOutcome::Idle => match cancel.is_set() {
true => break Ok(()),
false => std::thread::yield_now(),
},
TickOutcome::Busy => (),
}
}
}
}

View file

@ -0,0 +1,123 @@
use crate::{agents, bearers::Bearer, demux, mux, Payload};
use std::{
sync::{
atomic::{AtomicBool, Ordering},
mpsc::{channel, Receiver, SendError, Sender, TryRecvError},
Arc,
},
thread::{spawn, JoinHandle},
};
pub type StdIngress = Receiver<Payload>;
impl mux::Ingress for StdIngress {
fn try_recv(&mut self) -> Result<Payload, mux::IngressError> {
match Receiver::try_recv(self) {
Ok(x) => Ok(x),
Err(TryRecvError::Disconnected) => Err(mux::IngressError::Disconnected),
Err(TryRecvError::Empty) => Err(mux::IngressError::Empty),
}
}
}
pub type StdEgress = Sender<Payload>;
impl demux::Egress for StdEgress {
fn send(&self, payload: Payload) -> Result<(), demux::EgressError> {
match Sender::send(self, payload) {
Ok(_) => Ok(()),
Err(SendError(p)) => Err(demux::EgressError(p)),
}
}
}
pub type StdPlexer<B> = crate::Multiplexer<B, StdIngress, StdEgress>;
pub type StdChannel = (Sender<Payload>, Receiver<Payload>);
impl agents::Channel for StdChannel {
fn enqueue_chunk(&mut self, payload: Payload) -> Result<(), agents::ChannelError> {
match self.0.send(payload) {
Ok(_) => Ok(()),
Err(SendError(payload)) => Err(agents::ChannelError::NotConnected(Some(payload))),
}
}
fn dequeue_chunk(&mut self) -> Result<Payload, agents::ChannelError> {
match self.1.recv() {
Ok(payload) => Ok(payload),
Err(_) => Err(agents::ChannelError::NotConnected(None)),
}
}
}
pub fn use_channel<B: Bearer>(plexer: &mut StdPlexer<B>, protocol: u16) -> StdChannel {
let (demux_tx, demux_rx) = channel::<Payload>();
let (mux_tx, mux_rx) = channel::<Payload>();
plexer.register_channel(protocol, mux_rx, demux_tx);
(mux_tx, demux_rx)
}
#[derive(Clone, Debug, Default)]
pub struct Cancel(Arc<AtomicBool>);
impl Cancel {
pub fn set(&self) {
self.0.store(true, Ordering::SeqCst);
}
pub fn is_set(&self) -> bool {
self.0.load(Ordering::SeqCst)
}
}
#[derive(Debug)]
pub struct Loop<B>
where
B: Bearer,
{
cancel: Cancel,
thread: JoinHandle<Result<(), B::Error>>,
}
impl<B> Loop<B>
where
B: Bearer,
{
pub fn cancel(&self) {
self.cancel.set();
}
pub fn join(self) -> Result<(), B::Error> {
self.thread.join().unwrap()
}
}
pub fn spawn_muxer<B, I>(mut muxer: mux::Muxer<B, I>) -> Loop<B>
where
B: Bearer + 'static,
B::Error: Send,
I: mux::Ingress + Send + 'static,
{
let cancel = Cancel::default();
let cancel2 = cancel.clone();
let thread = spawn(move || muxer.block(cancel2));
Loop { cancel, thread }
}
pub fn spawn_demuxer<B, E>(mut demuxer: demux::Demuxer<B, E>) -> Loop<B>
where
B: Bearer + 'static,
B::Error: Send,
E: demux::Egress + Send + 'static,
{
let cancel = Cancel::default();
let cancel2 = cancel.clone();
let thread = spawn(move || demuxer.block(cancel2));
Loop { cancel, thread }
}

View file

@ -1,25 +1,37 @@
use std::{
net::{Ipv4Addr, SocketAddrV4, TcpListener, TcpStream},
thread::{self, JoinHandle},
time::Duration,
};
use log::info;
use pallas_multiplexer::{Channel, Multiplexer};
use pallas_codec::minicbor;
use pallas_multiplexer::{
agents::{Channel, ChannelBuffer},
spawn_demuxer, spawn_muxer, use_channel, StdPlexer,
};
use rand::{distributions::Uniform, Rng};
fn setup_passive_muxer<const P: u16>() -> JoinHandle<Multiplexer> {
fn setup_passive_muxer<const P: u16>() -> JoinHandle<StdPlexer<TcpStream>> {
thread::spawn(|| {
let server = TcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, P)).unwrap();
info!("listening for connections on port {}", P);
let (bearer, _) = server.accept().unwrap();
Multiplexer::setup(bearer, &[0x8003u16]).unwrap()
bearer.set_nonblocking(true).unwrap();
bearer
.set_read_timeout(Some(Duration::from_secs(3)))
.unwrap();
StdPlexer::new(bearer)
})
}
fn setup_active_muxer<const P: u16>() -> JoinHandle<Multiplexer> {
fn setup_active_muxer<const P: u16>() -> JoinHandle<StdPlexer<TcpStream>> {
thread::spawn(|| {
let bearer = TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, P)).unwrap();
Multiplexer::setup(bearer, &[0x0003u16]).unwrap()
StdPlexer::new(bearer)
})
}
@ -29,29 +41,7 @@ fn random_payload(size: usize) -> Vec<u8> {
}
#[test]
fn one_way_small_payload_is_consistent() {
let passive = setup_passive_muxer::<50201>();
// HACK: a small sleep seems to be required for Github actions runner to
// formally expose the port
thread::sleep(std::time::Duration::from_secs(1));
let active = setup_active_muxer::<50201>();
let mut active_muxer = active.join().unwrap();
let mut passive_muxer = passive.join().unwrap();
let Channel(tx, _) = active_muxer.use_channel(0x0003u16);
let Channel(_, rx) = passive_muxer.use_channel(0x8003u16);
let payload = random_payload(50);
tx.send(payload.clone()).unwrap();
let received_payload = rx.recv().unwrap();
assert_eq!(payload, received_payload)
}
#[test]
fn one_way_small_sequence_of_payloads_are_consistent() {
fn one_way_small_sequence_of_payloads() {
let passive = setup_passive_muxer::<50301>();
// HACK: a small sleep seems to be required for Github actions runner to
@ -60,16 +50,101 @@ fn one_way_small_sequence_of_payloads_are_consistent() {
let active = setup_active_muxer::<50301>();
let mut active_muxer = active.join().unwrap();
let mut passive_muxer = passive.join().unwrap();
let mut active_plexer = active.join().unwrap();
let mut passive_plexer = passive.join().unwrap();
let Channel(tx, _) = active_muxer.use_channel(0x0003u16);
let Channel(_, rx) = passive_muxer.use_channel(0x8003u16);
let mut sender_channel = use_channel(&mut active_plexer, 0x0003u16);
let mut receiver_channel = use_channel(&mut passive_plexer, 0x8003u16);
let loop1 = spawn_muxer(active_plexer.muxer);
let loop2 = spawn_demuxer(passive_plexer.demuxer);
for _ in [0..100] {
let payload = random_payload(50);
tx.send(payload.clone()).unwrap();
let received_payload = rx.recv().unwrap();
assert_eq!(payload, received_payload)
sender_channel.enqueue_chunk(payload.clone()).unwrap();
let received_payload = receiver_channel.dequeue_chunk().unwrap();
assert_eq!(payload, received_payload);
}
loop1.cancel();
loop1.join().unwrap();
loop2.cancel();
loop2.join().unwrap();
}
#[test]
fn threads_cancel_while_still_sending() {
let passive = setup_passive_muxer::<50401>();
// HACK: a small sleep seems to be required for Github actions runner to
// formally expose the port
thread::sleep(std::time::Duration::from_secs(1));
let active = setup_active_muxer::<50401>();
let mut active_plexer = active.join().unwrap();
let mut passive_plexer = passive.join().unwrap();
let mut sender_channel = use_channel(&mut active_plexer, 0x0003u16);
let mut receiver_channel = use_channel(&mut passive_plexer, 0x8003u16);
let loop1 = spawn_muxer(active_plexer.muxer);
let loop2 = spawn_demuxer(passive_plexer.demuxer);
thread::spawn(move || loop {
let payload = random_payload(50);
sender_channel.enqueue_chunk(payload.clone()).unwrap();
let received_payload = receiver_channel.dequeue_chunk().unwrap();
assert_eq!(payload, received_payload);
});
thread::sleep(Duration::from_secs(5));
loop1.cancel();
loop1.join().unwrap();
loop2.cancel();
loop2.join().unwrap();
}
#[test]
fn multiple_messages_in_same_payload() {
let mut input = Vec::new();
let in_part1 = (1u8, 2u8, 3u8);
let in_part2 = (6u8, 5u8, 4u8);
minicbor::encode(in_part1, &mut input).unwrap();
minicbor::encode(in_part2, &mut input).unwrap();
let mut channel = std::sync::mpsc::channel();
channel.0.send(input).unwrap();
let mut buf = ChannelBuffer::new(&mut channel);
let out_part1 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap();
let out_part2 = buf.recv_full_msg::<(u8, u8, u8)>().unwrap();
assert_eq!(in_part1, out_part1);
assert_eq!(in_part2, out_part2);
}
#[test]
fn fragmented_message_in_multiple_payload() {
let mut input = Vec::new();
let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8);
minicbor::encode(msg, &mut input).unwrap();
let mut channel = std::sync::mpsc::channel();
while !input.is_empty() {
let chunk = Vec::from(input.drain(0..2).as_slice());
channel.0.send(chunk).unwrap();
}
let mut buf = ChannelBuffer::new(&mut channel);
let out_msg = buf.recv_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>().unwrap();
assert_eq!(msg, out_msg);
}