From 9df9880301377775d76943a085520eaf1d51af2b Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Thu, 26 Jan 2023 16:57:07 +0100 Subject: [PATCH] Add disable-inactive-user config option This can be used to automatically disable users if they don't actively use the bouncer for a while. --- cmd/soju/main.go | 17 +++---- config/config.go | 31 +++++++++++- database/database.go | 1 + database/postgres.go | 27 +++++++++++ database/sqlite.go | 27 +++++++++++ doc/soju.1.scd | 9 ++++ server.go | 112 +++++++++++++++++++++++++++++++++++++++---- user.go | 5 ++ 8 files changed, 211 insertions(+), 18 deletions(-) diff --git a/cmd/soju/main.go b/cmd/soju/main.go index 0094381..a35711e 100644 --- a/cmd/soju/main.go +++ b/cmd/soju/main.go @@ -84,14 +84,15 @@ func loadConfig() (*config.Server, *soju.Config, error) { } cfg := &soju.Config{ - Hostname: raw.Hostname, - Title: raw.Title, - LogPath: raw.MsgStore.Source, - HTTPOrigins: raw.HTTPOrigins, - AcceptProxyIPs: raw.AcceptProxyIPs, - MaxUserNetworks: raw.MaxUserNetworks, - UpstreamUserIPs: raw.UpstreamUserIPs, - MOTD: motd, + Hostname: raw.Hostname, + Title: raw.Title, + LogPath: raw.MsgStore.Source, + HTTPOrigins: raw.HTTPOrigins, + AcceptProxyIPs: raw.AcceptProxyIPs, + MaxUserNetworks: raw.MaxUserNetworks, + UpstreamUserIPs: raw.UpstreamUserIPs, + DisableInactiveUsersDelay: raw.DisableInactiveUsersDelay, + MOTD: motd, } return raw, cfg, nil } diff --git a/config/config.go b/config/config.go index 25233dd..c273705 100644 --- a/config/config.go +++ b/config/config.go @@ -5,6 +5,8 @@ import ( "net" "os" "strconv" + "strings" + "time" "git.sr.ht/~emersion/go-scfg" ) @@ -32,6 +34,18 @@ var loopbackIPs = IPSet{ }, } +func parseDuration(s string) (time.Duration, error) { + if !strings.HasSuffix(s, "d") { + return 0, fmt.Errorf("missing 'd' suffix in duration") + } + s = strings.TrimSuffix(s, "d") + v, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0, fmt.Errorf("invalid duration: %v", err) + } + return time.Duration(v * 24 * float64(time.Hour)), nil +} + type TLS struct { CertPath, KeyPath string } @@ -57,8 +71,9 @@ type Server struct { HTTPOrigins []string AcceptProxyIPs IPSet - MaxUserNetworks int - UpstreamUserIPs []*net.IPNet + MaxUserNetworks int + UpstreamUserIPs []*net.IPNet + DisableInactiveUsersDelay time.Duration } func Defaults() *Server { @@ -180,6 +195,18 @@ func parse(cfg scfg.Block) (*Server, error) { } srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n) } + case "disable-inactive-user": + var durStr string + if err := d.ParseParams(&durStr); err != nil { + return nil, err + } + dur, err := parseDuration(durStr) + if err != nil { + return nil, fmt.Errorf("directive %q: %v", d.Name, err) + } else if dur < 0 { + return nil, fmt.Errorf("directive %q: duration must be positive", d.Name) + } + srv.DisableInactiveUsersDelay = dur default: return nil, fmt.Errorf("unknown directive %q", d.Name) } diff --git a/database/database.go b/database/database.go index 2a8027a..b032cb0 100644 --- a/database/database.go +++ b/database/database.go @@ -20,6 +20,7 @@ type Database interface { GetUser(ctx context.Context, username string) (*User, error) StoreUser(ctx context.Context, user *User) error DeleteUser(ctx context.Context, id int64) error + ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error) ListNetworks(ctx context.Context, userID int64) ([]Network, error) StoreNetwork(ctx context.Context, userID int64, network *Network) error diff --git a/database/postgres.go b/database/postgres.go index 5a00c08..b7cec8c 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -354,6 +354,33 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro return user, nil } +func (db *PostgresDB) ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error) { + ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + `SELECT username FROM "User" WHERE COALESCE(downstream_interacted_at, created_at) < $1`, + limit) + if err != nil { + return nil, err + } + defer rows.Close() + + var usernames []string + for rows.Next() { + var username string + if err := rows.Scan(&username); err != nil { + return nil, err + } + usernames = append(usernames, username) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return usernames, nil +} + func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error { ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) defer cancel() diff --git a/database/sqlite.go b/database/sqlite.go index 28e18cc..58801b0 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -445,6 +445,33 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error) return user, nil } +func (db *SqliteDB) ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error) { + ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) + defer cancel() + + rows, err := db.db.QueryContext(ctx, + "SELECT username FROM User WHERE coalesce(downstream_interacted_at, created_at) < ?", + sqliteTime{limit}) + if err != nil { + return nil, err + } + defer rows.Close() + + var usernames []string + for rows.Next() { + var username string + if err := rows.Scan(&username); err != nil { + return nil, err + } + usernames = append(usernames, username) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return usernames, nil +} + func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) defer cancel() diff --git a/doc/soju.1.scd b/doc/soju.1.scd index add5527..1805836 100644 --- a/doc/soju.1.scd +++ b/doc/soju.1.scd @@ -170,6 +170,15 @@ The following directives are supported: This can be useful to avoid having the whole bouncer banned from an upstream network because of one malicious user. +*disable-inactive-user* + Disable inactive users after the specified duration. + + A user is inactive when the last downstream connection is closed. + + The duration is a positive decimal number followed by the unit "d" (days). + For instance, "30d" disables users 30 days after they last disconnect from + the bouncer. + # IRC SERVICE soju exposes an IRC service called *BouncerServ* to manage the bouncer. diff --git a/server.go b/server.go index 7d1d305..c5c39c8 100644 --- a/server.go +++ b/server.go @@ -133,14 +133,15 @@ func (ln *retryListener) Accept() (net.Conn, error) { } type Config struct { - Hostname string - Title string - LogPath string - HTTPOrigins []string - AcceptProxyIPs config.IPSet - MaxUserNetworks int - MOTD string - UpstreamUserIPs []*net.IPNet + Hostname string + Title string + LogPath string + HTTPOrigins []string + AcceptProxyIPs config.IPSet + MaxUserNetworks int + MOTD string + UpstreamUserIPs []*net.IPNet + DisableInactiveUsersDelay time.Duration } type Server struct { @@ -151,6 +152,7 @@ type Server struct { config atomic.Value // *Config db database.Database stopWG sync.WaitGroup + stopCh chan struct{} lock sync.Mutex listeners map[net.Listener]struct{} @@ -178,6 +180,7 @@ func NewServer(db database.Database) *Server { db: db, listeners: make(map[net.Listener]struct{}), users: make(map[string]*user), + stopCh: make(chan struct{}), } srv.config.Store(&Config{ Hostname: "localhost", @@ -216,6 +219,12 @@ func (s *Server) Start() error { } s.lock.Unlock() + s.stopWG.Add(1) + go func() { + defer s.stopWG.Done() + s.disableInactiveUsersLoop() + }() + return nil } @@ -343,6 +352,8 @@ func (s *Server) sendWebPush(ctx context.Context, sub *webpush.Subscription, vap func (s *Server) Shutdown() { s.Logger.Printf("shutting down server") + close(s.stopCh) + s.lock.Lock() s.shutdown = true for ln := range s.listeners { @@ -547,3 +558,88 @@ func (s *Server) Stats() *ServerStats { stats.Upstreams = s.metrics.upstreams.Value() return &stats } + +func (s *Server) disableInactiveUsersLoop() { + ticker := time.NewTicker(4 * time.Hour) + defer ticker.Stop() + + for { + select { + case <-s.stopCh: + return + case <-ticker.C: + } + + if err := s.disableInactiveUsers(context.TODO()); err != nil { + s.Logger.Printf("failed to disable inactive users: %v", err) + } + } +} + +func (s *Server) disableInactiveUsers(ctx context.Context) error { + delay := s.Config().DisableInactiveUsersDelay + if delay == 0 { + return nil + } + + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + usernames, err := s.db.ListInactiveUsernames(ctx, time.Now().Add(-delay)) + if err != nil { + return fmt.Errorf("failed to list inactive users: %v", err) + } else if len(usernames) == 0 { + return nil + } + + // Filter out users with active downstream connections + var users []*user + s.lock.Lock() + for _, username := range usernames { + u := s.users[username] + if u == nil { + // TODO: disable the user in the DB + continue + } + + if n := u.numDownstreamConns.Load(); n > 0 { + continue + } + + users = append(users, u) + } + s.lock.Unlock() + + if len(users) == 0 { + return nil + } + + s.Logger.Printf("found %v inactive users", len(users)) + for _, u := range users { + done := make(chan error, 1) + enabled := false + event := eventUserUpdate{ + enabled: &enabled, + done: done, + } + select { + case <-ctx.Done(): + return ctx.Err() + case u.events <- event: + // Event was sent, let's wait for the reply + } + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + if err != nil { + return err + } else { + s.Logger.Printf("deleted inactive user %q", u.Username) + } + } + } + + return nil +} diff --git a/user.go b/user.go index bdb72a8..c58fa04 100644 --- a/user.go +++ b/user.go @@ -11,6 +11,7 @@ import ( "net" "sort" "strings" + "sync/atomic" "time" "git.sr.ht/~emersion/soju/xirc" @@ -503,6 +504,8 @@ type user struct { events chan event done chan struct{} + numDownstreamConns atomic.Int64 + networks []*network downstreamConns []*downstreamConn msgStore msgstore.Store @@ -715,6 +718,7 @@ func (u *user) run() { } u.downstreamConns = append(u.downstreamConns, dc) + u.numDownstreamConns.Add(1) dc.forEachNetwork(func(network *network) { if network.lastError != nil { @@ -734,6 +738,7 @@ func (u *user) run() { for i := range u.downstreamConns { if u.downstreamConns[i] == dc { u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...) + u.numDownstreamConns.Add(-1) break } }