spliting irc struct to be able to use async tasks

This commit is contained in:
wrk 2023-05-29 19:15:32 +02:00
parent 6b66e862d5
commit 734b25c403
6 changed files with 361 additions and 247 deletions

View File

@ -1,85 +1,142 @@
use std::time::Duration;
use log::{debug, info, warn}; use log::{debug, info, warn};
use std::time::Duration;
use crate::{Irc, IrcPrefix}; use crate::{Irc, IrcPrefix};
impl Irc { impl Irc {
pub(crate) fn event_ping(&mut self, ping_token: &str) { pub(crate) async fn event_ping(&mut self, ping_token: &str) {
debug!("PING {}", ping_token); debug!("PING {}", ping_token);
self.queue(&format!("PONG {}", ping_token));
self.context
.write()
.await
.queue(&format!("PONG {}", ping_token));
} }
pub(crate) fn event_welcome(&mut self, welcome_msg: &str) { pub(crate) async fn event_welcome(&mut self, welcome_msg: &str) {
debug!("{welcome_msg}"); debug!("{welcome_msg}");
// self.identify(); let mut context = self.context.write().await;
self.join_config_channels(); context.identify();
context.join_config_channels();
} }
pub(crate) fn event_nicknameinuse(&mut self) { pub(crate) async fn event_nicknameinuse(&mut self) {
let new_nick = &format!("{}_", &self.config.nick); let mut context = self.context.write().await;
let new_nick = &format!("{}_", &context.config.nick);
warn!("Nick already in use., switching to {}", new_nick); warn!("Nick already in use., switching to {}", new_nick);
self.update_nick(new_nick) context.update_nick(new_nick)
} }
pub(crate) fn event_kick(&mut self, channel: &str, nick: &str, kicker: &str, reason: &str) { pub(crate) async fn event_kick(
if nick != &self.config.nick { &mut self,
channel: &str,
nick: &str,
kicker: &str,
reason: &str,
) {
let mut context = self.context.write().await;
if nick != &context.config.nick {
return; return;
} }
warn!("We got kicked from {} by {}! ({})", channel, kicker, reason); warn!("We got kicked from {} by {}! ({})", channel, kicker, reason);
self.join(channel); context.join(channel);
} }
pub(crate) async fn event_quit<'a>(&mut self, prefix: &'a IrcPrefix<'a>) { pub(crate) async fn event_quit<'a>(&mut self, prefix: &'a IrcPrefix<'a>) {
if prefix.nick != self.config.nick { if prefix.nick != self.context.read().await.config.nick {
return; return;
} }
warn!("We quit. We'll reconnect in {} seconds.", 15); warn!("We quit. We'll reconnect in {} seconds.", 15);
std::thread::sleep(Duration::from_secs(15)); std::thread::sleep(Duration::from_secs(15));
self.connect().await.unwrap(); self.connect().await.unwrap();
self.register();
} }
pub(crate) fn event_invite(&mut self, prefix: &IrcPrefix, channel: &str) { pub(crate) async fn event_invite<'a>(&mut self, prefix: &'a IrcPrefix<'a>, channel: &str) {
info!("{} invited us to {}", prefix.nick, channel); info!("{} invited us to {}", prefix.nick, channel);
self.join(channel); self.context.write().await.join(channel);
} }
pub(crate) fn event_notice( pub(crate) async fn event_notice<'a>(
&mut self, &mut self,
prefix: Option<&IrcPrefix>, _prefix: Option<&IrcPrefix<'a>>,
channel: &str, channel: &str,
message: &str, message: &str,
) { ) {
//TODO, register shit let mut context = self.context.write().await;
if channel == &context.config.nick {
if message.ends_with(&format!(
"\x02{}\x02 isn't registered.",
context.config.nick
)) {
let nickserv_pass = context.config.nickserv_pass.as_ref().unwrap().to_string();
let nickserv_email = context.config.nickserv_email.as_ref().unwrap().to_string();
info!("Registering to nickserv now.");
context.privmsg(
"NickServ",
&format!("REGISTER {} {}", nickserv_pass, nickserv_email),
);
}
if message.ends_with(" seconds to register.") {
let seconds = message
.split_whitespace()
.nth(10)
.unwrap()
.parse::<usize>()
.unwrap()
+ 1;
info!("Waiting {} seconds to register.", seconds);
let ctx_clone = self.context.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_secs(seconds as u64)).await;
ctx_clone.write().await.identify();
});
}
}
} }
pub(crate) fn event_privmsg(&mut self, prefix: &IrcPrefix, channel: &str, message: &str) { pub(crate) async fn event_privmsg<'a>(
if !message.starts_with(&self.config.cmdkey) { &mut self,
prefix: &'a IrcPrefix<'a>,
channel: &str,
message: &str,
) {
let sys_name;
{
let context = self.context.read().await;
if !message.starts_with(&context.config.cmdkey) {
return; return;
} }
let mut elements = message.split_whitespace(); let mut elements = message.split_whitespace();
let sys_name = &elements.next().unwrap()[1..]; sys_name = elements.next().unwrap()[1..].to_owned();
if self.is_owner(prefix) && sys_name == "raw" { if context.is_owner(prefix) && sys_name == "raw" {
self.queue(&elements.collect::<Vec<_>>().join(" ")); let mut context = self.context.write().await;
context.queue(&elements.collect::<Vec<_>>().join(" "));
return;
}
}
if self.is_flood(channel).await {
return; return;
} }
if self.is_flood(channel) { //TODO:
return; // MOVE RUN_SYSTEM BACK TO IRC
}
let response = self.run_system(prefix, sys_name); let mut context = self.context.write().await;
let response = context.run_system(prefix, &sys_name);
if response.0.is_none() { if response.0.is_none() {
return; return;
} }
for line in response.0.unwrap() { for line in response.0.unwrap() {
self.privmsg(channel, &line) context.privmsg(channel, &line)
} }
} }
} }

View File

@ -5,5 +5,5 @@ use std::{
#[derive(Default)] #[derive(Default)]
pub struct Factory { pub struct Factory {
pub(crate) resources: HashMap<TypeId, Box<dyn Any>>, pub(crate) resources: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
} }

View File

@ -3,6 +3,7 @@ pub mod factory;
pub mod irc_command; pub mod irc_command;
pub mod system; pub mod system;
pub mod system_params; pub mod system_params;
pub mod utils;
use std::{ use std::{
any::TypeId, any::TypeId,
@ -10,6 +11,7 @@ use std::{
io::ErrorKind, io::ErrorKind,
net::ToSocketAddrs, net::ToSocketAddrs,
path::Path, path::Path,
sync::Arc,
time::SystemTime, time::SystemTime,
}; };
@ -23,6 +25,7 @@ use tokio::{
fs::File, fs::File,
io::{AsyncReadExt, AsyncWriteExt}, io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream, net::TcpStream,
sync::RwLock,
}; };
pub(crate) const MAX_MSG_LEN: usize = 512; pub(crate) const MAX_MSG_LEN: usize = 512;
@ -67,7 +70,6 @@ impl Default for FloodControl {
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct IrcPrefix<'a> { pub struct IrcPrefix<'a> {
pub admin: bool,
pub nick: &'a str, pub nick: &'a str,
pub user: Option<&'a str>, pub user: Option<&'a str>,
pub host: Option<&'a str>, pub host: Option<&'a str>,
@ -101,7 +103,6 @@ impl<'a> From<&'a str> for IrcPrefix<'a> {
} }
Self { Self {
admin: false,
nick: nick, nick: nick,
user: Some(user), user: Some(user),
host: Some(user_split[1]), host: Some(user_split[1]),
@ -146,164 +147,35 @@ pub struct IrcConfig {
nick: String, nick: String,
user: String, user: String,
real: String, real: String,
nickserv_pass: String, nickserv_pass: Option<String>,
nickserv_email: String, nickserv_email: Option<String>,
cmdkey: String, cmdkey: String,
flood_interval: f32, flood_interval: f32,
owner: String, owner: String,
admins: Vec<String>, admins: Vec<String>,
} }
pub struct Irc { // TODO:
/*
split Irc into two structs, one for the context, which is Send + Sync to be usable in tasks
one for the comms.
*/
pub struct Context {
config: IrcConfig, config: IrcConfig,
stream: Stream, identified: bool,
send_queue: VecDeque<String>,
systems: HashMap<String, StoredSystem>, systems: HashMap<String, StoredSystem>,
factory: Factory, factory: Factory,
flood_controls: HashMap<String, FloodControl>,
send_queue: VecDeque<String>,
recv_queue: VecDeque<String>,
partial_line: String,
} }
impl Irc { impl Context {
pub async fn from_config(path: impl AsRef<Path>) -> std::io::Result<Self> { pub fn privmsg(&mut self, channel: &str, message: &str) {
let mut file = File::open(path).await?; debug!("sending privmsg to {} : {}", channel, message);
let mut contents = String::new(); self.queue(&format!("PRIVMSG {} :{}", channel, message));
file.read_to_string(&mut contents).await?;
let config: IrcConfig = serde_yaml::from_str(&contents).unwrap();
Ok(Self {
config,
stream: Stream::None,
systems: HashMap::default(),
factory: Factory::default(),
flood_controls: HashMap::default(),
send_queue: VecDeque::new(),
recv_queue: VecDeque::new(),
partial_line: String::new(),
})
} }
pub fn add_system<I, S: for<'a> System<'a> + 'static>(
&mut self,
name: &str,
system: impl for<'a> IntoSystem<'a, I, System = S>,
) -> &mut Self {
self.systems
.insert(name.to_owned(), Box::new(system.into_system()));
self
}
pub fn add_resource<R: 'static>(&mut self, res: R) -> &mut Self {
self.factory
.resources
.insert(TypeId::of::<R>(), Box::new(res));
self
}
pub fn run_system<'a>(&mut self, prefix: &'a IrcPrefix, name: &str) -> Response {
let system = self.systems.get_mut(name).unwrap();
system.run(prefix, &mut self.factory)
}
pub async fn connect(&mut self) -> std::io::Result<()> {
let domain = format!("{}:{}", self.config.host, self.config.port);
info!("Connecting to {}", domain);
let mut addrs = domain
.to_socket_addrs()
.expect("Unable to get addrs from domain {domain}");
let sock = addrs
.next()
.expect("Unable to get ip from addrs: {addrs:?}");
let plain_stream = TcpStream::connect(sock).await?;
if self.config.ssl {
let stream = async_native_tls::connect(self.config.host.clone(), plain_stream)
.await
.unwrap();
self.stream = Stream::Tls(stream);
return Ok(());
}
self.stream = Stream::Plain(plain_stream);
Ok(())
}
pub fn register(&mut self) {
info!(
"Registering as {}!{} ({})",
self.config.nick, self.config.user, self.config.real
);
self.queue(&format!(
"USER {} 0 * {}",
self.config.user, self.config.real
));
self.queue(&format!("NICK {}", self.config.nick));
}
async fn recv(&mut self) -> std::io::Result<()> {
let mut buf = [0; MAX_MSG_LEN];
let bytes_read = match self.stream.read(&mut buf).await {
Ok(bytes_read) => bytes_read,
Err(err) => match err.kind() {
ErrorKind::WouldBlock => {
return Ok(());
}
_ => panic!("{err}"),
},
};
if bytes_read == 0 {
return Ok(());
}
let buf = &buf[..bytes_read];
self.partial_line += String::from_utf8_lossy(buf).into_owned().as_str();
let new_lines: Vec<&str> = self.partial_line.split("\r\n").collect();
let len = new_lines.len();
for (index, line) in new_lines.into_iter().enumerate() {
if index == len - 1 && &buf[buf.len() - 3..] != b"\r\n" {
self.partial_line = line.to_owned();
break;
}
self.recv_queue.push_back(line.to_owned());
}
Ok(())
}
async fn send(&mut self) -> std::io::Result<()> {
while self.send_queue.len() > 0 {
let msg = self.send_queue.pop_front().unwrap();
trace!(">> {}", msg.replace("\r\n", ""));
let bytes_written = match self.stream.write(msg.as_bytes()).await {
Ok(bytes_written) => bytes_written,
Err(err) => match err.kind() {
ErrorKind::WouldBlock => {
return Ok(());
}
_ => panic!("{err}"),
},
};
if bytes_written < msg.len() {
self.send_queue.push_front(msg[bytes_written..].to_owned());
}
}
Ok(())
}
fn queue(&mut self, msg: &str) { fn queue(&mut self, msg: &str) {
let mut msg = msg.replace("\r", "").replace("\n", ""); let mut msg = msg.replace("\r", "").replace("\n", "");
@ -324,39 +196,27 @@ impl Irc {
} }
} }
pub async fn update(&mut self) -> std::io::Result<()> { pub fn identify(&mut self) {
self.recv().await?; if self.config.nickserv_pass.is_none() || self.identified {
self.handle_commands().await; return;
self.send().await?;
Ok(())
} }
pub async fn handle_commands(&mut self) { self.privmsg(
while self.recv_queue.len() != 0 { "NickServ",
let owned_line = self.recv_queue.pop_front().unwrap(); &format!("IDENTIFY {}", self.config.nickserv_pass.as_ref().unwrap()),
let line = owned_line.as_str(); );
trace!("<< {:?}", line);
let mut message: IrcMessage = line.into();
let Some(prefix) = &mut message.prefix else {
return self.handle_message(&message).await;
};
if self.is_owner(prefix) {
prefix.admin = true;
} else {
for admin in &self.config.admins {
if self.is_admin(prefix, admin) {
prefix.admin = true;
break;
}
}
} }
self.handle_message(&message).await; pub fn register(&mut self) {
} info!(
"Registering as {}!{} ({})",
self.config.nick, self.config.user, self.config.real
);
self.queue(&format!(
"USER {} 0 * {}",
self.config.user, self.config.real
));
self.queue(&format!("NICK {}", self.config.nick));
} }
fn is_owner(&self, prefix: &IrcPrefix) -> bool { fn is_owner(&self, prefix: &IrcPrefix) -> bool {
@ -396,7 +256,122 @@ impl Irc {
self.queue(&format!("NICK {}", self.config.nick)); self.queue(&format!("NICK {}", self.config.nick));
} }
fn is_flood(&mut self, channel: &str) -> bool { pub fn privmsg_all(&mut self, message: &str) {
for i in 0..self.config.channels.len() {
let channel = self.config.channels.iter().nth(i).unwrap();
debug!("sending privmsg to {} : {}", channel, message);
self.queue(&format!("PRIVMSG {} :{}", channel, message));
}
}
pub fn add_system<I, S: for<'a> System<'a> + Send + Sync + 'static>(
&mut self,
name: &str,
system: impl for<'a> IntoSystem<'a, I, System = S>,
) -> &mut Self {
self.systems
.insert(name.to_owned(), Box::new(system.into_system()));
self
}
pub fn add_resource<R: Send + Sync + 'static>(&mut self, res: R) -> &mut Self {
self.factory
.resources
.insert(TypeId::of::<R>(), Box::new(res));
self
}
pub fn run_system<'a>(&mut self, prefix: &'a IrcPrefix, name: &str) -> Response {
let system = self.systems.get_mut(name).unwrap();
system.run(prefix, &mut self.factory)
}
}
pub struct Irc {
context: Arc<RwLock<Context>>,
recv_queue: VecDeque<String>,
flood_controls: HashMap<String, FloodControl>,
stream: Stream,
partial_line: String,
}
impl Irc {
pub async fn from_config(path: impl AsRef<Path>) -> std::io::Result<Self> {
let mut file = File::open(path).await?;
let mut contents = String::new();
file.read_to_string(&mut contents).await?;
let config: IrcConfig = serde_yaml::from_str(&contents).unwrap();
let context = Arc::new(RwLock::new(Context {
config,
identified: false,
send_queue: VecDeque::new(),
systems: HashMap::default(),
factory: Factory::default(),
}));
Ok(Self {
context,
stream: Stream::None,
recv_queue: VecDeque::new(),
flood_controls: HashMap::default(),
partial_line: String::new(),
})
}
pub async fn add_system<I, S: for<'a> System<'a> + Send + Sync + 'static>(
&mut self,
name: &str,
system: impl for<'a> IntoSystem<'a, I, System = S>,
) -> &mut Self {
{
let mut context = self.context.write().await;
context.add_system(name, system);
}
self
}
pub async fn add_resource<R: Send + Sync + 'static>(&mut self, res: R) -> &mut Self {
{
let mut context = self.context.write().await;
context.add_resource(res);
}
self
}
pub async fn connect(&mut self) -> std::io::Result<()> {
let mut context = self.context.write().await;
let domain = format!("{}:{}", context.config.host, context.config.port);
info!("Connecting to {}", domain);
let mut addrs = domain
.to_socket_addrs()
.expect("Unable to get addrs from domain {domain}");
let sock = addrs
.next()
.expect("Unable to get ip from addrs: {addrs:?}");
let plain_stream = TcpStream::connect(sock).await?;
if context.config.ssl {
let stream = async_native_tls::connect(context.config.host.clone(), plain_stream)
.await
.unwrap();
self.stream = Stream::Tls(stream);
context.register();
return Ok(());
}
self.stream = Stream::Plain(plain_stream);
context.register();
Ok(())
}
async fn is_flood(&mut self, channel: &str) -> bool {
let mut flood_control = match self.flood_controls.entry(channel.to_owned()) { let mut flood_control = match self.flood_controls.entry(channel.to_owned()) {
std::collections::hash_map::Entry::Occupied(o) => o.into_mut(), std::collections::hash_map::Entry::Occupied(o) => o.into_mut(),
std::collections::hash_map::Entry::Vacant(v) => { std::collections::hash_map::Entry::Vacant(v) => {
@ -409,7 +384,8 @@ impl Irc {
let elapsed = flood_control.last_cmd.elapsed().unwrap(); let elapsed = flood_control.last_cmd.elapsed().unwrap();
if elapsed.as_secs_f32() < self.config.flood_interval {
if elapsed.as_secs_f32() < self.context.read().await.config.flood_interval {
warn!("they be floodin @ {channel}!"); warn!("they be floodin @ {channel}!");
return true; return true;
} }
@ -418,46 +394,127 @@ impl Irc {
false false
} }
pub fn privmsg(&mut self, channel: &str, message: &str) { async fn recv(&mut self) -> std::io::Result<()> {
debug!("sending privmsg to {} : {}", channel, message); let mut buf = [0; MAX_MSG_LEN];
self.queue(&format!("PRIVMSG {} :{}", channel, message));
let bytes_read = match self.stream.read(&mut buf).await {
Ok(bytes_read) => bytes_read,
Err(err) => match err.kind() {
ErrorKind::WouldBlock => {
return Ok(());
}
_ => panic!("{err}"),
},
};
if bytes_read == 0 {
return Ok(());
} }
pub fn privmsg_all(&mut self, message: &str) { let buf = &buf[..bytes_read];
for i in 0..self.config.channels.len() {
let channel = self.config.channels.iter().nth(i).unwrap(); self.partial_line += String::from_utf8_lossy(buf).into_owned().as_str();
debug!("sending privmsg to {} : {}", channel, message); let new_lines: Vec<&str> = self.partial_line.split("\r\n").collect();
self.queue(&format!("PRIVMSG {} :{}", channel, message)); let len = new_lines.len();
for (index, line) in new_lines.into_iter().enumerate() {
if index == len - 1 && &buf[buf.len() - 3..] != b"\r\n" {
self.partial_line = line.to_owned();
break;
}
self.recv_queue.push_back(line.to_owned());
}
Ok(())
}
async fn send(&mut self) -> std::io::Result<()> {
let mut context = self.context.write().await;
while context.send_queue.len() > 0 {
let msg = context.send_queue.pop_front().unwrap();
trace!(">> {}", msg.replace("\r\n", ""));
let bytes_written = match self.stream.write(msg.as_bytes()).await {
Ok(bytes_written) => bytes_written,
Err(err) => match err.kind() {
ErrorKind::WouldBlock => {
return Ok(());
}
_ => panic!("{err}"),
},
};
if bytes_written < msg.len() {
context
.send_queue
.push_front(msg[bytes_written..].to_owned());
}
}
Ok(())
}
pub async fn handle_commands(&mut self) {
while self.recv_queue.len() != 0 {
let owned_line = self.recv_queue.pop_front().unwrap();
let line = owned_line.as_str();
trace!("<< {:?}", line);
let mut message: IrcMessage = line.into();
let Some(prefix) = &mut message.prefix else {
return self.handle_message(&message).await;
};
self.handle_message(&message).await;
} }
} }
async fn handle_message<'a>(&mut self, message: &'a IrcMessage<'a>) { async fn handle_message<'a>(&mut self, message: &'a IrcMessage<'a>) {
match message.command { match message.command {
IrcCommand::PING => self.event_ping(&message.parameters[0]), IrcCommand::PING => self.event_ping(&message.parameters[0]).await,
IrcCommand::RPL_WELCOME => self.event_welcome(&message.parameters[1..].join(" ")), IrcCommand::RPL_WELCOME => self.event_welcome(&message.parameters[1..].join(" ")).await,
IrcCommand::ERR_NICKNAMEINUSE => self.event_nicknameinuse(), IrcCommand::ERR_NICKNAMEINUSE => self.event_nicknameinuse().await,
IrcCommand::KICK => self.event_kick( IrcCommand::KICK => {
self.event_kick(
&message.parameters[0], &message.parameters[0],
&message.parameters[1], &message.parameters[1],
&message.prefix.as_ref().unwrap().nick, &message.prefix.as_ref().unwrap().nick,
&message.parameters[2..].join(" "), &message.parameters[2..].join(" "),
), )
.await
}
IrcCommand::QUIT => self.event_quit(message.prefix.as_ref().unwrap()).await, IrcCommand::QUIT => self.event_quit(message.prefix.as_ref().unwrap()).await,
IrcCommand::INVITE => self.event_invite( IrcCommand::INVITE => {
self.event_invite(
message.prefix.as_ref().unwrap(), message.prefix.as_ref().unwrap(),
&message.parameters[1][1..], &message.parameters[1][1..],
), )
IrcCommand::PRIVMSG => self.event_privmsg( .await
}
IrcCommand::PRIVMSG => {
self.event_privmsg(
message.prefix.as_ref().unwrap(), message.prefix.as_ref().unwrap(),
&message.parameters[0], &message.parameters[0],
&message.parameters[1..].join(" ")[1..], &message.parameters[1..].join(" ")[1..],
), )
IrcCommand::NOTICE => self.event_notice( .await
}
IrcCommand::NOTICE => {
self.event_notice(
message.prefix.as_ref(), message.prefix.as_ref(),
&message.parameters[0], &message.parameters[0],
&message.parameters[1..].join(" ")[1..], &message.parameters[1..].join(" ")[1..],
), )
.await
}
_ => {} _ => {}
} }
} }
pub async fn update(&mut self) -> std::io::Result<()> {
self.recv().await?;
self.handle_commands().await;
self.send().await?;
Ok(())
}
} }

View File

@ -81,7 +81,7 @@ impl_into_system!(T1, T2);
impl_into_system!(T1, T2, T3); impl_into_system!(T1, T2, T3);
impl_into_system!(T1, T2, T3, T4); impl_into_system!(T1, T2, T3, T4);
pub(crate) type StoredSystem = Box<dyn for<'a> System<'a>>; pub(crate) type StoredSystem = Box<dyn for<'a> System<'a> + Send + Sync>;
pub(crate) trait SystemParam { pub(crate) trait SystemParam {
type Item<'new>; type Item<'new>;

View File

@ -74,7 +74,7 @@ impl<'res, T: 'static> SystemParam for ResMut<'res, T> {
type Item<'new> = ResMut<'new, T>; type Item<'new> = ResMut<'new, T>;
fn retrieve<'r>(_prefix: &'r IrcPrefix, factory: &'r Factory) -> Self::Item<'r> { fn retrieve<'r>(_prefix: &'r IrcPrefix, factory: &'r Factory) -> Self::Item<'r> {
let const_ptr = &factory.resources as *const HashMap<TypeId, Box<dyn Any>>; let const_ptr = &factory.resources as *const HashMap<TypeId, Box<dyn Any + Send + Sync>>;
let mut_ptr = const_ptr as *mut HashMap<TypeId, Box<dyn Any>>; let mut_ptr = const_ptr as *mut HashMap<TypeId, Box<dyn Any>>;
let res_mut = unsafe { &mut *mut_ptr }; let res_mut = unsafe { &mut *mut_ptr };

0
src/utils.rs Normal file
View File