diff --git a/examples/n2c-miniprotocols/src/main.rs b/examples/n2c-miniprotocols/src/main.rs index 4dec5f2..6ca0458 100644 --- a/examples/n2c-miniprotocols/src/main.rs +++ b/examples/n2c-miniprotocols/src/main.rs @@ -9,7 +9,7 @@ async fn do_localstate_query(client: &mut NodeClient) { let result = client .statequery() - .query(localstate::queries::RequestV10::GetSystemStart) + .query(localstate::queries::Request::GetSystemStart) .await .unwrap(); diff --git a/pallas-network/src/facades.rs b/pallas-network/src/facades.rs index 0e99c5d..2cd8b5d 100644 --- a/pallas-network/src/facades.rs +++ b/pallas-network/src/facades.rs @@ -150,7 +150,7 @@ pub struct NodeClient { pub plexer_handle: JoinHandle>, pub handshake: handshake::Confirmation, pub chainsync: chainsync::N2CClient, - pub statequery: localstate::ClientV10, + pub statequery: localstate::Client, } impl NodeClient { @@ -236,7 +236,7 @@ impl NodeClient { &mut self.chainsync } - pub fn statequery(&mut self) -> &mut localstate::ClientV10 { + pub fn statequery(&mut self) -> &mut localstate::Client { &mut self.statequery } @@ -250,7 +250,7 @@ pub struct NodeServer { pub plexer_handle: JoinHandle>, pub version: (VersionNumber, n2c::VersionData), pub chainsync: chainsync::N2CServer, - // statequery: localstate::Server, + pub statequery: localstate::Server, } #[cfg(not(target_os = "windows"))] @@ -264,11 +264,11 @@ impl NodeServer { let hs_channel = server_plexer.subscribe_server(PROTOCOL_N2C_HANDSHAKE); let cs_channel = server_plexer.subscribe_server(PROTOCOL_N2C_CHAIN_SYNC); - // let sq_channel = server_plexer.subscribe_server(PROTOCOL_N2C_STATE_QUERY); + let sq_channel = server_plexer.subscribe_server(PROTOCOL_N2C_STATE_QUERY); let mut server_hs: handshake::Server = handshake::Server::new(hs_channel); let server_cs = chainsync::N2CServer::new(cs_channel); - // let server_sq = localstate::Server::new(sq_channel); + let server_sq = localstate::Server::new(sq_channel); let plexer_handle = tokio::spawn(async move { server_plexer.run().await }); @@ -282,7 +282,7 @@ impl NodeServer { plexer_handle, version: ver, chainsync: server_cs, - // statequery: server_sq + statequery: server_sq }) } else { plexer_handle.abort(); @@ -294,9 +294,9 @@ impl NodeServer { &mut self.chainsync } - // pub fn statequery(&mut self) -> &mut localstate::Server { - // &mut self.statequery - // } + pub fn statequery(&mut self) -> &mut localstate::Server { + &mut self.statequery + } pub fn abort(&mut self) { self.plexer_handle.abort(); diff --git a/pallas-network/src/miniprotocols/handshake/n2c.rs b/pallas-network/src/miniprotocols/handshake/n2c.rs index 5d7a0bd..e0b7bd3 100644 --- a/pallas-network/src/miniprotocols/handshake/n2c.rs +++ b/pallas-network/src/miniprotocols/handshake/n2c.rs @@ -83,6 +83,12 @@ impl VersionTable { #[derive(Debug, Clone, PartialEq)] pub struct VersionData(NetworkMagic, Option); +impl VersionData { + pub fn new(magic: NetworkMagic, param: Option) -> Self { + Self(magic, param) + } +} + impl Encode<()> for VersionData { fn encode( &self, diff --git a/pallas-network/src/miniprotocols/localstate/client.rs b/pallas-network/src/miniprotocols/localstate/client.rs index ef5e2cc..f14d79c 100644 --- a/pallas-network/src/miniprotocols/localstate/client.rs +++ b/pallas-network/src/miniprotocols/localstate/client.rs @@ -10,7 +10,7 @@ use crate::miniprotocols::Point; use crate::multiplexer; #[derive(Error, Debug)] -pub enum Error { +pub enum ClientError { #[error("attempted to receive message while agency is ours")] AgencyIsOurs, #[error("attempted to send message while agency is theirs")] @@ -27,21 +27,21 @@ pub enum Error { Plexer(multiplexer::Error), } -impl From for Error { +impl From for ClientError { fn from(x: AcquireFailure) -> Self { match x { - AcquireFailure::PointTooOld => Error::AcquirePointTooOld, - AcquireFailure::PointNotOnChain => Error::AcquirePointNotFound, + AcquireFailure::PointTooOld => ClientError::AcquirePointTooOld, + AcquireFailure::PointNotOnChain => ClientError::AcquirePointNotFound, } } } -pub struct Client(State, multiplexer::ChannelBuffer, PhantomData) +pub struct GenericClient(State, multiplexer::ChannelBuffer, PhantomData) where Q: Query, Message: Fragment; -impl Client +impl GenericClient where Q: Query, Message: Fragment, @@ -71,58 +71,59 @@ where } } - fn assert_agency_is_ours(&self) -> Result<(), Error> { + fn assert_agency_is_ours(&self) -> Result<(), ClientError> { if !self.has_agency() { - Err(Error::AgencyIsTheirs) + Err(ClientError::AgencyIsTheirs) } else { Ok(()) } } - fn assert_agency_is_theirs(&self) -> Result<(), Error> { + fn assert_agency_is_theirs(&self) -> Result<(), ClientError> { if self.has_agency() { - Err(Error::AgencyIsOurs) + Err(ClientError::AgencyIsOurs) } else { Ok(()) } } - fn assert_outbound_state(&self, msg: &Message) -> Result<(), Error> { + fn assert_outbound_state(&self, msg: &Message) -> Result<(), ClientError> { match (&self.0, msg) { (State::Idle, Message::Acquire(_)) => Ok(()), (State::Idle, Message::Done) => Ok(()), (State::Acquired, Message::Query(_)) => Ok(()), + (State::Acquired, Message::ReAcquire(_)) => Ok(()), (State::Acquired, Message::Release) => Ok(()), - _ => Err(Error::InvalidOutbound), + _ => Err(ClientError::InvalidOutbound), } } - fn assert_inbound_state(&self, msg: &Message) -> Result<(), Error> { + fn assert_inbound_state(&self, msg: &Message) -> Result<(), ClientError> { match (&self.0, msg) { (State::Acquiring, Message::Acquired) => Ok(()), (State::Acquiring, Message::Failure(_)) => Ok(()), (State::Querying, Message::Result(_)) => Ok(()), - _ => Err(Error::InvalidInbound), + _ => Err(ClientError::InvalidInbound), } } - pub async fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + pub async fn send_message(&mut self, msg: &Message) -> Result<(), ClientError> { self.assert_agency_is_ours()?; self.assert_outbound_state(msg)?; - self.1.send_msg_chunks(msg).await.map_err(Error::Plexer)?; + self.1.send_msg_chunks(msg).await.map_err(ClientError::Plexer)?; Ok(()) } - pub async fn recv_message(&mut self) -> Result, Error> { + pub async fn recv_message(&mut self) -> Result, ClientError> { self.assert_agency_is_theirs()?; - let msg = self.1.recv_full_msg().await.map_err(Error::Plexer)?; + let msg = self.1.recv_full_msg().await.map_err(ClientError::Plexer)?; self.assert_inbound_state(&msg)?; Ok(msg) } - pub async fn send_acquire(&mut self, point: Option) -> Result<(), Error> { + pub async fn send_acquire(&mut self, point: Option) -> Result<(), ClientError> { let msg = Message::::Acquire(point); self.send_message(&msg).await?; self.0 = State::Acquiring; @@ -130,7 +131,31 @@ where Ok(()) } - pub async fn recv_while_acquiring(&mut self) -> Result<(), Error> { + pub async fn send_reacquire(&mut self, point: Option) -> Result<(), ClientError> { + let msg = Message::::ReAcquire(point); + self.send_message(&msg).await?; + self.0 = State::Acquiring; + + Ok(()) + } + + pub async fn send_release(&mut self) -> Result<(), ClientError> { + let msg = Message::::Release; + self.send_message(&msg).await?; + self.0 = State::Idle; + + Ok(()) + } + + pub async fn send_done(&mut self) -> Result<(), ClientError> { + let msg = Message::::Done; + self.send_message(&msg).await?; + self.0 = State::Done; + + Ok(()) + } + + pub async fn recv_while_acquiring(&mut self) -> Result<(), ClientError> { match self.recv_message().await? { Message::Acquired => { self.0 = State::Acquired; @@ -138,18 +163,18 @@ where } Message::Failure(x) => { self.0 = State::Idle; - Err(Error::from(x)) + Err(ClientError::from(x)) } - _ => Err(Error::InvalidInbound), + _ => Err(ClientError::InvalidInbound), } } - pub async fn acquire(&mut self, point: Option) -> Result<(), Error> { + pub async fn acquire(&mut self, point: Option) -> Result<(), ClientError> { self.send_acquire(point).await?; self.recv_while_acquiring().await } - pub async fn send_query(&mut self, request: Q::Request) -> Result<(), Error> { + pub async fn send_query(&mut self, request: Q::Request) -> Result<(), ClientError> { let msg = Message::::Query(request); self.send_message(&msg).await?; self.0 = State::Querying; @@ -157,20 +182,20 @@ where Ok(()) } - pub async fn recv_while_querying(&mut self) -> Result { + pub async fn recv_while_querying(&mut self) -> Result { match self.recv_message().await? { Message::Result(x) => { self.0 = State::Acquired; Ok(x) } - _ => Err(Error::InvalidInbound), + _ => Err(ClientError::InvalidInbound), } } - pub async fn query(&mut self, request: Q::Request) -> Result { + pub async fn query(&mut self, request: Q::Request) -> Result { self.send_query(request).await?; self.recv_while_querying().await } } -pub type ClientV10 = Client; +pub type Client = GenericClient; diff --git a/pallas-network/src/miniprotocols/localstate/codec.rs b/pallas-network/src/miniprotocols/localstate/codec.rs index bf82b2b..7b747eb 100644 --- a/pallas-network/src/miniprotocols/localstate/codec.rs +++ b/pallas-network/src/miniprotocols/localstate/codec.rs @@ -68,13 +68,11 @@ where } Message::Query(query) => { e.array(2)?.u16(3)?; - e.array(1)?; e.encode(query)?; Ok(()) } Message::Result(result) => { e.array(2)?.u16(4)?; - e.array(1)?; e.encode(result)?; Ok(()) } diff --git a/pallas-network/src/miniprotocols/localstate/mod.rs b/pallas-network/src/miniprotocols/localstate/mod.rs index c327f1c..1e478a4 100644 --- a/pallas-network/src/miniprotocols/localstate/mod.rs +++ b/pallas-network/src/miniprotocols/localstate/mod.rs @@ -2,7 +2,9 @@ mod client; mod codec; mod protocol; pub mod queries; +mod server; pub use client::*; pub use codec::*; pub use protocol::*; +pub use server::*; diff --git a/pallas-network/src/miniprotocols/localstate/queries.rs b/pallas-network/src/miniprotocols/localstate/queries.rs index e766efd..7cae029 100644 --- a/pallas-network/src/miniprotocols/localstate/queries.rs +++ b/pallas-network/src/miniprotocols/localstate/queries.rs @@ -2,36 +2,224 @@ use pallas_codec::minicbor::{decode, encode, Decode, Decoder, Encode, Encoder}; use super::Query; -#[derive(Debug, Clone)] -pub struct BlockQuery {} +// https://github.com/input-output-hk/ouroboros-consensus/blob/main/ouroboros-consensus-cardano/src/shelley/Ouroboros/Consensus/Shelley/Ledger/Query.hs +#[derive(Debug, Clone, PartialEq)] +#[repr(u16)] +pub enum BlockQuery { + GetLedgerTip, + GetEpochNo, + // GetNonMyopicMemberRewards(()), + GetCurrentPParams, + GetProposedPParamsUpdates, + GetStakeDistribution, + // GetUTxOByAddress(()), + // GetUTxOWhole, (Response too large for now) + // DebugEpochState, (Response too large for now) + // GetCBOR(()), + // GetFilteredDelegationsAndRewardAccounts(()), + GetGenesisConfig, + // DebugNewEpochState, (Response too large for now) + DebugChainDepState, + GetRewardProvenance, + // GetUTxOByTxIn(()), + GetStakePools, + // GetStakePoolParams(()), + GetRewardInfoPools, + // GetPoolState(()), + // GetStakeSnapshots(()), + // GetPoolDistr(()), + // GetStakeDelegDeposits(()), + // GetConstitutionHash, +} -#[derive(Debug, Clone)] -pub enum RequestV10 { +impl Encode<()> for BlockQuery { + fn encode( + &self, + e: &mut Encoder, + _ctx: &mut (), + ) -> Result<(), encode::Error> { + e.array(2)?; + e.u16(0)?; + e.array(2)?; + /* + TODO: Think this is era or something? First fetch era with + [3, [0, [2, [1]]]], then use it here? + */ + e.u16(5)?; + match self { + BlockQuery::GetLedgerTip => { + e.array(1)?; + e.u16(0)?; + } + BlockQuery::GetEpochNo => { + e.array(1)?; + e.u16(1)?; + } + // BlockQuery::GetNonMyopicMemberRewards(()) => { + // e.array(X)?; + // e.u16(2)?; + // } + BlockQuery::GetCurrentPParams => { + e.array(1)?; + e.u16(3)?; + } + BlockQuery::GetProposedPParamsUpdates => { + e.array(1)?; + e.u16(4)?; + } + BlockQuery::GetStakeDistribution => { + e.array(1)?; + e.u16(5)?; + } + // BlockQuery::GetUTxOByAddress(()) => { + // e.array(X)?; + // e.u16(6)?; + // } + // BlockQuery::GetUTxOWhole => { + // e.array(1)?; + // e.u16(7)?; + // } + // BlockQuery::DebugEpochState => { + // e.array(1)?; + // e.u16(8)?; + // } + // BlockQuery::GetCBOR(()) => { + // e.array(X)?; + // e.u16(9)?; + // } + // BlockQuery::GetFilteredDelegationsAndRewardAccounts(()) => { + // e.array(X)?; + // e.u16(10)?; + // } + BlockQuery::GetGenesisConfig => { + e.array(1)?; + e.u16(11)?; + } + // BlockQuery::DebugNewEpochState => { + // e.array(1)?; + // e.u16(12)?; + // } + BlockQuery::DebugChainDepState => { + e.array(1)?; + e.u16(13)?; + } + BlockQuery::GetRewardProvenance => { + e.array(1)?; + e.u16(14)?; + } + // BlockQuery::GetUTxOByTxIn(()) => { + // e.array(X)?; + // e.u16(15)?; + // } + BlockQuery::GetStakePools => { + e.array(1)?; + e.u16(16)?; + } + // BlockQuery::GetStakePoolParams(()) => { + // e.array(X)?; + // e.u16(17)?; + // } + BlockQuery::GetRewardInfoPools => { + e.array(1)?; + e.u16(18)?; + } + // BlockQuery::GetPoolState(()) => { + // e.array(X)?; + // e.u16(19)?; + // } + // BlockQuery::GetStakeSnapshots(()) => { + // e.array(X)?; + // e.u16(20)?; + // } + // BlockQuery::GetPoolDistr(()) => { + // e.array(X)?; + // e.u16(21)?; + // } + // BlockQuery::GetStakeDelegDeposits(()) => { + // e.array(X)?; + // e.u16(22)?; + // } + // BlockQuery::GetConstitutionHash => { + // e.array(1)?; + // e.u16(23)?; + // } + } + Ok(()) + } +} + +impl<'b> Decode<'b, ()> for BlockQuery { + fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result { + d.array()?; + d.u16()?; + d.array()?; + d.u16()?; + d.array()?; + + match d.u16()? { + 0 => Ok(Self::GetLedgerTip), + 1 => Ok(Self::GetEpochNo), + // 2 => Ok(Self::GetNonMyopicMemberRewards(())), + 3 => Ok(Self::GetCurrentPParams), + 4 => Ok(Self::GetProposedPParamsUpdates), + 5 => Ok(Self::GetStakeDistribution), + // 6 => Ok(Self::GetUTxOByAddress(())), + // 7 => Ok(Self::GetUTxOWhole), + // 8 => Ok(Self::DebugEpochState), + // 9 => Ok(Self::GetCBOR(())), + // 10 => Ok(Self::GetFilteredDelegationsAndRewardAccounts(())), + 11 => Ok(Self::GetGenesisConfig), + // 12 => Ok(Self::DebugNewEpochState), + 13 => Ok(Self::DebugChainDepState), + 14 => Ok(Self::GetRewardProvenance), + // 15 => Ok(Self::GetUTxOByTxIn(())), + 16 => Ok(Self::GetStakePools), + // 17 => Ok(Self::GetStakePoolParams(())), + 18 => Ok(Self::GetRewardInfoPools), + // 19 => Ok(Self::GetPoolState(())), + // 20 => Ok(Self::GetStakeSnapshots(())), + // 21 => Ok(Self::GetPoolDistr(())), + // 22 => Ok(Self::GetStakeDelegDeposits(())), + // 23 => Ok(Self::GetConstitutionHash), + _ => unreachable!(), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Request { BlockQuery(BlockQuery), GetSystemStart, GetChainBlockNo, GetChainPoint, } -impl Encode<()> for RequestV10 { +impl Encode<()> for Request { fn encode( &self, e: &mut Encoder, _ctx: &mut (), ) -> Result<(), encode::Error> { match self { - Self::BlockQuery(..) => { - todo!() + Self::BlockQuery(q) => { + e.array(2)?; + e.u16(0)?; + e.encode(q)?; + + Ok(()) } Self::GetSystemStart => { + e.array(1)?; e.u16(1)?; Ok(()) } Self::GetChainBlockNo => { + e.array(1)?; e.u16(2)?; Ok(()) } Self::GetChainPoint => { + e.array(1)?; e.u16(3)?; Ok(()) } @@ -39,22 +227,48 @@ impl Encode<()> for RequestV10 { } } -impl<'b> Decode<'b, ()> for RequestV10 { - fn decode(_d: &mut Decoder<'b>, _ctx: &mut ()) -> Result { - todo!() +impl<'b> Decode<'b, ()> for Request { + fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result { + let size = match d.array()? { + Some(l) => l, + None => return Err(decode::Error::message("unexpected indefinite len list")), + }; + + let tag = d.u16()?; + + match (size, tag) { + (2, 0) => Ok(Self::BlockQuery(d.decode()?)), + (1, 1) => Ok(Self::GetSystemStart), + (1, 2) => Ok(Self::GetChainBlockNo), + (1, 3) => Ok(Self::GetChainPoint), + _ => { + return Err(decode::Error::message( + "invalid (size, tag) for lsq request", + )) + } + } } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct GenericResponse(Vec); +impl GenericResponse { + /// "bytes" must be valid CBOR + pub fn new(bytes: Vec) -> Self { + Self(bytes) + } +} + impl Encode<()> for GenericResponse { fn encode( &self, - _e: &mut Encoder, + e: &mut Encoder, _ctx: &mut (), ) -> Result<(), encode::Error> { - todo!() + e.writer_mut() + .write_all(&self.0) + .map_err(|e| encode::Error::write(e)) } } @@ -69,10 +283,11 @@ impl<'b> Decode<'b, ()> for GenericResponse { } } +/// Queries available as of N2C V16 #[derive(Debug, Clone)] -pub struct QueryV10 {} +pub struct QueryV16 {} -impl Query for QueryV10 { - type Request = RequestV10; +impl Query for QueryV16 { + type Request = Request; type Response = GenericResponse; } diff --git a/pallas-network/src/miniprotocols/localstate/server.rs b/pallas-network/src/miniprotocols/localstate/server.rs new file mode 100644 index 0000000..d3ceba2 --- /dev/null +++ b/pallas-network/src/miniprotocols/localstate/server.rs @@ -0,0 +1,184 @@ +use std::fmt::Debug; + +use pallas_codec::Fragment; + +use std::marker::PhantomData; +use thiserror::*; + +use super::{AcquireFailure, Message, Query, State}; +use crate::miniprotocols::Point; +use crate::multiplexer; + +#[derive(Error, Debug)] +pub enum Error { + #[error("attempted to receive message while agency is ours")] + AgencyIsOurs, + #[error("attempted to send message while agency is theirs")] + AgencyIsTheirs, + #[error("inbound message is not valid for current state")] + InvalidInbound, + #[error("outbound message is not valid for current state")] + InvalidOutbound, + #[error("error while sending or receiving data through the channel")] + Plexer(multiplexer::Error), +} + +/// Request received from the client to acquire the ledger +pub struct ClientAcquireRequest(pub Option); + +/// Request received from the client when in the Acquired state +#[derive(Debug)] +pub enum ClientQueryRequest { + ReAcquire(Option), + Query(Q::Request), + Release, +} + +pub struct GenericServer(State, multiplexer::ChannelBuffer, PhantomData) +where + Q: Query, + Message: Fragment; + +impl GenericServer +where + Q: Query, + Message: Fragment, +{ + pub fn new(channel: multiplexer::AgentChannel) -> Self { + Self( + State::Idle, + multiplexer::ChannelBuffer::new(channel), + PhantomData {}, + ) + } + + pub fn state(&self) -> &State { + &self.0 + } + + pub fn is_done(&self) -> bool { + self.0 == State::Done + } + + fn has_agency(&self) -> bool { + match self.state() { + State::Acquiring => true, + State::Querying => true, + _ => false, + } + } + + fn assert_agency_is_ours(&self) -> Result<(), Error> { + if !self.has_agency() { + Err(Error::AgencyIsTheirs) + } else { + Ok(()) + } + } + + fn assert_agency_is_theirs(&self) -> Result<(), Error> { + if self.has_agency() { + Err(Error::AgencyIsOurs) + } else { + Ok(()) + } + } + + fn assert_outbound_state(&self, msg: &Message) -> Result<(), Error> { + match (&self.0, msg) { + (State::Acquiring, Message::Acquired) => Ok(()), + (State::Acquiring, Message::Failure(_)) => Ok(()), + (State::Querying, Message::Result(_)) => Ok(()), + _ => Err(Error::InvalidOutbound), + } + } + + fn assert_inbound_state(&self, msg: &Message) -> Result<(), Error> { + match (&self.0, msg) { + (State::Idle, Message::Acquire(_)) => Ok(()), + (State::Idle, Message::Done) => Ok(()), + (State::Acquired, Message::Query(_)) => Ok(()), + (State::Acquired, Message::ReAcquire(_)) => Ok(()), + (State::Acquired, Message::Release) => Ok(()), + _ => Err(Error::InvalidInbound), + } + } + + pub async fn send_message(&mut self, msg: &Message) -> Result<(), Error> { + self.assert_agency_is_ours()?; + self.assert_outbound_state(msg)?; + self.1.send_msg_chunks(msg).await.map_err(Error::Plexer)?; + + Ok(()) + } + + pub async fn recv_message(&mut self) -> Result, Error> { + self.assert_agency_is_theirs()?; + let msg = self.1.recv_full_msg().await.map_err(Error::Plexer)?; + self.assert_inbound_state(&msg)?; + + Ok(msg) + } + + pub async fn send_failure(&mut self, reason: AcquireFailure) -> Result<(), Error> { + let msg = Message::::Failure(reason); + self.send_message(&msg).await?; + self.0 = State::Idle; + + Ok(()) + } + + pub async fn send_acquired(&mut self) -> Result<(), Error> { + let msg = Message::::Acquired; + self.send_message(&msg).await?; + self.0 = State::Acquired; + + Ok(()) + } + + pub async fn send_result(&mut self, response: Q::Response) -> Result<(), Error> { + let msg = Message::::Result(response); + self.send_message(&msg).await?; + self.0 = State::Acquired; + + Ok(()) + } + + /// Receive a message from the Client when the protocol is in the Idle state + /// + /// Returns the client's request to acquire the ledger or None if a Done + /// message was received from the client causing the protocol to finish. + pub async fn recv_while_idle(&mut self) -> Result, Error> { + match self.recv_message().await? { + Message::Acquire(point) => { + self.0 = State::Acquiring; + Ok(Some(ClientAcquireRequest(point))) + } + Message::Done => { + self.0 = State::Done; + Ok(None) + } + _ => Err(Error::InvalidInbound), + } + } + + pub async fn recv_while_acquired(&mut self) -> Result, Error> { + match self.recv_message().await? { + Message::ReAcquire(point) => { + self.0 = State::Acquiring; + Ok(ClientQueryRequest::ReAcquire(point)) + } + Message::Query(query) => { + self.0 = State::Querying; + Ok(ClientQueryRequest::Query(query)) + } + Message::Release => { + self.0 = State::Idle; + Ok(ClientQueryRequest::Release) + } + _ => Err(Error::InvalidInbound), + } + } +} + +pub type Server = GenericServer; diff --git a/pallas-network/tests/protocols.rs b/pallas-network/tests/protocols.rs index 59de6a5..ad5cb34 100644 --- a/pallas-network/tests/protocols.rs +++ b/pallas-network/tests/protocols.rs @@ -1,15 +1,23 @@ +use std::fs; use std::net::{Ipv4Addr, SocketAddrV4}; use std::time::Duration; -use pallas_network::facades::{PeerClient, PeerServer}; +use pallas_network::facades::{NodeClient, PeerClient, PeerServer}; use pallas_network::miniprotocols::blockfetch::BlockRequest; +use pallas_network::miniprotocols::handshake::n2c; +use pallas_network::miniprotocols::handshake::n2n::VersionData; +use pallas_network::miniprotocols::localstate::queries::{GenericResponse, Request}; +use pallas_network::miniprotocols::localstate::{ClientAcquireRequest, ClientQueryRequest}; use pallas_network::miniprotocols::chainsync::{ClientRequest, HeaderContent, Tip}; use pallas_network::miniprotocols::{ blockfetch, chainsync::{self, NextResponse}, Point, }; -use tokio::net::TcpListener; +use pallas_network::miniprotocols::{handshake, localstate}; +use pallas_network::multiplexer::{Bearer, Plexer}; +use std::path::Path; +use tokio::net::{TcpListener, UnixListener}; #[tokio::test] #[ignore] @@ -247,6 +255,150 @@ pub async fn blockfetch_server_and_client_happy_path() { _ = tokio::join!(client, server); } +#[tokio::test] +#[ignore] +pub async fn local_state_query_server_and_client_happy_path() { + let server = tokio::spawn({ + async move { + // server setup + let socket_path = Path::new("node.socket"); + + if socket_path.exists() { + fs::remove_file(&socket_path).unwrap(); + } + + let unix_listener = UnixListener::bind(socket_path).unwrap(); + + let (bearer, _) = Bearer::accept_unix(&unix_listener).await.unwrap(); + + let mut server_plexer = Plexer::new(bearer); + + let mut server_hs: handshake::Server = + handshake::Server::new(server_plexer.subscribe_server(0)); + + let mut server_sq: localstate::Server = + localstate::Server::new(server_plexer.subscribe_server(7)); + + tokio::spawn(async move { server_plexer.run().await }); + + server_hs.receive_proposed_versions().await.unwrap(); + server_hs + .accept_version(10, n2c::VersionData::new(0, Some(false))) + .await + .unwrap(); + + // server receives range from client, sends blocks + + let ClientAcquireRequest(maybe_point) = + server_sq.recv_while_idle().await.unwrap().unwrap(); + + assert_eq!(maybe_point, Some(Point::Origin)); + assert_eq!(*server_sq.state(), localstate::State::Acquiring); + + // server_bf.send_block_range(bodies).await.unwrap(); + + server_sq.send_acquired().await.unwrap(); + + assert_eq!(*server_sq.state(), localstate::State::Acquired); + + // server receives query from client + + let query = match server_sq.recv_while_acquired().await.unwrap() { + ClientQueryRequest::Query(q) => q, + x => panic!("unexpected message from client: {x:?}"), + }; + + assert_eq!( + query, + Request::BlockQuery(localstate::queries::BlockQuery::GetStakePools) + ); + + assert_eq!(*server_sq.state(), localstate::State::Querying); + + server_sq + .send_result(GenericResponse::new(hex::decode("82011A008BD423").unwrap())) + .await + .unwrap(); + + assert_eq!(*server_sq.state(), localstate::State::Acquired); + + // server receives reaquire from the client + + let maybe_point = match server_sq.recv_while_acquired().await.unwrap() { + ClientQueryRequest::ReAcquire(p) => p, + x => panic!("unexpected message from client: {x:?}"), + }; + + assert_eq!(maybe_point, Some(Point::Specific(1337, vec![1, 2, 3]))); + assert_eq!(*server_sq.state(), localstate::State::Acquiring); + + server_sq.send_acquired().await.unwrap(); + + // server receives release from the client + + match server_sq.recv_while_acquired().await.unwrap() { + ClientQueryRequest::Release => (), + x => panic!("unexpected message from client: {x:?}"), + }; + + assert!(server_sq.recv_while_idle().await.unwrap().is_none()); + + assert_eq!(*server_sq.state(), localstate::State::Done); + } + }); + + let client = tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(1)).await; + + // client setup + + let socket_path = "node.socket"; + + let mut client_to_server_conn = NodeClient::connect(socket_path, 0).await.unwrap(); + + let client_sq = client_to_server_conn.statequery(); + + // client sends acquire + + client_sq.send_acquire(Some(Point::Origin)).await.unwrap(); + + client_sq.recv_while_acquiring().await.unwrap(); + + assert_eq!(*client_sq.state(), localstate::State::Acquired); + + // client sends a BlockQuery + + client_sq + .send_query(Request::BlockQuery( + localstate::queries::BlockQuery::GetStakePools, + )) + .await + .unwrap(); + + let resp = client_sq.recv_while_querying().await.unwrap(); + + assert_eq!( + resp, + GenericResponse::new(hex::decode("82011A008BD423").unwrap()) + ); + + // client sends a ReAquire + + client_sq + .send_reacquire(Some(Point::Specific(1337, vec![1, 2, 3]))) + .await + .unwrap(); + + client_sq.recv_while_acquiring().await.unwrap(); + + client_sq.send_release().await.unwrap(); + + client_sq.send_done().await.unwrap(); + }); + + _ = tokio::join!(client, server); +} + #[tokio::test] #[ignore] pub async fn chainsync_server_and_client_happy_path_n2n() { @@ -263,9 +415,21 @@ pub async fn chainsync_server_and_client_happy_path_n2n() { .await .unwrap(); - let mut peer_server = PeerServer::accept(&server_listener, 0).await.unwrap(); + let (bearer, _) = Bearer::accept_tcp(&server_listener).await.unwrap(); - let server_cs = peer_server.chainsync(); + let mut server_plexer = Plexer::new(bearer); + + let mut server_hs: handshake::Server = + handshake::Server::new(server_plexer.subscribe_server(0)); + let mut server_cs = chainsync::N2NServer::new(server_plexer.subscribe_server(2)); + + tokio::spawn(async move { server_plexer.run().await }); + + server_hs.receive_proposed_versions().await.unwrap(); + server_hs + .accept_version(10, VersionData::new(0, false)) + .await + .unwrap(); // server receives find intersect from client, sends intersect point @@ -359,7 +523,7 @@ pub async fn chainsync_server_and_client_happy_path_n2n() { }); let client = tokio::spawn(async move { - tokio::time::sleep(Duration::from_secs(1)).await; + tokio::time::sleep(Duration::from_secs(2)).await; // client setup @@ -432,6 +596,4 @@ pub async fn chainsync_server_and_client_happy_path_n2n() { }); _ = tokio::join!(client, server); -} - -// TODO: redo txsubmission client test +} \ No newline at end of file