From bdb132ad98a1498bba450c3b9d51aa075c3e6dde Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 19 Aug 2020 19:28:29 +0200 Subject: [PATCH] Implement rate limiting for upstream messages Allow up to 10 outgoing messages in a burst, then throttle to 1 message each 2 seconds. Closes: https://todo.sr.ht/~emersion/soju/87 --- conn.go | 60 +++++++++++++++++++++++++++++++++++++++++++++++++-- downstream.go | 3 ++- server.go | 4 +++- upstream.go | 8 ++++++- user.go | 4 ++-- 5 files changed, 72 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index c0f8f16..7a6c63b 100644 --- a/conn.go +++ b/conn.go @@ -106,6 +106,52 @@ func (wa websocketAddr) String() string { return string(wa) } +type rateLimiter struct { + C <-chan struct{} + ticker *time.Ticker + stopped chan struct{} +} + +func newRateLimiter(delay time.Duration, burst int) *rateLimiter { + ch := make(chan struct{}, burst) + for i := 0; i < burst; i++ { + ch <- struct{}{} + } + ticker := time.NewTicker(delay) + stopped := make(chan struct{}) + go func() { + for { + select { + case <-ticker.C: + select { + case ch <- struct{}{}: + // This space is intentionally left blank + case <-stopped: + return + } + case <-stopped: + return + } + } + }() + return &rateLimiter{ + C: ch, + ticker: ticker, + stopped: stopped, + } +} + +func (rl *rateLimiter) Stop() { + rl.ticker.Stop() + close(rl.stopped) +} + +type connOptions struct { + Logger Logger + RateLimitDelay time.Duration + RateLimitBurst int +} + type conn struct { conn ircConn srv *Server @@ -116,17 +162,27 @@ type conn struct { closed bool } -func newConn(srv *Server, ic ircConn, logger Logger) *conn { +func newConn(srv *Server, ic ircConn, options *connOptions) *conn { outgoing := make(chan *irc.Message, 64) c := &conn{ conn: ic, srv: srv, outgoing: outgoing, - logger: logger, + logger: options.Logger, } go func() { + var rl *rateLimiter + if options.RateLimitDelay > 0 && options.RateLimitBurst > 0 { + rl = newRateLimiter(options.RateLimitDelay, options.RateLimitBurst) + defer rl.Stop() + } + for msg := range outgoing { + if rl != nil { + <-rl.C + } + if c.srv.Debug { c.logger.Printf("sent: %v", msg) } diff --git a/downstream.go b/downstream.go index 8863f28..952121d 100644 --- a/downstream.go +++ b/downstream.go @@ -102,8 +102,9 @@ type downstreamConn struct { func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { remoteAddr := ic.RemoteAddr().String() logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)} + options := connOptions{Logger: logger} dc := &downstreamConn{ - conn: *newConn(srv, ic, logger), + conn: *newConn(srv, ic, &options), id: id, supportedCaps: make(map[string]string), caps: make(map[string]bool), diff --git a/server.go b/server.go index 883b9b5..c538c20 100644 --- a/server.go +++ b/server.go @@ -16,9 +16,11 @@ import ( ) // TODO: make configurable -var retryConnectMinDelay = time.Minute +var retryConnectDelay = time.Minute var connectTimeout = 15 * time.Second var writeTimeout = 10 * time.Second +var upstreamMessageDelay = 2 * time.Second +var upstreamMessageBurst = 10 type Logger interface { Print(v ...interface{}) diff --git a/upstream.go b/upstream.go index 77c155f..29613af 100644 --- a/upstream.go +++ b/upstream.go @@ -157,8 +157,14 @@ func connectToUpstream(network *network) (*upstreamConn, error) { return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme) } + options := connOptions{ + Logger: logger, + RateLimitDelay: upstreamMessageDelay, + RateLimitBurst: upstreamMessageBurst, + } + uc := &upstreamConn{ - conn: *newConn(network.user.srv, newNetIRCConn(netConn), logger), + conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options), network: network, user: network.user, channels: make(map[string]*upstreamChannel), diff --git a/user.go b/user.go index 0090321..6b9512a 100644 --- a/user.go +++ b/user.go @@ -120,8 +120,8 @@ func (net *network) run() { return } - if dur := time.Now().Sub(lastTry); dur < retryConnectMinDelay { - delay := retryConnectMinDelay - dur + if dur := time.Now().Sub(lastTry); dur < retryConnectDelay { + delay := retryConnectDelay - dur net.user.srv.Logger.Printf("waiting %v before trying to reconnect to %q", delay.Truncate(time.Second), net.Addr) time.Sleep(delay) }