diff --git a/src/main.rs b/src/main.rs index b76ef76..190d4b3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,7 +4,9 @@ use tokio_native_tls::native_tls::TlsConnector as NTlsConnector; use tokio_native_tls::TlsConnector; use tokio::sync::mpsc; use serde::Deserialize; +use tokio_rustls::rustls::Writer; use std::fs; +use std::future::IntoFuture; use colored::*; #[derive(Deserialize)] @@ -16,7 +18,7 @@ struct Config { channel: String, sasl_username: Option, sasl_password: Option, - capabilities: Option> + capabilities: Option>, } mod mods { @@ -35,99 +37,129 @@ async fn main() -> Result<(), Box> { println!("Connected to {}!", format!("{}:{}", config.server, config.port).green()); println!("Establishing TLS connection..."); - let tls_stream = tls_exec (&config, tcp_stream).await?; + let mut tls_stream = tls_exec (&config, tcp_stream).await?; println!("TLS connection established!"); + tls_stream.flush().await?; - handler(tls_stream, &config).await?; + handler(tls_stream, config).await?; } else { println!("Non-SSL connection not implemented."); } Ok(()) } - +/// Load the config file fn loaded_config() -> Result> { let config_contents = fs::read_to_string("config.toml")?; - //let config_contents = fs::read_to_string("config.toml").expect("Error reading config.toml"); let config: Config = toml::from_str(&config_contents)?; - //let config: Config = toml::from_str(&config_contents).expect("Error parsing config.toml"); Ok(config) } -//async fn tls_exec(config: &Config, tcp_stream: TcpStream) -> Result, Box> { -// let mut tls_builder = NTlsConnector::builder(); -// tls_builder.danger_accept_invalid_certs(true); -// let tls_connector = TlsConnector::from(tls_builder.build()?); -// let domain = &config.server; -// let tls_stream = tls_connector.connect(domain, tcp_stream).await?; -// println!("TLS connection established!"); -// Ok(tls_stream) -//} - +/// Establish a TLS connection to the server async fn tls_exec(config: &Config, tcp_stream: TcpStream) -> Result, Box> { let tls_builder = NTlsConnector::builder().danger_accept_invalid_certs(true).build()?; let tls_connector = TlsConnector::from(tls_builder); Ok(tls_connector.connect(&config.server, tcp_stream).await?) } -async fn handler(tls_stream: tokio_native_tls::TlsStream, config: &Config) -> Result<(), Box> { -//async fn handler(mut tls_stream: tokio_native_tls::TlsStream, config: &Config) -> Result<(), Box> { - let (mut reader, mut writer) = split(tls_stream); - let (tx, mut rx) = mpsc::channel(1000); +/// Handle the connection to the server +async fn handler(tls_stream: tokio_native_tls::TlsStream, config: Config) -> Result<(), Box> { + let (reader, writer) = split(tls_stream); + let (tx, rx) = mpsc::channel(1000); + + + let read_task = tokio::spawn(async move { - let mut buf = vec![0; 4096]; - while let Ok(n) = reader.read(&mut buf).await { - if n == 0 { break; } // connection killed x.x - let msg = String::from_utf8_lossy(&buf[..n]).to_string(); - if tx.send(msg).await.is_err() { break; } // channel killed x.x - } + readmsg(reader, tx).await; }); - //let read_task = tokio::spawn(async move { - // let mut buf = vec![0; 4096]; - // loop { - // let n = match reader.read(&mut buf).await { - // Ok(0) => return, // connection killed x.x - // Ok(n) => n, - // Err(e) => { - // eprintln!("Error reading from socket: {:?}", e); - // return; - // }, - // }; - // - // let msg = String::from_utf8_lossy(&buf[..n]).to_string(); - // if tx.send(msg).await.is_err() { - // eprintln!("Error sending message to the channel"); - // return; - // } - // } - //}); - // + let write_task = tokio::spawn(async move { - while let Some(msg) = rx.recv().await { - // new commands here - if msg.starts_with("PING") { - writer.write_all(format!("PONG {}\r\n", &msg[5..]).as_bytes()).await.unwrap(); - } - } + writemsg(writer, rx, &config).await; }); - //let write_task = tokio::spawn(async move { - // while let Some(msg) = rx.recv().await { - // if msg.starts_with("PING") { - // writer.write_all(format!("PONG {}\r\n", &msg[5..]).as_bytes()).await.unwrap(); - // } - // if let Some(username) = &config.sasl_username { - // if let Some(password) = &config.sasl_password { - // handle_sasl_messages(&mut writer, &msg, username, password).await.unwrap(); - // } - // } - // } - //}); - let _ = tokio::try_join!(read_task, write_task); - + Ok(()) } +/// Read messages from the server +async fn readmsg(mut reader: tokio::io::ReadHalf>, tx: tokio::sync::mpsc::Sender) { + let mut buf = vec![0; 4096]; + while let Ok (n) = reader.read(&mut buf).await { + if n == 0 { break; } + let msg = String::from_utf8_lossy(&buf[..n]).to_string(); + // must pretty this up later + println!{"{}{}{} {}{} {}", "[".green().bold(), ">".yellow().bold(), "]".green().bold(), "DEBUG:".bold().yellow(), ":".bold().green(), msg.purple()}; + + tx.send(msg).await.unwrap(); + } +} + +/// Write messages to the server +async fn writemsg(mut writer: tokio::io::WriteHalf>, mut rx: tokio::sync::mpsc::Receiver, config: &Config) { + // sasl auth + let capabilities = config.capabilities.clone(); + let username = config.sasl_username.clone().unwrap(); + let password = config.sasl_password.clone().unwrap(); + let nickname = config.nickname.clone(); + + + if !password.is_empty() { + println!("Starting SASL auth..."); + start_sasl_auth(&mut writer, "PLAIN", &nickname, capabilities).await.unwrap(); + writer.flush().await.unwrap(); + } else { + nickme(&mut writer, &nickname).await.unwrap(); + } + + writer.flush().await.unwrap(); + // THIS NEEDS TO BE REBUILT TO BE MORE MODULAR AND SECURE + while let Some(msg) = rx.recv().await { + + if msg.starts_with("PING") { + let response = msg.replace("PING", "PONG"); + println!("{} {} {}","[%] PONG:".bold().green(), nickname.blue(), response.purple()); + writer.write_all(response.as_bytes()).await.unwrap(); + writer.flush().await.unwrap(); + //continue; + } + // handle sasl auth + if !password.is_empty(){ + println!("Handling SASL messages..."); + handle_sasl_messages(&mut writer, &msg, &username, &password, &nickname).await.unwrap(); + //continue; + writer.flush().await.unwrap(); + } + + // new commands here + if msg.contains("001") { + println!("Setting mode"); + writer.write_all(format!("MODE {} +B\r\n", nickname).as_bytes()).await.unwrap(); + writer.flush().await.unwrap(); + } + + + if msg.contains("433") { + println!("Nickname already in use, appending _ to nickname"); + let new_nick = format!("{}_", nickname); + nickme(&mut writer, &new_nick).await.unwrap(); + writer.flush().await.unwrap(); + } + if msg.contains("376") { + println!("Joining channel"); + writer.write_all(format!("JOIN {}\r\n", config.channel).as_bytes()).await.unwrap(); + writer.flush().await.unwrap(); + } + + } +} + +async fn nickme(writer: &mut W, nickname: &str) -> Result<(), Box> { + writer.write_all(format!("NICK {}\r\n", nickname).as_bytes()).await?; + writer.flush().await?; + writer.write_all(format!("USER {} 0 * :{}\r\n", nickname, nickname).as_bytes()).await?; + writer.flush().await?; + Ok(()) +} diff --git a/src/mods/sasl.rs b/src/mods/sasl.rs index 020c1ba..0ac0bdd 100644 --- a/src/mods/sasl.rs +++ b/src/mods/sasl.rs @@ -1,29 +1,23 @@ // mods/sasl.rs use base64::Engine; -use tokio::io::AsyncWriteExt; -/// Sends the initial commands to negotiate capabilities and start SASL authentication. pub async fn start_sasl_auth( -//pub async fn start_sasl_auth(...) -> Result<(), Box> { writer: &mut W, mechanism: &str, nickname: &str, - capabilities: &[String], // Add a parameter for capabilities -) -> Result<(), Box> { - // Request a list of capabilities from the server + capabilities: Option>) -> Result<(), Box> { writer.write_all(b"CAP LS 302\r\n").await?; - // Send NICK and USER commands let nick_cmd = format!("NICK {}\r\n", nickname); writer.write_all(nick_cmd.as_bytes()).await?; let user_cmd = format!("USER {} 0 * :{}\r\n", nickname, nickname); writer.write_all(user_cmd.as_bytes()).await?; - // Request specific capabilities, including 'sasl' for SASL authentication - if !capabilities.is_empty() { - let cap_req_cmd = format!("CAP REQ :{}\r\n", capabilities.join(" ")); - writer.write_all(cap_req_cmd.as_bytes()).await?; + if let Some(caps) = capabilities { + if !caps.is_empty() { + let cap_req_cmd = format!("CAP REQ :{}\r\n", caps.join(" ")); + writer.write_all(cap_req_cmd.as_bytes()).await?; + } } else { - // If no specific capabilities are requested, directly request 'sasl' writer.write_all(b"CAP REQ :sasl\r\n").await?; } @@ -31,15 +25,15 @@ pub async fn start_sasl_auth( Ok(()) } -/// Continues the SASL authentication process based on the server's responses. -//pub async fn handle_sasl_messages(...) -> Result<(), Box> { pub async fn handle_sasl_messages( writer: &mut W, message: &str, username: &str, password: &str, + nickname: &str, ) -> Result<(), Box> { - if message.contains("CAP * ACK :sasl") { + let nick = format!("CAP {} ACK :sasl", nickname.to_string()); + if message.contains(&nick) { writer.write_all(b"AUTHENTICATE PLAIN\r\n").await?; } else if message.starts_with("AUTHENTICATE +") { let auth_string = format!("\0{}\0{}", username, password);