Add user downstream interaction time

This commit is contained in:
Simon Ser 2023-01-26 14:02:11 +01:00
parent 05d7813835
commit 57f5ee8d6f
4 changed files with 60 additions and 28 deletions

View File

@ -72,6 +72,7 @@ type User struct {
Realname string Realname string
Admin bool Admin bool
Enabled bool Enabled bool
DownstreamInteractedAt time.Time
} }
func (u *User) CheckPassword(password string) (upgraded bool, err error) { func (u *User) CheckPassword(password string) (upgraded bool, err error) {

View File

@ -33,7 +33,8 @@ CREATE TABLE "User" (
nick VARCHAR(255), nick VARCHAR(255),
realname VARCHAR(255), realname VARCHAR(255),
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(), created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
enabled BOOLEAN NOT NULL DEFAULT TRUE enabled BOOLEAN NOT NULL DEFAULT TRUE,
downstream_interacted_at TIMESTAMP WITH TIME ZONE
); );
CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL'); CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
@ -171,6 +172,7 @@ var postgresMigrations = []string{
`ALTER TABLE "Network" ADD COLUMN certfp TEXT`, `ALTER TABLE "Network" ADD COLUMN certfp TEXT`,
`ALTER TABLE "User" ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()`, `ALTER TABLE "User" ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()`,
`ALTER TABLE "User" ADD COLUMN enabled BOOLEAN NOT NULL DEFAULT TRUE`, `ALTER TABLE "User" ADD COLUMN enabled BOOLEAN NOT NULL DEFAULT TRUE`,
`ALTER TABLE "User" ADD COLUMN downstream_interacted_at TIMESTAMP WITH TIME ZONE`,
} }
type PostgresDB struct { type PostgresDB struct {
@ -304,7 +306,8 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, rows, err := db.db.QueryContext(ctx,
`SELECT id, username, password, admin, nick, realname, enabled `SELECT id, username, password, admin, nick, realname, enabled,
downstream_interacted_at
FROM "User"`) FROM "User"`)
if err != nil { if err != nil {
return nil, err return nil, err
@ -315,7 +318,7 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
for rows.Next() { for rows.Next() {
var user User var user User
var password, nick, realname sql.NullString var password, nick, realname sql.NullString
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil { if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &user.DownstreamInteractedAt); err != nil {
return nil, err return nil, err
} }
user.Password = password.String user.Password = password.String
@ -338,11 +341,11 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro
var password, nick, realname sql.NullString var password, nick, realname sql.NullString
row := db.db.QueryRowContext(ctx, row := db.db.QueryRowContext(ctx,
`SELECT id, password, admin, nick, realname, enabled `SELECT id, password, admin, nick, realname, enabled, downstream_interacted_at
FROM "User" FROM "User"
WHERE username = $1`, WHERE username = $1`,
username) username)
if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil { if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &user.DownstreamInteractedAt); err != nil {
return nil, err return nil, err
} }
user.Password = password.String user.Password = password.String
@ -362,16 +365,20 @@ func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
var err error var err error
if user.ID == 0 { if user.ID == 0 {
err = db.db.QueryRowContext(ctx, ` err = db.db.QueryRowContext(ctx, `
INSERT INTO "User" (username, password, admin, nick, realname, enabled) INSERT INTO "User" (username, password, admin, nick, realname,
VALUES ($1, $2, $3, $4, $5, $6) enabled, downstream_interacted_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id`, RETURNING id`,
user.Username, password, user.Admin, nick, realname, user.Enabled).Scan(&user.ID) user.Username, password, user.Admin, nick, realname, user.Enabled,
user.DownstreamInteractedAt).Scan(&user.ID)
} else { } else {
_, err = db.db.ExecContext(ctx, ` _, err = db.db.ExecContext(ctx, `
UPDATE "User" UPDATE "User"
SET password = $1, admin = $2, nick = $3, realname = $4, enabled = $5 SET password = $1, admin = $2, nick = $3, realname = $4,
WHERE id = $6`, enabled = $5, downstream_interacted_at = $6
password, user.Admin, nick, realname, user.Enabled, user.ID) WHERE id = $7`,
password, user.Admin, nick, realname, user.Enabled,
user.DownstreamInteractedAt, user.ID)
} }
return err return err
} }

View File

@ -63,7 +63,8 @@ CREATE TABLE User (
realname TEXT, realname TEXT,
nick TEXT, nick TEXT,
created_at TEXT NOT NULL, created_at TEXT NOT NULL,
enabled INTEGER NOT NULL DEFAULT 1 enabled INTEGER NOT NULL DEFAULT 1,
downstream_interacted_at TEXT
); );
CREATE TABLE Network ( CREATE TABLE Network (
@ -291,6 +292,7 @@ var sqliteMigrations = []string{
UPDATE User SET created_at = strftime('` + sqliteTimeFormat + `', 'now'); UPDATE User SET created_at = strftime('` + sqliteTimeFormat + `', 'now');
`, `,
"ALTER TABLE User ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1", "ALTER TABLE User ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
"ALTER TABLE User ADD COLUMN downstream_interacted_at TEXT;",
} }
type SqliteDB struct { type SqliteDB struct {
@ -390,7 +392,8 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
defer cancel() defer cancel()
rows, err := db.db.QueryContext(ctx, rows, err := db.db.QueryContext(ctx,
`SELECT id, username, password, admin, nick, realname, enabled `SELECT id, username, password, admin, nick, realname, enabled,
downstream_interacted_at
FROM User`) FROM User`)
if err != nil { if err != nil {
return nil, err return nil, err
@ -401,12 +404,14 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
for rows.Next() { for rows.Next() {
var user User var user User
var password, nick, realname sql.NullString var password, nick, realname sql.NullString
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil { var downstreamInteractedAt sqliteTime
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
return nil, err return nil, err
} }
user.Password = password.String user.Password = password.String
user.Nick = nick.String user.Nick = nick.String
user.Realname = realname.String user.Realname = realname.String
user.DownstreamInteractedAt = downstreamInteractedAt.Time
users = append(users, user) users = append(users, user)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
@ -423,17 +428,20 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error)
user := &User{Username: username} user := &User{Username: username}
var password, nick, realname sql.NullString var password, nick, realname sql.NullString
var downstreamInteractedAt sqliteTime
row := db.db.QueryRowContext(ctx, row := db.db.QueryRowContext(ctx,
`SELECT id, password, admin, nick, realname, enabled `SELECT id, password, admin, nick, realname, enabled,
downstream_interacted_at
FROM User FROM User
WHERE username = ?`, WHERE username = ?`,
username) username)
if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil { if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
return nil, err return nil, err
} }
user.Password = password.String user.Password = password.String
user.Nick = nick.String user.Nick = nick.String
user.Realname = realname.String user.Realname = realname.String
user.DownstreamInteractedAt = downstreamInteractedAt.Time
return user, nil return user, nil
} }
@ -449,6 +457,7 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
sql.Named("realname", toNullString(user.Realname)), sql.Named("realname", toNullString(user.Realname)),
sql.Named("enabled", user.Enabled), sql.Named("enabled", user.Enabled),
sql.Named("now", sqliteTime{time.Now()}), sql.Named("now", sqliteTime{time.Now()}),
sql.Named("downstream_interacted_at", sqliteTime{user.DownstreamInteractedAt}),
} }
var err error var err error
@ -456,7 +465,8 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
_, err = db.db.ExecContext(ctx, ` _, err = db.db.ExecContext(ctx, `
UPDATE User UPDATE User
SET password = :password, admin = :admin, nick = :nick, SET password = :password, admin = :admin, nick = :nick,
realname = :realname, enabled = :enabled realname = :realname, enabled = :enabled,
downstream_interacted_at = :downstream_interacted_at
WHERE username = :username`, WHERE username = :username`,
args...) args...)
} else { } else {
@ -464,9 +474,9 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
res, err = db.db.ExecContext(ctx, ` res, err = db.db.ExecContext(ctx, `
INSERT INTO INSERT INTO
User(username, password, admin, nick, realname, created_at, User(username, password, admin, nick, realname, created_at,
enabled) enabled, downstream_interacted_at)
VALUES (:username, :password, :admin, :nick, :realname, :now, VALUES (:username, :password, :admin, :nick, :realname, :now,
:enabled)`, :enabled, :downstream_interacted_at)`,
args...) args...)
if err != nil { if err != nil {
return err return err

18
user.go
View File

@ -683,6 +683,7 @@ func (u *user) run() {
} }
case eventDownstreamConnected: case eventDownstreamConnected:
dc := e.dc dc := e.dc
ctx := context.TODO()
if dc.network != nil { if dc.network != nil {
dc.monitored.SetCasemapping(dc.network.casemap) dc.monitored.SetCasemapping(dc.network.casemap)
@ -697,7 +698,7 @@ func (u *user) run() {
break break
} }
if err := dc.welcome(context.TODO()); err != nil { if err := dc.welcome(ctx); err != nil {
if ircErr, ok := err.(ircError); ok { if ircErr, ok := err.(ircError); ok {
msg := ircErr.Message.Copy() msg := ircErr.Message.Copy()
msg.Prefix = dc.srv.prefix() msg.Prefix = dc.srv.prefix()
@ -724,8 +725,11 @@ func (u *user) run() {
u.forEachUpstream(func(uc *upstreamConn) { u.forEachUpstream(func(uc *upstreamConn) {
uc.updateAway() uc.updateAway()
}) })
u.bumpDownstreamInteractionTime(ctx)
case eventDownstreamDisconnected: case eventDownstreamDisconnected:
dc := e.dc dc := e.dc
ctx := context.TODO()
for i := range u.downstreamConns { for i := range u.downstreamConns {
if u.downstreamConns[i] == dc { if u.downstreamConns[i] == dc {
@ -735,7 +739,7 @@ func (u *user) run() {
} }
dc.forEachNetwork(func(net *network) { dc.forEachNetwork(func(net *network) {
net.storeClientDeliveryReceipts(context.TODO(), dc.clientName) net.storeClientDeliveryReceipts(ctx, dc.clientName)
}) })
u.forEachUpstream(func(uc *upstreamConn) { u.forEachUpstream(func(uc *upstreamConn) {
@ -743,6 +747,8 @@ func (u *user) run() {
uc.updateAway() uc.updateAway()
uc.updateMonitor() uc.updateMonitor()
}) })
u.bumpDownstreamInteractionTime(ctx)
case eventDownstreamMessage: case eventDownstreamMessage:
msg, dc := e.msg, e.dc msg, dc := e.msg, e.dc
if dc.isClosed() { if dc.isClosed() {
@ -1229,3 +1235,11 @@ func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAd
return &net.TCPAddr{IP: ip}, nil return &net.TCPAddr{IP: ip}, nil
} }
func (u *user) bumpDownstreamInteractionTime(ctx context.Context) {
record := u.User
record.DownstreamInteractedAt = time.Now()
if err := u.updateUser(ctx, &record); err != nil {
u.logger.Printf("failed to bump downstream interaction time: %v", err)
}
}