add multithreading + rewrote async
This commit is contained in:
parent
0c848899a7
commit
087cef1ffe
@ -7,11 +7,11 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.36.0", features = ["full"] }
|
||||
tokio-openssl = "0.6.4"
|
||||
tokio-socks = "0.5.1"
|
||||
openssl = "0.10.64"
|
||||
tokio-rustls = "0.25.0"
|
||||
tokio-native-tls = "0.3.1"
|
||||
native-tls = "0.2.11"
|
||||
rand = "0.8.5"
|
||||
regex = "1.10.3"
|
||||
toml = "0.8.10"
|
||||
serde = { version = "1.0.197", features = ["derive"] }
|
||||
colored = "2.1.0"
|
||||
|
105
src/main.rs
105
src/main.rs
@ -1,11 +1,11 @@
|
||||
use colored::*;
|
||||
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
|
||||
use tokio::io::{split, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_native_tls::native_tls::TlsConnector as NTlsConnector;
|
||||
use tokio_native_tls::TlsConnector;
|
||||
use tokio::sync::mpsc;
|
||||
use serde::Deserialize;
|
||||
use std::fs;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_openssl::SslStream;
|
||||
use std::pin::Pin;
|
||||
use colored::*;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Config {
|
||||
@ -16,8 +16,8 @@ struct Config {
|
||||
channel: String,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<(dyn std::error::Error)>> {
|
||||
#[tokio::main(flavor = "multi_thread", worker_threads = 12)]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("Loading Config...");
|
||||
let config_contents = fs::read_to_string("config.toml").expect("Error reading config.toml");
|
||||
let config: Config = toml::from_str(&config_contents).expect("Error parsing config.toml");
|
||||
@ -25,51 +25,66 @@ async fn main() -> Result<(), Box<(dyn std::error::Error)>> {
|
||||
|
||||
let addr = format!("{}:{}", config.server, config.port);
|
||||
println!("Connecting to {}...", addr.green());
|
||||
let tcp_stream = TcpStream::connect(&addr).await.unwrap();
|
||||
let tcp_stream = TcpStream::connect(&addr).await?;
|
||||
println!("Connected to {}!", addr.green());
|
||||
|
||||
if config.use_ssl {
|
||||
println!("Establishing SSL connection...");
|
||||
let mut connector = SslConnector::builder(SslMethod::tls()).unwrap().build().configure().unwrap().into_ssl(&addr).unwrap();
|
||||
connector.set_verify(SslVerifyMode::NONE);
|
||||
let mut ssl_stream = SslStream::new(connector, tcp_stream).unwrap();
|
||||
println!("Establishing TLS connection...");
|
||||
let mut tls_builder = NTlsConnector::builder();
|
||||
tls_builder.danger_accept_invalid_certs(true);
|
||||
let tls_connector = TlsConnector::from(tls_builder.build()?);
|
||||
let domain = &config.server;
|
||||
let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
|
||||
println!("TLS connection established!");
|
||||
|
||||
// Perform the SSL handshake
|
||||
match Pin::new(&mut ssl_stream).connect().await {
|
||||
Ok(_) => println!("SSL connection established!"),
|
||||
Err(e) => {
|
||||
println!("Error establishing SSL connection: {:?}", e);
|
||||
return Err(Box::new(e) as Box<dyn std::error::Error>);
|
||||
let (reader, writer) = split(tls_stream);
|
||||
let (tx, mut rx) = mpsc::channel(1000);
|
||||
// Spawn a task to handle reading
|
||||
let read_task = tokio::spawn(async move {
|
||||
let mut reader = reader;
|
||||
let mut buf = vec![0; 4096];
|
||||
loop {
|
||||
let n = match reader.read(&mut buf).await {
|
||||
Ok(0) => return, // connection was closed
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
eprintln!("Error reading from socket: {:?}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let msg = String::from_utf8_lossy(&buf[..n]).to_string();
|
||||
if tx.send(msg).await.is_err() {
|
||||
eprintln!("Error sending message to the channel");
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
println!("Sending NICK and USER commands...");
|
||||
ssl_stream.write_all(format!("NICK {}\r\n", config.nickname).as_bytes()).await.unwrap();
|
||||
ssl_stream.write_all(format!("USER {} 0 * :{}\r\n", config.nickname, config.nickname).as_bytes()).await.unwrap();
|
||||
ssl_stream.write_all(format!("JOIN {}\r\n", config.channel).as_bytes()).await.unwrap();
|
||||
let write_task = tokio::spawn(async move {
|
||||
let mut writer = writer;
|
||||
writer.write_all(format!("NICK {}\r\n", config.nickname).as_bytes()).await.unwrap();
|
||||
writer.write_all(format!("USER {} 0 * :{}\r\n", config.nickname, config.nickname).as_bytes()).await.unwrap();
|
||||
writer.write_all(format!("JOIN {}\r\n", config.channel).as_bytes()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
|
||||
let (read_half, write_half) = tokio::io::split(ssl_stream);
|
||||
|
||||
// split the stream and then transfer this for non-ssl
|
||||
let mut reader = BufReader::new(read_half);
|
||||
let mut writer = BufWriter::new(write_half);
|
||||
let mut lines = reader.lines();
|
||||
|
||||
while let Some(result) = lines.next_line().await.unwrap() {
|
||||
|
||||
let received = String::from_utf8_lossy(result.as_bytes()).trim().to_string();
|
||||
println!("{} {}","[%] DEBUG:".bold().green(), received.purple());
|
||||
|
||||
let message = received.trim();
|
||||
if message.starts_with("PING") {
|
||||
println!("Sending PONG...");
|
||||
let response = message.replace("PING", "PONG");
|
||||
println!("{} {}","[%] PONG:".bold().green(), config.nickname.blue());
|
||||
writer.write_all(response.as_bytes()).await.unwrap();
|
||||
continue;
|
||||
while let Some(msg) = rx.recv().await {
|
||||
// handle messages better
|
||||
println!("{} {}", "[%] DEBUG:".bold().green(), msg.purple());
|
||||
if msg.starts_with("PING") {
|
||||
writer.write_all(format!("PONG {}\r\n", &msg[5..]).as_bytes()).await.unwrap();
|
||||
}
|
||||
// super dirty auto-rejoin on kick REWRITE THIS
|
||||
if let Some(pos) = msg.find(" KICK ") {
|
||||
let parts: Vec<&str> = msg[pos..].split_whitespace().collect();
|
||||
if parts.len() > 3 && parts[2] == config.nickname {
|
||||
writer.write_all(format!("JOIN {}\r\n", config.channel).as_bytes()).await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let _ = tokio::try_join!(read_task, write_task);
|
||||
} else {
|
||||
println!("Non-SSL connection not implemented.");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user