Add context to {conn,upstreamConn}.SendMessage

This avoids blocking on upstream message rate limiting for too
long.
This commit is contained in:
Simon Ser 2021-12-08 18:03:40 +01:00
parent d21fc06d88
commit 66aea1b4a2
4 changed files with 57 additions and 45 deletions

10
conn.go
View File

@ -217,14 +217,20 @@ func (c *conn) ReadMessage() (*irc.Message, error) {
// //
// If the connection is closed before the message is sent, SendMessage silently // If the connection is closed before the message is sent, SendMessage silently
// drops the message. // drops the message.
func (c *conn) SendMessage(msg *irc.Message) { func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) {
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
if c.closed { if c.closed {
return return
} }
c.outgoing <- msg
select {
case c.outgoing <- msg:
// Success
case <-ctx.Done():
c.logger.Printf("failed to send message: %v", ctx.Err())
}
} }
func (c *conn) RemoteAddr() net.Addr { func (c *conn) RemoteAddr() net.Addr {

View File

@ -551,7 +551,7 @@ func (dc *downstreamConn) SendMessage(msg *irc.Message) {
} }
dc.srv.metrics.downstreamOutMessagesTotal.Inc() dc.srv.metrics.downstreamOutMessagesTotal.Inc()
dc.conn.SendMessage(msg) dc.conn.SendMessage(context.TODO(), msg)
} }
func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef irc.TagValue)) { func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef irc.TagValue)) {
@ -1666,7 +1666,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if upstream != nil && upstream != uc { if upstream != nil && upstream != uc {
return return
} }
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "NICK", Command: "NICK",
Params: []string{nick}, Params: []string{nick},
}) })
@ -1700,7 +1700,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
// We only need to call updateNetwork for upstreams that don't // We only need to call updateNetwork for upstreams that don't
// support setname // support setname
if uc := n.conn; uc != nil && uc.caps["setname"] { if uc := n.conn; uc != nil && uc.caps["setname"] {
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "SETNAME", Command: "SETNAME",
Params: []string{realname}, Params: []string{realname},
}) })
@ -1775,7 +1775,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if key != "" { if key != "" {
params = append(params, key) params = append(params, key)
} }
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "JOIN", Command: "JOIN",
Params: params, Params: params,
}) })
@ -1835,7 +1835,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if reason != "" { if reason != "" {
params = append(params, reason) params = append(params, reason)
} }
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "PART", Command: "PART",
Params: params, Params: params,
}) })
@ -1896,7 +1896,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if reason != "" { if reason != "" {
params = append(params, reason) params = append(params, reason)
} }
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "KICK", Command: "KICK",
Params: params, Params: params,
}) })
@ -1915,7 +1915,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if casemapASCII(name) == dc.nickCM { if casemapASCII(name) == dc.nickCM {
if modeStr != "" { if modeStr != "" {
if uc := dc.upstream(); uc != nil { if uc := dc.upstream(); uc != nil {
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "MODE", Command: "MODE",
Params: []string{uc.nick, modeStr}, Params: []string{uc.nick, modeStr},
}) })
@ -1956,7 +1956,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if modeStr != "" { if modeStr != "" {
params := []string{upstreamName, modeStr} params := []string{upstreamName, modeStr}
params = append(params, msg.Params[2:]...) params = append(params, msg.Params[2:]...)
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "MODE", Command: "MODE",
Params: params, Params: params,
}) })
@ -2005,7 +2005,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if len(msg.Params) > 1 { // setting topic if len(msg.Params) > 1 { // setting topic
topic := msg.Params[1] topic := msg.Params[1]
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "TOPIC", Command: "TOPIC",
Params: []string{upstreamName, topic}, Params: []string{upstreamName, topic},
}) })
@ -2070,7 +2070,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
sendNames(dc, ch) sendNames(dc, ch)
} else { } else {
// NAMES on a channel we have not joined, ask upstream // NAMES on a channel we have not joined, ask upstream
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "NAMES", Command: "NAMES",
Params: []string{upstreamName}, Params: []string{upstreamName},
}) })
@ -2270,7 +2270,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
params = []string{upstreamNick} params = []string{upstreamNick}
} }
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "WHOIS", Command: "WHOIS",
Params: params, Params: params,
}) })
@ -2347,7 +2347,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if uc.isChannel(upstreamName) { if uc.isChannel(upstreamName) {
unmarshaledText = dc.unmarshalText(uc, text) unmarshaledText = dc.unmarshalText(uc, text)
} }
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Tags: tags, Tags: tags,
Command: msg.Command, Command: msg.Command,
Params: []string{upstreamName, unmarshaledText}, Params: []string{upstreamName, unmarshaledText},
@ -2398,7 +2398,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
continue continue
} }
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Tags: tags, Tags: tags,
Command: "TAGMSG", Command: "TAGMSG",
Params: []string{upstreamName}, Params: []string{upstreamName},
@ -2430,7 +2430,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
uc := ucChannel uc := ucChannel
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "INVITE", Command: "INVITE",
Params: []string{upstreamUser, upstreamChannel}, Params: []string{upstreamUser, upstreamChannel},
}) })
@ -2850,7 +2850,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return newUnknownCommandError(msg.Command) return newUnknownCommandError(msg.Command)
} }
uc.SendMessageLabeled(dc.id, msg) uc.SendMessageLabeled(ctx, dc.id, msg)
} }
return nil return nil
} }

View File

@ -616,7 +616,7 @@ func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params [
if err != nil { if err != nil {
return fmt.Errorf("failed to parse command %q: %v", params[1], err) return fmt.Errorf("failed to parse command %q: %v", params[1], err)
} }
uc.SendMessage(m) uc.SendMessage(ctx, m)
sendServicePRIVMSG(dc, fmt.Sprintf("sent command to %q", net.GetName())) sendServicePRIVMSG(dc, fmt.Sprintf("sent command to %q", net.GetName()))
return nil return nil

View File

@ -342,7 +342,7 @@ func (uc *upstreamConn) sendNextPendingCommand(cmd string) {
if len(uc.pendingCmds[cmd]) == 0 { if len(uc.pendingCmds[cmd]) == 0 {
return return
} }
uc.SendMessage(uc.pendingCmds[cmd][0].msg) uc.SendMessage(context.TODO(), uc.pendingCmds[cmd][0].msg)
} }
func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) { func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) {
@ -450,7 +450,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
switch msg.Command { switch msg.Command {
case "PING": case "PING":
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "PONG", Command: "PONG",
Params: msg.Params, Params: msg.Params,
}) })
@ -529,7 +529,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
break // we'll send CAP END after authentication is completed break // we'll send CAP END after authentication is completed
} }
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "CAP", Command: "CAP",
Params: []string{"END"}, Params: []string{"END"},
}) })
@ -583,7 +583,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
// TODO: if a challenge is 400 bytes long, buffer it // TODO: if a challenge is 400 bytes long, buffer it
var challengeStr string var challengeStr string
if err := parseMessageParams(msg, &challengeStr); err != nil { if err := parseMessageParams(msg, &challengeStr); err != nil {
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "AUTHENTICATE", Command: "AUTHENTICATE",
Params: []string{"*"}, Params: []string{"*"},
}) })
@ -595,7 +595,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
var err error var err error
challenge, err = base64.StdEncoding.DecodeString(challengeStr) challenge, err = base64.StdEncoding.DecodeString(challengeStr)
if err != nil { if err != nil {
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "AUTHENTICATE", Command: "AUTHENTICATE",
Params: []string{"*"}, Params: []string{"*"},
}) })
@ -612,7 +612,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
resp, err = uc.saslClient.Next(challenge) resp, err = uc.saslClient.Next(challenge)
} }
if err != nil { if err != nil {
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "AUTHENTICATE", Command: "AUTHENTICATE",
Params: []string{"*"}, Params: []string{"*"},
}) })
@ -625,7 +625,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
respStr = base64.StdEncoding.EncodeToString(resp) respStr = base64.StdEncoding.EncodeToString(resp)
} }
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "AUTHENTICATE", Command: "AUTHENTICATE",
Params: []string{respStr}, Params: []string{respStr},
}) })
@ -669,7 +669,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
if !uc.registered { if !uc.registered {
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "CAP", Command: "CAP",
Params: []string{"END"}, Params: []string{"END"},
}) })
@ -707,7 +707,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
for _, msg := range join(channels, keys) { for _, msg := range join(channels, keys) {
uc.SendMessage(msg) uc.SendMessage(ctx, msg)
} }
} }
case irc.RPL_MYINFO: case irc.RPL_MYINFO:
@ -931,7 +931,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}) })
uc.updateChannelAutoDetach(ch) uc.updateChannelAutoDetach(ch)
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "MODE", Command: "MODE",
Params: []string{ch}, Params: []string{ch},
}) })
@ -1531,7 +1531,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
} }
if found { if found {
uc.logger.Printf("desired nick %q is now available", wantNick) uc.logger.Printf("desired nick %q is now available", wantNick)
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "NICK", Command: "NICK",
Params: []string{wantNick}, Params: []string{wantNick},
}) })
@ -1711,7 +1711,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.nick = uc.nick + "_" uc.nick = uc.nick + "_"
uc.nickCM = uc.network.casemap(uc.nick) uc.nickCM = uc.network.casemap(uc.nick)
uc.logger.Printf("desired nick is not available, falling back to %q", uc.nick) uc.logger.Printf("desired nick is not available, falling back to %q", uc.nick)
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "NICK", Command: "NICK",
Params: []string{uc.nick}, Params: []string{uc.nick},
}) })
@ -1825,7 +1825,7 @@ func (uc *upstreamConn) requestCaps() {
return return
} }
uc.SendMessage(&irc.Message{ uc.SendMessage(context.TODO(), &irc.Message{
Command: "CAP", Command: "CAP",
Params: []string{"REQ", strings.Join(requestCaps, " ")}, Params: []string{"REQ", strings.Join(requestCaps, " ")},
}) })
@ -1882,7 +1882,7 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
return fmt.Errorf("unsupported SASL mechanism %q", name) return fmt.Errorf("unsupported SASL mechanism %q", name)
} }
uc.SendMessage(&irc.Message{ uc.SendMessage(context.TODO(), &irc.Message{
Command: "AUTHENTICATE", Command: "AUTHENTICATE",
Params: []string{auth.Mechanism}, Params: []string{auth.Mechanism},
}) })
@ -1902,28 +1902,30 @@ func splitSpace(s string) []string {
} }
func (uc *upstreamConn) register() { func (uc *upstreamConn) register() {
ctx := context.TODO()
uc.nick = GetNick(&uc.user.User, &uc.network.Network) uc.nick = GetNick(&uc.user.User, &uc.network.Network)
uc.nickCM = uc.network.casemap(uc.nick) uc.nickCM = uc.network.casemap(uc.nick)
uc.username = GetUsername(&uc.user.User, &uc.network.Network) uc.username = GetUsername(&uc.user.User, &uc.network.Network)
uc.realname = GetRealname(&uc.user.User, &uc.network.Network) uc.realname = GetRealname(&uc.user.User, &uc.network.Network)
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "CAP", Command: "CAP",
Params: []string{"LS", "302"}, Params: []string{"LS", "302"},
}) })
if uc.network.Pass != "" { if uc.network.Pass != "" {
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "PASS", Command: "PASS",
Params: []string{uc.network.Pass}, Params: []string{uc.network.Pass},
}) })
} }
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "NICK", Command: "NICK",
Params: []string{uc.nick}, Params: []string{uc.nick},
}) })
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "USER", Command: "USER",
Params: []string{uc.username, "0", "*", uc.realname}, Params: []string{uc.username, "0", "*", uc.realname},
}) })
@ -1960,7 +1962,7 @@ func (uc *upstreamConn) runUntilRegistered() error {
if err != nil { if err != nil {
uc.logger.Printf("failed to parse connect command %q: %v", command, err) uc.logger.Printf("failed to parse connect command %q: %v", command, err)
} else { } else {
uc.SendMessage(m) uc.SendMessage(context.TODO(), m)
} }
} }
@ -1982,17 +1984,17 @@ func (uc *upstreamConn) readMessages(ch chan<- event) error {
return nil return nil
} }
func (uc *upstreamConn) SendMessage(msg *irc.Message) { func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) {
if !uc.caps["message-tags"] { if !uc.caps["message-tags"] {
msg = msg.Copy() msg = msg.Copy()
msg.Tags = nil msg.Tags = nil
} }
uc.srv.metrics.upstreamOutMessagesTotal.Inc() uc.srv.metrics.upstreamOutMessagesTotal.Inc()
uc.conn.SendMessage(msg) uc.conn.SendMessage(ctx, msg)
} }
func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message) { func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) {
if uc.caps["labeled-response"] { if uc.caps["labeled-response"] {
if msg.Tags == nil { if msg.Tags == nil {
msg.Tags = make(map[string]irc.TagValue) msg.Tags = make(map[string]irc.TagValue)
@ -2000,7 +2002,7 @@ func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message
msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID)) msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID))
uc.nextLabelID++ uc.nextLabelID++
} }
uc.SendMessage(msg) uc.SendMessage(ctx, msg)
} }
// appendLog appends a message to the log file. // appendLog appends a message to the log file.
@ -2073,6 +2075,8 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstr
} }
func (uc *upstreamConn) updateAway() { func (uc *upstreamConn) updateAway() {
ctx := context.TODO()
away := true away := true
uc.forEachDownstream(func(*downstreamConn) { uc.forEachDownstream(func(*downstreamConn) {
away = false away = false
@ -2081,12 +2085,12 @@ func (uc *upstreamConn) updateAway() {
return return
} }
if away { if away {
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "AWAY", Command: "AWAY",
Params: []string{"Auto away"}, Params: []string{"Auto away"},
}) })
} else { } else {
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "AWAY", Command: "AWAY",
}) })
} }
@ -2110,6 +2114,8 @@ func (uc *upstreamConn) updateMonitor() {
return return
} }
ctx := context.TODO()
add := make(map[string]struct{}) add := make(map[string]struct{})
var addList []string var addList []string
seen := make(map[string]struct{}) seen := make(map[string]struct{})
@ -2148,7 +2154,7 @@ func (uc *upstreamConn) updateMonitor() {
if removeAll && len(addList) == 0 && len(removeList) > 0 { if removeAll && len(addList) == 0 && len(removeList) > 0 {
// Optimization when the last MONITOR-aware downstream disconnects // Optimization when the last MONITOR-aware downstream disconnects
uc.SendMessage(&irc.Message{ uc.SendMessage(ctx, &irc.Message{
Command: "MONITOR", Command: "MONITOR",
Params: []string{"C"}, Params: []string{"C"},
}) })
@ -2156,7 +2162,7 @@ func (uc *upstreamConn) updateMonitor() {
msgs := generateMonitor("-", removeList) msgs := generateMonitor("-", removeList)
msgs = append(msgs, generateMonitor("+", addList)...) msgs = append(msgs, generateMonitor("+", addList)...)
for _, msg := range msgs { for _, msg := range msgs {
uc.SendMessage(msg) uc.SendMessage(ctx, msg)
} }
} }