From 087cef1ffe97185daf3dd4ad8627eafefee394b3 Mon Sep 17 00:00:00 2001 From: sad Date: Sat, 2 Mar 2024 03:01:30 -0700 Subject: [PATCH] add multithreading + rewrote async --- Cargo.toml | 6 +-- src/main.rs | 105 ++++++++++++++++++++++++++++++---------------------- 2 files changed, 63 insertions(+), 48 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8a7557c..725dbf5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,11 +7,11 @@ edition = "2021" [dependencies] tokio = { version = "1.36.0", features = ["full"] } -tokio-openssl = "0.6.4" tokio-socks = "0.5.1" -openssl = "0.10.64" +tokio-rustls = "0.25.0" +tokio-native-tls = "0.3.1" +native-tls = "0.2.11" rand = "0.8.5" -regex = "1.10.3" toml = "0.8.10" serde = { version = "1.0.197", features = ["derive"] } colored = "2.1.0" diff --git a/src/main.rs b/src/main.rs index f72e4cf..55c5f18 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,11 @@ -use colored::*; -use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; +use tokio::io::{split, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio_native_tls::native_tls::TlsConnector as NTlsConnector; +use tokio_native_tls::TlsConnector; +use tokio::sync::mpsc; use serde::Deserialize; use std::fs; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}; -use tokio::net::TcpStream; -use tokio_openssl::SslStream; -use std::pin::Pin; +use colored::*; #[derive(Deserialize)] struct Config { @@ -16,8 +16,8 @@ struct Config { channel: String, } -#[tokio::main] -async fn main() -> Result<(), Box<(dyn std::error::Error)>> { +#[tokio::main(flavor = "multi_thread", worker_threads = 12)] +async fn main() -> Result<(), Box> { println!("Loading Config..."); let config_contents = fs::read_to_string("config.toml").expect("Error reading config.toml"); let config: Config = toml::from_str(&config_contents).expect("Error parsing config.toml"); @@ -25,51 +25,66 @@ async fn main() -> Result<(), Box<(dyn std::error::Error)>> { let addr = format!("{}:{}", config.server, config.port); println!("Connecting to {}...", addr.green()); - let tcp_stream = TcpStream::connect(&addr).await.unwrap(); + let tcp_stream = TcpStream::connect(&addr).await?; println!("Connected to {}!", addr.green()); if config.use_ssl { - println!("Establishing SSL connection..."); - let mut connector = SslConnector::builder(SslMethod::tls()).unwrap().build().configure().unwrap().into_ssl(&addr).unwrap(); - connector.set_verify(SslVerifyMode::NONE); - let mut ssl_stream = SslStream::new(connector, tcp_stream).unwrap(); + println!("Establishing TLS connection..."); + let mut tls_builder = NTlsConnector::builder(); + tls_builder.danger_accept_invalid_certs(true); + let tls_connector = TlsConnector::from(tls_builder.build()?); + let domain = &config.server; + let tls_stream = tls_connector.connect(domain, tcp_stream).await?; + println!("TLS connection established!"); - // Perform the SSL handshake - match Pin::new(&mut ssl_stream).connect().await { - Ok(_) => println!("SSL connection established!"), - Err(e) => { - println!("Error establishing SSL connection: {:?}", e); - return Err(Box::new(e) as Box); + let (reader, writer) = split(tls_stream); + let (tx, mut rx) = mpsc::channel(1000); + // Spawn a task to handle reading + let read_task = tokio::spawn(async move { + let mut reader = reader; + let mut buf = vec![0; 4096]; + loop { + let n = match reader.read(&mut buf).await { + Ok(0) => return, // connection was closed + Ok(n) => n, + Err(e) => { + eprintln!("Error reading from socket: {:?}", e); + return; + } + }; + + let msg = String::from_utf8_lossy(&buf[..n]).to_string(); + if tx.send(msg).await.is_err() { + eprintln!("Error sending message to the channel"); + return; + } } - }; + }); - println!("Sending NICK and USER commands..."); - ssl_stream.write_all(format!("NICK {}\r\n", config.nickname).as_bytes()).await.unwrap(); - ssl_stream.write_all(format!("USER {} 0 * :{}\r\n", config.nickname, config.nickname).as_bytes()).await.unwrap(); - ssl_stream.write_all(format!("JOIN {}\r\n", config.channel).as_bytes()).await.unwrap(); + let write_task = tokio::spawn(async move { + let mut writer = writer; + writer.write_all(format!("NICK {}\r\n", config.nickname).as_bytes()).await.unwrap(); + writer.write_all(format!("USER {} 0 * :{}\r\n", config.nickname, config.nickname).as_bytes()).await.unwrap(); + writer.write_all(format!("JOIN {}\r\n", config.channel).as_bytes()).await.unwrap(); + writer.flush().await.unwrap(); - - let (read_half, write_half) = tokio::io::split(ssl_stream); - - // split the stream and then transfer this for non-ssl - let mut reader = BufReader::new(read_half); - let mut writer = BufWriter::new(write_half); - let mut lines = reader.lines(); - - while let Some(result) = lines.next_line().await.unwrap() { - - let received = String::from_utf8_lossy(result.as_bytes()).trim().to_string(); - println!("{} {}","[%] DEBUG:".bold().green(), received.purple()); - - let message = received.trim(); - if message.starts_with("PING") { - println!("Sending PONG..."); - let response = message.replace("PING", "PONG"); - println!("{} {}","[%] PONG:".bold().green(), config.nickname.blue()); - writer.write_all(response.as_bytes()).await.unwrap(); - continue; + while let Some(msg) = rx.recv().await { + // handle messages better + println!("{} {}", "[%] DEBUG:".bold().green(), msg.purple()); + if msg.starts_with("PING") { + writer.write_all(format!("PONG {}\r\n", &msg[5..]).as_bytes()).await.unwrap(); + } + // super dirty auto-rejoin on kick REWRITE THIS + if let Some(pos) = msg.find(" KICK ") { + let parts: Vec<&str> = msg[pos..].split_whitespace().collect(); + if parts.len() > 3 && parts[2] == config.nickname { + writer.write_all(format!("JOIN {}\r\n", config.channel).as_bytes()).await.unwrap(); + } + } } - } + }); + + let _ = tokio::try_join!(read_task, write_task); } else { println!("Non-SSL connection not implemented."); }