diff --git a/database/database.go b/database/database.go index d26d1a7..2a8027a 100644 --- a/database/database.go +++ b/database/database.go @@ -65,13 +65,14 @@ type DatabaseStats struct { } type User struct { - ID int64 - Username string - Password string // hashed - Nick string - Realname string - Admin bool - Enabled bool + ID int64 + Username string + Password string // hashed + Nick string + Realname string + Admin bool + Enabled bool + DownstreamInteractedAt time.Time } func (u *User) CheckPassword(password string) (upgraded bool, err error) { diff --git a/database/postgres.go b/database/postgres.go index f209bad..5a00c08 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -33,7 +33,8 @@ CREATE TABLE "User" ( nick VARCHAR(255), realname VARCHAR(255), created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(), - enabled BOOLEAN NOT NULL DEFAULT TRUE + enabled BOOLEAN NOT NULL DEFAULT TRUE, + downstream_interacted_at TIMESTAMP WITH TIME ZONE ); CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); @@ -171,6 +172,7 @@ var postgresMigrations = []string{ `ALTER TABLE "Network" ADD COLUMN certfp TEXT`, `ALTER TABLE "User" ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()`, `ALTER TABLE "User" ADD COLUMN enabled BOOLEAN NOT NULL DEFAULT TRUE`, + `ALTER TABLE "User" ADD COLUMN downstream_interacted_at TIMESTAMP WITH TIME ZONE`, } type PostgresDB struct { @@ -304,7 +306,8 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) { defer cancel() rows, err := db.db.QueryContext(ctx, - `SELECT id, username, password, admin, nick, realname, enabled + `SELECT id, username, password, admin, nick, realname, enabled, + downstream_interacted_at FROM "User"`) if err != nil { return nil, err @@ -315,7 +318,7 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) { for rows.Next() { var user User var password, nick, realname sql.NullString - if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil { + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &user.DownstreamInteractedAt); err != nil { return nil, err } user.Password = password.String @@ -338,11 +341,11 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro var password, nick, realname sql.NullString row := db.db.QueryRowContext(ctx, - `SELECT id, password, admin, nick, realname, enabled + `SELECT id, password, admin, nick, realname, enabled, downstream_interacted_at FROM "User" WHERE username = $1`, username) - if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil { + if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &user.DownstreamInteractedAt); err != nil { return nil, err } user.Password = password.String @@ -362,16 +365,20 @@ func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error { var err error if user.ID == 0 { err = db.db.QueryRowContext(ctx, ` - INSERT INTO "User" (username, password, admin, nick, realname, enabled) - VALUES ($1, $2, $3, $4, $5, $6) + INSERT INTO "User" (username, password, admin, nick, realname, + enabled, downstream_interacted_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id`, - user.Username, password, user.Admin, nick, realname, user.Enabled).Scan(&user.ID) + user.Username, password, user.Admin, nick, realname, user.Enabled, + user.DownstreamInteractedAt).Scan(&user.ID) } else { _, err = db.db.ExecContext(ctx, ` UPDATE "User" - SET password = $1, admin = $2, nick = $3, realname = $4, enabled = $5 - WHERE id = $6`, - password, user.Admin, nick, realname, user.Enabled, user.ID) + SET password = $1, admin = $2, nick = $3, realname = $4, + enabled = $5, downstream_interacted_at = $6 + WHERE id = $7`, + password, user.Admin, nick, realname, user.Enabled, + user.DownstreamInteractedAt, user.ID) } return err } diff --git a/database/sqlite.go b/database/sqlite.go index 85a647c..28e18cc 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -63,7 +63,8 @@ CREATE TABLE User ( realname TEXT, nick TEXT, created_at TEXT NOT NULL, - enabled INTEGER NOT NULL DEFAULT 1 + enabled INTEGER NOT NULL DEFAULT 1, + downstream_interacted_at TEXT ); CREATE TABLE Network ( @@ -291,6 +292,7 @@ var sqliteMigrations = []string{ UPDATE User SET created_at = strftime('` + sqliteTimeFormat + `', 'now'); `, "ALTER TABLE User ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1", + "ALTER TABLE User ADD COLUMN downstream_interacted_at TEXT;", } type SqliteDB struct { @@ -390,7 +392,8 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { defer cancel() rows, err := db.db.QueryContext(ctx, - `SELECT id, username, password, admin, nick, realname, enabled + `SELECT id, username, password, admin, nick, realname, enabled, + downstream_interacted_at FROM User`) if err != nil { return nil, err @@ -401,12 +404,14 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) { for rows.Next() { var user User var password, nick, realname sql.NullString - if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil { + var downstreamInteractedAt sqliteTime + if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil { return nil, err } user.Password = password.String user.Nick = nick.String user.Realname = realname.String + user.DownstreamInteractedAt = downstreamInteractedAt.Time users = append(users, user) } if err := rows.Err(); err != nil { @@ -423,17 +428,20 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) user := &User{Username: username} var password, nick, realname sql.NullString + var downstreamInteractedAt sqliteTime row := db.db.QueryRowContext(ctx, - `SELECT id, password, admin, nick, realname, enabled + `SELECT id, password, admin, nick, realname, enabled, + downstream_interacted_at FROM User WHERE username = ?`, username) - if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil { + if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil { return nil, err } user.Password = password.String user.Nick = nick.String user.Realname = realname.String + user.DownstreamInteractedAt = downstreamInteractedAt.Time return user, nil } @@ -449,6 +457,7 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { sql.Named("realname", toNullString(user.Realname)), sql.Named("enabled", user.Enabled), sql.Named("now", sqliteTime{time.Now()}), + sql.Named("downstream_interacted_at", sqliteTime{user.DownstreamInteractedAt}), } var err error @@ -456,7 +465,8 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { _, err = db.db.ExecContext(ctx, ` UPDATE User SET password = :password, admin = :admin, nick = :nick, - realname = :realname, enabled = :enabled + realname = :realname, enabled = :enabled, + downstream_interacted_at = :downstream_interacted_at WHERE username = :username`, args...) } else { @@ -464,9 +474,9 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { res, err = db.db.ExecContext(ctx, ` INSERT INTO User(username, password, admin, nick, realname, created_at, - enabled) + enabled, downstream_interacted_at) VALUES (:username, :password, :admin, :nick, :realname, :now, - :enabled)`, + :enabled, :downstream_interacted_at)`, args...) if err != nil { return err diff --git a/user.go b/user.go index 2b651d7..bdb72a8 100644 --- a/user.go +++ b/user.go @@ -683,6 +683,7 @@ func (u *user) run() { } case eventDownstreamConnected: dc := e.dc + ctx := context.TODO() if dc.network != nil { dc.monitored.SetCasemapping(dc.network.casemap) @@ -697,7 +698,7 @@ func (u *user) run() { break } - if err := dc.welcome(context.TODO()); err != nil { + if err := dc.welcome(ctx); err != nil { if ircErr, ok := err.(ircError); ok { msg := ircErr.Message.Copy() msg.Prefix = dc.srv.prefix() @@ -724,8 +725,11 @@ func (u *user) run() { u.forEachUpstream(func(uc *upstreamConn) { uc.updateAway() }) + + u.bumpDownstreamInteractionTime(ctx) case eventDownstreamDisconnected: dc := e.dc + ctx := context.TODO() for i := range u.downstreamConns { if u.downstreamConns[i] == dc { @@ -735,7 +739,7 @@ func (u *user) run() { } dc.forEachNetwork(func(net *network) { - net.storeClientDeliveryReceipts(context.TODO(), dc.clientName) + net.storeClientDeliveryReceipts(ctx, dc.clientName) }) u.forEachUpstream(func(uc *upstreamConn) { @@ -743,6 +747,8 @@ func (u *user) run() { uc.updateAway() uc.updateMonitor() }) + + u.bumpDownstreamInteractionTime(ctx) case eventDownstreamMessage: msg, dc := e.msg, e.dc if dc.isClosed() { @@ -1229,3 +1235,11 @@ func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAd return &net.TCPAddr{IP: ip}, nil } + +func (u *user) bumpDownstreamInteractionTime(ctx context.Context) { + record := u.User + record.DownstreamInteractedAt = time.Now() + if err := u.updateUser(ctx, &record); err != nil { + u.logger.Printf("failed to bump downstream interaction time: %v", err) + } +}