Added context to systems

This commit is contained in:
wrk 2023-06-09 02:02:16 +02:00
parent a1b1435b72
commit 789bfca745
5 changed files with 312 additions and 232 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
/target
Cargo.lock

View File

@ -15,16 +15,20 @@ impl Irc {
pub(crate) async fn event_welcome(&mut self, welcome_msg: &str) {
debug!("{welcome_msg}");
self.identify().await;
let mut context = self.context.write().await;
context.identify();
context.join_config_channels();
for channel in &self.config.channels {
context.join(channel);
}
}
pub(crate) async fn event_nicknameinuse(&mut self) {
let mut context = self.context.write().await;
let new_nick = &format!("{}_", &context.config.nick);
let new_nick = format!("{}_", &self.config.nick);
warn!("Nick already in use., switching to {}", new_nick);
context.update_nick(new_nick)
context.nick(&new_nick);
self.config.nick = new_nick;
}
pub(crate) async fn event_kick(
@ -35,7 +39,7 @@ impl Irc {
reason: &str,
) {
let mut context = self.context.write().await;
if nick != &context.config.nick {
if nick != &self.config.nick {
return;
}
@ -44,7 +48,7 @@ impl Irc {
}
pub(crate) async fn event_quit<'a>(&mut self, prefix: &'a IrcPrefix<'a>) {
if prefix.nick != self.context.read().await.config.nick {
if prefix.nick != self.config.nick {
return;
}
@ -64,7 +68,7 @@ impl Irc {
channel: &str,
message: &str,
) {
let config = self.context.read().await.config.clone();
let config = self.config.clone();
if channel == &config.nick {
if message.ends_with(&format!("\x02{}\x02 isn't registered.", config.nick)) {
@ -87,12 +91,13 @@ impl Irc {
+ 1;
info!("Waiting {} seconds to register.", seconds);
/* TODO: fix this
let ctx_clone = self.context.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_secs(seconds as u64)).await;
ctx_clone.write().await.identify();
self.identify().await;
});
*/
}
}
}
@ -107,13 +112,13 @@ impl Irc {
let sys_name;
{
let context = self.context.read().await;
if !message.starts_with(&context.config.cmdkey) {
if !message.starts_with(&self.config.cmdkey) {
return;
}
elements = message.split_whitespace();
sys_name = elements.next().unwrap()[1..].to_owned();
if context.is_owner(prefix) && sys_name == "raw" {
if prefix.owner() && sys_name == "raw" {
drop(context);
let mut context = self.context.write().await;
context.queue(&elements.collect::<Vec<_>>().join(" "));
@ -127,14 +132,14 @@ impl Irc {
let arguments = elements.collect::<Vec<_>>();
let mut context = self.context.write().await;
if !context.systems.contains_key(&sys_name) {
let resp = context.run_default_system(prefix, &arguments).await;
if !self.systems.contains_key(&sys_name) {
let resp = self.run_default_system(prefix, channel, &arguments).await;
let Response::Data(data) = resp else {
return;
};
let mut context = self.context.write().await;
for (idx, line) in data.data.iter().enumerate() {
if idx == 0 && data.highlight {
context.privmsg(channel, &format!("{}: {}", prefix.nick, line))
@ -145,11 +150,12 @@ impl Irc {
return;
}
let response = context.run_system(prefix, &arguments, &sys_name).await;
let response = self.run_system(prefix, channel, &arguments, &sys_name).await;
let Response::Data(data) = response else {
return;
};
let mut context = self.context.write().await;
for (idx, line) in data.data.iter().enumerate() {
if idx == 0 && data.highlight {
context.privmsg(channel, &format!("{}: {}", prefix.nick, line))

View File

@ -26,7 +26,7 @@ use tokio::{
fs::File,
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf},
net::TcpStream,
sync::{mpsc, RwLock},
sync::RwLock,
};
pub(crate) const MAX_MSG_LEN: usize = 512;
@ -69,11 +69,30 @@ impl Default for FloodControl {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum IrcPrefixKind {
Owner,
Admin,
#[default]
User,
}
#[derive(Clone, Debug, Default)]
pub struct IrcPrefix<'a> {
pub nick: &'a str,
pub user: Option<&'a str>,
pub host: Option<&'a str>,
kind: IrcPrefixKind,
}
impl<'a> IrcPrefix<'a> {
pub fn owner(&self) -> bool {
self.kind == IrcPrefixKind::Owner
}
pub fn admin(&self) -> bool {
self.kind == IrcPrefixKind::Admin
}
}
impl<'a> From<&'a str> for IrcPrefix<'a> {
@ -107,6 +126,7 @@ impl<'a> From<&'a str> for IrcPrefix<'a> {
nick: nick,
user: Some(user),
host: Some(user_split[1]),
..Default::default()
}
}
}
@ -163,19 +183,11 @@ pub struct IrcConfig {
*/
pub struct Context {
config: IrcConfig,
identified: bool,
pub struct IrcContext {
send_queue: VecDeque<String>,
default_system: Option<StoredSystem>,
invalid_system: Option<StoredSystem>,
systems: HashMap<String, StoredSystem>,
tasks: Vec<(Duration, StoredSystem)>,
factory: Arc<RwLock<Factory>>,
}
impl Context {
impl IrcContext {
pub fn privmsg(&mut self, channel: &str, message: &str) {
debug!("sending privmsg to {} : {}", channel, message);
self.queue(&format!("PRIVMSG {} :{}", channel, message));
@ -200,136 +212,13 @@ impl Context {
}
}
pub fn identify(&mut self) {
if self.config.nickserv_pass.is_none() || self.identified {
return;
}
self.privmsg(
"NickServ",
&format!("IDENTIFY {}", self.config.nickserv_pass.as_ref().unwrap()),
);
}
pub fn register(&mut self) {
info!(
"Registering as {}!{} ({})",
self.config.nick, self.config.user, self.config.real
);
self.queue(&format!(
"USER {} 0 * {}",
self.config.user, self.config.real
));
self.queue(&format!("NICK {}", self.config.nick));
}
fn is_owner(&self, prefix: &IrcPrefix) -> bool {
self.is_admin(prefix, &self.config.owner)
}
fn is_admin(&self, prefix: &IrcPrefix, admin: &str) -> bool {
let admin = ":".to_owned() + &admin;
let admin_prefix: IrcPrefix = admin.as_str().into();
if (admin_prefix.nick == prefix.nick || admin_prefix.nick == "*")
&& (admin_prefix.user == prefix.user || admin_prefix.user == Some("*"))
&& (admin_prefix.host == prefix.host || admin_prefix.host == Some("*"))
{
return true;
}
false
}
fn join(&mut self, channel: &str) {
pub fn join(&mut self, channel: &str) {
info!("Joining {channel}");
self.queue(&format!("JOIN {}", channel));
self.config.channels.insert(channel.to_owned());
}
fn join_config_channels(&mut self) {
for i in 0..self.config.channels.len() {
let channel = self.config.channels.iter().nth(i).unwrap();
info!("Joining {channel}");
self.queue(&format!("JOIN {}", channel))
}
}
fn update_nick(&mut self, new_nick: &str) {
self.config.nick = new_nick.to_owned();
self.queue(&format!("NICK {}", self.config.nick));
}
pub fn privmsg_all(&mut self, message: &str) {
for i in 0..self.config.channels.len() {
let channel = self.config.channels.iter().nth(i).unwrap();
debug!("sending privmsg to {} : {}", channel, message);
self.queue(&format!("PRIVMSG {} :{}", channel, message));
}
}
pub async fn run_system<'a>(
&mut self,
prefix: &'a IrcPrefix<'a>,
arguments: &'a [&'a str],
name: &str,
) -> Response {
let system = self.systems.get_mut(name).unwrap();
system.run(prefix, arguments, &mut *self.factory.write().await)
}
pub async fn run_default_system<'a>(
&mut self,
prefix: &'a IrcPrefix<'a>,
arguments: &'a [&'a str],
) -> Response {
if self.invalid_system.is_none() {
return Response::Empty;
}
self.default_system.as_mut().unwrap().run(
prefix,
arguments,
&mut *self.factory.write().await,
)
}
pub async fn run_invalid_system<'a>(
&mut self,
prefix: &'a IrcPrefix<'a>,
arguments: &'a [&'a str],
) -> Response {
if self.invalid_system.is_none() {
return Response::Empty;
}
self.invalid_system.as_mut().unwrap().run(
prefix,
arguments,
&mut *self.factory.write().await,
)
}
pub async fn run_interval_tasks(&mut self, tx: mpsc::Sender<Response>) {
for (duration, mut task) in std::mem::take(&mut self.tasks) {
let fact = self.factory.clone();
let task_tx = tx.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(duration).await;
let resp = task.run(
&IrcPrefix {
nick: "",
user: None,
host: None,
},
&[],
&mut *fact.write().await,
);
task_tx.send(resp).await.unwrap();
}
});
}
pub fn nick(&mut self, nick: &str) {
self.queue(&format!("NICK {}", nick));
}
}
@ -338,10 +227,19 @@ pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Unpin {}
impl<T: AsyncRead + AsyncWrite + Send + Unpin> AsyncReadWrite for T {}
pub struct Irc {
context: Arc<RwLock<Context>>,
context: Arc<RwLock<IrcContext>>,
flood_controls: HashMap<String, FloodControl>,
stream: Option<Box<dyn AsyncReadWrite>>,
partial_line: String,
config: IrcConfig,
identified: bool,
default_system: Option<StoredSystem>,
invalid_system: Option<StoredSystem>,
systems: HashMap<String, StoredSystem>,
tasks: Vec<(Duration, StoredSystem)>,
factory: Arc<RwLock<Factory>>,
}
impl Irc {
@ -352,15 +250,8 @@ impl Irc {
let config: IrcConfig = serde_yaml::from_str(&contents).unwrap();
let context = Arc::new(RwLock::new(Context {
config,
identified: false,
let context = Arc::new(RwLock::new(IrcContext {
send_queue: VecDeque::new(),
default_system: None,
invalid_system: None,
systems: HashMap::default(),
tasks: Vec::new(),
factory: Arc::new(RwLock::new(Factory::default())),
}));
Ok(Self {
@ -368,6 +259,13 @@ impl Irc {
stream: None,
flood_controls: HashMap::default(),
partial_line: String::new(),
config,
identified: false,
default_system: None,
invalid_system: None,
systems: HashMap::default(),
tasks: Vec::new(),
factory: Arc::new(RwLock::new(Factory::default())),
})
}
@ -376,12 +274,8 @@ impl Irc {
name: &str,
system: impl for<'a> IntoSystem<I, System = S>,
) -> &mut Self {
{
let mut context = self.context.write().await;
context
.systems
.insert(name.to_owned(), Box::new(system.into_system()));
}
self.systems
.insert(name.to_owned(), Box::new(system.into_system()));
self
}
@ -389,10 +283,7 @@ impl Irc {
&mut self,
system: impl for<'a> IntoSystem<I, System = S>,
) -> &mut Self {
{
let mut context = self.context.write().await;
context.default_system = Some(Box::new(system.into_system()));
}
self.default_system = Some(Box::new(system.into_system()));
self
}
@ -400,10 +291,7 @@ impl Irc {
&mut self,
system: impl for<'a> IntoSystem<I, System = S>,
) -> &mut Self {
{
let mut context = self.context.write().await;
context.invalid_system = Some(Box::new(system.into_system()));
}
self.invalid_system = Some(Box::new(system.into_system()));
self
}
@ -412,10 +300,7 @@ impl Irc {
duration: Duration,
system: impl for<'a> IntoSystem<I, System = S>,
) -> &mut Self {
{
let mut context = self.context.write().await;
context.tasks.push((duration, Box::new(system.into_system())));
}
self.tasks.push((duration, Box::new(system.into_system())));
self
}
@ -423,30 +308,22 @@ impl Irc {
&mut self,
system: impl for<'a> IntoSystem<I, System = S>,
) -> &mut Self {
{
let mut context = self.context.write().await;
context.tasks.push((Duration::ZERO, Box::new(system.into_system())));
}
self.tasks
.push((Duration::ZERO, Box::new(system.into_system())));
self
}
pub async fn add_resource<R: Send + Sync + 'static>(&mut self, res: R) -> &mut Self {
{
let context = self.context.write().await;
context
.factory
.write()
.await
.resources
.insert(TypeId::of::<R>(), Box::new(res));
}
self.factory
.write()
.await
.resources
.insert(TypeId::of::<R>(), Box::new(res));
self
}
pub async fn connect(&mut self) -> std::io::Result<()> {
let context = self.context.read().await;
let domain = format!("{}:{}", context.config.host, context.config.port);
let domain = format!("{}:{}", self.config.host, self.config.port);
info!("Connecting to {}", domain);
@ -460,8 +337,8 @@ impl Irc {
let plain_stream = TcpStream::connect(sock).await?;
if context.config.ssl {
let stream = async_native_tls::connect(context.config.host.clone(), plain_stream)
if self.config.ssl {
let stream = async_native_tls::connect(self.config.host.clone(), plain_stream)
.await
.unwrap();
self.stream = Some(Box::new(stream));
@ -485,7 +362,7 @@ impl Irc {
let elapsed = flood_control.last_cmd.elapsed().unwrap();
if elapsed.as_secs_f32() < self.context.read().await.config.flood_interval {
if elapsed.as_secs_f32() < self.config.flood_interval {
warn!("they be floodin @ {channel}!");
return true;
}
@ -494,6 +371,50 @@ impl Irc {
false
}
fn is_owner(&self, prefix: &IrcPrefix) -> bool {
let owner = ":".to_owned() + &self.config.owner;
let owner_prefix: IrcPrefix = owner.as_str().into();
if (owner_prefix.nick == prefix.nick || owner_prefix.nick == "*")
&& (owner_prefix.user == prefix.user || owner_prefix.user == Some("*"))
&& (owner_prefix.host == prefix.host || owner_prefix.host == Some("*"))
{
return true;
}
false
}
fn is_admin(&self, prefix: &IrcPrefix) -> bool {
for admin_str in &self.config.admins {
let admin = ":".to_owned() + admin_str;
let admin_prefix: IrcPrefix = admin.as_str().into();
if (admin_prefix.nick == prefix.nick || admin_prefix.nick == "*")
&& (admin_prefix.user == prefix.user || admin_prefix.user == Some("*"))
&& (admin_prefix.host == prefix.host || admin_prefix.host == Some("*"))
{
return true;
}
}
false
}
pub fn into_message<'a>(&self, line: &'a str) -> IrcMessage<'a> {
let mut message: IrcMessage = line.into();
if let Some(prefix) = &mut message.prefix {
if self.is_owner(prefix) {
prefix.kind = IrcPrefixKind::Owner;
} else if self.is_admin(prefix) {
prefix.kind = IrcPrefixKind::Admin;
}
}
message
}
pub async fn handle_commands(&mut self, mut lines: VecDeque<String>) {
while lines.len() != 0 {
let owned_line = lines.pop_front().unwrap();
@ -501,7 +422,8 @@ impl Irc {
trace!("<< {:?}", line);
let message: IrcMessage = line.into();
let message = self.into_message(line);
self.handle_message(&message).await;
}
}
@ -548,27 +470,116 @@ impl Irc {
}
}
pub async fn register(&mut self) {
info!(
"Registering as {}!{} ({})",
self.config.nick, self.config.user, self.config.real
);
let mut context = self.context.write().await;
context.queue(&format!(
"USER {} 0 * {}",
self.config.user, self.config.real
));
context.nick(&self.config.nick);
}
pub async fn identify(&mut self) {
if self.config.nickserv_pass.is_none() || self.identified {
return;
}
self.context.write().await.privmsg(
"NickServ",
&format!("IDENTIFY {}", self.config.nickserv_pass.as_ref().unwrap()),
);
}
pub async fn run_system<'a>(
&mut self,
prefix: &'a IrcPrefix<'a>,
channel: &'a str,
arguments: &'a [&'a str],
name: &str,
) -> Response {
let system = self.systems.get_mut(name).unwrap();
system.run(
prefix,
channel,
arguments,
&mut *self.context.write().await,
&mut *self.factory.write().await,
)
}
pub async fn run_default_system<'a>(
&mut self,
prefix: &'a IrcPrefix<'a>,
channel: &'a str,
arguments: &'a [&'a str],
) -> Response {
if self.invalid_system.is_none() {
return Response::Empty;
}
self.default_system.as_mut().unwrap().run(
prefix,
channel,
arguments,
&mut *self.context.write().await,
&mut *self.factory.write().await,
)
}
pub async fn run_invalid_system<'a>(
&mut self,
prefix: &'a IrcPrefix<'a>,
channel: &'a str,
arguments: &'a [&'a str],
) -> Response {
if self.invalid_system.is_none() {
return Response::Empty;
}
self.invalid_system.as_mut().unwrap().run(
prefix,
channel,
arguments,
&mut *self.context.write().await,
&mut *self.factory.write().await,
)
}
pub async fn run_interval_tasks(&mut self) {
for (duration, mut task) in std::mem::take(&mut self.tasks) {
let fact = self.factory.clone();
let ctx = self.context.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(duration).await;
task.run(
&IrcPrefix::default(),
"",
&[],
&mut *ctx.write().await,
&mut *fact.write().await,
);
}
});
}
}
pub async fn run(&mut self) -> std::io::Result<()> {
self.connect().await?;
info!("Ready!");
let (tx, mut rx) = mpsc::channel::<Response>(512);
{
let mut context = self.context.write().await;
context.register();
context.run_interval_tasks(tx).await;
}
self.register().await;
self.run_interval_tasks().await;
let stream = self.stream.take().unwrap();
let (mut reader, mut writer) = tokio::io::split(stream);
let cloned_ctx = self.context.clone();
tokio::spawn(async move {
loop {
handle_rx(&mut rx, &cloned_ctx).await;
}
});
let cloned_ctx = self.context.clone();
tokio::spawn(async move {
loop {
@ -583,23 +594,9 @@ impl Irc {
}
}
async fn handle_rx(rx: &mut mpsc::Receiver<Response>, arc_context: &RwLock<Context>) {
while let Some(response) = rx.recv().await {
let mut context = arc_context.write().await;
let Response::Data(data) = response else {
continue;
};
for line in data.data {
context.privmsg_all(&line);
}
}
}
async fn send<T: AsyncWrite>(
writer: &mut WriteHalf<T>,
arc_context: &RwLock<Context>,
arc_context: &RwLock<IrcContext>,
) -> std::io::Result<()> {
let mut len;
{

View File

@ -1,6 +1,6 @@
use std::marker::PhantomData;
use crate::{factory::Factory, format::Msg, IrcPrefix};
use crate::{factory::Factory, format::Msg, IrcContext, IrcPrefix};
pub struct FunctionSystem<Input, F> {
f: F,
@ -8,7 +8,14 @@ pub struct FunctionSystem<Input, F> {
}
pub trait System {
fn run(&mut self, prefix: &IrcPrefix, arguments: &[&str], factory: &mut Factory) -> Response;
fn run(
&mut self,
prefix: &IrcPrefix,
channel: &str,
arguments: &[&str],
context: &mut IrcContext,
factory: &mut Factory,
) -> Response;
}
pub trait IntoSystem<Input> {
@ -29,7 +36,7 @@ macro_rules! impl_system {
FnMut( $($params),* ) -> R +
FnMut( $(<$params as SystemParam>::Item<'b>),* ) -> R
{
fn run(&mut self, prefix: &IrcPrefix, arguments: &[&str], factory: &mut Factory) -> Response {
fn run(&mut self, prefix: &IrcPrefix, channel: &str, arguments: &[&str], context: &mut IrcContext, factory: &mut Factory) -> Response {
fn call_inner<'a, R: IntoResponse, $($params),*>(
mut f: impl FnMut($($params),*) -> R,
$($params: $params),*
@ -42,7 +49,7 @@ macro_rules! impl_system {
return Response::InvalidArgument;
}
let $params = $params::retrieve(prefix, arguments, &factory);
let $params = $params::retrieve(prefix, channel, arguments, &context, &factory);
)*
@ -116,7 +123,9 @@ pub(crate) trait SystemParam {
type Item<'new>;
fn retrieve<'r>(
prefix: &'r IrcPrefix,
channel: &'r str,
arguments: &'r [&'r str],
context: &'r IrcContext,
factory: &'r Factory,
) -> Self::Item<'r>;
#[allow(unused_variables)]

View File

@ -4,7 +4,7 @@ use std::{
ops::{Deref, DerefMut},
};
use crate::{factory::Factory, system::SystemParam, IrcPrefix};
use crate::{factory::Factory, system::SystemParam, IrcContext, IrcPrefix};
#[derive(Debug)]
pub struct Res<'a, T: 'static> {
@ -30,7 +30,9 @@ impl<'res, T: 'static> SystemParam for Res<'res, T> {
fn retrieve<'r>(
_prefix: &'r IrcPrefix,
_channel: &str,
_arguments: &'r [&'r str],
_context: &'r IrcContext,
factory: &'r Factory,
) -> Self::Item<'r> {
Res {
@ -79,7 +81,9 @@ impl<'res, T: 'static> SystemParam for ResMut<'res, T> {
fn retrieve<'r>(
_prefix: &'r IrcPrefix,
_channel: &str,
_arguments: &'r [&'r str],
_context: &'r IrcContext,
factory: &'r Factory,
) -> Self::Item<'r> {
let const_ptr = &factory.resources as *const HashMap<TypeId, Box<dyn Any + Send + Sync>>;
@ -101,13 +105,38 @@ impl<'a> SystemParam for IrcPrefix<'a> {
fn retrieve<'r>(
prefix: &'r IrcPrefix,
_channel: &str,
_arguments: &'r [&'r str],
_context: &'r IrcContext,
_factory: &'r Factory,
) -> Self::Item<'r> {
prefix.clone()
}
}
pub struct Channel<'a>(&'a str);
impl<'a> Deref for Channel<'a> {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'a> SystemParam for Channel<'a> {
type Item<'new> = Channel<'new>;
fn retrieve<'r>(
_prefix: &'r IrcPrefix,
channel: &'r str,
_arguments: &'r [&'r str],
_context: &'r IrcContext,
_factory: &'r Factory,
) -> Self::Item<'r> {
Channel(channel)
}
}
pub struct AnyArguments<'a>(&'a [&'a str]);
impl<'a> Deref for AnyArguments<'a> {
@ -123,7 +152,9 @@ impl<'a> SystemParam for AnyArguments<'a> {
fn retrieve<'r>(
_prefix: &'r IrcPrefix,
_channel: &str,
arguments: &'r [&'r str],
_context: &'r IrcContext,
_factory: &'r Factory,
) -> Self::Item<'r> {
AnyArguments(&arguments)
@ -145,7 +176,9 @@ impl<'a, const N: usize> SystemParam for Arguments<'a, N> {
fn retrieve<'r>(
_prefix: &'r IrcPrefix,
_channel: &str,
arguments: &'r [&'r str],
_context: &'r IrcContext,
_factory: &'r Factory,
) -> Self::Item<'r> {
Arguments(&arguments[..N])
@ -155,3 +188,36 @@ impl<'a, const N: usize> SystemParam for Arguments<'a, N> {
arguments.len() == N
}
}
pub struct Context<'a>(&'a mut IrcContext);
impl<'a> Deref for Context<'a> {
type Target = IrcContext;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'a> DerefMut for Context<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<'a> SystemParam for Context<'a> {
type Item<'new> = Context<'new>;
fn retrieve<'r>(
_prefix: &'r IrcPrefix,
_channel: &str,
_arguments: &'r [&'r str],
context: &'r IrcContext,
_factory: &'r Factory,
) -> Self::Item<'r> {
let const_ptr = context as *const IrcContext;
let mut_ptr = const_ptr as *mut IrcContext;
let ctx_mut = unsafe { &mut *mut_ptr };
Context(ctx_mut)
}
}