From d423a1ca24200e2c1d714c6415efa9cb6e9332b2 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 11 Dec 2023 11:50:16 +0100 Subject: [PATCH] Add conn.Shutdown References: https://todo.sr.ht/~emersion/soju/156 --- conn.go | 25 +++++++++++++++++++++++++ downstream.go | 5 +++-- server.go | 2 +- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/conn.go b/conn.go index f681b05..983f976 100644 --- a/conn.go +++ b/conn.go @@ -143,6 +143,10 @@ func newConn(srv *Server, ic ircConn, options *connOptions) *conn { rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst) for msg := range outgoing { + if msg == nil { + break + } + if err := rl.Wait(ctx); err != nil { break } @@ -224,6 +228,27 @@ func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) { } } +// Shutdown gracefully closes the connection, flushing any pending message. +func (c *conn) Shutdown(ctx context.Context) { + c.lock.Lock() + defer c.lock.Unlock() + + if c.closed { + return + } + + select { + case c.outgoing <- nil: + // Success + case <-ctx.Done(): + c.logger.Printf("failed to shutdown connection: %v", ctx.Err()) + // Forcibly close the connection + if err := c.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + c.logger.Printf("failed to close connection: %v", err) + } + } +} + func (c *conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } diff --git a/downstream.go b/downstream.go index b4a84ff..f28b40f 100644 --- a/downstream.go +++ b/downstream.go @@ -622,7 +622,8 @@ func (dc *downstreamConn) handleMessage(ctx context.Context, msg *irc.Message) e switch msg.Command { case "QUIT": - return dc.Close() + dc.conn.Shutdown(ctx) + return nil // TODO: stop handling commands default: if dc.registered { return dc.handleMessageRegistered(ctx, msg) @@ -1698,7 +1699,7 @@ func (dc *downstreamConn) runUntilRegistered() error { Command: "ERROR", Params: []string{"Connection registration timed out"}, }) - dc.Close() + dc.Shutdown(ctx) } }() diff --git a/server.go b/server.go index db62678..d26e873 100644 --- a/server.go +++ b/server.go @@ -469,7 +469,7 @@ func (s *Server) Handle(ic ircConn) { id := atomic.AddUint64(&lastDownstreamID, 1) dc := newDownstreamConn(s, ic, id) - defer dc.Close() + defer dc.Shutdown(context.TODO()) if shutdown { dc.SendMessage(context.TODO(), &irc.Message{