diff --git a/src/main.rs b/src/main.rs index 190d4b3..28a825f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,9 +4,8 @@ use tokio_native_tls::native_tls::TlsConnector as NTlsConnector; use tokio_native_tls::TlsConnector; use tokio::sync::mpsc; use serde::Deserialize; -use tokio_rustls::rustls::Writer; use std::fs; -use std::future::IntoFuture; +use std::sync::atomic::{AtomicBool, Ordering}; use colored::*; #[derive(Deserialize)] @@ -28,24 +27,25 @@ use mods::sasl::{start_sasl_auth, handle_sasl_messages}; #[tokio::main(flavor = "multi_thread", worker_threads = 12)] async fn main() -> Result<(), Box> { + tokio::spawn(async move { println!("Loading Config..."); let config = loaded_config().expect("Error parsing config.toml"); println!("Config loaded!"); if config.use_ssl { - let tcp_stream = TcpStream::connect(format!("{}:{}", config.server, config.port)).await?; + let tcp_stream = TcpStream::connect(format!("{}:{}", config.server, config.port)).await; println!("Connected to {}!", format!("{}:{}", config.server, config.port).green()); println!("Establishing TLS connection..."); - let mut tls_stream = tls_exec (&config, tcp_stream).await?; + let mut tls_stream = tls_exec (&config, tcp_stream.unwrap()).await.unwrap(); println!("TLS connection established!"); - tls_stream.flush().await?; + tls_stream.flush().await.unwrap(); - handler(tls_stream, config).await?; + handler(tls_stream, config).await.unwrap(); } else { println!("Non-SSL connection not implemented."); } - + }).await.unwrap(); Ok(()) } /// Load the config file @@ -89,73 +89,88 @@ async fn readmsg(mut reader: tokio::io::ReadHalf".yellow().bold(), "]".green().bold(), "DEBUG:".bold().yellow(), ":".bold().green(), msg.purple()}; - - tx.send(msg).await.unwrap(); + let msg_list = String::from_utf8_lossy(&buf[..n]).to_string(); + for lines in msg_list.lines() { + let msg = lines.to_string(); + println!("{}{}{} {}{} {}", "[".green().bold(), ">".yellow().bold(), "]".green().bold(), "DEBUG:".bold().yellow(), ":".bold().green(), msg.trim().purple()); + tx.send(msg).await.unwrap(); + if buf.len() == n { + buf.resize(buf.len() * 2, 0); + } + } } } + + +static SASL_AUTH: AtomicBool = AtomicBool::new(false); /// Write messages to the server async fn writemsg(mut writer: tokio::io::WriteHalf>, mut rx: tokio::sync::mpsc::Receiver, config: &Config) { // sasl auth - let capabilities = config.capabilities.clone(); + //let capabilities = config.capabilities.clone(); let username = config.sasl_username.clone().unwrap(); let password = config.sasl_password.clone().unwrap(); let nickname = config.nickname.clone(); - - - if !password.is_empty() { + if !password.is_empty() && !SASL_AUTH.load(Ordering::Relaxed) { + let capabilities = config.capabilities.clone(); println!("Starting SASL auth..."); start_sasl_auth(&mut writer, "PLAIN", &nickname, capabilities).await.unwrap(); writer.flush().await.unwrap(); + SASL_AUTH.store(true, Ordering::Relaxed); } else { nickme(&mut writer, &nickname).await.unwrap(); + writer.flush().await.unwrap(); } + //writer.flush().await.unwrap(); + //let msg = rx.recv().await.unwrap(); + //let msg = msg.trim(); + //let parts = msg.split(' ').collect::>(); - writer.flush().await.unwrap(); // THIS NEEDS TO BE REBUILT TO BE MORE MODULAR AND SECURE while let Some(msg) = rx.recv().await { + let msg = msg.trim(); + if msg.is_empty() { + continue; + } + let parts = msg.split(' ').collect::>(); + let serv = parts.first().unwrap_or(&""); + let cmd = parts.get(1).unwrap_or(&""); - if msg.starts_with("PING") { - let response = msg.replace("PING", "PONG"); + + + println!("{} {} {} {} {}", "DEBUG:".bold().yellow(), "serv:".bold().green(), serv.purple(), "cmd:".bold().green(), cmd.purple()); + if *serv == "PING" { + let response = msg.replace("PING", "PONG") + "\r\n"; println!("{} {} {}","[%] PONG:".bold().green(), nickname.blue(), response.purple()); writer.write_all(response.as_bytes()).await.unwrap(); writer.flush().await.unwrap(); - //continue; + continue; } - // handle sasl auth - if !password.is_empty(){ + if (*cmd == "CAP" || msg.starts_with("AUTHENTICATE +") || *cmd == "903") && SASL_AUTH.load(Ordering::Relaxed) { println!("Handling SASL messages..."); - handle_sasl_messages(&mut writer, &msg, &username, &password, &nickname).await.unwrap(); - //continue; + handle_sasl_messages(&mut writer, msg.trim(), &username, &password, &nickname).await.unwrap(); writer.flush().await.unwrap(); - } - - // new commands here - if msg.contains("001") { + } + if *cmd == "001" { println!("Setting mode"); writer.write_all(format!("MODE {} +B\r\n", nickname).as_bytes()).await.unwrap(); writer.flush().await.unwrap(); } - - - if msg.contains("433") { - println!("Nickname already in use, appending _ to nickname"); - let new_nick = format!("{}_", nickname); - nickme(&mut writer, &new_nick).await.unwrap(); - writer.flush().await.unwrap(); - } - if msg.contains("376") { + + if *cmd == "376" { println!("Joining channel"); writer.write_all(format!("JOIN {}\r\n", config.channel).as_bytes()).await.unwrap(); writer.flush().await.unwrap(); } - - } + if *cmd == "PRIVMSG" { + let channel = parts[2]; + let user = parts[0]; + let host = user.split_at(user.find('!').unwrap()); + let msg = parts[3..].join(" ").replace(':', ""); + println!("{}{}{} {}{} {} {} {}", "[".green().bold(), ">".yellow().bold(), "]".green().bold(), "PRIVMSG:".bold().yellow(), ":".bold().green(), channel.yellow(), user.blue(), msg.purple()); + } + } } - async fn nickme(writer: &mut W, nickname: &str) -> Result<(), Box> { writer.write_all(format!("NICK {}\r\n", nickname).as_bytes()).await?; writer.flush().await?; diff --git a/src/mods/sasl.rs b/src/mods/sasl.rs index 0ac0bdd..bdef9d8 100644 --- a/src/mods/sasl.rs +++ b/src/mods/sasl.rs @@ -1,4 +1,5 @@ // mods/sasl.rs +use crate::nickme; use base64::Engine; pub async fn start_sasl_auth( writer: &mut W, @@ -7,10 +8,7 @@ pub async fn start_sasl_auth( capabilities: Option>) -> Result<(), Box> { writer.write_all(b"CAP LS 302\r\n").await?; - let nick_cmd = format!("NICK {}\r\n", nickname); - writer.write_all(nick_cmd.as_bytes()).await?; - let user_cmd = format!("USER {} 0 * :{}\r\n", nickname, nickname); - writer.write_all(user_cmd.as_bytes()).await?; + nickme(writer, nickname).await?; if let Some(caps) = capabilities { if !caps.is_empty() { @@ -20,7 +18,7 @@ pub async fn start_sasl_auth( } else { writer.write_all(b"CAP REQ :sasl\r\n").await?; } - + //println!("Handling SASL messages..."); writer.flush().await?; Ok(()) } @@ -32,8 +30,7 @@ pub async fn handle_sasl_messages( password: &str, nickname: &str, ) -> Result<(), Box> { - let nick = format!("CAP {} ACK :sasl", nickname.to_string()); - if message.contains(&nick) { + if message.contains(format!("CAP {} ACK :sasl", nickname).as_str()) { writer.write_all(b"AUTHENTICATE PLAIN\r\n").await?; } else if message.starts_with("AUTHENTICATE +") { let auth_string = format!("\0{}\0{}", username, password); @@ -45,4 +42,3 @@ pub async fn handle_sasl_messages( writer.flush().await?; Ok(()) } -