diff --git a/examples/block-download/src/main.rs b/examples/block-download/src/main.rs index f8fc235..271cf9b 100644 --- a/examples/block-download/src/main.rs +++ b/examples/block-download/src/main.rs @@ -1,53 +1,39 @@ use pallas::network::{ miniprotocols::{ - handshake::{n2n::VersionTable, Initiator}, - run_agent, Point, TESTNET_MAGIC, + blockfetch, + handshake::{self, n2n::VersionTable}, + Point, TESTNET_MAGIC, }, multiplexer::{bearers::Bearer, StdPlexer}, }; -use pallas::network::miniprotocols::blockfetch::{BatchClient, Observer}; - -#[derive(Debug)] -struct BlockPrinter; - -impl Observer for BlockPrinter { - fn on_block_received(&mut self, body: Vec) -> Result<(), Box> { - println!("{}", hex::encode(&body)); - println!("----------"); - Ok(()) - } -} - fn main() { env_logger::init(); let bearer = Bearer::connect_tcp("relays-new.cardano-testnet.iohkdev.io:3001").unwrap(); let mut plexer = StdPlexer::new(bearer); - let mut channel0 = plexer.use_channel(0).into(); - let mut channel3 = plexer.use_channel(3).into(); + let channel0 = plexer.use_channel(0); + let channel3 = plexer.use_channel(3); plexer.muxer.spawn(); plexer.demuxer.spawn(); let versions = VersionTable::v4_and_above(TESTNET_MAGIC); - let _last = run_agent(Initiator::initial(versions), &mut channel0).unwrap(); + let mut hs_client = handshake::N2NClient::new(channel0); + let handshake = hs_client.handshake(versions).unwrap(); - let range = ( - Point::Specific( - 63528597, - hex::decode("3f3d81c7b88f0fa28867541c5fea8794125cccf6d6c9ee0037a1dbb064130dfd") - .unwrap(), - ), - Point::Specific( - 63528597, - hex::decode("3f3d81c7b88f0fa28867541c5fea8794125cccf6d6c9ee0037a1dbb064130dfd") - .unwrap(), - ), + assert!(matches!(handshake, handshake::Confirmation::Accepted(..))); + + let point = Point::Specific( + 63528597, + hex::decode("3f3d81c7b88f0fa28867541c5fea8794125cccf6d6c9ee0037a1dbb064130dfd").unwrap(), ); - let bf = BatchClient::initial(range, BlockPrinter {}); - let bf_last = run_agent(bf, &mut channel3); - println!("{:?}", bf_last); + let mut bf_client = blockfetch::Client::new(channel3); + + let block = bf_client.fetch_single(point).unwrap(); + + println!("downloaded block of size: {}", block.len()); + println!("{}", hex::encode(&block)); } diff --git a/examples/n2c-miniprotocols/src/main.rs b/examples/n2c-miniprotocols/src/main.rs index dc36162..9a26332 100644 --- a/examples/n2c-miniprotocols/src/main.rs +++ b/examples/n2c-miniprotocols/src/main.rs @@ -1,80 +1,62 @@ use pallas::network::{ - miniprotocols::{chainsync, handshake, localstate, run_agent, Point, MAINNET_MAGIC}, + miniprotocols::{chainsync, handshake, localstate, Point, MAINNET_MAGIC}, multiplexer::{self, bearers::Bearer}, }; #[derive(Debug)] struct LoggingObserver; -impl chainsync::Observer for LoggingObserver { - fn on_roll_forward( - &mut self, - _content: chainsync::HeaderContent, - tip: &chainsync::Tip, - ) -> Result> { - log::debug!("asked to roll forward, tip at {:?}", tip); +fn do_handshake(channel: multiplexer::StdChannel) { + let mut client = handshake::N2CClient::new(channel); - Ok(chainsync::Continuation::Proceed) - } + let confirmation = client + .handshake(handshake::n2c::VersionTable::v1_and_above(MAINNET_MAGIC)) + .unwrap(); - fn on_intersect_found( - &mut self, - point: &Point, - tip: &chainsync::Tip, - ) -> Result> { - log::debug!("intersect was found {:?} (tip: {:?})", point, tip); - - Ok(chainsync::Continuation::Proceed) - } - - fn on_rollback( - &mut self, - point: &Point, - ) -> Result> { - log::debug!("asked to roll back {:?}", point); - - Ok(chainsync::Continuation::Proceed) - } - - fn on_tip_reached(&mut self) -> Result> { - log::debug!("tip was reached"); - - Ok(chainsync::Continuation::Proceed) + match confirmation { + handshake::Confirmation::Accepted(v, _) => { + log::info!("hand-shake accepted, using version {}", v) + } + handshake::Confirmation::Rejected(x) => { + log::info!("hand-shake rejected with reason {:?}", x) + } } } -fn do_handshake(mut channel: multiplexer::StdChannelBuffer) { - let versions = handshake::n2c::VersionTable::v1_and_above(MAINNET_MAGIC); - let _last = run_agent(handshake::Initiator::initial(versions), &mut channel).unwrap(); +fn do_localstate_query(channel: multiplexer::StdChannel) { + let mut client = localstate::ClientV10::new(channel); + client.acquire(None).unwrap(); + + let result = client + .query(localstate::queries::RequestV10::GetSystemStart) + .unwrap(); + + log::info!("system start result: {:?}", result); } -fn do_localstate_query(mut channel: multiplexer::StdChannelBuffer) { - let agent = run_agent( - localstate::OneShotClient::::initial( - None, - localstate::queries::RequestV10::GetChainPoint, - ), - &mut channel, - ); - - log::info!("state query result: {:?}", agent); -} - -fn do_chainsync(mut channel: multiplexer::StdChannelBuffer) { +fn do_chainsync(channel: multiplexer::StdChannel) { let known_points = vec![Point::Specific( 43847831u64, hex::decode("15b9eeee849dd6386d3770b0745e0450190f7560e5159b1b3ab13b14b2684a45").unwrap(), )]; - let agent = run_agent( - chainsync::Consumer::::initial( - Some(known_points), - LoggingObserver {}, - ), - &mut channel, - ); + let mut client = chainsync::N2CClient::new(channel); - println!("{:?}", agent); + let (point, _) = client.find_intersect(known_points).unwrap(); + + log::info!("intersected point is {:?}", point); + + for _ in 0..10 { + let next = client.request_next().unwrap(); + + match next { + chainsync::NextResponse::RollForward(h, _) => { + log::info!("rolling forward, block size: {}", h.len()) + } + chainsync::NextResponse::RollBackward(x, _) => log::info!("rollback to {:?}", x), + chainsync::NextResponse::Await => log::info!("tip of chaing reached"), + }; + } } fn main() { @@ -89,9 +71,9 @@ fn main() { // setup the multiplexer by specifying the bearer and the IDs of the // miniprotocols to use let mut plexer = multiplexer::StdPlexer::new(bearer); - let channel0 = plexer.use_channel(0).into(); - let channel7 = plexer.use_channel(7).into(); - let channel5 = plexer.use_channel(5).into(); + let channel0 = plexer.use_channel(0); + let channel7 = plexer.use_channel(7); + let channel5 = plexer.use_channel(5); plexer.muxer.spawn(); plexer.demuxer.spawn(); diff --git a/examples/n2n-miniprotocols/src/main.rs b/examples/n2n-miniprotocols/src/main.rs index 86c2cbb..30e5600 100644 --- a/examples/n2n-miniprotocols/src/main.rs +++ b/examples/n2n-miniprotocols/src/main.rs @@ -1,61 +1,29 @@ use pallas::network::{ - miniprotocols::{blockfetch, chainsync, handshake, run_agent, Point, MAINNET_MAGIC}, - multiplexer::{agents::ChannelBuffer, bearers::Bearer, StdChannel, StdPlexer}, + miniprotocols::{blockfetch, chainsync, handshake, Point, MAINNET_MAGIC}, + multiplexer::{bearers::Bearer, StdChannel, StdPlexer}, }; #[derive(Debug)] struct LoggingObserver; -impl blockfetch::Observer for LoggingObserver { - fn on_block_received(&mut self, body: Vec) -> Result<(), Box> { - log::trace!("block received: {}", hex::encode(&body)); - Ok(()) +fn do_handshake(channel: StdChannel) { + let mut client = handshake::N2NClient::new(channel); + + let confirmation = client + .handshake(handshake::n2n::VersionTable::v7_and_above(MAINNET_MAGIC)) + .unwrap(); + + match confirmation { + handshake::Confirmation::Accepted(v, _) => { + log::info!("hand-shake accepted, using version {}", v) + } + handshake::Confirmation::Rejected(x) => { + log::info!("hand-shake rejected with reason {:?}", x) + } } } -impl chainsync::Observer for LoggingObserver { - fn on_roll_forward( - &mut self, - _content: chainsync::HeaderContent, - tip: &chainsync::Tip, - ) -> Result> { - log::info!("asked to roll forward, tip at {:?}", tip); - - Ok(chainsync::Continuation::Proceed) - } - - fn on_intersect_found( - &mut self, - point: &Point, - tip: &chainsync::Tip, - ) -> Result> { - log::debug!("intersect was found {:?} (tip: {:?})", point, tip); - - Ok(chainsync::Continuation::Proceed) - } - - fn on_rollback( - &mut self, - point: &Point, - ) -> Result> { - log::debug!("asked to roll back {:?}", point); - - Ok(chainsync::Continuation::Proceed) - } - - fn on_tip_reached(&mut self) -> Result> { - log::debug!("tip was reached"); - - Ok(chainsync::Continuation::Proceed) - } -} - -fn do_handshake(mut channel: ChannelBuffer) { - let versions = handshake::n2n::VersionTable::v4_and_above(MAINNET_MAGIC); - let _last = run_agent(handshake::Initiator::initial(versions), &mut channel).unwrap(); -} - -fn do_blockfetch(mut channel: ChannelBuffer) { +fn do_blockfetch(channel: StdChannel) { let range = ( Point::Specific( 43847831, @@ -69,29 +37,38 @@ fn do_blockfetch(mut channel: ChannelBuffer) { ), ); - let agent = run_agent( - blockfetch::BatchClient::initial(range, LoggingObserver {}), - &mut channel, - ); + let mut client = blockfetch::Client::new(channel); - println!("{:?}", agent); + let blocks = client.fetch_range(range).unwrap(); + + for block in blocks { + log::info!("received block of size: {}", block.len()); + } } -fn do_chainsync(mut channel: ChannelBuffer) { +fn do_chainsync(channel: StdChannel) { let known_points = vec![Point::Specific( 43847831u64, hex::decode("15b9eeee849dd6386d3770b0745e0450190f7560e5159b1b3ab13b14b2684a45").unwrap(), )]; - let agent = run_agent( - chainsync::Consumer::::initial( - Some(known_points), - LoggingObserver {}, - ), - &mut channel, - ); + let mut client = chainsync::N2NClient::new(channel); - println!("{:?}", agent); + let (point, _) = client.find_intersect(known_points).unwrap(); + + log::info!("intersected point is {:?}", point); + + for _ in 0..10 { + let next = client.request_next().unwrap(); + + match next { + chainsync::NextResponse::RollForward(h, _) => { + log::info!("rolling forward, header size: {}", h.cbor.len()) + } + chainsync::NextResponse::RollBackward(x, _) => log::info!("rollback to {:?}", x), + chainsync::NextResponse::Await => log::info!("tip of chaing reached"), + }; + } } fn main() { @@ -106,9 +83,9 @@ fn main() { // setup the multiplexer by specifying the bearer and the IDs of the // miniprotocols to use let mut plexer = StdPlexer::new(bearer); - let channel0 = plexer.use_channel(0).into(); - let channel3 = plexer.use_channel(3).into(); - let channel2 = plexer.use_channel(2).into(); + let channel0 = plexer.use_channel(0); + let channel3 = plexer.use_channel(3); + let channel2 = plexer.use_channel(2); plexer.muxer.spawn(); plexer.demuxer.spawn(); diff --git a/pallas-miniprotocols/Cargo.toml b/pallas-miniprotocols/Cargo.toml index e0560d7..c86ae00 100644 --- a/pallas-miniprotocols/Cargo.toml +++ b/pallas-miniprotocols/Cargo.toml @@ -17,3 +17,7 @@ log = "0.4.14" hex = "0.4.3" itertools = "0.10.3" thiserror = "1.0.31" + +[dev-dependencies] +env_logger = "0.9.0" +log = "0.4.16" diff --git a/pallas-miniprotocols/src/blockfetch/client.rs b/pallas-miniprotocols/src/blockfetch/client.rs new file mode 100644 index 0000000..1ec3ddf --- /dev/null +++ b/pallas-miniprotocols/src/blockfetch/client.rs @@ -0,0 +1,186 @@ +use pallas_codec::Fragment; +use pallas_multiplexer::agents::{Channel, ChannelBuffer, ChannelError}; +use thiserror::Error; + +use crate::common::Point; + +use super::{Message, State}; + +#[derive(Error, Debug)] +pub enum Error { + #[error("attempted to receive message while agency is ours")] + AgencyIsOurs, + + #[error("attempted to send message while agency is theirs")] + AgencyIsTheirs, + + #[error("inbound message is not valid for current state")] + InvalidInbound, + + #[error("outbound message is not valid for current state")] + InvalidOutbound, + + #[error("requested range doesn't contain any blocks")] + NoBlocks, + + #[error("error while sending or receiving data through the channel")] + ChannelError(ChannelError), +} + +pub type Body = Vec; + +pub type Range = (Point, Point); + +pub type HasBlocks = Option<()>; + +pub struct Client(State, ChannelBuffer) +where + H: Channel, + Message: Fragment; + +impl Client +where + H: Channel, + Message: Fragment, +{ + pub fn new(channel: H) -> Self { + Self(State::Idle, ChannelBuffer::new(channel)) + } + + pub fn state(&self) -> &State { + &self.0 + } + + pub fn is_done(&self) -> bool { + self.0 == State::Done + } + + fn has_agency(&self) -> bool { + match self.state() { + State::Idle => true, + State::Busy => false, + State::Streaming => false, + State::Done => false, + } + } + + fn assert_agency_is_ours(&self) -> Result<(), Error> { + if !self.has_agency() { + Err(Error::AgencyIsTheirs) + } else { + Ok(()) + } + } + + fn assert_agency_is_theirs(&self) -> Result<(), Error> { + if self.has_agency() { + Err(Error::AgencyIsOurs) + } else { + Ok(()) + } + } + + fn assert_outbound_state(&self, msg: &Message) -> Result<(), Error> { + match (&self.0, msg) { + (State::Idle, Message::RequestRange { .. }) => Ok(()), + (State::Idle, Message::ClientDone) => Ok(()), + _ => Err(Error::InvalidOutbound), + } + } + + fn assert_inbound_state(&self, msg: &Message) -> Result<(), Error> { + match (&self.0, msg) { + (State::Busy, Message::StartBatch) => Ok(()), + (State::Busy, Message::NoBlocks) => Ok(()), + (State::Streaming, Message::Block { .. }) => Ok(()), + (State::Streaming, Message::BatchDone) => Ok(()), + _ => Err(Error::InvalidInbound), + } + } + + pub fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + self.assert_agency_is_ours()?; + self.assert_outbound_state(msg)?; + self.1.send_msg_chunks(msg).map_err(Error::ChannelError)?; + + Ok(()) + } + + pub fn recv_message(&mut self) -> Result { + self.assert_agency_is_theirs()?; + let msg = self.1.recv_full_msg().map_err(Error::ChannelError)?; + self.assert_inbound_state(&msg)?; + + Ok(msg) + } + + pub fn send_request_range(&mut self, range: (Point, Point)) -> Result<(), Error> { + let msg = Message::RequestRange { range }; + self.send_message(&msg)?; + self.0 = State::Busy; + + Ok(()) + } + + pub fn recv_while_busy(&mut self) -> Result { + match self.recv_message()? { + Message::StartBatch => { + self.0 = State::Streaming; + Ok(Some(())) + } + Message::NoBlocks => { + self.0 = State::Idle; + Ok(None) + } + _ => Err(Error::InvalidInbound), + } + } + + pub fn request_range(&mut self, range: Range) -> Result { + self.send_request_range(range)?; + self.recv_while_busy() + } + + pub fn recv_while_streaming(&mut self) -> Result, Error> { + match self.recv_message()? { + Message::Block { body } => Ok(Some(body)), + Message::BatchDone => { + self.0 = State::Idle; + Ok(None) + } + _ => Err(Error::InvalidInbound), + } + } + + pub fn fetch_single(&mut self, point: Point) -> Result { + self.request_range((point.clone(), point))? + .ok_or(Error::NoBlocks)?; + + let body = self.recv_while_streaming()?.ok_or(Error::InvalidInbound)?; + + match self.recv_while_streaming()? { + Some(_) => Err(Error::InvalidInbound), + None => Ok(body), + } + } + + pub fn fetch_range(&mut self, range: Range) -> Result, Error> { + self.request_range(range)?.ok_or(Error::NoBlocks)?; + + let mut all = vec![]; + + while let Some(block) = self.recv_while_streaming()? { + all.push(block); + } + + Ok(all) + } + + pub fn send_done(&mut self) -> Result<(), Error> { + let msg = Message::ClientDone; + self.send_message(&msg)?; + self.0 = State::Done; + + Ok(()) + } +} diff --git a/pallas-miniprotocols/src/blockfetch/codec.rs b/pallas-miniprotocols/src/blockfetch/codec.rs new file mode 100644 index 0000000..de34de3 --- /dev/null +++ b/pallas-miniprotocols/src/blockfetch/codec.rs @@ -0,0 +1,73 @@ +use pallas_codec::minicbor::{data::Tag, decode, encode, Decode, Decoder, Encode, Encoder}; + +use super::Message; + +impl Encode<()> for Message { + fn encode( + &self, + e: &mut Encoder, + _ctx: &mut (), + ) -> Result<(), encode::Error> { + match self { + Message::RequestRange { range } => { + e.array(3)?.u16(0)?; + e.encode(&range.0)?; + e.encode(&range.1)?; + Ok(()) + } + Message::ClientDone => { + e.array(1)?.u16(1)?; + Ok(()) + } + Message::StartBatch => { + e.array(1)?.u16(2)?; + Ok(()) + } + Message::NoBlocks => { + e.array(1)?.u16(3)?; + Ok(()) + } + Message::Block { body } => { + e.array(2)?.u16(4)?; + e.tag(Tag::Cbor)?; + e.bytes(body)?; + Ok(()) + } + Message::BatchDone => { + e.array(1)?.u16(5)?; + Ok(()) + } + } + } +} + +impl<'b> Decode<'b, ()> for Message { + fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result { + d.array()?; + let label = d.u16()?; + + match label { + 0 => { + let point1 = d.decode()?; + let point2 = d.decode()?; + Ok(Message::RequestRange { + range: (point1, point2), + }) + } + 1 => Ok(Message::ClientDone), + 2 => Ok(Message::StartBatch), + 3 => Ok(Message::NoBlocks), + 4 => { + d.tag()?; + let body = d.bytes()?; + Ok(Message::Block { + body: Vec::from(body), + }) + } + 5 => Ok(Message::BatchDone), + _ => Err(decode::Error::message( + "unknown variant for blockfetch message", + )), + } + } +} diff --git a/pallas-miniprotocols/src/blockfetch/mod.rs b/pallas-miniprotocols/src/blockfetch/mod.rs index 23547f9..a9eaa04 100644 --- a/pallas-miniprotocols/src/blockfetch/mod.rs +++ b/pallas-miniprotocols/src/blockfetch/mod.rs @@ -1,386 +1,7 @@ -use crate::machines::{Agent, Transition}; -use crate::MachineError; +mod client; +mod codec; +mod protocol; -use crate::common::Point; - -use pallas_codec::minicbor::{decode, encode, Decode, Decoder, Encode, Encoder}; - -#[derive(Debug, PartialEq, Eq, Clone)] -pub enum State { - Idle, - Busy, - Streaming, - Done, -} - -#[derive(Debug)] -pub enum Message { - RequestRange { range: (Point, Point) }, - ClientDone, - StartBatch, - NoBlocks, - Block { body: Vec }, - BatchDone, -} - -impl Encode<()> for Message { - fn encode( - &self, - e: &mut Encoder, - _ctx: &mut (), - ) -> Result<(), encode::Error> { - match self { - Message::RequestRange { range } => { - e.array(3)?.u16(0)?; - e.encode(&range.0)?; - e.encode(&range.1)?; - Ok(()) - } - Message::ClientDone => { - e.array(1)?.u16(1)?; - Ok(()) - } - Message::StartBatch => { - e.array(1)?.u16(2)?; - Ok(()) - } - Message::NoBlocks => { - e.array(1)?.u16(3)?; - Ok(()) - } - Message::Block { body } => { - e.array(2)?.u16(4)?; - e.bytes(body)?; - Ok(()) - } - Message::BatchDone => { - e.array(1)?.u16(5)?; - Ok(()) - } - } - } -} - -impl<'b> Decode<'b, ()> for Message { - fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result { - d.array()?; - let label = d.u16()?; - - match label { - 0 => { - let point1 = d.decode()?; - let point2 = d.decode()?; - Ok(Message::RequestRange { - range: (point1, point2), - }) - } - 1 => Ok(Message::ClientDone), - 2 => Ok(Message::StartBatch), - 3 => Ok(Message::NoBlocks), - 4 => { - d.tag()?; - let body = d.bytes()?; - Ok(Message::Block { - body: Vec::from(body), - }) - } - 5 => Ok(Message::BatchDone), - _ => Err(decode::Error::message( - "unknown variant for blockfetch message", - )), - } - } -} - -pub trait Observer { - fn on_block_received(&mut self, body: Vec) -> Result<(), Box> { - log::debug!("block received, sice: {}", body.len()); - Ok(()) - } - - fn on_block_range_requested( - &self, - range: &(Point, Point), - ) -> Result<(), Box> { - log::debug!( - "block range requested, from: {:?}, to: {:?}", - range.0, - range.1 - ); - Ok(()) - } -} - -#[derive(Debug)] -pub struct NoopObserver {} - -impl Observer for NoopObserver {} - -#[derive(Debug)] -pub struct BatchClient -where - O: Observer, -{ - pub state: State, - pub range: (Point, Point), - pub observer: O, -} - -impl BatchClient -where - O: Observer, -{ - pub fn initial(range: (Point, Point), observer: O) -> Self { - Self { - state: State::Idle, - range, - observer, - } - } - - fn request_range_msg(&self) -> Message { - Message::RequestRange { - range: self.range.clone(), - } - } - - fn on_range_requested(self) -> Transition { - log::debug!("block range requested"); - - Ok(Self { - state: State::Busy, - ..self - }) - } - - fn on_block(mut self, body: Vec) -> Transition { - log::debug!("received block body, size {}", body.len()); - - self.observer - .on_block_received(body) - .map_err(MachineError::downstream)?; - - Ok(self) - } - - fn on_batch_done(self) -> Transition { - Ok(Self { - state: State::Done, - ..self - }) - } - - fn on_client_done(self) -> Transition { - Ok(Self { - state: State::Done, - ..self - }) - } -} - -impl Agent for BatchClient -where - O: Observer, -{ - type Message = Message; - type State = State; - - fn state(&self) -> &Self::State { - &self.state - } - - fn is_done(&self) -> bool { - self.state == State::Done - } - - fn has_agency(&self) -> bool { - match self.state { - State::Idle => true, - State::Busy => false, - State::Streaming => false, - State::Done => false, - } - } - - fn build_next(&self) -> Self::Message { - match self.state { - State::Idle => self.request_range_msg(), - _ => panic!("I don't have agency, don't know what to do"), - } - } - - fn apply_start(self) -> Transition { - Ok(Self { - state: State::Idle, - ..self - }) - } - - fn apply_outbound(self, msg: Self::Message) -> Transition { - match (&self.state, msg) { - (State::Idle, Message::RequestRange { .. }) => self.on_range_requested(), - (State::Idle, Message::ClientDone) => self.on_client_done(), - _ => panic!("I don't have agency, I don't expect outbound message"), - } - } - - fn apply_inbound(self, msg: Self::Message) -> Transition { - match (&self.state, msg) { - (State::Busy, Message::StartBatch) => Ok(Self { - state: State::Streaming, - ..self - }), - (State::Busy, Message::NoBlocks) => Ok(Self { - state: State::Done, - ..self - }), - (State::Streaming, Message::Block { body }) => self.on_block(body), - (State::Streaming, Message::BatchDone) => self.on_batch_done(), - _ => panic!("I have agency, I don't expect messages"), - } - } -} - -#[derive(Debug)] -pub struct OnDemandClient -where - I: Iterator, - O: Observer, -{ - pub state: State, - pub inflight: Option<(Point, Point)>, - pub next: Option<(Point, Point)>, - pub requests: I, - pub observer: O, -} - -impl OnDemandClient -where - I: Iterator, - O: Observer, -{ - pub fn initial(requests: I, observer: O) -> Self { - Self { - state: State::Idle, - inflight: None, - next: None, - requests, - observer, - } - } - - fn wait_for_request(mut self) -> Transition { - log::debug!("waiting for external block request"); - - let next = self.requests.next(); - - match next { - Some(x) => Ok(Self { - state: State::Idle, - next: Some((x.clone(), x)), - ..self - }), - None => Ok(Self { - state: State::Done, - next: None, - ..self - }), - } - } - - fn on_range_requested(self, range: (Point, Point)) -> Transition { - log::debug!("requested block range {:?}", range); - - Ok(Self { - state: State::Busy, - inflight: Some(range), - next: None, - ..self - }) - } - - fn on_block(mut self, body: Vec) -> Transition { - log::debug!("received block body, size {}", body.len()); - - self.observer - .on_block_received(body) - .map_err(MachineError::downstream)?; - - Ok(self) - } - - fn on_batch_done(self) -> Transition { - self.wait_for_request() - } - - fn on_client_done(self) -> Transition { - Ok(Self { - state: State::Done, - ..self - }) - } -} - -impl Agent for OnDemandClient -where - I: Iterator, - O: Observer, -{ - type Message = Message; - type State = State; - - fn state(&self) -> &Self::State { - &self.state - } - - fn is_done(&self) -> bool { - self.state == State::Done - } - - fn has_agency(&self) -> bool { - match self.state { - State::Idle => true, - State::Busy => false, - State::Streaming => false, - State::Done => false, - } - } - - fn build_next(&self) -> Self::Message { - match (&self.state, &self.next) { - (State::Idle, Some(range)) => Message::RequestRange { - range: range.clone(), - }, - (State::Idle, None) => panic!("I'm idle but no more block requests available"), - _ => panic!("I don't have agency, don't know what to do"), - } - } - - fn apply_start(self) -> Transition { - self.wait_for_request() - } - - fn apply_outbound(self, msg: Self::Message) -> Transition { - match (&self.state, msg) { - (State::Idle, Message::RequestRange { range }) => self.on_range_requested(range), - (State::Idle, Message::ClientDone) => self.on_client_done(), - _ => panic!("I don't have agency, I don't expect outbound message"), - } - } - - fn apply_inbound(self, msg: Self::Message) -> Transition { - match (&self.state, msg) { - (State::Busy, Message::StartBatch) => Ok(Self { - state: State::Streaming, - ..self - }), - (State::Busy, Message::NoBlocks) => Ok(Self { - state: State::Idle, - ..self - }), - (State::Streaming, Message::Block { body }) => self.on_block(body), - (State::Streaming, Message::BatchDone) => self.on_batch_done(), - _ => panic!("I have agency, I don't expect inbound message"), - } - } -} +pub use client::*; +pub use codec::*; +pub use protocol::*; diff --git a/pallas-miniprotocols/src/blockfetch/protocol.rs b/pallas-miniprotocols/src/blockfetch/protocol.rs new file mode 100644 index 0000000..594626e --- /dev/null +++ b/pallas-miniprotocols/src/blockfetch/protocol.rs @@ -0,0 +1,19 @@ +use crate::Point; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum State { + Idle, + Busy, + Streaming, + Done, +} + +#[derive(Debug)] +pub enum Message { + RequestRange { range: (Point, Point) }, + ClientDone, + StartBatch, + NoBlocks, + Block { body: Vec }, + BatchDone, +} diff --git a/pallas-miniprotocols/src/chainsync/agents.rs b/pallas-miniprotocols/src/chainsync/agents.rs deleted file mode 100644 index fe57455..0000000 --- a/pallas-miniprotocols/src/chainsync/agents.rs +++ /dev/null @@ -1,353 +0,0 @@ -use core::panic; -use std::fmt::Debug; -use std::marker::PhantomData; - -use pallas_codec::Fragment; - -use crate::machines::{Agent, MachineError, Transition}; - -use crate::common::Point; - -use super::{BlockContent, HeaderContent, Message, SkippedContent, State, Tip}; - -#[derive(Debug, PartialEq, Eq)] -pub enum Continuation { - Proceed, - DropOut, - Done, -} - -/// An observer of chain-sync events sent by the state-machine -pub trait Observer { - fn on_roll_forward( - &mut self, - _content: C, - tip: &Tip, - ) -> Result> { - log::debug!("asked to roll forward, tip at {:?}", tip); - - Ok(Continuation::Proceed) - } - - fn on_intersect_found( - &mut self, - point: &Point, - tip: &Tip, - ) -> Result> { - log::debug!("intersect was found {:?} (tip: {:?})", point, tip); - - Ok(Continuation::Proceed) - } - - fn on_rollback(&mut self, point: &Point) -> Result> { - log::debug!("asked to roll back {:?}", point); - - Ok(Continuation::Proceed) - } - - fn on_tip_reached(&mut self) -> Result> { - log::debug!("tip was reached"); - - Ok(Continuation::Proceed) - } -} - -#[derive(Debug)] -pub struct NoopObserver {} - -impl Observer for NoopObserver {} - -#[derive(Debug)] -pub struct Consumer -where - Self: Agent, - O: Observer, -{ - pub state: State, - pub known_points: Option>, - pub intersect: Option, - pub tip: Option, - - continuation: Continuation, - - observer: O, - - _phantom: PhantomData, -} - -impl Consumer -where - O: Observer, - Message: Fragment, - C: std::fmt::Debug + 'static, -{ - pub fn initial(known_points: Option>, observer: O) -> Self { - Self { - state: State::Idle, - intersect: None, - tip: None, - known_points, - continuation: Continuation::Proceed, - observer, - _phantom: PhantomData::default(), - } - } - - fn on_intersect_found(mut self, point: Point, tip: Tip) -> Transition { - log::debug!("intersect found: {:?} (tip: {:?})", point, tip); - - let continuation = self - .observer - .on_intersect_found(&point, &tip) - .map_err(MachineError::downstream)?; - - Ok(Self { - tip: Some(tip), - intersect: Some(point), - state: State::Idle, - continuation, - ..self - }) - } - - fn on_intersect_not_found(self, tip: Tip) -> Transition { - log::debug!("intersect not found (tip: {:?})", tip); - - Ok(Self { - tip: Some(tip), - intersect: None, - state: State::Done, - ..self - }) - } - - fn on_roll_forward(mut self, content: C, tip: Tip) -> Transition { - log::debug!("rolling forward"); - - let continuation = self - .observer - .on_roll_forward(content, &tip) - .map_err(MachineError::downstream)?; - - Ok(Self { - tip: Some(tip), - state: State::Idle, - continuation, - ..self - }) - } - - fn on_roll_backward(mut self, point: Point, tip: Tip) -> Transition { - log::debug!("rolling backward to point: {:?}", point); - - let continuation = self - .observer - .on_rollback(&point) - .map_err(MachineError::downstream)?; - - Ok(Self { - tip: Some(tip), - intersect: Some(point), - state: State::Idle, - continuation, - ..self - }) - } - - fn on_await_reply(mut self) -> Transition { - log::debug!("reached tip, await reply"); - - let continuation = self - .observer - .on_tip_reached() - .map_err(MachineError::downstream)?; - - Ok(Self { - state: State::MustReply, - continuation, - ..self - }) - } -} - -impl Agent for Consumer -where - O: Observer, - C: Debug + 'static, - Message: Fragment, -{ - type Message = Message; - type State = State; - - fn state(&self) -> &Self::State { - &self.state - } - - fn is_done(&self) -> bool { - self.state == State::Done || self.continuation == Continuation::DropOut - } - - fn has_agency(&self) -> bool { - match self.state { - State::Idle => true, - State::CanAwait => false, - State::MustReply => false, - State::Intersect => false, - State::Done => false, - } - } - - fn build_next(&self) -> Self::Message { - match (&self.state, &self.intersect, &self.continuation) { - (State::Idle, _, Continuation::Done) => Message::::Done, - (State::Idle, None, Continuation::Proceed) => match &self.known_points { - Some(x) => Message::::FindIntersect(x.clone()), - None => Message::::RequestNext, - }, - (State::Idle, Some(_), Continuation::Proceed) => Message::::RequestNext, - _ => panic!(""), - } - } - - fn apply_start(self) -> Transition { - Ok(self) - } - - fn apply_outbound(self, msg: Self::Message) -> Transition { - match (self.state, msg) { - (State::Idle, Message::RequestNext) => Ok(Self { - state: State::CanAwait, - ..self - }), - (State::Idle, Message::FindIntersect(_)) => Ok(Self { - state: State::Intersect, - ..self - }), - (State::Idle, Message::Done) => Ok(Self { - state: State::Done, - ..self - }), - _ => panic!(""), - } - } - - fn apply_inbound(self, msg: Self::Message) -> Transition { - match (&self.state, msg) { - (State::CanAwait, Message::RollForward(header, tip)) => { - self.on_roll_forward(header, tip) - } - (State::CanAwait, Message::RollBackward(point, tip)) => { - self.on_roll_backward(point, tip) - } - (State::CanAwait, Message::AwaitReply) => self.on_await_reply(), - (State::MustReply, Message::RollForward(content, tip)) => { - self.on_roll_forward(content, tip) - } - (State::MustReply, Message::RollBackward(point, tip)) => { - self.on_roll_backward(point, tip) - } - (State::Intersect, Message::IntersectFound(point, tip)) => { - self.on_intersect_found(point, tip) - } - (State::Intersect, Message::IntersectNotFound(tip)) => self.on_intersect_not_found(tip), - (state, msg) => Err(MachineError::invalid_msg::(state, &msg)), - } - } -} - -#[derive(Debug)] -pub struct TipFinder { - pub state: State, - pub wellknown_point: Point, - pub output: Option, -} - -impl TipFinder { - pub fn initial(wellknown_point: Point) -> Self { - TipFinder { - wellknown_point, - output: None, - state: State::Idle, - } - } - - fn on_intersect_found(self, tip: Tip) -> Transition { - log::debug!("intersect found with tip: {:?}", tip); - - Ok(Self { - state: State::Done, - output: Some(tip), - ..self - }) - } - - fn on_intersect_not_found(self, tip: Tip) -> Transition { - log::debug!("intersect not found but still have a tip: {:?}", tip); - - Ok(Self { - state: State::Done, - output: Some(tip), - ..self - }) - } -} - -pub type HeaderConsumer = Consumer; - -pub type BlockConsumer = Consumer; - -impl Agent for TipFinder { - type Message = Message; - type State = State; - - fn state(&self) -> &Self::State { - &self.state - } - - fn is_done(&self) -> bool { - self.state == State::Done - } - - fn has_agency(&self) -> bool { - match self.state { - State::Idle => true, - State::CanAwait => false, - State::MustReply => false, - State::Intersect => false, - State::Done => false, - } - } - - fn build_next(&self) -> Self::Message { - match self.state { - State::Idle => { - Message::::FindIntersect(vec![self.wellknown_point.clone()]) - } - _ => panic!("I don't know what to do"), - } - } - - fn apply_start(self) -> Transition { - Ok(self) - } - - fn apply_outbound(self, msg: Self::Message) -> Transition { - match (self.state, msg) { - (State::Idle, Message::FindIntersect(_)) => Ok(Self { - state: State::Intersect, - ..self - }), - _ => panic!("I don't know what to do"), - } - } - - fn apply_inbound(self, msg: Self::Message) -> Transition { - match (&self.state, msg) { - (State::Intersect, Message::IntersectFound(_point, tip)) => { - self.on_intersect_found(tip) - } - (State::Intersect, Message::IntersectNotFound(tip)) => self.on_intersect_not_found(tip), - (state, msg) => Err(MachineError::invalid_msg::(state, &msg)), - } - } -} diff --git a/pallas-miniprotocols/src/chainsync/client.rs b/pallas-miniprotocols/src/chainsync/client.rs new file mode 100644 index 0000000..18526a1 --- /dev/null +++ b/pallas-miniprotocols/src/chainsync/client.rs @@ -0,0 +1,208 @@ +use pallas_codec::Fragment; +use pallas_multiplexer::agents::{Channel, ChannelBuffer, ChannelError}; +use std::marker::PhantomData; +use thiserror::Error; + +use crate::common::Point; + +use super::{BlockContent, HeaderContent, Message, State, Tip}; + +#[derive(Error, Debug)] +pub enum Error { + #[error("attempted to receive message while agency is ours")] + AgencyIsOurs, + + #[error("attempted to send message while agency is theirs")] + AgencyIsTheirs, + + #[error("inbound message is not valid for current state")] + InvalidInbound, + + #[error("outbound message is not valid for current state")] + InvalidOutbound, + + #[error("no intersection point found")] + IntersectionNotFound, + + #[error("error while sending or receiving data through the channel")] + ChannelError(ChannelError), +} + +pub type IntersectResponse = (Option, Tip); + +#[derive(Debug)] +pub enum NextResponse { + RollForward(CONTENT, Tip), + RollBackward(Point, Tip), + Await, +} + +pub struct Client(State, ChannelBuffer, PhantomData) +where + H: Channel, + Message: Fragment; + +impl Client +where + H: Channel, + Message: Fragment, +{ + pub fn new(channel: H) -> Self { + Self(State::Idle, ChannelBuffer::new(channel), PhantomData {}) + } + + pub fn state(&self) -> &State { + &self.0 + } + + pub fn is_done(&self) -> bool { + self.0 == State::Done + } + + pub fn has_agency(&self) -> bool { + match self.state() { + State::Idle => true, + State::CanAwait => false, + State::MustReply => false, + State::Intersect => false, + State::Done => false, + } + } + + fn assert_agency_is_ours(&self) -> Result<(), Error> { + if !self.has_agency() { + Err(Error::AgencyIsTheirs) + } else { + Ok(()) + } + } + + fn assert_agency_is_theirs(&self) -> Result<(), Error> { + if self.has_agency() { + Err(Error::AgencyIsOurs) + } else { + Ok(()) + } + } + + fn assert_outbound_state(&self, msg: &Message) -> Result<(), Error> { + match (&self.0, msg) { + (State::Idle, Message::RequestNext) => Ok(()), + (State::Idle, Message::FindIntersect(_)) => Ok(()), + (State::Idle, Message::Done) => Ok(()), + _ => Err(Error::InvalidOutbound), + } + } + + fn assert_inbound_state(&self, msg: &Message) -> Result<(), Error> { + match (&self.0, msg) { + (State::CanAwait, Message::RollForward(_, _)) => Ok(()), + (State::CanAwait, Message::RollBackward(_, _)) => Ok(()), + (State::CanAwait, Message::AwaitReply) => Ok(()), + (State::MustReply, Message::RollForward(_, _)) => Ok(()), + (State::MustReply, Message::RollBackward(_, _)) => Ok(()), + (State::Intersect, Message::IntersectFound(_, _)) => Ok(()), + (State::Intersect, Message::IntersectNotFound(_)) => Ok(()), + _ => Err(Error::InvalidInbound), + } + } + + pub fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + self.assert_agency_is_ours()?; + self.assert_outbound_state(msg)?; + self.1.send_msg_chunks(msg).map_err(Error::ChannelError)?; + + Ok(()) + } + + pub fn recv_message(&mut self) -> Result, Error> { + self.assert_agency_is_theirs()?; + let msg = self.1.recv_full_msg().map_err(Error::ChannelError)?; + self.assert_inbound_state(&msg)?; + + Ok(msg) + } + + pub fn send_find_intersect(&mut self, points: Vec) -> Result<(), Error> { + let msg = Message::FindIntersect(points); + self.send_message(&msg)?; + self.0 = State::Intersect; + + Ok(()) + } + + pub fn recv_intersect_response(&mut self) -> Result { + match self.recv_message()? { + Message::IntersectFound(point, tip) => { + self.0 = State::Idle; + Ok((Some(point), tip)) + } + Message::IntersectNotFound(tip) => { + self.0 = State::Idle; + Ok((None, tip)) + } + _ => Err(Error::InvalidInbound), + } + } + + pub fn find_intersect(&mut self, points: Vec) -> Result { + self.send_find_intersect(points)?; + self.recv_intersect_response() + } + + pub fn send_request_next(&mut self) -> Result<(), Error> { + let msg = Message::RequestNext; + self.send_message(&msg)?; + self.0 = State::CanAwait; + + Ok(()) + } + + pub fn recv_request_response(&mut self) -> Result, Error> { + match self.recv_message()? { + Message::AwaitReply => { + self.0 = State::MustReply; + Ok(NextResponse::Await) + } + Message::RollForward(a, b) => { + self.0 = State::Idle; + Ok(NextResponse::RollForward(a, b)) + } + Message::RollBackward(a, b) => { + self.0 = State::Idle; + Ok(NextResponse::RollBackward(a, b)) + } + _ => Err(Error::InvalidInbound), + } + } + + pub fn request_next(&mut self) -> Result, Error> { + self.send_request_next()?; + self.recv_request_response() + } + + pub fn intersect_origin(&mut self) -> Result { + let (point, _) = self.find_intersect(vec![Point::Origin])?; + + point.ok_or(Error::IntersectionNotFound) + } + + pub fn intersect_tip(&mut self) -> Result { + let (_, Tip(point, _)) = self.find_intersect(vec![Point::Origin])?; + let (point, _) = self.find_intersect(vec![point])?; + + point.ok_or(Error::IntersectionNotFound) + } + + pub fn send_done(&mut self) -> Result<(), Error> { + let msg = Message::Done; + self.send_message(&msg)?; + self.0 = State::Done; + + Ok(()) + } +} + +pub type N2NClient = Client; + +pub type N2CClient = Client; diff --git a/pallas-miniprotocols/src/chainsync/mod.rs b/pallas-miniprotocols/src/chainsync/mod.rs index 5b0b505..2ad863e 100644 --- a/pallas-miniprotocols/src/chainsync/mod.rs +++ b/pallas-miniprotocols/src/chainsync/mod.rs @@ -1,9 +1,9 @@ -mod agents; mod buffer; +mod client; mod codec; mod protocol; -pub use agents::*; pub use buffer::*; +pub use client::*; pub use codec::*; pub use protocol::*; diff --git a/pallas-miniprotocols/src/handshake/agents.rs b/pallas-miniprotocols/src/handshake/agents.rs deleted file mode 100644 index 9a4d583..0000000 --- a/pallas-miniprotocols/src/handshake/agents.rs +++ /dev/null @@ -1,96 +0,0 @@ -use std::fmt::Debug; - -use crate::{Agent, Transition}; - -use super::protocol::{Message, RefuseReason, State, VersionNumber, VersionTable}; - -#[derive(Debug)] -pub enum Output { - Pending, - Accepted(VersionNumber, D), - Refused(RefuseReason), -} - -#[derive(Debug)] -pub struct Initiator -where - D: Debug + Clone, -{ - pub state: State, - pub output: Output, - pub version_table: VersionTable, -} - -impl Initiator -where - D: Debug + Clone, -{ - pub fn initial(version_table: VersionTable) -> Self { - Initiator { - state: State::Propose, - output: Output::Pending, - version_table, - } - } -} - -impl Agent for Initiator -where - D: Debug + Clone, -{ - type Message = Message; - type State = State; - - fn state(&self) -> &Self::State { - &self.state - } - - fn is_done(&self) -> bool { - self.state == State::Done - } - - fn has_agency(&self) -> bool { - match self.state { - State::Propose => true, - State::Confirm => false, - State::Done => false, - } - } - - fn build_next(&self) -> Self::Message { - match self.state { - State::Propose => Message::Propose(self.version_table.clone()), - _ => panic!("I don't have agency, nothing to send"), - } - } - - fn apply_start(self) -> Transition { - Ok(self) - } - - fn apply_outbound(self, msg: Self::Message) -> Transition { - match (self.state, msg) { - (State::Propose, Message::Propose(_)) => Ok(Self { - state: State::Confirm, - ..self - }), - _ => panic!(""), - } - } - - fn apply_inbound(self, msg: Self::Message) -> Transition { - match (self.state, msg) { - (State::Confirm, Message::Accept(version, data)) => Ok(Self { - state: State::Done, - output: Output::Accepted(version, data), - ..self - }), - (State::Confirm, Message::Refuse(reason)) => Ok(Self { - state: State::Done, - output: Output::Refused(reason), - ..self - }), - _ => panic!("Current state does't expect to receive a message"), - } - } -} diff --git a/pallas-miniprotocols/src/handshake/client.rs b/pallas-miniprotocols/src/handshake/client.rs new file mode 100644 index 0000000..e99d6fc --- /dev/null +++ b/pallas-miniprotocols/src/handshake/client.rs @@ -0,0 +1,139 @@ +use pallas_codec::Fragment; +use pallas_multiplexer::agents::{Channel, ChannelBuffer, ChannelError}; +use std::marker::PhantomData; +use thiserror::*; + +use super::{Message, RefuseReason, State, VersionNumber, VersionTable}; + +#[derive(Error, Debug)] +pub enum Error { + #[error("attempted to receive message while agency is ours")] + AgencyIsOurs, + + #[error("attempted to send message while agency is theirs")] + AgencyIsTheirs, + + #[error("inbound message is not valid for current state")] + InvalidInbound, + + #[error("outbound message is not valid for current state")] + InvalidOutbound, + + #[error("error while sending or receiving data through the channel")] + ChannelError(ChannelError), +} + +#[derive(Debug)] +pub enum Confirmation { + Accepted(VersionNumber, D), + Rejected(RefuseReason), +} + +pub struct Client(State, ChannelBuffer, PhantomData) +where + H: Channel; + +impl Client +where + H: Channel, + D: std::fmt::Debug + Clone, + Message: Fragment, +{ + pub fn new(channel: H) -> Self { + Self(State::Propose, ChannelBuffer::new(channel), PhantomData {}) + } + + pub fn state(&self) -> &State { + &self.0 + } + + pub fn is_done(&self) -> bool { + self.0 == State::Done + } + + pub fn has_agency(&self) -> bool { + match self.state() { + State::Propose => true, + State::Confirm => false, + State::Done => false, + } + } + + fn assert_agency_is_ours(&self) -> Result<(), Error> { + if !self.has_agency() { + Err(Error::AgencyIsTheirs) + } else { + Ok(()) + } + } + + fn assert_agency_is_theirs(&self) -> Result<(), Error> { + if self.has_agency() { + Err(Error::AgencyIsOurs) + } else { + Ok(()) + } + } + + fn assert_outbound_state(&self, msg: &Message) -> Result<(), Error> { + match (&self.0, msg) { + (State::Propose, Message::Propose(_)) => Ok(()), + _ => Err(Error::InvalidOutbound), + } + } + + fn assert_inbound_state(&self, msg: &Message) -> Result<(), Error> { + match (&self.0, msg) { + (State::Confirm, Message::Accept(..)) => Ok(()), + (State::Confirm, Message::Refuse(..)) => Ok(()), + _ => Err(Error::InvalidInbound), + } + } + + pub fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + self.assert_agency_is_ours()?; + self.assert_outbound_state(msg)?; + self.1.send_msg_chunks(msg).map_err(Error::ChannelError)?; + + Ok(()) + } + + pub fn recv_message(&mut self) -> Result, Error> { + self.assert_agency_is_theirs()?; + let msg = self.1.recv_full_msg().map_err(Error::ChannelError)?; + self.assert_inbound_state(&msg)?; + + Ok(msg) + } + + pub fn send_propose(&mut self, versions: VersionTable) -> Result<(), Error> { + let msg = Message::Propose(versions); + self.send_message(&msg)?; + self.0 = State::Confirm; + + Ok(()) + } + + pub fn recv_while_confirm(&mut self) -> Result, Error> { + match self.recv_message()? { + Message::Accept(v, m) => { + self.0 = State::Done; + Ok(Confirmation::Accepted(v, m)) + } + Message::Refuse(r) => { + self.0 = State::Done; + Ok(Confirmation::Rejected(r)) + } + _ => Err(Error::InvalidInbound), + } + } + + pub fn handshake(&mut self, versions: VersionTable) -> Result, Error> { + self.send_propose(versions)?; + self.recv_while_confirm() + } +} + +pub type N2NClient = Client; + +pub type N2CClient = Client; diff --git a/pallas-miniprotocols/src/handshake/mod.rs b/pallas-miniprotocols/src/handshake/mod.rs index 97598a8..f4d2634 100644 --- a/pallas-miniprotocols/src/handshake/mod.rs +++ b/pallas-miniprotocols/src/handshake/mod.rs @@ -1,8 +1,8 @@ -mod agents; +mod client; mod protocol; pub mod n2c; pub mod n2n; -pub use agents::*; +pub use client::*; pub use protocol::*; diff --git a/pallas-miniprotocols/src/localstate/client.rs b/pallas-miniprotocols/src/localstate/client.rs new file mode 100644 index 0000000..58cab05 --- /dev/null +++ b/pallas-miniprotocols/src/localstate/client.rs @@ -0,0 +1,175 @@ +use std::fmt::Debug; + +use pallas_codec::Fragment; + +use crate::common::Point; + +use pallas_multiplexer::agents::{Channel, ChannelBuffer, ChannelError}; +use std::marker::PhantomData; +use thiserror::*; + +use super::{AcquireFailure, Message, Query, State}; + +#[derive(Error, Debug)] +pub enum Error { + #[error("attempted to receive message while agency is ours")] + AgencyIsOurs, + #[error("attempted to send message while agency is theirs")] + AgencyIsTheirs, + #[error("inbound message is not valid for current state")] + InvalidInbound, + #[error("outbound message is not valid for current state")] + InvalidOutbound, + #[error("failure acquiring point, not found")] + AcquirePointNotFound, + #[error("failure acquiring point, too old")] + AcquirePointTooOld, + #[error("error while sending or receiving data through the channel")] + ChannelError(ChannelError), +} + +impl From for Error { + fn from(x: AcquireFailure) -> Self { + match x { + AcquireFailure::PointTooOld => Error::AcquirePointTooOld, + AcquireFailure::PointNotOnChain => Error::AcquirePointNotFound, + } + } +} + +pub struct Client(State, ChannelBuffer, PhantomData) +where + H: Channel, + Q: Query, + Message: Fragment; + +impl Client +where + H: Channel, + Q: Query, + Message: Fragment, +{ + pub fn new(channel: H) -> Self { + Self(State::Idle, ChannelBuffer::new(channel), PhantomData {}) + } + + pub fn state(&self) -> &State { + &self.0 + } + + pub fn is_done(&self) -> bool { + self.0 == State::Done + } + + #[allow(clippy::match_like_matches_macro)] + fn has_agency(&self) -> bool { + match self.state() { + State::Idle => true, + State::Acquired => true, + _ => false, + } + } + + fn assert_agency_is_ours(&self) -> Result<(), Error> { + if !self.has_agency() { + Err(Error::AgencyIsTheirs) + } else { + Ok(()) + } + } + + fn assert_agency_is_theirs(&self) -> Result<(), Error> { + if self.has_agency() { + Err(Error::AgencyIsOurs) + } else { + Ok(()) + } + } + + fn assert_outbound_state(&self, msg: &Message) -> Result<(), Error> { + match (&self.0, msg) { + (State::Idle, Message::Acquire(_)) => Ok(()), + (State::Idle, Message::Done) => Ok(()), + (State::Acquired, Message::Query(_)) => Ok(()), + (State::Acquired, Message::Release) => Ok(()), + _ => Err(Error::InvalidOutbound), + } + } + + fn assert_inbound_state(&self, msg: &Message) -> Result<(), Error> { + match (&self.0, msg) { + (State::Acquiring, Message::Acquired) => Ok(()), + (State::Acquiring, Message::Failure(_)) => Ok(()), + (State::Querying, Message::Result(_)) => Ok(()), + _ => Err(Error::InvalidInbound), + } + } + + pub fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + self.assert_agency_is_ours()?; + self.assert_outbound_state(msg)?; + self.1.send_msg_chunks(msg).map_err(Error::ChannelError)?; + + Ok(()) + } + + pub fn recv_message(&mut self) -> Result, Error> { + self.assert_agency_is_theirs()?; + let msg = self.1.recv_full_msg().map_err(Error::ChannelError)?; + self.assert_inbound_state(&msg)?; + + Ok(msg) + } + + pub fn send_acquire(&mut self, point: Option) -> Result<(), Error> { + let msg = Message::::Acquire(point); + self.send_message(&msg)?; + self.0 = State::Acquiring; + + Ok(()) + } + + pub fn recv_while_acquiring(&mut self) -> Result<(), Error> { + match self.recv_message()? { + Message::Acquired => { + self.0 = State::Acquired; + Ok(()) + } + Message::Failure(x) => { + self.0 = State::Idle; + Err(Error::from(x)) + } + _ => Err(Error::InvalidInbound), + } + } + + pub fn acquire(&mut self, point: Option) -> Result<(), Error> { + self.send_acquire(point)?; + self.recv_while_acquiring() + } + + pub fn send_query(&mut self, request: Q::Request) -> Result<(), Error> { + let msg = Message::::Query(request); + self.send_message(&msg)?; + self.0 = State::Querying; + + Ok(()) + } + + pub fn recv_while_querying(&mut self) -> Result { + match self.recv_message()? { + Message::Result(x) => { + self.0 = State::Acquired; + Ok(x) + } + _ => Err(Error::InvalidInbound), + } + } + + pub fn query(&mut self, request: Q::Request) -> Result { + self.send_query(request)?; + self.recv_while_querying() + } +} + +pub type ClientV10 = Client; diff --git a/pallas-miniprotocols/src/localstate/codec.rs b/pallas-miniprotocols/src/localstate/codec.rs index 650a89f..bf82b2b 100644 --- a/pallas-miniprotocols/src/localstate/codec.rs +++ b/pallas-miniprotocols/src/localstate/codec.rs @@ -10,7 +10,7 @@ impl Encode<()> for AcquireFailure { ) -> Result<(), encode::Error> { let code = match self { AcquireFailure::PointTooOld => 0, - AcquireFailure::PointNotInChain => 1, + AcquireFailure::PointNotOnChain => 1, }; e.u16(code)?; @@ -28,7 +28,7 @@ impl<'b> Decode<'b, ()> for AcquireFailure { match code { 0 => Ok(AcquireFailure::PointTooOld), - 1 => Ok(AcquireFailure::PointNotInChain), + 1 => Ok(AcquireFailure::PointNotOnChain), _ => Err(decode::Error::message( "can't infer acquire failure from variant id", )), diff --git a/pallas-miniprotocols/src/localstate/mod.rs b/pallas-miniprotocols/src/localstate/mod.rs index 4675cdb..c327f1c 100644 --- a/pallas-miniprotocols/src/localstate/mod.rs +++ b/pallas-miniprotocols/src/localstate/mod.rs @@ -1,169 +1,8 @@ +mod client; mod codec; +mod protocol; pub mod queries; -use std::fmt::Debug; - -use pallas_codec::Fragment; - -use crate::machines::{Agent, MachineError, Transition}; - -use crate::common::Point; - -#[derive(Debug, PartialEq, Eq, Clone)] -pub enum State { - Idle, - Acquiring, - Acquired, - Querying, - Done, -} - -#[derive(Debug)] -pub enum AcquireFailure { - PointTooOld, - PointNotInChain, -} -pub trait Query: Debug { - type Request: Clone + Debug; - type Response: Clone + Debug; -} - -#[derive(Debug)] -pub enum Message { - Acquire(Option), - Failure(AcquireFailure), - Acquired, - Query(Q::Request), - Result(Q::Response), - ReAcquire(Option), - Release, - Done, -} - -pub type Output = Result; - -#[derive(Debug)] -pub struct OneShotClient { - pub state: State, - pub check_point: Option, - pub request: Q::Request, - pub output: Option>, -} - -impl OneShotClient -where - Q: Query + 'static, - Message: Fragment, -{ - pub fn initial(check_point: Option, request: Q::Request) -> Self { - Self { - state: State::Idle, - output: None, - check_point, - request, - } - } - - fn on_acquired(self) -> Transition { - log::debug!("acquired check point for chain state"); - - Ok(Self { - state: State::Acquired, - ..self - }) - } - - fn on_result(self, response: Q::Response) -> Transition { - log::debug!("query result received: {:?}", response); - - Ok(Self { - // once we get a result, since this is a one-shot client, we mutate into Done - state: State::Done, - output: Some(Ok(response)), - ..self - }) - } - - fn on_failure(self, failure: AcquireFailure) -> Transition { - log::debug!("acquire failure: {:?}", failure); - - Ok(Self { - state: State::Idle, - output: Some(Err(failure)), - ..self - }) - } -} - -impl Agent for OneShotClient -where - Q: Query + 'static, - Message: Fragment, -{ - type Message = Message; - type State = State; - - fn state(&self) -> &Self::State { - &self.state - } - - fn is_done(&self) -> bool { - self.state == State::Done - } - - #[allow(clippy::match_like_matches_macro)] - fn has_agency(&self) -> bool { - match self.state { - State::Idle => true, - State::Acquired => true, - _ => false, - } - } - - fn build_next(&self) -> Self::Message { - match (&self.state, &self.output) { - // if we're idle and without a result, assume start of flow - (State::Idle, None) => Message::::Acquire(self.check_point.clone()), - // if we don't have an output, assume start of query - (State::Acquired, None) => Message::::Query(self.request.clone()), - // if we have an output but still acquired, release the server - (State::Acquired, Some(_)) => Message::::Release, - _ => panic!("I don't have agency, don't know what to do"), - } - } - - fn apply_start(self) -> Transition { - Ok(self) - } - - fn apply_outbound(self, msg: Self::Message) -> Transition { - match (self.state, msg) { - (State::Idle, Message::Acquire(_)) => Ok(Self { - state: State::Acquiring, - ..self - }), - (State::Acquired, Message::Query(_)) => Ok(Self { - state: State::Querying, - ..self - }), - (State::Acquired, Message::Release) => Ok(Self { - state: State::Idle, - ..self - }), - (State::Idle, Message::Done) => Ok(Self { - state: State::Done, - ..self - }), - _ => panic!(""), - } - } - - fn apply_inbound(self, msg: Self::Message) -> Transition { - match (&self.state, msg) { - (State::Acquiring, Message::Acquired) => self.on_acquired(), - (State::Acquiring, Message::Failure(failure)) => self.on_failure(failure), - (State::Querying, Message::Result(result)) => self.on_result(result), - (state, msg) => Err(MachineError::invalid_msg::(state, &msg)), - } - } -} +pub use client::*; +pub use codec::*; +pub use protocol::*; diff --git a/pallas-miniprotocols/src/localstate/protocol.rs b/pallas-miniprotocols/src/localstate/protocol.rs new file mode 100644 index 0000000..a40b0ca --- /dev/null +++ b/pallas-miniprotocols/src/localstate/protocol.rs @@ -0,0 +1,35 @@ +use std::fmt::Debug; + +use crate::common::Point; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum State { + Idle, + Acquiring, + Acquired, + Querying, + Done, +} + +#[derive(Debug)] +pub enum AcquireFailure { + PointTooOld, + PointNotOnChain, +} + +pub trait Query: Debug { + type Request: Clone + Debug; + type Response: Clone + Debug; +} + +#[derive(Debug)] +pub enum Message { + Acquire(Option), + Failure(AcquireFailure), + Acquired, + Query(Q::Request), + Result(Q::Response), + ReAcquire(Option), + Release, + Done, +} diff --git a/pallas-miniprotocols/tests/integration.rs b/pallas-miniprotocols/tests/integration.rs new file mode 100644 index 0000000..7ff7b49 --- /dev/null +++ b/pallas-miniprotocols/tests/integration.rs @@ -0,0 +1,129 @@ +use pallas_miniprotocols::{ + blockfetch, + chainsync::{self, NextResponse}, + handshake::{self, Confirmation}, + Point, +}; +use pallas_multiplexer::{bearers::Bearer, StdPlexer}; + +#[test] +#[ignore] +pub fn chainsync_happy_path() { + let bearer = Bearer::connect_tcp("preview-node.world.dev.cardano.org:30002").unwrap(); + let mut plexer = StdPlexer::new(bearer); + + let channel0 = plexer.use_channel(0); + let channel2 = plexer.use_channel(2); + + plexer.muxer.spawn(); + plexer.demuxer.spawn(); + + let mut client = handshake::N2NClient::new(channel0); + + let confirmation = client + .handshake(handshake::n2n::VersionTable::v7_and_above(2)) + .unwrap(); + + assert!(matches!(confirmation, Confirmation::Accepted(..))); + + if let Confirmation::Accepted(v, _) = confirmation { + assert!(v >= 7); + } + + let known_point = Point::Specific( + 5953863, + hex::decode("7e44cb1e230b686875ae6a256b95c9b4eea7c9e9a9d046b626ed69d4c1b9bfe1").unwrap(), + ); + + let mut client = chainsync::N2NClient::new(channel2); + + let (point, _) = client.find_intersect(vec![known_point.clone()]).unwrap(); + + assert!(matches!(client.state(), chainsync::State::Idle)); + + match point { + Some(point) => assert_eq!(point, known_point.clone()), + None => panic!("expected point"), + } + + let next = client.request_next().unwrap(); + + match next { + NextResponse::RollBackward(point, _) => assert_eq!(point, known_point.clone()), + _ => panic!("expected rollback"), + } + + assert!(matches!(client.state(), chainsync::State::Idle)); + + for _ in 0..10 { + let next = client.request_next().unwrap(); + + match next { + NextResponse::RollForward(_, _) => (), + _ => panic!("expected roll-forward"), + } + + assert!(matches!(client.state(), chainsync::State::Idle)); + } + + client.send_done().unwrap(); + + assert!(matches!(client.state(), chainsync::State::Done)); +} + +#[test] +#[ignore] +pub fn blockfetch_happy_path() { + let bearer = Bearer::connect_tcp("preview-node.world.dev.cardano.org:30002").unwrap(); + let mut plexer = StdPlexer::new(bearer); + + let channel0 = plexer.use_channel(0); + let channel3 = plexer.use_channel(3); + + plexer.muxer.spawn(); + plexer.demuxer.spawn(); + + let mut client = handshake::N2NClient::new(channel0); + + let confirmation = client + .handshake(handshake::n2n::VersionTable::v7_and_above(2)) + .unwrap(); + + assert!(matches!(confirmation, Confirmation::Accepted(..))); + + if let Confirmation::Accepted(v, _) = confirmation { + assert!(v >= 7); + } + + let known_point = Point::Specific( + 5953863, + hex::decode("7e44cb1e230b686875ae6a256b95c9b4eea7c9e9a9d046b626ed69d4c1b9bfe1").unwrap(), + ); + + let mut client = blockfetch::Client::new(channel3); + + let range_ok = client.request_range((known_point.clone(), known_point.clone())); + + assert!(matches!(client.state(), blockfetch::State::Streaming)); + + assert!(matches!(range_ok, Ok(_))); + + for _ in 0..1 { + let next = client.recv_while_streaming().unwrap(); + + match next { + Some(body) => assert_eq!(body.len(), 863), + _ => panic!("expected block body"), + } + + assert!(matches!(client.state(), blockfetch::State::Streaming)); + } + + let next = client.recv_while_streaming().unwrap(); + + assert!(matches!(next, None)); + + client.send_done().unwrap(); + + assert!(matches!(client.state(), blockfetch::State::Done)); +}