diff --git a/src/events.rs b/src/events.rs index 7ce2c6b..e494668 100644 --- a/src/events.rs +++ b/src/events.rs @@ -1,85 +1,142 @@ -use std::time::Duration; - use log::{debug, info, warn}; +use std::time::Duration; use crate::{Irc, IrcPrefix}; impl Irc { - pub(crate) fn event_ping(&mut self, ping_token: &str) { + pub(crate) async fn event_ping(&mut self, ping_token: &str) { debug!("PING {}", ping_token); - self.queue(&format!("PONG {}", ping_token)); + + self.context + .write() + .await + .queue(&format!("PONG {}", ping_token)); } - pub(crate) fn event_welcome(&mut self, welcome_msg: &str) { + pub(crate) async fn event_welcome(&mut self, welcome_msg: &str) { debug!("{welcome_msg}"); - // self.identify(); - self.join_config_channels(); + let mut context = self.context.write().await; + context.identify(); + context.join_config_channels(); } - pub(crate) fn event_nicknameinuse(&mut self) { - let new_nick = &format!("{}_", &self.config.nick); + pub(crate) async fn event_nicknameinuse(&mut self) { + let mut context = self.context.write().await; + let new_nick = &format!("{}_", &context.config.nick); warn!("Nick already in use., switching to {}", new_nick); - self.update_nick(new_nick) + context.update_nick(new_nick) } - pub(crate) fn event_kick(&mut self, channel: &str, nick: &str, kicker: &str, reason: &str) { - if nick != &self.config.nick { + pub(crate) async fn event_kick( + &mut self, + channel: &str, + nick: &str, + kicker: &str, + reason: &str, + ) { + let mut context = self.context.write().await; + if nick != &context.config.nick { return; } warn!("We got kicked from {} by {}! ({})", channel, kicker, reason); - self.join(channel); + context.join(channel); } pub(crate) async fn event_quit<'a>(&mut self, prefix: &'a IrcPrefix<'a>) { - if prefix.nick != self.config.nick { + if prefix.nick != self.context.read().await.config.nick { return; } warn!("We quit. We'll reconnect in {} seconds.", 15); std::thread::sleep(Duration::from_secs(15)); self.connect().await.unwrap(); - self.register(); } - pub(crate) fn event_invite(&mut self, prefix: &IrcPrefix, channel: &str) { + pub(crate) async fn event_invite<'a>(&mut self, prefix: &'a IrcPrefix<'a>, channel: &str) { info!("{} invited us to {}", prefix.nick, channel); - self.join(channel); + self.context.write().await.join(channel); } - pub(crate) fn event_notice( + pub(crate) async fn event_notice<'a>( &mut self, - prefix: Option<&IrcPrefix>, + _prefix: Option<&IrcPrefix<'a>>, channel: &str, message: &str, ) { - //TODO, register shit + let mut context = self.context.write().await; + + if channel == &context.config.nick { + if message.ends_with(&format!( + "\x02{}\x02 isn't registered.", + context.config.nick + )) { + let nickserv_pass = context.config.nickserv_pass.as_ref().unwrap().to_string(); + let nickserv_email = context.config.nickserv_email.as_ref().unwrap().to_string(); + info!("Registering to nickserv now."); + context.privmsg( + "NickServ", + &format!("REGISTER {} {}", nickserv_pass, nickserv_email), + ); + } + if message.ends_with(" seconds to register.") { + let seconds = message + .split_whitespace() + .nth(10) + .unwrap() + .parse::() + .unwrap() + + 1; + + info!("Waiting {} seconds to register.", seconds); + let ctx_clone = self.context.clone(); + + tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(seconds as u64)).await; + ctx_clone.write().await.identify(); + }); + } + } } - pub(crate) fn event_privmsg(&mut self, prefix: &IrcPrefix, channel: &str, message: &str) { - if !message.starts_with(&self.config.cmdkey) { - return; - } - let mut elements = message.split_whitespace(); - let sys_name = &elements.next().unwrap()[1..]; + pub(crate) async fn event_privmsg<'a>( + &mut self, + prefix: &'a IrcPrefix<'a>, + channel: &str, + message: &str, + ) { + let sys_name; + { + let context = self.context.read().await; + if !message.starts_with(&context.config.cmdkey) { + return; + } + let mut elements = message.split_whitespace(); + sys_name = elements.next().unwrap()[1..].to_owned(); - if self.is_owner(prefix) && sys_name == "raw" { - self.queue(&elements.collect::>().join(" ")); + if context.is_owner(prefix) && sys_name == "raw" { + let mut context = self.context.write().await; + context.queue(&elements.collect::>().join(" ")); + return; + } + } + + if self.is_flood(channel).await { return; } - if self.is_flood(channel) { - return; - } + //TODO: + // MOVE RUN_SYSTEM BACK TO IRC - let response = self.run_system(prefix, sys_name); + let mut context = self.context.write().await; + let response = context.run_system(prefix, &sys_name); if response.0.is_none() { return; } for line in response.0.unwrap() { - self.privmsg(channel, &line) + context.privmsg(channel, &line) } } } diff --git a/src/factory.rs b/src/factory.rs index 5b640ab..f9bd719 100644 --- a/src/factory.rs +++ b/src/factory.rs @@ -5,5 +5,5 @@ use std::{ #[derive(Default)] pub struct Factory { - pub(crate) resources: HashMap>, + pub(crate) resources: HashMap>, } diff --git a/src/lib.rs b/src/lib.rs index 330a8fc..3ff4366 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ pub mod factory; pub mod irc_command; pub mod system; pub mod system_params; +pub mod utils; use std::{ any::TypeId, @@ -10,6 +11,7 @@ use std::{ io::ErrorKind, net::ToSocketAddrs, path::Path, + sync::Arc, time::SystemTime, }; @@ -23,6 +25,7 @@ use tokio::{ fs::File, io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream, + sync::RwLock, }; pub(crate) const MAX_MSG_LEN: usize = 512; @@ -67,7 +70,6 @@ impl Default for FloodControl { #[derive(Clone, Debug, Default)] pub struct IrcPrefix<'a> { - pub admin: bool, pub nick: &'a str, pub user: Option<&'a str>, pub host: Option<&'a str>, @@ -101,7 +103,6 @@ impl<'a> From<&'a str> for IrcPrefix<'a> { } Self { - admin: false, nick: nick, user: Some(user), host: Some(user_split[1]), @@ -146,164 +147,35 @@ pub struct IrcConfig { nick: String, user: String, real: String, - nickserv_pass: String, - nickserv_email: String, + nickserv_pass: Option, + nickserv_email: Option, cmdkey: String, flood_interval: f32, owner: String, admins: Vec, } -pub struct Irc { +// TODO: +/* + split Irc into two structs, one for the context, which is Send + Sync to be usable in tasks + one for the comms. + +*/ + +pub struct Context { config: IrcConfig, - stream: Stream, + identified: bool, + send_queue: VecDeque, systems: HashMap, factory: Factory, - - flood_controls: HashMap, - - send_queue: VecDeque, - recv_queue: VecDeque, - partial_line: String, } -impl Irc { - pub async fn from_config(path: impl AsRef) -> std::io::Result { - let mut file = File::open(path).await?; - let mut contents = String::new(); - file.read_to_string(&mut contents).await?; - - let config: IrcConfig = serde_yaml::from_str(&contents).unwrap(); - - Ok(Self { - config, - stream: Stream::None, - systems: HashMap::default(), - factory: Factory::default(), - flood_controls: HashMap::default(), - send_queue: VecDeque::new(), - recv_queue: VecDeque::new(), - partial_line: String::new(), - }) +impl Context { + pub fn privmsg(&mut self, channel: &str, message: &str) { + debug!("sending privmsg to {} : {}", channel, message); + self.queue(&format!("PRIVMSG {} :{}", channel, message)); } - - pub fn add_system System<'a> + 'static>( - &mut self, - name: &str, - system: impl for<'a> IntoSystem<'a, I, System = S>, - ) -> &mut Self { - self.systems - .insert(name.to_owned(), Box::new(system.into_system())); - self - } - - pub fn add_resource(&mut self, res: R) -> &mut Self { - self.factory - .resources - .insert(TypeId::of::(), Box::new(res)); - self - } - - pub fn run_system<'a>(&mut self, prefix: &'a IrcPrefix, name: &str) -> Response { - let system = self.systems.get_mut(name).unwrap(); - system.run(prefix, &mut self.factory) - } - - pub async fn connect(&mut self) -> std::io::Result<()> { - let domain = format!("{}:{}", self.config.host, self.config.port); - - info!("Connecting to {}", domain); - - let mut addrs = domain - .to_socket_addrs() - .expect("Unable to get addrs from domain {domain}"); - - let sock = addrs - .next() - .expect("Unable to get ip from addrs: {addrs:?}"); - - let plain_stream = TcpStream::connect(sock).await?; - - if self.config.ssl { - let stream = async_native_tls::connect(self.config.host.clone(), plain_stream) - .await - .unwrap(); - self.stream = Stream::Tls(stream); - return Ok(()); - } - - self.stream = Stream::Plain(plain_stream); - Ok(()) - } - - pub fn register(&mut self) { - info!( - "Registering as {}!{} ({})", - self.config.nick, self.config.user, self.config.real - ); - self.queue(&format!( - "USER {} 0 * {}", - self.config.user, self.config.real - )); - self.queue(&format!("NICK {}", self.config.nick)); - } - - async fn recv(&mut self) -> std::io::Result<()> { - let mut buf = [0; MAX_MSG_LEN]; - - let bytes_read = match self.stream.read(&mut buf).await { - Ok(bytes_read) => bytes_read, - Err(err) => match err.kind() { - ErrorKind::WouldBlock => { - return Ok(()); - } - _ => panic!("{err}"), - }, - }; - - if bytes_read == 0 { - return Ok(()); - } - - let buf = &buf[..bytes_read]; - - self.partial_line += String::from_utf8_lossy(buf).into_owned().as_str(); - let new_lines: Vec<&str> = self.partial_line.split("\r\n").collect(); - let len = new_lines.len(); - - for (index, line) in new_lines.into_iter().enumerate() { - if index == len - 1 && &buf[buf.len() - 3..] != b"\r\n" { - self.partial_line = line.to_owned(); - break; - } - self.recv_queue.push_back(line.to_owned()); - } - Ok(()) - } - - async fn send(&mut self) -> std::io::Result<()> { - while self.send_queue.len() > 0 { - let msg = self.send_queue.pop_front().unwrap(); - - trace!(">> {}", msg.replace("\r\n", "")); - let bytes_written = match self.stream.write(msg.as_bytes()).await { - Ok(bytes_written) => bytes_written, - Err(err) => match err.kind() { - ErrorKind::WouldBlock => { - return Ok(()); - } - _ => panic!("{err}"), - }, - }; - - if bytes_written < msg.len() { - self.send_queue.push_front(msg[bytes_written..].to_owned()); - } - } - Ok(()) - } - fn queue(&mut self, msg: &str) { let mut msg = msg.replace("\r", "").replace("\n", ""); @@ -324,39 +196,27 @@ impl Irc { } } - pub async fn update(&mut self) -> std::io::Result<()> { - self.recv().await?; - self.handle_commands().await; - self.send().await?; - Ok(()) + pub fn identify(&mut self) { + if self.config.nickserv_pass.is_none() || self.identified { + return; + } + + self.privmsg( + "NickServ", + &format!("IDENTIFY {}", self.config.nickserv_pass.as_ref().unwrap()), + ); } - pub async fn handle_commands(&mut self) { - while self.recv_queue.len() != 0 { - let owned_line = self.recv_queue.pop_front().unwrap(); - let line = owned_line.as_str(); - - trace!("<< {:?}", line); - - let mut message: IrcMessage = line.into(); - - let Some(prefix) = &mut message.prefix else { - return self.handle_message(&message).await; - }; - - if self.is_owner(prefix) { - prefix.admin = true; - } else { - for admin in &self.config.admins { - if self.is_admin(prefix, admin) { - prefix.admin = true; - break; - } - } - } - - self.handle_message(&message).await; - } + pub fn register(&mut self) { + info!( + "Registering as {}!{} ({})", + self.config.nick, self.config.user, self.config.real + ); + self.queue(&format!( + "USER {} 0 * {}", + self.config.user, self.config.real + )); + self.queue(&format!("NICK {}", self.config.nick)); } fn is_owner(&self, prefix: &IrcPrefix) -> bool { @@ -396,7 +256,122 @@ impl Irc { self.queue(&format!("NICK {}", self.config.nick)); } - fn is_flood(&mut self, channel: &str) -> bool { + pub fn privmsg_all(&mut self, message: &str) { + for i in 0..self.config.channels.len() { + let channel = self.config.channels.iter().nth(i).unwrap(); + debug!("sending privmsg to {} : {}", channel, message); + self.queue(&format!("PRIVMSG {} :{}", channel, message)); + } + } + + pub fn add_system System<'a> + Send + Sync + 'static>( + &mut self, + name: &str, + system: impl for<'a> IntoSystem<'a, I, System = S>, + ) -> &mut Self { + self.systems + .insert(name.to_owned(), Box::new(system.into_system())); + self + } + + pub fn add_resource(&mut self, res: R) -> &mut Self { + self.factory + .resources + .insert(TypeId::of::(), Box::new(res)); + self + } + + pub fn run_system<'a>(&mut self, prefix: &'a IrcPrefix, name: &str) -> Response { + let system = self.systems.get_mut(name).unwrap(); + system.run(prefix, &mut self.factory) + } +} + +pub struct Irc { + context: Arc>, + recv_queue: VecDeque, + flood_controls: HashMap, + stream: Stream, + partial_line: String, +} + +impl Irc { + pub async fn from_config(path: impl AsRef) -> std::io::Result { + let mut file = File::open(path).await?; + let mut contents = String::new(); + file.read_to_string(&mut contents).await?; + + let config: IrcConfig = serde_yaml::from_str(&contents).unwrap(); + + let context = Arc::new(RwLock::new(Context { + config, + identified: false, + send_queue: VecDeque::new(), + systems: HashMap::default(), + factory: Factory::default(), + })); + + Ok(Self { + context, + stream: Stream::None, + recv_queue: VecDeque::new(), + flood_controls: HashMap::default(), + partial_line: String::new(), + }) + } + + pub async fn add_system System<'a> + Send + Sync + 'static>( + &mut self, + name: &str, + system: impl for<'a> IntoSystem<'a, I, System = S>, + ) -> &mut Self { + { + let mut context = self.context.write().await; + context.add_system(name, system); + } + self + } + + pub async fn add_resource(&mut self, res: R) -> &mut Self { + { + let mut context = self.context.write().await; + context.add_resource(res); + } + self + } + + pub async fn connect(&mut self) -> std::io::Result<()> { + let mut context = self.context.write().await; + + let domain = format!("{}:{}", context.config.host, context.config.port); + + info!("Connecting to {}", domain); + + let mut addrs = domain + .to_socket_addrs() + .expect("Unable to get addrs from domain {domain}"); + + let sock = addrs + .next() + .expect("Unable to get ip from addrs: {addrs:?}"); + + let plain_stream = TcpStream::connect(sock).await?; + + if context.config.ssl { + let stream = async_native_tls::connect(context.config.host.clone(), plain_stream) + .await + .unwrap(); + self.stream = Stream::Tls(stream); + context.register(); + return Ok(()); + } + + self.stream = Stream::Plain(plain_stream); + context.register(); + Ok(()) + } + + async fn is_flood(&mut self, channel: &str) -> bool { let mut flood_control = match self.flood_controls.entry(channel.to_owned()) { std::collections::hash_map::Entry::Occupied(o) => o.into_mut(), std::collections::hash_map::Entry::Vacant(v) => { @@ -409,7 +384,8 @@ impl Irc { let elapsed = flood_control.last_cmd.elapsed().unwrap(); - if elapsed.as_secs_f32() < self.config.flood_interval { + + if elapsed.as_secs_f32() < self.context.read().await.config.flood_interval { warn!("they be floodin @ {channel}!"); return true; } @@ -418,46 +394,127 @@ impl Irc { false } - pub fn privmsg(&mut self, channel: &str, message: &str) { - debug!("sending privmsg to {} : {}", channel, message); - self.queue(&format!("PRIVMSG {} :{}", channel, message)); + async fn recv(&mut self) -> std::io::Result<()> { + let mut buf = [0; MAX_MSG_LEN]; + + let bytes_read = match self.stream.read(&mut buf).await { + Ok(bytes_read) => bytes_read, + Err(err) => match err.kind() { + ErrorKind::WouldBlock => { + return Ok(()); + } + _ => panic!("{err}"), + }, + }; + + if bytes_read == 0 { + return Ok(()); + } + + let buf = &buf[..bytes_read]; + + self.partial_line += String::from_utf8_lossy(buf).into_owned().as_str(); + let new_lines: Vec<&str> = self.partial_line.split("\r\n").collect(); + let len = new_lines.len(); + + for (index, line) in new_lines.into_iter().enumerate() { + if index == len - 1 && &buf[buf.len() - 3..] != b"\r\n" { + self.partial_line = line.to_owned(); + break; + } + self.recv_queue.push_back(line.to_owned()); + } + Ok(()) } - pub fn privmsg_all(&mut self, message: &str) { - for i in 0..self.config.channels.len() { - let channel = self.config.channels.iter().nth(i).unwrap(); - debug!("sending privmsg to {} : {}", channel, message); - self.queue(&format!("PRIVMSG {} :{}", channel, message)); + async fn send(&mut self) -> std::io::Result<()> { + let mut context = self.context.write().await; + while context.send_queue.len() > 0 { + let msg = context.send_queue.pop_front().unwrap(); + + trace!(">> {}", msg.replace("\r\n", "")); + let bytes_written = match self.stream.write(msg.as_bytes()).await { + Ok(bytes_written) => bytes_written, + Err(err) => match err.kind() { + ErrorKind::WouldBlock => { + return Ok(()); + } + _ => panic!("{err}"), + }, + }; + + if bytes_written < msg.len() { + context + .send_queue + .push_front(msg[bytes_written..].to_owned()); + } + } + Ok(()) + } + + pub async fn handle_commands(&mut self) { + while self.recv_queue.len() != 0 { + let owned_line = self.recv_queue.pop_front().unwrap(); + let line = owned_line.as_str(); + + trace!("<< {:?}", line); + + let mut message: IrcMessage = line.into(); + + let Some(prefix) = &mut message.prefix else { + return self.handle_message(&message).await; + }; + + self.handle_message(&message).await; } } async fn handle_message<'a>(&mut self, message: &'a IrcMessage<'a>) { match message.command { - IrcCommand::PING => self.event_ping(&message.parameters[0]), - IrcCommand::RPL_WELCOME => self.event_welcome(&message.parameters[1..].join(" ")), - IrcCommand::ERR_NICKNAMEINUSE => self.event_nicknameinuse(), - IrcCommand::KICK => self.event_kick( - &message.parameters[0], - &message.parameters[1], - &message.prefix.as_ref().unwrap().nick, - &message.parameters[2..].join(" "), - ), + IrcCommand::PING => self.event_ping(&message.parameters[0]).await, + IrcCommand::RPL_WELCOME => self.event_welcome(&message.parameters[1..].join(" ")).await, + IrcCommand::ERR_NICKNAMEINUSE => self.event_nicknameinuse().await, + IrcCommand::KICK => { + self.event_kick( + &message.parameters[0], + &message.parameters[1], + &message.prefix.as_ref().unwrap().nick, + &message.parameters[2..].join(" "), + ) + .await + } IrcCommand::QUIT => self.event_quit(message.prefix.as_ref().unwrap()).await, - IrcCommand::INVITE => self.event_invite( - message.prefix.as_ref().unwrap(), - &message.parameters[1][1..], - ), - IrcCommand::PRIVMSG => self.event_privmsg( - message.prefix.as_ref().unwrap(), - &message.parameters[0], - &message.parameters[1..].join(" ")[1..], - ), - IrcCommand::NOTICE => self.event_notice( - message.prefix.as_ref(), - &message.parameters[0], - &message.parameters[1..].join(" ")[1..], - ), + IrcCommand::INVITE => { + self.event_invite( + message.prefix.as_ref().unwrap(), + &message.parameters[1][1..], + ) + .await + } + IrcCommand::PRIVMSG => { + self.event_privmsg( + message.prefix.as_ref().unwrap(), + &message.parameters[0], + &message.parameters[1..].join(" ")[1..], + ) + .await + } + IrcCommand::NOTICE => { + self.event_notice( + message.prefix.as_ref(), + &message.parameters[0], + &message.parameters[1..].join(" ")[1..], + ) + .await + } _ => {} } } + + pub async fn update(&mut self) -> std::io::Result<()> { + self.recv().await?; + self.handle_commands().await; + self.send().await?; + Ok(()) + } } diff --git a/src/system.rs b/src/system.rs index 171e6f2..c03f246 100644 --- a/src/system.rs +++ b/src/system.rs @@ -81,7 +81,7 @@ impl_into_system!(T1, T2); impl_into_system!(T1, T2, T3); impl_into_system!(T1, T2, T3, T4); -pub(crate) type StoredSystem = Box System<'a>>; +pub(crate) type StoredSystem = Box System<'a> + Send + Sync>; pub(crate) trait SystemParam { type Item<'new>; diff --git a/src/system_params.rs b/src/system_params.rs index ec14a00..a292223 100644 --- a/src/system_params.rs +++ b/src/system_params.rs @@ -74,7 +74,7 @@ impl<'res, T: 'static> SystemParam for ResMut<'res, T> { type Item<'new> = ResMut<'new, T>; fn retrieve<'r>(_prefix: &'r IrcPrefix, factory: &'r Factory) -> Self::Item<'r> { - let const_ptr = &factory.resources as *const HashMap>; + let const_ptr = &factory.resources as *const HashMap>; let mut_ptr = const_ptr as *mut HashMap>; let res_mut = unsafe { &mut *mut_ptr }; diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..e69de29