Add context to {conn,upstreamConn}.SendMessage

This avoids blocking on upstream message rate limiting for too
long.
This commit is contained in:
Simon Ser 2021-12-08 18:03:40 +01:00
parent d21fc06d88
commit 66aea1b4a2
4 changed files with 57 additions and 45 deletions

10
conn.go
View File

@ -217,14 +217,20 @@ func (c *conn) ReadMessage() (*irc.Message, error) {
//
// If the connection is closed before the message is sent, SendMessage silently
// drops the message.
func (c *conn) SendMessage(msg *irc.Message) {
func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) {
c.lock.Lock()
defer c.lock.Unlock()
if c.closed {
return
}
c.outgoing <- msg
select {
case c.outgoing <- msg:
// Success
case <-ctx.Done():
c.logger.Printf("failed to send message: %v", ctx.Err())
}
}
func (c *conn) RemoteAddr() net.Addr {

View File

@ -551,7 +551,7 @@ func (dc *downstreamConn) SendMessage(msg *irc.Message) {
}
dc.srv.metrics.downstreamOutMessagesTotal.Inc()
dc.conn.SendMessage(msg)
dc.conn.SendMessage(context.TODO(), msg)
}
func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef irc.TagValue)) {
@ -1666,7 +1666,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if upstream != nil && upstream != uc {
return
}
uc.SendMessageLabeled(dc.id, &irc.Message{
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "NICK",
Params: []string{nick},
})
@ -1700,7 +1700,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
// We only need to call updateNetwork for upstreams that don't
// support setname
if uc := n.conn; uc != nil && uc.caps["setname"] {
uc.SendMessageLabeled(dc.id, &irc.Message{
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "SETNAME",
Params: []string{realname},
})
@ -1775,7 +1775,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if key != "" {
params = append(params, key)
}
uc.SendMessageLabeled(dc.id, &irc.Message{
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "JOIN",
Params: params,
})
@ -1835,7 +1835,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if reason != "" {
params = append(params, reason)
}
uc.SendMessageLabeled(dc.id, &irc.Message{
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "PART",
Params: params,
})
@ -1896,7 +1896,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if reason != "" {
params = append(params, reason)
}
uc.SendMessageLabeled(dc.id, &irc.Message{
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "KICK",
Params: params,
})
@ -1915,7 +1915,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if casemapASCII(name) == dc.nickCM {
if modeStr != "" {
if uc := dc.upstream(); uc != nil {
uc.SendMessageLabeled(dc.id, &irc.Message{
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "MODE",
Params: []string{uc.nick, modeStr},
})
@ -1956,7 +1956,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if modeStr != "" {
params := []string{upstreamName, modeStr}
params = append(params, msg.Params[2:]...)
uc.SendMessageLabeled(dc.id, &irc.Message{
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "MODE",
Params: params,
})
@ -2005,7 +2005,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if len(msg.Params) > 1 { // setting topic
topic := msg.Params[1]
uc.SendMessageLabeled(dc.id, &irc.Message{
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "TOPIC",
Params: []string{upstreamName, topic},
})
@ -2070,7 +2070,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
sendNames(dc, ch)
} else {
// NAMES on a channel we have not joined, ask upstream
uc.SendMessageLabeled(dc.id, &irc.Message{
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "NAMES",
Params: []string{upstreamName},
})
@ -2270,7 +2270,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
params = []string{upstreamNick}
}
uc.SendMessageLabeled(dc.id, &irc.Message{
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "WHOIS",
Params: params,
})
@ -2347,7 +2347,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if uc.isChannel(upstreamName) {
unmarshaledText = dc.unmarshalText(uc, text)
}
uc.SendMessageLabeled(dc.id, &irc.Message{
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Tags: tags,
Command: msg.Command,
Params: []string{upstreamName, unmarshaledText},
@ -2398,7 +2398,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
continue
}
uc.SendMessageLabeled(dc.id, &irc.Message{
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Tags: tags,
Command: "TAGMSG",
Params: []string{upstreamName},
@ -2430,7 +2430,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
uc := ucChannel
uc.SendMessageLabeled(dc.id, &irc.Message{
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "INVITE",
Params: []string{upstreamUser, upstreamChannel},
})
@ -2850,7 +2850,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return newUnknownCommandError(msg.Command)
}
uc.SendMessageLabeled(dc.id, msg)
uc.SendMessageLabeled(ctx, dc.id, msg)
}
return nil
}

View File

@ -616,7 +616,7 @@ func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params [
if err != nil {
return fmt.Errorf("failed to parse command %q: %v", params[1], err)
}
uc.SendMessage(m)
uc.SendMessage(ctx, m)
sendServicePRIVMSG(dc, fmt.Sprintf("sent command to %q", net.GetName()))
return nil

View File

@ -342,7 +342,7 @@ func (uc *upstreamConn) sendNextPendingCommand(cmd string) {
if len(uc.pendingCmds[cmd]) == 0 {
return
}
uc.SendMessage(uc.pendingCmds[cmd][0].msg)
uc.SendMessage(context.TODO(), uc.pendingCmds[cmd][0].msg)
}
func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) {
@ -450,7 +450,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
switch msg.Command {
case "PING":
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "PONG",
Params: msg.Params,
})
@ -529,7 +529,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
break // we'll send CAP END after authentication is completed
}
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "CAP",
Params: []string{"END"},
})
@ -583,7 +583,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
// TODO: if a challenge is 400 bytes long, buffer it
var challengeStr string
if err := parseMessageParams(msg, &challengeStr); err != nil {
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "AUTHENTICATE",
Params: []string{"*"},
})
@ -595,7 +595,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
var err error
challenge, err = base64.StdEncoding.DecodeString(challengeStr)
if err != nil {
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "AUTHENTICATE",
Params: []string{"*"},
})
@ -612,7 +612,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
resp, err = uc.saslClient.Next(challenge)
}
if err != nil {
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "AUTHENTICATE",
Params: []string{"*"},
})
@ -625,7 +625,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
respStr = base64.StdEncoding.EncodeToString(resp)
}
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "AUTHENTICATE",
Params: []string{respStr},
})
@ -669,7 +669,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}
if !uc.registered {
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "CAP",
Params: []string{"END"},
})
@ -707,7 +707,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}
for _, msg := range join(channels, keys) {
uc.SendMessage(msg)
uc.SendMessage(ctx, msg)
}
}
case irc.RPL_MYINFO:
@ -931,7 +931,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
})
uc.updateChannelAutoDetach(ch)
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "MODE",
Params: []string{ch},
})
@ -1531,7 +1531,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
}
if found {
uc.logger.Printf("desired nick %q is now available", wantNick)
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "NICK",
Params: []string{wantNick},
})
@ -1711,7 +1711,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
uc.nick = uc.nick + "_"
uc.nickCM = uc.network.casemap(uc.nick)
uc.logger.Printf("desired nick is not available, falling back to %q", uc.nick)
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "NICK",
Params: []string{uc.nick},
})
@ -1825,7 +1825,7 @@ func (uc *upstreamConn) requestCaps() {
return
}
uc.SendMessage(&irc.Message{
uc.SendMessage(context.TODO(), &irc.Message{
Command: "CAP",
Params: []string{"REQ", strings.Join(requestCaps, " ")},
})
@ -1882,7 +1882,7 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
return fmt.Errorf("unsupported SASL mechanism %q", name)
}
uc.SendMessage(&irc.Message{
uc.SendMessage(context.TODO(), &irc.Message{
Command: "AUTHENTICATE",
Params: []string{auth.Mechanism},
})
@ -1902,28 +1902,30 @@ func splitSpace(s string) []string {
}
func (uc *upstreamConn) register() {
ctx := context.TODO()
uc.nick = GetNick(&uc.user.User, &uc.network.Network)
uc.nickCM = uc.network.casemap(uc.nick)
uc.username = GetUsername(&uc.user.User, &uc.network.Network)
uc.realname = GetRealname(&uc.user.User, &uc.network.Network)
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "CAP",
Params: []string{"LS", "302"},
})
if uc.network.Pass != "" {
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "PASS",
Params: []string{uc.network.Pass},
})
}
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "NICK",
Params: []string{uc.nick},
})
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "USER",
Params: []string{uc.username, "0", "*", uc.realname},
})
@ -1960,7 +1962,7 @@ func (uc *upstreamConn) runUntilRegistered() error {
if err != nil {
uc.logger.Printf("failed to parse connect command %q: %v", command, err)
} else {
uc.SendMessage(m)
uc.SendMessage(context.TODO(), m)
}
}
@ -1982,17 +1984,17 @@ func (uc *upstreamConn) readMessages(ch chan<- event) error {
return nil
}
func (uc *upstreamConn) SendMessage(msg *irc.Message) {
func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) {
if !uc.caps["message-tags"] {
msg = msg.Copy()
msg.Tags = nil
}
uc.srv.metrics.upstreamOutMessagesTotal.Inc()
uc.conn.SendMessage(msg)
uc.conn.SendMessage(ctx, msg)
}
func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message) {
func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) {
if uc.caps["labeled-response"] {
if msg.Tags == nil {
msg.Tags = make(map[string]irc.TagValue)
@ -2000,7 +2002,7 @@ func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message
msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID))
uc.nextLabelID++
}
uc.SendMessage(msg)
uc.SendMessage(ctx, msg)
}
// appendLog appends a message to the log file.
@ -2073,6 +2075,8 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstr
}
func (uc *upstreamConn) updateAway() {
ctx := context.TODO()
away := true
uc.forEachDownstream(func(*downstreamConn) {
away = false
@ -2081,12 +2085,12 @@ func (uc *upstreamConn) updateAway() {
return
}
if away {
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "AWAY",
Params: []string{"Auto away"},
})
} else {
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "AWAY",
})
}
@ -2110,6 +2114,8 @@ func (uc *upstreamConn) updateMonitor() {
return
}
ctx := context.TODO()
add := make(map[string]struct{})
var addList []string
seen := make(map[string]struct{})
@ -2148,7 +2154,7 @@ func (uc *upstreamConn) updateMonitor() {
if removeAll && len(addList) == 0 && len(removeList) > 0 {
// Optimization when the last MONITOR-aware downstream disconnects
uc.SendMessage(&irc.Message{
uc.SendMessage(ctx, &irc.Message{
Command: "MONITOR",
Params: []string{"C"},
})
@ -2156,7 +2162,7 @@ func (uc *upstreamConn) updateMonitor() {
msgs := generateMonitor("-", removeList)
msgs = append(msgs, generateMonitor("+", addList)...)
for _, msg := range msgs {
uc.SendMessage(msg)
uc.SendMessage(ctx, msg)
}
}