Introduce a database package
This commit is contained in:
parent
27f21eab94
commit
3a7dee8128
@ -24,6 +24,7 @@ import (
|
||||
|
||||
"git.sr.ht/~emersion/soju"
|
||||
"git.sr.ht/~emersion/soju/config"
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
// TCP keep-alive interval for downstream TCP connections
|
||||
@ -116,7 +117,7 @@ func main() {
|
||||
log.Printf("failed to bump max number of opened files: %v", err)
|
||||
}
|
||||
|
||||
db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
|
||||
db, err := database.Open(cfg.SQLDriver, cfg.SQLSource)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to open database: %v", err)
|
||||
}
|
||||
@ -308,7 +309,7 @@ func main() {
|
||||
log.Printf("server listening on %q", listen)
|
||||
}
|
||||
|
||||
if db, ok := db.(soju.MetricsCollectorDatabase); ok && srv.MetricsRegistry != nil {
|
||||
if db, ok := db.(database.MetricsCollectorDatabase); ok && srv.MetricsRegistry != nil {
|
||||
if err := db.RegisterMetrics(srv.MetricsRegistry); err != nil {
|
||||
log.Fatalf("failed to register database metrics: %v", err)
|
||||
}
|
||||
|
@ -9,10 +9,11 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"git.sr.ht/~emersion/soju"
|
||||
"git.sr.ht/~emersion/soju/config"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
|
||||
"git.sr.ht/~emersion/soju/config"
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
const usage = `usage: sojuctl [-config path] <action> [options...]
|
||||
@ -44,7 +45,7 @@ func main() {
|
||||
cfg = config.Defaults()
|
||||
}
|
||||
|
||||
db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
|
||||
db, err := database.Open(cfg.SQLDriver, cfg.SQLSource)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to open database: %v", err)
|
||||
}
|
||||
@ -73,7 +74,7 @@ func main() {
|
||||
log.Fatalf("failed to hash password: %v", err)
|
||||
}
|
||||
|
||||
user := soju.User{
|
||||
user := database.User{
|
||||
Username: username,
|
||||
Password: string(hashed),
|
||||
Admin: *admin,
|
||||
|
@ -12,8 +12,8 @@ import (
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"git.sr.ht/~emersion/soju"
|
||||
"git.sr.ht/~emersion/soju/config"
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
const usage = `usage: znc-import [options...] <znc config path>
|
||||
@ -64,7 +64,7 @@ func main() {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
|
||||
db, err := database.Open(cfg.SQLDriver, cfg.SQLSource)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to open database: %v", err)
|
||||
}
|
||||
@ -86,7 +86,7 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatalf("failed to list users in DB: %v", err)
|
||||
}
|
||||
existingUsers := make(map[string]*soju.User, len(l))
|
||||
existingUsers := make(map[string]*database.User, len(l))
|
||||
for i, u := range l {
|
||||
existingUsers[u.Username] = &l[i]
|
||||
}
|
||||
@ -107,7 +107,7 @@ func main() {
|
||||
log.Printf("user %q: updating existing user", username)
|
||||
} else {
|
||||
// "!!" is an invalid crypt format, thus disables password auth
|
||||
u = &soju.User{Username: username, Password: "!!"}
|
||||
u = &database.User{Username: username, Password: "!!"}
|
||||
usersCreated++
|
||||
log.Printf("user %q: creating new user", username)
|
||||
}
|
||||
@ -123,7 +123,7 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatalf("failed to list networks for user %q: %v", username, err)
|
||||
}
|
||||
existingNetworks := make(map[string]*soju.Network, len(l))
|
||||
existingNetworks := make(map[string]*database.Network, len(l))
|
||||
for i, n := range l {
|
||||
existingNetworks[n.GetName()] = &l[i]
|
||||
}
|
||||
@ -175,7 +175,7 @@ func main() {
|
||||
if ok {
|
||||
logger.Printf("updating existing network")
|
||||
} else {
|
||||
n = &soju.Network{Name: netName}
|
||||
n = &database.Network{Name: netName}
|
||||
logger.Printf("creating new network")
|
||||
}
|
||||
|
||||
@ -194,7 +194,7 @@ func main() {
|
||||
if err != nil {
|
||||
logger.Fatalf("failed to list channels: %v", err)
|
||||
}
|
||||
existingChannels := make(map[string]*soju.Channel, len(l))
|
||||
existingChannels := make(map[string]*database.Channel, len(l))
|
||||
for i, ch := range l {
|
||||
existingChannels[ch.Name] = &l[i]
|
||||
}
|
||||
@ -213,7 +213,7 @@ func main() {
|
||||
if ok {
|
||||
logger.Printf("channel %q: updating existing channel", chName)
|
||||
} else {
|
||||
ch = &soju.Channel{Name: chName}
|
||||
ch = &database.Channel{Name: chName}
|
||||
logger.Printf("channel %q: creating new channel", chName)
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
package soju
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -38,7 +38,7 @@ type MetricsCollectorDatabase interface {
|
||||
RegisterMetrics(r prometheus.Registerer) error
|
||||
}
|
||||
|
||||
func OpenDB(driver, source string) (Database, error) {
|
||||
func Open(driver, source string) (Database, error) {
|
||||
switch driver {
|
||||
case "sqlite3":
|
||||
return OpenSqliteDB(source)
|
||||
@ -149,20 +149,6 @@ const (
|
||||
FilterMessage
|
||||
)
|
||||
|
||||
func parseFilter(filter string) (MessageFilter, error) {
|
||||
switch filter {
|
||||
case "default":
|
||||
return FilterDefault, nil
|
||||
case "none":
|
||||
return FilterNone, nil
|
||||
case "highlight":
|
||||
return FilterHighlight, nil
|
||||
case "message":
|
||||
return FilterMessage, nil
|
||||
}
|
||||
return 0, fmt.Errorf("unknown filter: %q", filter)
|
||||
}
|
||||
|
||||
type Channel struct {
|
||||
ID int64
|
||||
Name string
|
@ -1,4 +1,4 @@
|
||||
package soju
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -127,6 +127,37 @@ func OpenPostgresDB(source string) (Database, error) {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func openTempPostgresDB(source string) (*sql.DB, error) {
|
||||
db, err := sql.Open("postgres", source)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to PostgreSQL: %v", err)
|
||||
}
|
||||
|
||||
// Store all tables in a temporary schema which will be dropped when the
|
||||
// connection to PostgreSQL is closed.
|
||||
db.SetMaxOpenConns(1)
|
||||
if _, err := db.Exec("SET search_path TO pg_temp"); err != nil {
|
||||
return nil, fmt.Errorf("failed to set PostgreSQL search_path: %v", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func OpenTempPostgresDB(source string) (Database, error) {
|
||||
sqlPostgresDB, err := openTempPostgresDB(source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := &PostgresDB{db: sqlPostgresDB}
|
||||
if err := db.upgrade(); err != nil {
|
||||
sqlPostgresDB.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) upgrade() error {
|
||||
tx, err := db.db.Begin()
|
||||
if err != nil {
|
@ -1,7 +1,6 @@
|
||||
package soju
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
@ -68,29 +67,17 @@ CREATE TABLE "DeliveryReceipt" (
|
||||
);
|
||||
`
|
||||
|
||||
func openTempPostgresDB(t *testing.T) *sql.DB {
|
||||
func TestPostgresMigrations(t *testing.T) {
|
||||
source, ok := os.LookupEnv("SOJU_TEST_POSTGRES")
|
||||
if !ok {
|
||||
t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests")
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", source)
|
||||
sqlDB, err := openTempPostgresDB(source)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to PostgreSQL: %v", err)
|
||||
t.Fatalf("openTempPostgresDB() failed: %v", err)
|
||||
}
|
||||
|
||||
// Store all tables in a temporary schema which will be dropped when the
|
||||
// connection to PostgreSQL is closed.
|
||||
db.SetMaxOpenConns(1)
|
||||
if _, err := db.Exec("SET search_path TO pg_temp"); err != nil {
|
||||
t.Fatalf("failed to set PostgreSQL search_path: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestPostgresMigrations(t *testing.T) {
|
||||
sqlDB := openTempPostgresDB(t)
|
||||
if _, err := sqlDB.Exec(postgresV0Schema); err != nil {
|
||||
t.Fatalf("DB.Exec() failed for v0 schema: %v", err)
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package soju
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -15,6 +15,12 @@ import (
|
||||
|
||||
const sqliteQueryTimeout = 5 * time.Second
|
||||
|
||||
const sqliteTimeLayout = "2006-01-02T15:04:05.000Z"
|
||||
|
||||
func formatSqliteTime(t time.Time) string {
|
||||
return t.UTC().Format(sqliteTimeLayout)
|
||||
}
|
||||
|
||||
const sqliteSchema = `
|
||||
CREATE TABLE User (
|
||||
id INTEGER PRIMARY KEY,
|
||||
@ -212,6 +218,13 @@ func OpenSqliteDB(source string) (Database, error) {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func OpenTempSqliteDB() (Database, error) {
|
||||
// :memory: will open a separate database for each new connection. Make
|
||||
// sure the sql package only uses a single connection via SetMaxOpenConns.
|
||||
// An alternative solution is to use "file::memory:?cache=shared".
|
||||
return OpenSqliteDB(":memory:")
|
||||
}
|
||||
|
||||
func (db *SqliteDB) Close() error {
|
||||
return db.db.Close()
|
||||
}
|
||||
@ -732,7 +745,7 @@ func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name st
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if t, err := time.Parse(serverTimeLayout, timestamp); err != nil {
|
||||
if t, err := time.Parse(sqliteTimeLayout, timestamp); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
receipt.Timestamp = t
|
||||
@ -746,7 +759,7 @@ func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, recei
|
||||
|
||||
args := []interface{}{
|
||||
sql.Named("id", receipt.ID),
|
||||
sql.Named("timestamp", formatServerTime(receipt.Timestamp)),
|
||||
sql.Named("timestamp", formatSqliteTime(receipt.Timestamp)),
|
||||
sql.Named("network", networkID),
|
||||
sql.Named("target", receipt.Target),
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package soju
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
@ -16,6 +16,8 @@ import (
|
||||
"github.com/emersion/go-sasl"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/irc.v3"
|
||||
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
type ircError struct {
|
||||
@ -100,7 +102,7 @@ func parseBouncerNetID(subcommand, s string) (int64, error) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func fillNetworkAddrAttrs(attrs irc.Tags, network *Network) {
|
||||
func fillNetworkAddrAttrs(attrs irc.Tags, network *database.Network) {
|
||||
u, err := network.URL()
|
||||
if err != nil {
|
||||
return
|
||||
@ -132,13 +134,13 @@ func getNetworkAttrs(network *network) irc.Tags {
|
||||
attrs := irc.Tags{
|
||||
"name": irc.TagValue(network.GetName()),
|
||||
"state": irc.TagValue(state),
|
||||
"nickname": irc.TagValue(GetNick(&network.user.User, &network.Network)),
|
||||
"nickname": irc.TagValue(database.GetNick(&network.user.User, &network.Network)),
|
||||
}
|
||||
|
||||
if network.Username != "" {
|
||||
attrs["username"] = irc.TagValue(network.Username)
|
||||
}
|
||||
if realname := GetRealname(&network.user.User, &network.Network); realname != "" {
|
||||
if realname := database.GetRealname(&network.user.User, &network.Network); realname != "" {
|
||||
attrs["realname"] = irc.TagValue(realname)
|
||||
}
|
||||
|
||||
@ -169,7 +171,7 @@ func networkAddrFromAttrs(attrs irc.Tags) string {
|
||||
return addr
|
||||
}
|
||||
|
||||
func updateNetworkAttrs(record *Network, attrs irc.Tags, subcommand string) error {
|
||||
func updateNetworkAttrs(record *database.Network, attrs irc.Tags, subcommand string) error {
|
||||
addrAttrs := irc.Tags{}
|
||||
fillNetworkAddrAttrs(addrAttrs, record)
|
||||
|
||||
@ -414,7 +416,7 @@ func isOurNick(net *network, nick string) bool {
|
||||
// know whether this name is our nickname. Best-effort: use the network's
|
||||
// configured nickname and hope it was the one being used when we were
|
||||
// connected.
|
||||
return net.casemap(nick) == net.casemap(GetNick(&net.user.User, &net.Network))
|
||||
return net.casemap(nick) == net.casemap(database.GetNick(&net.user.User, &net.Network))
|
||||
}
|
||||
|
||||
// marshalEntity converts an upstream entity name (ie. channel or nick) into a
|
||||
@ -1146,9 +1148,9 @@ func (dc *downstreamConn) updateNick() {
|
||||
if uc := dc.upstream(); uc != nil {
|
||||
nick = uc.nick
|
||||
} else if dc.network != nil {
|
||||
nick = GetNick(&dc.user.User, &dc.network.Network)
|
||||
nick = database.GetNick(&dc.user.User, &dc.network.Network)
|
||||
} else {
|
||||
nick = GetNick(&dc.user.User, nil)
|
||||
nick = database.GetNick(&dc.user.User, nil)
|
||||
}
|
||||
|
||||
if nick == dc.nick {
|
||||
@ -1201,9 +1203,9 @@ func (dc *downstreamConn) updateRealname() {
|
||||
if uc := dc.upstream(); uc != nil {
|
||||
realname = uc.realname
|
||||
} else if dc.network != nil {
|
||||
realname = GetRealname(&dc.user.User, &dc.network.Network)
|
||||
realname = database.GetRealname(&dc.user.User, &dc.network.Network)
|
||||
} else {
|
||||
realname = GetRealname(&dc.user.User, nil)
|
||||
realname = database.GetRealname(&dc.user.User, nil)
|
||||
}
|
||||
|
||||
if realname != dc.realname {
|
||||
@ -1439,7 +1441,7 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
|
||||
|
||||
dc.logger.Printf("auto-saving network %q", dc.registration.networkName)
|
||||
var err error
|
||||
network, err = dc.user.createNetwork(ctx, &Network{
|
||||
network, err = dc.user.createNetwork(ctx, &database.Network{
|
||||
Addr: dc.registration.networkName,
|
||||
Nick: nick,
|
||||
Enabled: true,
|
||||
@ -1475,7 +1477,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
|
||||
if uc := dc.upstream(); uc != nil {
|
||||
dc.nick = uc.nick
|
||||
} else if dc.network != nil {
|
||||
dc.nick = GetNick(&dc.user.User, &dc.network.Network)
|
||||
dc.nick = database.GetNick(&dc.user.User, &dc.network.Network)
|
||||
} else {
|
||||
dc.nick = dc.user.Username
|
||||
}
|
||||
@ -1931,7 +1933,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
}
|
||||
uc.network.attach(ctx, ch)
|
||||
} else {
|
||||
ch = &Channel{
|
||||
ch = &database.Channel{
|
||||
Name: upstreamName,
|
||||
Key: key,
|
||||
}
|
||||
@ -1963,7 +1965,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
if ch != nil {
|
||||
uc.network.detach(ch)
|
||||
} else {
|
||||
ch = &Channel{
|
||||
ch = &database.Channel{
|
||||
Name: name,
|
||||
Detached: true,
|
||||
}
|
||||
@ -2911,7 +2913,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"},
|
||||
}}
|
||||
} else if r == nil {
|
||||
r = &ReadReceipt{
|
||||
r = &database.ReadReceipt{
|
||||
Target: entityCM,
|
||||
}
|
||||
}
|
||||
@ -3082,7 +3084,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
}
|
||||
attrs := irc.ParseTags(attrsStr)
|
||||
|
||||
record := &Network{Nick: dc.nick, Enabled: true}
|
||||
record := &database.Network{Nick: dc.nick, Enabled: true}
|
||||
if err := updateNetworkAttrs(record, attrs, subcommand); err != nil {
|
||||
return err
|
||||
}
|
||||
|
6
irc.go
6
irc.go
@ -9,6 +9,8 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"gopkg.in/irc.v3"
|
||||
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -653,12 +655,12 @@ func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
|
||||
|
||||
type channelCasemapMap struct{ casemapMap }
|
||||
|
||||
func (cm *channelCasemapMap) Value(name string) *Channel {
|
||||
func (cm *channelCasemapMap) Value(name string) *database.Channel {
|
||||
entry, ok := cm.innerMap[cm.casemap(name)]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return entry.value.(*Channel)
|
||||
return entry.value.(*database.Channel)
|
||||
}
|
||||
|
||||
type membershipsCasemapMap struct{ casemapMap }
|
||||
|
16
msgstore.go
16
msgstore.go
@ -9,6 +9,8 @@ import (
|
||||
|
||||
"git.sr.ht/~sircmpwn/go-bare"
|
||||
"gopkg.in/irc.v3"
|
||||
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
// messageStore is a per-user store for IRC messages.
|
||||
@ -17,11 +19,11 @@ type messageStore interface {
|
||||
// LastMsgID queries the last message ID for the given network, entity and
|
||||
// date. The message ID returned may not refer to a valid message, but can be
|
||||
// used in history queries.
|
||||
LastMsgID(network *Network, entity string, t time.Time) (string, error)
|
||||
LastMsgID(network *database.Network, entity string, t time.Time) (string, error)
|
||||
// LoadLatestID queries the latest non-event messages for the given network,
|
||||
// entity and date, up to a count of limit messages, sorted from oldest to newest.
|
||||
LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error)
|
||||
Append(network *Network, entity string, msg *irc.Message) (id string, err error)
|
||||
LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error)
|
||||
Append(network *database.Network, entity string, msg *irc.Message) (id string, err error)
|
||||
}
|
||||
|
||||
type chatHistoryTarget struct {
|
||||
@ -38,17 +40,17 @@ type chatHistoryMessageStore interface {
|
||||
// It returns up to limit targets, starting from start and ending on end,
|
||||
// both excluded. end may be before or after start.
|
||||
// If events is false, only PRIVMSG/NOTICE messages are considered.
|
||||
ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
|
||||
ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
|
||||
// LoadBeforeTime loads up to limit messages before start down to end. The
|
||||
// returned messages must be between and excluding the provided bounds.
|
||||
// end is before start.
|
||||
// If events is false, only PRIVMSG/NOTICE messages are considered.
|
||||
LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
|
||||
LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
|
||||
// LoadBeforeTime loads up to limit messages after start up to end. The
|
||||
// returned messages must be between and excluding the provided bounds.
|
||||
// end is after start.
|
||||
// If events is false, only PRIVMSG/NOTICE messages are considered.
|
||||
LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
|
||||
LoadAfterTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
|
||||
}
|
||||
|
||||
type searchOptions struct {
|
||||
@ -66,7 +68,7 @@ type searchMessageStore interface {
|
||||
messageStore
|
||||
|
||||
// Search returns messages matching the specified options.
|
||||
Search(ctx context.Context, network *Network, search searchOptions) ([]*irc.Message, error)
|
||||
Search(ctx context.Context, network *database.Network, search searchOptions) ([]*irc.Message, error)
|
||||
}
|
||||
|
||||
type msgIDType uint
|
||||
|
@ -13,6 +13,8 @@ import (
|
||||
|
||||
"git.sr.ht/~sircmpwn/go-bare"
|
||||
"gopkg.in/irc.v3"
|
||||
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -80,7 +82,7 @@ type fsMessageStoreFile struct {
|
||||
// https://github.com/znc/znc/blob/master/modules/log.cpp
|
||||
type fsMessageStore struct {
|
||||
root string
|
||||
user *User
|
||||
user *database.User
|
||||
|
||||
// Write-only files used by Append
|
||||
files map[string]*fsMessageStoreFile // indexed by entity
|
||||
@ -90,7 +92,7 @@ var _ messageStore = (*fsMessageStore)(nil)
|
||||
var _ chatHistoryMessageStore = (*fsMessageStore)(nil)
|
||||
var _ searchMessageStore = (*fsMessageStore)(nil)
|
||||
|
||||
func newFSMessageStore(root string, user *User) *fsMessageStore {
|
||||
func newFSMessageStore(root string, user *database.User) *fsMessageStore {
|
||||
return &fsMessageStore{
|
||||
root: filepath.Join(root, escapeFilename(user.Username)),
|
||||
user: user,
|
||||
@ -98,14 +100,14 @@ func newFSMessageStore(root string, user *User) *fsMessageStore {
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string {
|
||||
func (ms *fsMessageStore) logPath(network *database.Network, entity string, t time.Time) string {
|
||||
year, month, day := t.Date()
|
||||
filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
|
||||
return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename)
|
||||
}
|
||||
|
||||
// nextMsgID queries the message ID for the next message to be written to f.
|
||||
func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) {
|
||||
func nextFSMsgID(network *database.Network, entity string, t time.Time, f *os.File) (string, error) {
|
||||
offset, err := f.Seek(0, io.SeekEnd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to query next FS message ID: %v", err)
|
||||
@ -113,7 +115,7 @@ func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (stri
|
||||
return formatFSMsgID(network.ID, entity, t, offset), nil
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
|
||||
func (ms *fsMessageStore) LastMsgID(network *database.Network, entity string, t time.Time) (string, error) {
|
||||
p := ms.logPath(network, entity, t)
|
||||
fi, err := os.Stat(p)
|
||||
if os.IsNotExist(err) {
|
||||
@ -124,7 +126,7 @@ func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time
|
||||
return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
|
||||
func (ms *fsMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) {
|
||||
s := formatMessage(msg)
|
||||
if s == "" {
|
||||
return "", nil
|
||||
@ -253,7 +255,7 @@ func formatMessage(msg *irc.Message) string {
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) parseMessage(line string, network *Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) {
|
||||
func (ms *fsMessageStore) parseMessage(line string, network *database.Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) {
|
||||
var hour, minute, second int
|
||||
_, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second)
|
||||
if err != nil {
|
||||
@ -380,7 +382,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *Network, entity str
|
||||
// our nickname in the logs, so grab it from the network settings.
|
||||
// Not very accurate since this may not match our nick at the time
|
||||
// the message was received, but we can't do a lot better.
|
||||
entity = GetNick(ms.user, network)
|
||||
entity = database.GetNick(ms.user, network)
|
||||
}
|
||||
params = []string{entity, text}
|
||||
}
|
||||
@ -399,7 +401,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *Network, entity str
|
||||
return msg, t, nil
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
||||
func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
||||
path := ms.logPath(network, entity, ref)
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
@ -458,7 +460,7 @@ func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, r
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
||||
func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
||||
path := ms.logPath(network, entity, ref)
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
@ -493,7 +495,7 @@ func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, re
|
||||
return history, nil
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
||||
func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
||||
if start.IsZero() {
|
||||
start = time.Now()
|
||||
} else {
|
||||
@ -526,11 +528,11 @@ func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *Network, e
|
||||
return messages[remaining:], nil
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
|
||||
func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
|
||||
return ms.getBeforeTime(ctx, network, entity, start, end, limit, events, nil)
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
||||
func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
||||
start = start.In(time.Local)
|
||||
if end.IsZero() {
|
||||
end = time.Now()
|
||||
@ -562,11 +564,11 @@ func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *Network, en
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
|
||||
func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
|
||||
return ms.getAfterTime(ctx, network, entity, start, end, limit, events, nil)
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
|
||||
func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) {
|
||||
var afterTime time.Time
|
||||
var afterOffset int64
|
||||
if id != "" {
|
||||
@ -614,7 +616,7 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, en
|
||||
return history[remaining:], nil
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
|
||||
func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
|
||||
start = start.In(time.Local)
|
||||
end = end.In(time.Local)
|
||||
rootPath := filepath.Join(ms.root, escapeFilename(network.GetName()))
|
||||
@ -693,7 +695,7 @@ func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, sta
|
||||
return targets, nil
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) Search(ctx context.Context, network *Network, opts searchOptions) ([]*irc.Message, error) {
|
||||
func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, opts searchOptions) ([]*irc.Message, error) {
|
||||
text := strings.ToLower(opts.text)
|
||||
selector := func(m *irc.Message) bool {
|
||||
if opts.from != "" && m.User != opts.from {
|
||||
@ -711,7 +713,7 @@ func (ms *fsMessageStore) Search(ctx context.Context, network *Network, opts sea
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error {
|
||||
func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *database.Network) error {
|
||||
oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName()))
|
||||
newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName()))
|
||||
// Avoid loosing data by overwriting an existing directory
|
||||
|
@ -7,6 +7,8 @@ import (
|
||||
|
||||
"git.sr.ht/~sircmpwn/go-bare"
|
||||
"gopkg.in/irc.v3"
|
||||
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
const messageRingBufferCap = 4096
|
||||
@ -55,7 +57,7 @@ func (ms *memoryMessageStore) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer {
|
||||
func (ms *memoryMessageStore) get(network *database.Network, entity string) *messageRingBuffer {
|
||||
k := ringBufferKey{networkID: network.ID, entity: entity}
|
||||
if rb, ok := ms.buffers[k]; ok {
|
||||
return rb
|
||||
@ -65,7 +67,7 @@ func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingB
|
||||
return rb
|
||||
}
|
||||
|
||||
func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
|
||||
func (ms *memoryMessageStore) LastMsgID(network *database.Network, entity string, t time.Time) (string, error) {
|
||||
var seq uint64
|
||||
k := ringBufferKey{networkID: network.ID, entity: entity}
|
||||
if rb, ok := ms.buffers[k]; ok {
|
||||
@ -74,7 +76,7 @@ func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.
|
||||
return formatMemoryMsgID(network.ID, entity, seq), nil
|
||||
}
|
||||
|
||||
func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
|
||||
func (ms *memoryMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) {
|
||||
switch msg.Command {
|
||||
case "PRIVMSG", "NOTICE":
|
||||
// Only append these messages, because LoadLatestID shouldn't return
|
||||
@ -94,7 +96,7 @@ func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.M
|
||||
return formatMemoryMsgID(network.ID, entity, seq), nil
|
||||
}
|
||||
|
||||
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
|
||||
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) {
|
||||
_, _, seq, err := parseMemoryMsgID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"git.sr.ht/~emersion/soju/config"
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
// TODO: make configurable
|
||||
@ -141,7 +142,7 @@ type Server struct {
|
||||
MetricsRegistry prometheus.Registerer // can be nil
|
||||
|
||||
config atomic.Value // *Config
|
||||
db Database
|
||||
db database.Database
|
||||
stopWG sync.WaitGroup
|
||||
|
||||
lock sync.Mutex
|
||||
@ -161,7 +162,7 @@ type Server struct {
|
||||
}
|
||||
}
|
||||
|
||||
func NewServer(db Database) *Server {
|
||||
func NewServer(db database.Database) *Server {
|
||||
srv := &Server{
|
||||
Logger: NewLogger(log.Writer(), true),
|
||||
db: db,
|
||||
@ -273,7 +274,7 @@ func (s *Server) Shutdown() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
|
||||
func (s *Server) createUser(ctx context.Context, user *database.User) (*user, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
@ -304,7 +305,7 @@ func (s *Server) getUser(name string) *user {
|
||||
return u
|
||||
}
|
||||
|
||||
func (s *Server) addUserLocked(user *User) *user {
|
||||
func (s *Server) addUserLocked(user *database.User) *user {
|
||||
s.Logger.Printf("starting bouncer for user %q", user.Username)
|
||||
u := newUser(s, user)
|
||||
s.users[u.Username] = u
|
||||
|
@ -3,10 +3,13 @@ package soju
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/irc.v3"
|
||||
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
var testServerPrefix = &irc.Prefix{Name: "soju-test-server"}
|
||||
@ -16,34 +19,35 @@ const (
|
||||
testPassword = testUsername
|
||||
)
|
||||
|
||||
func createTempSqliteDB(t *testing.T) Database {
|
||||
db, err := OpenDB("sqlite3", ":memory:")
|
||||
func createTempSqliteDB(t *testing.T) database.Database {
|
||||
db, err := database.OpenTempSqliteDB()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temporary SQLite database: %v", err)
|
||||
}
|
||||
// :memory: will open a separate database for each new connection. Make
|
||||
// sure the sql package only uses a single connection. An alternative
|
||||
// solution is to use "file::memory:?cache=shared".
|
||||
db.(*SqliteDB).db.SetMaxOpenConns(1)
|
||||
return db
|
||||
}
|
||||
|
||||
func createTempPostgresDB(t *testing.T) Database {
|
||||
db := &PostgresDB{db: openTempPostgresDB(t)}
|
||||
if err := db.upgrade(); err != nil {
|
||||
t.Fatalf("failed to upgrade PostgreSQL database: %v", err)
|
||||
func createTempPostgresDB(t *testing.T) database.Database {
|
||||
source, ok := os.LookupEnv("SOJU_TEST_POSTGRES")
|
||||
if !ok {
|
||||
t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests")
|
||||
}
|
||||
|
||||
db, err := database.OpenTempPostgresDB(source)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temporary PostgreSQL database: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func createTestUser(t *testing.T, db Database) *User {
|
||||
func createTestUser(t *testing.T, db database.Database) *database.User {
|
||||
hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate bcrypt hash: %v", err)
|
||||
}
|
||||
|
||||
record := &User{Username: testUsername, Password: string(hashed)}
|
||||
record := &database.User{Username: testUsername, Password: string(hashed)}
|
||||
if err := db.StoreUser(context.Background(), record); err != nil {
|
||||
t.Fatalf("failed to store test user: %v", err)
|
||||
}
|
||||
@ -57,13 +61,13 @@ func createTestDownstream(t *testing.T, srv *Server) ircConn {
|
||||
return newNetIRCConn(c2)
|
||||
}
|
||||
|
||||
func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Listener) {
|
||||
func createTestUpstream(t *testing.T, db database.Database, user *database.User) (*database.Network, net.Listener) {
|
||||
ln, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create TCP listener: %v", err)
|
||||
}
|
||||
|
||||
network := &Network{
|
||||
network := &database.Network{
|
||||
Name: "testnet",
|
||||
Addr: "irc+insecure://" + ln.Addr().String(),
|
||||
Nick: user.Username,
|
||||
@ -95,7 +99,7 @@ func expectMessage(t *testing.T, c ircConn, cmd string) *irc.Message {
|
||||
return msg
|
||||
}
|
||||
|
||||
func registerDownstreamConn(t *testing.T, c ircConn, network *Network) {
|
||||
func registerDownstreamConn(t *testing.T, c ircConn, network *database.Network) {
|
||||
c.WriteMessage(&irc.Message{
|
||||
Command: "PASS",
|
||||
Params: []string{testPassword},
|
||||
@ -151,7 +155,7 @@ func registerUpstreamConn(t *testing.T, c ircConn) {
|
||||
})
|
||||
}
|
||||
|
||||
func testServer(t *testing.T, db Database) {
|
||||
func testServer(t *testing.T, db database.Database) {
|
||||
user := createTestUser(t, db)
|
||||
network, upstream := createTestUpstream(t, db, user)
|
||||
defer upstream.Close()
|
||||
|
28
service.go
28
service.go
@ -17,6 +17,8 @@ import (
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/irc.v3"
|
||||
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
const serviceNick = "BouncerServ"
|
||||
@ -447,7 +449,7 @@ func newNetworkFlagSet() *networkFlagSet {
|
||||
return fs
|
||||
}
|
||||
|
||||
func (fs *networkFlagSet) update(network *Network) error {
|
||||
func (fs *networkFlagSet) update(network *database.Network) error {
|
||||
if fs.Addr != nil {
|
||||
if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
|
||||
scheme := addrParts[0]
|
||||
@ -508,7 +510,7 @@ func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params
|
||||
return fmt.Errorf("flag -addr is required")
|
||||
}
|
||||
|
||||
record := &Network{
|
||||
record := &database.Network{
|
||||
Addr: *fs.Addr,
|
||||
Enabled: true,
|
||||
}
|
||||
@ -833,7 +835,7 @@ func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string)
|
||||
return fmt.Errorf("failed to hash password: %v", err)
|
||||
}
|
||||
|
||||
user := &User{
|
||||
user := &database.User{
|
||||
Username: *username,
|
||||
Password: string(hashed),
|
||||
Realname: *realname,
|
||||
@ -971,9 +973,9 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
|
||||
n := 0
|
||||
|
||||
sendNetwork := func(net *network) {
|
||||
var channels []*Channel
|
||||
var channels []*database.Channel
|
||||
for _, entry := range net.channels.innerMap {
|
||||
channels = append(channels, entry.value.(*Channel))
|
||||
channels = append(channels, entry.value.(*database.Channel))
|
||||
}
|
||||
|
||||
sort.Slice(channels, func(i, j int) bool {
|
||||
@ -1031,6 +1033,20 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseFilter(filter string) (database.MessageFilter, error) {
|
||||
switch filter {
|
||||
case "default":
|
||||
return database.FilterDefault, nil
|
||||
case "none":
|
||||
return database.FilterNone, nil
|
||||
case "highlight":
|
||||
return database.FilterHighlight, nil
|
||||
case "message":
|
||||
return database.FilterMessage, nil
|
||||
}
|
||||
return 0, fmt.Errorf("unknown filter: %q", filter)
|
||||
}
|
||||
|
||||
type channelFlagSet struct {
|
||||
*flag.FlagSet
|
||||
RelayDetached, ReattachOn, DetachAfter, DetachOn *string
|
||||
@ -1045,7 +1061,7 @@ func newChannelFlagSet() *channelFlagSet {
|
||||
return fs
|
||||
}
|
||||
|
||||
func (fs *channelFlagSet) update(channel *Channel) error {
|
||||
func (fs *channelFlagSet) update(channel *database.Channel) error {
|
||||
if fs.RelayDetached != nil {
|
||||
filter, err := parseFilter(*fs.RelayDetached)
|
||||
if err != nil {
|
||||
|
20
upstream.go
20
upstream.go
@ -17,6 +17,8 @@ import (
|
||||
|
||||
"github.com/emersion/go-sasl"
|
||||
"gopkg.in/irc.v3"
|
||||
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
// permanentUpstreamCaps is the static list of upstream capabilities always
|
||||
@ -510,7 +512,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
}
|
||||
|
||||
highlight := uc.network.isHighlight(msg)
|
||||
if ch.DetachOn == FilterMessage || ch.DetachOn == FilterDefault || (ch.DetachOn == FilterHighlight && highlight) {
|
||||
if ch.DetachOn == database.FilterMessage || ch.DetachOn == database.FilterDefault || (ch.DetachOn == database.FilterHighlight && highlight) {
|
||||
uc.updateChannelAutoDetach(target)
|
||||
}
|
||||
}
|
||||
@ -765,7 +767,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
if uc.network.channels.Len() > 0 {
|
||||
var channels, keys []string
|
||||
for _, entry := range uc.network.channels.innerMap {
|
||||
ch := entry.value.(*Channel)
|
||||
ch := entry.value.(*database.Channel)
|
||||
channels = append(channels, ch.Name)
|
||||
keys = append(keys, ch.Key)
|
||||
}
|
||||
@ -1553,7 +1555,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
}
|
||||
|
||||
// Check if the nick we want is now free
|
||||
wantNick := GetNick(&uc.user.User, &uc.network.Network)
|
||||
wantNick := database.GetNick(&uc.user.User, &uc.network.Network)
|
||||
wantNickCM := uc.network.casemap(wantNick)
|
||||
if !online && uc.nickCM != wantNickCM {
|
||||
found := false
|
||||
@ -1796,13 +1798,13 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *Channel, msg *irc.Message) {
|
||||
func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *database.Channel, msg *irc.Message) {
|
||||
if uc.network.detachedMessageNeedsRelay(ch, msg) {
|
||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||
dc.relayDetachedMessage(uc.network, msg)
|
||||
})
|
||||
}
|
||||
if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) {
|
||||
if ch.ReattachOn == database.FilterMessage || (ch.ReattachOn == database.FilterHighlight && uc.network.isHighlight(msg)) {
|
||||
uc.network.attach(ctx, ch)
|
||||
if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
|
||||
uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
|
||||
@ -1960,10 +1962,10 @@ func splitSpace(s string) []string {
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) register(ctx context.Context) {
|
||||
uc.nick = GetNick(&uc.user.User, &uc.network.Network)
|
||||
uc.nick = database.GetNick(&uc.user.User, &uc.network.Network)
|
||||
uc.nickCM = uc.network.casemap(uc.nick)
|
||||
uc.username = GetUsername(&uc.user.User, &uc.network.Network)
|
||||
uc.realname = GetRealname(&uc.user.User, &uc.network.Network)
|
||||
uc.username = database.GetUsername(&uc.user.User, &uc.network.Network)
|
||||
uc.realname = database.GetRealname(&uc.user.User, &uc.network.Network)
|
||||
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "CAP",
|
||||
@ -2193,7 +2195,7 @@ func (uc *upstreamConn) updateMonitor() {
|
||||
}
|
||||
})
|
||||
|
||||
wantNick := GetNick(&uc.user.User, &uc.network.Network)
|
||||
wantNick := database.GetNick(&uc.user.User, &uc.network.Network)
|
||||
wantNickCM := uc.network.casemap(wantNick)
|
||||
if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) {
|
||||
addList = append(addList, wantNickCM)
|
||||
|
40
user.go
40
user.go
@ -14,6 +14,8 @@ import (
|
||||
"time"
|
||||
|
||||
"gopkg.in/irc.v3"
|
||||
|
||||
"git.sr.ht/~emersion/soju/database"
|
||||
)
|
||||
|
||||
type event interface{}
|
||||
@ -123,7 +125,7 @@ func (ds deliveredStore) ForEachClient(f func(clientName string)) {
|
||||
}
|
||||
|
||||
type network struct {
|
||||
Network
|
||||
database.Network
|
||||
user *user
|
||||
logger Logger
|
||||
stopped chan struct{}
|
||||
@ -135,7 +137,7 @@ type network struct {
|
||||
casemap casemapping
|
||||
}
|
||||
|
||||
func newNetwork(user *user, record *Network, channels []Channel) *network {
|
||||
func newNetwork(user *user, record *database.Network, channels []database.Channel) *network {
|
||||
logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
|
||||
|
||||
m := channelCasemapMap{newCasemapMap(0)}
|
||||
@ -176,7 +178,7 @@ func (net *network) isStopped() bool {
|
||||
}
|
||||
}
|
||||
|
||||
func userIdent(u *User) string {
|
||||
func userIdent(u *database.User) string {
|
||||
// The ident is a string we will send to upstream servers in clear-text.
|
||||
// For privacy reasons, make sure it doesn't expose any meaningful user
|
||||
// metadata. We just use the base64-encoded hashed ID, so that people don't
|
||||
@ -278,7 +280,7 @@ func (net *network) stop() {
|
||||
}
|
||||
}
|
||||
|
||||
func (net *network) detach(ch *Channel) {
|
||||
func (net *network) detach(ch *database.Channel) {
|
||||
if ch.Detached {
|
||||
return
|
||||
}
|
||||
@ -312,7 +314,7 @@ func (net *network) detach(ch *Channel) {
|
||||
})
|
||||
}
|
||||
|
||||
func (net *network) attach(ctx context.Context, ch *Channel) {
|
||||
func (net *network) attach(ctx context.Context, ch *database.Channel) {
|
||||
if !ch.Detached {
|
||||
return
|
||||
}
|
||||
@ -388,13 +390,13 @@ func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName
|
||||
return
|
||||
}
|
||||
|
||||
var receipts []DeliveryReceipt
|
||||
var receipts []database.DeliveryReceipt
|
||||
net.delivered.ForEachTarget(func(target string) {
|
||||
msgID := net.delivered.LoadID(target, clientName)
|
||||
if msgID == "" {
|
||||
return
|
||||
}
|
||||
receipts = append(receipts, DeliveryReceipt{
|
||||
receipts = append(receipts, database.DeliveryReceipt{
|
||||
Target: target,
|
||||
InternalMsgID: msgID,
|
||||
})
|
||||
@ -421,9 +423,9 @@ func (net *network) isHighlight(msg *irc.Message) bool {
|
||||
return msg.Prefix.Name != nick && isHighlight(text, nick)
|
||||
}
|
||||
|
||||
func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool {
|
||||
func (net *network) detachedMessageNeedsRelay(ch *database.Channel, msg *irc.Message) bool {
|
||||
highlight := net.isHighlight(msg)
|
||||
return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
|
||||
return ch.RelayDetached == database.FilterMessage || ((ch.RelayDetached == database.FilterHighlight || ch.RelayDetached == database.FilterDefault) && highlight)
|
||||
}
|
||||
|
||||
func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) {
|
||||
@ -443,7 +445,7 @@ func (net *network) autoSaveSASLPlain(ctx context.Context, username, password st
|
||||
}
|
||||
|
||||
type user struct {
|
||||
User
|
||||
database.User
|
||||
srv *Server
|
||||
logger Logger
|
||||
|
||||
@ -455,7 +457,7 @@ type user struct {
|
||||
msgStore messageStore
|
||||
}
|
||||
|
||||
func newUser(srv *Server, record *User) *user {
|
||||
func newUser(srv *Server, record *database.User) *user {
|
||||
logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
|
||||
|
||||
var msgStore messageStore
|
||||
@ -817,7 +819,7 @@ func (u *user) removeNetwork(network *network) {
|
||||
panic("tried to remove a non-existing network")
|
||||
}
|
||||
|
||||
func (u *user) checkNetwork(record *Network) error {
|
||||
func (u *user) checkNetwork(record *database.Network) error {
|
||||
url, err := record.URL()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -867,7 +869,7 @@ func (u *user) checkNetwork(record *Network) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) {
|
||||
func (u *user) createNetwork(ctx context.Context, record *database.Network) (*network, error) {
|
||||
if record.ID != 0 {
|
||||
panic("tried creating an already-existing network")
|
||||
}
|
||||
@ -894,7 +896,7 @@ func (u *user) createNetwork(ctx context.Context, record *Network) (*network, er
|
||||
return network, nil
|
||||
}
|
||||
|
||||
func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) {
|
||||
func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*network, error) {
|
||||
if record.ID == 0 {
|
||||
panic("tried updating a new network")
|
||||
}
|
||||
@ -920,9 +922,9 @@ func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, er
|
||||
|
||||
// Most network changes require us to re-connect to the upstream server
|
||||
|
||||
channels := make([]Channel, 0, network.channels.Len())
|
||||
channels := make([]database.Channel, 0, network.channels.Len())
|
||||
for _, entry := range network.channels.innerMap {
|
||||
ch := entry.value.(*Channel)
|
||||
ch := entry.value.(*database.Channel)
|
||||
channels = append(channels, *ch)
|
||||
}
|
||||
|
||||
@ -992,7 +994,7 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *user) updateUser(ctx context.Context, record *User) error {
|
||||
func (u *user) updateUser(ctx context.Context, record *database.User) error {
|
||||
if u.ID != record.ID {
|
||||
panic("ID mismatch when updating user")
|
||||
}
|
||||
@ -1005,7 +1007,7 @@ func (u *user) updateUser(ctx context.Context, record *User) error {
|
||||
|
||||
if realnameUpdated {
|
||||
// Re-connect to networks which use the default realname
|
||||
var needUpdate []Network
|
||||
var needUpdate []database.Network
|
||||
for _, net := range u.networks {
|
||||
if net.Realname != "" {
|
||||
continue
|
||||
@ -1016,7 +1018,7 @@ func (u *user) updateUser(ctx context.Context, record *User) error {
|
||||
if uc := net.conn; uc != nil && uc.caps.IsEnabled("setname") {
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "SETNAME",
|
||||
Params: []string{GetRealname(&u.User, &net.Network)},
|
||||
Params: []string{database.GetRealname(&u.User, &net.Network)},
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user