From 66aea1b4a22d683f8b018b995dc898efcf45712c Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 8 Dec 2021 18:03:40 +0100 Subject: [PATCH] Add context to {conn,upstreamConn}.SendMessage This avoids blocking on upstream message rate limiting for too long. --- conn.go | 10 +++++++-- downstream.go | 30 +++++++++++++------------- service.go | 2 +- upstream.go | 60 ++++++++++++++++++++++++++++----------------------- 4 files changed, 57 insertions(+), 45 deletions(-) diff --git a/conn.go b/conn.go index 7b210cd..04b1237 100644 --- a/conn.go +++ b/conn.go @@ -217,14 +217,20 @@ func (c *conn) ReadMessage() (*irc.Message, error) { // // If the connection is closed before the message is sent, SendMessage silently // drops the message. -func (c *conn) SendMessage(msg *irc.Message) { +func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) { c.lock.Lock() defer c.lock.Unlock() if c.closed { return } - c.outgoing <- msg + + select { + case c.outgoing <- msg: + // Success + case <-ctx.Done(): + c.logger.Printf("failed to send message: %v", ctx.Err()) + } } func (c *conn) RemoteAddr() net.Addr { diff --git a/downstream.go b/downstream.go index 6efe756..5bbfcb7 100644 --- a/downstream.go +++ b/downstream.go @@ -551,7 +551,7 @@ func (dc *downstreamConn) SendMessage(msg *irc.Message) { } dc.srv.metrics.downstreamOutMessagesTotal.Inc() - dc.conn.SendMessage(msg) + dc.conn.SendMessage(context.TODO(), msg) } func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef irc.TagValue)) { @@ -1666,7 +1666,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if upstream != nil && upstream != uc { return } - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "NICK", Params: []string{nick}, }) @@ -1700,7 +1700,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. // We only need to call updateNetwork for upstreams that don't // support setname if uc := n.conn; uc != nil && uc.caps["setname"] { - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "SETNAME", Params: []string{realname}, }) @@ -1775,7 +1775,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if key != "" { params = append(params, key) } - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "JOIN", Params: params, }) @@ -1835,7 +1835,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if reason != "" { params = append(params, reason) } - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "PART", Params: params, }) @@ -1896,7 +1896,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if reason != "" { params = append(params, reason) } - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "KICK", Params: params, }) @@ -1915,7 +1915,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if casemapASCII(name) == dc.nickCM { if modeStr != "" { if uc := dc.upstream(); uc != nil { - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "MODE", Params: []string{uc.nick, modeStr}, }) @@ -1956,7 +1956,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if modeStr != "" { params := []string{upstreamName, modeStr} params = append(params, msg.Params[2:]...) - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "MODE", Params: params, }) @@ -2005,7 +2005,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if len(msg.Params) > 1 { // setting topic topic := msg.Params[1] - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "TOPIC", Params: []string{upstreamName, topic}, }) @@ -2070,7 +2070,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. sendNames(dc, ch) } else { // NAMES on a channel we have not joined, ask upstream - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "NAMES", Params: []string{upstreamName}, }) @@ -2270,7 +2270,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. params = []string{upstreamNick} } - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "WHOIS", Params: params, }) @@ -2347,7 +2347,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if uc.isChannel(upstreamName) { unmarshaledText = dc.unmarshalText(uc, text) } - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Tags: tags, Command: msg.Command, Params: []string{upstreamName, unmarshaledText}, @@ -2398,7 +2398,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. continue } - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Tags: tags, Command: "TAGMSG", Params: []string{upstreamName}, @@ -2430,7 +2430,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } uc := ucChannel - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "INVITE", Params: []string{upstreamUser, upstreamChannel}, }) @@ -2850,7 +2850,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return newUnknownCommandError(msg.Command) } - uc.SendMessageLabeled(dc.id, msg) + uc.SendMessageLabeled(ctx, dc.id, msg) } return nil } diff --git a/service.go b/service.go index 0d6d7f0..80fb261 100644 --- a/service.go +++ b/service.go @@ -616,7 +616,7 @@ func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params [ if err != nil { return fmt.Errorf("failed to parse command %q: %v", params[1], err) } - uc.SendMessage(m) + uc.SendMessage(ctx, m) sendServicePRIVMSG(dc, fmt.Sprintf("sent command to %q", net.GetName())) return nil diff --git a/upstream.go b/upstream.go index 1a454d3..6277a3b 100644 --- a/upstream.go +++ b/upstream.go @@ -342,7 +342,7 @@ func (uc *upstreamConn) sendNextPendingCommand(cmd string) { if len(uc.pendingCmds[cmd]) == 0 { return } - uc.SendMessage(uc.pendingCmds[cmd][0].msg) + uc.SendMessage(context.TODO(), uc.pendingCmds[cmd][0].msg) } func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) { @@ -450,7 +450,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err switch msg.Command { case "PING": - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "PONG", Params: msg.Params, }) @@ -529,7 +529,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err break // we'll send CAP END after authentication is completed } - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "CAP", Params: []string{"END"}, }) @@ -583,7 +583,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err // TODO: if a challenge is 400 bytes long, buffer it var challengeStr string if err := parseMessageParams(msg, &challengeStr); err != nil { - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "AUTHENTICATE", Params: []string{"*"}, }) @@ -595,7 +595,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err var err error challenge, err = base64.StdEncoding.DecodeString(challengeStr) if err != nil { - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "AUTHENTICATE", Params: []string{"*"}, }) @@ -612,7 +612,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err resp, err = uc.saslClient.Next(challenge) } if err != nil { - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "AUTHENTICATE", Params: []string{"*"}, }) @@ -625,7 +625,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err respStr = base64.StdEncoding.EncodeToString(resp) } - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "AUTHENTICATE", Params: []string{respStr}, }) @@ -669,7 +669,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } if !uc.registered { - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "CAP", Params: []string{"END"}, }) @@ -707,7 +707,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } for _, msg := range join(channels, keys) { - uc.SendMessage(msg) + uc.SendMessage(ctx, msg) } } case irc.RPL_MYINFO: @@ -931,7 +931,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err }) uc.updateChannelAutoDetach(ch) - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "MODE", Params: []string{ch}, }) @@ -1531,7 +1531,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } if found { uc.logger.Printf("desired nick %q is now available", wantNick) - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "NICK", Params: []string{wantNick}, }) @@ -1711,7 +1711,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.nick = uc.nick + "_" uc.nickCM = uc.network.casemap(uc.nick) uc.logger.Printf("desired nick is not available, falling back to %q", uc.nick) - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "NICK", Params: []string{uc.nick}, }) @@ -1825,7 +1825,7 @@ func (uc *upstreamConn) requestCaps() { return } - uc.SendMessage(&irc.Message{ + uc.SendMessage(context.TODO(), &irc.Message{ Command: "CAP", Params: []string{"REQ", strings.Join(requestCaps, " ")}, }) @@ -1882,7 +1882,7 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error { return fmt.Errorf("unsupported SASL mechanism %q", name) } - uc.SendMessage(&irc.Message{ + uc.SendMessage(context.TODO(), &irc.Message{ Command: "AUTHENTICATE", Params: []string{auth.Mechanism}, }) @@ -1902,28 +1902,30 @@ func splitSpace(s string) []string { } func (uc *upstreamConn) register() { + ctx := context.TODO() + uc.nick = GetNick(&uc.user.User, &uc.network.Network) uc.nickCM = uc.network.casemap(uc.nick) uc.username = GetUsername(&uc.user.User, &uc.network.Network) uc.realname = GetRealname(&uc.user.User, &uc.network.Network) - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "CAP", Params: []string{"LS", "302"}, }) if uc.network.Pass != "" { - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "PASS", Params: []string{uc.network.Pass}, }) } - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "NICK", Params: []string{uc.nick}, }) - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "USER", Params: []string{uc.username, "0", "*", uc.realname}, }) @@ -1960,7 +1962,7 @@ func (uc *upstreamConn) runUntilRegistered() error { if err != nil { uc.logger.Printf("failed to parse connect command %q: %v", command, err) } else { - uc.SendMessage(m) + uc.SendMessage(context.TODO(), m) } } @@ -1982,17 +1984,17 @@ func (uc *upstreamConn) readMessages(ch chan<- event) error { return nil } -func (uc *upstreamConn) SendMessage(msg *irc.Message) { +func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) { if !uc.caps["message-tags"] { msg = msg.Copy() msg.Tags = nil } uc.srv.metrics.upstreamOutMessagesTotal.Inc() - uc.conn.SendMessage(msg) + uc.conn.SendMessage(ctx, msg) } -func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message) { +func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) { if uc.caps["labeled-response"] { if msg.Tags == nil { msg.Tags = make(map[string]irc.TagValue) @@ -2000,7 +2002,7 @@ func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID)) uc.nextLabelID++ } - uc.SendMessage(msg) + uc.SendMessage(ctx, msg) } // appendLog appends a message to the log file. @@ -2073,6 +2075,8 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstr } func (uc *upstreamConn) updateAway() { + ctx := context.TODO() + away := true uc.forEachDownstream(func(*downstreamConn) { away = false @@ -2081,12 +2085,12 @@ func (uc *upstreamConn) updateAway() { return } if away { - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "AWAY", Params: []string{"Auto away"}, }) } else { - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "AWAY", }) } @@ -2110,6 +2114,8 @@ func (uc *upstreamConn) updateMonitor() { return } + ctx := context.TODO() + add := make(map[string]struct{}) var addList []string seen := make(map[string]struct{}) @@ -2148,7 +2154,7 @@ func (uc *upstreamConn) updateMonitor() { if removeAll && len(addList) == 0 && len(removeList) > 0 { // Optimization when the last MONITOR-aware downstream disconnects - uc.SendMessage(&irc.Message{ + uc.SendMessage(ctx, &irc.Message{ Command: "MONITOR", Params: []string{"C"}, }) @@ -2156,7 +2162,7 @@ func (uc *upstreamConn) updateMonitor() { msgs := generateMonitor("-", removeList) msgs = append(msgs, generateMonitor("+", addList)...) for _, msg := range msgs { - uc.SendMessage(msg) + uc.SendMessage(ctx, msg) } }