diff --git a/db.go b/db.go index 246bb85..5c84026 100644 --- a/db.go +++ b/db.go @@ -215,18 +215,11 @@ func (db *DB) upgrade() error { return tx.Commit() } -func fromStringPtr(ptr *string) string { - if ptr == nil { - return "" +func toNullString(s string) sql.NullString { + return sql.NullString{ + String: s, + Valid: s != "", } - return *ptr -} - -func toStringPtr(s string) *string { - if s == "" { - return nil - } - return &s } func (db *DB) ListUsers() ([]User, error) { @@ -242,11 +235,11 @@ func (db *DB) ListUsers() ([]User, error) { var users []User for rows.Next() { var user User - var password *string + var password sql.NullString if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin); err != nil { return nil, err } - user.Password = fromStringPtr(password) + user.Password = password.String users = append(users, user) } if err := rows.Err(); err != nil { @@ -262,12 +255,12 @@ func (db *DB) GetUser(username string) (*User, error) { user := &User{Username: username} - var password *string + var password sql.NullString row := db.db.QueryRow("SELECT id, password, admin FROM User WHERE username = ?", username) if err := row.Scan(&user.ID, &password, &user.Admin); err != nil { return nil, err } - user.Password = fromStringPtr(password) + user.Password = password.String return user, nil } @@ -275,7 +268,7 @@ func (db *DB) StoreUser(user *User) error { db.lock.Lock() defer db.lock.Unlock() - password := toStringPtr(user.Password) + password := toNullString(user.Password) var err error if user.ID != 0 { @@ -346,24 +339,24 @@ func (db *DB) ListNetworks(userID int64) ([]Network, error) { var networks []Network for rows.Next() { var net Network - var name, username, realname, pass, connectCommands *string - var saslMechanism, saslPlainUsername, saslPlainPassword *string + var name, username, realname, pass, connectCommands sql.NullString + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname, &pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword, &net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob) if err != nil { return nil, err } - net.Name = fromStringPtr(name) - net.Username = fromStringPtr(username) - net.Realname = fromStringPtr(realname) - net.Pass = fromStringPtr(pass) - if connectCommands != nil { - net.ConnectCommands = strings.Split(*connectCommands, "\r\n") + net.Name = name.String + net.Username = username.String + net.Realname = realname.String + net.Pass = pass.String + if connectCommands.Valid { + net.ConnectCommands = strings.Split(connectCommands.String, "\r\n") } - net.SASL.Mechanism = fromStringPtr(saslMechanism) - net.SASL.Plain.Username = fromStringPtr(saslPlainUsername) - net.SASL.Plain.Password = fromStringPtr(saslPlainPassword) + net.SASL.Mechanism = saslMechanism.String + net.SASL.Plain.Username = saslPlainUsername.String + net.SASL.Plain.Password = saslPlainPassword.String networks = append(networks, net) } if err := rows.Err(); err != nil { @@ -377,19 +370,19 @@ func (db *DB) StoreNetwork(userID int64, network *Network) error { db.lock.Lock() defer db.lock.Unlock() - netName := toStringPtr(network.Name) - netUsername := toStringPtr(network.Username) - realname := toStringPtr(network.Realname) - pass := toStringPtr(network.Pass) - connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n")) + netName := toNullString(network.Name) + netUsername := toNullString(network.Username) + realname := toNullString(network.Realname) + pass := toNullString(network.Pass) + connectCommands := toNullString(strings.Join(network.ConnectCommands, "\r\n")) - var saslMechanism, saslPlainUsername, saslPlainPassword *string + var saslMechanism, saslPlainUsername, saslPlainPassword sql.NullString if network.SASL.Mechanism != "" { - saslMechanism = &network.SASL.Mechanism + saslMechanism = toNullString(network.SASL.Mechanism) switch network.SASL.Mechanism { case "PLAIN": - saslPlainUsername = toStringPtr(network.SASL.Plain.Username) - saslPlainPassword = toStringPtr(network.SASL.Plain.Password) + saslPlainUsername = toNullString(network.SASL.Plain.Username) + saslPlainPassword = toNullString(network.SASL.Plain.Password) network.SASL.External.CertBlob = nil network.SASL.External.PrivKeyBlob = nil case "EXTERNAL": @@ -465,11 +458,11 @@ func (db *DB) ListChannels(networkID int64) ([]Channel, error) { var channels []Channel for rows.Next() { var ch Channel - var key *string + var key sql.NullString if err := rows.Scan(&ch.ID, &ch.Name, &key, &ch.Detached); err != nil { return nil, err } - ch.Key = fromStringPtr(key) + ch.Key = key.String channels = append(channels, ch) } if err := rows.Err(); err != nil { @@ -483,7 +476,7 @@ func (db *DB) StoreChannel(networkID int64, ch *Channel) error { db.lock.Lock() defer db.lock.Unlock() - key := toStringPtr(ch.Key) + key := toNullString(ch.Key) var err error if ch.ID != 0 {