finilized async shit

This commit is contained in:
ayywrk 2023-05-30 00:49:52 +02:00
parent 734b25c403
commit 794a08e32c
3 changed files with 110 additions and 85 deletions

View File

@ -64,7 +64,7 @@ impl Irc {
channel: &str, channel: &str,
message: &str, message: &str,
) { ) {
let mut context = self.context.write().await; let context = self.context.read().await;
if channel == &context.config.nick { if channel == &context.config.nick {
if message.ends_with(&format!( if message.ends_with(&format!(
@ -74,6 +74,7 @@ impl Irc {
let nickserv_pass = context.config.nickserv_pass.as_ref().unwrap().to_string(); let nickserv_pass = context.config.nickserv_pass.as_ref().unwrap().to_string();
let nickserv_email = context.config.nickserv_email.as_ref().unwrap().to_string(); let nickserv_email = context.config.nickserv_email.as_ref().unwrap().to_string();
info!("Registering to nickserv now."); info!("Registering to nickserv now.");
let mut context = self.context.write().await;
context.privmsg( context.privmsg(
"NickServ", "NickServ",
&format!("REGISTER {} {}", nickserv_pass, nickserv_email), &format!("REGISTER {} {}", nickserv_pass, nickserv_email),

View File

@ -6,6 +6,7 @@ macro_rules! make_irc_command_enum {
($($variant:ident: $value:expr),+) => { ($($variant:ident: $value:expr),+) => {
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
#[derive(Debug)]
pub enum IrcCommand { pub enum IrcCommand {
UNKNOWN, UNKNOWN,
$($variant),+ $($variant),+
@ -19,7 +20,6 @@ macro_rules! make_irc_command_enum {
} }
} }
} }
}; };
} }

View File

@ -23,7 +23,7 @@ use serde::{Deserialize, Serialize};
use system::{IntoSystem, Response, StoredSystem, System}; use system::{IntoSystem, Response, StoredSystem, System};
use tokio::{ use tokio::{
fs::File, fs::File,
io::{AsyncReadExt, AsyncWriteExt}, io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf},
net::TcpStream, net::TcpStream,
sync::RwLock, sync::RwLock,
}; };
@ -287,11 +287,14 @@ impl Context {
} }
} }
pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Unpin {}
impl<T: AsyncRead + AsyncWrite + Send + Unpin> AsyncReadWrite for T {}
pub struct Irc { pub struct Irc {
context: Arc<RwLock<Context>>, context: Arc<RwLock<Context>>,
recv_queue: VecDeque<String>,
flood_controls: HashMap<String, FloodControl>, flood_controls: HashMap<String, FloodControl>,
stream: Stream, stream: Option<Box<dyn AsyncReadWrite>>,
partial_line: String, partial_line: String,
} }
@ -313,8 +316,7 @@ impl Irc {
Ok(Self { Ok(Self {
context, context,
stream: Stream::None, stream: None,
recv_queue: VecDeque::new(),
flood_controls: HashMap::default(), flood_controls: HashMap::default(),
partial_line: String::new(), partial_line: String::new(),
}) })
@ -341,7 +343,7 @@ impl Irc {
} }
pub async fn connect(&mut self) -> std::io::Result<()> { pub async fn connect(&mut self) -> std::io::Result<()> {
let mut context = self.context.write().await; let context = self.context.read().await;
let domain = format!("{}:{}", context.config.host, context.config.port); let domain = format!("{}:{}", context.config.host, context.config.port);
@ -361,13 +363,11 @@ impl Irc {
let stream = async_native_tls::connect(context.config.host.clone(), plain_stream) let stream = async_native_tls::connect(context.config.host.clone(), plain_stream)
.await .await
.unwrap(); .unwrap();
self.stream = Stream::Tls(stream); self.stream = Some(Box::new(stream));
context.register();
return Ok(()); return Ok(());
} }
self.stream = Stream::Plain(plain_stream); self.stream = Some(Box::new(plain_stream));
context.register();
Ok(()) Ok(())
} }
@ -384,7 +384,6 @@ impl Irc {
let elapsed = flood_control.last_cmd.elapsed().unwrap(); let elapsed = flood_control.last_cmd.elapsed().unwrap();
if elapsed.as_secs_f32() < self.context.read().await.config.flood_interval { if elapsed.as_secs_f32() < self.context.read().await.config.flood_interval {
warn!("they be floodin @ {channel}!"); warn!("they be floodin @ {channel}!");
return true; return true;
@ -394,77 +393,14 @@ impl Irc {
false false
} }
async fn recv(&mut self) -> std::io::Result<()> { pub async fn handle_commands(&mut self, mut lines: VecDeque<String>) {
let mut buf = [0; MAX_MSG_LEN]; while lines.len() != 0 {
let owned_line = lines.pop_front().unwrap();
let bytes_read = match self.stream.read(&mut buf).await {
Ok(bytes_read) => bytes_read,
Err(err) => match err.kind() {
ErrorKind::WouldBlock => {
return Ok(());
}
_ => panic!("{err}"),
},
};
if bytes_read == 0 {
return Ok(());
}
let buf = &buf[..bytes_read];
self.partial_line += String::from_utf8_lossy(buf).into_owned().as_str();
let new_lines: Vec<&str> = self.partial_line.split("\r\n").collect();
let len = new_lines.len();
for (index, line) in new_lines.into_iter().enumerate() {
if index == len - 1 && &buf[buf.len() - 3..] != b"\r\n" {
self.partial_line = line.to_owned();
break;
}
self.recv_queue.push_back(line.to_owned());
}
Ok(())
}
async fn send(&mut self) -> std::io::Result<()> {
let mut context = self.context.write().await;
while context.send_queue.len() > 0 {
let msg = context.send_queue.pop_front().unwrap();
trace!(">> {}", msg.replace("\r\n", ""));
let bytes_written = match self.stream.write(msg.as_bytes()).await {
Ok(bytes_written) => bytes_written,
Err(err) => match err.kind() {
ErrorKind::WouldBlock => {
return Ok(());
}
_ => panic!("{err}"),
},
};
if bytes_written < msg.len() {
context
.send_queue
.push_front(msg[bytes_written..].to_owned());
}
}
Ok(())
}
pub async fn handle_commands(&mut self) {
while self.recv_queue.len() != 0 {
let owned_line = self.recv_queue.pop_front().unwrap();
let line = owned_line.as_str(); let line = owned_line.as_str();
trace!("<< {:?}", line); trace!("<< {:?}", line);
let mut message: IrcMessage = line.into(); let message: IrcMessage = line.into();
let Some(prefix) = &mut message.prefix else {
return self.handle_message(&message).await;
};
self.handle_message(&message).await; self.handle_message(&message).await;
} }
} }
@ -511,10 +447,98 @@ impl Irc {
} }
} }
pub async fn update(&mut self) -> std::io::Result<()> { pub async fn run(&mut self) -> std::io::Result<()> {
self.recv().await?; info!("Ready!");
self.handle_commands().await; {
self.send().await?; let mut context = self.context.write().await;
Ok(()) context.register();
}
let stream = self.stream.take().unwrap();
let (mut reader, mut writer) = tokio::io::split(stream);
let cloned_ctx = self.context.clone();
tokio::spawn(async move {
loop {
send(&mut writer, &cloned_ctx).await.unwrap();
}
});
loop {
let lines = recv(&mut reader, &mut self.partial_line).await?;
self.handle_commands(lines.into()).await;
}
} }
} }
async fn send<T: AsyncWrite>(
writer: &mut WriteHalf<T>,
arc_context: &RwLock<Context>,
) -> std::io::Result<()> {
let mut len;
{
let context = arc_context.read().await;
len = context.send_queue.len();
}
while len > 0 {
let mut context = arc_context.write().await;
let msg = context.send_queue.pop_front().unwrap();
len -= 1;
trace!(">> {}", msg.replace("\r\n", ""));
let bytes_written = match writer.write(msg.as_bytes()).await {
Ok(bytes_written) => bytes_written,
Err(err) => match err.kind() {
ErrorKind::WouldBlock => {
return Ok(());
}
_ => panic!("{err}"),
},
};
if bytes_written < msg.len() {
context
.send_queue
.push_front(msg[bytes_written..].to_owned());
}
}
Ok(())
}
async fn recv<T: AsyncRead>(
reader: &mut ReadHalf<T>,
partial_line: &mut String,
) -> std::io::Result<Vec<String>> {
let mut buf = [0; MAX_MSG_LEN];
let mut lines = vec![];
let bytes_read = match reader.read(&mut buf).await {
Ok(bytes_read) => bytes_read,
Err(err) => match err.kind() {
ErrorKind::WouldBlock => {
return Ok(lines);
}
_ => panic!("{err}"),
},
};
if bytes_read == 0 {
return Ok(lines);
}
let buf = &buf[..bytes_read];
*partial_line += String::from_utf8_lossy(buf).into_owned().as_str();
let new_lines: Vec<&str> = partial_line.split("\r\n").collect();
let len = new_lines.len();
for (index, line) in new_lines.into_iter().enumerate() {
if index == len - 1 && &buf[buf.len() - 3..] != b"\r\n" {
*partial_line = line.to_owned();
break;
}
lines.push(line.to_owned());
}
Ok(lines)
}