Add support for SASL authentication
We now store SASL credentials in the database and automatically populate them on NickServ REGISTER/IDENTIFY. References: https://todo.sr.ht/~emersion/jounce/10
This commit is contained in:
parent
dad8bc2173
commit
03d5600da6
92
db.go
92
db.go
@ -12,6 +12,15 @@ type User struct {
|
||||
Password string // hashed
|
||||
}
|
||||
|
||||
type SASL struct {
|
||||
Mechanism string
|
||||
|
||||
Plain struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
}
|
||||
|
||||
type Network struct {
|
||||
ID int64
|
||||
Addr string
|
||||
@ -19,6 +28,7 @@ type Network struct {
|
||||
Username string
|
||||
Realname string
|
||||
Pass string
|
||||
SASL SASL
|
||||
}
|
||||
|
||||
type Channel struct {
|
||||
@ -45,6 +55,20 @@ func (db *DB) Close() error {
|
||||
return db.Close()
|
||||
}
|
||||
|
||||
func fromStringPtr(ptr *string) string {
|
||||
if ptr == nil {
|
||||
return ""
|
||||
}
|
||||
return *ptr
|
||||
}
|
||||
|
||||
func toStringPtr(s string) *string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
func (db *DB) ListUsers() ([]User, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
@ -62,9 +86,7 @@ func (db *DB) ListUsers() ([]User, error) {
|
||||
if err := rows.Scan(&user.Username, &password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if password != nil {
|
||||
user.Password = *password
|
||||
}
|
||||
user.Password = fromStringPtr(password)
|
||||
users = append(users, user)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
@ -78,10 +100,7 @@ func (db *DB) CreateUser(user *User) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
var password *string
|
||||
if user.Password != "" {
|
||||
password = &user.Password
|
||||
}
|
||||
password := toStringPtr(user.Password)
|
||||
_, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
|
||||
return err
|
||||
}
|
||||
@ -90,7 +109,11 @@ func (db *DB) ListNetworks(username string) ([]Network, error) {
|
||||
db.lock.RLock()
|
||||
defer db.lock.RUnlock()
|
||||
|
||||
rows, err := db.db.Query("SELECT id, addr, nick, username, realname, pass FROM Network WHERE user = ?", username)
|
||||
rows, err := db.db.Query(`SELECT id, addr, nick, username, realname, pass,
|
||||
sasl_mechanism, sasl_plain_username, sasl_plain_password
|
||||
FROM Network
|
||||
WHERE user = ?`,
|
||||
username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -100,18 +123,18 @@ func (db *DB) ListNetworks(username string) ([]Network, error) {
|
||||
for rows.Next() {
|
||||
var net Network
|
||||
var username, realname, pass *string
|
||||
if err := rows.Scan(&net.ID, &net.Addr, &net.Nick, &username, &realname, &pass); err != nil {
|
||||
var saslMechanism, saslPlainUsername, saslPlainPassword *string
|
||||
err := rows.Scan(&net.ID, &net.Addr, &net.Nick, &username, &realname,
|
||||
&pass, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if username != nil {
|
||||
net.Username = *username
|
||||
}
|
||||
if realname != nil {
|
||||
net.Realname = *realname
|
||||
}
|
||||
if pass != nil {
|
||||
net.Pass = *pass
|
||||
}
|
||||
net.Username = fromStringPtr(username)
|
||||
net.Realname = fromStringPtr(realname)
|
||||
net.Pass = fromStringPtr(pass)
|
||||
net.SASL.Mechanism = fromStringPtr(saslMechanism)
|
||||
net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
|
||||
net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
|
||||
networks = append(networks, net)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
@ -125,29 +148,36 @@ func (db *DB) StoreNetwork(username string, network *Network) error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
var netUsername, realname, pass *string
|
||||
if network.Username != "" {
|
||||
netUsername = &network.Username
|
||||
netUsername := toStringPtr(network.Username)
|
||||
realname := toStringPtr(network.Realname)
|
||||
pass := toStringPtr(network.Pass)
|
||||
|
||||
var saslMechanism, saslPlainUsername, saslPlainPassword *string
|
||||
if network.SASL.Mechanism != "" {
|
||||
saslMechanism = &network.SASL.Mechanism
|
||||
switch network.SASL.Mechanism {
|
||||
case "PLAIN":
|
||||
saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
|
||||
saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
|
||||
}
|
||||
if network.Realname != "" {
|
||||
realname = &network.Realname
|
||||
}
|
||||
if network.Pass != "" {
|
||||
pass = &network.Pass
|
||||
}
|
||||
|
||||
var err error
|
||||
if network.ID != 0 {
|
||||
_, err = db.db.Exec(`UPDATE Network
|
||||
SET addr = ?, nick = ?, username = ?, realname = ?, pass = ?
|
||||
SET addr = ?, nick = ?, username = ?, realname = ?, pass = ?,
|
||||
sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
|
||||
WHERE id = ?`,
|
||||
network.Addr, network.Nick, netUsername, realname, pass, network.ID)
|
||||
network.Addr, network.Nick, netUsername, realname, pass,
|
||||
saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
|
||||
} else {
|
||||
var res sql.Result
|
||||
res, err = db.db.Exec(`INSERT INTO Network(user, addr, nick, username,
|
||||
realname, pass)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
username, network.Addr, network.Nick, netUsername, realname, pass)
|
||||
realname, pass, sasl_mechanism, sasl_plain_username,
|
||||
sasl_plain_password)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
username, network.Addr, network.Nick, netUsername, realname, pass,
|
||||
saslMechanism, saslPlainUsername, saslPlainPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -381,7 +381,7 @@ func (dc *downstreamConn) register() error {
|
||||
addr = addr + ":6697"
|
||||
}
|
||||
|
||||
dc.logger.Printf("trying to connect to new upstream server %q", addr)
|
||||
dc.logger.Printf("trying to connect to new network %q", addr)
|
||||
if err := sanityCheckServer(addr); err != nil {
|
||||
dc.logger.Printf("failed to connect to %q: %v", addr, err)
|
||||
return ircError{&irc.Message{
|
||||
@ -390,7 +390,7 @@ func (dc *downstreamConn) register() error {
|
||||
}}
|
||||
}
|
||||
|
||||
dc.logger.Printf("auto-adding network %q", networkName)
|
||||
dc.logger.Printf("auto-saving network %q", networkName)
|
||||
network, err = u.createNetwork(networkName, dc.nick)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -618,6 +618,10 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if upstreamName == "NickServ" {
|
||||
dc.handleNickServPRIVMSG(uc, text)
|
||||
}
|
||||
|
||||
uc.SendMessage(&irc.Message{
|
||||
Command: "PRIVMSG",
|
||||
Params: []string{upstreamName, text},
|
||||
@ -629,3 +633,41 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
|
||||
username, password, ok := parseNickServCredentials(text, uc.nick)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
|
||||
n := uc.network
|
||||
n.SASL.Mechanism = "PLAIN"
|
||||
n.SASL.Plain.Username = username
|
||||
n.SASL.Plain.Password = password
|
||||
if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil {
|
||||
dc.logger.Printf("failed to save NickServ credentials: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
|
||||
fields := strings.Fields(text)
|
||||
if len(fields) < 2 {
|
||||
return "", "", false
|
||||
}
|
||||
cmd := strings.ToUpper(fields[0])
|
||||
params := fields[1:]
|
||||
switch cmd {
|
||||
case "REGISTER":
|
||||
username = nick
|
||||
password = params[0]
|
||||
case "IDENTIFY":
|
||||
if len(params) == 1 {
|
||||
username = nick
|
||||
} else {
|
||||
username = params[0]
|
||||
}
|
||||
password = params[1]
|
||||
}
|
||||
return username, password, true
|
||||
}
|
||||
|
1
go.mod
1
go.mod
@ -3,6 +3,7 @@ module git.sr.ht/~emersion/jounce
|
||||
go 1.13
|
||||
|
||||
require (
|
||||
github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible
|
||||
golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4
|
||||
gopkg.in/irc.v3 v3.1.0
|
||||
|
2
go.sum
2
go.sum
@ -1,3 +1,5 @@
|
||||
github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b h1:uhWtEWBHgop1rqEk2klKaxPAkVDCXexai6hSuRQ7Nvs=
|
||||
github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b/go.mod h1:G/dpzLu16WtQpBfQ/z3LYiYJn3ZhKSGWn83fyoyQe/k=
|
||||
github.com/mattn/go-sqlite3 v1.13.0 h1:LnJI81JidiW9r7pS/hXe6cFeO5EXNq7KbfvoJLRI69c=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
|
9
irc.go
9
irc.go
@ -11,6 +11,15 @@ const (
|
||||
rpl_localusers = "265"
|
||||
rpl_globalusers = "266"
|
||||
rpl_topicwhotime = "333"
|
||||
rpl_loggedin = "900"
|
||||
rpl_loggedout = "901"
|
||||
err_nicklocked = "902"
|
||||
rpl_saslsuccess = "903"
|
||||
err_saslfail = "904"
|
||||
err_sasltoolong = "905"
|
||||
err_saslaborted = "906"
|
||||
err_saslalready = "907"
|
||||
rpl_saslmechs = "908"
|
||||
)
|
||||
|
||||
type modeSet string
|
||||
|
@ -11,6 +11,9 @@ CREATE TABLE Network (
|
||||
username VARCHAR(255),
|
||||
realname VARCHAR(255),
|
||||
pass VARCHAR(255),
|
||||
sasl_mechanism VARCHAR(255),
|
||||
sasl_plain_username VARCHAR(255),
|
||||
sasl_plain_password VARCHAR(255),
|
||||
FOREIGN KEY(user) REFERENCES User(username),
|
||||
UNIQUE(user, addr, nick)
|
||||
);
|
||||
|
190
upstream.go
190
upstream.go
@ -2,6 +2,7 @@ package jounce
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -9,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/emersion/go-sasl"
|
||||
"gopkg.in/irc.v3"
|
||||
)
|
||||
|
||||
@ -48,6 +50,9 @@ type upstreamConn struct {
|
||||
channels map[string]*upstreamChannel
|
||||
history map[string]uint64
|
||||
caps map[string]string
|
||||
|
||||
saslClient sasl.Client
|
||||
saslStarted bool
|
||||
}
|
||||
|
||||
func connectToUpstream(network *network) (*upstreamConn, error) {
|
||||
@ -169,28 +174,150 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
||||
case "NOTICE":
|
||||
uc.logger.Print(msg)
|
||||
case "CAP":
|
||||
if len(msg.Params) < 2 {
|
||||
var subCmd string
|
||||
if err := parseMessageParams(msg, nil, &subCmd); err != nil {
|
||||
return err
|
||||
}
|
||||
subCmd = strings.ToUpper(subCmd)
|
||||
subParams := msg.Params[2:]
|
||||
switch subCmd {
|
||||
case "LS":
|
||||
if len(subParams) < 1 {
|
||||
return newNeedMoreParamsError(msg.Command)
|
||||
}
|
||||
caps := strings.Fields(msg.Params[len(msg.Params)-1])
|
||||
more := msg.Params[len(msg.Params)-2] == "*"
|
||||
caps := strings.Fields(subParams[len(subParams)-1])
|
||||
more := len(subParams) >= 2 && msg.Params[len(subParams)-2] == "*"
|
||||
|
||||
for _, s := range caps {
|
||||
kv := strings.SplitN(s, "=", 2)
|
||||
k := strings.ToLower(kv[0])
|
||||
var v string
|
||||
if len(kv) >= 2 {
|
||||
if len(kv) == 2 {
|
||||
v = kv[1]
|
||||
}
|
||||
uc.caps[k] = v
|
||||
}
|
||||
|
||||
if !more {
|
||||
if more {
|
||||
break // wait to receive all capabilities
|
||||
}
|
||||
|
||||
if uc.requestSASL() {
|
||||
uc.SendMessage(&irc.Message{
|
||||
Command: "CAP",
|
||||
Params: []string{"REQ", "sasl"},
|
||||
})
|
||||
break // we'll send CAP END after authentication is completed
|
||||
}
|
||||
|
||||
uc.SendMessage(&irc.Message{
|
||||
Command: "CAP",
|
||||
Params: []string{"END"},
|
||||
})
|
||||
case "ACK", "NAK":
|
||||
if len(subParams) < 1 {
|
||||
return newNeedMoreParamsError(msg.Command)
|
||||
}
|
||||
caps := strings.Fields(subParams[0])
|
||||
|
||||
for _, name := range caps {
|
||||
if err := uc.handleCapAck(strings.ToLower(name), subCmd == "ACK"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if uc.saslClient == nil {
|
||||
uc.SendMessage(&irc.Message{
|
||||
Command: "CAP",
|
||||
Params: []string{"END"},
|
||||
})
|
||||
}
|
||||
default:
|
||||
uc.logger.Printf("unhandled message: %v", msg)
|
||||
}
|
||||
case "AUTHENTICATE":
|
||||
if uc.saslClient == nil {
|
||||
return fmt.Errorf("received unexpected AUTHENTICATE message")
|
||||
}
|
||||
|
||||
// TODO: if a challenge is 400 bytes long, buffer it
|
||||
var challengeStr string
|
||||
if err := parseMessageParams(msg, &challengeStr); err != nil {
|
||||
uc.SendMessage(&irc.Message{
|
||||
Command: "AUTHENTICATE",
|
||||
Params: []string{"*"},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var challenge []byte
|
||||
if challengeStr != "+" {
|
||||
var err error
|
||||
challenge, err = base64.StdEncoding.DecodeString(challengeStr)
|
||||
if err != nil {
|
||||
uc.SendMessage(&irc.Message{
|
||||
Command: "AUTHENTICATE",
|
||||
Params: []string{"*"},
|
||||
})
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var resp []byte
|
||||
var err error
|
||||
if !uc.saslStarted {
|
||||
_, resp, err = uc.saslClient.Start()
|
||||
uc.saslStarted = true
|
||||
} else {
|
||||
resp, err = uc.saslClient.Next(challenge)
|
||||
}
|
||||
if err != nil {
|
||||
uc.SendMessage(&irc.Message{
|
||||
Command: "AUTHENTICATE",
|
||||
Params: []string{"*"},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: send response in multiple chunks if >= 400 bytes
|
||||
var respStr = "+"
|
||||
if resp != nil {
|
||||
respStr = base64.StdEncoding.EncodeToString(resp)
|
||||
}
|
||||
|
||||
uc.SendMessage(&irc.Message{
|
||||
Command: "AUTHENTICATE",
|
||||
Params: []string{respStr},
|
||||
})
|
||||
case rpl_loggedin:
|
||||
var account string
|
||||
if err := parseMessageParams(msg, nil, nil, &account); err != nil {
|
||||
return err
|
||||
}
|
||||
uc.logger.Printf("logged in with account %q", account)
|
||||
case rpl_loggedout:
|
||||
uc.logger.Printf("logged out")
|
||||
case err_nicklocked, rpl_saslsuccess, err_saslfail, err_sasltoolong, err_saslaborted:
|
||||
var info string
|
||||
if err := parseMessageParams(msg, nil, &info); err != nil {
|
||||
return err
|
||||
}
|
||||
switch msg.Command {
|
||||
case err_nicklocked:
|
||||
uc.logger.Printf("invalid nick used with SASL authentication: %v", info)
|
||||
case err_saslfail:
|
||||
uc.logger.Printf("SASL authentication failed: %v", info)
|
||||
case err_sasltoolong:
|
||||
uc.logger.Printf("SASL message too long: %v", info)
|
||||
}
|
||||
|
||||
uc.saslClient = nil
|
||||
uc.saslStarted = false
|
||||
|
||||
uc.SendMessage(&irc.Message{
|
||||
Command: "CAP",
|
||||
Params: []string{"END"},
|
||||
})
|
||||
case irc.RPL_WELCOME:
|
||||
uc.registered = true
|
||||
uc.logger.Printf("connection registered")
|
||||
@ -439,7 +566,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
|
||||
case irc.RPL_STATSVLINE, irc.RPL_STATSPING, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE:
|
||||
// Ignore
|
||||
default:
|
||||
uc.logger.Printf("unhandled upstream message: %v", msg)
|
||||
uc.logger.Printf("unhandled message: %v", msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -477,6 +604,57 @@ func (uc *upstreamConn) register() {
|
||||
})
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) requestSASL() bool {
|
||||
if uc.network.SASL.Mechanism == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
v, ok := uc.caps["sasl"]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if v != "" {
|
||||
mechanisms := strings.Split(v, ",")
|
||||
found := false
|
||||
for _, mech := range mechanisms {
|
||||
if strings.EqualFold(mech, uc.network.SASL.Mechanism) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
|
||||
auth := &uc.network.SASL
|
||||
switch name {
|
||||
case "sasl":
|
||||
if !ok {
|
||||
uc.logger.Printf("server refused to acknowledge the SASL capability")
|
||||
return nil
|
||||
}
|
||||
|
||||
switch auth.Mechanism {
|
||||
case "PLAIN":
|
||||
uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
|
||||
uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
|
||||
default:
|
||||
return fmt.Errorf("unsupported SASL mechanism %q", name)
|
||||
}
|
||||
|
||||
uc.SendMessage(&irc.Message{
|
||||
Command: "AUTHENTICATE",
|
||||
Params: []string{auth.Mechanism},
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) readMessages() error {
|
||||
for {
|
||||
msg, err := uc.irc.ReadMessage()
|
||||
|
Loading…
Reference in New Issue
Block a user