Add user downstream interaction time

This commit is contained in:
Simon Ser 2023-01-26 14:02:11 +01:00
parent 05d7813835
commit 57f5ee8d6f
4 changed files with 60 additions and 28 deletions

View File

@ -72,6 +72,7 @@ type User struct {
Realname string
Admin bool
Enabled bool
DownstreamInteractedAt time.Time
}
func (u *User) CheckPassword(password string) (upgraded bool, err error) {

View File

@ -33,7 +33,8 @@ CREATE TABLE "User" (
nick VARCHAR(255),
realname VARCHAR(255),
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
enabled BOOLEAN NOT NULL DEFAULT TRUE
enabled BOOLEAN NOT NULL DEFAULT TRUE,
downstream_interacted_at TIMESTAMP WITH TIME ZONE
);
CREATE TYPE sasl_mechanism AS ENUM ('PLAIN', 'EXTERNAL');
@ -171,6 +172,7 @@ var postgresMigrations = []string{
`ALTER TABLE "Network" ADD COLUMN certfp TEXT`,
`ALTER TABLE "User" ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()`,
`ALTER TABLE "User" ADD COLUMN enabled BOOLEAN NOT NULL DEFAULT TRUE`,
`ALTER TABLE "User" ADD COLUMN downstream_interacted_at TIMESTAMP WITH TIME ZONE`,
}
type PostgresDB struct {
@ -304,7 +306,8 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
defer cancel()
rows, err := db.db.QueryContext(ctx,
`SELECT id, username, password, admin, nick, realname, enabled
`SELECT id, username, password, admin, nick, realname, enabled,
downstream_interacted_at
FROM "User"`)
if err != nil {
return nil, err
@ -315,7 +318,7 @@ func (db *PostgresDB) ListUsers(ctx context.Context) ([]User, error) {
for rows.Next() {
var user User
var password, nick, realname sql.NullString
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil {
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &user.DownstreamInteractedAt); err != nil {
return nil, err
}
user.Password = password.String
@ -338,11 +341,11 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro
var password, nick, realname sql.NullString
row := db.db.QueryRowContext(ctx,
`SELECT id, password, admin, nick, realname, enabled
`SELECT id, password, admin, nick, realname, enabled, downstream_interacted_at
FROM "User"
WHERE username = $1`,
username)
if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil {
if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &user.DownstreamInteractedAt); err != nil {
return nil, err
}
user.Password = password.String
@ -362,16 +365,20 @@ func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
var err error
if user.ID == 0 {
err = db.db.QueryRowContext(ctx, `
INSERT INTO "User" (username, password, admin, nick, realname, enabled)
VALUES ($1, $2, $3, $4, $5, $6)
INSERT INTO "User" (username, password, admin, nick, realname,
enabled, downstream_interacted_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id`,
user.Username, password, user.Admin, nick, realname, user.Enabled).Scan(&user.ID)
user.Username, password, user.Admin, nick, realname, user.Enabled,
user.DownstreamInteractedAt).Scan(&user.ID)
} else {
_, err = db.db.ExecContext(ctx, `
UPDATE "User"
SET password = $1, admin = $2, nick = $3, realname = $4, enabled = $5
WHERE id = $6`,
password, user.Admin, nick, realname, user.Enabled, user.ID)
SET password = $1, admin = $2, nick = $3, realname = $4,
enabled = $5, downstream_interacted_at = $6
WHERE id = $7`,
password, user.Admin, nick, realname, user.Enabled,
user.DownstreamInteractedAt, user.ID)
}
return err
}

View File

@ -63,7 +63,8 @@ CREATE TABLE User (
realname TEXT,
nick TEXT,
created_at TEXT NOT NULL,
enabled INTEGER NOT NULL DEFAULT 1
enabled INTEGER NOT NULL DEFAULT 1,
downstream_interacted_at TEXT
);
CREATE TABLE Network (
@ -291,6 +292,7 @@ var sqliteMigrations = []string{
UPDATE User SET created_at = strftime('` + sqliteTimeFormat + `', 'now');
`,
"ALTER TABLE User ADD COLUMN enabled INTEGER NOT NULL DEFAULT 1",
"ALTER TABLE User ADD COLUMN downstream_interacted_at TEXT;",
}
type SqliteDB struct {
@ -390,7 +392,8 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
defer cancel()
rows, err := db.db.QueryContext(ctx,
`SELECT id, username, password, admin, nick, realname, enabled
`SELECT id, username, password, admin, nick, realname, enabled,
downstream_interacted_at
FROM User`)
if err != nil {
return nil, err
@ -401,12 +404,14 @@ func (db *SqliteDB) ListUsers(ctx context.Context) ([]User, error) {
for rows.Next() {
var user User
var password, nick, realname sql.NullString
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil {
var downstreamInteractedAt sqliteTime
if err := rows.Scan(&user.ID, &user.Username, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
return nil, err
}
user.Password = password.String
user.Nick = nick.String
user.Realname = realname.String
user.DownstreamInteractedAt = downstreamInteractedAt.Time
users = append(users, user)
}
if err := rows.Err(); err != nil {
@ -423,17 +428,20 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error)
user := &User{Username: username}
var password, nick, realname sql.NullString
var downstreamInteractedAt sqliteTime
row := db.db.QueryRowContext(ctx,
`SELECT id, password, admin, nick, realname, enabled
`SELECT id, password, admin, nick, realname, enabled,
downstream_interacted_at
FROM User
WHERE username = ?`,
username)
if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled); err != nil {
if err := row.Scan(&user.ID, &password, &user.Admin, &nick, &realname, &user.Enabled, &downstreamInteractedAt); err != nil {
return nil, err
}
user.Password = password.String
user.Nick = nick.String
user.Realname = realname.String
user.DownstreamInteractedAt = downstreamInteractedAt.Time
return user, nil
}
@ -449,6 +457,7 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
sql.Named("realname", toNullString(user.Realname)),
sql.Named("enabled", user.Enabled),
sql.Named("now", sqliteTime{time.Now()}),
sql.Named("downstream_interacted_at", sqliteTime{user.DownstreamInteractedAt}),
}
var err error
@ -456,7 +465,8 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
_, err = db.db.ExecContext(ctx, `
UPDATE User
SET password = :password, admin = :admin, nick = :nick,
realname = :realname, enabled = :enabled
realname = :realname, enabled = :enabled,
downstream_interacted_at = :downstream_interacted_at
WHERE username = :username`,
args...)
} else {
@ -464,9 +474,9 @@ func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
res, err = db.db.ExecContext(ctx, `
INSERT INTO
User(username, password, admin, nick, realname, created_at,
enabled)
enabled, downstream_interacted_at)
VALUES (:username, :password, :admin, :nick, :realname, :now,
:enabled)`,
:enabled, :downstream_interacted_at)`,
args...)
if err != nil {
return err

18
user.go
View File

@ -683,6 +683,7 @@ func (u *user) run() {
}
case eventDownstreamConnected:
dc := e.dc
ctx := context.TODO()
if dc.network != nil {
dc.monitored.SetCasemapping(dc.network.casemap)
@ -697,7 +698,7 @@ func (u *user) run() {
break
}
if err := dc.welcome(context.TODO()); err != nil {
if err := dc.welcome(ctx); err != nil {
if ircErr, ok := err.(ircError); ok {
msg := ircErr.Message.Copy()
msg.Prefix = dc.srv.prefix()
@ -724,8 +725,11 @@ func (u *user) run() {
u.forEachUpstream(func(uc *upstreamConn) {
uc.updateAway()
})
u.bumpDownstreamInteractionTime(ctx)
case eventDownstreamDisconnected:
dc := e.dc
ctx := context.TODO()
for i := range u.downstreamConns {
if u.downstreamConns[i] == dc {
@ -735,7 +739,7 @@ func (u *user) run() {
}
dc.forEachNetwork(func(net *network) {
net.storeClientDeliveryReceipts(context.TODO(), dc.clientName)
net.storeClientDeliveryReceipts(ctx, dc.clientName)
})
u.forEachUpstream(func(uc *upstreamConn) {
@ -743,6 +747,8 @@ func (u *user) run() {
uc.updateAway()
uc.updateMonitor()
})
u.bumpDownstreamInteractionTime(ctx)
case eventDownstreamMessage:
msg, dc := e.msg, e.dc
if dc.isClosed() {
@ -1229,3 +1235,11 @@ func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAd
return &net.TCPAddr{IP: ip}, nil
}
func (u *user) bumpDownstreamInteractionTime(ctx context.Context) {
record := u.User
record.DownstreamInteractedAt = time.Now()
if err := u.updateUser(ctx, &record); err != nil {
u.logger.Printf("failed to bump downstream interaction time: %v", err)
}
}