Add disable-inactive-user config option
This can be used to automatically disable users if they don't actively use the bouncer for a while.
This commit is contained in:
parent
57f5ee8d6f
commit
9df9880301
@ -91,6 +91,7 @@ func loadConfig() (*config.Server, *soju.Config, error) {
|
||||
AcceptProxyIPs: raw.AcceptProxyIPs,
|
||||
MaxUserNetworks: raw.MaxUserNetworks,
|
||||
UpstreamUserIPs: raw.UpstreamUserIPs,
|
||||
DisableInactiveUsersDelay: raw.DisableInactiveUsersDelay,
|
||||
MOTD: motd,
|
||||
}
|
||||
return raw, cfg, nil
|
||||
|
@ -5,6 +5,8 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.sr.ht/~emersion/go-scfg"
|
||||
)
|
||||
@ -32,6 +34,18 @@ var loopbackIPs = IPSet{
|
||||
},
|
||||
}
|
||||
|
||||
func parseDuration(s string) (time.Duration, error) {
|
||||
if !strings.HasSuffix(s, "d") {
|
||||
return 0, fmt.Errorf("missing 'd' suffix in duration")
|
||||
}
|
||||
s = strings.TrimSuffix(s, "d")
|
||||
v, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid duration: %v", err)
|
||||
}
|
||||
return time.Duration(v * 24 * float64(time.Hour)), nil
|
||||
}
|
||||
|
||||
type TLS struct {
|
||||
CertPath, KeyPath string
|
||||
}
|
||||
@ -59,6 +73,7 @@ type Server struct {
|
||||
|
||||
MaxUserNetworks int
|
||||
UpstreamUserIPs []*net.IPNet
|
||||
DisableInactiveUsersDelay time.Duration
|
||||
}
|
||||
|
||||
func Defaults() *Server {
|
||||
@ -180,6 +195,18 @@ func parse(cfg scfg.Block) (*Server, error) {
|
||||
}
|
||||
srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n)
|
||||
}
|
||||
case "disable-inactive-user":
|
||||
var durStr string
|
||||
if err := d.ParseParams(&durStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dur, err := parseDuration(durStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("directive %q: %v", d.Name, err)
|
||||
} else if dur < 0 {
|
||||
return nil, fmt.Errorf("directive %q: duration must be positive", d.Name)
|
||||
}
|
||||
srv.DisableInactiveUsersDelay = dur
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown directive %q", d.Name)
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ type Database interface {
|
||||
GetUser(ctx context.Context, username string) (*User, error)
|
||||
StoreUser(ctx context.Context, user *User) error
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error)
|
||||
|
||||
ListNetworks(ctx context.Context, userID int64) ([]Network, error)
|
||||
StoreNetwork(ctx context.Context, userID int64, network *Network) error
|
||||
|
@ -354,6 +354,33 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx,
|
||||
`SELECT username FROM "User" WHERE COALESCE(downstream_interacted_at, created_at) < $1`,
|
||||
limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var usernames []string
|
||||
for rows.Next() {
|
||||
var username string
|
||||
if err := rows.Scan(&username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usernames = append(usernames, username)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return usernames, nil
|
||||
}
|
||||
|
||||
func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||
defer cancel()
|
||||
|
@ -445,6 +445,33 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (db *SqliteDB) ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
rows, err := db.db.QueryContext(ctx,
|
||||
"SELECT username FROM User WHERE coalesce(downstream_interacted_at, created_at) < ?",
|
||||
sqliteTime{limit})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var usernames []string
|
||||
for rows.Next() {
|
||||
var username string
|
||||
if err := rows.Scan(&username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usernames = append(usernames, username)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return usernames, nil
|
||||
}
|
||||
|
||||
func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||
defer cancel()
|
||||
|
@ -170,6 +170,15 @@ The following directives are supported:
|
||||
This can be useful to avoid having the whole bouncer banned from an upstream
|
||||
network because of one malicious user.
|
||||
|
||||
*disable-inactive-user* <duration>
|
||||
Disable inactive users after the specified duration.
|
||||
|
||||
A user is inactive when the last downstream connection is closed.
|
||||
|
||||
The duration is a positive decimal number followed by the unit "d" (days).
|
||||
For instance, "30d" disables users 30 days after they last disconnect from
|
||||
the bouncer.
|
||||
|
||||
# IRC SERVICE
|
||||
|
||||
soju exposes an IRC service called *BouncerServ* to manage the bouncer.
|
||||
|
96
server.go
96
server.go
@ -141,6 +141,7 @@ type Config struct {
|
||||
MaxUserNetworks int
|
||||
MOTD string
|
||||
UpstreamUserIPs []*net.IPNet
|
||||
DisableInactiveUsersDelay time.Duration
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
@ -151,6 +152,7 @@ type Server struct {
|
||||
config atomic.Value // *Config
|
||||
db database.Database
|
||||
stopWG sync.WaitGroup
|
||||
stopCh chan struct{}
|
||||
|
||||
lock sync.Mutex
|
||||
listeners map[net.Listener]struct{}
|
||||
@ -178,6 +180,7 @@ func NewServer(db database.Database) *Server {
|
||||
db: db,
|
||||
listeners: make(map[net.Listener]struct{}),
|
||||
users: make(map[string]*user),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
srv.config.Store(&Config{
|
||||
Hostname: "localhost",
|
||||
@ -216,6 +219,12 @@ func (s *Server) Start() error {
|
||||
}
|
||||
s.lock.Unlock()
|
||||
|
||||
s.stopWG.Add(1)
|
||||
go func() {
|
||||
defer s.stopWG.Done()
|
||||
s.disableInactiveUsersLoop()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -343,6 +352,8 @@ func (s *Server) sendWebPush(ctx context.Context, sub *webpush.Subscription, vap
|
||||
func (s *Server) Shutdown() {
|
||||
s.Logger.Printf("shutting down server")
|
||||
|
||||
close(s.stopCh)
|
||||
|
||||
s.lock.Lock()
|
||||
s.shutdown = true
|
||||
for ln := range s.listeners {
|
||||
@ -547,3 +558,88 @@ func (s *Server) Stats() *ServerStats {
|
||||
stats.Upstreams = s.metrics.upstreams.Value()
|
||||
return &stats
|
||||
}
|
||||
|
||||
func (s *Server) disableInactiveUsersLoop() {
|
||||
ticker := time.NewTicker(4 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
if err := s.disableInactiveUsers(context.TODO()); err != nil {
|
||||
s.Logger.Printf("failed to disable inactive users: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) disableInactiveUsers(ctx context.Context) error {
|
||||
delay := s.Config().DisableInactiveUsersDelay
|
||||
if delay == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
|
||||
usernames, err := s.db.ListInactiveUsernames(ctx, time.Now().Add(-delay))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list inactive users: %v", err)
|
||||
} else if len(usernames) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter out users with active downstream connections
|
||||
var users []*user
|
||||
s.lock.Lock()
|
||||
for _, username := range usernames {
|
||||
u := s.users[username]
|
||||
if u == nil {
|
||||
// TODO: disable the user in the DB
|
||||
continue
|
||||
}
|
||||
|
||||
if n := u.numDownstreamConns.Load(); n > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
users = append(users, u)
|
||||
}
|
||||
s.lock.Unlock()
|
||||
|
||||
if len(users) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.Logger.Printf("found %v inactive users", len(users))
|
||||
for _, u := range users {
|
||||
done := make(chan error, 1)
|
||||
enabled := false
|
||||
event := eventUserUpdate{
|
||||
enabled: &enabled,
|
||||
done: done,
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case u.events <- event:
|
||||
// Event was sent, let's wait for the reply
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return err
|
||||
} else {
|
||||
s.Logger.Printf("deleted inactive user %q", u.Username)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
5
user.go
5
user.go
@ -11,6 +11,7 @@ import (
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.sr.ht/~emersion/soju/xirc"
|
||||
@ -503,6 +504,8 @@ type user struct {
|
||||
events chan event
|
||||
done chan struct{}
|
||||
|
||||
numDownstreamConns atomic.Int64
|
||||
|
||||
networks []*network
|
||||
downstreamConns []*downstreamConn
|
||||
msgStore msgstore.Store
|
||||
@ -715,6 +718,7 @@ func (u *user) run() {
|
||||
}
|
||||
|
||||
u.downstreamConns = append(u.downstreamConns, dc)
|
||||
u.numDownstreamConns.Add(1)
|
||||
|
||||
dc.forEachNetwork(func(network *network) {
|
||||
if network.lastError != nil {
|
||||
@ -734,6 +738,7 @@ func (u *user) run() {
|
||||
for i := range u.downstreamConns {
|
||||
if u.downstreamConns[i] == dc {
|
||||
u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
|
||||
u.numDownstreamConns.Add(-1)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user