Introduce a database package
This commit is contained in:
parent
27f21eab94
commit
3a7dee8128
@ -24,6 +24,7 @@ import (
|
|||||||
|
|
||||||
"git.sr.ht/~emersion/soju"
|
"git.sr.ht/~emersion/soju"
|
||||||
"git.sr.ht/~emersion/soju/config"
|
"git.sr.ht/~emersion/soju/config"
|
||||||
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TCP keep-alive interval for downstream TCP connections
|
// TCP keep-alive interval for downstream TCP connections
|
||||||
@ -116,7 +117,7 @@ func main() {
|
|||||||
log.Printf("failed to bump max number of opened files: %v", err)
|
log.Printf("failed to bump max number of opened files: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
|
db, err := database.Open(cfg.SQLDriver, cfg.SQLSource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to open database: %v", err)
|
log.Fatalf("failed to open database: %v", err)
|
||||||
}
|
}
|
||||||
@ -308,7 +309,7 @@ func main() {
|
|||||||
log.Printf("server listening on %q", listen)
|
log.Printf("server listening on %q", listen)
|
||||||
}
|
}
|
||||||
|
|
||||||
if db, ok := db.(soju.MetricsCollectorDatabase); ok && srv.MetricsRegistry != nil {
|
if db, ok := db.(database.MetricsCollectorDatabase); ok && srv.MetricsRegistry != nil {
|
||||||
if err := db.RegisterMetrics(srv.MetricsRegistry); err != nil {
|
if err := db.RegisterMetrics(srv.MetricsRegistry); err != nil {
|
||||||
log.Fatalf("failed to register database metrics: %v", err)
|
log.Fatalf("failed to register database metrics: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -9,10 +9,11 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"git.sr.ht/~emersion/soju"
|
|
||||||
"git.sr.ht/~emersion/soju/config"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"golang.org/x/crypto/ssh/terminal"
|
"golang.org/x/crypto/ssh/terminal"
|
||||||
|
|
||||||
|
"git.sr.ht/~emersion/soju/config"
|
||||||
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usage = `usage: sojuctl [-config path] <action> [options...]
|
const usage = `usage: sojuctl [-config path] <action> [options...]
|
||||||
@ -44,7 +45,7 @@ func main() {
|
|||||||
cfg = config.Defaults()
|
cfg = config.Defaults()
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
|
db, err := database.Open(cfg.SQLDriver, cfg.SQLSource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to open database: %v", err)
|
log.Fatalf("failed to open database: %v", err)
|
||||||
}
|
}
|
||||||
@ -73,7 +74,7 @@ func main() {
|
|||||||
log.Fatalf("failed to hash password: %v", err)
|
log.Fatalf("failed to hash password: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
user := soju.User{
|
user := database.User{
|
||||||
Username: username,
|
Username: username,
|
||||||
Password: string(hashed),
|
Password: string(hashed),
|
||||||
Admin: *admin,
|
Admin: *admin,
|
||||||
|
@ -12,8 +12,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
"git.sr.ht/~emersion/soju"
|
|
||||||
"git.sr.ht/~emersion/soju/config"
|
"git.sr.ht/~emersion/soju/config"
|
||||||
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usage = `usage: znc-import [options...] <znc config path>
|
const usage = `usage: znc-import [options...] <znc config path>
|
||||||
@ -64,7 +64,7 @@ func main() {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
db, err := soju.OpenDB(cfg.SQLDriver, cfg.SQLSource)
|
db, err := database.Open(cfg.SQLDriver, cfg.SQLSource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to open database: %v", err)
|
log.Fatalf("failed to open database: %v", err)
|
||||||
}
|
}
|
||||||
@ -86,7 +86,7 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to list users in DB: %v", err)
|
log.Fatalf("failed to list users in DB: %v", err)
|
||||||
}
|
}
|
||||||
existingUsers := make(map[string]*soju.User, len(l))
|
existingUsers := make(map[string]*database.User, len(l))
|
||||||
for i, u := range l {
|
for i, u := range l {
|
||||||
existingUsers[u.Username] = &l[i]
|
existingUsers[u.Username] = &l[i]
|
||||||
}
|
}
|
||||||
@ -107,7 +107,7 @@ func main() {
|
|||||||
log.Printf("user %q: updating existing user", username)
|
log.Printf("user %q: updating existing user", username)
|
||||||
} else {
|
} else {
|
||||||
// "!!" is an invalid crypt format, thus disables password auth
|
// "!!" is an invalid crypt format, thus disables password auth
|
||||||
u = &soju.User{Username: username, Password: "!!"}
|
u = &database.User{Username: username, Password: "!!"}
|
||||||
usersCreated++
|
usersCreated++
|
||||||
log.Printf("user %q: creating new user", username)
|
log.Printf("user %q: creating new user", username)
|
||||||
}
|
}
|
||||||
@ -123,7 +123,7 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to list networks for user %q: %v", username, err)
|
log.Fatalf("failed to list networks for user %q: %v", username, err)
|
||||||
}
|
}
|
||||||
existingNetworks := make(map[string]*soju.Network, len(l))
|
existingNetworks := make(map[string]*database.Network, len(l))
|
||||||
for i, n := range l {
|
for i, n := range l {
|
||||||
existingNetworks[n.GetName()] = &l[i]
|
existingNetworks[n.GetName()] = &l[i]
|
||||||
}
|
}
|
||||||
@ -175,7 +175,7 @@ func main() {
|
|||||||
if ok {
|
if ok {
|
||||||
logger.Printf("updating existing network")
|
logger.Printf("updating existing network")
|
||||||
} else {
|
} else {
|
||||||
n = &soju.Network{Name: netName}
|
n = &database.Network{Name: netName}
|
||||||
logger.Printf("creating new network")
|
logger.Printf("creating new network")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,7 +194,7 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("failed to list channels: %v", err)
|
logger.Fatalf("failed to list channels: %v", err)
|
||||||
}
|
}
|
||||||
existingChannels := make(map[string]*soju.Channel, len(l))
|
existingChannels := make(map[string]*database.Channel, len(l))
|
||||||
for i, ch := range l {
|
for i, ch := range l {
|
||||||
existingChannels[ch.Name] = &l[i]
|
existingChannels[ch.Name] = &l[i]
|
||||||
}
|
}
|
||||||
@ -213,7 +213,7 @@ func main() {
|
|||||||
if ok {
|
if ok {
|
||||||
logger.Printf("channel %q: updating existing channel", chName)
|
logger.Printf("channel %q: updating existing channel", chName)
|
||||||
} else {
|
} else {
|
||||||
ch = &soju.Channel{Name: chName}
|
ch = &database.Channel{Name: chName}
|
||||||
logger.Printf("channel %q: creating new channel", chName)
|
logger.Printf("channel %q: creating new channel", chName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package soju
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -38,7 +38,7 @@ type MetricsCollectorDatabase interface {
|
|||||||
RegisterMetrics(r prometheus.Registerer) error
|
RegisterMetrics(r prometheus.Registerer) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenDB(driver, source string) (Database, error) {
|
func Open(driver, source string) (Database, error) {
|
||||||
switch driver {
|
switch driver {
|
||||||
case "sqlite3":
|
case "sqlite3":
|
||||||
return OpenSqliteDB(source)
|
return OpenSqliteDB(source)
|
||||||
@ -149,20 +149,6 @@ const (
|
|||||||
FilterMessage
|
FilterMessage
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseFilter(filter string) (MessageFilter, error) {
|
|
||||||
switch filter {
|
|
||||||
case "default":
|
|
||||||
return FilterDefault, nil
|
|
||||||
case "none":
|
|
||||||
return FilterNone, nil
|
|
||||||
case "highlight":
|
|
||||||
return FilterHighlight, nil
|
|
||||||
case "message":
|
|
||||||
return FilterMessage, nil
|
|
||||||
}
|
|
||||||
return 0, fmt.Errorf("unknown filter: %q", filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Channel struct {
|
type Channel struct {
|
||||||
ID int64
|
ID int64
|
||||||
Name string
|
Name string
|
@ -1,4 +1,4 @@
|
|||||||
package soju
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -127,6 +127,37 @@ func OpenPostgresDB(source string) (Database, error) {
|
|||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func openTempPostgresDB(source string) (*sql.DB, error) {
|
||||||
|
db, err := sql.Open("postgres", source)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect to PostgreSQL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store all tables in a temporary schema which will be dropped when the
|
||||||
|
// connection to PostgreSQL is closed.
|
||||||
|
db.SetMaxOpenConns(1)
|
||||||
|
if _, err := db.Exec("SET search_path TO pg_temp"); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set PostgreSQL search_path: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenTempPostgresDB(source string) (Database, error) {
|
||||||
|
sqlPostgresDB, err := openTempPostgresDB(source)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
db := &PostgresDB{db: sqlPostgresDB}
|
||||||
|
if err := db.upgrade(); err != nil {
|
||||||
|
sqlPostgresDB.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (db *PostgresDB) upgrade() error {
|
func (db *PostgresDB) upgrade() error {
|
||||||
tx, err := db.db.Begin()
|
tx, err := db.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
@ -1,7 +1,6 @@
|
|||||||
package soju
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -68,29 +67,17 @@ CREATE TABLE "DeliveryReceipt" (
|
|||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
func openTempPostgresDB(t *testing.T) *sql.DB {
|
func TestPostgresMigrations(t *testing.T) {
|
||||||
source, ok := os.LookupEnv("SOJU_TEST_POSTGRES")
|
source, ok := os.LookupEnv("SOJU_TEST_POSTGRES")
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests")
|
t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests")
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := sql.Open("postgres", source)
|
sqlDB, err := openTempPostgresDB(source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to connect to PostgreSQL: %v", err)
|
t.Fatalf("openTempPostgresDB() failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store all tables in a temporary schema which will be dropped when the
|
|
||||||
// connection to PostgreSQL is closed.
|
|
||||||
db.SetMaxOpenConns(1)
|
|
||||||
if _, err := db.Exec("SET search_path TO pg_temp"); err != nil {
|
|
||||||
t.Fatalf("failed to set PostgreSQL search_path: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresMigrations(t *testing.T) {
|
|
||||||
sqlDB := openTempPostgresDB(t)
|
|
||||||
if _, err := sqlDB.Exec(postgresV0Schema); err != nil {
|
if _, err := sqlDB.Exec(postgresV0Schema); err != nil {
|
||||||
t.Fatalf("DB.Exec() failed for v0 schema: %v", err)
|
t.Fatalf("DB.Exec() failed for v0 schema: %v", err)
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package soju
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -15,6 +15,12 @@ import (
|
|||||||
|
|
||||||
const sqliteQueryTimeout = 5 * time.Second
|
const sqliteQueryTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
const sqliteTimeLayout = "2006-01-02T15:04:05.000Z"
|
||||||
|
|
||||||
|
func formatSqliteTime(t time.Time) string {
|
||||||
|
return t.UTC().Format(sqliteTimeLayout)
|
||||||
|
}
|
||||||
|
|
||||||
const sqliteSchema = `
|
const sqliteSchema = `
|
||||||
CREATE TABLE User (
|
CREATE TABLE User (
|
||||||
id INTEGER PRIMARY KEY,
|
id INTEGER PRIMARY KEY,
|
||||||
@ -212,6 +218,13 @@ func OpenSqliteDB(source string) (Database, error) {
|
|||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func OpenTempSqliteDB() (Database, error) {
|
||||||
|
// :memory: will open a separate database for each new connection. Make
|
||||||
|
// sure the sql package only uses a single connection via SetMaxOpenConns.
|
||||||
|
// An alternative solution is to use "file::memory:?cache=shared".
|
||||||
|
return OpenSqliteDB(":memory:")
|
||||||
|
}
|
||||||
|
|
||||||
func (db *SqliteDB) Close() error {
|
func (db *SqliteDB) Close() error {
|
||||||
return db.db.Close()
|
return db.db.Close()
|
||||||
}
|
}
|
||||||
@ -732,7 +745,7 @@ func (db *SqliteDB) GetReadReceipt(ctx context.Context, networkID int64, name st
|
|||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if t, err := time.Parse(serverTimeLayout, timestamp); err != nil {
|
if t, err := time.Parse(sqliteTimeLayout, timestamp); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
receipt.Timestamp = t
|
receipt.Timestamp = t
|
||||||
@ -746,7 +759,7 @@ func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, recei
|
|||||||
|
|
||||||
args := []interface{}{
|
args := []interface{}{
|
||||||
sql.Named("id", receipt.ID),
|
sql.Named("id", receipt.ID),
|
||||||
sql.Named("timestamp", formatServerTime(receipt.Timestamp)),
|
sql.Named("timestamp", formatSqliteTime(receipt.Timestamp)),
|
||||||
sql.Named("network", networkID),
|
sql.Named("network", networkID),
|
||||||
sql.Named("target", receipt.Target),
|
sql.Named("target", receipt.Target),
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package soju
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
@ -16,6 +16,8 @@ import (
|
|||||||
"github.com/emersion/go-sasl"
|
"github.com/emersion/go-sasl"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"gopkg.in/irc.v3"
|
"gopkg.in/irc.v3"
|
||||||
|
|
||||||
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ircError struct {
|
type ircError struct {
|
||||||
@ -100,7 +102,7 @@ func parseBouncerNetID(subcommand, s string) (int64, error) {
|
|||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fillNetworkAddrAttrs(attrs irc.Tags, network *Network) {
|
func fillNetworkAddrAttrs(attrs irc.Tags, network *database.Network) {
|
||||||
u, err := network.URL()
|
u, err := network.URL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -132,13 +134,13 @@ func getNetworkAttrs(network *network) irc.Tags {
|
|||||||
attrs := irc.Tags{
|
attrs := irc.Tags{
|
||||||
"name": irc.TagValue(network.GetName()),
|
"name": irc.TagValue(network.GetName()),
|
||||||
"state": irc.TagValue(state),
|
"state": irc.TagValue(state),
|
||||||
"nickname": irc.TagValue(GetNick(&network.user.User, &network.Network)),
|
"nickname": irc.TagValue(database.GetNick(&network.user.User, &network.Network)),
|
||||||
}
|
}
|
||||||
|
|
||||||
if network.Username != "" {
|
if network.Username != "" {
|
||||||
attrs["username"] = irc.TagValue(network.Username)
|
attrs["username"] = irc.TagValue(network.Username)
|
||||||
}
|
}
|
||||||
if realname := GetRealname(&network.user.User, &network.Network); realname != "" {
|
if realname := database.GetRealname(&network.user.User, &network.Network); realname != "" {
|
||||||
attrs["realname"] = irc.TagValue(realname)
|
attrs["realname"] = irc.TagValue(realname)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,7 +171,7 @@ func networkAddrFromAttrs(attrs irc.Tags) string {
|
|||||||
return addr
|
return addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateNetworkAttrs(record *Network, attrs irc.Tags, subcommand string) error {
|
func updateNetworkAttrs(record *database.Network, attrs irc.Tags, subcommand string) error {
|
||||||
addrAttrs := irc.Tags{}
|
addrAttrs := irc.Tags{}
|
||||||
fillNetworkAddrAttrs(addrAttrs, record)
|
fillNetworkAddrAttrs(addrAttrs, record)
|
||||||
|
|
||||||
@ -414,7 +416,7 @@ func isOurNick(net *network, nick string) bool {
|
|||||||
// know whether this name is our nickname. Best-effort: use the network's
|
// know whether this name is our nickname. Best-effort: use the network's
|
||||||
// configured nickname and hope it was the one being used when we were
|
// configured nickname and hope it was the one being used when we were
|
||||||
// connected.
|
// connected.
|
||||||
return net.casemap(nick) == net.casemap(GetNick(&net.user.User, &net.Network))
|
return net.casemap(nick) == net.casemap(database.GetNick(&net.user.User, &net.Network))
|
||||||
}
|
}
|
||||||
|
|
||||||
// marshalEntity converts an upstream entity name (ie. channel or nick) into a
|
// marshalEntity converts an upstream entity name (ie. channel or nick) into a
|
||||||
@ -1146,9 +1148,9 @@ func (dc *downstreamConn) updateNick() {
|
|||||||
if uc := dc.upstream(); uc != nil {
|
if uc := dc.upstream(); uc != nil {
|
||||||
nick = uc.nick
|
nick = uc.nick
|
||||||
} else if dc.network != nil {
|
} else if dc.network != nil {
|
||||||
nick = GetNick(&dc.user.User, &dc.network.Network)
|
nick = database.GetNick(&dc.user.User, &dc.network.Network)
|
||||||
} else {
|
} else {
|
||||||
nick = GetNick(&dc.user.User, nil)
|
nick = database.GetNick(&dc.user.User, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if nick == dc.nick {
|
if nick == dc.nick {
|
||||||
@ -1201,9 +1203,9 @@ func (dc *downstreamConn) updateRealname() {
|
|||||||
if uc := dc.upstream(); uc != nil {
|
if uc := dc.upstream(); uc != nil {
|
||||||
realname = uc.realname
|
realname = uc.realname
|
||||||
} else if dc.network != nil {
|
} else if dc.network != nil {
|
||||||
realname = GetRealname(&dc.user.User, &dc.network.Network)
|
realname = database.GetRealname(&dc.user.User, &dc.network.Network)
|
||||||
} else {
|
} else {
|
||||||
realname = GetRealname(&dc.user.User, nil)
|
realname = database.GetRealname(&dc.user.User, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if realname != dc.realname {
|
if realname != dc.realname {
|
||||||
@ -1439,7 +1441,7 @@ func (dc *downstreamConn) loadNetwork(ctx context.Context) error {
|
|||||||
|
|
||||||
dc.logger.Printf("auto-saving network %q", dc.registration.networkName)
|
dc.logger.Printf("auto-saving network %q", dc.registration.networkName)
|
||||||
var err error
|
var err error
|
||||||
network, err = dc.user.createNetwork(ctx, &Network{
|
network, err = dc.user.createNetwork(ctx, &database.Network{
|
||||||
Addr: dc.registration.networkName,
|
Addr: dc.registration.networkName,
|
||||||
Nick: nick,
|
Nick: nick,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
@ -1475,7 +1477,7 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
|
|||||||
if uc := dc.upstream(); uc != nil {
|
if uc := dc.upstream(); uc != nil {
|
||||||
dc.nick = uc.nick
|
dc.nick = uc.nick
|
||||||
} else if dc.network != nil {
|
} else if dc.network != nil {
|
||||||
dc.nick = GetNick(&dc.user.User, &dc.network.Network)
|
dc.nick = database.GetNick(&dc.user.User, &dc.network.Network)
|
||||||
} else {
|
} else {
|
||||||
dc.nick = dc.user.Username
|
dc.nick = dc.user.Username
|
||||||
}
|
}
|
||||||
@ -1931,7 +1933,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
|||||||
}
|
}
|
||||||
uc.network.attach(ctx, ch)
|
uc.network.attach(ctx, ch)
|
||||||
} else {
|
} else {
|
||||||
ch = &Channel{
|
ch = &database.Channel{
|
||||||
Name: upstreamName,
|
Name: upstreamName,
|
||||||
Key: key,
|
Key: key,
|
||||||
}
|
}
|
||||||
@ -1963,7 +1965,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
|||||||
if ch != nil {
|
if ch != nil {
|
||||||
uc.network.detach(ch)
|
uc.network.detach(ch)
|
||||||
} else {
|
} else {
|
||||||
ch = &Channel{
|
ch = &database.Channel{
|
||||||
Name: name,
|
Name: name,
|
||||||
Detached: true,
|
Detached: true,
|
||||||
}
|
}
|
||||||
@ -2911,7 +2913,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
|||||||
Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"},
|
Params: []string{"READ", "INTERNAL_ERROR", target, "Internal error"},
|
||||||
}}
|
}}
|
||||||
} else if r == nil {
|
} else if r == nil {
|
||||||
r = &ReadReceipt{
|
r = &database.ReadReceipt{
|
||||||
Target: entityCM,
|
Target: entityCM,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3082,7 +3084,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
|||||||
}
|
}
|
||||||
attrs := irc.ParseTags(attrsStr)
|
attrs := irc.ParseTags(attrsStr)
|
||||||
|
|
||||||
record := &Network{Nick: dc.nick, Enabled: true}
|
record := &database.Network{Nick: dc.nick, Enabled: true}
|
||||||
if err := updateNetworkAttrs(record, attrs, subcommand); err != nil {
|
if err := updateNetworkAttrs(record, attrs, subcommand); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
6
irc.go
6
irc.go
@ -9,6 +9,8 @@ import (
|
|||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"gopkg.in/irc.v3"
|
"gopkg.in/irc.v3"
|
||||||
|
|
||||||
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -653,12 +655,12 @@ func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
|
|||||||
|
|
||||||
type channelCasemapMap struct{ casemapMap }
|
type channelCasemapMap struct{ casemapMap }
|
||||||
|
|
||||||
func (cm *channelCasemapMap) Value(name string) *Channel {
|
func (cm *channelCasemapMap) Value(name string) *database.Channel {
|
||||||
entry, ok := cm.innerMap[cm.casemap(name)]
|
entry, ok := cm.innerMap[cm.casemap(name)]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return entry.value.(*Channel)
|
return entry.value.(*database.Channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
type membershipsCasemapMap struct{ casemapMap }
|
type membershipsCasemapMap struct{ casemapMap }
|
||||||
|
16
msgstore.go
16
msgstore.go
@ -9,6 +9,8 @@ import (
|
|||||||
|
|
||||||
"git.sr.ht/~sircmpwn/go-bare"
|
"git.sr.ht/~sircmpwn/go-bare"
|
||||||
"gopkg.in/irc.v3"
|
"gopkg.in/irc.v3"
|
||||||
|
|
||||||
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
// messageStore is a per-user store for IRC messages.
|
// messageStore is a per-user store for IRC messages.
|
||||||
@ -17,11 +19,11 @@ type messageStore interface {
|
|||||||
// LastMsgID queries the last message ID for the given network, entity and
|
// LastMsgID queries the last message ID for the given network, entity and
|
||||||
// date. The message ID returned may not refer to a valid message, but can be
|
// date. The message ID returned may not refer to a valid message, but can be
|
||||||
// used in history queries.
|
// used in history queries.
|
||||||
LastMsgID(network *Network, entity string, t time.Time) (string, error)
|
LastMsgID(network *database.Network, entity string, t time.Time) (string, error)
|
||||||
// LoadLatestID queries the latest non-event messages for the given network,
|
// LoadLatestID queries the latest non-event messages for the given network,
|
||||||
// entity and date, up to a count of limit messages, sorted from oldest to newest.
|
// entity and date, up to a count of limit messages, sorted from oldest to newest.
|
||||||
LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error)
|
LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error)
|
||||||
Append(network *Network, entity string, msg *irc.Message) (id string, err error)
|
Append(network *database.Network, entity string, msg *irc.Message) (id string, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type chatHistoryTarget struct {
|
type chatHistoryTarget struct {
|
||||||
@ -38,17 +40,17 @@ type chatHistoryMessageStore interface {
|
|||||||
// It returns up to limit targets, starting from start and ending on end,
|
// It returns up to limit targets, starting from start and ending on end,
|
||||||
// both excluded. end may be before or after start.
|
// both excluded. end may be before or after start.
|
||||||
// If events is false, only PRIVMSG/NOTICE messages are considered.
|
// If events is false, only PRIVMSG/NOTICE messages are considered.
|
||||||
ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
|
ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error)
|
||||||
// LoadBeforeTime loads up to limit messages before start down to end. The
|
// LoadBeforeTime loads up to limit messages before start down to end. The
|
||||||
// returned messages must be between and excluding the provided bounds.
|
// returned messages must be between and excluding the provided bounds.
|
||||||
// end is before start.
|
// end is before start.
|
||||||
// If events is false, only PRIVMSG/NOTICE messages are considered.
|
// If events is false, only PRIVMSG/NOTICE messages are considered.
|
||||||
LoadBeforeTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
|
LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
|
||||||
// LoadBeforeTime loads up to limit messages after start up to end. The
|
// LoadBeforeTime loads up to limit messages after start up to end. The
|
||||||
// returned messages must be between and excluding the provided bounds.
|
// returned messages must be between and excluding the provided bounds.
|
||||||
// end is after start.
|
// end is after start.
|
||||||
// If events is false, only PRIVMSG/NOTICE messages are considered.
|
// If events is false, only PRIVMSG/NOTICE messages are considered.
|
||||||
LoadAfterTime(ctx context.Context, network *Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
|
LoadAfterTime(ctx context.Context, network *database.Network, entity string, start, end time.Time, limit int, events bool) ([]*irc.Message, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type searchOptions struct {
|
type searchOptions struct {
|
||||||
@ -66,7 +68,7 @@ type searchMessageStore interface {
|
|||||||
messageStore
|
messageStore
|
||||||
|
|
||||||
// Search returns messages matching the specified options.
|
// Search returns messages matching the specified options.
|
||||||
Search(ctx context.Context, network *Network, search searchOptions) ([]*irc.Message, error)
|
Search(ctx context.Context, network *database.Network, search searchOptions) ([]*irc.Message, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type msgIDType uint
|
type msgIDType uint
|
||||||
|
@ -13,6 +13,8 @@ import (
|
|||||||
|
|
||||||
"git.sr.ht/~sircmpwn/go-bare"
|
"git.sr.ht/~sircmpwn/go-bare"
|
||||||
"gopkg.in/irc.v3"
|
"gopkg.in/irc.v3"
|
||||||
|
|
||||||
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -80,7 +82,7 @@ type fsMessageStoreFile struct {
|
|||||||
// https://github.com/znc/znc/blob/master/modules/log.cpp
|
// https://github.com/znc/znc/blob/master/modules/log.cpp
|
||||||
type fsMessageStore struct {
|
type fsMessageStore struct {
|
||||||
root string
|
root string
|
||||||
user *User
|
user *database.User
|
||||||
|
|
||||||
// Write-only files used by Append
|
// Write-only files used by Append
|
||||||
files map[string]*fsMessageStoreFile // indexed by entity
|
files map[string]*fsMessageStoreFile // indexed by entity
|
||||||
@ -90,7 +92,7 @@ var _ messageStore = (*fsMessageStore)(nil)
|
|||||||
var _ chatHistoryMessageStore = (*fsMessageStore)(nil)
|
var _ chatHistoryMessageStore = (*fsMessageStore)(nil)
|
||||||
var _ searchMessageStore = (*fsMessageStore)(nil)
|
var _ searchMessageStore = (*fsMessageStore)(nil)
|
||||||
|
|
||||||
func newFSMessageStore(root string, user *User) *fsMessageStore {
|
func newFSMessageStore(root string, user *database.User) *fsMessageStore {
|
||||||
return &fsMessageStore{
|
return &fsMessageStore{
|
||||||
root: filepath.Join(root, escapeFilename(user.Username)),
|
root: filepath.Join(root, escapeFilename(user.Username)),
|
||||||
user: user,
|
user: user,
|
||||||
@ -98,14 +100,14 @@ func newFSMessageStore(root string, user *User) *fsMessageStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) logPath(network *Network, entity string, t time.Time) string {
|
func (ms *fsMessageStore) logPath(network *database.Network, entity string, t time.Time) string {
|
||||||
year, month, day := t.Date()
|
year, month, day := t.Date()
|
||||||
filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
|
filename := fmt.Sprintf("%04d-%02d-%02d.log", year, month, day)
|
||||||
return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename)
|
return filepath.Join(ms.root, escapeFilename(network.GetName()), escapeFilename(entity), filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
// nextMsgID queries the message ID for the next message to be written to f.
|
// nextMsgID queries the message ID for the next message to be written to f.
|
||||||
func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (string, error) {
|
func nextFSMsgID(network *database.Network, entity string, t time.Time, f *os.File) (string, error) {
|
||||||
offset, err := f.Seek(0, io.SeekEnd)
|
offset, err := f.Seek(0, io.SeekEnd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to query next FS message ID: %v", err)
|
return "", fmt.Errorf("failed to query next FS message ID: %v", err)
|
||||||
@ -113,7 +115,7 @@ func nextFSMsgID(network *Network, entity string, t time.Time, f *os.File) (stri
|
|||||||
return formatFSMsgID(network.ID, entity, t, offset), nil
|
return formatFSMsgID(network.ID, entity, t, offset), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
|
func (ms *fsMessageStore) LastMsgID(network *database.Network, entity string, t time.Time) (string, error) {
|
||||||
p := ms.logPath(network, entity, t)
|
p := ms.logPath(network, entity, t)
|
||||||
fi, err := os.Stat(p)
|
fi, err := os.Stat(p)
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
@ -124,7 +126,7 @@ func (ms *fsMessageStore) LastMsgID(network *Network, entity string, t time.Time
|
|||||||
return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil
|
return formatFSMsgID(network.ID, entity, t, fi.Size()-1), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
|
func (ms *fsMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) {
|
||||||
s := formatMessage(msg)
|
s := formatMessage(msg)
|
||||||
if s == "" {
|
if s == "" {
|
||||||
return "", nil
|
return "", nil
|
||||||
@ -253,7 +255,7 @@ func formatMessage(msg *irc.Message) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) parseMessage(line string, network *Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) {
|
func (ms *fsMessageStore) parseMessage(line string, network *database.Network, entity string, ref time.Time, events bool) (*irc.Message, time.Time, error) {
|
||||||
var hour, minute, second int
|
var hour, minute, second int
|
||||||
_, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second)
|
_, err := fmt.Sscanf(line, "[%02d:%02d:%02d] ", &hour, &minute, &second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -380,7 +382,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *Network, entity str
|
|||||||
// our nickname in the logs, so grab it from the network settings.
|
// our nickname in the logs, so grab it from the network settings.
|
||||||
// Not very accurate since this may not match our nick at the time
|
// Not very accurate since this may not match our nick at the time
|
||||||
// the message was received, but we can't do a lot better.
|
// the message was received, but we can't do a lot better.
|
||||||
entity = GetNick(ms.user, network)
|
entity = database.GetNick(ms.user, network)
|
||||||
}
|
}
|
||||||
params = []string{entity, text}
|
params = []string{entity, text}
|
||||||
}
|
}
|
||||||
@ -399,7 +401,7 @@ func (ms *fsMessageStore) parseMessage(line string, network *Network, entity str
|
|||||||
return msg, t, nil
|
return msg, t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
func (ms *fsMessageStore) parseMessagesBefore(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, afterOffset int64, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
||||||
path := ms.logPath(network, entity, ref)
|
path := ms.logPath(network, entity, ref)
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -458,7 +460,7 @@ func (ms *fsMessageStore) parseMessagesBefore(network *Network, entity string, r
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, ref time.Time, end time.Time, events bool, limit int, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
func (ms *fsMessageStore) parseMessagesAfter(network *database.Network, entity string, ref time.Time, end time.Time, events bool, limit int, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
||||||
path := ms.logPath(network, entity, ref)
|
path := ms.logPath(network, entity, ref)
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -493,7 +495,7 @@ func (ms *fsMessageStore) parseMessagesAfter(network *Network, entity string, re
|
|||||||
return history, nil
|
return history, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
||||||
if start.IsZero() {
|
if start.IsZero() {
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
} else {
|
} else {
|
||||||
@ -526,11 +528,11 @@ func (ms *fsMessageStore) getBeforeTime(ctx context.Context, network *Network, e
|
|||||||
return messages[remaining:], nil
|
return messages[remaining:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
|
func (ms *fsMessageStore) LoadBeforeTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
|
||||||
return ms.getBeforeTime(ctx, network, entity, start, end, limit, events, nil)
|
return ms.getBeforeTime(ctx, network, entity, start, end, limit, events, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool, selector func(m *irc.Message) bool) ([]*irc.Message, error) {
|
||||||
start = start.In(time.Local)
|
start = start.In(time.Local)
|
||||||
if end.IsZero() {
|
if end.IsZero() {
|
||||||
end = time.Now()
|
end = time.Now()
|
||||||
@ -562,11 +564,11 @@ func (ms *fsMessageStore) getAfterTime(ctx context.Context, network *Network, en
|
|||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
|
func (ms *fsMessageStore) LoadAfterTime(ctx context.Context, network *database.Network, entity string, start time.Time, end time.Time, limit int, events bool) ([]*irc.Message, error) {
|
||||||
return ms.getAfterTime(ctx, network, entity, start, end, limit, events, nil)
|
return ms.getAfterTime(ctx, network, entity, start, end, limit, events, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
|
func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) {
|
||||||
var afterTime time.Time
|
var afterTime time.Time
|
||||||
var afterOffset int64
|
var afterOffset int64
|
||||||
if id != "" {
|
if id != "" {
|
||||||
@ -614,7 +616,7 @@ func (ms *fsMessageStore) LoadLatestID(ctx context.Context, network *Network, en
|
|||||||
return history[remaining:], nil
|
return history[remaining:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
|
func (ms *fsMessageStore) ListTargets(ctx context.Context, network *database.Network, start, end time.Time, limit int, events bool) ([]chatHistoryTarget, error) {
|
||||||
start = start.In(time.Local)
|
start = start.In(time.Local)
|
||||||
end = end.In(time.Local)
|
end = end.In(time.Local)
|
||||||
rootPath := filepath.Join(ms.root, escapeFilename(network.GetName()))
|
rootPath := filepath.Join(ms.root, escapeFilename(network.GetName()))
|
||||||
@ -693,7 +695,7 @@ func (ms *fsMessageStore) ListTargets(ctx context.Context, network *Network, sta
|
|||||||
return targets, nil
|
return targets, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) Search(ctx context.Context, network *Network, opts searchOptions) ([]*irc.Message, error) {
|
func (ms *fsMessageStore) Search(ctx context.Context, network *database.Network, opts searchOptions) ([]*irc.Message, error) {
|
||||||
text := strings.ToLower(opts.text)
|
text := strings.ToLower(opts.text)
|
||||||
selector := func(m *irc.Message) bool {
|
selector := func(m *irc.Message) bool {
|
||||||
if opts.from != "" && m.User != opts.from {
|
if opts.from != "" && m.User != opts.from {
|
||||||
@ -711,7 +713,7 @@ func (ms *fsMessageStore) Search(ctx context.Context, network *Network, opts sea
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *Network) error {
|
func (ms *fsMessageStore) RenameNetwork(oldNet, newNet *database.Network) error {
|
||||||
oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName()))
|
oldDir := filepath.Join(ms.root, escapeFilename(oldNet.GetName()))
|
||||||
newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName()))
|
newDir := filepath.Join(ms.root, escapeFilename(newNet.GetName()))
|
||||||
// Avoid loosing data by overwriting an existing directory
|
// Avoid loosing data by overwriting an existing directory
|
||||||
|
@ -7,6 +7,8 @@ import (
|
|||||||
|
|
||||||
"git.sr.ht/~sircmpwn/go-bare"
|
"git.sr.ht/~sircmpwn/go-bare"
|
||||||
"gopkg.in/irc.v3"
|
"gopkg.in/irc.v3"
|
||||||
|
|
||||||
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
const messageRingBufferCap = 4096
|
const messageRingBufferCap = 4096
|
||||||
@ -55,7 +57,7 @@ func (ms *memoryMessageStore) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingBuffer {
|
func (ms *memoryMessageStore) get(network *database.Network, entity string) *messageRingBuffer {
|
||||||
k := ringBufferKey{networkID: network.ID, entity: entity}
|
k := ringBufferKey{networkID: network.ID, entity: entity}
|
||||||
if rb, ok := ms.buffers[k]; ok {
|
if rb, ok := ms.buffers[k]; ok {
|
||||||
return rb
|
return rb
|
||||||
@ -65,7 +67,7 @@ func (ms *memoryMessageStore) get(network *Network, entity string) *messageRingB
|
|||||||
return rb
|
return rb
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.Time) (string, error) {
|
func (ms *memoryMessageStore) LastMsgID(network *database.Network, entity string, t time.Time) (string, error) {
|
||||||
var seq uint64
|
var seq uint64
|
||||||
k := ringBufferKey{networkID: network.ID, entity: entity}
|
k := ringBufferKey{networkID: network.ID, entity: entity}
|
||||||
if rb, ok := ms.buffers[k]; ok {
|
if rb, ok := ms.buffers[k]; ok {
|
||||||
@ -74,7 +76,7 @@ func (ms *memoryMessageStore) LastMsgID(network *Network, entity string, t time.
|
|||||||
return formatMemoryMsgID(network.ID, entity, seq), nil
|
return formatMemoryMsgID(network.ID, entity, seq), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.Message) (string, error) {
|
func (ms *memoryMessageStore) Append(network *database.Network, entity string, msg *irc.Message) (string, error) {
|
||||||
switch msg.Command {
|
switch msg.Command {
|
||||||
case "PRIVMSG", "NOTICE":
|
case "PRIVMSG", "NOTICE":
|
||||||
// Only append these messages, because LoadLatestID shouldn't return
|
// Only append these messages, because LoadLatestID shouldn't return
|
||||||
@ -94,7 +96,7 @@ func (ms *memoryMessageStore) Append(network *Network, entity string, msg *irc.M
|
|||||||
return formatMemoryMsgID(network.ID, entity, seq), nil
|
return formatMemoryMsgID(network.ID, entity, seq), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *Network, entity, id string, limit int) ([]*irc.Message, error) {
|
func (ms *memoryMessageStore) LoadLatestID(ctx context.Context, network *database.Network, entity, id string, limit int) ([]*irc.Message, error) {
|
||||||
_, _, seq, err := parseMemoryMsgID(id)
|
_, _, seq, err := parseMemoryMsgID(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -20,6 +20,7 @@ import (
|
|||||||
"nhooyr.io/websocket"
|
"nhooyr.io/websocket"
|
||||||
|
|
||||||
"git.sr.ht/~emersion/soju/config"
|
"git.sr.ht/~emersion/soju/config"
|
||||||
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: make configurable
|
// TODO: make configurable
|
||||||
@ -141,7 +142,7 @@ type Server struct {
|
|||||||
MetricsRegistry prometheus.Registerer // can be nil
|
MetricsRegistry prometheus.Registerer // can be nil
|
||||||
|
|
||||||
config atomic.Value // *Config
|
config atomic.Value // *Config
|
||||||
db Database
|
db database.Database
|
||||||
stopWG sync.WaitGroup
|
stopWG sync.WaitGroup
|
||||||
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
@ -161,7 +162,7 @@ type Server struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(db Database) *Server {
|
func NewServer(db database.Database) *Server {
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
Logger: NewLogger(log.Writer(), true),
|
Logger: NewLogger(log.Writer(), true),
|
||||||
db: db,
|
db: db,
|
||||||
@ -273,7 +274,7 @@ func (s *Server) Shutdown() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) createUser(ctx context.Context, user *User) (*user, error) {
|
func (s *Server) createUser(ctx context.Context, user *database.User) (*user, error) {
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
defer s.lock.Unlock()
|
defer s.lock.Unlock()
|
||||||
|
|
||||||
@ -304,7 +305,7 @@ func (s *Server) getUser(name string) *user {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) addUserLocked(user *User) *user {
|
func (s *Server) addUserLocked(user *database.User) *user {
|
||||||
s.Logger.Printf("starting bouncer for user %q", user.Username)
|
s.Logger.Printf("starting bouncer for user %q", user.Username)
|
||||||
u := newUser(s, user)
|
u := newUser(s, user)
|
||||||
s.users[u.Username] = u
|
s.users[u.Username] = u
|
||||||
|
@ -3,10 +3,13 @@ package soju
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"gopkg.in/irc.v3"
|
"gopkg.in/irc.v3"
|
||||||
|
|
||||||
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testServerPrefix = &irc.Prefix{Name: "soju-test-server"}
|
var testServerPrefix = &irc.Prefix{Name: "soju-test-server"}
|
||||||
@ -16,34 +19,35 @@ const (
|
|||||||
testPassword = testUsername
|
testPassword = testUsername
|
||||||
)
|
)
|
||||||
|
|
||||||
func createTempSqliteDB(t *testing.T) Database {
|
func createTempSqliteDB(t *testing.T) database.Database {
|
||||||
db, err := OpenDB("sqlite3", ":memory:")
|
db, err := database.OpenTempSqliteDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create temporary SQLite database: %v", err)
|
t.Fatalf("failed to create temporary SQLite database: %v", err)
|
||||||
}
|
}
|
||||||
// :memory: will open a separate database for each new connection. Make
|
|
||||||
// sure the sql package only uses a single connection. An alternative
|
|
||||||
// solution is to use "file::memory:?cache=shared".
|
|
||||||
db.(*SqliteDB).db.SetMaxOpenConns(1)
|
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
func createTempPostgresDB(t *testing.T) Database {
|
func createTempPostgresDB(t *testing.T) database.Database {
|
||||||
db := &PostgresDB{db: openTempPostgresDB(t)}
|
source, ok := os.LookupEnv("SOJU_TEST_POSTGRES")
|
||||||
if err := db.upgrade(); err != nil {
|
if !ok {
|
||||||
t.Fatalf("failed to upgrade PostgreSQL database: %v", err)
|
t.Skip("set SOJU_TEST_POSTGRES to a connection string to execute PostgreSQL tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := database.OpenTempPostgresDB(source)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create temporary PostgreSQL database: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
func createTestUser(t *testing.T, db Database) *User {
|
func createTestUser(t *testing.T, db database.Database) *database.User {
|
||||||
hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost)
|
hashed, err := bcrypt.GenerateFromPassword([]byte(testPassword), bcrypt.DefaultCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to generate bcrypt hash: %v", err)
|
t.Fatalf("failed to generate bcrypt hash: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
record := &User{Username: testUsername, Password: string(hashed)}
|
record := &database.User{Username: testUsername, Password: string(hashed)}
|
||||||
if err := db.StoreUser(context.Background(), record); err != nil {
|
if err := db.StoreUser(context.Background(), record); err != nil {
|
||||||
t.Fatalf("failed to store test user: %v", err)
|
t.Fatalf("failed to store test user: %v", err)
|
||||||
}
|
}
|
||||||
@ -57,13 +61,13 @@ func createTestDownstream(t *testing.T, srv *Server) ircConn {
|
|||||||
return newNetIRCConn(c2)
|
return newNetIRCConn(c2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createTestUpstream(t *testing.T, db Database, user *User) (*Network, net.Listener) {
|
func createTestUpstream(t *testing.T, db database.Database, user *database.User) (*database.Network, net.Listener) {
|
||||||
ln, err := net.Listen("tcp", "localhost:0")
|
ln, err := net.Listen("tcp", "localhost:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create TCP listener: %v", err)
|
t.Fatalf("failed to create TCP listener: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
network := &Network{
|
network := &database.Network{
|
||||||
Name: "testnet",
|
Name: "testnet",
|
||||||
Addr: "irc+insecure://" + ln.Addr().String(),
|
Addr: "irc+insecure://" + ln.Addr().String(),
|
||||||
Nick: user.Username,
|
Nick: user.Username,
|
||||||
@ -95,7 +99,7 @@ func expectMessage(t *testing.T, c ircConn, cmd string) *irc.Message {
|
|||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerDownstreamConn(t *testing.T, c ircConn, network *Network) {
|
func registerDownstreamConn(t *testing.T, c ircConn, network *database.Network) {
|
||||||
c.WriteMessage(&irc.Message{
|
c.WriteMessage(&irc.Message{
|
||||||
Command: "PASS",
|
Command: "PASS",
|
||||||
Params: []string{testPassword},
|
Params: []string{testPassword},
|
||||||
@ -151,7 +155,7 @@ func registerUpstreamConn(t *testing.T, c ircConn) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testServer(t *testing.T, db Database) {
|
func testServer(t *testing.T, db database.Database) {
|
||||||
user := createTestUser(t, db)
|
user := createTestUser(t, db)
|
||||||
network, upstream := createTestUpstream(t, db, user)
|
network, upstream := createTestUpstream(t, db, user)
|
||||||
defer upstream.Close()
|
defer upstream.Close()
|
||||||
|
28
service.go
28
service.go
@ -17,6 +17,8 @@ import (
|
|||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"gopkg.in/irc.v3"
|
"gopkg.in/irc.v3"
|
||||||
|
|
||||||
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
const serviceNick = "BouncerServ"
|
const serviceNick = "BouncerServ"
|
||||||
@ -447,7 +449,7 @@ func newNetworkFlagSet() *networkFlagSet {
|
|||||||
return fs
|
return fs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fs *networkFlagSet) update(network *Network) error {
|
func (fs *networkFlagSet) update(network *database.Network) error {
|
||||||
if fs.Addr != nil {
|
if fs.Addr != nil {
|
||||||
if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
|
if addrParts := strings.SplitN(*fs.Addr, "://", 2); len(addrParts) == 2 {
|
||||||
scheme := addrParts[0]
|
scheme := addrParts[0]
|
||||||
@ -508,7 +510,7 @@ func handleServiceNetworkCreate(ctx context.Context, dc *downstreamConn, params
|
|||||||
return fmt.Errorf("flag -addr is required")
|
return fmt.Errorf("flag -addr is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
record := &Network{
|
record := &database.Network{
|
||||||
Addr: *fs.Addr,
|
Addr: *fs.Addr,
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
}
|
}
|
||||||
@ -833,7 +835,7 @@ func handleUserCreate(ctx context.Context, dc *downstreamConn, params []string)
|
|||||||
return fmt.Errorf("failed to hash password: %v", err)
|
return fmt.Errorf("failed to hash password: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
user := &User{
|
user := &database.User{
|
||||||
Username: *username,
|
Username: *username,
|
||||||
Password: string(hashed),
|
Password: string(hashed),
|
||||||
Realname: *realname,
|
Realname: *realname,
|
||||||
@ -971,9 +973,9 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
|
|||||||
n := 0
|
n := 0
|
||||||
|
|
||||||
sendNetwork := func(net *network) {
|
sendNetwork := func(net *network) {
|
||||||
var channels []*Channel
|
var channels []*database.Channel
|
||||||
for _, entry := range net.channels.innerMap {
|
for _, entry := range net.channels.innerMap {
|
||||||
channels = append(channels, entry.value.(*Channel))
|
channels = append(channels, entry.value.(*database.Channel))
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(channels, func(i, j int) bool {
|
sort.Slice(channels, func(i, j int) bool {
|
||||||
@ -1031,6 +1033,20 @@ func handleServiceChannelStatus(ctx context.Context, dc *downstreamConn, params
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseFilter(filter string) (database.MessageFilter, error) {
|
||||||
|
switch filter {
|
||||||
|
case "default":
|
||||||
|
return database.FilterDefault, nil
|
||||||
|
case "none":
|
||||||
|
return database.FilterNone, nil
|
||||||
|
case "highlight":
|
||||||
|
return database.FilterHighlight, nil
|
||||||
|
case "message":
|
||||||
|
return database.FilterMessage, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("unknown filter: %q", filter)
|
||||||
|
}
|
||||||
|
|
||||||
type channelFlagSet struct {
|
type channelFlagSet struct {
|
||||||
*flag.FlagSet
|
*flag.FlagSet
|
||||||
RelayDetached, ReattachOn, DetachAfter, DetachOn *string
|
RelayDetached, ReattachOn, DetachAfter, DetachOn *string
|
||||||
@ -1045,7 +1061,7 @@ func newChannelFlagSet() *channelFlagSet {
|
|||||||
return fs
|
return fs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fs *channelFlagSet) update(channel *Channel) error {
|
func (fs *channelFlagSet) update(channel *database.Channel) error {
|
||||||
if fs.RelayDetached != nil {
|
if fs.RelayDetached != nil {
|
||||||
filter, err := parseFilter(*fs.RelayDetached)
|
filter, err := parseFilter(*fs.RelayDetached)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
20
upstream.go
20
upstream.go
@ -17,6 +17,8 @@ import (
|
|||||||
|
|
||||||
"github.com/emersion/go-sasl"
|
"github.com/emersion/go-sasl"
|
||||||
"gopkg.in/irc.v3"
|
"gopkg.in/irc.v3"
|
||||||
|
|
||||||
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
// permanentUpstreamCaps is the static list of upstream capabilities always
|
// permanentUpstreamCaps is the static list of upstream capabilities always
|
||||||
@ -510,7 +512,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
|||||||
}
|
}
|
||||||
|
|
||||||
highlight := uc.network.isHighlight(msg)
|
highlight := uc.network.isHighlight(msg)
|
||||||
if ch.DetachOn == FilterMessage || ch.DetachOn == FilterDefault || (ch.DetachOn == FilterHighlight && highlight) {
|
if ch.DetachOn == database.FilterMessage || ch.DetachOn == database.FilterDefault || (ch.DetachOn == database.FilterHighlight && highlight) {
|
||||||
uc.updateChannelAutoDetach(target)
|
uc.updateChannelAutoDetach(target)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -765,7 +767,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
|||||||
if uc.network.channels.Len() > 0 {
|
if uc.network.channels.Len() > 0 {
|
||||||
var channels, keys []string
|
var channels, keys []string
|
||||||
for _, entry := range uc.network.channels.innerMap {
|
for _, entry := range uc.network.channels.innerMap {
|
||||||
ch := entry.value.(*Channel)
|
ch := entry.value.(*database.Channel)
|
||||||
channels = append(channels, ch.Name)
|
channels = append(channels, ch.Name)
|
||||||
keys = append(keys, ch.Key)
|
keys = append(keys, ch.Key)
|
||||||
}
|
}
|
||||||
@ -1553,7 +1555,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if the nick we want is now free
|
// Check if the nick we want is now free
|
||||||
wantNick := GetNick(&uc.user.User, &uc.network.Network)
|
wantNick := database.GetNick(&uc.user.User, &uc.network.Network)
|
||||||
wantNickCM := uc.network.casemap(wantNick)
|
wantNickCM := uc.network.casemap(wantNick)
|
||||||
if !online && uc.nickCM != wantNickCM {
|
if !online && uc.nickCM != wantNickCM {
|
||||||
found := false
|
found := false
|
||||||
@ -1796,13 +1798,13 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *Channel, msg *irc.Message) {
|
func (uc *upstreamConn) handleDetachedMessage(ctx context.Context, ch *database.Channel, msg *irc.Message) {
|
||||||
if uc.network.detachedMessageNeedsRelay(ch, msg) {
|
if uc.network.detachedMessageNeedsRelay(ch, msg) {
|
||||||
uc.forEachDownstream(func(dc *downstreamConn) {
|
uc.forEachDownstream(func(dc *downstreamConn) {
|
||||||
dc.relayDetachedMessage(uc.network, msg)
|
dc.relayDetachedMessage(uc.network, msg)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if ch.ReattachOn == FilterMessage || (ch.ReattachOn == FilterHighlight && uc.network.isHighlight(msg)) {
|
if ch.ReattachOn == database.FilterMessage || (ch.ReattachOn == database.FilterHighlight && uc.network.isHighlight(msg)) {
|
||||||
uc.network.attach(ctx, ch)
|
uc.network.attach(ctx, ch)
|
||||||
if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
|
if err := uc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
|
||||||
uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
|
uc.logger.Printf("failed to update channel %q: %v", ch.Name, err)
|
||||||
@ -1960,10 +1962,10 @@ func splitSpace(s string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (uc *upstreamConn) register(ctx context.Context) {
|
func (uc *upstreamConn) register(ctx context.Context) {
|
||||||
uc.nick = GetNick(&uc.user.User, &uc.network.Network)
|
uc.nick = database.GetNick(&uc.user.User, &uc.network.Network)
|
||||||
uc.nickCM = uc.network.casemap(uc.nick)
|
uc.nickCM = uc.network.casemap(uc.nick)
|
||||||
uc.username = GetUsername(&uc.user.User, &uc.network.Network)
|
uc.username = database.GetUsername(&uc.user.User, &uc.network.Network)
|
||||||
uc.realname = GetRealname(&uc.user.User, &uc.network.Network)
|
uc.realname = database.GetRealname(&uc.user.User, &uc.network.Network)
|
||||||
|
|
||||||
uc.SendMessage(ctx, &irc.Message{
|
uc.SendMessage(ctx, &irc.Message{
|
||||||
Command: "CAP",
|
Command: "CAP",
|
||||||
@ -2193,7 +2195,7 @@ func (uc *upstreamConn) updateMonitor() {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
wantNick := GetNick(&uc.user.User, &uc.network.Network)
|
wantNick := database.GetNick(&uc.user.User, &uc.network.Network)
|
||||||
wantNickCM := uc.network.casemap(wantNick)
|
wantNickCM := uc.network.casemap(wantNick)
|
||||||
if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) {
|
if _, ok := add[wantNickCM]; !ok && !uc.monitored.Has(wantNick) && !uc.isOurNick(wantNick) {
|
||||||
addList = append(addList, wantNickCM)
|
addList = append(addList, wantNickCM)
|
||||||
|
40
user.go
40
user.go
@ -14,6 +14,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gopkg.in/irc.v3"
|
"gopkg.in/irc.v3"
|
||||||
|
|
||||||
|
"git.sr.ht/~emersion/soju/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
type event interface{}
|
type event interface{}
|
||||||
@ -123,7 +125,7 @@ func (ds deliveredStore) ForEachClient(f func(clientName string)) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type network struct {
|
type network struct {
|
||||||
Network
|
database.Network
|
||||||
user *user
|
user *user
|
||||||
logger Logger
|
logger Logger
|
||||||
stopped chan struct{}
|
stopped chan struct{}
|
||||||
@ -135,7 +137,7 @@ type network struct {
|
|||||||
casemap casemapping
|
casemap casemapping
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNetwork(user *user, record *Network, channels []Channel) *network {
|
func newNetwork(user *user, record *database.Network, channels []database.Channel) *network {
|
||||||
logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
|
logger := &prefixLogger{user.logger, fmt.Sprintf("network %q: ", record.GetName())}
|
||||||
|
|
||||||
m := channelCasemapMap{newCasemapMap(0)}
|
m := channelCasemapMap{newCasemapMap(0)}
|
||||||
@ -176,7 +178,7 @@ func (net *network) isStopped() bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func userIdent(u *User) string {
|
func userIdent(u *database.User) string {
|
||||||
// The ident is a string we will send to upstream servers in clear-text.
|
// The ident is a string we will send to upstream servers in clear-text.
|
||||||
// For privacy reasons, make sure it doesn't expose any meaningful user
|
// For privacy reasons, make sure it doesn't expose any meaningful user
|
||||||
// metadata. We just use the base64-encoded hashed ID, so that people don't
|
// metadata. We just use the base64-encoded hashed ID, so that people don't
|
||||||
@ -278,7 +280,7 @@ func (net *network) stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (net *network) detach(ch *Channel) {
|
func (net *network) detach(ch *database.Channel) {
|
||||||
if ch.Detached {
|
if ch.Detached {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -312,7 +314,7 @@ func (net *network) detach(ch *Channel) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (net *network) attach(ctx context.Context, ch *Channel) {
|
func (net *network) attach(ctx context.Context, ch *database.Channel) {
|
||||||
if !ch.Detached {
|
if !ch.Detached {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -388,13 +390,13 @@ func (net *network) storeClientDeliveryReceipts(ctx context.Context, clientName
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var receipts []DeliveryReceipt
|
var receipts []database.DeliveryReceipt
|
||||||
net.delivered.ForEachTarget(func(target string) {
|
net.delivered.ForEachTarget(func(target string) {
|
||||||
msgID := net.delivered.LoadID(target, clientName)
|
msgID := net.delivered.LoadID(target, clientName)
|
||||||
if msgID == "" {
|
if msgID == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
receipts = append(receipts, DeliveryReceipt{
|
receipts = append(receipts, database.DeliveryReceipt{
|
||||||
Target: target,
|
Target: target,
|
||||||
InternalMsgID: msgID,
|
InternalMsgID: msgID,
|
||||||
})
|
})
|
||||||
@ -421,9 +423,9 @@ func (net *network) isHighlight(msg *irc.Message) bool {
|
|||||||
return msg.Prefix.Name != nick && isHighlight(text, nick)
|
return msg.Prefix.Name != nick && isHighlight(text, nick)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) bool {
|
func (net *network) detachedMessageNeedsRelay(ch *database.Channel, msg *irc.Message) bool {
|
||||||
highlight := net.isHighlight(msg)
|
highlight := net.isHighlight(msg)
|
||||||
return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
|
return ch.RelayDetached == database.FilterMessage || ((ch.RelayDetached == database.FilterHighlight || ch.RelayDetached == database.FilterDefault) && highlight)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) {
|
func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) {
|
||||||
@ -443,7 +445,7 @@ func (net *network) autoSaveSASLPlain(ctx context.Context, username, password st
|
|||||||
}
|
}
|
||||||
|
|
||||||
type user struct {
|
type user struct {
|
||||||
User
|
database.User
|
||||||
srv *Server
|
srv *Server
|
||||||
logger Logger
|
logger Logger
|
||||||
|
|
||||||
@ -455,7 +457,7 @@ type user struct {
|
|||||||
msgStore messageStore
|
msgStore messageStore
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUser(srv *Server, record *User) *user {
|
func newUser(srv *Server, record *database.User) *user {
|
||||||
logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
|
logger := &prefixLogger{srv.Logger, fmt.Sprintf("user %q: ", record.Username)}
|
||||||
|
|
||||||
var msgStore messageStore
|
var msgStore messageStore
|
||||||
@ -817,7 +819,7 @@ func (u *user) removeNetwork(network *network) {
|
|||||||
panic("tried to remove a non-existing network")
|
panic("tried to remove a non-existing network")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) checkNetwork(record *Network) error {
|
func (u *user) checkNetwork(record *database.Network) error {
|
||||||
url, err := record.URL()
|
url, err := record.URL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -867,7 +869,7 @@ func (u *user) checkNetwork(record *Network) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) createNetwork(ctx context.Context, record *Network) (*network, error) {
|
func (u *user) createNetwork(ctx context.Context, record *database.Network) (*network, error) {
|
||||||
if record.ID != 0 {
|
if record.ID != 0 {
|
||||||
panic("tried creating an already-existing network")
|
panic("tried creating an already-existing network")
|
||||||
}
|
}
|
||||||
@ -894,7 +896,7 @@ func (u *user) createNetwork(ctx context.Context, record *Network) (*network, er
|
|||||||
return network, nil
|
return network, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, error) {
|
func (u *user) updateNetwork(ctx context.Context, record *database.Network) (*network, error) {
|
||||||
if record.ID == 0 {
|
if record.ID == 0 {
|
||||||
panic("tried updating a new network")
|
panic("tried updating a new network")
|
||||||
}
|
}
|
||||||
@ -920,9 +922,9 @@ func (u *user) updateNetwork(ctx context.Context, record *Network) (*network, er
|
|||||||
|
|
||||||
// Most network changes require us to re-connect to the upstream server
|
// Most network changes require us to re-connect to the upstream server
|
||||||
|
|
||||||
channels := make([]Channel, 0, network.channels.Len())
|
channels := make([]database.Channel, 0, network.channels.Len())
|
||||||
for _, entry := range network.channels.innerMap {
|
for _, entry := range network.channels.innerMap {
|
||||||
ch := entry.value.(*Channel)
|
ch := entry.value.(*database.Channel)
|
||||||
channels = append(channels, *ch)
|
channels = append(channels, *ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -992,7 +994,7 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) updateUser(ctx context.Context, record *User) error {
|
func (u *user) updateUser(ctx context.Context, record *database.User) error {
|
||||||
if u.ID != record.ID {
|
if u.ID != record.ID {
|
||||||
panic("ID mismatch when updating user")
|
panic("ID mismatch when updating user")
|
||||||
}
|
}
|
||||||
@ -1005,7 +1007,7 @@ func (u *user) updateUser(ctx context.Context, record *User) error {
|
|||||||
|
|
||||||
if realnameUpdated {
|
if realnameUpdated {
|
||||||
// Re-connect to networks which use the default realname
|
// Re-connect to networks which use the default realname
|
||||||
var needUpdate []Network
|
var needUpdate []database.Network
|
||||||
for _, net := range u.networks {
|
for _, net := range u.networks {
|
||||||
if net.Realname != "" {
|
if net.Realname != "" {
|
||||||
continue
|
continue
|
||||||
@ -1016,7 +1018,7 @@ func (u *user) updateUser(ctx context.Context, record *User) error {
|
|||||||
if uc := net.conn; uc != nil && uc.caps.IsEnabled("setname") {
|
if uc := net.conn; uc != nil && uc.caps.IsEnabled("setname") {
|
||||||
uc.SendMessage(ctx, &irc.Message{
|
uc.SendMessage(ctx, &irc.Message{
|
||||||
Command: "SETNAME",
|
Command: "SETNAME",
|
||||||
Params: []string{GetRealname(&u.User, &net.Network)},
|
Params: []string{database.GetRealname(&u.User, &net.Network)},
|
||||||
})
|
})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user