soju/db.go
delthas 7f74055380 Add support for custom network on-connect commands
Some servers use custom IRC bots with custom commands for registering to
specific services after connection.

This adds support for setting custom raw IRC messages, that will be
sent after registering to a network.

It also adds support for a custom flag.Value type for string
slice flags (flags taking several string values).
2020-04-16 17:38:47 +02:00

421 lines
9.5 KiB
Go

package soju
import (
"database/sql"
"fmt"
"strings"
"sync"
_ "github.com/mattn/go-sqlite3"
)
type User struct {
Username string
Password string // hashed
}
type SASL struct {
Mechanism string
Plain struct {
Username string
Password string
}
}
type Network struct {
ID int64
Name string
Addr string
Nick string
Username string
Realname string
Pass string
ConnectCommands []string
SASL SASL
}
func (net *Network) GetName() string {
if net.Name != "" {
return net.Name
}
return net.Addr
}
type Channel struct {
ID int64
Name string
Key string
}
var ErrNoSuchChannel = fmt.Errorf("soju: no such channel")
const schema = `
CREATE TABLE User (
username VARCHAR(255) PRIMARY KEY,
password VARCHAR(255) NOT NULL
);
CREATE TABLE Network (
id INTEGER PRIMARY KEY,
name VARCHAR(255),
user VARCHAR(255) NOT NULL,
addr VARCHAR(255) NOT NULL,
nick VARCHAR(255) NOT NULL,
username VARCHAR(255),
realname VARCHAR(255),
pass VARCHAR(255),
connect_commands VARCHAR(1023),
sasl_mechanism VARCHAR(255),
sasl_plain_username VARCHAR(255),
sasl_plain_password VARCHAR(255),
FOREIGN KEY(user) REFERENCES User(username),
UNIQUE(user, addr, nick)
);
CREATE TABLE Channel (
id INTEGER PRIMARY KEY,
network INTEGER NOT NULL,
name VARCHAR(255) NOT NULL,
key VARCHAR(255),
FOREIGN KEY(network) REFERENCES Network(id),
UNIQUE(network, name)
);
`
var migrations = []string{
"", // migration #0 is reserved for schema initialization
"ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
}
type DB struct {
lock sync.RWMutex
db *sql.DB
}
func OpenSQLDB(driver, source string) (*DB, error) {
sqlDB, err := sql.Open(driver, source)
if err != nil {
return nil, err
}
db := &DB{db: sqlDB}
if err := db.upgrade(); err != nil {
return nil, err
}
return db, nil
}
func (db *DB) Close() error {
db.lock.Lock()
defer db.lock.Unlock()
return db.Close()
}
func (db *DB) upgrade() error {
db.lock.Lock()
defer db.lock.Unlock()
var version int
if err := db.db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
return fmt.Errorf("failed to query schema version: %v", err)
}
if version == len(migrations) {
return nil
} else if version > len(migrations) {
return fmt.Errorf("soju (version %d) older than schema (version %d)", len(migrations), version)
}
tx, err := db.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if version == 0 {
if _, err := tx.Exec(schema); err != nil {
return fmt.Errorf("failed to initialize schema: %v", err)
}
} else {
for i := version; i < len(migrations); i++ {
if _, err := tx.Exec(migrations[i]); err != nil {
return fmt.Errorf("failed to execute migration #%v: %v", i, err)
}
}
}
// For some reason prepared statements don't work here
_, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations)))
if err != nil {
return fmt.Errorf("failed to bump schema version: %v", err)
}
return tx.Commit()
}
func fromStringPtr(ptr *string) string {
if ptr == nil {
return ""
}
return *ptr
}
func toStringPtr(s string) *string {
if s == "" {
return nil
}
return &s
}
func (db *DB) ListUsers() ([]User, error) {
db.lock.RLock()
defer db.lock.RUnlock()
rows, err := db.db.Query("SELECT username, password FROM User")
if err != nil {
return nil, err
}
defer rows.Close()
var users []User
for rows.Next() {
var user User
var password *string
if err := rows.Scan(&user.Username, &password); err != nil {
return nil, err
}
user.Password = fromStringPtr(password)
users = append(users, user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return users, nil
}
func (db *DB) GetUser(username string) (*User, error) {
db.lock.RLock()
defer db.lock.RUnlock()
user := &User{Username: username}
var password *string
row := db.db.QueryRow("SELECT password FROM User WHERE username = ?", username)
if err := row.Scan(&password); err != nil {
return nil, err
}
user.Password = fromStringPtr(password)
return user, nil
}
func (db *DB) CreateUser(user *User) error {
db.lock.Lock()
defer db.lock.Unlock()
password := toStringPtr(user.Password)
_, err := db.db.Exec("INSERT INTO User(username, password) VALUES (?, ?)", user.Username, password)
return err
}
func (db *DB) UpdatePassword(user *User) error {
db.lock.Lock()
defer db.lock.Unlock()
password := toStringPtr(user.Password)
_, err := db.db.Exec(`UPDATE User
SET password = ?
WHERE username = ?`,
password, user.Username)
return err
}
func (db *DB) ListNetworks(username string) ([]Network, error) {
db.lock.RLock()
defer db.lock.RUnlock()
rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password
FROM Network
WHERE user = ?`,
username)
if err != nil {
return nil, err
}
defer rows.Close()
var networks []Network
for rows.Next() {
var net Network
var name, username, realname, pass, connectCommands *string
var saslMechanism, saslPlainUsername, saslPlainPassword *string
err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
&pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
if err != nil {
return nil, err
}
net.Name = fromStringPtr(name)
net.Username = fromStringPtr(username)
net.Realname = fromStringPtr(realname)
net.Pass = fromStringPtr(pass)
if connectCommands != nil {
net.ConnectCommands = strings.Split(*connectCommands, "\r\n")
}
net.SASL.Mechanism = fromStringPtr(saslMechanism)
net.SASL.Plain.Username = fromStringPtr(saslPlainUsername)
net.SASL.Plain.Password = fromStringPtr(saslPlainPassword)
networks = append(networks, net)
}
if err := rows.Err(); err != nil {
return nil, err
}
return networks, nil
}
func (db *DB) StoreNetwork(username string, network *Network) error {
db.lock.Lock()
defer db.lock.Unlock()
netName := toStringPtr(network.Name)
netUsername := toStringPtr(network.Username)
realname := toStringPtr(network.Realname)
pass := toStringPtr(network.Pass)
connectCommands := toStringPtr(strings.Join(network.ConnectCommands, "\r\n"))
var saslMechanism, saslPlainUsername, saslPlainPassword *string
if network.SASL.Mechanism != "" {
saslMechanism = &network.SASL.Mechanism
switch network.SASL.Mechanism {
case "PLAIN":
saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
default:
return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
}
}
var err error
if network.ID != 0 {
_, err = db.db.Exec(`UPDATE Network
SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
WHERE id = ?`,
netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
} else {
var res sql.Result
res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
sasl_plain_password)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
saslMechanism, saslPlainUsername, saslPlainPassword)
if err != nil {
return err
}
network.ID, err = res.LastInsertId()
}
return err
}
func (db *DB) DeleteNetwork(id int64) error {
db.lock.Lock()
defer db.lock.Unlock()
tx, err := db.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
_, err = tx.Exec("DELETE FROM Network WHERE id = ?", id)
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM Channel WHERE network = ?", id)
if err != nil {
return err
}
return tx.Commit()
}
func (db *DB) ListChannels(networkID int64) ([]Channel, error) {
db.lock.RLock()
defer db.lock.RUnlock()
rows, err := db.db.Query("SELECT id, name, key FROM Channel WHERE network = ?", networkID)
if err != nil {
return nil, err
}
defer rows.Close()
var channels []Channel
for rows.Next() {
var ch Channel
var key *string
if err := rows.Scan(&ch.ID, &ch.Name, &key); err != nil {
return nil, err
}
ch.Key = fromStringPtr(key)
channels = append(channels, ch)
}
if err := rows.Err(); err != nil {
return nil, err
}
return channels, nil
}
func (db *DB) GetChannel(networkID int64, name string) (*Channel, error) {
db.lock.RLock()
defer db.lock.RUnlock()
ch := &Channel{Name: name}
var key *string
row := db.db.QueryRow("SELECT id, key FROM Channel WHERE network = ? AND name = ?", networkID, name)
if err := row.Scan(&ch.ID, &key); err == sql.ErrNoRows {
return nil, ErrNoSuchChannel
} else if err != nil {
return nil, err
}
ch.Key = fromStringPtr(key)
return ch, nil
}
func (db *DB) StoreChannel(networkID int64, ch *Channel) error {
db.lock.Lock()
defer db.lock.Unlock()
key := toStringPtr(ch.Key)
var err error
if ch.ID != 0 {
_, err = db.db.Exec(`UPDATE Channel
SET network = ?, name = ?, key = ?
WHERE id = ?`, networkID, ch.Name, key, ch.ID)
} else {
var res sql.Result
res, err = db.db.Exec(`INSERT INTO Channel(network, name, key)
VALUES (?, ?, ?)`, networkID, ch.Name, key)
if err != nil {
return err
}
ch.ID, err = res.LastInsertId()
}
return err
}
func (db *DB) DeleteChannel(networkID int64, name string) error {
db.lock.Lock()
defer db.lock.Unlock()
_, err := db.db.Exec("DELETE FROM Channel WHERE network = ? AND name = ?", networkID, name)
return err
}