From 03d5600da6161a681629fe8a789e3e5045303701 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Fri, 13 Mar 2020 15:12:44 +0100 Subject: [PATCH] Add support for SASL authentication We now store SASL credentials in the database and automatically populate them on NickServ REGISTER/IDENTIFY. References: https://todo.sr.ht/~emersion/jounce/10 --- db.go | 94 +++++++++++++++-------- downstream.go | 46 ++++++++++- go.mod | 1 + go.sum | 2 + irc.go | 9 +++ schema.sql | 3 + upstream.go | 208 ++++++++++++++++++++++++++++++++++++++++++++++---- 7 files changed, 314 insertions(+), 49 deletions(-) diff --git a/db.go b/db.go index b7de816..b88ebc2 100644 --- a/db.go +++ b/db.go @@ -12,6 +12,15 @@ type User struct { Password string // hashed } +type SASL struct { + Mechanism string + + Plain struct { + Username string + Password string + } +} + type Network struct { ID int64 Addr string @@ -19,6 +28,7 @@ type Network struct { Username string Realname string Pass string + SASL SASL } type Channel struct { @@ -45,6 +55,20 @@ func (db *DB) Close() error { return db.Close() } +func fromStringPtr(ptr *string) string { + if ptr == nil { + return "" + } + return *ptr +} + +func toStringPtr(s string) *string { + if s == "" { + return nil + } + return &s +} + func (db *DB) ListUsers() ([]User, error) { db.lock.RLock() defer db.lock.RUnlock() @@ -62,9 +86,7 @@ func (db *DB) ListUsers() ([]User, error) { if err := rows.Scan(&user.Username, &password); err != nil { return nil, err } - if password != nil { - user.Password = *password - } + user.Password = fromStringPtr(password) users = append(users, user) } if err := rows.Err(); err != nil { @@ -78,10 +100,7 @@ func (db *DB) CreateUser(user *User) error { db.lock.Lock() defer db.lock.Unlock() - var password *string - if user.Password != "" { - password = &user.Password - } + password := toStringPtr(user.Password) _, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password) return err } @@ -90,7 +109,11 @@ func (db *DB) ListNetworks(username string) ([]Network, error) { db.lock.RLock() defer db.lock.RUnlock() - rows, err := db.db.Query("SELECT id, addr, nick, username, realname, pass FROM Network WHERE user = ?", username) + rows, err := db.db.Query(`SELECT id, addr, nick, username, realname, pass, + sasl_mechanism, sasl_plain_username, sasl_plain_password + FROM Network + WHERE user = ?`, + username) if err != nil { return nil, err } @@ -100,18 +123,18 @@ func (db *DB) ListNetworks(username string) ([]Network, error) { for rows.Next() { var net Network var username, realname, pass *string - if err := rows.Scan(&net.ID, &net.Addr, &net.Nick, &username, &realname, &pass); err != nil { + var saslMechanism, saslPlainUsername, saslPlainPassword *string + err := rows.Scan(&net.ID, &net.Addr, &net.Nick, &username, &realname, + &pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword) + if err != nil { return nil, err } - if username != nil { - net.Username = *username - } - if realname != nil { - net.Realname = *realname - } - if pass != nil { - net.Pass = *pass - } + net.Username = fromStringPtr(username) + net.Realname = fromStringPtr(realname) + net.Pass = fromStringPtr(pass) + net.SASL.Mechanism = fromStringPtr(saslMechanism) + net.SASL.Plain.Username = fromStringPtr(saslPlainUsername) + net.SASL.Plain.Password = fromStringPtr(saslPlainPassword) networks = append(networks, net) } if err := rows.Err(); err != nil { @@ -125,29 +148,36 @@ func (db *DB) StoreNetwork(username string, network *Network) error { db.lock.Lock() defer db.lock.Unlock() - var netUsername, realname, pass *string - if network.Username != "" { - netUsername = &network.Username - } - if network.Realname != "" { - realname = &network.Realname - } - if network.Pass != "" { - pass = &network.Pass + netUsername := toStringPtr(network.Username) + realname := toStringPtr(network.Realname) + pass := toStringPtr(network.Pass) + + var saslMechanism, saslPlainUsername, saslPlainPassword *string + if network.SASL.Mechanism != "" { + saslMechanism = &network.SASL.Mechanism + switch network.SASL.Mechanism { + case "PLAIN": + saslPlainUsername = toStringPtr(network.SASL.Plain.Username) + saslPlainPassword = toStringPtr(network.SASL.Plain.Password) + } } var err error if network.ID != 0 { _, err = db.db.Exec(`UPDATE Network - SET addr = ?, nick = ?, username = ?, realname = ?, pass = ? + SET addr = ?, nick = ?, username = ?, realname = ?, pass = ?, + sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ? WHERE id = ?`, - network.Addr, network.Nick, netUsername, realname, pass, network.ID) + network.Addr, network.Nick, netUsername, realname, pass, + saslMechanism, saslPlainUsername, saslPlainPassword, network.ID) } else { var res sql.Result res, err = db.db.Exec(`INSERT INTO Network(user, addr, nick, username, - realname, pass) - VALUES (?, ?, ?, ?, ?, ?)`, - username, network.Addr, network.Nick, netUsername, realname, pass) + realname, pass, sasl_mechanism, sasl_plain_username, + sasl_plain_password) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + username, network.Addr, network.Nick, netUsername, realname, pass, + saslMechanism, saslPlainUsername, saslPlainPassword) if err != nil { return err } diff --git a/downstream.go b/downstream.go index 84f8e0b..10069f9 100644 --- a/downstream.go +++ b/downstream.go @@ -381,7 +381,7 @@ func (dc *downstreamConn) register() error { addr = addr + ":6697" } - dc.logger.Printf("trying to connect to new upstream server %q", addr) + dc.logger.Printf("trying to connect to new network %q", addr) if err := sanityCheckServer(addr); err != nil { dc.logger.Printf("failed to connect to %q: %v", addr, err) return ircError{&irc.Message{ @@ -390,7 +390,7 @@ func (dc *downstreamConn) register() error { }} } - dc.logger.Printf("auto-adding network %q", networkName) + dc.logger.Printf("auto-saving network %q", networkName) network, err = u.createNetwork(networkName, dc.nick) if err != nil { return err @@ -618,6 +618,10 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { return err } + if upstreamName == "NickServ" { + dc.handleNickServPRIVMSG(uc, text) + } + uc.SendMessage(&irc.Message{ Command: "PRIVMSG", Params: []string{upstreamName, text}, @@ -629,3 +633,41 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { } return nil } + +func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) { + username, password, ok := parseNickServCredentials(text, uc.nick) + if !ok { + return + } + + dc.logger.Printf("auto-saving NickServ credentials with username %q", username) + n := uc.network + n.SASL.Mechanism = "PLAIN" + n.SASL.Plain.Username = username + n.SASL.Plain.Password = password + if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil { + dc.logger.Printf("failed to save NickServ credentials: %v", err) + } +} + +func parseNickServCredentials(text, nick string) (username, password string, ok bool) { + fields := strings.Fields(text) + if len(fields) < 2 { + return "", "", false + } + cmd := strings.ToUpper(fields[0]) + params := fields[1:] + switch cmd { + case "REGISTER": + username = nick + password = params[0] + case "IDENTIFY": + if len(params) == 1 { + username = nick + } else { + username = params[0] + } + password = params[1] + } + return username, password, true +} diff --git a/go.mod b/go.mod index 13373b9..909f6cf 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module git.sr.ht/~emersion/jounce go 1.13 require ( + github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b github.com/mattn/go-sqlite3 v2.0.3+incompatible golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4 gopkg.in/irc.v3 v3.1.0 diff --git a/go.sum b/go.sum index 423a1f4..5547595 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b h1:uhWtEWBHgop1rqEk2klKaxPAkVDCXexai6hSuRQ7Nvs= +github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b/go.mod h1:G/dpzLu16WtQpBfQ/z3LYiYJn3ZhKSGWn83fyoyQe/k= github.com/mattn/go-sqlite3 v1.13.0 h1:LnJI81JidiW9r7pS/hXe6cFeO5EXNq7KbfvoJLRI69c= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= diff --git a/irc.go b/irc.go index d377ada..98df2ab 100644 --- a/irc.go +++ b/irc.go @@ -11,6 +11,15 @@ const ( rpl_localusers = "265" rpl_globalusers = "266" rpl_topicwhotime = "333" + rpl_loggedin = "900" + rpl_loggedout = "901" + err_nicklocked = "902" + rpl_saslsuccess = "903" + err_saslfail = "904" + err_sasltoolong = "905" + err_saslaborted = "906" + err_saslalready = "907" + rpl_saslmechs = "908" ) type modeSet string diff --git a/schema.sql b/schema.sql index b7c7ab1..0ba4907 100644 --- a/schema.sql +++ b/schema.sql @@ -11,6 +11,9 @@ CREATE TABLE Network ( username VARCHAR(255), realname VARCHAR(255), pass VARCHAR(255), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), FOREIGN KEY(user) REFERENCES User(username), UNIQUE(user, addr, nick) ); diff --git a/upstream.go b/upstream.go index a4d1e2f..480afa5 100644 --- a/upstream.go +++ b/upstream.go @@ -2,6 +2,7 @@ package jounce import ( "crypto/tls" + "encoding/base64" "fmt" "io" "net" @@ -9,6 +10,7 @@ import ( "strings" "time" + "github.com/emersion/go-sasl" "gopkg.in/irc.v3" ) @@ -48,6 +50,9 @@ type upstreamConn struct { channels map[string]*upstreamChannel history map[string]uint64 caps map[string]string + + saslClient sasl.Client + saslStarted bool } func connectToUpstream(network *network) (*upstreamConn, error) { @@ -169,28 +174,150 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { case "NOTICE": uc.logger.Print(msg) case "CAP": - if len(msg.Params) < 2 { - return newNeedMoreParamsError(msg.Command) + var subCmd string + if err := parseMessageParams(msg, nil, &subCmd); err != nil { + return err } - caps := strings.Fields(msg.Params[len(msg.Params)-1]) - more := msg.Params[len(msg.Params)-2] == "*" - - for _, s := range caps { - kv := strings.SplitN(s, "=", 2) - k := strings.ToLower(kv[0]) - var v string - if len(kv) >= 2 { - v = kv[1] + subCmd = strings.ToUpper(subCmd) + subParams := msg.Params[2:] + switch subCmd { + case "LS": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := strings.Fields(subParams[len(subParams)-1]) + more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*" + + for _, s := range caps { + kv := strings.SplitN(s, "=", 2) + k := strings.ToLower(kv[0]) + var v string + if len(kv) == 2 { + v = kv[1] + } + uc.caps[k] = v + } + + if more { + break // wait to receive all capabilities + } + + if uc.requestSASL() { + uc.SendMessage(&irc.Message{ + Command: "CAP", + Params: []string{"REQ", "sasl"}, + }) + break // we'll send CAP END after authentication is completed } - uc.caps[k] = v - } - if !more { uc.SendMessage(&irc.Message{ Command: "CAP", Params: []string{"END"}, }) + case "ACK", "NAK": + if len(subParams) < 1 { + return newNeedMoreParamsError(msg.Command) + } + caps := strings.Fields(subParams[0]) + + for _, name := range caps { + if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil { + return err + } + } + + if uc.saslClient == nil { + uc.SendMessage(&irc.Message{ + Command: "CAP", + Params: []string{"END"}, + }) + } + default: + uc.logger.Printf("unhandled message: %v", msg) } + case "AUTHENTICATE": + if uc.saslClient == nil { + return fmt.Errorf("received unexpected AUTHENTICATE message") + } + + // TODO: if a challenge is 400 bytes long, buffer it + var challengeStr string + if err := parseMessageParams(msg, &challengeStr); err != nil { + uc.SendMessage(&irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + + var challenge []byte + if challengeStr != "+" { + var err error + challenge, err = base64.StdEncoding.DecodeString(challengeStr) + if err != nil { + uc.SendMessage(&irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + } + + var resp []byte + var err error + if !uc.saslStarted { + _, resp, err = uc.saslClient.Start() + uc.saslStarted = true + } else { + resp, err = uc.saslClient.Next(challenge) + } + if err != nil { + uc.SendMessage(&irc.Message{ + Command: "AUTHENTICATE", + Params: []string{"*"}, + }) + return err + } + + // TODO: send response in multiple chunks if >= 400 bytes + var respStr = "+" + if resp != nil { + respStr = base64.StdEncoding.EncodeToString(resp) + } + + uc.SendMessage(&irc.Message{ + Command: "AUTHENTICATE", + Params: []string{respStr}, + }) + case rpl_loggedin: + var account string + if err := parseMessageParams(msg, nil, nil, &account); err != nil { + return err + } + uc.logger.Printf("logged in with account %q", account) + case rpl_loggedout: + uc.logger.Printf("logged out") + case err_nicklocked, rpl_saslsuccess, err_saslfail, err_sasltoolong, err_saslaborted: + var info string + if err := parseMessageParams(msg, nil, &info); err != nil { + return err + } + switch msg.Command { + case err_nicklocked: + uc.logger.Printf("invalid nick used with SASL authentication: %v", info) + case err_saslfail: + uc.logger.Printf("SASL authentication failed: %v", info) + case err_sasltoolong: + uc.logger.Printf("SASL message too long: %v", info) + } + + uc.saslClient = nil + uc.saslStarted = false + + uc.SendMessage(&irc.Message{ + Command: "CAP", + Params: []string{"END"}, + }) case irc.RPL_WELCOME: uc.registered = true uc.logger.Printf("connection registered") @@ -439,7 +566,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { case irc.RPL_STATSVLINE, irc.RPL_STATSPING, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE: // Ignore default: - uc.logger.Printf("unhandled upstream message: %v", msg) + uc.logger.Printf("unhandled message: %v", msg) } return nil } @@ -477,6 +604,57 @@ func (uc *upstreamConn) register() { }) } +func (uc *upstreamConn) requestSASL() bool { + if uc.network.SASL.Mechanism == "" { + return false + } + + v, ok := uc.caps["sasl"] + if !ok { + return false + } + if v != "" { + mechanisms := strings.Split(v, ",") + found := false + for _, mech := range mechanisms { + if strings.EqualFold(mech, uc.network.SASL.Mechanism) { + found = true + break + } + } + if !found { + return false + } + } + + return true +} + +func (uc *upstreamConn) handleCapAck(name string, ok bool) error { + auth := &uc.network.SASL + switch name { + case "sasl": + if !ok { + uc.logger.Printf("server refused to acknowledge the SASL capability") + return nil + } + + switch auth.Mechanism { + case "PLAIN": + uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username) + uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password) + default: + return fmt.Errorf("unsupported SASL mechanism %q", name) + } + + uc.SendMessage(&irc.Message{ + Command: "AUTHENTICATE", + Params: []string{auth.Mechanism}, + }) + } + return nil +} + func (uc *upstreamConn) readMessages() error { for { msg, err := uc.irc.ReadMessage()