Add support for SASL authentication

We now store SASL credentials in the database and automatically populate
them on NickServ REGISTER/IDENTIFY.

References: https://todo.sr.ht/~emersion/jounce/10
This commit is contained in:
Simon Ser 2020-03-13 15:12:44 +01:00
parent dad8bc2173
commit 03d5600da6
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
7 changed files with 314 additions and 49 deletions

94
db.go
View File

@ -12,6 +12,15 @@ type User struct {
Password string // hashed
}
type SASL struct {
Mechanism string
Plain struct {
Username string
Password string
}
}
type Network struct {
ID int64
Addr string
@ -19,6 +28,7 @@ type Network struct {
Username string
Realname string
Pass string
SASL SASL
}
type Channel struct {
@ -45,6 +55,20 @@ func (db *DB) Close() error {
return db.Close()
}
func fromStringPtr(ptr *string) string {
if ptr == nil {
return ""
}
return *ptr
}
func toStringPtr(s string) *string {
if s == "" {
return nil
}
return &s
}
func (db *DB) ListUsers() ([]User, error) {
db.lock.RLock()
defer db.lock.RUnlock()
@ -62,9 +86,7 @@ func (db *DB) ListUsers() ([]User, error) {
if err := rows.Scan(&user.Username, &password); err != nil {
return nil, err
}
if password != nil {
user.Password = *password
}
user.Password = fromStringPtr(password)
users = append(users, user)
}
if err := rows.Err(); err != nil {
@ -78,10 +100,7 @@ func (db *DB) CreateUser(user *User) error {
db.lock.Lock()
defer db.lock.Unlock()
var password *string
if user.Password != "" {
password = &user.Password
}
password := toStringPtr(user.Password)
_, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
return err
}
@ -90,7 +109,11 @@ func (db *DB) ListNetworks(username string) ([]Network, error) {
db.lock.RLock()
defer db.lock.RUnlock()
rows, err := db.db.Query("SELECT id, addr, nick, username, realname, pass FROM Network WHERE user = ?", username)
rows, err := db.db.Query(`SELECT id, addr, nick, username, realname, pass,
sasl_mechanism, sasl_plain_username, sasl_plain_password
FROM Network
WHERE user = ?`,
username)
if err != nil {
return nil, err
}
@ -100,18 +123,18 @@ func (db *DB) ListNetworks(username string) ([]Network, error) {
for rows.Next() {
var net Network
var username, realname, pass *string
if err := rows.Scan(&net.ID, &net.Addr, &net.Nick, &username, &realname, &pass); err != nil {
var saslMechanism, saslPlainUsername, saslPlainPassword *string
err := rows.Scan(&net.ID, &net.Addr, &net.Nick, &username, &realname,
&pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
if err != nil {
return nil, err
}
if username != nil {
net.Username = *username
}
if realname != nil {
net.Realname = *realname
}
if pass != nil {
net.Pass = *pass
}
net.Username = fromStringPtr(username)
net.Realname = fromStringPtr(realname)
net.Pass = fromStringPtr(pass)
net.SASL.Mechanism = fromStringPtr(saslMechanism)
net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
networks = append(networks, net)
}
if err := rows.Err(); err != nil {
@ -125,29 +148,36 @@ func (db *DB) StoreNetwork(username string, network *Network) error {
db.lock.Lock()
defer db.lock.Unlock()
var netUsername, realname, pass *string
if network.Username != "" {
netUsername = &network.Username
}
if network.Realname != "" {
realname = &network.Realname
}
if network.Pass != "" {
pass = &network.Pass
netUsername := toStringPtr(network.Username)
realname := toStringPtr(network.Realname)
pass := toStringPtr(network.Pass)
var saslMechanism, saslPlainUsername, saslPlainPassword *string
if network.SASL.Mechanism != "" {
saslMechanism = &network.SASL.Mechanism
switch network.SASL.Mechanism {
case "PLAIN":
saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
}
}
var err error
if network.ID != 0 {
_, err = db.db.Exec(`UPDATE Network
SET addr = ?, nick = ?, username = ?, realname = ?, pass = ?
SET addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
WHERE id = ?`,
network.Addr, network.Nick, netUsername, realname, pass, network.ID)
network.Addr, network.Nick, netUsername, realname, pass,
saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
} else {
var res sql.Result
res, err = db.db.Exec(`INSERT INTO Network(user, addr, nick, username,
realname, pass)
VALUES (?, ?, ?, ?, ?, ?)`,
username, network.Addr, network.Nick, netUsername, realname, pass)
realname, pass, sasl_mechanism, sasl_plain_username,
sasl_plain_password)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
username, network.Addr, network.Nick, netUsername, realname, pass,
saslMechanism, saslPlainUsername, saslPlainPassword)
if err != nil {
return err
}

View File

@ -381,7 +381,7 @@ func (dc *downstreamConn) register() error {
addr = addr + ":6697"
}
dc.logger.Printf("trying to connect to new upstream server %q", addr)
dc.logger.Printf("trying to connect to new network %q", addr)
if err := sanityCheckServer(addr); err != nil {
dc.logger.Printf("failed to connect to %q: %v", addr, err)
return ircError{&irc.Message{
@ -390,7 +390,7 @@ func (dc *downstreamConn) register() error {
}}
}
dc.logger.Printf("auto-adding network %q", networkName)
dc.logger.Printf("auto-saving network %q", networkName)
network, err = u.createNetwork(networkName, dc.nick)
if err != nil {
return err
@ -618,6 +618,10 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
return err
}
if upstreamName == "NickServ" {
dc.handleNickServPRIVMSG(uc, text)
}
uc.SendMessage(&irc.Message{
Command: "PRIVMSG",
Params: []string{upstreamName, text},
@ -629,3 +633,41 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}
return nil
}
func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
username, password, ok := parseNickServCredentials(text, uc.nick)
if !ok {
return
}
dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
n := uc.network
n.SASL.Mechanism = "PLAIN"
n.SASL.Plain.Username = username
n.SASL.Plain.Password = password
if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
dc.logger.Printf("failed to save NickServ credentials: %v", err)
}
}
func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
fields := strings.Fields(text)
if len(fields) < 2 {
return "", "", false
}
cmd := strings.ToUpper(fields[0])
params := fields[1:]
switch cmd {
case "REGISTER":
username = nick
password = params[0]
case "IDENTIFY":
if len(params) == 1 {
username = nick
} else {
username = params[0]
}
password = params[1]
}
return username, password, true
}

1
go.mod
View File

@ -3,6 +3,7 @@ module git.sr.ht/~emersion/jounce
go 1.13
require (
github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b
github.com/mattn/go-sqlite3 v2.0.3+incompatible
golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4
gopkg.in/irc.v3 v3.1.0

2
go.sum
View File

@ -1,3 +1,5 @@
github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b h1:uhWtEWBHgop1rqEk2klKaxPAkVDCXexai6hSuRQ7Nvs=
github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b/go.mod h1:G/dpzLu16WtQpBfQ/z3LYiYJn3ZhKSGWn83fyoyQe/k=
github.com/mattn/go-sqlite3 v1.13.0 h1:LnJI81JidiW9r7pS/hXe6cFeO5EXNq7KbfvoJLRI69c=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=

9
irc.go
View File

@ -11,6 +11,15 @@ const (
rpl_localusers = "265"
rpl_globalusers = "266"
rpl_topicwhotime = "333"
rpl_loggedin = "900"
rpl_loggedout = "901"
err_nicklocked = "902"
rpl_saslsuccess = "903"
err_saslfail = "904"
err_sasltoolong = "905"
err_saslaborted = "906"
err_saslalready = "907"
rpl_saslmechs = "908"
)
type modeSet string

View File

@ -11,6 +11,9 @@ CREATE TABLE Network (
username VARCHAR(255),
realname VARCHAR(255),
pass VARCHAR(255),
sasl_mechanism VARCHAR(255),
sasl_plain_username VARCHAR(255),
sasl_plain_password VARCHAR(255),
FOREIGN KEY(user) REFERENCES User(username),
UNIQUE(user, addr, nick)
);

View File

@ -2,6 +2,7 @@ package jounce
import (
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"net"
@ -9,6 +10,7 @@ import (
"strings"
"time"
"github.com/emersion/go-sasl"
"gopkg.in/irc.v3"
)
@ -48,6 +50,9 @@ type upstreamConn struct {
channels map[string]*upstreamChannel
history map[string]uint64
caps map[string]string
saslClient sasl.Client
saslStarted bool
}
func connectToUpstream(network *network) (*upstreamConn, error) {
@ -169,28 +174,150 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
case "NOTICE":
uc.logger.Print(msg)
case "CAP":
if len(msg.Params) < 2 {
return newNeedMoreParamsError(msg.Command)
var subCmd string
if err := parseMessageParams(msg, nil, &subCmd); err != nil {
return err
}
caps := strings.Fields(msg.Params[len(msg.Params)-1])
more := msg.Params[len(msg.Params)-2] == "*"
for _, s := range caps {
kv := strings.SplitN(s, "=", 2)
k := strings.ToLower(kv[0])
var v string
if len(kv) >= 2 {
v = kv[1]
subCmd = strings.ToUpper(subCmd)
subParams := msg.Params[2:]
switch subCmd {
case "LS":
if len(subParams) < 1 {
return newNeedMoreParamsError(msg.Command)
}
caps := strings.Fields(subParams[len(subParams)-1])
more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
for _, s := range caps {
kv := strings.SplitN(s, "=", 2)
k := strings.ToLower(kv[0])
var v string
if len(kv) == 2 {
v = kv[1]
}
uc.caps[k] = v
}
if more {
break // wait to receive all capabilities
}
if uc.requestSASL() {
uc.SendMessage(&irc.Message{
Command: "CAP",
Params: []string{"REQ", "sasl"},
})
break // we'll send CAP END after authentication is completed
}
uc.caps[k] = v
}
if !more {
uc.SendMessage(&irc.Message{
Command: "CAP",
Params: []string{"END"},
})
case "ACK", "NAK":
if len(subParams) < 1 {
return newNeedMoreParamsError(msg.Command)
}
caps := strings.Fields(subParams[0])
for _, name := range caps {
if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
return err
}
}
if uc.saslClient == nil {
uc.SendMessage(&irc.Message{
Command: "CAP",
Params: []string{"END"},
})
}
default:
uc.logger.Printf("unhandled message: %v", msg)
}
case "AUTHENTICATE":
if uc.saslClient == nil {
return fmt.Errorf("received unexpected AUTHENTICATE message")
}
// TODO: if a challenge is 400 bytes long, buffer it
var challengeStr string
if err := parseMessageParams(msg, &challengeStr); err != nil {
uc.SendMessage(&irc.Message{
Command: "AUTHENTICATE",
Params: []string{"*"},
})
return err
}
var challenge []byte
if challengeStr != "+" {
var err error
challenge, err = base64.StdEncoding.DecodeString(challengeStr)
if err != nil {
uc.SendMessage(&irc.Message{
Command: "AUTHENTICATE",
Params: []string{"*"},
})
return err
}
}
var resp []byte
var err error
if !uc.saslStarted {
_, resp, err = uc.saslClient.Start()
uc.saslStarted = true
} else {
resp, err = uc.saslClient.Next(challenge)
}
if err != nil {
uc.SendMessage(&irc.Message{
Command: "AUTHENTICATE",
Params: []string{"*"},
})
return err
}
// TODO: send response in multiple chunks if >= 400 bytes
var respStr = "+"
if resp != nil {
respStr = base64.StdEncoding.EncodeToString(resp)
}
uc.SendMessage(&irc.Message{
Command: "AUTHENTICATE",
Params: []string{respStr},
})
case rpl_loggedin:
var account string
if err := parseMessageParams(msg, nil, nil, &account); err != nil {
return err
}
uc.logger.Printf("logged in with account %q", account)
case rpl_loggedout:
uc.logger.Printf("logged out")
case err_nicklocked, rpl_saslsuccess, err_saslfail, err_sasltoolong, err_saslaborted:
var info string
if err := parseMessageParams(msg, nil, &info); err != nil {
return err
}
switch msg.Command {
case err_nicklocked:
uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
case err_saslfail:
uc.logger.Printf("SASL authentication failed: %v", info)
case err_sasltoolong:
uc.logger.Printf("SASL message too long: %v", info)
}
uc.saslClient = nil
uc.saslStarted = false
uc.SendMessage(&irc.Message{
Command: "CAP",
Params: []string{"END"},
})
case irc.RPL_WELCOME:
uc.registered = true
uc.logger.Printf("connection registered")
@ -439,7 +566,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
case irc.RPL_STATSVLINE, irc.RPL_STATSPING, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
// Ignore
default:
uc.logger.Printf("unhandled upstream message: %v", msg)
uc.logger.Printf("unhandled message: %v", msg)
}
return nil
}
@ -477,6 +604,57 @@ func (uc *upstreamConn) register() {
})
}
func (uc *upstreamConn) requestSASL() bool {
if uc.network.SASL.Mechanism == "" {
return false
}
v, ok := uc.caps["sasl"]
if !ok {
return false
}
if v != "" {
mechanisms := strings.Split(v, ",")
found := false
for _, mech := range mechanisms {
if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
found = true
break
}
}
if !found {
return false
}
}
return true
}
func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
auth := &uc.network.SASL
switch name {
case "sasl":
if !ok {
uc.logger.Printf("server refused to acknowledge the SASL capability")
return nil
}
switch auth.Mechanism {
case "PLAIN":
uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
default:
return fmt.Errorf("unsupported SASL mechanism %q", name)
}
uc.SendMessage(&irc.Message{
Command: "AUTHENTICATE",
Params: []string{auth.Mechanism},
})
}
return nil
}
func (uc *upstreamConn) readMessages() error {
for {
msg, err := uc.irc.ReadMessage()