Implement rate limiting for upstream messages

Allow up to 10 outgoing messages in a burst, then throttle to 1 message
each 2 seconds.

Closes: https://todo.sr.ht/~emersion/soju/87
This commit is contained in:
Simon Ser 2020-08-19 19:28:29 +02:00
parent 9f26422592
commit bdb132ad98
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
5 changed files with 72 additions and 7 deletions

60
conn.go
View File

@ -106,6 +106,52 @@ func (wa websocketAddr) String() string {
return string(wa) return string(wa)
} }
type rateLimiter struct {
C <-chan struct{}
ticker *time.Ticker
stopped chan struct{}
}
func newRateLimiter(delay time.Duration, burst int) *rateLimiter {
ch := make(chan struct{}, burst)
for i := 0; i < burst; i++ {
ch <- struct{}{}
}
ticker := time.NewTicker(delay)
stopped := make(chan struct{})
go func() {
for {
select {
case <-ticker.C:
select {
case ch <- struct{}{}:
// This space is intentionally left blank
case <-stopped:
return
}
case <-stopped:
return
}
}
}()
return &rateLimiter{
C: ch,
ticker: ticker,
stopped: stopped,
}
}
func (rl *rateLimiter) Stop() {
rl.ticker.Stop()
close(rl.stopped)
}
type connOptions struct {
Logger Logger
RateLimitDelay time.Duration
RateLimitBurst int
}
type conn struct { type conn struct {
conn ircConn conn ircConn
srv *Server srv *Server
@ -116,17 +162,27 @@ type conn struct {
closed bool closed bool
} }
func newConn(srv *Server, ic ircConn, logger Logger) *conn { func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
outgoing := make(chan *irc.Message, 64) outgoing := make(chan *irc.Message, 64)
c := &conn{ c := &conn{
conn: ic, conn: ic,
srv: srv, srv: srv,
outgoing: outgoing, outgoing: outgoing,
logger: logger, logger: options.Logger,
} }
go func() { go func() {
var rl *rateLimiter
if options.RateLimitDelay > 0 && options.RateLimitBurst > 0 {
rl = newRateLimiter(options.RateLimitDelay, options.RateLimitBurst)
defer rl.Stop()
}
for msg := range outgoing { for msg := range outgoing {
if rl != nil {
<-rl.C
}
if c.srv.Debug { if c.srv.Debug {
c.logger.Printf("sent: %v", msg) c.logger.Printf("sent: %v", msg)
} }

View File

@ -102,8 +102,9 @@ type downstreamConn struct {
func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {
remoteAddr := ic.RemoteAddr().String() remoteAddr := ic.RemoteAddr().String()
logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)} logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)}
options := connOptions{Logger: logger}
dc := &downstreamConn{ dc := &downstreamConn{
conn: *newConn(srv, ic, logger), conn: *newConn(srv, ic, &options),
id: id, id: id,
supportedCaps: make(map[string]string), supportedCaps: make(map[string]string),
caps: make(map[string]bool), caps: make(map[string]bool),

View File

@ -16,9 +16,11 @@ import (
) )
// TODO: make configurable // TODO: make configurable
var retryConnectMinDelay = time.Minute var retryConnectDelay = time.Minute
var connectTimeout = 15 * time.Second var connectTimeout = 15 * time.Second
var writeTimeout = 10 * time.Second var writeTimeout = 10 * time.Second
var upstreamMessageDelay = 2 * time.Second
var upstreamMessageBurst = 10
type Logger interface { type Logger interface {
Print(v ...interface{}) Print(v ...interface{})

View File

@ -157,8 +157,14 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme) return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme)
} }
options := connOptions{
Logger: logger,
RateLimitDelay: upstreamMessageDelay,
RateLimitBurst: upstreamMessageBurst,
}
uc := &upstreamConn{ uc := &upstreamConn{
conn: *newConn(network.user.srv, newNetIRCConn(netConn), logger), conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options),
network: network, network: network,
user: network.user, user: network.user,
channels: make(map[string]*upstreamChannel), channels: make(map[string]*upstreamChannel),

View File

@ -120,8 +120,8 @@ func (net *network) run() {
return return
} }
if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay { if dur := time.Now().Sub(lastTry); dur < retryConnectDelay {
delay := retryConnectMinDelay - dur delay := retryConnectDelay - dur
net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr) net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr)
time.Sleep(delay) time.Sleep(delay)
} }