Add context to {conn,upstreamConn}.SendMessage
This avoids blocking on upstream message rate limiting for too long.
This commit is contained in:
parent
d21fc06d88
commit
66aea1b4a2
10
conn.go
10
conn.go
@ -217,14 +217,20 @@ func (c *conn) ReadMessage() (*irc.Message, error) {
|
||||
//
|
||||
// If the connection is closed before the message is sent, SendMessage silently
|
||||
// drops the message.
|
||||
func (c *conn) SendMessage(msg *irc.Message) {
|
||||
func (c *conn) SendMessage(ctx context.Context, msg *irc.Message) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if c.closed {
|
||||
return
|
||||
}
|
||||
c.outgoing <- msg
|
||||
|
||||
select {
|
||||
case c.outgoing <- msg:
|
||||
// Success
|
||||
case <-ctx.Done():
|
||||
c.logger.Printf("failed to send message: %v", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
|
@ -551,7 +551,7 @@ func (dc *downstreamConn) SendMessage(msg *irc.Message) {
|
||||
}
|
||||
|
||||
dc.srv.metrics.downstreamOutMessagesTotal.Inc()
|
||||
dc.conn.SendMessage(msg)
|
||||
dc.conn.SendMessage(context.TODO(), msg)
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef irc.TagValue)) {
|
||||
@ -1666,7 +1666,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
if upstream != nil && upstream != uc {
|
||||
return
|
||||
}
|
||||
uc.SendMessageLabeled(dc.id, &irc.Message{
|
||||
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
|
||||
Command: "NICK",
|
||||
Params: []string{nick},
|
||||
})
|
||||
@ -1700,7 +1700,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
// We only need to call updateNetwork for upstreams that don't
|
||||
// support setname
|
||||
if uc := n.conn; uc != nil && uc.caps["setname"] {
|
||||
uc.SendMessageLabeled(dc.id, &irc.Message{
|
||||
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
|
||||
Command: "SETNAME",
|
||||
Params: []string{realname},
|
||||
})
|
||||
@ -1775,7 +1775,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
if key != "" {
|
||||
params = append(params, key)
|
||||
}
|
||||
uc.SendMessageLabeled(dc.id, &irc.Message{
|
||||
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
|
||||
Command: "JOIN",
|
||||
Params: params,
|
||||
})
|
||||
@ -1835,7 +1835,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
if reason != "" {
|
||||
params = append(params, reason)
|
||||
}
|
||||
uc.SendMessageLabeled(dc.id, &irc.Message{
|
||||
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
|
||||
Command: "PART",
|
||||
Params: params,
|
||||
})
|
||||
@ -1896,7 +1896,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
if reason != "" {
|
||||
params = append(params, reason)
|
||||
}
|
||||
uc.SendMessageLabeled(dc.id, &irc.Message{
|
||||
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
|
||||
Command: "KICK",
|
||||
Params: params,
|
||||
})
|
||||
@ -1915,7 +1915,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
if casemapASCII(name) == dc.nickCM {
|
||||
if modeStr != "" {
|
||||
if uc := dc.upstream(); uc != nil {
|
||||
uc.SendMessageLabeled(dc.id, &irc.Message{
|
||||
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
|
||||
Command: "MODE",
|
||||
Params: []string{uc.nick, modeStr},
|
||||
})
|
||||
@ -1956,7 +1956,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
if modeStr != "" {
|
||||
params := []string{upstreamName, modeStr}
|
||||
params = append(params, msg.Params[2:]...)
|
||||
uc.SendMessageLabeled(dc.id, &irc.Message{
|
||||
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
|
||||
Command: "MODE",
|
||||
Params: params,
|
||||
})
|
||||
@ -2005,7 +2005,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
|
||||
if len(msg.Params) > 1 { // setting topic
|
||||
topic := msg.Params[1]
|
||||
uc.SendMessageLabeled(dc.id, &irc.Message{
|
||||
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
|
||||
Command: "TOPIC",
|
||||
Params: []string{upstreamName, topic},
|
||||
})
|
||||
@ -2070,7 +2070,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
sendNames(dc, ch)
|
||||
} else {
|
||||
// NAMES on a channel we have not joined, ask upstream
|
||||
uc.SendMessageLabeled(dc.id, &irc.Message{
|
||||
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
|
||||
Command: "NAMES",
|
||||
Params: []string{upstreamName},
|
||||
})
|
||||
@ -2270,7 +2270,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
params = []string{upstreamNick}
|
||||
}
|
||||
|
||||
uc.SendMessageLabeled(dc.id, &irc.Message{
|
||||
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
|
||||
Command: "WHOIS",
|
||||
Params: params,
|
||||
})
|
||||
@ -2347,7 +2347,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
if uc.isChannel(upstreamName) {
|
||||
unmarshaledText = dc.unmarshalText(uc, text)
|
||||
}
|
||||
uc.SendMessageLabeled(dc.id, &irc.Message{
|
||||
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
|
||||
Tags: tags,
|
||||
Command: msg.Command,
|
||||
Params: []string{upstreamName, unmarshaledText},
|
||||
@ -2398,7 +2398,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
continue
|
||||
}
|
||||
|
||||
uc.SendMessageLabeled(dc.id, &irc.Message{
|
||||
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
|
||||
Tags: tags,
|
||||
Command: "TAGMSG",
|
||||
Params: []string{upstreamName},
|
||||
@ -2430,7 +2430,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
}
|
||||
uc := ucChannel
|
||||
|
||||
uc.SendMessageLabeled(dc.id, &irc.Message{
|
||||
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
|
||||
Command: "INVITE",
|
||||
Params: []string{upstreamUser, upstreamChannel},
|
||||
})
|
||||
@ -2850,7 +2850,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
|
||||
return newUnknownCommandError(msg.Command)
|
||||
}
|
||||
|
||||
uc.SendMessageLabeled(dc.id, msg)
|
||||
uc.SendMessageLabeled(ctx, dc.id, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -616,7 +616,7 @@ func handleServiceNetworkQuote(ctx context.Context, dc *downstreamConn, params [
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse command %q: %v", params[1], err)
|
||||
}
|
||||
uc.SendMessage(m)
|
||||
uc.SendMessage(ctx, m)
|
||||
|
||||
sendServicePRIVMSG(dc, fmt.Sprintf("sent command to %q", net.GetName()))
|
||||
return nil
|
||||
|
60
upstream.go
60
upstream.go
@ -342,7 +342,7 @@ func (uc *upstreamConn) sendNextPendingCommand(cmd string) {
|
||||
if len(uc.pendingCmds[cmd]) == 0 {
|
||||
return
|
||||
}
|
||||
uc.SendMessage(uc.pendingCmds[cmd][0].msg)
|
||||
uc.SendMessage(context.TODO(), uc.pendingCmds[cmd][0].msg)
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) {
|
||||
@ -450,7 +450,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
|
||||
switch msg.Command {
|
||||
case "PING":
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "PONG",
|
||||
Params: msg.Params,
|
||||
})
|
||||
@ -529,7 +529,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
break // we'll send CAP END after authentication is completed
|
||||
}
|
||||
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "CAP",
|
||||
Params: []string{"END"},
|
||||
})
|
||||
@ -583,7 +583,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
// TODO: if a challenge is 400 bytes long, buffer it
|
||||
var challengeStr string
|
||||
if err := parseMessageParams(msg, &challengeStr); err != nil {
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "AUTHENTICATE",
|
||||
Params: []string{"*"},
|
||||
})
|
||||
@ -595,7 +595,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
var err error
|
||||
challenge, err = base64.StdEncoding.DecodeString(challengeStr)
|
||||
if err != nil {
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "AUTHENTICATE",
|
||||
Params: []string{"*"},
|
||||
})
|
||||
@ -612,7 +612,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
resp, err = uc.saslClient.Next(challenge)
|
||||
}
|
||||
if err != nil {
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "AUTHENTICATE",
|
||||
Params: []string{"*"},
|
||||
})
|
||||
@ -625,7 +625,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
respStr = base64.StdEncoding.EncodeToString(resp)
|
||||
}
|
||||
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "AUTHENTICATE",
|
||||
Params: []string{respStr},
|
||||
})
|
||||
@ -669,7 +669,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
}
|
||||
|
||||
if !uc.registered {
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "CAP",
|
||||
Params: []string{"END"},
|
||||
})
|
||||
@ -707,7 +707,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
}
|
||||
|
||||
for _, msg := range join(channels, keys) {
|
||||
uc.SendMessage(msg)
|
||||
uc.SendMessage(ctx, msg)
|
||||
}
|
||||
}
|
||||
case irc.RPL_MYINFO:
|
||||
@ -931,7 +931,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
})
|
||||
uc.updateChannelAutoDetach(ch)
|
||||
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "MODE",
|
||||
Params: []string{ch},
|
||||
})
|
||||
@ -1531,7 +1531,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
}
|
||||
if found {
|
||||
uc.logger.Printf("desired nick %q is now available", wantNick)
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "NICK",
|
||||
Params: []string{wantNick},
|
||||
})
|
||||
@ -1711,7 +1711,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
|
||||
uc.nick = uc.nick + "_"
|
||||
uc.nickCM = uc.network.casemap(uc.nick)
|
||||
uc.logger.Printf("desired nick is not available, falling back to %q", uc.nick)
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "NICK",
|
||||
Params: []string{uc.nick},
|
||||
})
|
||||
@ -1825,7 +1825,7 @@ func (uc *upstreamConn) requestCaps() {
|
||||
return
|
||||
}
|
||||
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(context.TODO(), &irc.Message{
|
||||
Command: "CAP",
|
||||
Params: []string{"REQ", strings.Join(requestCaps, " ")},
|
||||
})
|
||||
@ -1882,7 +1882,7 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
|
||||
return fmt.Errorf("unsupported SASL mechanism %q", name)
|
||||
}
|
||||
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(context.TODO(), &irc.Message{
|
||||
Command: "AUTHENTICATE",
|
||||
Params: []string{auth.Mechanism},
|
||||
})
|
||||
@ -1902,28 +1902,30 @@ func splitSpace(s string) []string {
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) register() {
|
||||
ctx := context.TODO()
|
||||
|
||||
uc.nick = GetNick(&uc.user.User, &uc.network.Network)
|
||||
uc.nickCM = uc.network.casemap(uc.nick)
|
||||
uc.username = GetUsername(&uc.user.User, &uc.network.Network)
|
||||
uc.realname = GetRealname(&uc.user.User, &uc.network.Network)
|
||||
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "CAP",
|
||||
Params: []string{"LS", "302"},
|
||||
})
|
||||
|
||||
if uc.network.Pass != "" {
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "PASS",
|
||||
Params: []string{uc.network.Pass},
|
||||
})
|
||||
}
|
||||
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "NICK",
|
||||
Params: []string{uc.nick},
|
||||
})
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "USER",
|
||||
Params: []string{uc.username, "0", "*", uc.realname},
|
||||
})
|
||||
@ -1960,7 +1962,7 @@ func (uc *upstreamConn) runUntilRegistered() error {
|
||||
if err != nil {
|
||||
uc.logger.Printf("failed to parse connect command %q: %v", command, err)
|
||||
} else {
|
||||
uc.SendMessage(m)
|
||||
uc.SendMessage(context.TODO(), m)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1982,17 +1984,17 @@ func (uc *upstreamConn) readMessages(ch chan<- event) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) SendMessage(msg *irc.Message) {
|
||||
func (uc *upstreamConn) SendMessage(ctx context.Context, msg *irc.Message) {
|
||||
if !uc.caps["message-tags"] {
|
||||
msg = msg.Copy()
|
||||
msg.Tags = nil
|
||||
}
|
||||
|
||||
uc.srv.metrics.upstreamOutMessagesTotal.Inc()
|
||||
uc.conn.SendMessage(msg)
|
||||
uc.conn.SendMessage(ctx, msg)
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message) {
|
||||
func (uc *upstreamConn) SendMessageLabeled(ctx context.Context, downstreamID uint64, msg *irc.Message) {
|
||||
if uc.caps["labeled-response"] {
|
||||
if msg.Tags == nil {
|
||||
msg.Tags = make(map[string]irc.TagValue)
|
||||
@ -2000,7 +2002,7 @@ func (uc *upstreamConn) SendMessageLabeled(downstreamID uint64, msg *irc.Message
|
||||
msg.Tags["label"] = irc.TagValue(fmt.Sprintf("sd-%d-%d", downstreamID, uc.nextLabelID))
|
||||
uc.nextLabelID++
|
||||
}
|
||||
uc.SendMessage(msg)
|
||||
uc.SendMessage(ctx, msg)
|
||||
}
|
||||
|
||||
// appendLog appends a message to the log file.
|
||||
@ -2073,6 +2075,8 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstr
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) updateAway() {
|
||||
ctx := context.TODO()
|
||||
|
||||
away := true
|
||||
uc.forEachDownstream(func(*downstreamConn) {
|
||||
away = false
|
||||
@ -2081,12 +2085,12 @@ func (uc *upstreamConn) updateAway() {
|
||||
return
|
||||
}
|
||||
if away {
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "AWAY",
|
||||
Params: []string{"Auto away"},
|
||||
})
|
||||
} else {
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "AWAY",
|
||||
})
|
||||
}
|
||||
@ -2110,6 +2114,8 @@ func (uc *upstreamConn) updateMonitor() {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
|
||||
add := make(map[string]struct{})
|
||||
var addList []string
|
||||
seen := make(map[string]struct{})
|
||||
@ -2148,7 +2154,7 @@ func (uc *upstreamConn) updateMonitor() {
|
||||
|
||||
if removeAll && len(addList) == 0 && len(removeList) > 0 {
|
||||
// Optimization when the last MONITOR-aware downstream disconnects
|
||||
uc.SendMessage(&irc.Message{
|
||||
uc.SendMessage(ctx, &irc.Message{
|
||||
Command: "MONITOR",
|
||||
Params: []string{"C"},
|
||||
})
|
||||
@ -2156,7 +2162,7 @@ func (uc *upstreamConn) updateMonitor() {
|
||||
msgs := generateMonitor("-", removeList)
|
||||
msgs = append(msgs, generateMonitor("+", addList)...)
|
||||
for _, msg := range msgs {
|
||||
uc.SendMessage(msg)
|
||||
uc.SendMessage(ctx, msg)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user