From 1fbb4efeefee016372ed4263e483d450e9e76c1e Mon Sep 17 00:00:00 2001 From: Santiago Carmuega Date: Wed, 16 Mar 2022 18:17:10 -0300 Subject: [PATCH] fix(miniprotocols): Handle regression related to multi-msg payloads (#76) --- examples/block-download/src/main.rs | 14 ++--- pallas-codec/src/lib.rs | 20 +------- pallas-miniprotocols/src/machines.rs | 7 ++- pallas-miniprotocols/src/payloads.rs | 77 +++++++++++++++++++++++++--- 4 files changed, 79 insertions(+), 39 deletions(-) diff --git a/examples/block-download/src/main.rs b/examples/block-download/src/main.rs index 517e796..95d2176 100644 --- a/examples/block-download/src/main.rs +++ b/examples/block-download/src/main.rs @@ -8,10 +8,7 @@ use pallas::network::{ multiplexer::Multiplexer, }; -use pallas::{ - ledger::primitives::{alonzo::*, Fragment}, - network::miniprotocols::blockfetch::{BatchClient, Observer}, -}; +use pallas::network::miniprotocols::blockfetch::{BatchClient, Observer}; use std::net::TcpStream; @@ -22,7 +19,6 @@ impl Observer for BlockPrinter { fn on_block_received(&self, body: Vec) -> Result<(), Box> { println!("{}", hex::encode(&body)); println!("----------"); - BlockWrapper::decode_fragment(&body[..])?; Ok(()) } } @@ -42,13 +38,13 @@ fn main() { let range = ( Point::Specific( - 4492794, - hex::decode("5c196e7394ace0449ba5a51c919369699b13896e97432894b4f0354dce8670b6") + 97, + hex::decode("cf7fa60bbd210273d79fa48d11ab1d141242af32b231cc40ce3411230a8d3c61") .unwrap(), ), Point::Specific( - 4492794, - hex::decode("5c196e7394ace0449ba5a51c919369699b13896e97432894b4f0354dce8670b6") + 99, + hex::decode("a52cca923a67326ea9c409e958a17a77990be72f3607625ec5b3d456202e223e") .unwrap(), ), ); diff --git a/pallas-codec/src/lib.rs b/pallas-codec/src/lib.rs index 06bdbfd..db64582 100644 --- a/pallas-codec/src/lib.rs +++ b/pallas-codec/src/lib.rs @@ -1,25 +1,9 @@ -use minicbor::encode::Write; - /// Shared re-export of minicbor lib across all Pallas pub use minicbor; /// Round-trip friendly common helper structs pub mod utils; -pub trait Fragment: Sized { - fn read_cbor(buffer: &[u8]) -> Result; - fn write_cbor(&self, write: W) -> Result<(), minicbor::encode::Error>; -} +pub trait Fragment: Sized + for<'b> minicbor::Decode<'b> + minicbor::Encode {} -impl Fragment for T -where - T: for<'b> minicbor::Decode<'b> + minicbor::Encode, -{ - fn read_cbor(buffer: &[u8]) -> Result { - minicbor::decode(buffer) - } - - fn write_cbor(&self, write: W) -> Result<(), minicbor::encode::Error> { - minicbor::encode(self, write) - } -} +impl Fragment for T where T: for<'b> minicbor::Decode<'b> + minicbor::Encode + Sized {} diff --git a/pallas-miniprotocols/src/machines.rs b/pallas-miniprotocols/src/machines.rs index 15f8c69..060a8a7 100644 --- a/pallas-miniprotocols/src/machines.rs +++ b/pallas-miniprotocols/src/machines.rs @@ -1,5 +1,5 @@ pub use crate::payloads::*; -use pallas_codec::Fragment; +use pallas_codec::{minicbor, Fragment}; use pallas_multiplexer::{Channel, Payload}; use std::fmt::{Debug, Display}; use std::sync::mpsc::Sender; @@ -66,7 +66,7 @@ pub trait MachineOutput { impl MachineOutput for Sender { fn send_msg(&self, data: &impl Fragment) -> Result<(), Box> { let mut payload = Vec::new(); - data.write_cbor(&mut payload)?; + minicbor::encode(data, &mut payload)?; self.send(payload)?; Ok(()) @@ -92,6 +92,7 @@ where let Channel(tx, rx) = channel; let mut agent = agent; + let mut buffer = Vec::new(); while !agent.is_done() { log::debug!("evaluating agent {:?}", agent); @@ -101,8 +102,6 @@ where agent = agent.send_next(tx)?; } false => { - let mut buffer = Vec::new(); - let msg = read_until_full_msg::(&mut buffer, rx).unwrap(); log::trace!("procesing inbound msg: {:?}", msg); agent = agent.receive_next(msg)?; diff --git a/pallas-miniprotocols/src/payloads.rs b/pallas-miniprotocols/src/payloads.rs index 08f437b..c053e20 100644 --- a/pallas-miniprotocols/src/payloads.rs +++ b/pallas-miniprotocols/src/payloads.rs @@ -1,11 +1,11 @@ -use pallas_codec::Fragment; +use pallas_codec::{minicbor, Fragment}; use pallas_multiplexer::Payload; use std::sync::mpsc::Receiver; pub type Error = Box; enum Decoding { - Done(M), + Done(M, usize), NotEnoughData, UnexpectedError(Error), } @@ -14,15 +14,17 @@ fn try_decode_message(buffer: &[u8]) -> Decoding where M: Fragment, { - let maybe_msg: Result = M::read_cbor(buffer); + let mut decoder = minicbor::Decoder::new(buffer); + let maybe_msg = decoder.decode(); match maybe_msg { - Ok(msg) => Decoding::Done(msg), + Ok(msg) => Decoding::Done(msg, decoder.position()), Err(err) if err.is_end_of_input() => Decoding::NotEnoughData, Err(err) => Decoding::UnexpectedError(Box::new(err)), } } +/// Reads from the receiver until a complete message is found pub fn read_until_full_msg( buffer: &mut Vec, receiver: &mut Receiver, @@ -30,14 +32,73 @@ pub fn read_until_full_msg( where M: Fragment, { - let chunk = receiver.recv()?; - buffer.extend(chunk); + // do an eager reading if buffer is empty, no point in going through the error + // handling + if buffer.is_empty() { + let chunk = receiver.recv()?; + buffer.extend(chunk); + } let decoding = try_decode_message::(buffer); match decoding { - Decoding::Done(msg) => Ok(msg), + Decoding::Done(msg, pos) => { + buffer.drain(0..pos); + Ok(msg) + } Decoding::UnexpectedError(err) => Err(err), - Decoding::NotEnoughData => read_until_full_msg::(buffer, receiver), + Decoding::NotEnoughData => { + let chunk = receiver.recv()?; + buffer.extend(chunk); + + read_until_full_msg::(buffer, receiver) + } + } +} + +#[cfg(test)] +mod tests { + use crate::read_until_full_msg; + use pallas_codec::minicbor; + use std::sync::mpsc::channel; + + #[test] + fn multiple_messages_in_same_payload() { + let mut input = Vec::new(); + let in_part1 = (1u8, 2u8, 3u8); + let in_part2 = (6u8, 5u8, 4u8); + + minicbor::encode(in_part1, &mut input).unwrap(); + minicbor::encode(in_part2, &mut input).unwrap(); + + let (tx, mut rx) = channel(); + tx.send(input).unwrap(); + + let mut output = Vec::new(); + let out_part1 = read_until_full_msg::<(u8, u8, u8)>(&mut output, &mut rx).unwrap(); + let out_part2 = read_until_full_msg::<(u8, u8, u8)>(&mut output, &mut rx).unwrap(); + + assert_eq!(in_part1, out_part1); + assert_eq!(in_part2, out_part2); + } + + #[test] + fn fragmented_message_in_multiple_payload() { + let mut input = Vec::new(); + let msg = (11u8, 12u8, 13u8, 14u8, 15u8, 16u8, 17u8); + minicbor::encode(msg, &mut input).unwrap(); + + let (tx, mut rx) = channel(); + + while !input.is_empty() { + let chunk = Vec::from(input.drain(0..2).as_slice()); + tx.send(chunk).unwrap(); + } + + let mut output = Vec::new(); + let out_msg = + read_until_full_msg::<(u8, u8, u8, u8, u8, u8, u8)>(&mut output, &mut rx).unwrap(); + + assert_eq!(msg, out_msg); } }