201 lines
5.4 KiB
Rust
201 lines
5.4 KiB
Rust
use std::net::SocketAddr;
|
|
use std::path::Path;
|
|
|
|
use byteorder::{ByteOrder, NetworkEndian};
|
|
use thiserror::Error;
|
|
use tokio::io::AsyncWriteExt;
|
|
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs, UnixStream};
|
|
use tokio::time::Instant;
|
|
use tracing::trace;
|
|
|
|
const HEADER_LEN: usize = 8;
|
|
|
|
pub type Timestamp = u32;
|
|
|
|
pub type Payload = Vec<u8>;
|
|
|
|
pub type Protocol = u16;
|
|
|
|
#[derive(Debug)]
|
|
pub struct Header {
|
|
pub protocol: Protocol,
|
|
pub timestamp: Timestamp,
|
|
pub payload_len: u16,
|
|
}
|
|
|
|
impl From<&[u8]> for Header {
|
|
fn from(value: &[u8]) -> Self {
|
|
let timestamp = NetworkEndian::read_u32(&value[0..4]);
|
|
let protocol = NetworkEndian::read_u16(&value[4..6]);
|
|
let payload_len = NetworkEndian::read_u16(&value[6..8]);
|
|
|
|
Self {
|
|
timestamp,
|
|
protocol,
|
|
payload_len,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<Header> for [u8; 8] {
|
|
fn from(value: Header) -> Self {
|
|
let mut out = [0u8; 8];
|
|
NetworkEndian::write_u32(&mut out[0..4], value.timestamp);
|
|
NetworkEndian::write_u16(&mut out[4..6], value.protocol);
|
|
NetworkEndian::write_u16(&mut out[6..8], value.payload_len);
|
|
|
|
out
|
|
}
|
|
}
|
|
|
|
pub struct Segment {
|
|
pub header: Header,
|
|
pub payload: Payload,
|
|
}
|
|
|
|
pub enum Bearer {
|
|
Tcp(TcpStream),
|
|
Unix(UnixStream),
|
|
}
|
|
|
|
const BUFFER_LEN: usize = 1024 * 10;
|
|
|
|
impl Bearer {
|
|
pub async fn connect_tcp(addr: impl ToSocketAddrs) -> Result<Self, tokio::io::Error> {
|
|
let stream = TcpStream::connect(addr).await?;
|
|
Ok(Self::Tcp(stream))
|
|
}
|
|
|
|
pub async fn accept_tcp(listener: TcpListener) -> tokio::io::Result<(Self, SocketAddr)> {
|
|
let (stream, addr) = listener.accept().await?;
|
|
Ok((Self::Tcp(stream), addr))
|
|
}
|
|
|
|
pub async fn connect_unix(path: impl AsRef<Path>) -> Result<Self, tokio::io::Error> {
|
|
let stream = UnixStream::connect(path).await?;
|
|
Ok(Self::Unix(stream))
|
|
}
|
|
|
|
pub async fn readable(&self) -> tokio::io::Result<()> {
|
|
match self {
|
|
Bearer::Tcp(x) => x.readable().await,
|
|
Bearer::Unix(x) => x.readable().await,
|
|
}
|
|
}
|
|
|
|
fn try_read(&mut self, buf: &mut [u8]) -> tokio::io::Result<usize> {
|
|
match self {
|
|
Bearer::Tcp(x) => x.try_read(buf),
|
|
Bearer::Unix(x) => x.try_read(buf),
|
|
}
|
|
}
|
|
|
|
async fn write_all(&mut self, buf: &[u8]) -> tokio::io::Result<()> {
|
|
match self {
|
|
Bearer::Tcp(x) => x.write_all(buf).await,
|
|
Bearer::Unix(x) => x.write_all(buf).await,
|
|
}
|
|
}
|
|
|
|
async fn flush(&mut self) -> tokio::io::Result<()> {
|
|
match self {
|
|
Bearer::Tcp(x) => x.flush().await,
|
|
Bearer::Unix(x) => x.flush().await,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum Error {
|
|
#[error("no data available in bearer to complete segment")]
|
|
NoData,
|
|
|
|
#[error("unexpected I/O error")]
|
|
Io(#[source] tokio::io::Error),
|
|
}
|
|
|
|
pub struct SegmentBuffer(Bearer, Vec<u8>);
|
|
|
|
impl SegmentBuffer {
|
|
pub fn new(bearer: Bearer) -> Self {
|
|
Self(bearer, Vec::with_capacity(BUFFER_LEN))
|
|
}
|
|
|
|
/// Cancel-safe loop that reads from bearer until certain len
|
|
async fn cancellable_read(&mut self, required: usize) -> Result<(), Error> {
|
|
loop {
|
|
self.0.readable().await.map_err(Error::Io)?;
|
|
trace!("bearer is readable");
|
|
|
|
let remaining = required - self.1.len();
|
|
let mut buf = vec![0u8; remaining];
|
|
|
|
match self.0.try_read(&mut buf) {
|
|
Ok(0) => break Err(Error::NoData),
|
|
Ok(n) => {
|
|
trace!(n, "found data on bearer");
|
|
self.1.extend_from_slice(&buf[0..n]);
|
|
|
|
if self.1.len() >= required {
|
|
break Ok(());
|
|
}
|
|
}
|
|
Err(ref e) if e.kind() == tokio::io::ErrorKind::WouldBlock => {
|
|
trace!("reading from bearer would block");
|
|
continue;
|
|
}
|
|
Err(e) => {
|
|
return Err(Error::Io(e));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Peek the available data in search for a frame header
|
|
async fn peek_header(&mut self) -> Result<Header, Error> {
|
|
trace!("waiting for header buf");
|
|
self.cancellable_read(HEADER_LEN).await?;
|
|
|
|
trace!("found enough data for header");
|
|
let header = &self.1[..HEADER_LEN];
|
|
|
|
Ok(Header::from(header))
|
|
}
|
|
|
|
// Cancel-safe read of a full segment from the bearer
|
|
pub async fn read_segment(&mut self) -> Result<(Protocol, Payload), Error> {
|
|
let header = self.peek_header().await?;
|
|
|
|
trace!("waiting for full segment buf");
|
|
let segment_size = HEADER_LEN + header.payload_len as usize;
|
|
|
|
self.cancellable_read(segment_size).await?;
|
|
|
|
trace!("draining segment buffer");
|
|
let segment = self.1.drain(..segment_size);
|
|
let payload = segment.skip(HEADER_LEN).collect();
|
|
|
|
Ok((header.protocol, payload))
|
|
}
|
|
|
|
pub async fn write_segment(
|
|
&mut self,
|
|
protocol: u16,
|
|
clock: &Instant,
|
|
payload: &[u8],
|
|
) -> Result<(), std::io::Error> {
|
|
let header = Header {
|
|
protocol,
|
|
timestamp: clock.elapsed().as_micros() as u32,
|
|
payload_len: payload.len() as u16,
|
|
};
|
|
|
|
let buf: [u8; 8] = header.into();
|
|
self.0.write_all(&buf).await?;
|
|
self.0.write_all(payload).await?;
|
|
|
|
self.0.flush().await?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|