Compare commits

..

3 Commits

Author SHA1 Message Date
sad
f02b9a237b
verifiy connections 2024-03-02 10:16:29 -07:00
sad
500ce3b59e
refactor+cleanup 2024-03-02 05:23:47 -07:00
sad
e83663ffe5
base caps+sasl implementation 2024-03-02 04:54:16 -07:00
4 changed files with 185 additions and 62 deletions

View File

@ -13,6 +13,7 @@ tokio-native-tls = "0.3.1"
native-tls = "0.2.11"
rand = "0.8.5"
toml = "0.8.10"
base64 = "0.22.0"
serde = { version = "1.0.197", features = ["derive"] }
colored = "2.1.0"
futures = "0.3.30"

View File

@ -2,5 +2,7 @@ server = "irc.supernets.org"
port = 6697
use_ssl = true
nickname = "g1r"
channel = "#superbowl"
channel = "#dev"
sasl_username = "g1r"
sasl_password = "fuckyou.lol"
capabilities = ["sasl"]

View File

@ -4,7 +4,9 @@ use tokio_native_tls::native_tls::TlsConnector as NTlsConnector;
use tokio_native_tls::TlsConnector;
use tokio::sync::mpsc;
use serde::Deserialize;
use tokio_rustls::rustls::Writer;
use std::fs;
use std::future::IntoFuture;
use colored::*;
#[derive(Deserialize)]
@ -14,80 +16,150 @@ struct Config {
use_ssl: bool,
nickname: String,
channel: String,
sasl_username: Option<String>,
sasl_password: Option<String>,
capabilities: Option<Vec<String>>,
}
mod mods {
pub mod sasl;
}
use mods::sasl::{start_sasl_auth, handle_sasl_messages};
#[tokio::main(flavor = "multi_thread", worker_threads = 12)]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Loading Config...");
let config_contents = fs::read_to_string("config.toml").expect("Error reading config.toml");
let config: Config = toml::from_str(&config_contents).expect("Error parsing config.toml");
let config = loaded_config().expect("Error parsing config.toml");
println!("Config loaded!");
let addr = format!("{}:{}", config.server, config.port);
println!("Connecting to {}...", addr.green());
let tcp_stream = TcpStream::connect(&addr).await?;
println!("Connected to {}!", addr.green());
if config.use_ssl {
let tcp_stream = TcpStream::connect(format!("{}:{}", config.server, config.port)).await?;
println!("Connected to {}!", format!("{}:{}", config.server, config.port).green());
println!("Establishing TLS connection...");
let mut tls_builder = NTlsConnector::builder();
tls_builder.danger_accept_invalid_certs(true);
let tls_connector = TlsConnector::from(tls_builder.build()?);
let domain = &config.server;
let tls_stream = tls_connector.connect(domain, tcp_stream).await?;
let mut tls_stream = tls_exec (&config, tcp_stream).await?;
println!("TLS connection established!");
tls_stream.flush().await?;
let (reader, writer) = split(tls_stream);
let (tx, mut rx) = mpsc::channel(1000);
// Spawn a task to handle reading
let read_task = tokio::spawn(async move {
let mut reader = reader;
let mut buf = vec![0; 4096];
loop {
let n = match reader.read(&mut buf).await {
Ok(0) => return, // connection was closed
Ok(n) => n,
Err(e) => {
eprintln!("Error reading from socket: {:?}", e);
return;
}
};
let msg = String::from_utf8_lossy(&buf[..n]).to_string();
if tx.send(msg).await.is_err() {
eprintln!("Error sending message to the channel");
return;
}
}
});
let write_task = tokio::spawn(async move {
let mut writer = writer;
writer.write_all(format!("NICK {}\r\n", config.nickname).as_bytes()).await.unwrap();
writer.write_all(format!("USER {} 0 * :{}\r\n", config.nickname, config.nickname).as_bytes()).await.unwrap();
writer.write_all(format!("JOIN {}\r\n", config.channel).as_bytes()).await.unwrap();
writer.flush().await.unwrap();
while let Some(msg) = rx.recv().await {
// handle messages better
println!("{} {}", "[%] DEBUG:".bold().green(), msg.purple());
if msg.starts_with("PING") {
writer.write_all(format!("PONG {}\r\n", &msg[5..]).as_bytes()).await.unwrap();
}
// super dirty auto-rejoin on kick REWRITE THIS
if let Some(pos) = msg.find(" KICK ") {
let parts: Vec<&str> = msg[pos..].split_whitespace().collect();
if parts.len() > 3 && parts[2] == config.nickname {
writer.write_all(format!("JOIN {}\r\n", config.channel).as_bytes()).await.unwrap();
}
}
}
});
let _ = tokio::try_join!(read_task, write_task);
handler(tls_stream, config).await?;
} else {
println!("Non-SSL connection not implemented.");
}
Ok(())
}
/// Load the config file
fn loaded_config() -> Result<Config, Box<dyn std::error::Error>> {
let config_contents = fs::read_to_string("config.toml")?;
let config: Config = toml::from_str(&config_contents)?;
Ok(config)
}
/// Establish a TLS connection to the server
async fn tls_exec(config: &Config, tcp_stream: TcpStream) -> Result<tokio_native_tls::TlsStream<TcpStream>, Box<dyn std::error::Error>> {
let tls_builder = NTlsConnector::builder().danger_accept_invalid_certs(true).build()?;
let tls_connector = TlsConnector::from(tls_builder);
Ok(tls_connector.connect(&config.server, tcp_stream).await?)
}
/// Handle the connection to the server
async fn handler(tls_stream: tokio_native_tls::TlsStream<TcpStream>, config: Config) -> Result<(), Box<dyn std::error::Error>> {
let (reader, writer) = split(tls_stream);
let (tx, rx) = mpsc::channel(1000);
let read_task = tokio::spawn(async move {
readmsg(reader, tx).await;
});
let write_task = tokio::spawn(async move {
writemsg(writer, rx, &config).await;
});
let _ = tokio::try_join!(read_task, write_task);
Ok(())
}
/// Read messages from the server
async fn readmsg(mut reader: tokio::io::ReadHalf<tokio_native_tls::TlsStream<TcpStream>>, tx: tokio::sync::mpsc::Sender<String>) {
let mut buf = vec![0; 4096];
while let Ok (n) = reader.read(&mut buf).await {
if n == 0 { break; }
let msg = String::from_utf8_lossy(&buf[..n]).to_string();
// must pretty this up later
println!{"{}{}{} {}{} {}", "[".green().bold(), ">".yellow().bold(), "]".green().bold(), "DEBUG:".bold().yellow(), ":".bold().green(), msg.purple()};
tx.send(msg).await.unwrap();
}
}
/// Write messages to the server
async fn writemsg(mut writer: tokio::io::WriteHalf<tokio_native_tls::TlsStream<TcpStream>>, mut rx: tokio::sync::mpsc::Receiver<String>, config: &Config) {
// sasl auth
let capabilities = config.capabilities.clone();
let username = config.sasl_username.clone().unwrap();
let password = config.sasl_password.clone().unwrap();
let nickname = config.nickname.clone();
if !password.is_empty() {
println!("Starting SASL auth...");
start_sasl_auth(&mut writer, "PLAIN", &nickname, capabilities).await.unwrap();
writer.flush().await.unwrap();
} else {
nickme(&mut writer, &nickname).await.unwrap();
}
writer.flush().await.unwrap();
// THIS NEEDS TO BE REBUILT TO BE MORE MODULAR AND SECURE
while let Some(msg) = rx.recv().await {
if msg.starts_with("PING") {
let response = msg.replace("PING", "PONG");
println!("{} {} {}","[%] PONG:".bold().green(), nickname.blue(), response.purple());
writer.write_all(response.as_bytes()).await.unwrap();
writer.flush().await.unwrap();
//continue;
}
// handle sasl auth
if !password.is_empty(){
println!("Handling SASL messages...");
handle_sasl_messages(&mut writer, &msg, &username, &password, &nickname).await.unwrap();
//continue;
writer.flush().await.unwrap();
}
// new commands here
if msg.contains("001") {
println!("Setting mode");
writer.write_all(format!("MODE {} +B\r\n", nickname).as_bytes()).await.unwrap();
writer.flush().await.unwrap();
}
if msg.contains("433") {
println!("Nickname already in use, appending _ to nickname");
let new_nick = format!("{}_", nickname);
nickme(&mut writer, &new_nick).await.unwrap();
writer.flush().await.unwrap();
}
if msg.contains("376") {
println!("Joining channel");
writer.write_all(format!("JOIN {}\r\n", config.channel).as_bytes()).await.unwrap();
writer.flush().await.unwrap();
}
}
}
async fn nickme<W: tokio::io::AsyncWriteExt + Unpin>(writer: &mut W, nickname: &str) -> Result<(), Box<dyn std::error::Error>> {
writer.write_all(format!("NICK {}\r\n", nickname).as_bytes()).await?;
writer.flush().await?;
writer.write_all(format!("USER {} 0 * :{}\r\n", nickname, nickname).as_bytes()).await?;
writer.flush().await?;
Ok(())
}

48
src/mods/sasl.rs Normal file
View File

@ -0,0 +1,48 @@
// mods/sasl.rs
use base64::Engine;
pub async fn start_sasl_auth<W: tokio::io::AsyncWriteExt + Unpin>(
writer: &mut W,
mechanism: &str,
nickname: &str,
capabilities: Option<Vec<String>>) -> Result<(), Box<dyn std::error::Error>> {
writer.write_all(b"CAP LS 302\r\n").await?;
let nick_cmd = format!("NICK {}\r\n", nickname);
writer.write_all(nick_cmd.as_bytes()).await?;
let user_cmd = format!("USER {} 0 * :{}\r\n", nickname, nickname);
writer.write_all(user_cmd.as_bytes()).await?;
if let Some(caps) = capabilities {
if !caps.is_empty() {
let cap_req_cmd = format!("CAP REQ :{}\r\n", caps.join(" "));
writer.write_all(cap_req_cmd.as_bytes()).await?;
}
} else {
writer.write_all(b"CAP REQ :sasl\r\n").await?;
}
writer.flush().await?;
Ok(())
}
pub async fn handle_sasl_messages<W: tokio::io::AsyncWriteExt + Unpin>(
writer: &mut W,
message: &str,
username: &str,
password: &str,
nickname: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let nick = format!("CAP {} ACK :sasl", nickname.to_string());
if message.contains(&nick) {
writer.write_all(b"AUTHENTICATE PLAIN\r\n").await?;
} else if message.starts_with("AUTHENTICATE +") {
let auth_string = format!("\0{}\0{}", username, password);
let encoded = base64::engine::general_purpose::STANDARD.encode(auth_string);
writer.write_all(format!("AUTHENTICATE {}\r\n", encoded).as_bytes()).await?;
} else if message.contains("903 * :SASL authentication successful") {
writer.write_all(b"CAP END\r\n").await?;
}
writer.flush().await?;
Ok(())
}