Add conn.Shutdown

References: https://todo.sr.ht/~emersion/soju/156
This commit is contained in:
Simon Ser 2023-12-11 11:50:16 +01:00
parent e9678cee2f
commit d423a1ca24
3 changed files with 29 additions and 3 deletions

25
conn.go
View File

@ -143,6 +143,10 @@ func newConn(srv *Server, ic ircConn, options *connOptions) *conn {
rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst) rl := rate.NewLimiter(rate.Every(options.RateLimitDelay), options.RateLimitBurst)
for msg := range outgoing { for msg := range outgoing {
if msg == nil {
break
}
if err := rl.Wait(ctx); err != nil { if err := rl.Wait(ctx); err != nil {
break break
} }
@ -224,6 +228,27 @@ func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) {
} }
} }
// Shutdown gracefully closes the connection, flushing any pending message.
func (c *conn) Shutdown(ctx context.Context) {
c.lock.Lock()
defer c.lock.Unlock()
if c.closed {
return
}
select {
case c.outgoing <- nil:
// Success
case <-ctx.Done():
c.logger.Printf("failed to shutdown connection: %v", ctx.Err())
// Forcibly close the connection
if err := c.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
c.logger.Printf("failed to close connection: %v", err)
}
}
}
func (c *conn) RemoteAddr() net.Addr { func (c *conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr() return c.conn.RemoteAddr()
} }

View File

@ -622,7 +622,8 @@ func (dc *downstreamConn) handleMessage(ctx context.Context, msg *irc.Message) e
switch msg.Command { switch msg.Command {
case "QUIT": case "QUIT":
return dc.Close() dc.conn.Shutdown(ctx)
return nil // TODO: stop handling commands
default: default:
if dc.registered { if dc.registered {
return dc.handleMessageRegistered(ctx, msg) return dc.handleMessageRegistered(ctx, msg)
@ -1698,7 +1699,7 @@ func (dc *downstreamConn) runUntilRegistered() error {
Command: "ERROR", Command: "ERROR",
Params: []string{"Connection registration timed out"}, Params: []string{"Connection registration timed out"},
}) })
dc.Close() dc.Shutdown(ctx)
} }
}() }()

View File

@ -469,7 +469,7 @@ func (s *Server) Handle(ic ircConn) {
id := atomic.AddUint64(&lastDownstreamID, 1) id := atomic.AddUint64(&lastDownstreamID, 1)
dc := newDownstreamConn(s, ic, id) dc := newDownstreamConn(s, ic, id)
defer dc.Close() defer dc.Shutdown(context.TODO())
if shutdown { if shutdown {
dc.SendMessage(context.TODO(), &irc.Message{ dc.SendMessage(context.TODO(), &irc.Message{