diff --git a/pallas-handshake/src/common.rs b/pallas-handshake/src/common.rs index 4e1bc73..a3334f5 100644 --- a/pallas-handshake/src/common.rs +++ b/pallas-handshake/src/common.rs @@ -1,5 +1,5 @@ use itertools::Itertools; -use pallas_machines::{DecodePayload, EncodePayload, PayloadEncoder}; +use pallas_machines::{DecodePayload, EncodePayload, MachineError, PayloadEncoder}; use std::{collections::HashMap, fmt::Debug}; pub const TESTNET_MAGIC: u64 = 1097911063; @@ -17,10 +17,7 @@ impl EncodePayload for VersionTable where T: Debug + Clone + EncodePayload + DecodePayload, { - fn encode_payload( - &self, - e: &mut PayloadEncoder, - ) -> Result<(), Box> { + fn encode_payload(&self, e: &mut PayloadEncoder) -> Result<(), Box> { e.map(self.values.len() as u64)?; for key in self.values.keys().sorted() { @@ -44,10 +41,7 @@ pub enum RefuseReason { } impl EncodePayload for RefuseReason { - fn encode_payload( - &self, - e: &mut PayloadEncoder, - ) -> Result<(), Box> { + fn encode_payload(&self, e: &mut PayloadEncoder) -> Result<(), Box> { match self { RefuseReason::VersionMismatch(versions) => { e.array(2)?; @@ -69,7 +63,7 @@ impl EncodePayload for RefuseReason { } RefuseReason::Refused(version, msg) => { e.array(3)?; - e.u16(1)?; + e.u16(2)?; e.u64(*version)?; e.str(msg)?; @@ -78,3 +72,32 @@ impl EncodePayload for RefuseReason { } } } + +impl DecodePayload for RefuseReason { + fn decode_payload( + d: &mut pallas_machines::PayloadDecoder, + ) -> Result> { + d.array()?; + + match d.u16()? { + 0 => { + let versions = d.array_iter::()?; + let versions = versions.try_collect()?; + Ok(RefuseReason::VersionMismatch(versions)) + } + 1 => { + let version = d.u64()?; + let msg = d.str()?; + + Ok(RefuseReason::HandshakeDecodeError(version, msg.to_string())) + } + 2 => { + let version = d.u64()?; + let msg = d.str()?; + + Ok(RefuseReason::Refused(version, msg.to_string())) + } + x => Err(Box::new(MachineError::BadLabel(x))), + } + } +} diff --git a/pallas-handshake/src/n2c.rs b/pallas-handshake/src/n2c.rs index 6388474..cab38ca 100644 --- a/pallas-handshake/src/n2c.rs +++ b/pallas-handshake/src/n2c.rs @@ -1,9 +1,10 @@ use core::panic; use std::collections::HashMap; +use itertools::Merge; use pallas_machines::{ - Agent, DecodePayload, EncodePayload, MachineError, MachineOutput, - PayloadDecoder, PayloadEncoder, + Agent, DecodePayload, EncodePayload, MachineError, MachineOutput, PayloadDecoder, + PayloadEncoder, }; use crate::common::{NetworkMagic, RefuseReason, VersionNumber}; @@ -42,13 +43,10 @@ impl VersionTable { } #[derive(Debug, Clone)] -pub struct VersionData (NetworkMagic,); +pub struct VersionData(NetworkMagic); impl EncodePayload for VersionData { - fn encode_payload( - &self, - e: &mut PayloadEncoder, - ) -> Result<(), Box> { + fn encode_payload(&self, e: &mut PayloadEncoder) -> Result<(), Box> { e.u64(self.0)?; Ok(()) @@ -56,9 +54,7 @@ impl EncodePayload for VersionData { } impl DecodePayload for VersionData { - fn decode_payload( - d: &mut PayloadDecoder, - ) -> Result> { + fn decode_payload(d: &mut PayloadDecoder) -> Result> { let network_magic = d.u64()?; Ok(Self(network_magic)) @@ -73,10 +69,7 @@ pub enum Message { } impl EncodePayload for Message { - fn encode_payload( - &self, - e: &mut PayloadEncoder, - ) -> Result<(), Box> { + fn encode_payload(&self, e: &mut PayloadEncoder) -> Result<(), Box> { match self { Message::Propose(version_table) => { e.array(2)?.u16(0)?; @@ -98,24 +91,22 @@ impl EncodePayload for Message { } impl DecodePayload for Message { - fn decode_payload( - d: &mut PayloadDecoder, - ) -> Result> { + fn decode_payload(d: &mut PayloadDecoder) -> Result> { d.array()?; - let msg = match d.u16()? { + match d.u16()? { 0 => todo!(), 1 => { let version_number = d.u64()?; let version_data = VersionData::decode_payload(d)?; - - Message::Accept(version_number, version_data) + Ok(Message::Accept(version_number, version_data)) } - 2 => todo!(), - x => return Err(Box::new(MachineError::BadLabel(x))), - }; - - Ok(msg) + 2 => { + let reason = RefuseReason::decode_payload(d)?; + Ok(Message::Refuse(reason)) + } + x => Err(Box::new(MachineError::BadLabel(x))), + } } } @@ -165,10 +156,7 @@ impl Agent for Client { } } - fn send_next( - self, - tx: &impl MachineOutput, - ) -> Result> { + fn send_next(self, tx: &impl MachineOutput) -> Result> { match self.state { State::Propose => { tx.send_msg(&Message::Propose(self.version_table.clone()))?; @@ -182,10 +170,7 @@ impl Agent for Client { } } - fn receive_next( - self, - msg: Self::Message, - ) -> Result> { + fn receive_next(self, msg: Self::Message) -> Result> { match (self.state, msg) { (State::Confirm, Message::Accept(version, data)) => Ok(Self { state: State::Done, diff --git a/pallas-handshake/src/n2n.rs b/pallas-handshake/src/n2n.rs index 3c2e62a..10b8f5a 100644 --- a/pallas-handshake/src/n2n.rs +++ b/pallas-handshake/src/n2n.rs @@ -2,8 +2,8 @@ use core::panic; use std::collections::HashMap; use pallas_machines::{ - Agent, DecodePayload, EncodePayload, MachineError, MachineOutput, - PayloadDecoder, PayloadEncoder, + Agent, DecodePayload, EncodePayload, MachineError, MachineOutput, PayloadDecoder, + PayloadEncoder, }; use crate::common::{RefuseReason, VersionNumber}; @@ -28,6 +28,17 @@ impl VersionTable { VersionTable { values } } + + pub fn v6_and_above(network_magic: u64) -> VersionTable { + let values = vec![ + (PROTOCOL_V6, VersionData::new(network_magic, false)), + (PROTOCOL_V7, VersionData::new(network_magic, false)), + ] + .into_iter() + .collect::>(); + + VersionTable { values } + } } #[derive(Debug, Clone)] @@ -37,10 +48,7 @@ pub struct VersionData { } impl VersionData { - pub fn new( - network_magic: u64, - initiator_and_responder_diffusion_mode: bool, - ) -> Self { + pub fn new(network_magic: u64, initiator_and_responder_diffusion_mode: bool) -> Self { VersionData { network_magic, initiator_and_responder_diffusion_mode, @@ -49,10 +57,7 @@ impl VersionData { } impl EncodePayload for VersionData { - fn encode_payload( - &self, - e: &mut PayloadEncoder, - ) -> Result<(), Box> { + fn encode_payload(&self, e: &mut PayloadEncoder) -> Result<(), Box> { e.array(2)? .u64(self.network_magic)? .bool(self.initiator_and_responder_diffusion_mode)?; @@ -62,9 +67,7 @@ impl EncodePayload for VersionData { } impl DecodePayload for VersionData { - fn decode_payload( - d: &mut PayloadDecoder, - ) -> Result> { + fn decode_payload(d: &mut PayloadDecoder) -> Result> { d.array()?; let network_magic = d.u64()?; let initiator_and_responder_diffusion_mode = d.bool()?; @@ -84,10 +87,7 @@ pub enum Message { } impl EncodePayload for Message { - fn encode_payload( - &self, - e: &mut PayloadEncoder, - ) -> Result<(), Box> { + fn encode_payload(&self, e: &mut PayloadEncoder) -> Result<(), Box> { match self { Message::Propose(version_table) => { e.array(2)?.u16(0)?; @@ -109,24 +109,22 @@ impl EncodePayload for Message { } impl DecodePayload for Message { - fn decode_payload( - d: &mut PayloadDecoder, - ) -> Result> { + fn decode_payload(d: &mut PayloadDecoder) -> Result> { d.array()?; - let msg = match d.u16()? { + match d.u16()? { 0 => todo!(), 1 => { let version_number = d.u64()?; let version_data = VersionData::decode_payload(d)?; - - Message::Accept(version_number, version_data) + Ok(Message::Accept(version_number, version_data)) } - 2 => todo!(), - x => return Err(Box::new(MachineError::BadLabel(x))), - }; - - Ok(msg) + 2 => { + let reason = RefuseReason::decode_payload(d)?; + Ok(Message::Refuse(reason)) + } + x => Err(Box::new(MachineError::BadLabel(x))), + } } } @@ -176,10 +174,7 @@ impl Agent for Client { } } - fn send_next( - self, - tx: &impl MachineOutput, - ) -> Result> { + fn send_next(self, tx: &impl MachineOutput) -> Result> { match self.state { State::Propose => { tx.send_msg(&Message::Propose(self.version_table.clone()))?; @@ -193,10 +188,7 @@ impl Agent for Client { } } - fn receive_next( - self, - msg: Self::Message, - ) -> Result> { + fn receive_next(self, msg: Self::Message) -> Result> { match (self.state, msg) { (State::Confirm, Message::Accept(version, data)) => Ok(Self { state: State::Done,