From fa16337d97f9e2edbae2860470f2427a4847821d Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Sat, 24 Oct 2020 15:14:23 +0200 Subject: [PATCH] Switch DB API to user IDs This commit changes the Network schema to use user IDs instead of usernames. While at it, a new UNIQUE(user, name) constraint ensures there is no conflict with custom network names. Closes: https://todo.sr.ht/~emersion/soju/86 References: https://todo.sr.ht/~emersion/soju/29 --- contrib/znc-import.go | 5 ++-- db.go | 53 ++++++++++++++++++++++++++++++++++--------- downstream.go | 4 ++-- service.go | 8 +++---- user.go | 6 ++--- 5 files changed, 54 insertions(+), 22 deletions(-) diff --git a/contrib/znc-import.go b/contrib/znc-import.go index 53ba5b3..6b4e03b 100644 --- a/contrib/znc-import.go +++ b/contrib/znc-import.go @@ -114,8 +114,9 @@ func main() { if err := db.StoreUser(u); err != nil { log.Fatalf("failed to store user %q: %v", username, err) } + userID := u.ID - l, err := db.ListNetworks(username) + l, err := db.ListNetworks(userID) if err != nil { log.Fatalf("failed to list networks for user %q: %v", username, err) } @@ -181,7 +182,7 @@ func main() { n.Realname = netRealname n.Pass = pass - if err := db.StoreNetwork(username, n); err != nil { + if err := db.StoreNetwork(userID, n); err != nil { logger.Fatalf("failed to store network: %v", err) } diff --git a/db.go b/db.go index 349499a..246bb85 100644 --- a/db.go +++ b/db.go @@ -70,7 +70,7 @@ CREATE TABLE User ( CREATE TABLE Network ( id INTEGER PRIMARY KEY, name VARCHAR(255), - user VARCHAR(255) NOT NULL, + user INTEGER NOT NULL, addr VARCHAR(255) NOT NULL, nick VARCHAR(255) NOT NULL, username VARCHAR(255), @@ -82,8 +82,9 @@ CREATE TABLE Network ( sasl_plain_password VARCHAR(255), sasl_external_cert BLOB DEFAULT NULL, sasl_external_key BLOB DEFAULT NULL, - FOREIGN KEY(user) REFERENCES User(username), - UNIQUE(user, addr, nick) + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) ); CREATE TABLE Channel ( @@ -115,6 +116,36 @@ var migrations = []string{ DROP TABLE User; ALTER TABLE UserNew RENAME TO User; `, + ` + CREATE TABLE NetworkNew ( + id INTEGER PRIMARY KEY, + name VARCHAR(255), + user INTEGER NOT NULL, + addr VARCHAR(255) NOT NULL, + nick VARCHAR(255) NOT NULL, + username VARCHAR(255), + realname VARCHAR(255), + pass VARCHAR(255), + connect_commands VARCHAR(1023), + sasl_mechanism VARCHAR(255), + sasl_plain_username VARCHAR(255), + sasl_plain_password VARCHAR(255), + sasl_external_cert BLOB DEFAULT NULL, + sasl_external_key BLOB DEFAULT NULL, + FOREIGN KEY(user) REFERENCES User(id), + UNIQUE(user, addr, nick), + UNIQUE(user, name) + ); + INSERT INTO NetworkNew + SELECT Network.id, name, User.id as user, addr, nick, + Network.username, realname, pass, connect_commands, + sasl_mechanism, sasl_plain_username, sasl_plain_password, + sasl_external_cert, sasl_external_key + FROM Network + JOIN User ON Network.user = User.username; + DROP TABLE Network; + ALTER TABLE NetworkNew RENAME TO Network; + `, } type DB struct { @@ -263,7 +294,7 @@ func (db *DB) StoreUser(user *User) error { return err } -func (db *DB) DeleteUser(username string) error { +func (db *DB) DeleteUser(id int64) error { db.lock.Lock() defer db.lock.Unlock() @@ -279,17 +310,17 @@ func (db *DB) DeleteUser(username string) error { FROM Channel JOIN Network ON Channel.network = Network.id WHERE Network.user = ? - )`, username) + )`, id) if err != nil { return err } - _, err = tx.Exec("DELETE FROM Network WHERE user = ?", username) + _, err = tx.Exec("DELETE FROM Network WHERE user = ?", id) if err != nil { return err } - _, err = tx.Exec("DELETE FROM User WHERE username = ?", username) + _, err = tx.Exec("DELETE FROM User WHERE id = ?", id) if err != nil { return err } @@ -297,7 +328,7 @@ func (db *DB) DeleteUser(username string) error { return tx.Commit() } -func (db *DB) ListNetworks(username string) ([]Network, error) { +func (db *DB) ListNetworks(userID int64) ([]Network, error) { db.lock.RLock() defer db.lock.RUnlock() @@ -306,7 +337,7 @@ func (db *DB) ListNetworks(username string) ([]Network, error) { sasl_external_cert, sasl_external_key FROM Network WHERE user = ?`, - username) + userID) if err != nil { return nil, err } @@ -342,7 +373,7 @@ func (db *DB) ListNetworks(username string) ([]Network, error) { return networks, nil } -func (db *DB) StoreNetwork(username string, network *Network) error { +func (db *DB) StoreNetwork(userID int64, network *Network) error { db.lock.Lock() defer db.lock.Unlock() @@ -385,7 +416,7 @@ func (db *DB) StoreNetwork(username string, network *Network) error { realname, pass, connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password, sasl_external_cert, sasl_external_key) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands, + userID, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands, saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob) if err != nil { diff --git a/downstream.go b/downstream.go index 3b47815..26069cf 100644 --- a/downstream.go +++ b/downstream.go @@ -1016,7 +1016,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { return } n.Nick = nick - err = dc.srv.db.StoreNetwork(dc.user.Username, &n.Network) + err = dc.srv.db.StoreNetwork(dc.user.ID, &n.Network) }) if err != nil { return err @@ -1697,7 +1697,7 @@ func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) { n.SASL.Mechanism = "PLAIN" n.SASL.Plain.Username = username n.SASL.Plain.Password = password - if err := dc.srv.db.StoreNetwork(dc.user.Username, &n.Network); err != nil { + if err := dc.srv.db.StoreNetwork(dc.user.ID, &n.Network); err != nil { dc.logger.Printf("failed to save NickServ credentials: %v", err) } } diff --git a/service.go b/service.go index c8d31f4..8d73386 100644 --- a/service.go +++ b/service.go @@ -548,7 +548,7 @@ func handleServiceCertfpGenerate(dc *downstreamConn, params []string) error { net.SASL.External.PrivKeyBlob = privKeyBytes net.SASL.Mechanism = "EXTERNAL" - if err := dc.srv.db.StoreNetwork(net.Username, &net.Network); err != nil { + if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil { return err } @@ -593,7 +593,7 @@ func handleServiceSASLSetPlain(dc *downstreamConn, params []string) error { net.SASL.Plain.Password = params[2] net.SASL.Mechanism = "PLAIN" - if err := dc.srv.db.StoreNetwork(net.Username, &net.Network); err != nil { + if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil { return err } @@ -617,7 +617,7 @@ func handleServiceSASLReset(dc *downstreamConn, params []string) error { net.SASL.External.PrivKeyBlob = nil net.SASL.Mechanism = "" - if err := dc.srv.db.StoreNetwork(dc.user.Username, &net.Network); err != nil { + if err := dc.srv.db.StoreNetwork(dc.user.ID, &net.Network); err != nil { return err } @@ -689,7 +689,7 @@ func handleUserDelete(dc *downstreamConn, params []string) error { u.stop() - if err := dc.srv.db.DeleteUser(username); err != nil { + if err := dc.srv.db.DeleteUser(dc.user.ID); err != nil { return fmt.Errorf("failed to delete user: %v", err) } diff --git a/user.go b/user.go index aded19c..3c74022 100644 --- a/user.go +++ b/user.go @@ -314,7 +314,7 @@ func (u *user) getNetworkByID(id int64) *network { func (u *user) run() { defer close(u.done) - networks, err := u.srv.db.ListNetworks(u.Username) + networks, err := u.srv.db.ListNetworks(u.ID) if err != nil { u.srv.Logger.Printf("failed to list networks for user %q: %v", u.Username, err) return @@ -508,7 +508,7 @@ func (u *user) createNetwork(record *Network) (*network, error) { } network := newNetwork(u, record, nil) - err := u.srv.db.StoreNetwork(u.Username, &network.Network) + err := u.srv.db.StoreNetwork(u.ID, &network.Network) if err != nil { return nil, err } @@ -528,7 +528,7 @@ func (u *user) updateNetwork(record *Network) (*network, error) { panic("tried updating a non-existing network") } - if err := u.srv.db.StoreNetwork(u.Username, record); err != nil { + if err := u.srv.db.StoreNetwork(u.ID, record); err != nil { return nil, err }