Add disable-inactive-user config option
This can be used to automatically disable users if they don't actively use the bouncer for a while.
This commit is contained in:
parent
57f5ee8d6f
commit
9df9880301
@ -84,14 +84,15 @@ func loadConfig() (*config.Server, *soju.Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cfg := &soju.Config{
|
cfg := &soju.Config{
|
||||||
Hostname: raw.Hostname,
|
Hostname: raw.Hostname,
|
||||||
Title: raw.Title,
|
Title: raw.Title,
|
||||||
LogPath: raw.MsgStore.Source,
|
LogPath: raw.MsgStore.Source,
|
||||||
HTTPOrigins: raw.HTTPOrigins,
|
HTTPOrigins: raw.HTTPOrigins,
|
||||||
AcceptProxyIPs: raw.AcceptProxyIPs,
|
AcceptProxyIPs: raw.AcceptProxyIPs,
|
||||||
MaxUserNetworks: raw.MaxUserNetworks,
|
MaxUserNetworks: raw.MaxUserNetworks,
|
||||||
UpstreamUserIPs: raw.UpstreamUserIPs,
|
UpstreamUserIPs: raw.UpstreamUserIPs,
|
||||||
MOTD: motd,
|
DisableInactiveUsersDelay: raw.DisableInactiveUsersDelay,
|
||||||
|
MOTD: motd,
|
||||||
}
|
}
|
||||||
return raw, cfg, nil
|
return raw, cfg, nil
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,8 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.sr.ht/~emersion/go-scfg"
|
"git.sr.ht/~emersion/go-scfg"
|
||||||
)
|
)
|
||||||
@ -32,6 +34,18 @@ var loopbackIPs = IPSet{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseDuration(s string) (time.Duration, error) {
|
||||||
|
if !strings.HasSuffix(s, "d") {
|
||||||
|
return 0, fmt.Errorf("missing 'd' suffix in duration")
|
||||||
|
}
|
||||||
|
s = strings.TrimSuffix(s, "d")
|
||||||
|
v, err := strconv.ParseFloat(s, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid duration: %v", err)
|
||||||
|
}
|
||||||
|
return time.Duration(v * 24 * float64(time.Hour)), nil
|
||||||
|
}
|
||||||
|
|
||||||
type TLS struct {
|
type TLS struct {
|
||||||
CertPath, KeyPath string
|
CertPath, KeyPath string
|
||||||
}
|
}
|
||||||
@ -57,8 +71,9 @@ type Server struct {
|
|||||||
HTTPOrigins []string
|
HTTPOrigins []string
|
||||||
AcceptProxyIPs IPSet
|
AcceptProxyIPs IPSet
|
||||||
|
|
||||||
MaxUserNetworks int
|
MaxUserNetworks int
|
||||||
UpstreamUserIPs []*net.IPNet
|
UpstreamUserIPs []*net.IPNet
|
||||||
|
DisableInactiveUsersDelay time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func Defaults() *Server {
|
func Defaults() *Server {
|
||||||
@ -180,6 +195,18 @@ func parse(cfg scfg.Block) (*Server, error) {
|
|||||||
}
|
}
|
||||||
srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n)
|
srv.UpstreamUserIPs = append(srv.UpstreamUserIPs, n)
|
||||||
}
|
}
|
||||||
|
case "disable-inactive-user":
|
||||||
|
var durStr string
|
||||||
|
if err := d.ParseParams(&durStr); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dur, err := parseDuration(durStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("directive %q: %v", d.Name, err)
|
||||||
|
} else if dur < 0 {
|
||||||
|
return nil, fmt.Errorf("directive %q: duration must be positive", d.Name)
|
||||||
|
}
|
||||||
|
srv.DisableInactiveUsersDelay = dur
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown directive %q", d.Name)
|
return nil, fmt.Errorf("unknown directive %q", d.Name)
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,7 @@ type Database interface {
|
|||||||
GetUser(ctx context.Context, username string) (*User, error)
|
GetUser(ctx context.Context, username string) (*User, error)
|
||||||
StoreUser(ctx context.Context, user *User) error
|
StoreUser(ctx context.Context, user *User) error
|
||||||
DeleteUser(ctx context.Context, id int64) error
|
DeleteUser(ctx context.Context, id int64) error
|
||||||
|
ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error)
|
||||||
|
|
||||||
ListNetworks(ctx context.Context, userID int64) ([]Network, error)
|
ListNetworks(ctx context.Context, userID int64) ([]Network, error)
|
||||||
StoreNetwork(ctx context.Context, userID int64, network *Network) error
|
StoreNetwork(ctx context.Context, userID int64, network *Network) error
|
||||||
|
@ -354,6 +354,33 @@ func (db *PostgresDB) GetUser(ctx context.Context, username string) (*User, erro
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *PostgresDB) ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
rows, err := db.db.QueryContext(ctx,
|
||||||
|
`SELECT username FROM "User" WHERE COALESCE(downstream_interacted_at, created_at) < $1`,
|
||||||
|
limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var usernames []string
|
||||||
|
for rows.Next() {
|
||||||
|
var username string
|
||||||
|
if err := rows.Scan(&username); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
usernames = append(usernames, username)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return usernames, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
|
func (db *PostgresDB) StoreUser(ctx context.Context, user *User) error {
|
||||||
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -445,6 +445,33 @@ func (db *SqliteDB) GetUser(ctx context.Context, username string) (*User, error)
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *SqliteDB) ListInactiveUsernames(ctx context.Context, limit time.Time) ([]string, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
rows, err := db.db.QueryContext(ctx,
|
||||||
|
"SELECT username FROM User WHERE coalesce(downstream_interacted_at, created_at) < ?",
|
||||||
|
sqliteTime{limit})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var usernames []string
|
||||||
|
for rows.Next() {
|
||||||
|
var username string
|
||||||
|
if err := rows.Scan(&username); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
usernames = append(usernames, username)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return usernames, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
|
func (db *SqliteDB) StoreUser(ctx context.Context, user *User) error {
|
||||||
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -170,6 +170,15 @@ The following directives are supported:
|
|||||||
This can be useful to avoid having the whole bouncer banned from an upstream
|
This can be useful to avoid having the whole bouncer banned from an upstream
|
||||||
network because of one malicious user.
|
network because of one malicious user.
|
||||||
|
|
||||||
|
*disable-inactive-user* <duration>
|
||||||
|
Disable inactive users after the specified duration.
|
||||||
|
|
||||||
|
A user is inactive when the last downstream connection is closed.
|
||||||
|
|
||||||
|
The duration is a positive decimal number followed by the unit "d" (days).
|
||||||
|
For instance, "30d" disables users 30 days after they last disconnect from
|
||||||
|
the bouncer.
|
||||||
|
|
||||||
# IRC SERVICE
|
# IRC SERVICE
|
||||||
|
|
||||||
soju exposes an IRC service called *BouncerServ* to manage the bouncer.
|
soju exposes an IRC service called *BouncerServ* to manage the bouncer.
|
||||||
|
112
server.go
112
server.go
@ -133,14 +133,15 @@ func (ln *retryListener) Accept() (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Hostname string
|
Hostname string
|
||||||
Title string
|
Title string
|
||||||
LogPath string
|
LogPath string
|
||||||
HTTPOrigins []string
|
HTTPOrigins []string
|
||||||
AcceptProxyIPs config.IPSet
|
AcceptProxyIPs config.IPSet
|
||||||
MaxUserNetworks int
|
MaxUserNetworks int
|
||||||
MOTD string
|
MOTD string
|
||||||
UpstreamUserIPs []*net.IPNet
|
UpstreamUserIPs []*net.IPNet
|
||||||
|
DisableInactiveUsersDelay time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@ -151,6 +152,7 @@ type Server struct {
|
|||||||
config atomic.Value // *Config
|
config atomic.Value // *Config
|
||||||
db database.Database
|
db database.Database
|
||||||
stopWG sync.WaitGroup
|
stopWG sync.WaitGroup
|
||||||
|
stopCh chan struct{}
|
||||||
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
listeners map[net.Listener]struct{}
|
listeners map[net.Listener]struct{}
|
||||||
@ -178,6 +180,7 @@ func NewServer(db database.Database) *Server {
|
|||||||
db: db,
|
db: db,
|
||||||
listeners: make(map[net.Listener]struct{}),
|
listeners: make(map[net.Listener]struct{}),
|
||||||
users: make(map[string]*user),
|
users: make(map[string]*user),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
srv.config.Store(&Config{
|
srv.config.Store(&Config{
|
||||||
Hostname: "localhost",
|
Hostname: "localhost",
|
||||||
@ -216,6 +219,12 @@ func (s *Server) Start() error {
|
|||||||
}
|
}
|
||||||
s.lock.Unlock()
|
s.lock.Unlock()
|
||||||
|
|
||||||
|
s.stopWG.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer s.stopWG.Done()
|
||||||
|
s.disableInactiveUsersLoop()
|
||||||
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -343,6 +352,8 @@ func (s *Server) sendWebPush(ctx context.Context, sub *webpush.Subscription, vap
|
|||||||
func (s *Server) Shutdown() {
|
func (s *Server) Shutdown() {
|
||||||
s.Logger.Printf("shutting down server")
|
s.Logger.Printf("shutting down server")
|
||||||
|
|
||||||
|
close(s.stopCh)
|
||||||
|
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
s.shutdown = true
|
s.shutdown = true
|
||||||
for ln := range s.listeners {
|
for ln := range s.listeners {
|
||||||
@ -547,3 +558,88 @@ func (s *Server) Stats() *ServerStats {
|
|||||||
stats.Upstreams = s.metrics.upstreams.Value()
|
stats.Upstreams = s.metrics.upstreams.Value()
|
||||||
return &stats
|
return &stats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) disableInactiveUsersLoop() {
|
||||||
|
ticker := time.NewTicker(4 * time.Hour)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.stopCh:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.disableInactiveUsers(context.TODO()); err != nil {
|
||||||
|
s.Logger.Printf("failed to disable inactive users: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) disableInactiveUsers(ctx context.Context) error {
|
||||||
|
delay := s.Config().DisableInactiveUsersDelay
|
||||||
|
if delay == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
usernames, err := s.db.ListInactiveUsernames(ctx, time.Now().Add(-delay))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list inactive users: %v", err)
|
||||||
|
} else if len(usernames) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out users with active downstream connections
|
||||||
|
var users []*user
|
||||||
|
s.lock.Lock()
|
||||||
|
for _, username := range usernames {
|
||||||
|
u := s.users[username]
|
||||||
|
if u == nil {
|
||||||
|
// TODO: disable the user in the DB
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if n := u.numDownstreamConns.Load(); n > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
users = append(users, u)
|
||||||
|
}
|
||||||
|
s.lock.Unlock()
|
||||||
|
|
||||||
|
if len(users) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Logger.Printf("found %v inactive users", len(users))
|
||||||
|
for _, u := range users {
|
||||||
|
done := make(chan error, 1)
|
||||||
|
enabled := false
|
||||||
|
event := eventUserUpdate{
|
||||||
|
enabled: &enabled,
|
||||||
|
done: done,
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case u.events <- event:
|
||||||
|
// Event was sent, let's wait for the reply
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case err := <-done:
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
s.Logger.Printf("deleted inactive user %q", u.Username)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
5
user.go
5
user.go
@ -11,6 +11,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.sr.ht/~emersion/soju/xirc"
|
"git.sr.ht/~emersion/soju/xirc"
|
||||||
@ -503,6 +504,8 @@ type user struct {
|
|||||||
events chan event
|
events chan event
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
|
|
||||||
|
numDownstreamConns atomic.Int64
|
||||||
|
|
||||||
networks []*network
|
networks []*network
|
||||||
downstreamConns []*downstreamConn
|
downstreamConns []*downstreamConn
|
||||||
msgStore msgstore.Store
|
msgStore msgstore.Store
|
||||||
@ -715,6 +718,7 @@ func (u *user) run() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
u.downstreamConns = append(u.downstreamConns, dc)
|
u.downstreamConns = append(u.downstreamConns, dc)
|
||||||
|
u.numDownstreamConns.Add(1)
|
||||||
|
|
||||||
dc.forEachNetwork(func(network *network) {
|
dc.forEachNetwork(func(network *network) {
|
||||||
if network.lastError != nil {
|
if network.lastError != nil {
|
||||||
@ -734,6 +738,7 @@ func (u *user) run() {
|
|||||||
for i := range u.downstreamConns {
|
for i := range u.downstreamConns {
|
||||||
if u.downstreamConns[i] == dc {
|
if u.downstreamConns[i] == dc {
|
||||||
u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
|
u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
|
||||||
|
u.numDownstreamConns.Add(-1)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user