Introduce a database package

This commit is contained in:
Simon Ser 2022-05-09 12:34:43 +02:00
parent 27f21eab94
commit 3a7dee8128
18 changed files with 206 additions and 152 deletions

View File

@ -24,6 +24,7 @@ import (
"git.sr.ht/~emersion/soju" "git.sr.ht/~emersion/soju"
"git.sr.ht/~emersion/soju/config" "git.sr.ht/~emersion/soju/config"
"git.sr.ht/~emersion/soju/database"
) )
// TCP keep-alive interval for downstream TCP connections // TCP keep-alive interval for downstream TCP connections
@ -116,7 +117,7 @@ func main() {
log.Printf("failed to bump max number of opened files: %v", err) log.Printf("failed to bump max number of opened files: %v", err)
} }
db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource) db, err := database.Open(cfg.SQLDriver, cfg.SQLSource)
if err != nil { if err != nil {
log.Fatalf("failed to open database: %v", err) log.Fatalf("failed to open database: %v", err)
} }
@ -308,7 +309,7 @@ func main() {
log.Printf("server listening on %q", listen) log.Printf("server listening on %q", listen)
} }
if db, ok := db.(soju.MetricsCollectorDatabase); ok && srv.MetricsRegistry != nil { if db, ok := db.(database.MetricsCollectorDatabase); ok && srv.MetricsRegistry != nil {
if err := db.RegisterMetrics(srv.MetricsRegistry); err != nil { if err := db.RegisterMetrics(srv.MetricsRegistry); err != nil {
log.Fatalf("failed to register database metrics: %v", err) log.Fatalf("failed to register database metrics: %v", err)
} }

View File

@ -9,10 +9,11 @@ import (
"log" "log"
"os" "os"
"git.sr.ht/~emersion/soju"
"git.sr.ht/~emersion/soju/config"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/ssh/terminal" "golang.org/x/crypto/ssh/terminal"
"git.sr.ht/~emersion/soju/config"
"git.sr.ht/~emersion/soju/database"
) )
const usage = `usage: sojuctl [-config path] <action> [options...] const usage = `usage: sojuctl [-config path] <action> [options...]
@ -44,7 +45,7 @@ func main() {
cfg = config.Defaults() cfg = config.Defaults()
} }
db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource) db, err := database.Open(cfg.SQLDriver, cfg.SQLSource)
if err != nil { if err != nil {
log.Fatalf("failed to open database: %v", err) log.Fatalf("failed to open database: %v", err)
} }
@ -73,7 +74,7 @@ func main() {
log.Fatalf("failed to hash password: %v", err) log.Fatalf("failed to hash password: %v", err)
} }
user := soju.User{ user := database.User{
Username: username, Username: username,
Password: string(hashed), Password: string(hashed),
Admin: *admin, Admin: *admin,

View File

@ -12,8 +12,8 @@ import (
"strings" "strings"
"unicode" "unicode"
"git.sr.ht/~emersion/soju"
"git.sr.ht/~emersion/soju/config" "git.sr.ht/~emersion/soju/config"
"git.sr.ht/~emersion/soju/database"
) )
const usage = `usage: znc-import [options...] <znc config path> const usage = `usage: znc-import [options...] <znc config path>
@ -64,7 +64,7 @@ func main() {
ctx := context.Background() ctx := context.Background()
db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource) db, err := database.Open(cfg.SQLDriver, cfg.SQLSource)
if err != nil { if err != nil {
log.Fatalf("failed to open database: %v", err) log.Fatalf("failed to open database: %v", err)
} }
@ -86,7 +86,7 @@ func main() {
if err != nil { if err != nil {
log.Fatalf("failed to list users in DB: %v", err) log.Fatalf("failed to list users in DB: %v", err)
} }
existingUsers := make(map[string]*soju.User, len(l)) existingUsers := make(map[string]*database.User, len(l))
for i, u := range l { for i, u := range l {
existingUsers[u.Username] = &l[i] existingUsers[u.Username] = &l[i]
} }
@ -107,7 +107,7 @@ func main() {
log.Printf("user %q: updating existing user", username) log.Printf("user %q: updating existing user", username)
} else { } else {
// "!!" is an invalid crypt format, thus disables password auth // "!!" is an invalid crypt format, thus disables password auth
u = &soju.User{Username: username, Password: "!!"} u = &database.User{Username: username, Password: "!!"}
usersCreated++ usersCreated++
log.Printf("user %q: creating new user", username) log.Printf("user %q: creating new user", username)
} }
@ -123,7 +123,7 @@ func main() {
if err != nil { if err != nil {
log.Fatalf("failed to list networks for user %q: %v", username, err) log.Fatalf("failed to list networks for user %q: %v", username, err)
} }
existingNetworks := make(map[string]*soju.Network, len(l)) existingNetworks := make(map[string]*database.Network, len(l))
for i, n := range l { for i, n := range l {
existingNetworks[n.GetName()] = &l[i] existingNetworks[n.GetName()] = &l[i]
} }
@ -175,7 +175,7 @@ func main() {
if ok { if ok {
logger.Printf("updating existing network") logger.Printf("updating existing network")
} else { } else {
n = &soju.Network{Name: netName} n = &database.Network{Name: netName}
logger.Printf("creating new network") logger.Printf("creating new network")
} }
@ -194,7 +194,7 @@ func main() {
if err != nil { if err != nil {
logger.Fatalf("failed to list channels: %v", err) logger.Fatalf("failed to list channels: %v", err)
} }
existingChannels := make(map[string]*soju.Channel, len(l)) existingChannels := make(map[string]*database.Channel, len(l))
for i, ch := range l { for i, ch := range l {
existingChannels[ch.Name] = &l[i] existingChannels[ch.Name] = &l[i]
} }
@ -213,7 +213,7 @@ func main() {
if ok { if ok {
logger.Printf("channel %q: updating existing channel", chName) logger.Printf("channel %q: updating existing channel", chName)
} else { } else {
ch = &soju.Channel{Name: chName} ch = &database.Channel{Name: chName}
logger.Printf("channel %q: creating new channel", chName) logger.Printf("channel %q: creating new channel", chName)
} }

View File

@ -1,4 +1,4 @@
package soju package database
import ( import (
"context" "context"
@ -38,7 +38,7 @@ type MetricsCollectorDatabase interface {
RegisterMetrics(r prometheus.Registerer) error RegisterMetrics(r prometheus.Registerer) error
} }
func OpenDB(driver, source string) (Database, error) { func Open(driver, source string) (Database, error) {
switch driver { switch driver {
case "sqlite3": case "sqlite3":
return OpenSqliteDB(source) return OpenSqliteDB(source)
@ -149,20 +149,6 @@ const (
FilterMessage FilterMessage
) )
func parseFilter(filter string) (MessageFilter, error) {
switch filter {
case "default":
return FilterDefault, nil
case "none":
return FilterNone, nil
case "highlight":
return FilterHighlight, nil
case "message":
return FilterMessage, nil
}
return 0, fmt.Errorf("unknown filter: %q", filter)
}
type Channel struct { type Channel struct {
ID int64 ID int64
Name string Name string

View File

@ -1,4 +1,4 @@
package soju package database
import ( import (
"context" "context"
@ -127,6 +127,37 @@ func OpenPostgresDB(source string) (Database, error) {
return db, nil return db, nil
} }
func openTempPostgresDB(source string) (*sql.DB, error) {
db, err := sql.Open("postgres", source)
if err != nil {
return nil, fmt.Errorf("failed to connect to PostgreSQL: %v", err)
}
// Store all tables in a temporary schema which will be dropped when the
// connection to PostgreSQL is closed.
db.SetMaxOpenConns(1)
if _, err := db.Exec("SET search_path TO pg_temp"); err != nil {
return nil, fmt.Errorf("failed to set PostgreSQL search_path: %v", err)
}
return db, nil
}
func OpenTempPostgresDB(source string) (Database, error) {
sqlPostgresDB, err := openTempPostgresDB(source)
if err != nil {
return nil, err
}
db := &PostgresDB{db: sqlPostgresDB}
if err := db.upgrade(); err != nil {
sqlPostgresDB.Close()
return nil, err
}
return db, nil
}
func (db *PostgresDB) upgrade() error { func (db *PostgresDB) upgrade() error {
tx, err := db.db.Begin() tx, err := db.db.Begin()
if err != nil { if err != nil {

View File

@ -1,7 +1,6 @@
package soju package database
import ( import (
"database/sql"
"os" "os"
"testing" "testing"
) )
@ -68,29 +67,17 @@ CREATE TABLE "DeliveryReceipt" (
); );
` `
func openTempPostgresDB(t *testing.T) *sql.DB { func TestPostgresMigrations(t *testing.T) {
source, ok := os.LookupEnv("SOJU_TEST_POSTGRES") source, ok := os.LookupEnv("SOJU_TEST_POSTGRES")
if !ok { if !ok {
t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests") t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests")
} }
db, err := sql.Open("postgres", source) sqlDB, err := openTempPostgresDB(source)
if err != nil { if err != nil {
t.Fatalf("failed to connect to PostgreSQL: %v", err) t.Fatalf("openTempPostgresDB() failed: %v", err)
} }
// Store all tables in a temporary schema which will be dropped when the
// connection to PostgreSQL is closed.
db.SetMaxOpenConns(1)
if _, err := db.Exec("SET search_path TO pg_temp"); err != nil {
t.Fatalf("failed to set PostgreSQL search_path: %v", err)
}
return db
}
func TestPostgresMigrations(t *testing.T) {
sqlDB := openTempPostgresDB(t)
if _, err := sqlDB.Exec(postgresV0Schema); err != nil { if _, err := sqlDB.Exec(postgresV0Schema); err != nil {
t.Fatalf("DB.Exec() failed for v0 schema: %v", err) t.Fatalf("DB.Exec() failed for v0 schema: %v", err)
} }

View File

@ -1,4 +1,4 @@
package soju package database
import ( import (
"context" "context"
@ -15,6 +15,12 @@ import (
const sqliteQueryTimeout = 5 * time.Second const sqliteQueryTimeout = 5 * time.Second
const sqliteTimeLayout = "2006-01-02T15:04:05.000Z"
func formatSqliteTime(t time.Time) string {
return t.UTC().Format(sqliteTimeLayout)
}
const sqliteSchema = ` const sqliteSchema = `
CREATE TABLE User ( CREATE TABLE User (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
@ -212,6 +218,13 @@ func OpenSqliteDB(source string) (Database, error) {
return db, nil return db, nil
} }
func OpenTempSqliteDB() (Database, error) {
// :memory: will open a separate database for each new connection. Make
// sure the sql package only uses a single connection via SetMaxOpenConns.
// An alternative solution is to use "file::memory:?cache=shared".
return OpenSqliteDB(":memory:")
}
func (db *SqliteDB) Close() error { func (db *SqliteDB) Close() error {
return db.db.Close() return db.db.Close()
} }
@ -732,7 +745,7 @@ func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name st
} }
return nil, err return nil, err
} }
if t, err := time.Parse(serverTimeLayout, timestamp); err != nil { if t, err := time.Parse(sqliteTimeLayout, timestamp); err != nil {
return nil, err return nil, err
} else { } else {
receipt.Timestamp = t receipt.Timestamp = t
@ -746,7 +759,7 @@ func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, recei
args := []interface{}{ args := []interface{}{
sql.Named("id", receipt.ID), sql.Named("id", receipt.ID),
sql.Named("timestamp", formatServerTime(receipt.Timestamp)), sql.Named("timestamp", formatSqliteTime(receipt.Timestamp)),
sql.Named("network", networkID), sql.Named("network", networkID),
sql.Named("target", receipt.Target), sql.Named("target", receipt.Target),
} }

View File

@ -1,4 +1,4 @@
package soju package database
import ( import (
"database/sql" "database/sql"

View File

@ -16,6 +16,8 @@ import (
"github.com/emersion/go-sasl" "github.com/emersion/go-sasl"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
) )
type ircError struct { type ircError struct {
@ -100,7 +102,7 @@ func parseBouncerNetID(subcommand, s string) (int64, error) {
return id, nil return id, nil
} }
func fillNetworkAddrAttrs(attrs irc.Tags, network *Network) { func fillNetworkAddrAttrs(attrs irc.Tags, network *database.Network) {
u, err := network.URL() u, err := network.URL()
if err != nil { if err != nil {
return return
@ -132,13 +134,13 @@ func getNetworkAttrs(network *network) irc.Tags {
attrs := irc.Tags{ attrs := irc.Tags{
"name": irc.TagValue(network.GetName()), "name": irc.TagValue(network.GetName()),
"state": irc.TagValue(state), "state": irc.TagValue(state),
"nickname": irc.TagValue(GetNick(&network.user.User, &network.Network)), "nickname": irc.TagValue(database.GetNick(&network.user.User, &network.Network)),
} }
if network.Username != "" { if network.Username != "" {
attrs["username"] = irc.TagValue(network.Username) attrs["username"] = irc.TagValue(network.Username)
} }
if realname := GetRealname(&network.user.User, &network.Network); realname != "" { if realname := database.GetRealname(&network.user.User, &network.Network); realname != "" {
attrs["realname"] = irc.TagValue(realname) attrs["realname"] = irc.TagValue(realname)
} }
@ -169,7 +171,7 @@ func networkAddrFromAttrs(attrs irc.Tags) string {
return addr return addr
} }
func updateNetworkAttrs(record *Network, attrs irc.Tags, subcommand string) error { func updateNetworkAttrs(record *database.Network, attrs irc.Tags, subcommand string) error {
addrAttrs := irc.Tags{} addrAttrs := irc.Tags{}
fillNetworkAddrAttrs(addrAttrs, record) fillNetworkAddrAttrs(addrAttrs, record)
@ -414,7 +416,7 @@ func isOurNick(net *network, nick string) bool {
// know whether this name is our nickname. Best-effort: use the network's // know whether this name is our nickname. Best-effort: use the network's
// configured nickname and hope it was the one being used when we were // configured nickname and hope it was the one being used when we were
// connected. // connected.
return net.casemap(nick) == net.casemap(GetNick(&net.user.User, &net.Network)) return net.casemap(nick) == net.casemap(database.GetNick(&net.user.User, &net.Network))
} }
// marshalEntity converts an upstream entity name (ie. channel or nick) into a // marshalEntity converts an upstream entity name (ie. channel or nick) into a
@ -1146,9 +1148,9 @@ func (dc *downstreamConn) updateNick() {
if uc := dc.upstream(); uc != nil { if uc := dc.upstream(); uc != nil {
nick = uc.nick nick = uc.nick
} else if dc.network != nil { } else if dc.network != nil {
nick = GetNick(&dc.user.User, &dc.network.Network) nick = database.GetNick(&dc.user.User, &dc.network.Network)
} else { } else {
nick = GetNick(&dc.user.User, nil) nick = database.GetNick(&dc.user.User, nil)
} }
if nick == dc.nick { if nick == dc.nick {
@ -1201,9 +1203,9 @@ func (dc *downstreamConn) updateRealname() {
if uc := dc.upstream(); uc != nil { if uc := dc.upstream(); uc != nil {
realname = uc.realname realname = uc.realname
} else if dc.network != nil { } else if dc.network != nil {
realname = GetRealname(&dc.user.User, &dc.network.Network) realname = database.GetRealname(&dc.user.User, &dc.network.Network)
} else { } else {
realname = GetRealname(&dc.user.User, nil) realname = database.GetRealname(&dc.user.User, nil)
} }
if realname != dc.realname { if realname != dc.realname {
@ -1439,7 +1441,7 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
dc.logger.Printf("auto-saving network %q", dc.registration.networkName) dc.logger.Printf("auto-saving network %q", dc.registration.networkName)
var err error var err error
network, err = dc.user.createNetwork(ctx, &Network{ network, err = dc.user.createNetwork(ctx, &database.Network{
Addr: dc.registration.networkName, Addr: dc.registration.networkName,
Nick: nick, Nick: nick,
Enabled: true, Enabled: true,
@ -1475,7 +1477,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
if uc := dc.upstream(); uc != nil { if uc := dc.upstream(); uc != nil {
dc.nick = uc.nick dc.nick = uc.nick
} else if dc.network != nil { } else if dc.network != nil {
dc.nick = GetNick(&dc.user.User, &dc.network.Network) dc.nick = database.GetNick(&dc.user.User, &dc.network.Network)
} else { } else {
dc.nick = dc.user.Username dc.nick = dc.user.Username
} }
@ -1931,7 +1933,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
uc.network.attach(ctx, ch) uc.network.attach(ctx, ch)
} else { } else {
ch = &Channel{ ch = &database.Channel{
Name: upstreamName, Name: upstreamName,
Key: key, Key: key,
} }
@ -1963,7 +1965,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if ch != nil { if ch != nil {
uc.network.detach(ch) uc.network.detach(ch)
} else { } else {
ch = &Channel{ ch = &database.Channel{
Name: name, Name: name,
Detached: true, Detached: true,
} }
@ -2911,7 +2913,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"}, Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"},
}} }}
} else if r == nil { } else if r == nil {
r = &ReadReceipt{ r = &database.ReadReceipt{
Target: entityCM, Target: entityCM,
} }
} }
@ -3082,7 +3084,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
attrs := irc.ParseTags(attrsStr) attrs := irc.ParseTags(attrsStr)
record := &Network{Nick: dc.nick, Enabled: true} record := &database.Network{Nick: dc.nick, Enabled: true}
if err := updateNetworkAttrs(record, attrs, subcommand); err != nil { if err := updateNetworkAttrs(record, attrs, subcommand); err != nil {
return err return err
} }

6
irc.go
View File

@ -9,6 +9,8 @@ import (
"unicode/utf8" "unicode/utf8"
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
) )
const ( const (
@ -653,12 +655,12 @@ func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
type channelCasemapMap struct{ casemapMap } type channelCasemapMap struct{ casemapMap }
func (cm *channelCasemapMap) Value(name string) *Channel { func (cm *channelCasemapMap) Value(name string) *database.Channel {
entry, ok := cm.innerMap[cm.casemap(name)] entry, ok := cm.innerMap[cm.casemap(name)]
if !ok { if !ok {
return nil return nil
} }
return entry.value.(*Channel) return entry.value.(*database.Channel)
} }
type membershipsCasemapMap struct{ casemapMap } type membershipsCasemapMap struct{ casemapMap }

View File

@ -9,6 +9,8 @@ import (
"git.sr.ht/~sircmpwn/go-bare" "git.sr.ht/~sircmpwn/go-bare"
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
) )
// messageStore is a per-user store for IRC messages. // messageStore is a per-user store for IRC messages.
@ -17,11 +19,11 @@ type messageStore interface {
// LastMsgID queries the last message ID for the given network, entity and // LastMsgID queries the last message ID for the given network, entity and
// date. The message ID returned may not refer to a valid message, but can be // date. The message ID returned may not refer to a valid message, but can be
// used in history queries. // used in history queries.
LastMsgID(network *Network, entity string, t time.Time) (string, error) LastMsgID(network *database.Network, entity string, t time.Time) (string, error)
// LoadLatestID queries the latest non-event messages for the given network, // LoadLatestID queries the latest non-event messages for the given network,
// entity and date, up to a count of limit messages, sorted from oldest to newest. // entity and date, up to a count of limit messages, sorted from oldest to newest.
LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error)
Append(network *Network, entity string, msg *irc.Message) (id string, err error) Append(network *database.Network, entity string, msg *irc.Message) (id string, err error)
} }
type chatHistoryTarget struct { type chatHistoryTarget struct {
@ -38,17 +40,17 @@ type chatHistoryMessageStore interface {
// It returns up to limit targets, starting from start and ending on end, // It returns up to limit targets, starting from start and ending on end,
// both excluded. end may be before or after start. // both excluded. end may be before or after start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
// LoadBeforeTime loads up to limit messages before start down to end. The // LoadBeforeTime loads up to limit messages before start down to end. The
// returned messages must be between and excluding the provided bounds. // returned messages must be between and excluding the provided bounds.
// end is before start. // end is before start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
// LoadBeforeTime loads up to limit messages after start up to end. The // LoadBeforeTime loads up to limit messages after start up to end. The
// returned messages must be between and excluding the provided bounds. // returned messages must be between and excluding the provided bounds.
// end is after start. // end is after start.
// If events is false, only PRIVMSG/NOTICE messages are considered. // If events is false, only PRIVMSG/NOTICE messages are considered.
LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error) LoadAfterTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
} }
type searchOptions struct { type searchOptions struct {
@ -66,7 +68,7 @@ type searchMessageStore interface {
messageStore messageStore
// Search returns messages matching the specified options. // Search returns messages matching the specified options.
Search(ctx context.Context, network *Network, search searchOptions) ([]*irc.Message, error) Search(ctx context.Context, network *database.Network, search searchOptions) ([]*irc.Message, error)
} }
type msgIDType uint type msgIDType uint

View File

@ -13,6 +13,8 @@ import (
"git.sr.ht/~sircmpwn/go-bare" "git.sr.ht/~sircmpwn/go-bare"
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
) )
const ( const (
@ -80,7 +82,7 @@ type fsMessageStoreFile struct {
// https://github.com/znc/znc/blob/master/modules/log.cpp // https://github.com/znc/znc/blob/master/modules/log.cpp
type fsMessageStore struct { type fsMessageStore struct {
root string root string
user *User user *database.User
// Write-only files used by Append // Write-only files used by Append
files map[string]*fsMessageStoreFile // indexed by entity files map[string]*fsMessageStoreFile // indexed by entity
@ -90,7 +92,7 @@ var _ messageStore = (*fsMessageStore)(nil)
var _ chatHistoryMessageStore = (*fsMessageStore)(nil) var _ chatHistoryMessageStore = (*fsMessageStore)(nil)
var _ searchMessageStore = (*fsMessageStore)(nil) var _ searchMessageStore = (*fsMessageStore)(nil)
func newFSMessageStore(root string, user *User) *fsMessageStore { func newFSMessageStore(root string, user *database.User) *fsMessageStore {
return &fsMessageStore{ return &fsMessageStore{
root: filepath.Join(root, escapeFilename(user.Username)), root: filepath.Join(root, escapeFilename(user.Username)),
user: user, user: user,
@ -98,14 +100,14 @@ func newFSMessageStore(root string, user *User) *fsMessageStore {
} }
} }
func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string { func (ms *fsMessageStore) logPath(network *database.Network, entity string, t time.Time) string {
year, month, day := t.Date() year, month, day := t.Date()
filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day) filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename) return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename)
} }
// nextMsgID queries the message ID for the next message to be written to f. // nextMsgID queries the message ID for the next message to be written to f.
func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) { func nextFSMsgID(network *database.Network, entity string, t time.Time, f *os.File) (string, error) {
offset, err := f.Seek(0, io.SeekEnd) offset, err := f.Seek(0, io.SeekEnd)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to query next FS message ID: %v", err) return "", fmt.Errorf("failed to query next FS message ID: %v", err)
@ -113,7 +115,7 @@ func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (stri
return formatFSMsgID(network.ID, entity, t, offset), nil return formatFSMsgID(network.ID, entity, t, offset), nil
} }
func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) { func (ms *fsMessageStore) LastMsgID(network *database.Network, entity string, t time.Time) (string, error) {
p := ms.logPath(network, entity, t) p := ms.logPath(network, entity, t)
fi, err := os.Stat(p) fi, err := os.Stat(p)
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -124,7 +126,7 @@ func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time
return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil
} }
func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) { func (ms *fsMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) {
s := formatMessage(msg) s := formatMessage(msg)
if s == "" { if s == "" {
return "", nil return "", nil
@ -253,7 +255,7 @@ func formatMessage(msg *irc.Message) string {
} }
} }
func (ms *fsMessageStore) parseMessage(line string, network *Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) { func (ms *fsMessageStore) parseMessage(line string, network *database.Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) {
var hour, minute, second int var hour, minute, second int
_, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second) _, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second)
if err != nil { if err != nil {
@ -380,7 +382,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *Network, entity str
// our nickname in the logs, so grab it from the network settings. // our nickname in the logs, so grab it from the network settings.
// Not very accurate since this may not match our nick at the time // Not very accurate since this may not match our nick at the time
// the message was received, but we can't do a lot better. // the message was received, but we can't do a lot better.
entity = GetNick(ms.user, network) entity = database.GetNick(ms.user, network)
} }
params = []string{entity, text} params = []string{entity, text}
} }
@ -399,7 +401,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *Network, entity str
return msg, t, nil return msg, t, nil
} }
func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) { func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(network, entity, ref) path := ms.logPath(network, entity, ref)
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
@ -458,7 +460,7 @@ func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, r
} }
} }
func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, selector func(m *irc.Message) bool) ([]*irc.Message, error) { func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
path := ms.logPath(network, entity, ref) path := ms.logPath(network, entity, ref)
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
@ -493,7 +495,7 @@ func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, re
return history, nil return history, nil
} }
func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) { func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
if start.IsZero() { if start.IsZero() {
start = time.Now() start = time.Now()
} else { } else {
@ -526,11 +528,11 @@ func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *Network, e
return messages[remaining:], nil return messages[remaining:], nil
} }
func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
return ms.getBeforeTime(ctx, network, entity, start, end, limit, events, nil) return ms.getBeforeTime(ctx, network, entity, start, end, limit, events, nil)
} }
func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) { func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
start = start.In(time.Local) start = start.In(time.Local)
if end.IsZero() { if end.IsZero() {
end = time.Now() end = time.Now()
@ -562,11 +564,11 @@ func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *Network, en
return messages, nil return messages, nil
} }
func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
return ms.getAfterTime(ctx, network, entity, start, end, limit, events, nil) return ms.getAfterTime(ctx, network, entity, start, end, limit, events, nil)
} }
func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) {
var afterTime time.Time var afterTime time.Time
var afterOffset int64 var afterOffset int64
if id != "" { if id != "" {
@ -614,7 +616,7 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, en
return history[remaining:], nil return history[remaining:], nil
} }
func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) { func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
start = start.In(time.Local) start = start.In(time.Local)
end = end.In(time.Local) end = end.In(time.Local)
rootPath := filepath.Join(ms.root, escapeFilename(network.GetName())) rootPath := filepath.Join(ms.root, escapeFilename(network.GetName()))
@ -693,7 +695,7 @@ func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, sta
return targets, nil return targets, nil
} }
func (ms *fsMessageStore) Search(ctx context.Context, network *Network, opts searchOptions) ([]*irc.Message, error) { func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, opts searchOptions) ([]*irc.Message, error) {
text := strings.ToLower(opts.text) text := strings.ToLower(opts.text)
selector := func(m *irc.Message) bool { selector := func(m *irc.Message) bool {
if opts.from != "" && m.User != opts.from { if opts.from != "" && m.User != opts.from {
@ -711,7 +713,7 @@ func (ms *fsMessageStore) Search(ctx context.Context, network *Network, opts sea
} }
} }
func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error { func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *database.Network) error {
oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName())) oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName()))
newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName())) newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName()))
// Avoid loosing data by overwriting an existing directory // Avoid loosing data by overwriting an existing directory

View File

@ -7,6 +7,8 @@ import (
"git.sr.ht/~sircmpwn/go-bare" "git.sr.ht/~sircmpwn/go-bare"
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
) )
const messageRingBufferCap = 4096 const messageRingBufferCap = 4096
@ -55,7 +57,7 @@ func (ms *memoryMessageStore) Close() error {
return nil return nil
} }
func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer { func (ms *memoryMessageStore) get(network *database.Network, entity string) *messageRingBuffer {
k := ringBufferKey{networkID: network.ID, entity: entity} k := ringBufferKey{networkID: network.ID, entity: entity}
if rb, ok := ms.buffers[k]; ok { if rb, ok := ms.buffers[k]; ok {
return rb return rb
@ -65,7 +67,7 @@ func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingB
return rb return rb
} }
func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) { func (ms *memoryMessageStore) LastMsgID(network *database.Network, entity string, t time.Time) (string, error) {
var seq uint64 var seq uint64
k := ringBufferKey{networkID: network.ID, entity: entity} k := ringBufferKey{networkID: network.ID, entity: entity}
if rb, ok := ms.buffers[k]; ok { if rb, ok := ms.buffers[k]; ok {
@ -74,7 +76,7 @@ func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.
return formatMemoryMsgID(network.ID, entity, seq), nil return formatMemoryMsgID(network.ID, entity, seq), nil
} }
func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) { func (ms *memoryMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) {
switch msg.Command { switch msg.Command {
case "PRIVMSG", "NOTICE": case "PRIVMSG", "NOTICE":
// Only append these messages, because LoadLatestID shouldn't return // Only append these messages, because LoadLatestID shouldn't return
@ -94,7 +96,7 @@ func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.M
return formatMemoryMsgID(network.ID, entity, seq), nil return formatMemoryMsgID(network.ID, entity, seq), nil
} }
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) { func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) {
_, _, seq, err := parseMemoryMsgID(id) _, _, seq, err := parseMemoryMsgID(id)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -20,6 +20,7 @@ import (
"nhooyr.io/websocket" "nhooyr.io/websocket"
"git.sr.ht/~emersion/soju/config" "git.sr.ht/~emersion/soju/config"
"git.sr.ht/~emersion/soju/database"
) )
// TODO: make configurable // TODO: make configurable
@ -141,7 +142,7 @@ type Server struct {
MetricsRegistry prometheus.Registerer // can be nil MetricsRegistry prometheus.Registerer // can be nil
config atomic.Value // *Config config atomic.Value // *Config
db Database db database.Database
stopWG sync.WaitGroup stopWG sync.WaitGroup
lock sync.Mutex lock sync.Mutex
@ -161,7 +162,7 @@ type Server struct {
} }
} }
func NewServer(db Database) *Server { func NewServer(db database.Database) *Server {
srv := &Server{ srv := &Server{
Logger: NewLogger(log.Writer(), true), Logger: NewLogger(log.Writer(), true),
db: db, db: db,
@ -273,7 +274,7 @@ func (s *Server) Shutdown() {
} }
} }
func (s *Server) createUser(ctx context.Context, user *User) (*user, error) { func (s *Server) createUser(ctx context.Context, user *database.User) (*user, error) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
@ -304,7 +305,7 @@ func (s *Server) getUser(name string) *user {
return u return u
} }
func (s *Server) addUserLocked(user *User) *user { func (s *Server) addUserLocked(user *database.User) *user {
s.Logger.Printf("starting bouncer for user %q", user.Username) s.Logger.Printf("starting bouncer for user %q", user.Username)
u := newUser(s, user) u := newUser(s, user)
s.users[u.Username] = u s.users[u.Username] = u

View File

@ -3,10 +3,13 @@ package soju
import ( import (
"context" "context"
"net" "net"
"os"
"testing" "testing"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
) )
var testServerPrefix = &irc.Prefix{Name: "soju-test-server"} var testServerPrefix = &irc.Prefix{Name: "soju-test-server"}
@ -16,34 +19,35 @@ const (
testPassword = testUsername testPassword = testUsername
) )
func createTempSqliteDB(t *testing.T) Database { func createTempSqliteDB(t *testing.T) database.Database {
db, err := OpenDB("sqlite3", ":memory:") db, err := database.OpenTempSqliteDB()
if err != nil { if err != nil {
t.Fatalf("failed to create temporary SQLite database: %v", err) t.Fatalf("failed to create temporary SQLite database: %v", err)
} }
// :memory: will open a separate database for each new connection. Make
// sure the sql package only uses a single connection. An alternative
// solution is to use "file::memory:?cache=shared".
db.(*SqliteDB).db.SetMaxOpenConns(1)
return db return db
} }
func createTempPostgresDB(t *testing.T) Database { func createTempPostgresDB(t *testing.T) database.Database {
db := &PostgresDB{db: openTempPostgresDB(t)} source, ok := os.LookupEnv("SOJU_TEST_POSTGRES")
if err := db.upgrade(); err != nil { if !ok {
t.Fatalf("failed to upgrade PostgreSQL database: %v", err) t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests")
}
db, err := database.OpenTempPostgresDB(source)
if err != nil {
t.Fatalf("failed to create temporary PostgreSQL database: %v", err)
} }
return db return db
} }
func createTestUser(t *testing.T, db Database) *User { func createTestUser(t *testing.T, db database.Database) *database.User {
hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost) hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost)
if err != nil { if err != nil {
t.Fatalf("failed to generate bcrypt hash: %v", err) t.Fatalf("failed to generate bcrypt hash: %v", err)
} }
record := &User{Username: testUsername, Password: string(hashed)} record := &database.User{Username: testUsername, Password: string(hashed)}
if err := db.StoreUser(context.Background(), record); err != nil { if err := db.StoreUser(context.Background(), record); err != nil {
t.Fatalf("failed to store test user: %v", err) t.Fatalf("failed to store test user: %v", err)
} }
@ -57,13 +61,13 @@ func createTestDownstream(t *testing.T, srv *Server) ircConn {
return newNetIRCConn(c2) return newNetIRCConn(c2)
} }
func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Listener) { func createTestUpstream(t *testing.T, db database.Database, user *database.User) (*database.Network, net.Listener) {
ln, err := net.Listen("tcp", "localhost:0") ln, err := net.Listen("tcp", "localhost:0")
if err != nil { if err != nil {
t.Fatalf("failed to create TCP listener: %v", err) t.Fatalf("failed to create TCP listener: %v", err)
} }
network := &Network{ network := &database.Network{
Name: "testnet", Name: "testnet",
Addr: "irc+insecure://" + ln.Addr().String(), Addr: "irc+insecure://" + ln.Addr().String(),
Nick: user.Username, Nick: user.Username,
@ -95,7 +99,7 @@ func expectMessage(t *testing.T, c ircConn, cmd string) *irc.Message {
return msg return msg
} }
func registerDownstreamConn(t *testing.T, c ircConn, network *Network) { func registerDownstreamConn(t *testing.T, c ircConn, network *database.Network) {
c.WriteMessage(&irc.Message{ c.WriteMessage(&irc.Message{
Command: "PASS", Command: "PASS",
Params: []string{testPassword}, Params: []string{testPassword},
@ -151,7 +155,7 @@ func registerUpstreamConn(t *testing.T, c ircConn) {
}) })
} }
func testServer(t *testing.T, db Database) { func testServer(t *testing.T, db database.Database) {
user := createTestUser(t, db) user := createTestUser(t, db)
network, upstream := createTestUpstream(t, db, user) network, upstream := createTestUpstream(t, db, user)
defer upstream.Close() defer upstream.Close()

View File

@ -17,6 +17,8 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
) )
const serviceNick = "BouncerServ" const serviceNick = "BouncerServ"
@ -447,7 +449,7 @@ func newNetworkFlagSet() *networkFlagSet {
return fs return fs
} }
func (fs *networkFlagSet) update(network *Network) error { func (fs *networkFlagSet) update(network *database.Network) error {
if fs.Addr != nil { if fs.Addr != nil {
if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 { if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
scheme := addrParts[0] scheme := addrParts[0]
@ -508,7 +510,7 @@ func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params
return fmt.Errorf("flag -addr is required") return fmt.Errorf("flag -addr is required")
} }
record := &Network{ record := &database.Network{
Addr: *fs.Addr, Addr: *fs.Addr,
Enabled: true, Enabled: true,
} }
@ -833,7 +835,7 @@ func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string)
return fmt.Errorf("failed to hash password: %v", err) return fmt.Errorf("failed to hash password: %v", err)
} }
user := &User{ user := &database.User{
Username: *username, Username: *username,
Password: string(hashed), Password: string(hashed),
Realname: *realname, Realname: *realname,
@ -971,9 +973,9 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
n := 0 n := 0
sendNetwork := func(net *network) { sendNetwork := func(net *network) {
var channels []*Channel var channels []*database.Channel
for _, entry := range net.channels.innerMap { for _, entry := range net.channels.innerMap {
channels = append(channels, entry.value.(*Channel)) channels = append(channels, entry.value.(*database.Channel))
} }
sort.Slice(channels, func(i, j int) bool { sort.Slice(channels, func(i, j int) bool {
@ -1031,6 +1033,20 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
return nil return nil
} }
func parseFilter(filter string) (database.MessageFilter, error) {
switch filter {
case "default":
return database.FilterDefault, nil
case "none":
return database.FilterNone, nil
case "highlight":
return database.FilterHighlight, nil
case "message":
return database.FilterMessage, nil
}
return 0, fmt.Errorf("unknown filter: %q", filter)
}
type channelFlagSet struct { type channelFlagSet struct {
*flag.FlagSet *flag.FlagSet
RelayDetached, ReattachOn, DetachAfter, DetachOn *string RelayDetached, ReattachOn, DetachAfter, DetachOn *string
@ -1045,7 +1061,7 @@ func newChannelFlagSet() *channelFlagSet {
return fs return fs
} }
func (fs *channelFlagSet) update(channel *Channel) error { func (fs *channelFlagSet) update(channel *database.Channel) error {
if fs.RelayDetached != nil { if fs.RelayDetached != nil {
filter, err := parseFilter(*fs.RelayDetached) filter, err := parseFilter(*fs.RelayDetached)
if err != nil { if err != nil {

View File

@ -17,6 +17,8 @@ import (
"github.com/emersion/go-sasl" "github.com/emersion/go-sasl"
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
) )
// permanentUpstreamCaps is the static list of upstream capabilities always // permanentUpstreamCaps is the static list of upstream capabilities always
@ -510,7 +512,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
highlight := uc.network.isHighlight(msg) highlight := uc.network.isHighlight(msg)
if ch.DetachOn == FilterMessage || ch.DetachOn == FilterDefault || (ch.DetachOn == FilterHighlight && highlight) { if ch.DetachOn == database.FilterMessage || ch.DetachOn == database.FilterDefault || (ch.DetachOn == database.FilterHighlight && highlight) {
uc.updateChannelAutoDetach(target) uc.updateChannelAutoDetach(target)
} }
} }
@ -765,7 +767,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
if uc.network.channels.Len() > 0 { if uc.network.channels.Len() > 0 {
var channels, keys []string var channels, keys []string
for _, entry := range uc.network.channels.innerMap { for _, entry := range uc.network.channels.innerMap {
ch := entry.value.(*Channel) ch := entry.value.(*database.Channel)
channels = append(channels, ch.Name) channels = append(channels, ch.Name)
keys = append(keys, ch.Key) keys = append(keys, ch.Key)
} }
@ -1553,7 +1555,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
// Check if the nick we want is now free // Check if the nick we want is now free
wantNick := GetNick(&uc.user.User, &uc.network.Network) wantNick := database.GetNick(&uc.user.User, &uc.network.Network)
wantNickCM := uc.network.casemap(wantNick) wantNickCM := uc.network.casemap(wantNick)
if !online && uc.nickCM != wantNickCM { if !online && uc.nickCM != wantNickCM {
found := false found := false
@ -1796,13 +1798,13 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
return nil return nil
} }
func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *Channel, msg *irc.Message) { func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *database.Channel, msg *irc.Message) {
if uc.network.detachedMessageNeedsRelay(ch, msg) { if uc.network.detachedMessageNeedsRelay(ch, msg) {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.relayDetachedMessage(uc.network, msg) dc.relayDetachedMessage(uc.network, msg)
}) })
} }
if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) { if ch.ReattachOn == database.FilterMessage || (ch.ReattachOn == database.FilterHighlight && uc.network.isHighlight(msg)) {
uc.network.attach(ctx, ch) uc.network.attach(ctx, ch)
if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
uc.logger.Printf("failed to update channel %q: %v", ch.Name, err) uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
@ -1960,10 +1962,10 @@ func splitSpace(s string) []string {
} }
func (uc *upstreamConn) register(ctx context.Context) { func (uc *upstreamConn) register(ctx context.Context) {
uc.nick = GetNick(&uc.user.User, &uc.network.Network) uc.nick = database.GetNick(&uc.user.User, &uc.network.Network)
uc.nickCM = uc.network.casemap(uc.nick) uc.nickCM = uc.network.casemap(uc.nick)
uc.username = GetUsername(&uc.user.User, &uc.network.Network) uc.username = database.GetUsername(&uc.user.User, &uc.network.Network)
uc.realname = GetRealname(&uc.user.User, &uc.network.Network) uc.realname = database.GetRealname(&uc.user.User, &uc.network.Network)
uc.SendMessage(ctx, &irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "CAP", Command: "CAP",
@ -2193,7 +2195,7 @@ func (uc *upstreamConn) updateMonitor() {
} }
}) })
wantNick := GetNick(&uc.user.User, &uc.network.Network) wantNick := database.GetNick(&uc.user.User, &uc.network.Network)
wantNickCM := uc.network.casemap(wantNick) wantNickCM := uc.network.casemap(wantNick)
if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) { if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) {
addList = append(addList, wantNickCM) addList = append(addList, wantNickCM)

40
user.go
View File

@ -14,6 +14,8 @@ import (
"time" "time"
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
"git.sr.ht/~emersion/soju/database"
) )
type event interface{} type event interface{}
@ -123,7 +125,7 @@ func (ds deliveredStore) ForEachClient(f func(clientName string)) {
} }
type network struct { type network struct {
Network database.Network
user *user user *user
logger Logger logger Logger
stopped chan struct{} stopped chan struct{}
@ -135,7 +137,7 @@ type network struct {
casemap casemapping casemap casemapping
} }
func newNetwork(user *user, record *Network, channels []Channel) *network { func newNetwork(user *user, record *database.Network, channels []database.Channel) *network {
logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())} logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
m := channelCasemapMap{newCasemapMap(0)} m := channelCasemapMap{newCasemapMap(0)}
@ -176,7 +178,7 @@ func (net *network) isStopped() bool {
} }
} }
func userIdent(u *User) string { func userIdent(u *database.User) string {
// The ident is a string we will send to upstream servers in clear-text. // The ident is a string we will send to upstream servers in clear-text.
// For privacy reasons, make sure it doesn't expose any meaningful user // For privacy reasons, make sure it doesn't expose any meaningful user
// metadata. We just use the base64-encoded hashed ID, so that people don't // metadata. We just use the base64-encoded hashed ID, so that people don't
@ -278,7 +280,7 @@ func (net *network) stop() {
} }
} }
func (net *network) detach(ch *Channel) { func (net *network) detach(ch *database.Channel) {
if ch.Detached { if ch.Detached {
return return
} }
@ -312,7 +314,7 @@ func (net *network) detach(ch *Channel) {
}) })
} }
func (net *network) attach(ctx context.Context, ch *Channel) { func (net *network) attach(ctx context.Context, ch *database.Channel) {
if !ch.Detached { if !ch.Detached {
return return
} }
@ -388,13 +390,13 @@ func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName
return return
} }
var receipts []DeliveryReceipt var receipts []database.DeliveryReceipt
net.delivered.ForEachTarget(func(target string) { net.delivered.ForEachTarget(func(target string) {
msgID := net.delivered.LoadID(target, clientName) msgID := net.delivered.LoadID(target, clientName)
if msgID == "" { if msgID == "" {
return return
} }
receipts = append(receipts, DeliveryReceipt{ receipts = append(receipts, database.DeliveryReceipt{
Target: target, Target: target,
InternalMsgID: msgID, InternalMsgID: msgID,
}) })
@ -421,9 +423,9 @@ func (net *network) isHighlight(msg *irc.Message) bool {
return msg.Prefix.Name != nick && isHighlight(text, nick) return msg.Prefix.Name != nick && isHighlight(text, nick)
} }
func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool { func (net *network) detachedMessageNeedsRelay(ch *database.Channel, msg *irc.Message) bool {
highlight := net.isHighlight(msg) highlight := net.isHighlight(msg)
return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight) return ch.RelayDetached == database.FilterMessage || ((ch.RelayDetached == database.FilterHighlight || ch.RelayDetached == database.FilterDefault) && highlight)
} }
func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) { func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) {
@ -443,7 +445,7 @@ func (net *network) autoSaveSASLPlain(ctx context.Context, username, password st
} }
type user struct { type user struct {
User database.User
srv *Server srv *Server
logger Logger logger Logger
@ -455,7 +457,7 @@ type user struct {
msgStore messageStore msgStore messageStore
} }
func newUser(srv *Server, record *User) *user { func newUser(srv *Server, record *database.User) *user {
logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)} logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
var msgStore messageStore var msgStore messageStore
@ -817,7 +819,7 @@ func (u *user) removeNetwork(network *network) {
panic("tried to remove a non-existing network") panic("tried to remove a non-existing network")
} }
func (u *user) checkNetwork(record *Network) error { func (u *user) checkNetwork(record *database.Network) error {
url, err := record.URL() url, err := record.URL()
if err != nil { if err != nil {
return err return err
@ -867,7 +869,7 @@ func (u *user) checkNetwork(record *Network) error {
return nil return nil
} }
func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) { func (u *user) createNetwork(ctx context.Context, record *database.Network) (*network, error) {
if record.ID != 0 { if record.ID != 0 {
panic("tried creating an already-existing network") panic("tried creating an already-existing network")
} }
@ -894,7 +896,7 @@ func (u *user) createNetwork(ctx context.Context, record *Network) (*network, er
return network, nil return network, nil
} }
func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) { func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*network, error) {
if record.ID == 0 { if record.ID == 0 {
panic("tried updating a new network") panic("tried updating a new network")
} }
@ -920,9 +922,9 @@ func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, er
// Most network changes require us to re-connect to the upstream server // Most network changes require us to re-connect to the upstream server
channels := make([]Channel, 0, network.channels.Len()) channels := make([]database.Channel, 0, network.channels.Len())
for _, entry := range network.channels.innerMap { for _, entry := range network.channels.innerMap {
ch := entry.value.(*Channel) ch := entry.value.(*database.Channel)
channels = append(channels, *ch) channels = append(channels, *ch)
} }
@ -992,7 +994,7 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error {
return nil return nil
} }
func (u *user) updateUser(ctx context.Context, record *User) error { func (u *user) updateUser(ctx context.Context, record *database.User) error {
if u.ID != record.ID { if u.ID != record.ID {
panic("ID mismatch when updating user") panic("ID mismatch when updating user")
} }
@ -1005,7 +1007,7 @@ func (u *user) updateUser(ctx context.Context, record *User) error {
if realnameUpdated { if realnameUpdated {
// Re-connect to networks which use the default realname // Re-connect to networks which use the default realname
var needUpdate []Network var needUpdate []database.Network
for _, net := range u.networks { for _, net := range u.networks {
if net.Realname != "" { if net.Realname != "" {
continue continue
@ -1016,7 +1018,7 @@ func (u *user) updateUser(ctx context.Context, record *User) error {
if uc := net.conn; uc != nil && uc.caps.IsEnabled("setname") { if uc := net.conn; uc != nil && uc.caps.IsEnabled("setname") {
uc.SendMessage(ctx, &irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "SETNAME", Command: "SETNAME",
Params: []string{GetRealname(&u.User, &net.Network)}, Params: []string{database.GetRealname(&u.User, &net.Network)},
}) })
continue continue
} }