diff --git a/pallas-network/src/multiplexer.rs b/pallas-network/src/multiplexer.rs index 3bd5286..c2af636 100644 --- a/pallas-network/src/multiplexer.rs +++ b/pallas-network/src/multiplexer.rs @@ -1,5 +1,7 @@ //! A multiplexer of several mini-protocols through a single bearer +use std::collections::HashMap; + use byteorder::{ByteOrder, NetworkEndian}; use pallas_codec::{minicbor, Fragment}; use thiserror::Error; @@ -7,7 +9,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::task::JoinHandle; use tokio::time::Instant; use tokio::{select, sync::mpsc::error::SendError}; -use tracing::{debug, error, trace}; +use tracing::{debug, error, trace, warn}; type IOResult = tokio::io::Result; @@ -221,18 +223,16 @@ pub enum Error { AbortFailure, } -type Egress = ( - tokio::sync::broadcast::Sender<(Protocol, Payload)>, - tokio::sync::broadcast::Receiver<(Protocol, Payload)>, -); +type EgressChannel = tokio::sync::mpsc::Sender; +type Egress = HashMap; -const EGRESS_MSG_QUEUE_BUFFER: usize = 100_000; +const EGRESS_MSG_QUEUE_BUFFER: usize = 100; pub struct Demuxer(BearerReadHalf, Egress); impl Demuxer { pub fn new(bearer: BearerReadHalf) -> Self { - let egress = tokio::sync::broadcast::channel(EGRESS_MSG_QUEUE_BUFFER); + let egress = HashMap::new(); Self(bearer, egress) } @@ -250,27 +250,35 @@ impl Demuxer { Ok((header.protocol, buf)) } - fn demux(&mut self, protocol: Protocol, payload: Payload) -> Result<(), Error> { - if tracing::event_enabled!(tracing::Level::TRACE) { - trace!(protocol, data = hex::encode(&payload), "read from bearer"); - } + async fn demux(&mut self, protocol: Protocol, payload: Payload) -> Result<(), Error> { + let channel = self.1.get(&protocol); - self.1 - .0 - .send((protocol, payload)) - .map_err(|err| Error::PlexerDemux(err.0 .0, err.0 .1))?; + if let Some(sender) = channel { + sender + .send(payload) + .await + .map_err(|err| Error::PlexerDemux(protocol, err.0))?; + } else { + warn!(protocol, "message for unregistered protocol"); + } Ok(()) } - pub fn subscribe_recv(&self) -> tokio::sync::broadcast::Receiver<(Protocol, Payload)> { - self.1 .0.subscribe() + pub fn subscribe(&mut self, protocol: Protocol) -> tokio::sync::mpsc::Receiver { + let (sender, recv) = tokio::sync::mpsc::channel(EGRESS_MSG_QUEUE_BUFFER); + + // keep track of the sender + self.1.insert(protocol, sender); + + // return the receiver for the agent + recv } pub async fn tick(&mut self) -> Result<(), Error> { let (protocol, payload) = self.read_segment().await?; trace!(protocol, "demux happening"); - self.demux(protocol, payload) + self.demux(protocol, payload).await } pub async fn run(&mut self) -> Result<(), Error> { @@ -357,11 +365,10 @@ impl Muxer { } type ToPlexerPort = tokio::sync::mpsc::Sender<(Protocol, Payload)>; -type FromPlexerPort = tokio::sync::broadcast::Receiver<(Protocol, Payload)>; +type FromPlexerPort = tokio::sync::mpsc::Receiver; pub struct AgentChannel { - enqueue_protocol: Protocol, - dequeue_protocol: Protocol, + protocol: Protocol, to_plexer: ToPlexerPort, from_plexer: FromPlexerPort, } @@ -373,8 +380,7 @@ impl AgentChannel { from_plexer: FromPlexerPort, ) -> Self { Self { - enqueue_protocol: protocol, - dequeue_protocol: protocol ^ 0x8000, + protocol, from_plexer, to_plexer, } @@ -386,8 +392,7 @@ impl AgentChannel { from_plexer: FromPlexerPort, ) -> Self { Self { - enqueue_protocol: protocol ^ 0x8000, - dequeue_protocol: protocol, + protocol, from_plexer, to_plexer, } @@ -395,24 +400,13 @@ impl AgentChannel { pub async fn enqueue_chunk(&mut self, chunk: Payload) -> Result<(), Error> { self.to_plexer - .send((self.enqueue_protocol, chunk)) + .send((self.protocol, chunk)) .await .map_err(|SendError((protocol, payload))| Error::AgentEnqueue(protocol, payload)) } pub async fn dequeue_chunk(&mut self) -> Result { - loop { - let (protocol, payload) = self - .from_plexer - .recv() - .await - .map_err(|_| Error::AgentDequeue)?; - - if protocol == self.dequeue_protocol { - trace!(protocol, "message for our protocol"); - break Ok(payload); - } - } + self.from_plexer.recv().await.ok_or(Error::AgentDequeue) } } @@ -445,14 +439,14 @@ impl Plexer { pub fn subscribe_client(&mut self, protocol: Protocol) -> AgentChannel { let to_plexer = self.muxer.clone_sender(); - let from_plexer = self.demuxer.subscribe_recv(); + let from_plexer = self.demuxer.subscribe(protocol ^ 0x8000); AgentChannel::for_client(protocol, to_plexer, from_plexer) } pub fn subscribe_server(&mut self, protocol: Protocol) -> AgentChannel { let to_plexer = self.muxer.clone_sender(); - let from_plexer = self.demuxer.subscribe_recv(); - AgentChannel::for_server(protocol, to_plexer, from_plexer) + let from_plexer = self.demuxer.subscribe(protocol); + AgentChannel::for_server(protocol ^ 0x8000, to_plexer, from_plexer) } pub fn spawn(self) -> RunningPlexer { @@ -577,11 +571,11 @@ mod tests { minicbor::encode(in_part2, &mut input).unwrap(); let (to_plexer, _) = tokio::sync::mpsc::channel(100); - let (into_plexer, from_plexer) = tokio::sync::broadcast::channel(100); + let (into_plexer, from_plexer) = tokio::sync::mpsc::channel(100); let channel = AgentChannel::for_client(0, to_plexer, from_plexer); - into_plexer.send((0x8000, input)).unwrap(); + into_plexer.send(input).await.unwrap(); let mut buf = ChannelBuffer::new(channel); @@ -599,13 +593,13 @@ mod tests { minicbor::encode(msg, &mut input).unwrap(); let (to_plexer, _) = tokio::sync::mpsc::channel(100); - let (into_plexer, from_plexer) = tokio::sync::broadcast::channel(100); + let (into_plexer, from_plexer) = tokio::sync::mpsc::channel(100); let channel = AgentChannel::for_client(0, to_plexer, from_plexer); while !input.is_empty() { let chunk = Vec::from(input.drain(0..2).as_slice()); - into_plexer.send((0x8000, chunk)).unwrap(); + into_plexer.send(chunk).await.unwrap(); } let mut buf = ChannelBuffer::new(channel);