commit 9972c1d2947703a131e70113897b4ee2cedaad73 Author: wrk Date: Mon May 29 16:11:41 2023 +0200 init diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..d5a01c4 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "ircie" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tokio = { version = "1.28.2", features = ["full"] } +async-native-tls = { version = "0.5.0", default-features = false, features = [ "runtime-tokio" ] } +serde = { version = "1.0.163", features = ["derive"] } +serde_yaml = "0.9.21" \ No newline at end of file diff --git a/src/events.rs b/src/events.rs new file mode 100644 index 0000000..383daa9 --- /dev/null +++ b/src/events.rs @@ -0,0 +1,71 @@ +use std::time::Duration; + +use crate::{Irc, IrcPrefix}; + +impl Irc { + pub(crate) fn event_ping(&mut self, ping_token: &str) { + self.queue(&format!("PONG {}", ping_token)); + } + + pub(crate) fn event_welcome(&mut self) { + // self.identify(); + self.join_config_channels(); + } + + pub(crate) fn event_nicknameinuse(&mut self) { + self.update_nick(&format!("{}_", &self.config.nick)) + } + + pub(crate) fn event_kick(&mut self, channel: &str, nick: &str, message: &str) { + if nick != &self.config.nick { + return; + } + + println!("we got kicked!"); + println!("{message}"); + + self.join(channel); + } + + pub(crate) async fn event_quit<'a>(&mut self, prefix: &'a IrcPrefix<'a>) { + if prefix.nick != self.config.nick { + return; + } + + println!("need to reconnect."); + std::thread::sleep(Duration::from_secs(15)); + self.connect().await.unwrap(); + self.register(); + } + + pub(crate) fn event_invite(&mut self, prefix: &IrcPrefix, channel: &str) { + println!("{} invited us to {}", prefix.nick, channel); + } + + pub(crate) fn event_notice( + &mut self, + prefix: Option<&IrcPrefix>, + channel: &str, + message: &str, + ) { + //TODO, register shit + } + + pub(crate) fn event_privmsg(&mut self, prefix: &IrcPrefix, channel: &str, message: &str) { + if !message.starts_with(&self.config.cmdkey) { + return; + } + let mut elements = message.split_whitespace(); + let sys_name = &elements.next().unwrap()[1..]; + + if self.is_owner(prefix) && sys_name == "raw" { + self.queue(&elements.collect::>().join(" ")); + return; + } + + if self.is_flood(channel) { + return; + } + self.run_system(prefix, sys_name); + } +} diff --git a/src/factory.rs b/src/factory.rs new file mode 100644 index 0000000..5b640ab --- /dev/null +++ b/src/factory.rs @@ -0,0 +1,9 @@ +use std::{ + any::{Any, TypeId}, + collections::HashMap, +}; + +#[derive(Default)] +pub struct Factory { + pub(crate) resources: HashMap>, +} diff --git a/src/irc_command.rs b/src/irc_command.rs new file mode 100644 index 0000000..ae1497e --- /dev/null +++ b/src/irc_command.rs @@ -0,0 +1,242 @@ +macro_rules! make_irc_command_enum { + + ($variant:ident) => { + $variant + }; + + ($($variant:ident: $value:expr),+) => { + #[allow(non_camel_case_types)] + pub enum IrcCommand { + UNKNOWN, + $($variant),+ + } + + impl From<&str> for IrcCommand { + fn from(command_str: &str) -> Self { + match command_str { + $($value => Self::$variant,)+ + _ => Self::UNKNOWN, + } + } + } + + }; +} + +make_irc_command_enum!( + ADMIN: "ADMIN", + AWAY: "AWAY", + CNOTICE: "CNOTICE", + CPRIVMSG: "CPRIVMSG", + CONNECT: "CONNECT", + DIE: "DIE", + ENCAP: "ENCAP", + ERROR: "ERROR", + HELP: "HELP", + INFO: "INFO", + INVITE: "INVITE", + ISON: "ISON", + JOIN: "JOIN", + KICK: "KICK", + KILL: "KILL", + KNOCK: "KNOCK", + LINKS: "LINKS", + LIST: "LIST", + LUSERS: "LUSERS", + MODE: "MODE", + MOTD: "MOTD", + NAMES: "NAMES", + NICK: "NICK", + NOTICE: "NOTICE", + OPER: "OPER", + PART: "PART", + PASS: "PASS", + PING: "PING", + PONG: "PONG", + PRIVMSG: "PRIVMSG", + QUIT: "QUIT", + REHASH: "REHASH", + RULES: "RULES", + SERVER: "SERVER", + SERVICE: "SERVICE", + SERVLIST: "SERVLIST", + SQUERY: "SQUERY", + SQUIT: "SQUIT", + SETNAME: "SETNAME", + SILENCE: "SILENCE", + STATS: "STATS", + SUMMON: "SUMMON", + TIME: "TIME", + TOPIC: "TOPIC", + TRACE: "TRACE", + USER: "USER", + USERHOST: "USERHOST", + USERIP: "USERIP", + USERS: "USERS", + VERSION: "VERSION", + WALLOPS: "WALLOPS", + WATCH: "WATCH", + WHO: "WHO", + WHOIS: "WHOIS", + WHOWAS: "WHOWAS", + RPL_WELCOME: "001", + RPL_YOURHOST: "002", + RPL_CREATED: "003", + RPL_MYINFO: "004", + RPL_BOUNCE: "005", + RPL_TRACELINK: "200", + RPL_TRACECONNECTING: "201", + RPL_TRACEHANDSHAKE: "202", + RPL_TRACEUNKNOWN: "203", + RPL_TRACEOPERATOR: "204", + RPL_TRACEUSER: "205", + RPL_TRACESERVER: "206", + RPL_TRACESERVICE: "207", + RPL_TRACENEWTYPE: "208", + RPL_TRACECLASS: "209", + RPL_TRACERECONNECT: "210", + RPL_STATSLINKINFO: "211", + RPL_STATSCOMMANDS: "212", + RPL_STATSCLINE: "213", + RPL_STATSNLINE: "214", + RPL_STATSILINE: "215", + RPL_STATSKLINE: "216", + RPL_STATSQLINE: "217", + RPL_STATSYLINE: "218", + RPL_ENDOFSTATS: "219", + RPL_UMODEIS: "221", + RPL_SERVICEINFO: "231", + RPL_ENDOFSERVICES: "232", + RPL_SERVICE: "233", + RPL_SERVLIST: "234", + RPL_SERVLISTEND: "235", + RPL_STATSVLINE: "240", + RPL_STATSLLINE: "241", + RPL_STATSUPTIME: "242", + RPL_STATSOLINE: "243", + RPL_STATSHLINE: "244", + RPL_STATSPING: "246", + RPL_STATSBLINE: "247", + RPL_STATSDLINE: "250", + RPL_LUSERCLIENT: "251", + RPL_LUSEROP: "252", + RPL_LUSERUNKNOWN: "253", + RPL_LUSERCHANNELS: "254", + RPL_LUSERME: "255", + RPL_ADMINME: "256", + RPL_ADMINLOC1: "257", + RPL_ADMINLOC2: "258", + RPL_ADMINEMAIL: "259", + RPL_TRACELOG: "261", + RPL_TRACEEND: "262", + RPL_TRYAGAIN: "263", + RPL_NONE: "300", + RPL_AWAY: "301", + RPL_USERHOST: "302", + RPL_ISON: "303", + RPL_UNAWAY: "305", + RPL_NOWAWAY: "306", + RPL_WHOISUSER: "311", + RPL_WHOISSERVER: "312", + RPL_WHOISOPERATOR: "313", + RPL_WHOWASUSER: "314", + RPL_ENDOFWHO: "315", + RPL_WHOISCHANOP: "316", + RPL_WHOISIDLE: "317", + RPL_ENDOFWHOIS: "318", + RPL_WHOISCHANNELS: "319", + RPL_LISTSTART: "321", + RPL_LIST: "322", + RPL_LISTEND: "323", + RPL_CHANNELMODEIS: "324", + RPL_UNIQOPIS: "325", + RPL_NOTOPIC: "331", + RPL_TOPIC: "332", + RPL_INVITING: "341", + RPL_SUMMONING: "342", + RPL_INVITELIST: "346", + RPL_ENDOFINVITELIST: "347", + RPL_EXCEPTLIST: "348", + RPL_ENDOFEXCEPTLIST: "349", + RPL_VERSION: "351", + RPL_WHOREPLY: "352", + RPL_NAMREPLY: "353", + RPL_KILLDONE: "361", + RPL_CLOSING: "362", + RPL_CLOSEEND: "363", + RPL_LINKS: "364", + RPL_ENDOFLINKS: "365", + RPL_ENDOFNAMES: "366", + RPL_BANLIST: "367", + RPL_ENDOFBANLIST: "368", + RPL_ENDOFWHOWAS: "369", + RPL_INFO: "371", + RPL_MOTD: "372", + RPL_INFOSTART: "373", + RPL_ENDOFINFO: "374", + RPL_MOTDSTART: "375", + RPL_ENDOFMOTD: "376", + RPL_YOUREOPER: "381", + RPL_REHASHING: "382", + RPL_YOURESERVICE: "383", + RPL_MYPORTIS: "384", + RPL_TIME: "391", + RPL_USERSSTART: "392", + RPL_USERS: "393", + RPL_ENDOFUSERS: "394", + RPL_NOUSERS: "395", + ERR_NOSUCHNICK: "401", + ERR_NOSUCHSERVER: "402", + ERR_NOSUCHCHANNEL: "403", + ERR_CANNOTSENDTOCHAN: "404", + ERR_TOOMANYCHANNELS: "405", + ERR_WASNOSUCHNICK: "406", + ERR_TOOMANYTARGETS: "407", + ERR_NOSUCHSERVICE: "408", + ERR_NOORIGIN: "409", + ERR_NORECIPIENT: "411", + ERR_NOTEXTTOSEND: "412", + ERR_NOTOPLEVEL: "413", + ERR_WILDTOPLEVEL: "414", + ERR_BADMASK: "415", + ERR_UNKNOWNCOMMAND: "421", + ERR_NOMOTD: "422", + ERR_NOADMININFO: "423", + ERR_FILEERROR: "424", + ERR_NONICKNAMEGIVEN: "431", + ERR_ERRONEUSNICKNAME: "432", + ERR_NICKNAMEINUSE: "433", + ERR_NICKCOLLISION: "436", + ERR_UNAVAILRESOURCE: "437", + ERR_USERNOTINCHANNEL: "441", + ERR_NOTONCHANNEL: "442", + ERR_USERONCHANNEL: "443", + ERR_NOLOGIN: "444", + ERR_SUMMONDISABLED: "445", + ERR_USERSDISABLED: "446", + ERR_NOTREGISTERED: "451", + ERR_NEEDMOREPARAMS: "461", + ERR_ALREADYREGISTERED: "462", + ERR_NOPERMFORHOST: "463", + ERR_PASSWDMISMATCH: "464", + ERR_YOUREBANNEDCREEP: "465", + ERR_YOUWILLBEBANNED: "466", + ERR_KEYSET: "467", + ERR_CHANNELISFULL: "471", + ERR_UNKNOWNMODE: "472", + ERR_INVITEONLYCHAN: "473", + ERR_BANNEDFROMCHAN: "474", + ERR_BADCHANNELKEY: "475", + ERR_BADCHANMASK: "476", + ERR_NOCHANMODES: "477", + ERR_BANLISTFULL: "478", + ERR_NOPRIVILEGES: "481", + ERR_CHANOPRIVSNEEDED: "482", + ERR_CANTKILLSERVER: "483", + ERR_RESTRICTED: "484", + ERR_UNIQOPRIVSNEEDED: "485", + ERR_NOOPERHOST: "491", + ERR_NOSERVICEHOST: "492", + ERR_UMODEUNKNOWNFLAG: "501", + ERR_USERSDONTMATCH: "502" +); diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..b61e736 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,447 @@ +pub mod events; +pub mod factory; +pub mod irc_command; +pub mod system; +pub mod system_params; + +use std::{ + any::TypeId, + collections::{HashMap, VecDeque}, + io::ErrorKind, + net::ToSocketAddrs, + path::Path, + time::SystemTime, +}; + +use async_native_tls::TlsStream; +use factory::Factory; +use irc_command::IrcCommand; +use serde::{Deserialize, Serialize}; +use system::{IntoSystem, StoredSystem, System}; +use tokio::{ + fs::File, + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, +}; + +pub(crate) const MAX_MSG_LEN: usize = 512; + +#[derive(Default)] +pub enum Stream { + Plain(TcpStream), + Tls(TlsStream), + #[default] + None, +} + +impl Stream { + pub async fn read(&mut self, buf: &mut [u8]) -> std::result::Result { + match self { + Stream::Plain(stream) => stream.read(buf).await, + Stream::Tls(stream) => stream.read(buf).await, + Stream::None => panic!("No stream."), + } + } + + pub async fn write(&mut self, buf: &[u8]) -> std::result::Result { + match self { + Stream::Plain(stream) => stream.write(buf).await, + Stream::Tls(stream) => stream.write(buf).await, + Stream::None => panic!("No stream."), + } + } +} + +pub struct FloodControl { + last_cmd: SystemTime, +} + +impl Default for FloodControl { + fn default() -> Self { + Self { + last_cmd: SystemTime::now(), + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct IrcPrefix<'a> { + pub admin: bool, + pub nick: &'a str, + pub user: Option<&'a str>, + pub host: Option<&'a str>, +} + +impl<'a> From<&'a str> for IrcPrefix<'a> { + fn from(prefix_str: &'a str) -> Self { + let prefix_str = &prefix_str[1..]; + + let nick_split: Vec<&str> = prefix_str.split('!').collect(); + let nick = nick_split[0]; + + // we only have a nick + if nick_split.len() == 1 { + return Self { + nick, + ..Default::default() + }; + } + + let user_split: Vec<&str> = nick_split[1].split('@').collect(); + let user = user_split[0]; + + // we don't have an host + if user_split.len() == 1 { + return Self { + nick: nick, + user: Some(user), + ..Default::default() + }; + } + + Self { + admin: false, + nick: nick, + user: Some(user), + host: Some(user_split[1]), + } + } +} + +pub struct IrcMessage<'a> { + prefix: Option>, + command: IrcCommand, + parameters: Vec<&'a str>, +} + +impl<'a> From<&'a str> for IrcMessage<'a> { + fn from(line: &'a str) -> Self { + let mut elements = line.split_whitespace(); + + let tmp = elements.next().unwrap(); + + if tmp.chars().next().unwrap() == ':' { + return Self { + prefix: Some(tmp.into()), + command: elements.next().unwrap().into(), + parameters: elements.collect(), + }; + } + + Self { + prefix: None, + command: tmp.into(), + parameters: elements.collect(), + } + } +} + +#[derive(Serialize, Deserialize)] +pub struct IrcConfig { + host: String, + port: u16, + ssl: bool, + channels: Vec, + nick: String, + user: String, + real: String, + nickserv_pass: String, + nickserv_email: String, + cmdkey: String, + flood_interval: f32, + owner: String, + admins: Vec, +} + +pub struct Irc { + config: IrcConfig, + stream: Stream, + + systems: HashMap, + factory: Factory, + + flood_controls: HashMap, + + send_queue: VecDeque, + recv_queue: VecDeque, + partial_line: String, +} + +impl Irc { + pub async fn from_config(path: impl AsRef) -> std::io::Result { + let mut file = File::open(path).await?; + let mut contents = String::new(); + file.read_to_string(&mut contents).await?; + + let config: IrcConfig = serde_yaml::from_str(&contents).unwrap(); + + Ok(Self { + config, + stream: Stream::None, + systems: HashMap::default(), + factory: Factory::default(), + flood_controls: HashMap::default(), + send_queue: VecDeque::new(), + recv_queue: VecDeque::new(), + partial_line: String::new(), + }) + } + + pub fn add_system System<'a> + 'static>( + &mut self, + name: &str, + system: impl for<'a> IntoSystem<'a, I, System = S>, + ) -> &mut Self { + self.systems + .insert(name.to_owned(), Box::new(system.into_system())); + self + } + + pub fn add_resource(&mut self, res: R) -> &mut Self { + self.factory + .resources + .insert(TypeId::of::(), Box::new(res)); + self + } + + pub fn run_system<'a>(&mut self, prefix: &'a IrcPrefix, name: &str) { + let system = self.systems.get_mut(name).unwrap(); + system.run(prefix, &mut self.factory); + } + + pub async fn connect(&mut self) -> std::io::Result<()> { + let domain = format!("{}:{}", self.config.host, self.config.port); + + let mut addrs = domain + .to_socket_addrs() + .expect("Unable to get addrs from domain {domain}"); + + let sock = addrs + .next() + .expect("Unable to get ip from addrs: {addrs:?}"); + + let plain_stream = TcpStream::connect(sock).await?; + + if self.config.ssl { + let stream = async_native_tls::connect(self.config.host.clone(), plain_stream) + .await + .unwrap(); + self.stream = Stream::Tls(stream); + return Ok(()); + } + + self.stream = Stream::Plain(plain_stream); + Ok(()) + } + + pub fn register(&mut self) { + self.queue(&format!( + "USER {} 0 * {}", + self.config.user, self.config.real + )); + self.queue(&format!("NICK {}", self.config.nick)); + } + + async fn recv(&mut self) -> std::io::Result<()> { + let mut buf = [0; MAX_MSG_LEN]; + + let bytes_read = match self.stream.read(&mut buf).await { + Ok(bytes_read) => bytes_read, + Err(err) => match err.kind() { + ErrorKind::WouldBlock => { + return Ok(()); + } + _ => panic!("{err}"), + }, + }; + + if bytes_read == 0 { + return Ok(()); + } + + let buf = &buf[..bytes_read]; + + self.partial_line += String::from_utf8_lossy(buf).into_owned().as_str(); + let new_lines: Vec<&str> = self.partial_line.split("\r\n").collect(); + let len = new_lines.len(); + + for (index, line) in new_lines.into_iter().enumerate() { + if index == len - 1 && &buf[buf.len() - 3..] != b"\r\n" { + self.partial_line = line.to_owned(); + break; + } + self.recv_queue.push_back(line.to_owned()); + } + Ok(()) + } + + async fn send(&mut self) -> std::io::Result<()> { + while self.send_queue.len() > 0 { + let msg = self.send_queue.pop_front().unwrap(); + + let bytes_written = match self.stream.write(msg.as_bytes()).await { + Ok(bytes_written) => bytes_written, + Err(err) => match err.kind() { + ErrorKind::WouldBlock => { + return Ok(()); + } + _ => panic!("{err}"), + }, + }; + + if bytes_written < msg.len() { + self.send_queue.push_front(msg[bytes_written..].to_owned()); + } + } + Ok(()) + } + + fn queue(&mut self, msg: &str) { + let mut msg = msg.replace("\r", "").replace("\n", ""); + + if msg.len() > MAX_MSG_LEN - "\r\n".len() { + let mut i = 0; + + while i < msg.len() { + let max = (MAX_MSG_LEN - "\r\n".len()).min(msg[i..].len()); + + let mut m = msg[i..(i + max)].to_owned(); + println!(">> {:?}", m); + m = m + "\r\n"; + self.send_queue.push_back(m); + i += MAX_MSG_LEN - "\r\n".len() + } + } else { + println!(">> {:?}", msg); + msg = msg + "\r\n"; + self.send_queue.push_back(msg); + } + } + + pub async fn update(&mut self) -> std::io::Result<()> { + self.recv().await?; + self.send().await?; + self.handle_commands().await; + Ok(()) + } + + pub async fn handle_commands(&mut self) { + while self.recv_queue.len() != 0 { + let owned_line = self.recv_queue.pop_front().unwrap(); + let line = owned_line.as_str(); + + println!("<< {:?}", line); + + let mut message: IrcMessage = line.into(); + + let Some(prefix) = &mut message.prefix else { + return self.handle_message(&message).await; + }; + + if self.is_owner(prefix) { + prefix.admin = true; + } else { + for admin in &self.config.admins { + if self.is_admin(prefix, admin) { + prefix.admin = true; + break; + } + } + } + + self.handle_message(&message).await; + } + } + + fn is_owner(&self, prefix: &IrcPrefix) -> bool { + self.is_admin(prefix, &self.config.owner) + } + + fn is_admin(&self, prefix: &IrcPrefix, admin: &str) -> bool { + let admin = ":".to_owned() + &admin; + let admin_prefix: IrcPrefix = admin.as_str().into(); + + if (admin_prefix.nick == prefix.nick || admin_prefix.nick == "*") + && (admin_prefix.user == prefix.user || admin_prefix.user == Some("*")) + && (admin_prefix.host == prefix.host || admin_prefix.host == Some("*")) + { + return true; + } + + false + } + + fn join(&mut self, channel: &str) { + self.queue(&format!("JOIN {}", channel)) + } + + fn join_config_channels(&mut self) { + for i in 0..self.config.channels.len() { + let channel = &self.config.channels[i]; + self.queue(&format!("JOIN {}", channel)) + } + } + + fn update_nick(&mut self, new_nick: &str) { + self.config.nick = new_nick.to_owned(); + self.queue(&format!("NICK {}", self.config.nick)); + } + + fn is_flood(&mut self, channel: &str) -> bool { + let mut flood_control = match self.flood_controls.entry(channel.to_owned()) { + std::collections::hash_map::Entry::Occupied(o) => o.into_mut(), + std::collections::hash_map::Entry::Vacant(v) => v.insert(FloodControl { + last_cmd: SystemTime::now(), + }), + }; + + let elapsed = flood_control.last_cmd.elapsed().unwrap(); + + if elapsed.as_secs_f32() < self.config.flood_interval { + return true; + } + + flood_control.last_cmd = SystemTime::now(); + false + } + + pub fn privmsg(&mut self, channel: &str, message: &str) { + self.queue(&format!("PRIVMSG {} :{}", channel, message)); + } + + pub fn privmsg_all(&mut self, message: &str) { + for i in 0..self.config.channels.len() { + let channel = &self.config.channels[i]; + self.queue(&format!("PRIVMSG {} :{}", channel, message)); + } + } + + async fn handle_message<'a>(&mut self, message: &'a IrcMessage<'a>) { + match message.command { + IrcCommand::PING => self.event_ping(&message.parameters[0]), + IrcCommand::RPL_WELCOME => self.event_welcome(), + IrcCommand::ERR_NICKNAMEINUSE => self.event_nicknameinuse(), + IrcCommand::KICK => self.event_kick( + message.parameters[0], + message.parameters[1], + &message.parameters[3..].join(" "), + ), + IrcCommand::QUIT => self.event_quit(message.prefix.as_ref().unwrap()).await, + IrcCommand::INVITE => self.event_invite( + message.prefix.as_ref().unwrap(), + &message.parameters[0][1..], + ), + IrcCommand::PRIVMSG => self.event_privmsg( + message.prefix.as_ref().unwrap(), + &message.parameters[0], + &message.parameters[1..].join(" ")[1..], + ), + IrcCommand::NOTICE => self.event_notice( + message.prefix.as_ref(), + &message.parameters[0], + &message.parameters[1..].join(" ")[1..], + ), + _ => {} + } + } +} diff --git a/src/system.rs b/src/system.rs new file mode 100644 index 0000000..d53ddbd --- /dev/null +++ b/src/system.rs @@ -0,0 +1,102 @@ +use std::marker::PhantomData; + +use crate::{factory::Factory, IrcPrefix}; + +pub struct FunctionSystem { + f: F, + marker: PhantomData Input>, +} + +pub trait System<'a> { + fn run(&mut self, prefix: &'a IrcPrefix, factory: &'a mut Factory) -> Response; +} + +pub trait IntoSystem<'a, Input> { + type System: System<'a>; + + fn into_system(self) -> Self::System; +} + +macro_rules! impl_system { + ( + $($params:ident),* + ) => { + #[allow(non_snake_case)] + #[allow(unused)] + impl System<'_> for FunctionSystem<($($params,)*), F> + where + for<'a, 'b> &'a mut F: + FnMut( $($params),* ) -> R + + FnMut( $(<$params as SystemParam>::Item<'b>),* ) -> R + { + fn run(&mut self, prefix: &IrcPrefix, factory: &mut Factory) -> Response { + fn call_inner<'a, R: IntoResponse, $($params),*>( + mut f: impl FnMut($($params),*) -> R, + $($params: $params),* + ) -> Response { + f($($params),*).response() + } + + $( + let $params = $params::retrieve(prefix, &factory); + )* + + call_inner(&mut self.f, $($params),*) + } + } + } +} + +macro_rules! impl_into_system { + ( + $($params:ident),* + ) => { + impl IntoSystem<'_, ($($params,)*)> for F + where + for<'a, 'b> &'a mut F: + FnMut( $($params),* ) -> R + + FnMut( $(<$params as SystemParam>::Item<'b>),* ) -> R + { + type System = FunctionSystem<($($params,)*), Self>; + + fn into_system(self) -> Self::System { + FunctionSystem { + f: self, + marker: Default::default(), + } + } + } + } +} + +impl_system!(); +impl_system!(T1); +impl_system!(T1, T2); +impl_system!(T1, T2, T3); +impl_system!(T1, T2, T3, T4); + +impl_into_system!(); +impl_into_system!(T1); +impl_into_system!(T1, T2); +impl_into_system!(T1, T2, T3); +impl_into_system!(T1, T2, T3, T4); + +pub(crate) type StoredSystem = Box System<'a>>; + +pub(crate) trait SystemParam { + type Item<'new>; + fn retrieve<'r>(prefix: &'r IrcPrefix, factory: &'r Factory) -> Self::Item<'r>; +} + +#[derive(Clone)] +pub struct Response(pub Option>); + +pub trait IntoResponse { + fn response(self) -> Response; +} + +impl IntoResponse for () { + fn response(self) -> Response { + Response(None) + } +} diff --git a/src/system_params.rs b/src/system_params.rs new file mode 100644 index 0000000..ec14a00 --- /dev/null +++ b/src/system_params.rs @@ -0,0 +1,97 @@ +use std::{ + any::{Any, TypeId}, + collections::HashMap, + ops::{Deref, DerefMut}, +}; + +use crate::{factory::Factory, system::SystemParam, IrcPrefix}; + +#[derive(Debug)] +pub struct Res<'a, T: 'static> { + value: &'a T, +} + +impl<'a, T: 'static> Deref for Res<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.value + } +} + +impl<'a, T: 'static> AsRef for Res<'a, T> { + fn as_ref(&self) -> &T { + self.value + } +} + +impl<'res, T: 'static> SystemParam for Res<'res, T> { + type Item<'new> = Res<'new, T>; + + fn retrieve<'r>(_prefix: &'r IrcPrefix, factory: &'r Factory) -> Self::Item<'r> { + Res { + value: &factory + .resources + .get(&TypeId::of::()) + .unwrap() + .downcast_ref() + .unwrap(), + } + } +} + +pub struct ResMut<'a, T: 'static> { + value: &'a mut T, +} + +impl<'a, T: 'static> Deref for ResMut<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.value + } +} + +impl<'a, T: 'static> DerefMut for ResMut<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.value + } +} + +impl<'a, T: 'static> AsRef for ResMut<'a, T> { + fn as_ref(&self) -> &T { + self.value + } +} + +impl<'a, T: 'static> AsMut for ResMut<'a, T> { + fn as_mut(&mut self) -> &mut T { + &mut self.value + } +} + +impl<'res, T: 'static> SystemParam for ResMut<'res, T> { + type Item<'new> = ResMut<'new, T>; + + fn retrieve<'r>(_prefix: &'r IrcPrefix, factory: &'r Factory) -> Self::Item<'r> { + let const_ptr = &factory.resources as *const HashMap>; + let mut_ptr = const_ptr as *mut HashMap>; + let res_mut = unsafe { &mut *mut_ptr }; + + ResMut { + value: res_mut + .get_mut(&TypeId::of::()) + .unwrap() + .downcast_mut() + .unwrap(), + } + } +} + +impl<'a> SystemParam for IrcPrefix<'a> { + type Item<'new> = IrcPrefix<'new>; + + fn retrieve<'r>(prefix: &'r IrcPrefix, _factory: &'r Factory) -> Self::Item<'r> { + prefix.clone() + } +}