Implement TLS fingerprint pinning

Closes: https://todo.sr.ht/~emersion/soju/56
This commit is contained in:
rj1 2022-12-10 02:12:46 -06:00 committed by Simon Ser
parent 2604a14b7f
commit be185fba33
6 changed files with 86 additions and 21 deletions

View File

@ -130,6 +130,7 @@ type Network struct {
Realname string Realname string
Pass string Pass string
ConnectCommands []string ConnectCommands []string
CertFP string
SASL SASL SASL SASL
AutoAway bool AutoAway bool
Enabled bool Enabled bool

View File

@ -44,6 +44,7 @@ CREATE TABLE "Network" (
nick VARCHAR(255), nick VARCHAR(255),
username VARCHAR(255), username VARCHAR(255),
realname VARCHAR(255), realname VARCHAR(255),
certfp TEXT,
pass VARCHAR(255), pass VARCHAR(255),
connect_commands VARCHAR(1023), connect_commands VARCHAR(1023),
sasl_mechanism sasl_mechanism, sasl_mechanism sasl_mechanism,
@ -165,6 +166,7 @@ var postgresMigrations = []string{
SET NOT NULL; SET NOT NULL;
`, `,
`ALTER TABLE "Network" ADD COLUMN auto_away BOOLEAN NOT NULL DEFAULT TRUE`, `ALTER TABLE "Network" ADD COLUMN auto_away BOOLEAN NOT NULL DEFAULT TRUE`,
`ALTER TABLE "Network" ADD COLUMN certfp TEXT`,
} }
type PostgresDB struct { type PostgresDB struct {
@ -380,7 +382,7 @@ func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, ` rows, err := db.db.QueryContext(ctx, `
SELECT id, name, addr, nick, username, realname, pass, connect_commands, sasl_mechanism, SELECT id, name, addr, nick, username, realname, certfp, pass, connect_commands, sasl_mechanism,
sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, auto_away, enabled sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key, auto_away, enabled
FROM "Network" FROM "Network"
WHERE "user" = $1`, userID) WHERE "user" = $1`, userID)
@ -392,9 +394,9 @@ func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network
var networks []Network var networks []Network
for rows.Next() { for rows.Next() {
var net Network var net Network
var name, nick, username, realname, pass, connectCommands sql.NullString var name, nick, username, realname, certfp, pass, connectCommands sql.NullString
var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, &certfp,
&pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
&net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.AutoAway, &net.Enabled) &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.AutoAway, &net.Enabled)
if err != nil { if err != nil {
@ -404,6 +406,7 @@ func (db *PostgresDB) ListNetworks(ctx context.Context, userID int64) ([]Network
net.Nick = nick.String net.Nick = nick.String
net.Username = username.String net.Username = username.String
net.Realname = realname.String net.Realname = realname.String
net.CertFP = certfp.String
net.Pass = pass.String net.Pass = pass.String
if connectCommands.Valid { if connectCommands.Valid {
net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
@ -428,6 +431,7 @@ func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *N
nick := toNullString(network.Nick) nick := toNullString(network.Nick)
netUsername := toNullString(network.Username) netUsername := toNullString(network.Username)
realname := toNullString(network.Realname) realname := toNullString(network.Realname)
certfp := toNullString(network.CertFP)
pass := toNullString(network.Pass) pass := toNullString(network.Pass)
connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n")) connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n"))
@ -450,23 +454,23 @@ func (db *PostgresDB) StoreNetwork(ctx context.Context, userID int64, network *N
var err error var err error
if network.ID == 0 { if network.ID == 0 {
err = db.db.QueryRowContext(ctx, ` err = db.db.QueryRowContext(ctx, `
INSERT INTO "Network" ("user", name, addr, nick, username, realname, pass, connect_commands, INSERT INTO "Network" ("user", name, addr, nick, username, realname, certfp, pass, connect_commands,
sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert,
sasl_external_key, auto_away, enabled) sasl_external_key, auto_away, enabled)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
RETURNING id`, RETURNING id`,
userID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands, userID, netName, network.Addr, nick, netUsername, realname, certfp, pass, connectCommands,
saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
network.SASL.External.PrivKeyBlob, network.AutoAway, network.Enabled).Scan(&network.ID) network.SASL.External.PrivKeyBlob, network.AutoAway, network.Enabled).Scan(&network.ID)
} else { } else {
_, err = db.db.ExecContext(ctx, ` _, err = db.db.ExecContext(ctx, `
UPDATE "Network" UPDATE "Network"
SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, pass = $7, SET name = $2, addr = $3, nick = $4, username = $5, realname = $6, certfp = $7, pass = $8,
connect_commands = $8, sasl_mechanism = $9, sasl_plain_username = $10, connect_commands = $9, sasl_mechanism = $10, sasl_plain_username = $11,
sasl_plain_password = $11, sasl_external_cert = $12, sasl_external_key = $13, sasl_plain_password = $12, sasl_external_cert = $13, sasl_external_key = $14,
auto_away = $14, enabled = $15 auto_away = $14, enabled = $15
WHERE id = $1`, WHERE id = $1`,
network.ID, netName, network.Addr, nick, netUsername, realname, pass, connectCommands, network.ID, netName, network.Addr, nick, netUsername, realname, certfp, pass, connectCommands,
saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
network.SASL.External.PrivKeyBlob, network.AutoAway, network.Enabled) network.SASL.External.PrivKeyBlob, network.AutoAway, network.Enabled)
} }

View File

@ -42,6 +42,7 @@ CREATE TABLE Network (
nick TEXT, nick TEXT,
username TEXT, username TEXT,
realname TEXT, realname TEXT,
certfp TEXT,
pass TEXT, pass TEXT,
connect_commands TEXT, connect_commands TEXT,
sasl_mechanism TEXT, sasl_mechanism TEXT,
@ -250,6 +251,7 @@ var sqliteMigrations = []string{
`, `,
"ALTER TABLE User ADD COLUMN nick TEXT;", "ALTER TABLE User ADD COLUMN nick TEXT;",
"ALTER TABLE Network ADD COLUMN auto_away INTEGER NOT NULL DEFAULT 1;", "ALTER TABLE Network ADD COLUMN auto_away INTEGER NOT NULL DEFAULT 1;",
"ALTER TABLE Network ADD COLUMN certfp TEXT;",
} }
type SqliteDB struct { type SqliteDB struct {
@ -488,7 +490,7 @@ func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network,
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, ` rows, err := db.db.QueryContext(ctx, `
SELECT id, name, addr, nick, username, realname, pass, SELECT id, name, addr, nick, username, realname, certfp, pass,
connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
sasl_external_cert, sasl_external_key, auto_away, enabled sasl_external_cert, sasl_external_key, auto_away, enabled
FROM Network FROM Network
@ -502,9 +504,9 @@ func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network,
var networks []Network var networks []Network
for rows.Next() { for rows.Next() {
var net Network var net Network
var name, nick, username, realname, pass, connectCommands sql.NullString var name, nick, username, realname, certfp, pass, connectCommands sql.NullString
var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString
err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, err := rows.Scan(&net.ID, &name, &net.Addr, &nick, &username, &realname, &certfp,
&pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
&net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.AutoAway, &net.Enabled) &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob, &net.AutoAway, &net.Enabled)
if err != nil { if err != nil {
@ -514,6 +516,7 @@ func (db *SqliteDB) ListNetworks(ctx context.Context, userID int64) ([]Network,
net.Nick = nick.String net.Nick = nick.String
net.Username = username.String net.Username = username.String
net.Realname = realname.String net.Realname = realname.String
net.CertFP = certfp.String
net.Pass = pass.String net.Pass = pass.String
if connectCommands.Valid { if connectCommands.Valid {
net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") net.ConnectCommands = strings.Split(connectCommands.String, "\r\n")
@ -556,6 +559,7 @@ func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Net
sql.Named("nick", toNullString(network.Nick)), sql.Named("nick", toNullString(network.Nick)),
sql.Named("username", toNullString(network.Username)), sql.Named("username", toNullString(network.Username)),
sql.Named("realname", toNullString(network.Realname)), sql.Named("realname", toNullString(network.Realname)),
sql.Named("certfp", toNullString(network.CertFP)),
sql.Named("pass", toNullString(network.Pass)), sql.Named("pass", toNullString(network.Pass)),
sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))), sql.Named("connect_commands", toNullString(strings.Join(network.ConnectCommands, "\r\n"))),
sql.Named("sasl_mechanism", saslMechanism), sql.Named("sasl_mechanism", saslMechanism),
@ -575,7 +579,7 @@ func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Net
_, err = db.db.ExecContext(ctx, ` _, err = db.db.ExecContext(ctx, `
UPDATE Network UPDATE Network
SET name = :name, addr = :addr, nick = :nick, username = :username, SET name = :name, addr = :addr, nick = :nick, username = :username,
realname = :realname, pass = :pass, connect_commands = :connect_commands, realname = :realname, certfp = :certfp, pass = :pass, connect_commands = :connect_commands,
sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password, sasl_mechanism = :sasl_mechanism, sasl_plain_username = :sasl_plain_username, sasl_plain_password = :sasl_plain_password,
sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key, sasl_external_cert = :sasl_external_cert, sasl_external_key = :sasl_external_key,
auto_away = :auto_away, enabled = :enabled auto_away = :auto_away, enabled = :enabled
@ -583,10 +587,10 @@ func (db *SqliteDB) StoreNetwork(ctx context.Context, userID int64, network *Net
} else { } else {
var res sql.Result var res sql.Result
res, err = db.db.ExecContext(ctx, ` res, err = db.db.ExecContext(ctx, `
INSERT INTO Network(user, name, addr, nick, username, realname, pass, INSERT INTO Network(user, name, addr, nick, username, realname, certfp, pass,
connect_commands, sasl_mechanism, sasl_plain_username, connect_commands, sasl_mechanism, sasl_plain_username,
sasl_plain_password, sasl_external_cert, sasl_external_key, auto_away, enabled) sasl_plain_password, sasl_external_cert, sasl_external_key, auto_away, enabled)
VALUES (:user, :name, :addr, :nick, :username, :realname, :pass, VALUES (:user, :name, :addr, :nick, :username, :realname, :certfp, :pass,
:connect_commands, :sasl_mechanism, :sasl_plain_username, :connect_commands, :sasl_mechanism, :sasl_plain_username,
:sasl_plain_password, :sasl_external_cert, :sasl_external_key, :auto_away, :enabled)`, :sasl_plain_password, :sasl_external_cert, :sasl_external_key, :auto_away, :enabled)`,
args...) args...)

View File

@ -213,6 +213,12 @@ abbreviated form, for instance *network* can be abbreviated as *net* or just
Connect with the specified real name. By default, the account's realname Connect with the specified real name. By default, the account's realname
is used if set, otherwise the network's nickname is used. is used if set, otherwise the network's nickname is used.
*-certfp* <fingerprint>
Instead of using certificate authorities to check the server's TLS
certificate, check whether the server certificate matches the provided
fingerprint. This can be used to connect to servers using self-signed
certificates. The fingerprint format is SHA512.
*-nick* <nickname> *-nick* <nickname>
Connect with the specified nickname. By default, the account's username Connect with the specified nickname. By default, the account's username
is used. is used.

View File

@ -201,7 +201,7 @@ func init() {
"network": { "network": {
children: serviceCommandSet{ children: serviceCommandSet{
"create": { "create": {
usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-auto-away auto-away] [-enabled enabled] [-connect-command command]...", usage: "-addr <addr> [-name name] [-username username] [-pass pass] [-realname realname] [-certfp fingerprint] [-nick nick] [-auto-away auto-away] [-enabled enabled] [-connect-command command]...",
desc: "add a new network", desc: "add a new network",
handle: handleServiceNetworkCreate, handle: handleServiceNetworkCreate,
}, },
@ -210,7 +210,7 @@ func init() {
handle: handleServiceNetworkStatus, handle: handleServiceNetworkStatus,
}, },
"update": { "update": {
usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-nick nick] [-auto-away auto-away] [-enabled enabled] [-connect-command command]...", usage: "[name] [-addr addr] [-name name] [-username username] [-pass pass] [-realname realname] [-certfp fingerprint] [-nick nick] [-auto-away auto-away] [-enabled enabled] [-connect-command command]...",
desc: "update a network", desc: "update a network",
handle: handleServiceNetworkUpdate, handle: handleServiceNetworkUpdate,
}, },
@ -435,9 +435,9 @@ func getNetworkFromArg(dc *downstreamConn, params []string) (*network, []string,
type networkFlagSet struct { type networkFlagSet struct {
*flag.FlagSet *flag.FlagSet
Addr, Name, Nick, Username, Pass, Realname *string Addr, Name, Nick, Username, Pass, Realname, CertFP *string
AutoAway, Enabled *bool AutoAway, Enabled *bool
ConnectCommands []string ConnectCommands []string
} }
func newNetworkFlagSet() *networkFlagSet { func newNetworkFlagSet() *networkFlagSet {
@ -448,6 +448,7 @@ func newNetworkFlagSet() *networkFlagSet {
fs.Var(stringPtrFlag{&fs.Username}, "username", "") fs.Var(stringPtrFlag{&fs.Username}, "username", "")
fs.Var(stringPtrFlag{&fs.Pass}, "pass", "") fs.Var(stringPtrFlag{&fs.Pass}, "pass", "")
fs.Var(stringPtrFlag{&fs.Realname}, "realname", "") fs.Var(stringPtrFlag{&fs.Realname}, "realname", "")
fs.Var(stringPtrFlag{&fs.CertFP}, "fingerprint", "")
fs.Var(boolPtrFlag{&fs.AutoAway}, "auto-away", "") fs.Var(boolPtrFlag{&fs.AutoAway}, "auto-away", "")
fs.Var(boolPtrFlag{&fs.Enabled}, "enabled", "") fs.Var(boolPtrFlag{&fs.Enabled}, "enabled", "")
fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "") fs.Var((*stringSliceFlag)(&fs.ConnectCommands), "connect-command", "")
@ -484,6 +485,19 @@ func (fs *networkFlagSet) update(network *database.Network) error {
if fs.Realname != nil { if fs.Realname != nil {
network.Realname = *fs.Realname network.Realname = *fs.Realname
} }
if fs.CertFP != nil {
certFP := strings.ReplaceAll(*fs.CertFP, ":", "")
if _, err := hex.DecodeString(certFP); err != nil {
return fmt.Errorf("the certificate fingerprint must be hex-encoded")
}
if len(certFP) == 64 {
network.CertFP = "sha-256:" + certFP
} else if len(certFP) == 128 {
network.CertFP = "sha-512:" + certFP
} else {
return fmt.Errorf("the certificate fingerprint must be a SHA256 or SHA512 hash")
}
}
if fs.AutoAway != nil { if fs.AutoAway != nil {
network.AutoAway = *fs.AutoAway network.AutoAway = *fs.AutoAway
} }

View File

@ -4,9 +4,11 @@ import (
"context" "context"
"crypto" "crypto"
"crypto/sha256" "crypto/sha256"
"crypto/sha512"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -285,6 +287,40 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob)) logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
} }
if network.CertFP != "" {
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return fmt.Errorf("the server didn't present any TLS certificate")
}
parts := strings.SplitN(network.CertFP, ":", 2)
algo, localCertFP := parts[0], parts[1]
for _, rawCert := range rawCerts {
var remoteCertFP string
switch algo {
case "sha-512":
sum := sha512.Sum512(rawCert)
remoteCertFP = hex.EncodeToString(sum[:])
case "sha-256":
sum := sha256.Sum256(rawCert)
remoteCertFP = hex.EncodeToString(sum[:])
}
if remoteCertFP == localCertFP {
return nil // fingerprints match
}
}
// Fingerprints don't match, let's give the user a fingerprint
// they can use to connect
sum := sha512.Sum512(rawCerts[0])
remoteCertFP := hex.EncodeToString(sum[:])
return fmt.Errorf("the configured TLS certificate fingerprint doesn't match the server's - %s", remoteCertFP)
}
}
netConn, err = dialer.DialContext(ctx, "tcp", addr) netConn, err = dialer.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err) return nil, fmt.Errorf("failed to dial %q: %v", addr, err)