248 lines
6.8 KiB
Rust
248 lines
6.8 KiB
Rust
|
use std::{
|
||
|
fmt::Display,
|
||
|
net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket},
|
||
|
str::{from_utf8, FromStr},
|
||
|
time::{Duration, Instant},
|
||
|
};
|
||
|
|
||
|
use crate::NetworkLogFile;
|
||
|
|
||
|
use anyhow::Result;
|
||
|
|
||
|
const READ_BUFFER_SIZE: usize = 1024;
|
||
|
const RETRIES: usize = 3;
|
||
|
|
||
|
#[derive(Debug, Clone)]
|
||
|
pub enum SocketContent {
|
||
|
Message(SocketAddr, String),
|
||
|
NewConnection(SocketAddr),
|
||
|
TimeOut(SocketAddr),
|
||
|
}
|
||
|
|
||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||
|
pub enum LogLevel {
|
||
|
None,
|
||
|
Debug,
|
||
|
Verbose,
|
||
|
}
|
||
|
|
||
|
impl Default for LogLevel {
|
||
|
fn default() -> Self {
|
||
|
Self::None
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl FromStr for LogLevel {
|
||
|
type Err = anyhow::Error;
|
||
|
|
||
|
fn from_str(s: &str) -> Result<Self> {
|
||
|
match s {
|
||
|
"None" => Ok(Self::None),
|
||
|
"Debug" => Ok(Self::Debug),
|
||
|
"Verbose" => Ok(Self::Verbose),
|
||
|
|
||
|
_ => Err(anyhow::Error::msg(format!(
|
||
|
"Failed parsing LogLevel from {}",
|
||
|
s
|
||
|
))),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Display for LogLevel {
|
||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
|
match self {
|
||
|
Self::None => write!(f, "None"),
|
||
|
Self::Debug => write!(f, "Debug"),
|
||
|
Self::Verbose => write!(f, "Verbose"),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Debug, Clone)]
|
||
|
pub struct SocketCreateInfo {
|
||
|
pub port: u16,
|
||
|
pub connection_time_out: Duration,
|
||
|
pub heartbeat_interval: Duration,
|
||
|
pub log_level: LogLevel,
|
||
|
}
|
||
|
|
||
|
pub struct Socket {
|
||
|
udp_socket: UdpSocket,
|
||
|
connection_time_out: Duration,
|
||
|
|
||
|
log_level: LogLevel,
|
||
|
pub(crate) heartbeat_interval: Duration,
|
||
|
pub(crate) start: Instant,
|
||
|
}
|
||
|
|
||
|
impl Socket {
|
||
|
pub fn new(info: SocketCreateInfo) -> Result<Self> {
|
||
|
let socket = UdpSocket::bind(SocketAddr::from(SocketAddrV4::new(
|
||
|
Ipv4Addr::new(0, 0, 0, 0),
|
||
|
info.port,
|
||
|
)))?;
|
||
|
|
||
|
socket.set_nonblocking(true)?;
|
||
|
|
||
|
Ok(Self {
|
||
|
udp_socket: socket,
|
||
|
connection_time_out: info.connection_time_out,
|
||
|
|
||
|
log_level: info.log_level,
|
||
|
heartbeat_interval: info.heartbeat_interval,
|
||
|
start: Instant::now(),
|
||
|
})
|
||
|
}
|
||
|
|
||
|
pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
|
||
|
self.udp_socket.set_nonblocking(nonblocking)?;
|
||
|
|
||
|
Ok(())
|
||
|
}
|
||
|
|
||
|
pub fn set_broadcast(&self, broadcast: bool) -> Result<()> {
|
||
|
self.udp_socket.set_broadcast(broadcast)?;
|
||
|
|
||
|
Ok(())
|
||
|
}
|
||
|
|
||
|
pub fn connection_time_out(&self) -> Duration {
|
||
|
self.connection_time_out
|
||
|
}
|
||
|
|
||
|
pub fn send(&self, addr: SocketAddr, mut msg: &str) -> Result<()> {
|
||
|
let total_length = msg.len();
|
||
|
|
||
|
match self.log_level {
|
||
|
LogLevel::None => (),
|
||
|
LogLevel::Debug => {
|
||
|
if total_length > 0 {
|
||
|
NetworkLogFile::log(format!("SOCKET SEND [{}]: {}", addr, msg))?;
|
||
|
}
|
||
|
}
|
||
|
LogLevel::Verbose => {
|
||
|
NetworkLogFile::log(format!("SOCKET SEND [{}]: {}", addr, msg))?;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
let mut tries = 0;
|
||
|
|
||
|
loop {
|
||
|
let mut total_bytes_send = 0;
|
||
|
|
||
|
if msg.len() >= READ_BUFFER_SIZE {
|
||
|
let fragments = msg.len() / READ_BUFFER_SIZE;
|
||
|
|
||
|
for _ in 0..fragments {
|
||
|
let (first, second) = msg.split_at(READ_BUFFER_SIZE);
|
||
|
msg = second;
|
||
|
|
||
|
total_bytes_send += self.send_bytes(addr, first.as_bytes())?.unwrap_or(0);
|
||
|
}
|
||
|
|
||
|
total_bytes_send += self.send_bytes(addr, msg.as_bytes())?.unwrap_or(0);
|
||
|
} else {
|
||
|
total_bytes_send = self.send_bytes(addr, msg.as_bytes())?.unwrap_or(0);
|
||
|
}
|
||
|
|
||
|
if total_bytes_send != total_length {
|
||
|
println!(
|
||
|
"Error when send message did not match the expected length ({} vs {})",
|
||
|
total_length, total_bytes_send,
|
||
|
);
|
||
|
|
||
|
if tries < RETRIES {
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
tries += 1;
|
||
|
} else {
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
Ok(())
|
||
|
}
|
||
|
|
||
|
#[inline]
|
||
|
fn send_bytes(&self, addr: SocketAddr, bytes: &[u8]) -> Result<Option<usize>> {
|
||
|
Self::filter_would_block(self.udp_socket.send_to(bytes, addr))
|
||
|
.map_err(|err| anyhow::Error::msg(format!("Failed to send message {:?}", err.kind())))
|
||
|
}
|
||
|
|
||
|
pub fn receive(&self) -> Result<Option<(SocketAddr, String)>> {
|
||
|
let mut address = None;
|
||
|
let mut msg = String::new();
|
||
|
let mut buffer = [0; READ_BUFFER_SIZE];
|
||
|
|
||
|
while let Some((bytes_read, addr)) = Self::filter_connection_reset(
|
||
|
Self::filter_would_block(self.udp_socket.recv_from(&mut buffer)),
|
||
|
)
|
||
|
.map_err(|err| anyhow::Error::msg(format!("Failed to receive message {:?}", err.kind())))?
|
||
|
.flatten()
|
||
|
{
|
||
|
let bytes = &buffer[0..bytes_read];
|
||
|
|
||
|
match address {
|
||
|
Some(address) => {
|
||
|
if address != addr {
|
||
|
return Err(anyhow::Error::msg(format!(
|
||
|
"Error when messages are from different sources ({} vs {})",
|
||
|
address, addr
|
||
|
)));
|
||
|
}
|
||
|
}
|
||
|
None => {
|
||
|
address = Some(addr);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
msg += from_utf8(bytes)?;
|
||
|
|
||
|
if bytes_read < READ_BUFFER_SIZE {
|
||
|
match self.log_level {
|
||
|
LogLevel::None => (),
|
||
|
LogLevel::Debug => {
|
||
|
if !msg.is_empty() {
|
||
|
NetworkLogFile::log(format!("SOCKET RECEIVE [{}]: {}", addr, msg))?;
|
||
|
}
|
||
|
}
|
||
|
LogLevel::Verbose => {
|
||
|
NetworkLogFile::log(format!("SOCKET RECEIVE [{}]: {}", addr, msg))?;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return Ok(Some((addr, msg)));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return Ok(None);
|
||
|
}
|
||
|
|
||
|
/// Filters std::io::ErrorKind::WouldBlock and creates Option<T>
|
||
|
fn filter_would_block<T: std::fmt::Debug>(
|
||
|
result: std::result::Result<T, std::io::Error>,
|
||
|
) -> std::result::Result<Option<T>, std::io::Error> {
|
||
|
match result {
|
||
|
Ok(t) => Ok(Some(t)),
|
||
|
Err(err) => match err.kind() {
|
||
|
std::io::ErrorKind::WouldBlock => Ok(None),
|
||
|
_ => Err(err),
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn filter_connection_reset<T: std::fmt::Debug>(
|
||
|
result: std::result::Result<T, std::io::Error>,
|
||
|
) -> std::result::Result<Option<T>, std::io::Error> {
|
||
|
match result {
|
||
|
Ok(t) => Ok(Some(t)),
|
||
|
Err(err) => match err.kind() {
|
||
|
std::io::ErrorKind::ConnectionReset => Ok(None),
|
||
|
_ => Err(err),
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
}
|