From 789bfca7458d101811e77a2f19c368cc773325b1 Mon Sep 17 00:00:00 2001 From: wrk Date: Fri, 9 Jun 2023 02:02:16 +0200 Subject: [PATCH] Added context to systems --- .gitignore | 2 + src/events.rs | 36 ++-- src/lib.rs | 421 +++++++++++++++++++++---------------------- src/system.rs | 17 +- src/system_params.rs | 68 ++++++- 5 files changed, 312 insertions(+), 232 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..869df07 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock \ No newline at end of file diff --git a/src/events.rs b/src/events.rs index 1f7641e..ecc7f83 100644 --- a/src/events.rs +++ b/src/events.rs @@ -15,16 +15,20 @@ impl Irc { pub(crate) async fn event_welcome(&mut self, welcome_msg: &str) { debug!("{welcome_msg}"); + self.identify().await; + let mut context = self.context.write().await; - context.identify(); - context.join_config_channels(); + for channel in &self.config.channels { + context.join(channel); + } } pub(crate) async fn event_nicknameinuse(&mut self) { let mut context = self.context.write().await; - let new_nick = &format!("{}_", &context.config.nick); + let new_nick = format!("{}_", &self.config.nick); warn!("Nick already in use., switching to {}", new_nick); - context.update_nick(new_nick) + context.nick(&new_nick); + self.config.nick = new_nick; } pub(crate) async fn event_kick( @@ -35,7 +39,7 @@ impl Irc { reason: &str, ) { let mut context = self.context.write().await; - if nick != &context.config.nick { + if nick != &self.config.nick { return; } @@ -44,7 +48,7 @@ impl Irc { } pub(crate) async fn event_quit<'a>(&mut self, prefix: &'a IrcPrefix<'a>) { - if prefix.nick != self.context.read().await.config.nick { + if prefix.nick != self.config.nick { return; } @@ -64,7 +68,7 @@ impl Irc { channel: &str, message: &str, ) { - let config = self.context.read().await.config.clone(); + let config = self.config.clone(); if channel == &config.nick { if message.ends_with(&format!("\x02{}\x02 isn't registered.", config.nick)) { @@ -87,12 +91,13 @@ impl Irc { + 1; info!("Waiting {} seconds to register.", seconds); + /* TODO: fix this let ctx_clone = self.context.clone(); - tokio::spawn(async move { tokio::time::sleep(Duration::from_secs(seconds as u64)).await; - ctx_clone.write().await.identify(); + self.identify().await; }); + */ } } } @@ -107,13 +112,13 @@ impl Irc { let sys_name; { let context = self.context.read().await; - if !message.starts_with(&context.config.cmdkey) { + if !message.starts_with(&self.config.cmdkey) { return; } elements = message.split_whitespace(); sys_name = elements.next().unwrap()[1..].to_owned(); - if context.is_owner(prefix) && sys_name == "raw" { + if prefix.owner() && sys_name == "raw" { drop(context); let mut context = self.context.write().await; context.queue(&elements.collect::>().join(" ")); @@ -127,14 +132,14 @@ impl Irc { let arguments = elements.collect::>(); - let mut context = self.context.write().await; - if !context.systems.contains_key(&sys_name) { - let resp = context.run_default_system(prefix, &arguments).await; + if !self.systems.contains_key(&sys_name) { + let resp = self.run_default_system(prefix, channel, &arguments).await; let Response::Data(data) = resp else { return; }; + let mut context = self.context.write().await; for (idx, line) in data.data.iter().enumerate() { if idx == 0 && data.highlight { context.privmsg(channel, &format!("{}: {}", prefix.nick, line)) @@ -145,11 +150,12 @@ impl Irc { return; } - let response = context.run_system(prefix, &arguments, &sys_name).await; + let response = self.run_system(prefix, channel, &arguments, &sys_name).await; let Response::Data(data) = response else { return; }; + let mut context = self.context.write().await; for (idx, line) in data.data.iter().enumerate() { if idx == 0 && data.highlight { context.privmsg(channel, &format!("{}: {}", prefix.nick, line)) diff --git a/src/lib.rs b/src/lib.rs index 1d9ab9c..0331da1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,7 +26,7 @@ use tokio::{ fs::File, io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}, net::TcpStream, - sync::{mpsc, RwLock}, + sync::RwLock, }; pub(crate) const MAX_MSG_LEN: usize = 512; @@ -69,11 +69,30 @@ impl Default for FloodControl { } } +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum IrcPrefixKind { + Owner, + Admin, + #[default] + User, +} + #[derive(Clone, Debug, Default)] pub struct IrcPrefix<'a> { pub nick: &'a str, pub user: Option<&'a str>, pub host: Option<&'a str>, + kind: IrcPrefixKind, +} + +impl<'a> IrcPrefix<'a> { + pub fn owner(&self) -> bool { + self.kind == IrcPrefixKind::Owner + } + + pub fn admin(&self) -> bool { + self.kind == IrcPrefixKind::Admin + } } impl<'a> From<&'a str> for IrcPrefix<'a> { @@ -107,6 +126,7 @@ impl<'a> From<&'a str> for IrcPrefix<'a> { nick: nick, user: Some(user), host: Some(user_split[1]), + ..Default::default() } } } @@ -163,19 +183,11 @@ pub struct IrcConfig { */ -pub struct Context { - config: IrcConfig, - identified: bool, +pub struct IrcContext { send_queue: VecDeque, - - default_system: Option, - invalid_system: Option, - systems: HashMap, - tasks: Vec<(Duration, StoredSystem)>, - factory: Arc>, } -impl Context { +impl IrcContext { pub fn privmsg(&mut self, channel: &str, message: &str) { debug!("sending privmsg to {} : {}", channel, message); self.queue(&format!("PRIVMSG {} :{}", channel, message)); @@ -200,136 +212,13 @@ impl Context { } } - pub fn identify(&mut self) { - if self.config.nickserv_pass.is_none() || self.identified { - return; - } - - self.privmsg( - "NickServ", - &format!("IDENTIFY {}", self.config.nickserv_pass.as_ref().unwrap()), - ); - } - - pub fn register(&mut self) { - info!( - "Registering as {}!{} ({})", - self.config.nick, self.config.user, self.config.real - ); - self.queue(&format!( - "USER {} 0 * {}", - self.config.user, self.config.real - )); - self.queue(&format!("NICK {}", self.config.nick)); - } - - fn is_owner(&self, prefix: &IrcPrefix) -> bool { - self.is_admin(prefix, &self.config.owner) - } - - fn is_admin(&self, prefix: &IrcPrefix, admin: &str) -> bool { - let admin = ":".to_owned() + &admin; - let admin_prefix: IrcPrefix = admin.as_str().into(); - - if (admin_prefix.nick == prefix.nick || admin_prefix.nick == "*") - && (admin_prefix.user == prefix.user || admin_prefix.user == Some("*")) - && (admin_prefix.host == prefix.host || admin_prefix.host == Some("*")) - { - return true; - } - - false - } - - fn join(&mut self, channel: &str) { + pub fn join(&mut self, channel: &str) { info!("Joining {channel}"); self.queue(&format!("JOIN {}", channel)); - self.config.channels.insert(channel.to_owned()); } - fn join_config_channels(&mut self) { - for i in 0..self.config.channels.len() { - let channel = self.config.channels.iter().nth(i).unwrap(); - info!("Joining {channel}"); - self.queue(&format!("JOIN {}", channel)) - } - } - - fn update_nick(&mut self, new_nick: &str) { - self.config.nick = new_nick.to_owned(); - self.queue(&format!("NICK {}", self.config.nick)); - } - - pub fn privmsg_all(&mut self, message: &str) { - for i in 0..self.config.channels.len() { - let channel = self.config.channels.iter().nth(i).unwrap(); - debug!("sending privmsg to {} : {}", channel, message); - self.queue(&format!("PRIVMSG {} :{}", channel, message)); - } - } - - pub async fn run_system<'a>( - &mut self, - prefix: &'a IrcPrefix<'a>, - arguments: &'a [&'a str], - name: &str, - ) -> Response { - let system = self.systems.get_mut(name).unwrap(); - system.run(prefix, arguments, &mut *self.factory.write().await) - } - - pub async fn run_default_system<'a>( - &mut self, - prefix: &'a IrcPrefix<'a>, - arguments: &'a [&'a str], - ) -> Response { - if self.invalid_system.is_none() { - return Response::Empty; - } - - self.default_system.as_mut().unwrap().run( - prefix, - arguments, - &mut *self.factory.write().await, - ) - } - - pub async fn run_invalid_system<'a>( - &mut self, - prefix: &'a IrcPrefix<'a>, - arguments: &'a [&'a str], - ) -> Response { - if self.invalid_system.is_none() { - return Response::Empty; - } - - self.invalid_system.as_mut().unwrap().run( - prefix, - arguments, - &mut *self.factory.write().await, - ) - } - - pub async fn run_interval_tasks(&mut self, tx: mpsc::Sender) { - for (duration, mut task) in std::mem::take(&mut self.tasks) { - let fact = self.factory.clone(); - let task_tx = tx.clone(); - tokio::spawn(async move { - loop { - tokio::time::sleep(duration).await; - let resp = task.run( - &IrcPrefix { - nick: "", - user: None, - host: None, - }, - &[], - &mut *fact.write().await, - ); - task_tx.send(resp).await.unwrap(); - } - }); - } + pub fn nick(&mut self, nick: &str) { + self.queue(&format!("NICK {}", nick)); } } @@ -338,10 +227,19 @@ pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Unpin {} impl AsyncReadWrite for T {} pub struct Irc { - context: Arc>, + context: Arc>, flood_controls: HashMap, stream: Option>, partial_line: String, + + config: IrcConfig, + identified: bool, + + default_system: Option, + invalid_system: Option, + systems: HashMap, + tasks: Vec<(Duration, StoredSystem)>, + factory: Arc>, } impl Irc { @@ -352,15 +250,8 @@ impl Irc { let config: IrcConfig = serde_yaml::from_str(&contents).unwrap(); - let context = Arc::new(RwLock::new(Context { - config, - identified: false, + let context = Arc::new(RwLock::new(IrcContext { send_queue: VecDeque::new(), - default_system: None, - invalid_system: None, - systems: HashMap::default(), - tasks: Vec::new(), - factory: Arc::new(RwLock::new(Factory::default())), })); Ok(Self { @@ -368,6 +259,13 @@ impl Irc { stream: None, flood_controls: HashMap::default(), partial_line: String::new(), + config, + identified: false, + default_system: None, + invalid_system: None, + systems: HashMap::default(), + tasks: Vec::new(), + factory: Arc::new(RwLock::new(Factory::default())), }) } @@ -376,12 +274,8 @@ impl Irc { name: &str, system: impl for<'a> IntoSystem, ) -> &mut Self { - { - let mut context = self.context.write().await; - context - .systems - .insert(name.to_owned(), Box::new(system.into_system())); - } + self.systems + .insert(name.to_owned(), Box::new(system.into_system())); self } @@ -389,10 +283,7 @@ impl Irc { &mut self, system: impl for<'a> IntoSystem, ) -> &mut Self { - { - let mut context = self.context.write().await; - context.default_system = Some(Box::new(system.into_system())); - } + self.default_system = Some(Box::new(system.into_system())); self } @@ -400,10 +291,7 @@ impl Irc { &mut self, system: impl for<'a> IntoSystem, ) -> &mut Self { - { - let mut context = self.context.write().await; - context.invalid_system = Some(Box::new(system.into_system())); - } + self.invalid_system = Some(Box::new(system.into_system())); self } @@ -412,10 +300,7 @@ impl Irc { duration: Duration, system: impl for<'a> IntoSystem, ) -> &mut Self { - { - let mut context = self.context.write().await; - context.tasks.push((duration, Box::new(system.into_system()))); - } + self.tasks.push((duration, Box::new(system.into_system()))); self } @@ -423,30 +308,22 @@ impl Irc { &mut self, system: impl for<'a> IntoSystem, ) -> &mut Self { - { - let mut context = self.context.write().await; - context.tasks.push((Duration::ZERO, Box::new(system.into_system()))); - } + self.tasks + .push((Duration::ZERO, Box::new(system.into_system()))); self } pub async fn add_resource(&mut self, res: R) -> &mut Self { - { - let context = self.context.write().await; - context - .factory - .write() - .await - .resources - .insert(TypeId::of::(), Box::new(res)); - } + self.factory + .write() + .await + .resources + .insert(TypeId::of::(), Box::new(res)); self } pub async fn connect(&mut self) -> std::io::Result<()> { - let context = self.context.read().await; - - let domain = format!("{}:{}", context.config.host, context.config.port); + let domain = format!("{}:{}", self.config.host, self.config.port); info!("Connecting to {}", domain); @@ -460,8 +337,8 @@ impl Irc { let plain_stream = TcpStream::connect(sock).await?; - if context.config.ssl { - let stream = async_native_tls::connect(context.config.host.clone(), plain_stream) + if self.config.ssl { + let stream = async_native_tls::connect(self.config.host.clone(), plain_stream) .await .unwrap(); self.stream = Some(Box::new(stream)); @@ -485,7 +362,7 @@ impl Irc { let elapsed = flood_control.last_cmd.elapsed().unwrap(); - if elapsed.as_secs_f32() < self.context.read().await.config.flood_interval { + if elapsed.as_secs_f32() < self.config.flood_interval { warn!("they be floodin @ {channel}!"); return true; } @@ -494,6 +371,50 @@ impl Irc { false } + fn is_owner(&self, prefix: &IrcPrefix) -> bool { + let owner = ":".to_owned() + &self.config.owner; + let owner_prefix: IrcPrefix = owner.as_str().into(); + + if (owner_prefix.nick == prefix.nick || owner_prefix.nick == "*") + && (owner_prefix.user == prefix.user || owner_prefix.user == Some("*")) + && (owner_prefix.host == prefix.host || owner_prefix.host == Some("*")) + { + return true; + } + + false + } + + fn is_admin(&self, prefix: &IrcPrefix) -> bool { + for admin_str in &self.config.admins { + let admin = ":".to_owned() + admin_str; + let admin_prefix: IrcPrefix = admin.as_str().into(); + + if (admin_prefix.nick == prefix.nick || admin_prefix.nick == "*") + && (admin_prefix.user == prefix.user || admin_prefix.user == Some("*")) + && (admin_prefix.host == prefix.host || admin_prefix.host == Some("*")) + { + return true; + } + } + + false + } + + pub fn into_message<'a>(&self, line: &'a str) -> IrcMessage<'a> { + let mut message: IrcMessage = line.into(); + + if let Some(prefix) = &mut message.prefix { + if self.is_owner(prefix) { + prefix.kind = IrcPrefixKind::Owner; + } else if self.is_admin(prefix) { + prefix.kind = IrcPrefixKind::Admin; + } + } + + message + } + pub async fn handle_commands(&mut self, mut lines: VecDeque) { while lines.len() != 0 { let owned_line = lines.pop_front().unwrap(); @@ -501,7 +422,8 @@ impl Irc { trace!("<< {:?}", line); - let message: IrcMessage = line.into(); + let message = self.into_message(line); + self.handle_message(&message).await; } } @@ -548,27 +470,116 @@ impl Irc { } } + pub async fn register(&mut self) { + info!( + "Registering as {}!{} ({})", + self.config.nick, self.config.user, self.config.real + ); + let mut context = self.context.write().await; + + context.queue(&format!( + "USER {} 0 * {}", + self.config.user, self.config.real + )); + context.nick(&self.config.nick); + } + + pub async fn identify(&mut self) { + if self.config.nickserv_pass.is_none() || self.identified { + return; + } + + self.context.write().await.privmsg( + "NickServ", + &format!("IDENTIFY {}", self.config.nickserv_pass.as_ref().unwrap()), + ); + } + + pub async fn run_system<'a>( + &mut self, + prefix: &'a IrcPrefix<'a>, + channel: &'a str, + arguments: &'a [&'a str], + name: &str, + ) -> Response { + let system = self.systems.get_mut(name).unwrap(); + system.run( + prefix, + channel, + arguments, + &mut *self.context.write().await, + &mut *self.factory.write().await, + ) + } + + pub async fn run_default_system<'a>( + &mut self, + prefix: &'a IrcPrefix<'a>, + channel: &'a str, + arguments: &'a [&'a str], + ) -> Response { + if self.invalid_system.is_none() { + return Response::Empty; + } + + self.default_system.as_mut().unwrap().run( + prefix, + channel, + arguments, + &mut *self.context.write().await, + &mut *self.factory.write().await, + ) + } + + pub async fn run_invalid_system<'a>( + &mut self, + prefix: &'a IrcPrefix<'a>, + channel: &'a str, + arguments: &'a [&'a str], + ) -> Response { + if self.invalid_system.is_none() { + return Response::Empty; + } + + self.invalid_system.as_mut().unwrap().run( + prefix, + channel, + arguments, + &mut *self.context.write().await, + &mut *self.factory.write().await, + ) + } + + pub async fn run_interval_tasks(&mut self) { + for (duration, mut task) in std::mem::take(&mut self.tasks) { + let fact = self.factory.clone(); + let ctx = self.context.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(duration).await; + task.run( + &IrcPrefix::default(), + "", + &[], + &mut *ctx.write().await, + &mut *fact.write().await, + ); + } + }); + } + } + pub async fn run(&mut self) -> std::io::Result<()> { self.connect().await?; info!("Ready!"); - let (tx, mut rx) = mpsc::channel::(512); - { - let mut context = self.context.write().await; - context.register(); - context.run_interval_tasks(tx).await; - } + self.register().await; + + self.run_interval_tasks().await; let stream = self.stream.take().unwrap(); let (mut reader, mut writer) = tokio::io::split(stream); - let cloned_ctx = self.context.clone(); - tokio::spawn(async move { - loop { - handle_rx(&mut rx, &cloned_ctx).await; - } - }); - let cloned_ctx = self.context.clone(); tokio::spawn(async move { loop { @@ -583,23 +594,9 @@ impl Irc { } } -async fn handle_rx(rx: &mut mpsc::Receiver, arc_context: &RwLock) { - while let Some(response) = rx.recv().await { - let mut context = arc_context.write().await; - - let Response::Data(data) = response else { - continue; - }; - - for line in data.data { - context.privmsg_all(&line); - } - } -} - async fn send( writer: &mut WriteHalf, - arc_context: &RwLock, + arc_context: &RwLock, ) -> std::io::Result<()> { let mut len; { diff --git a/src/system.rs b/src/system.rs index ffe3513..bab0784 100644 --- a/src/system.rs +++ b/src/system.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use crate::{factory::Factory, format::Msg, IrcPrefix}; +use crate::{factory::Factory, format::Msg, IrcContext, IrcPrefix}; pub struct FunctionSystem { f: F, @@ -8,7 +8,14 @@ pub struct FunctionSystem { } pub trait System { - fn run(&mut self, prefix: &IrcPrefix, arguments: &[&str], factory: &mut Factory) -> Response; + fn run( + &mut self, + prefix: &IrcPrefix, + channel: &str, + arguments: &[&str], + context: &mut IrcContext, + factory: &mut Factory, + ) -> Response; } pub trait IntoSystem { @@ -29,7 +36,7 @@ macro_rules! impl_system { FnMut( $($params),* ) -> R + FnMut( $(<$params as SystemParam>::Item<'b>),* ) -> R { - fn run(&mut self, prefix: &IrcPrefix, arguments: &[&str], factory: &mut Factory) -> Response { + fn run(&mut self, prefix: &IrcPrefix, channel: &str, arguments: &[&str], context: &mut IrcContext, factory: &mut Factory) -> Response { fn call_inner<'a, R: IntoResponse, $($params),*>( mut f: impl FnMut($($params),*) -> R, $($params: $params),* @@ -42,7 +49,7 @@ macro_rules! impl_system { return Response::InvalidArgument; } - let $params = $params::retrieve(prefix, arguments, &factory); + let $params = $params::retrieve(prefix, channel, arguments, &context, &factory); )* @@ -116,7 +123,9 @@ pub(crate) trait SystemParam { type Item<'new>; fn retrieve<'r>( prefix: &'r IrcPrefix, + channel: &'r str, arguments: &'r [&'r str], + context: &'r IrcContext, factory: &'r Factory, ) -> Self::Item<'r>; #[allow(unused_variables)] diff --git a/src/system_params.rs b/src/system_params.rs index b8b0dc3..bf82a3e 100644 --- a/src/system_params.rs +++ b/src/system_params.rs @@ -4,7 +4,7 @@ use std::{ ops::{Deref, DerefMut}, }; -use crate::{factory::Factory, system::SystemParam, IrcPrefix}; +use crate::{factory::Factory, system::SystemParam, IrcContext, IrcPrefix}; #[derive(Debug)] pub struct Res<'a, T: 'static> { @@ -30,7 +30,9 @@ impl<'res, T: 'static> SystemParam for Res<'res, T> { fn retrieve<'r>( _prefix: &'r IrcPrefix, + _channel: &str, _arguments: &'r [&'r str], + _context: &'r IrcContext, factory: &'r Factory, ) -> Self::Item<'r> { Res { @@ -79,7 +81,9 @@ impl<'res, T: 'static> SystemParam for ResMut<'res, T> { fn retrieve<'r>( _prefix: &'r IrcPrefix, + _channel: &str, _arguments: &'r [&'r str], + _context: &'r IrcContext, factory: &'r Factory, ) -> Self::Item<'r> { let const_ptr = &factory.resources as *const HashMap>; @@ -101,13 +105,38 @@ impl<'a> SystemParam for IrcPrefix<'a> { fn retrieve<'r>( prefix: &'r IrcPrefix, + _channel: &str, _arguments: &'r [&'r str], + _context: &'r IrcContext, _factory: &'r Factory, ) -> Self::Item<'r> { prefix.clone() } } +pub struct Channel<'a>(&'a str); +impl<'a> Deref for Channel<'a> { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a> SystemParam for Channel<'a> { + type Item<'new> = Channel<'new>; + + fn retrieve<'r>( + _prefix: &'r IrcPrefix, + channel: &'r str, + _arguments: &'r [&'r str], + _context: &'r IrcContext, + _factory: &'r Factory, + ) -> Self::Item<'r> { + Channel(channel) + } +} + pub struct AnyArguments<'a>(&'a [&'a str]); impl<'a> Deref for AnyArguments<'a> { @@ -123,7 +152,9 @@ impl<'a> SystemParam for AnyArguments<'a> { fn retrieve<'r>( _prefix: &'r IrcPrefix, + _channel: &str, arguments: &'r [&'r str], + _context: &'r IrcContext, _factory: &'r Factory, ) -> Self::Item<'r> { AnyArguments(&arguments) @@ -145,7 +176,9 @@ impl<'a, const N: usize> SystemParam for Arguments<'a, N> { fn retrieve<'r>( _prefix: &'r IrcPrefix, + _channel: &str, arguments: &'r [&'r str], + _context: &'r IrcContext, _factory: &'r Factory, ) -> Self::Item<'r> { Arguments(&arguments[..N]) @@ -155,3 +188,36 @@ impl<'a, const N: usize> SystemParam for Arguments<'a, N> { arguments.len() == N } } + +pub struct Context<'a>(&'a mut IrcContext); + +impl<'a> Deref for Context<'a> { + type Target = IrcContext; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a> DerefMut for Context<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl<'a> SystemParam for Context<'a> { + type Item<'new> = Context<'new>; + + fn retrieve<'r>( + _prefix: &'r IrcPrefix, + _channel: &str, + _arguments: &'r [&'r str], + context: &'r IrcContext, + _factory: &'r Factory, + ) -> Self::Item<'r> { + let const_ptr = context as *const IrcContext; + let mut_ptr = const_ptr as *mut IrcContext; + let ctx_mut = unsafe { &mut *mut_ptr }; + Context(ctx_mut) + } +}