Add disable-inactive-user config option

This can be used to automatically disable users if they don't
actively use the bouncer for a while.
This commit is contained in:
Simon Ser 2023-01-26 16:57:07 +01:00
parent 57f5ee8d6f
commit 9df9880301
8 changed files with 211 additions and 18 deletions

View File

@ -84,14 +84,15 @@ func loadConfig() (*config.Server, *soju.Config, error) {
} }
cfg := &soju.Config{ cfg := &soju.Config{
Hostname: raw.Hostname, Hostname: raw.Hostname,
Title: raw.Title, Title: raw.Title,
LogPath: raw.MsgStore.Source, LogPath: raw.MsgStore.Source,
HTTPOrigins: raw.HTTPOrigins, HTTPOrigins: raw.HTTPOrigins,
AcceptProxyIPs: raw.AcceptProxyIPs, AcceptProxyIPs: raw.AcceptProxyIPs,
MaxUserNetworks: raw.MaxUserNetworks, MaxUserNetworks: raw.MaxUserNetworks,
UpstreamUserIPs: raw.UpstreamUserIPs, UpstreamUserIPs: raw.UpstreamUserIPs,
MOTD: motd, DisableInactiveUsersDelay: raw.DisableInactiveUsersDelay,
MOTD: motd,
} }
return raw, cfg, nil return raw, cfg, nil
} }

View File

@ -5,6 +5,8 @@ import (
"net" "net"
"os" "os"
"strconv" "strconv"
"strings"
"time"
"git.sr.ht/~emersion/go-scfg" "git.sr.ht/~emersion/go-scfg"
) )
@ -32,6 +34,18 @@ var loopbackIPs = IPSet{
}, },
} }
func parseDuration(s string) (time.Duration, error) {
if !strings.HasSuffix(s, "d") {
return 0, fmt.Errorf("missing 'd' suffix in duration")
}
s = strings.TrimSuffix(s, "d")
v, err := strconv.ParseFloat(s, 64)
if err != nil {
return 0, fmt.Errorf("invalid duration: %v", err)
}
return time.Duration(v * 24 * float64(time.Hour)), nil
}
type TLS struct { type TLS struct {
CertPath, KeyPath string CertPath, KeyPath string
} }
@ -57,8 +71,9 @@ type Server struct {
HTTPOrigins []string HTTPOrigins []string
AcceptProxyIPs IPSet AcceptProxyIPs IPSet
MaxUserNetworks int MaxUserNetworks int
UpstreamUserIPs []*net.IPNet UpstreamUserIPs []*net.IPNet
DisableInactiveUsersDelay time.Duration
} }
func Defaults() *Server { func Defaults() *Server {
@ -180,6 +195,18 @@ func parse(cfg scfg.Block) (*Server, error) {
} }
srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n) srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n)
} }
case "disable-inactive-user":
var durStr string
if err := d.ParseParams(&durStr); err != nil {
return nil, err
}
dur, err := parseDuration(durStr)
if err != nil {
return nil, fmt.Errorf("directive %q: %v", d.Name, err)
} else if dur < 0 {
return nil, fmt.Errorf("directive %q: duration must be positive", d.Name)
}
srv.DisableInactiveUsersDelay = dur
default: default:
return nil, fmt.Errorf("unknown directive %q", d.Name) return nil, fmt.Errorf("unknown directive %q", d.Name)
} }

View File

@ -20,6 +20,7 @@ type Database interface {
GetUser(ctx context.Context, username string) (*User, error) GetUser(ctx context.Context, username string) (*User, error)
StoreUser(ctx context.Context, user *User) error StoreUser(ctx context.Context, user *User) error
DeleteUser(ctx context.Context, id int64) error DeleteUser(ctx context.Context, id int64) error
ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error)
ListNetworks(ctx context.Context, userID int64) ([]Network, error) ListNetworks(ctx context.Context, userID int64) ([]Network, error)
StoreNetwork(ctx context.Context, userID int64, network *Network) error StoreNetwork(ctx context.Context, userID int64, network *Network) error

View File

@ -354,6 +354,33 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro
return user, nil return user, nil
} }
func (db *PostgresDB) ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx,
`SELECT username FROM "User" WHERE COALESCE(downstream_interacted_at, created_at) < $1`,
limit)
if err != nil {
return nil, err
}
defer rows.Close()
var usernames []string
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return nil, err
}
usernames = append(usernames, username)
}
if err := rows.Err(); err != nil {
return nil, err
}
return usernames, nil
}
func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error { func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout) ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel() defer cancel()

View File

@ -445,6 +445,33 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error)
return user, nil return user, nil
} }
func (db *SqliteDB) ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error) {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
rows, err := db.db.QueryContext(ctx,
"SELECT username FROM User WHERE coalesce(downstream_interacted_at, created_at) < ?",
sqliteTime{limit})
if err != nil {
return nil, err
}
defer rows.Close()
var usernames []string
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return nil, err
}
usernames = append(usernames, username)
}
if err := rows.Err(); err != nil {
return nil, err
}
return usernames, nil
}
func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error { func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout) ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel() defer cancel()

View File

@ -170,6 +170,15 @@ The following directives are supported:
This can be useful to avoid having the whole bouncer banned from an upstream This can be useful to avoid having the whole bouncer banned from an upstream
network because of one malicious user. network because of one malicious user.
*disable-inactive-user* <duration>
Disable inactive users after the specified duration.
A user is inactive when the last downstream connection is closed.
The duration is a positive decimal number followed by the unit "d" (days).
For instance, "30d" disables users 30 days after they last disconnect from
the bouncer.
# IRC SERVICE # IRC SERVICE
soju exposes an IRC service called *BouncerServ* to manage the bouncer. soju exposes an IRC service called *BouncerServ* to manage the bouncer.

112
server.go
View File

@ -133,14 +133,15 @@ func (ln *retryListener) Accept() (net.Conn, error) {
} }
type Config struct { type Config struct {
Hostname string Hostname string
Title string Title string
LogPath string LogPath string
HTTPOrigins []string HTTPOrigins []string
AcceptProxyIPs config.IPSet AcceptProxyIPs config.IPSet
MaxUserNetworks int MaxUserNetworks int
MOTD string MOTD string
UpstreamUserIPs []*net.IPNet UpstreamUserIPs []*net.IPNet
DisableInactiveUsersDelay time.Duration
} }
type Server struct { type Server struct {
@ -151,6 +152,7 @@ type Server struct {
config atomic.Value // *Config config atomic.Value // *Config
db database.Database db database.Database
stopWG sync.WaitGroup stopWG sync.WaitGroup
stopCh chan struct{}
lock sync.Mutex lock sync.Mutex
listeners map[net.Listener]struct{} listeners map[net.Listener]struct{}
@ -178,6 +180,7 @@ func NewServer(db database.Database) *Server {
db: db, db: db,
listeners: make(map[net.Listener]struct{}), listeners: make(map[net.Listener]struct{}),
users: make(map[string]*user), users: make(map[string]*user),
stopCh: make(chan struct{}),
} }
srv.config.Store(&Config{ srv.config.Store(&Config{
Hostname: "localhost", Hostname: "localhost",
@ -216,6 +219,12 @@ func (s *Server) Start() error {
} }
s.lock.Unlock() s.lock.Unlock()
s.stopWG.Add(1)
go func() {
defer s.stopWG.Done()
s.disableInactiveUsersLoop()
}()
return nil return nil
} }
@ -343,6 +352,8 @@ func (s *Server) sendWebPush(ctx context.Context, sub *webpush.Subscription, vap
func (s *Server) Shutdown() { func (s *Server) Shutdown() {
s.Logger.Printf("shutting down server") s.Logger.Printf("shutting down server")
close(s.stopCh)
s.lock.Lock() s.lock.Lock()
s.shutdown = true s.shutdown = true
for ln := range s.listeners { for ln := range s.listeners {
@ -547,3 +558,88 @@ func (s *Server) Stats() *ServerStats {
stats.Upstreams = s.metrics.upstreams.Value() stats.Upstreams = s.metrics.upstreams.Value()
return &stats return &stats
} }
func (s *Server) disableInactiveUsersLoop() {
ticker := time.NewTicker(4 * time.Hour)
defer ticker.Stop()
for {
select {
case <-s.stopCh:
return
case <-ticker.C:
}
if err := s.disableInactiveUsers(context.TODO()); err != nil {
s.Logger.Printf("failed to disable inactive users: %v", err)
}
}
}
func (s *Server) disableInactiveUsers(ctx context.Context) error {
delay := s.Config().DisableInactiveUsersDelay
if delay == 0 {
return nil
}
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
usernames, err := s.db.ListInactiveUsernames(ctx, time.Now().Add(-delay))
if err != nil {
return fmt.Errorf("failed to list inactive users: %v", err)
} else if len(usernames) == 0 {
return nil
}
// Filter out users with active downstream connections
var users []*user
s.lock.Lock()
for _, username := range usernames {
u := s.users[username]
if u == nil {
// TODO: disable the user in the DB
continue
}
if n := u.numDownstreamConns.Load(); n > 0 {
continue
}
users = append(users, u)
}
s.lock.Unlock()
if len(users) == 0 {
return nil
}
s.Logger.Printf("found %v inactive users", len(users))
for _, u := range users {
done := make(chan error, 1)
enabled := false
event := eventUserUpdate{
enabled: &enabled,
done: done,
}
select {
case <-ctx.Done():
return ctx.Err()
case u.events <- event:
// Event was sent, let's wait for the reply
}
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
if err != nil {
return err
} else {
s.Logger.Printf("deleted inactive user %q", u.Username)
}
}
}
return nil
}

View File

@ -11,6 +11,7 @@ import (
"net" "net"
"sort" "sort"
"strings" "strings"
"sync/atomic"
"time" "time"
"git.sr.ht/~emersion/soju/xirc" "git.sr.ht/~emersion/soju/xirc"
@ -503,6 +504,8 @@ type user struct {
events chan event events chan event
done chan struct{} done chan struct{}
numDownstreamConns atomic.Int64
networks []*network networks []*network
downstreamConns []*downstreamConn downstreamConns []*downstreamConn
msgStore msgstore.Store msgStore msgstore.Store
@ -715,6 +718,7 @@ func (u *user) run() {
} }
u.downstreamConns = append(u.downstreamConns, dc) u.downstreamConns = append(u.downstreamConns, dc)
u.numDownstreamConns.Add(1)
dc.forEachNetwork(func(network *network) { dc.forEachNetwork(func(network *network) {
if network.lastError != nil { if network.lastError != nil {
@ -734,6 +738,7 @@ func (u *user) run() {
for i := range u.downstreamConns { for i := range u.downstreamConns {
if u.downstreamConns[i] == dc { if u.downstreamConns[i] == dc {
u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...) u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
u.numDownstreamConns.Add(-1)
break break
} }
} }