diff --git a/src/events.rs b/src/events.rs index e494668..57831d6 100644 --- a/src/events.rs +++ b/src/events.rs @@ -64,7 +64,7 @@ impl Irc { channel: &str, message: &str, ) { - let mut context = self.context.write().await; + let context = self.context.read().await; if channel == &context.config.nick { if message.ends_with(&format!( @@ -74,6 +74,7 @@ impl Irc { let nickserv_pass = context.config.nickserv_pass.as_ref().unwrap().to_string(); let nickserv_email = context.config.nickserv_email.as_ref().unwrap().to_string(); info!("Registering to nickserv now."); + let mut context = self.context.write().await; context.privmsg( "NickServ", &format!("REGISTER {} {}", nickserv_pass, nickserv_email), diff --git a/src/irc_command.rs b/src/irc_command.rs index ae1497e..8c359ad 100644 --- a/src/irc_command.rs +++ b/src/irc_command.rs @@ -6,6 +6,7 @@ macro_rules! make_irc_command_enum { ($($variant:ident: $value:expr),+) => { #[allow(non_camel_case_types)] + #[derive(Debug)] pub enum IrcCommand { UNKNOWN, $($variant),+ @@ -19,7 +20,6 @@ macro_rules! make_irc_command_enum { } } } - }; } diff --git a/src/lib.rs b/src/lib.rs index 3ff4366..d5a7f2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,7 +23,7 @@ use serde::{Deserialize, Serialize}; use system::{IntoSystem, Response, StoredSystem, System}; use tokio::{ fs::File, - io::{AsyncReadExt, AsyncWriteExt}, + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}, net::TcpStream, sync::RwLock, }; @@ -287,11 +287,14 @@ impl Context { } } +pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Unpin {} + +impl AsyncReadWrite for T {} + pub struct Irc { context: Arc>, - recv_queue: VecDeque, flood_controls: HashMap, - stream: Stream, + stream: Option>, partial_line: String, } @@ -313,8 +316,7 @@ impl Irc { Ok(Self { context, - stream: Stream::None, - recv_queue: VecDeque::new(), + stream: None, flood_controls: HashMap::default(), partial_line: String::new(), }) @@ -341,7 +343,7 @@ impl Irc { } pub async fn connect(&mut self) -> std::io::Result<()> { - let mut context = self.context.write().await; + let context = self.context.read().await; let domain = format!("{}:{}", context.config.host, context.config.port); @@ -361,13 +363,11 @@ impl Irc { let stream = async_native_tls::connect(context.config.host.clone(), plain_stream) .await .unwrap(); - self.stream = Stream::Tls(stream); - context.register(); + self.stream = Some(Box::new(stream)); return Ok(()); } - self.stream = Stream::Plain(plain_stream); - context.register(); + self.stream = Some(Box::new(plain_stream)); Ok(()) } @@ -384,7 +384,6 @@ impl Irc { let elapsed = flood_control.last_cmd.elapsed().unwrap(); - if elapsed.as_secs_f32() < self.context.read().await.config.flood_interval { warn!("they be floodin @ {channel}!"); return true; @@ -394,77 +393,14 @@ impl Irc { false } - async fn recv(&mut self) -> std::io::Result<()> { - let mut buf = [0; MAX_MSG_LEN]; - - let bytes_read = match self.stream.read(&mut buf).await { - Ok(bytes_read) => bytes_read, - Err(err) => match err.kind() { - ErrorKind::WouldBlock => { - return Ok(()); - } - _ => panic!("{err}"), - }, - }; - - if bytes_read == 0 { - return Ok(()); - } - - let buf = &buf[..bytes_read]; - - self.partial_line += String::from_utf8_lossy(buf).into_owned().as_str(); - let new_lines: Vec<&str> = self.partial_line.split("\r\n").collect(); - let len = new_lines.len(); - - for (index, line) in new_lines.into_iter().enumerate() { - if index == len - 1 && &buf[buf.len() - 3..] != b"\r\n" { - self.partial_line = line.to_owned(); - break; - } - self.recv_queue.push_back(line.to_owned()); - } - Ok(()) - } - - async fn send(&mut self) -> std::io::Result<()> { - let mut context = self.context.write().await; - while context.send_queue.len() > 0 { - let msg = context.send_queue.pop_front().unwrap(); - - trace!(">> {}", msg.replace("\r\n", "")); - let bytes_written = match self.stream.write(msg.as_bytes()).await { - Ok(bytes_written) => bytes_written, - Err(err) => match err.kind() { - ErrorKind::WouldBlock => { - return Ok(()); - } - _ => panic!("{err}"), - }, - }; - - if bytes_written < msg.len() { - context - .send_queue - .push_front(msg[bytes_written..].to_owned()); - } - } - Ok(()) - } - - pub async fn handle_commands(&mut self) { - while self.recv_queue.len() != 0 { - let owned_line = self.recv_queue.pop_front().unwrap(); + pub async fn handle_commands(&mut self, mut lines: VecDeque) { + while lines.len() != 0 { + let owned_line = lines.pop_front().unwrap(); let line = owned_line.as_str(); trace!("<< {:?}", line); - let mut message: IrcMessage = line.into(); - - let Some(prefix) = &mut message.prefix else { - return self.handle_message(&message).await; - }; - + let message: IrcMessage = line.into(); self.handle_message(&message).await; } } @@ -511,10 +447,98 @@ impl Irc { } } - pub async fn update(&mut self) -> std::io::Result<()> { - self.recv().await?; - self.handle_commands().await; - self.send().await?; - Ok(()) + pub async fn run(&mut self) -> std::io::Result<()> { + info!("Ready!"); + { + let mut context = self.context.write().await; + context.register(); + } + + let stream = self.stream.take().unwrap(); + + let (mut reader, mut writer) = tokio::io::split(stream); + + let cloned_ctx = self.context.clone(); + tokio::spawn(async move { + loop { + send(&mut writer, &cloned_ctx).await.unwrap(); + } + }); + + loop { + let lines = recv(&mut reader, &mut self.partial_line).await?; + self.handle_commands(lines.into()).await; + } } } + +async fn send( + writer: &mut WriteHalf, + arc_context: &RwLock, +) -> std::io::Result<()> { + let mut len; + { + let context = arc_context.read().await; + len = context.send_queue.len(); + } + + while len > 0 { + let mut context = arc_context.write().await; + let msg = context.send_queue.pop_front().unwrap(); + len -= 1; + + trace!(">> {}", msg.replace("\r\n", "")); + let bytes_written = match writer.write(msg.as_bytes()).await { + Ok(bytes_written) => bytes_written, + Err(err) => match err.kind() { + ErrorKind::WouldBlock => { + return Ok(()); + } + _ => panic!("{err}"), + }, + }; + + if bytes_written < msg.len() { + context + .send_queue + .push_front(msg[bytes_written..].to_owned()); + } + } + Ok(()) +} + +async fn recv( + reader: &mut ReadHalf, + partial_line: &mut String, +) -> std::io::Result> { + let mut buf = [0; MAX_MSG_LEN]; + let mut lines = vec![]; + let bytes_read = match reader.read(&mut buf).await { + Ok(bytes_read) => bytes_read, + Err(err) => match err.kind() { + ErrorKind::WouldBlock => { + return Ok(lines); + } + _ => panic!("{err}"), + }, + }; + + if bytes_read == 0 { + return Ok(lines); + } + + let buf = &buf[..bytes_read]; + + *partial_line += String::from_utf8_lossy(buf).into_owned().as_str(); + let new_lines: Vec<&str> = partial_line.split("\r\n").collect(); + let len = new_lines.len(); + + for (index, line) in new_lines.into_iter().enumerate() { + if index == len - 1 && &buf[buf.len() - 3..] != b"\r\n" { + *partial_line = line.to_owned(); + break; + } + lines.push(line.to_owned()); + } + Ok(lines) +}