diff --git a/downstream.go b/downstream.go index 90767fa..9538f95 100644 --- a/downstream.go +++ b/downstream.go @@ -474,7 +474,7 @@ func (dc *downstreamConn) readMessages(ch chan<- event) error { // SendMessage sends an outgoing message. // // This can only called from the user goroutine. -func (dc *downstreamConn) SendMessage(msg *irc.Message) { +func (dc *downstreamConn) SendMessage(ctx context.Context, msg *irc.Message) { if !dc.caps.IsEnabled("message-tags") { if msg.Command == "TAGMSG" { return @@ -528,15 +528,15 @@ func (dc *downstreamConn) SendMessage(msg *irc.Message) { } dc.srv.metrics.downstreamOutMessagesTotal.Inc() - dc.conn.SendMessage(context.TODO(), msg) + dc.conn.SendMessage(ctx, msg) } -func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f func(batchRef string)) { +func (dc *downstreamConn) SendBatch(ctx context.Context, typ string, params []string, tags irc.Tags, f func(batchRef string)) { dc.lastBatchRef++ ref := fmt.Sprintf("%v", dc.lastBatchRef) if dc.caps.IsEnabled("batch") { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Tags: tags, Prefix: dc.srv.prefix(), Command: "BATCH", @@ -547,7 +547,7 @@ func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, f(ref) if dc.caps.IsEnabled("batch") { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "BATCH", Params: []string{"-" + ref}, @@ -556,25 +556,25 @@ func (dc *downstreamConn) SendBatch(typ string, params []string, tags irc.Tags, } // sendMessageWithID sends an outgoing message with the specified internal ID. -func (dc *downstreamConn) sendMessageWithID(msg *irc.Message, id string) { - dc.SendMessage(msg) +func (dc *downstreamConn) sendMessageWithID(ctx context.Context, msg *irc.Message, id string) { + dc.SendMessage(ctx, msg) if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps.IsEnabled("draft/chathistory") { return } - dc.sendPing(id) + dc.sendPing(ctx, id) } // advanceMessageWithID advances history to the specified message ID without // sending a message. This is useful e.g. for self-messages when echo-message // isn't enabled. -func (dc *downstreamConn) advanceMessageWithID(msg *irc.Message, id string) { +func (dc *downstreamConn) advanceMessageWithID(ctx context.Context, msg *irc.Message, id string) { if id == "" || !dc.messageSupportsBacklog(msg) || dc.caps.IsEnabled("draft/chathistory") { return } - dc.sendPing(id) + dc.sendPing(ctx, id) } // ackMsgID acknowledges that a message has been received. @@ -593,9 +593,9 @@ func (dc *downstreamConn) ackMsgID(id string) { network.delivered.StoreID(entity, dc.clientName, id) } -func (dc *downstreamConn) sendPing(msgID string) { +func (dc *downstreamConn) sendPing(ctx context.Context, msgID string) { token := "soju-msgid-" + msgID - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Command: "PING", Params: []string{token}, }) @@ -644,9 +644,9 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir return err } case "CAP": - return dc.handleCap(msg) + return dc.handleCap(ctx, msg) case "AUTHENTICATE": - credentials, err := dc.handleAuthenticate(msg) + credentials, err := dc.handleAuthenticate(ctx, msg) if err != nil { return err } else if credentials == nil { @@ -690,7 +690,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir if err != nil { dc.logger.Printf("SASL %v authentication error for nick %q: %v", credentials.mechanism, dc.nick, err) - dc.endSASL(&irc.Message{ + dc.endSASL(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_SASLFAIL, Params: []string{dc.nick, authErrorReason(err)}, @@ -707,7 +707,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir // RPL_LOGGEDIN to mirror the upstream connection status. Let's // see how many clients that breaks. See: // https://github.com/ircv3/ircv3-specifications/pull/476 - dc.endSASL(nil) + dc.endSASL(ctx, nil) case "BOUNCER": var subcommand string if err := parseMessageParams(msg, &subcommand); err != nil { @@ -747,7 +747,7 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir return nil } -func (dc *downstreamConn) handleCap(msg *irc.Message) error { +func (dc *downstreamConn) handleCap(ctx context.Context, msg *irc.Message) error { var cmd string if err := parseMessageParams(msg, &cmd); err != nil { return err @@ -781,7 +781,7 @@ func (dc *downstreamConn) handleCap(msg *irc.Message) error { } // TODO: multi-line replies - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "CAP", Params: []string{dc.nick, "LS", strings.Join(caps, " ")}, @@ -802,7 +802,7 @@ func (dc *downstreamConn) handleCap(msg *irc.Message) error { } // TODO: multi-line replies - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "CAP", Params: []string{dc.nick, "LIST", strings.Join(caps, " ")}, @@ -860,7 +860,7 @@ func (dc *downstreamConn) handleCap(msg *irc.Message) error { if ack { reply = "ACK" } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "CAP", Params: []string{dc.nick, reply, args[0]}, @@ -882,7 +882,7 @@ func (dc *downstreamConn) handleCap(msg *irc.Message) error { return nil } -func (dc *downstreamConn) handleAuthenticate(msg *irc.Message) (result *downstreamSASL, err error) { +func (dc *downstreamConn) handleAuthenticate(ctx context.Context, msg *irc.Message) (result *downstreamSASL, err error) { defer func() { if err != nil { dc.sasl = nil @@ -989,7 +989,7 @@ func (dc *downstreamConn) handleAuthenticate(msg *irc.Message) (result *downstre } // TODO: multi-line messages - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "AUTHENTICATE", Params: []string{challengeStr}, @@ -998,7 +998,7 @@ func (dc *downstreamConn) handleAuthenticate(msg *irc.Message) (result *downstre } } -func (dc *downstreamConn) endSASL(msg *irc.Message) { +func (dc *downstreamConn) endSASL(ctx context.Context, msg *irc.Message) { if dc.sasl == nil { return } @@ -1006,9 +1006,9 @@ func (dc *downstreamConn) endSASL(msg *irc.Message) { dc.sasl = nil if msg != nil { - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) } else { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_SASLSUCCESS, Params: []string{dc.nick, "SASL authentication successful"}, @@ -1016,7 +1016,7 @@ func (dc *downstreamConn) endSASL(msg *irc.Message) { } } -func (dc *downstreamConn) setSupportedCap(name, value string) { +func (dc *downstreamConn) setSupportedCap(ctx context.Context, name, value string) { prevValue, hasPrev := dc.caps.Available[name] changed := !hasPrev || prevValue != value dc.caps.Available[name] = value @@ -1030,14 +1030,14 @@ func (dc *downstreamConn) setSupportedCap(name, value string) { cap = name + "=" + value } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "CAP", Params: []string{dc.nick, "NEW", cap}, }) } -func (dc *downstreamConn) unsetSupportedCap(name string) { +func (dc *downstreamConn) unsetSupportedCap(ctx context.Context, name string) { hasPrev := dc.caps.IsAvailable(name) dc.caps.Del(name) @@ -1045,14 +1045,14 @@ func (dc *downstreamConn) unsetSupportedCap(name string) { return } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "CAP", Params: []string{dc.nick, "DEL", name}, }) } -func (dc *downstreamConn) updateSupportedCaps() { +func (dc *downstreamConn) updateSupportedCaps(ctx context.Context) { supportedCaps := make(map[string]bool) for cap := range needAllDownstreamCaps { supportedCaps[cap] = true @@ -1065,16 +1065,16 @@ func (dc *downstreamConn) updateSupportedCaps() { for cap, supported := range supportedCaps { if supported { - dc.setSupportedCap(cap, needAllDownstreamCaps[cap]) + dc.setSupportedCap(ctx, cap, needAllDownstreamCaps[cap]) } else { - dc.unsetSupportedCap(cap) + dc.unsetSupportedCap(ctx, cap) } } if uc := dc.upstream(); uc != nil && uc.supportsSASL("PLAIN") { - dc.setSupportedCap("sasl", "PLAIN,ANONYMOUS") + dc.setSupportedCap(ctx, "sasl", "PLAIN,ANONYMOUS") } else if dc.network != nil { - dc.unsetSupportedCap("sasl") + dc.unsetSupportedCap(ctx, "sasl") } if uc := dc.upstream(); uc != nil && uc.caps.IsEnabled("draft/account-registration") { @@ -1087,19 +1087,19 @@ func (dc *downstreamConn) updateSupportedCaps() { break } } - dc.setSupportedCap("draft/account-registration", strings.Join(values, ",")) + dc.setSupportedCap(ctx, "draft/account-registration", strings.Join(values, ",")) } else { - dc.unsetSupportedCap("draft/account-registration") + dc.unsetSupportedCap(ctx, "draft/account-registration") } if _, ok := dc.user.msgStore.(msgstore.ChatHistoryStore); ok && dc.network != nil { - dc.setSupportedCap("draft/event-playback", "") + dc.setSupportedCap(ctx, "draft/event-playback", "") } else { - dc.unsetSupportedCap("draft/event-playback") + dc.unsetSupportedCap(ctx, "draft/event-playback") } } -func (dc *downstreamConn) updateNick() { +func (dc *downstreamConn) updateNick(ctx context.Context) { var nick string if uc := dc.upstream(); uc != nil { nick = uc.nick @@ -1113,7 +1113,7 @@ func (dc *downstreamConn) updateNick() { return } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.prefix(), Command: "NICK", Params: []string{nick}, @@ -1122,7 +1122,7 @@ func (dc *downstreamConn) updateNick() { dc.nickCM = dc.casemap(dc.nick) } -func (dc *downstreamConn) updateHost() { +func (dc *downstreamConn) updateHost(ctx context.Context) { uc := dc.upstream() if uc == nil || uc.hostname == "" { return @@ -1133,13 +1133,13 @@ func (dc *downstreamConn) updateHost() { } if dc.caps.IsEnabled("chghost") { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.prefix(), Command: "CHGHOST", Params: []string{uc.username, uc.hostname}, }) } else if uc.hostname != dc.hostname { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.prefix(), Command: xirc.RPL_VISIBLEHOST, Params: []string{dc.nick, uc.hostname, "is now your visible host"}, @@ -1150,7 +1150,7 @@ func (dc *downstreamConn) updateHost() { dc.username = uc.username } -func (dc *downstreamConn) updateRealname() { +func (dc *downstreamConn) updateRealname(ctx context.Context) { if !dc.caps.IsEnabled("setname") { return } @@ -1165,7 +1165,7 @@ func (dc *downstreamConn) updateRealname() { } if realname != dc.realname { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.prefix(), Command: "SETNAME", Params: []string{realname}, @@ -1174,7 +1174,7 @@ func (dc *downstreamConn) updateRealname() { } } -func (dc *downstreamConn) updateAccount() { +func (dc *downstreamConn) updateAccount(ctx context.Context) { var account string if dc.network == nil { account = dc.user.Username @@ -1189,13 +1189,13 @@ func (dc *downstreamConn) updateAccount() { } if account != "" { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_LOGGEDIN, Params: []string{dc.nick, dc.prefix().String(), account, "You are logged in as " + account}, }) } else { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_LOGGEDOUT, Params: []string{dc.nick, dc.prefix().String(), "You are logged out"}, @@ -1266,7 +1266,7 @@ func (dc *downstreamConn) register(ctx context.Context) error { } if dc.sasl != nil { - dc.endSASL(&irc.Message{ + dc.endSASL(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_SASLABORTED, Params: []string{dc.nick, "SASL authentication aborted"}, @@ -1430,7 +1430,7 @@ func (dc *downstreamConn) welcome(ctx context.Context, user *user) error { dc.registration = nil - dc.updateSupportedCaps() + dc.updateSupportedCaps(ctx) if uc := dc.upstream(); uc != nil { dc.nick = uc.nick @@ -1482,54 +1482,54 @@ func (dc *downstreamConn) welcome(ctx context.Context, user *user) error { } } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_WELCOME, Params: []string{dc.nick, "Welcome to soju, " + dc.nick}, }) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_YOURHOST, Params: []string{dc.nick, "Your host is " + dc.srv.Config().Hostname}, }) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_MYINFO, Params: []string{dc.nick, dc.srv.Config().Hostname, "soju", "aiwroO", "OovaimnqpsrtklbeI"}, }) for _, msg := range xirc.GenerateIsupport(dc.srv.prefix(), dc.nick, isupport) { - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) } if uc := dc.upstream(); uc != nil { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_UMODEIS, Params: []string{dc.nick, "+" + string(uc.modes)}, }) } if dc.network == nil && dc.user.Admin { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_UMODEIS, Params: []string{dc.nick, "+o"}, }) } - dc.updateHost() - dc.updateRealname() - dc.updateAccount() + dc.updateHost(ctx) + dc.updateRealname(ctx) + dc.updateAccount(ctx) dc.updateCasemapping() if motd := dc.user.srv.Config().MOTD; motd != "" && dc.network == nil { for _, msg := range xirc.GenerateMOTD(dc.srv.prefix(), dc.nick, motd) { - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) } } else { motdHint := "No MOTD" if dc.network != nil { motdHint = "Use /motd to read the message of the day" } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_NOMOTD, Params: []string{dc.nick, motdHint}, @@ -1537,11 +1537,11 @@ func (dc *downstreamConn) welcome(ctx context.Context, user *user) error { } if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") { - dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef string) { + dc.SendBatch(ctx, "soju.im/bouncer-networks", nil, nil, func(batchRef string) { for _, network := range dc.user.networks { idStr := fmt.Sprintf("%v", network.ID) attrs := getNetworkAttrs(network) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Tags: irc.Tags{"batch": batchRef}, Prefix: dc.srv.prefix(), Command: "BOUNCER", @@ -1561,7 +1561,7 @@ func (dc *downstreamConn) welcome(ctx context.Context, user *user) error { return } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.prefix(), Command: "JOIN", Params: []string{ch.Name}, @@ -1643,7 +1643,7 @@ func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, t return } - dc.SendBatch("chathistory", []string{target}, nil, func(batchRef string) { + dc.SendBatch(ctx, "chathistory", []string{target}, nil, func(batchRef string) { for _, msg := range history { if ch != nil && ch.Detached { if net.detachedMessageNeedsRelay(ch, msg) { @@ -1651,7 +1651,7 @@ func (dc *downstreamConn) sendTargetBacklog(ctx context.Context, net *network, t } } else { msg.Tags["batch"] = batchRef - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) } } }) @@ -1679,7 +1679,7 @@ func (dc *downstreamConn) runUntilRegistered() error { go func() { <-ctx.Done() if err := ctx.Err(); err == context.DeadlineExceeded { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "ERROR", Params: []string{"Connection registration timed out"}, @@ -1697,7 +1697,7 @@ func (dc *downstreamConn) runUntilRegistered() error { err = dc.handleMessage(ctx, msg) if ircErr, ok := err.(ircError); ok { ircErr.Message.Prefix = dc.srv.prefix() - dc.SendMessage(ircErr.Message) + dc.SendMessage(ctx, ircErr.Message) } else if err != nil { return fmt.Errorf("failed to handle IRC command %q: %v", msg, err) } @@ -1713,7 +1713,7 @@ func (dc *downstreamConn) runUntilRegistered() error { func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.Message) error { switch msg.Command { case "CAP": - return dc.handleCap(msg) + return dc.handleCap(ctx, msg) case "PING": var source, destination string if err := parseMessageParams(msg, &source); err != nil { @@ -1729,7 +1729,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Params: []string{dc.nick, destination, "No such server"}, }} } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "PONG", Params: []string{hostname, source}, @@ -1792,12 +1792,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Params: []string{nick}, }) } else { - dc.updateNick() + dc.updateNick(ctx) } } else { for _, c := range dc.user.downstreamConns { if c.network == nil { - c.updateNick() + c.updateNick(ctx) } } } @@ -1808,7 +1808,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } if dc.realname == realname { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.prefix(), Command: "SETNAME", Params: []string{realname}, @@ -1856,7 +1856,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if dc.network == nil { for _, c := range dc.user.downstreamConns { if c.network == nil { - c.updateRealname() + c.updateRealname(ctx) } } } @@ -1883,7 +1883,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } if name == "" || strings.ContainsAny(name, illegalChanChars) { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_BADCHANMASK, Params: []string{name, "Invalid channel name"}, @@ -1891,7 +1891,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. continue } if !uc.isChannel(name) { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_NOSUCHCHANNEL, Params: []string{name, "Not a channel name"}, @@ -2007,7 +2007,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Params: []string{uc.nick, modeStr}, }) } else { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_UMODEUNKNOWNFLAG, Params: []string{dc.nick, "Cannot change user mode on bouncer connection"}, @@ -2019,7 +2019,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. userMode = string(uc.modes) } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_UMODEIS, Params: []string{dc.nick, "+" + userMode}, @@ -2066,13 +2066,13 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. params := []string{dc.nick, name, modeStr} params = append(params, modeParams...) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_CHANNELMODEIS, Params: params, }) if ch.creationTime != "" { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: xirc.RPL_CREATIONTIME, Params: []string{dc.nick, name, ch.creationTime}, @@ -2104,7 +2104,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Params: []string{dc.nick, name, "No such channel"}, }} } - sendTopic(dc, ch) + sendTopic(ctx, dc, ch) } case "LIST": uc, err := dc.upstreamForCommand(msg.Command) @@ -2120,7 +2120,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } if len(msg.Params) == 0 { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_ENDOFNAMES, Params: []string{dc.nick, "*", "End of /NAMES list"}, @@ -2132,7 +2132,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. for _, name := range channels { ch := uc.channels.Get(name) if ch != nil { - sendNames(dc, ch) + sendNames(ctx, dc, ch) } else { // NAMES on a channel we have not joined, ask upstream uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ @@ -2143,7 +2143,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } case "WHO": if len(msg.Params) == 0 { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_ENDOFWHO, Params: []string{dc.nick, "*", "End of /WHO list"}, @@ -2181,8 +2181,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Account: dc.user.Username, Realname: dc.realname, } - dc.SendMessage(xirc.GenerateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, xirc.GenerateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_ENDOFWHO, Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"}, @@ -2208,8 +2208,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Account: serviceNick, Realname: serviceRealname, } - dc.SendMessage(xirc.GenerateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, xirc.GenerateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_ENDOFWHO, Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"}, @@ -2217,7 +2217,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return nil } if dc.network == nil { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_ENDOFWHO, Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"}, @@ -2246,9 +2246,9 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if uc.isChannel(mask) { info.Channel = mask } - dc.SendMessage(xirc.GenerateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) + dc.SendMessage(ctx, xirc.GenerateWHOXReply(dc.srv.prefix(), dc.nick, fields, &info)) } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_ENDOFWHO, Params: []string{dc.nick, endOfWhoToken, "End of /WHO list"}, @@ -2277,29 +2277,29 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } if dc.network == nil && dc.casemap(mask) == dc.nickCM { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_WHOISUSER, Params: []string{dc.nick, dc.nick, dc.user.Username, dc.hostname, "*", dc.realname}, }) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_WHOISSERVER, Params: []string{dc.nick, dc.nick, dc.srv.Config().Hostname, "soju"}, }) if dc.user.Admin { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_WHOISOPERATOR, Params: []string{dc.nick, dc.nick, "is a bouncer administrator"}, }) } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: xirc.RPL_WHOISACCOUNT, Params: []string{dc.nick, dc.nick, dc.user.Username, "is logged in as"}, }) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_ENDOFWHOIS, Params: []string{dc.nick, dc.nick, "End of /WHOIS list"}, @@ -2307,32 +2307,32 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return nil } if dc.casemap(mask) == serviceNickCM { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_WHOISUSER, Params: []string{dc.nick, serviceNick, servicePrefix.User, servicePrefix.Host, "*", serviceRealname}, }) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_WHOISSERVER, Params: []string{dc.nick, serviceNick, dc.srv.Config().Hostname, "soju"}, }) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_WHOISOPERATOR, Params: []string{dc.nick, serviceNick, "is the bouncer service"}, }) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: xirc.RPL_WHOISACCOUNT, Params: []string{dc.nick, serviceNick, serviceNick, "is logged in as"}, }) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: xirc.RPL_WHOISBOT, Params: []string{dc.nick, serviceNick, "is a bot"}, }) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_ENDOFWHOIS, Params: []string{dc.nick, serviceNick, "End of /WHOIS list"}, @@ -2394,7 +2394,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } if dc.network == nil && dc.casemap(name) == dc.nickCM { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Tags: msg.Tags.Copy(), Prefix: dc.prefix(), Command: msg.Command, @@ -2407,7 +2407,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if dc.caps.IsEnabled("echo-message") { echoTags := tags.Copy() echoTags["time"] = dc.user.FormatServerTime(time.Now()) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Tags: echoTags, Prefix: dc.prefix(), Command: msg.Command, @@ -2499,7 +2499,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }} } - credentials, err := dc.handleAuthenticate(msg) + credentials, err := dc.handleAuthenticate(ctx, msg) if err != nil { return err } else if credentials == nil { @@ -2507,7 +2507,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } if uc.saslClient != nil { - dc.endSASL(&irc.Message{ + dc.endSASL(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_SASLFAIL, Params: []string{dc.nick, "Another authentication attempt is already in progress"}, @@ -2534,7 +2534,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. _, err := dc.user.updateNetwork(ctx, &record) if err != nil { dc.logger.Printf("failed to clear SASL credentials") - dc.endSASL(&irc.Message{ + dc.endSASL(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_SASLFAIL, Params: []string{dc.nick, "Internal server error"}, @@ -2542,9 +2542,9 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. break } } - dc.endSASL(nil) + dc.endSASL(ctx, nil) default: - dc.endSASL(&irc.Message{ + dc.endSASL(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_SASLFAIL, Params: []string{dc.nick, "Unsupported SASL authentication mechanism"}, @@ -2580,7 +2580,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. cmd = irc.RPL_UNAWAY desc = "You are no longer marked as being away" } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Command: cmd, Params: []string{dc.nick, desc}, }) @@ -2591,11 +2591,11 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } case "INFO": if dc.network == nil { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Command: irc.RPL_INFO, Params: []string{dc.nick, "soju "}, }) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Command: irc.RPL_ENDOFINFO, Params: []string{dc.nick, "End of INFO"}, }) @@ -2635,7 +2635,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if subcommand == "+" { // Hard limit, just to avoid having downstreams fill our map if dc.monitored.Len() >= 1000 { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_MONLISTFULL, Params: []string{dc.nick, "1000", target, "Bouncer monitor list is full"}, @@ -2647,7 +2647,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if uc.network.casemap(target) == serviceNickCM { // BouncerServ is never tired - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_MONONLINE, Params: []string{dc.nick, target}, @@ -2661,7 +2661,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. cmd = irc.RPL_MONONLINE } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: cmd, Params: []string{dc.nick, target}, @@ -2678,13 +2678,13 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. case "L": // list // TODO: be less lazy and pack the list dc.monitored.ForEach(func(name string, _ struct{}) { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_MONLIST, Params: []string{dc.nick, name}, }) }) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_ENDOFMONLIST, Params: []string{dc.nick, "End of MONITOR list"}, @@ -2701,7 +2701,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. cmd = irc.RPL_MONONLINE } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: cmd, Params: []string{dc.nick, target}, @@ -2728,7 +2728,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if dc.network == nil { // Either an unbound bouncer network, in which case we should return no targets, // or a multi-upstream downstream, but we don't support CHATHISTORY TARGETS for those yet. - dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef string) {}) + dc.SendBatch(ctx, "draft/chathistory-targets", nil, nil, func(batchRef string) {}) return nil } if err := parseMessageParams(msg, nil, &boundsStr[0], &boundsStr[1], &limitStr); err != nil { @@ -2744,7 +2744,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. // We don't save history for our service if dc.casemap(target) == serviceNickCM { - dc.SendBatch("chathistory", []string{target}, nil, func(batchRef string) {}) + dc.SendBatch(ctx, "chathistory", []string{target}, nil, func(batchRef string) {}) return nil } @@ -2827,13 +2827,13 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }} } - dc.SendBatch("draft/chathistory-targets", nil, nil, func(batchRef string) { + dc.SendBatch(ctx, "draft/chathistory-targets", nil, nil, func(batchRef string) { for _, target := range targets { if ch := network.channels.Get(target.Name); ch != nil && ch.Detached { continue } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Tags: irc.Tags{"batch": batchRef}, Prefix: dc.srv.prefix(), Command: "CHATHISTORY", @@ -2849,10 +2849,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return newChatHistoryError(subcommand, target) } - dc.SendBatch("chathistory", []string{target}, nil, func(batchRef string) { + dc.SendBatch(ctx, "chathistory", []string{target}, nil, func(batchRef string) { for _, msg := range history { msg.Tags["batch"] = batchRef - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) } }) case "READ", "MARKREAD": @@ -2869,7 +2869,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. // We don't save read receipts for our service if dc.casemap(target) == serviceNickCM { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.prefix(), Command: msg.Command, Params: []string{target, "*"}, @@ -2945,7 +2945,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if !d.caps.IsEnabled("draft/read-marker") { cmd = "READ" } - d.SendMessage(&irc.Message{ + d.SendMessage(ctx, &irc.Message{ Prefix: d.prefix(), Command: cmd, Params: []string{target, timestampStr}, @@ -3042,10 +3042,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }} } - dc.SendBatch("soju.im/search", nil, nil, func(batchRef string) { + dc.SendBatch(ctx, "soju.im/search", nil, nil, func(batchRef string) { for _, msg := range messages { msg.Tags["batch"] = batchRef - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) } }) case "BOUNCER": @@ -3061,11 +3061,11 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Params: []string{"BOUNCER", "REGISTRATION_IS_COMPLETED", "BIND", "Cannot bind to a network after registration"}, }} case "LISTNETWORKS": - dc.SendBatch("soju.im/bouncer-networks", nil, nil, func(batchRef string) { + dc.SendBatch(ctx, "soju.im/bouncer-networks", nil, nil, func(batchRef string) { for _, network := range dc.user.networks { idStr := fmt.Sprintf("%v", network.ID) attrs := getNetworkAttrs(network) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Tags: irc.Tags{"batch": batchRef}, Prefix: dc.srv.prefix(), Command: "BOUNCER", @@ -3100,7 +3100,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }} } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "BOUNCER", Params: []string{"ADDNETWORK", fmt.Sprintf("%v", network.ID)}, @@ -3144,7 +3144,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }} } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "BOUNCER", Params: []string{"CHANGENETWORK", idStr}, @@ -3171,7 +3171,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return err } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "BOUNCER", Params: []string{"DELNETWORK", idStr}, @@ -3275,7 +3275,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "WEBPUSH", Params: []string{"REGISTER", endpoint}, @@ -3297,7 +3297,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. oldSub := findWebPushSubscription(subs, endpoint) if oldSub == nil { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "WEBPUSH", Params: []string{"UNREGISTER", endpoint}, @@ -3313,7 +3313,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }} } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "WEBPUSH", Params: []string{"UNREGISTER", endpoint}, @@ -3407,7 +3407,7 @@ func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel // RPL_NOTOPIC shouldn't be sent on JOIN if ch.Topic != "" { - sendTopic(dc, ch) + sendTopic(ctx, dc, ch) } var markReadCmd string @@ -3426,7 +3426,7 @@ func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel if r != nil { timestampStr = fmt.Sprintf("timestamp=%s", xirc.FormatServerTime(r.Timestamp)) } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.prefix(), Command: markReadCmd, Params: []string{ch.Name, timestampStr}, @@ -3435,27 +3435,27 @@ func forwardChannel(ctx context.Context, dc *downstreamConn, ch *upstreamChannel } if !dc.caps.IsEnabled("soju.im/no-implicit-names") { - sendNames(dc, ch) + sendNames(ctx, dc, ch) } } -func sendTopic(dc *downstreamConn, ch *upstreamChannel) { +func sendTopic(ctx context.Context, dc *downstreamConn, ch *upstreamChannel) { if ch.Topic != "" { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_TOPIC, Params: []string{dc.nick, ch.Name, ch.Topic}, }) if ch.TopicWho != nil { topicTime := strconv.FormatInt(ch.TopicTime.Unix(), 10) - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: xirc.RPL_TOPICWHOTIME, Params: []string{dc.nick, ch.Name, ch.TopicWho.String(), topicTime}, }) } } else { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_NOTOPIC, Params: []string{dc.nick, ch.Name, "No topic is set"}, @@ -3463,7 +3463,7 @@ func sendTopic(dc *downstreamConn, ch *upstreamChannel) { } } -func sendNames(dc *downstreamConn, ch *upstreamChannel) { +func sendNames(ctx context.Context, dc *downstreamConn, ch *upstreamChannel) { var members []string ch.Members.ForEach(func(nick string, memberships *xirc.MembershipSet) { s := formatMemberPrefix(*memberships, dc) + nick @@ -3472,6 +3472,6 @@ func sendNames(dc *downstreamConn, ch *upstreamChannel) { msgs := xirc.GenerateNamesReply(dc.srv.prefix(), dc.nick, ch.Name, ch.Status, members) for _, msg := range msgs { - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) } } diff --git a/server.go b/server.go index 092493b..9abd2b1 100644 --- a/server.go +++ b/server.go @@ -462,7 +462,7 @@ func (s *Server) Handle(ic ircConn) { defer dc.Close() if shutdown { - dc.SendMessage(&irc.Message{ + dc.SendMessage(context.TODO(), &irc.Message{ Command: "ERROR", Params: []string{"Server is shutting down"}, }) @@ -478,7 +478,7 @@ func (s *Server) Handle(ic ircConn) { user, err := s.getOrCreateUser(context.TODO(), dc.registration.authUsername) if err != nil { - dc.SendMessage(&irc.Message{ + dc.SendMessage(context.TODO(), &irc.Message{ Command: "ERROR", Params: []string{"Internal server error"}, }) diff --git a/service.go b/service.go index a1f01f1..acd9e41 100644 --- a/service.go +++ b/service.go @@ -56,7 +56,7 @@ type serviceCommand struct { } func sendServiceNOTICE(dc *downstreamConn, text string) { - dc.SendMessage(&irc.Message{ + dc.SendMessage(context.TODO(), &irc.Message{ Prefix: servicePrefix, Command: "NOTICE", Params: []string{dc.nick, text}, @@ -64,7 +64,7 @@ func sendServiceNOTICE(dc *downstreamConn, text string) { } func sendServicePRIVMSG(dc *downstreamConn, text string) { - dc.SendMessage(&irc.Message{ + dc.SendMessage(context.TODO(), &irc.Message{ Prefix: servicePrefix, Command: "PRIVMSG", Params: []string{dc.nick, text}, diff --git a/upstream.go b/upstream.go index 5669bc3..c26f993 100644 --- a/upstream.go +++ b/upstream.go @@ -421,19 +421,20 @@ func (uc *upstreamConn) isOurNick(nick string) bool { return uc.network.equalCasemap(uc.nick, nick) } -func (uc *upstreamConn) forwardMessage(msg *irc.Message) { +func (uc *upstreamConn) forwardMessage(ctx context.Context, msg *irc.Message) { uc.forEachDownstream(func(dc *downstreamConn) { - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) }) } -func (uc *upstreamConn) forwardMsgByID(id uint64, msg *irc.Message) { +func (uc *upstreamConn) forwardMsgByID(ctx context.Context, id uint64, msg *irc.Message) { uc.forEachDownstreamByID(id, func(dc *downstreamConn) { - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) }) } func (uc *upstreamConn) abortPendingCommands() { + ctx := context.TODO() for _, l := range uc.pendingCmds { for _, pendingCmd := range l { dc := uc.downstreamByID(pendingCmd.downstreamID) @@ -443,7 +444,7 @@ func (uc *upstreamConn) abortPendingCommands() { switch pendingCmd.msg.Command { case "LIST": - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_LISTEND, Params: []string{dc.nick, "Command aborted"}, @@ -453,26 +454,26 @@ func (uc *upstreamConn) abortPendingCommands() { if len(pendingCmd.msg.Params) > 0 { mask = pendingCmd.msg.Params[0] } - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_ENDOFWHO, Params: []string{dc.nick, mask, "Command aborted"}, }) case "WHOIS": nick := pendingCmd.msg.Params[len(pendingCmd.msg.Params)-1] - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_ENDOFWHOIS, Params: []string{dc.nick, nick, "Command aborted"}, }) case "AUTHENTICATE": - dc.endSASL(&irc.Message{ + dc.endSASL(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_SASLABORTED, Params: []string{dc.nick, "SASL authentication aborted"}, }) case "REGISTER", "VERIFY": - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "FAIL", Params: []string{pendingCmd.msg.Command, "TEMPORARILY_UNAVAILABLE", pendingCmd.msg.Params[0], "Command aborted"}, @@ -732,7 +733,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if uc.registered { uc.forEachDownstream(func(dc *downstreamConn) { - dc.updateSupportedCaps() + dc.updateSupportedCaps(ctx) }) } case "NEW": @@ -753,7 +754,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if uc.registered { uc.forEachDownstream(func(dc *downstreamConn) { - dc.updateSupportedCaps() + dc.updateSupportedCaps(ctx) }) } default: @@ -818,8 +819,8 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.logger.Printf("logged in with account %q", uc.account) uc.forEachDownstream(func(dc *downstreamConn) { - dc.updateAccount() - dc.updateHost() + dc.updateAccount(ctx) + dc.updateHost(ctx) }) case irc.RPL_LOGGEDOUT: var rawPrefix string @@ -835,8 +836,8 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.logger.Printf("logged out") uc.forEachDownstream(func(dc *downstreamConn) { - dc.updateAccount() - dc.updateHost() + dc.updateAccount(ctx) + dc.updateHost(ctx) }) case xirc.RPL_VISIBLEHOST: var rawHost string @@ -852,7 +853,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } uc.forEachDownstream(func(dc *downstreamConn) { - dc.updateHost() + dc.updateHost(ctx) }) case irc.ERR_NICKLOCKED, irc.RPL_SASLSUCCESS, irc.ERR_SASLFAIL, irc.ERR_SASLTOOLONG, irc.ERR_SASLABORTED: var info string @@ -876,7 +877,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.network.autoSaveSASLPlain(ctx, dc.sasl.plain.Username, dc.sasl.plain.Password) } - dc.endSASL(msg) + dc.endSASL(ctx, msg) } if !uc.registered { @@ -898,7 +899,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.network.autoSaveSASLPlain(ctx, account, password) } - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) } case irc.RPL_WELCOME: if err := parseMessageParams(msg, &uc.nick); err != nil { @@ -993,7 +994,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.forEachDownstream(func(dc *downstreamConn) { msgs := xirc.GenerateIsupport(dc.srv.prefix(), dc.nick, downstreamIsupport) for _, msg := range msgs { - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) } }) case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD: @@ -1017,7 +1018,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return nil } - uc.forwardMsgByID(downstreamID, msg) + uc.forwardMsgByID(ctx, downstreamID, msg) case "BATCH": var tag string if err := parseMessageParams(msg, &tag); err != nil { @@ -1088,10 +1089,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err }) if !me { - uc.forwardMessage(msg) + uc.forwardMessage(ctx, msg) } else { uc.forEachDownstream(func(dc *downstreamConn) { - dc.updateNick() + dc.updateNick(ctx) }) uc.updateMonitor() } @@ -1112,10 +1113,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.realname = newRealname uc.forEachDownstream(func(dc *downstreamConn) { - dc.updateRealname() + dc.updateRealname(ctx) }) } else { - uc.forwardMessage(msg) + uc.forwardMessage(ctx, msg) } case "CHGHOST": var newUsername, newHostname string @@ -1135,11 +1136,11 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.hostname = newHostname uc.forEachDownstream(func(dc *downstreamConn) { - dc.updateHost() + dc.updateHost(ctx) }) } else { // TODO: add fallback with QUIT/JOIN/MODE messages - uc.forwardMessage(msg) + uc.forwardMessage(ctx, msg) } case "JOIN": var channels string @@ -1274,7 +1275,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.users.Del(msg.Prefix.Name) if msg.Prefix.Name != uc.nick { - uc.forwardMessage(msg) + uc.forwardMessage(ctx, msg) } case irc.RPL_TOPIC, irc.RPL_NOTOPIC: var name, topic string @@ -1322,7 +1323,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return err } - uc.forwardMessage(msg) + uc.forwardMessage(ctx, msg) } else { // channel mode change ch, err := uc.getChannel(name) if err != nil { @@ -1338,7 +1339,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err c := uc.network.channels.Get(name) if c == nil || !c.Detached { - uc.forwardMessage(msg) + uc.forwardMessage(ctx, msg) } } case irc.RPL_UMODEIS: @@ -1355,7 +1356,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return err } - uc.forwardMessage(msg) + uc.forwardMessage(ctx, msg) case irc.RPL_CHANNELMODEIS: var channel string if err := parseMessageParams(msg, nil, &channel); err != nil { @@ -1379,7 +1380,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err c := uc.network.channels.Get(channel) if firstMode && (c == nil || !c.Detached) { - uc.forwardMessage(msg) + uc.forwardMessage(ctx, msg) } case xirc.RPL_CREATIONTIME: var channel, creationTime string @@ -1397,7 +1398,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err c := uc.network.channels.Get(channel) if firstCreationTime && (c == nil || !c.Detached) { - uc.forwardMessage(msg) + uc.forwardMessage(ctx, msg) } case xirc.RPL_TOPICWHOTIME: var channel, who, timeStr string @@ -1420,7 +1421,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err c := uc.network.channels.Get(channel) if firstTopicWhoTime && (c == nil || !c.Detached) { - uc.forwardMessage(msg) + uc.forwardMessage(ctx, msg) } case irc.RPL_LIST: dc, cmd := uc.currentPendingCommand("LIST") @@ -1430,7 +1431,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return nil } - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) case irc.RPL_LISTEND: dc, cmd := uc.dequeueCommand("LIST") if cmd == nil { @@ -1439,7 +1440,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return nil } - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) case irc.RPL_NAMREPLY: var name, statusStr, members string if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil { @@ -1449,7 +1450,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err ch := uc.channels.Get(name) if ch == nil { // NAMES on a channel we have not joined, forward to downstream - uc.forwardMsgByID(downstreamID, msg) + uc.forwardMsgByID(ctx, downstreamID, msg) return nil } @@ -1472,7 +1473,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err ch := uc.channels.Get(name) if ch == nil { // NAMES on a channel we have not joined, forward to downstream - uc.forwardMsgByID(downstreamID, msg) + uc.forwardMsgByID(ctx, downstreamID, msg) return nil } @@ -1506,7 +1507,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } realname := parts[1] - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) if uc.shouldCacheUserInfo(nick) { uc.cacheUserInfo(nick, &upstreamUser{ @@ -1526,7 +1527,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return nil } - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) if len(cmd.Params) > 1 { fields, _ := xirc.ParseWHOXOptions(cmd.Params[1]) @@ -1559,7 +1560,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return nil } - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) case xirc.RPL_WHOISCERTFP, xirc.RPL_WHOISREGNICK, irc.RPL_WHOISUSER, irc.RPL_WHOISSERVER, irc.RPL_WHOISCHANNELS, irc.RPL_WHOISOPERATOR, irc.RPL_WHOISIDLE, xirc.RPL_WHOISSPECIAL, xirc.RPL_WHOISACCOUNT, xirc.RPL_WHOISACTUALLY, xirc.RPL_WHOISHOST, xirc.RPL_WHOISMODES, xirc.RPL_WHOISSECURE: dc, cmd := uc.currentPendingCommand("WHOIS") if cmd == nil { @@ -1568,7 +1569,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return nil } - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) case irc.RPL_ENDOFWHOIS: dc, cmd := uc.dequeueCommand("WHOIS") if cmd == nil { @@ -1577,7 +1578,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return nil } - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) case "INVITE": var nick, channel string if err := parseMessageParams(msg, &nick, &channel); err != nil { @@ -1590,7 +1591,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if !weAreInvited && !dc.caps.IsEnabled("invite-notify") { return } - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) }) if weAreInvited { @@ -1602,7 +1603,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return err } - uc.forwardMsgByID(downstreamID, msg) + uc.forwardMsgByID(ctx, downstreamID, msg) case irc.RPL_MONONLINE, irc.RPL_MONOFFLINE: var targetsStr string if err := parseMessageParams(msg, nil, &targetsStr); err != nil { @@ -1640,7 +1641,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err for _, target := range targets { prefix := irc.ParsePrefix(target) if dc.monitored.Has(prefix.Name) { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: msg.Command, Params: []string{dc.nick, target}, @@ -1658,7 +1659,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.forEachDownstream(func(dc *downstreamConn) { for _, target := range targets { if dc.monitored.Has(target) { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: msg.Command, Params: []string{dc.nick, limit, target}, @@ -1667,7 +1668,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err } }) case irc.RPL_AWAY: - uc.forwardMsgByID(downstreamID, msg) + uc.forwardMsgByID(ctx, downstreamID, msg) case "AWAY": // Update user flags, if we already have the flags cached uu := uc.users.Get(msg.Prefix.Name) @@ -1683,7 +1684,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err }) } - uc.forwardMessage(msg) + uc.forwardMessage(ctx, msg) case "ACCOUNT": var account string if err := parseMessageParams(msg, &account); err != nil { @@ -1692,9 +1693,9 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.cacheUserInfo(msg.Prefix.Name, &upstreamUser{ Account: account, }) - uc.forwardMessage(msg) + uc.forwardMessage(ctx, msg) case irc.RPL_BANLIST, irc.RPL_INVITELIST, irc.RPL_EXCEPTLIST, irc.RPL_ENDOFBANLIST, irc.RPL_ENDOFINVITELIST, irc.RPL_ENDOFEXCEPTLIST: - uc.forwardMsgByID(downstreamID, msg) + uc.forwardMsgByID(ctx, downstreamID, msg) case irc.ERR_NOSUCHNICK: var nick, reason string if err := parseMessageParams(msg, nil, &nick, &reason); err != nil { @@ -1706,10 +1707,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if cmd != nil && cm(cmd.Params[len(cmd.Params)-1]) == cm(nick) { uc.dequeueCommand("WHOIS") if dc != nil { - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) } } else { - uc.forwardMsgByID(downstreamID, msg) + uc.forwardMsgByID(ctx, downstreamID, msg) } case xirc.ERR_UNKNOWNERROR, irc.ERR_UNKNOWNCOMMAND, irc.ERR_NEEDMOREPARAMS, irc.RPL_TRYAGAIN: var command, reason string @@ -1726,7 +1727,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err uc.saslStarted = false } - uc.forwardMsgByID(downstreamID, msg) + uc.forwardMsgByID(ctx, downstreamID, msg) case "FAIL": var command, code string if err := parseMessageParams(msg, &command, &code); err != nil { @@ -1741,7 +1742,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err downstreamID = dc.id } - uc.forwardMsgByID(downstreamID, msg) + uc.forwardMsgByID(ctx, downstreamID, msg) case "ACK": // Ignore case irc.RPL_NOWAWAY, irc.RPL_UNAWAY: @@ -1761,7 +1762,7 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err return nil } - uc.forwardMsgByID(downstreamID, msg) + uc.forwardMsgByID(ctx, downstreamID, msg) case irc.RPL_LISTSTART: // Ignore case "ERROR": @@ -1801,10 +1802,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err if !uc.registered { return registrationError{msg} } - uc.forwardMsgByID(downstreamID, msg) + uc.forwardMsgByID(ctx, downstreamID, msg) default: uc.logger.Printf("unhandled message: %v", msg) - uc.forwardMsgByID(downstreamID, msg) + uc.forwardMsgByID(ctx, downstreamID, msg) } return nil } @@ -2146,12 +2147,13 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, originID uint64 ch := uc.network.channels.Get(target) detached := ch != nil && ch.Detached + ctx := context.TODO() uc.forEachDownstream(func(dc *downstreamConn) { echo := dc.id == originID && msg.Prefix != nil && uc.isOurNick(msg.Prefix.Name) if !detached && (!echo || dc.caps.IsEnabled("echo-message")) { - dc.sendMessageWithID(msg, msgID) + dc.sendMessageWithID(ctx, msg, msgID) } else { - dc.advanceMessageWithID(msg, msgID) + dc.advanceMessageWithID(ctx, msg, msgID) } }) } diff --git a/user.go b/user.go index b9997cf..d83d14a 100644 --- a/user.go +++ b/user.go @@ -337,7 +337,7 @@ func (net *network) detach(ch *database.Channel) { } net.forEachDownstream(func(dc *downstreamConn) { - dc.SendMessage(&irc.Message{ + dc.SendMessage(context.TODO(), &irc.Message{ Prefix: dc.prefix(), Command: "PART", Params: []string{ch.Name, "Detach"}, @@ -364,7 +364,7 @@ func (net *network) attach(ctx context.Context, ch *database.Channel) { } net.forEachDownstream(func(dc *downstreamConn) { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.prefix(), Command: "JOIN", Params: []string{ch.Name}, @@ -642,17 +642,18 @@ func (u *user) run() { uc.updateAway() uc.updateMonitor() + ctx := context.TODO() uc.forEachDownstream(func(dc *downstreamConn) { - dc.updateSupportedCaps() + dc.updateSupportedCaps(ctx) if !dc.caps.IsEnabled("soju.im/bouncer-networks") { sendServiceNOTICE(dc, fmt.Sprintf("connected to %s", uc.network.GetName())) } - dc.updateNick() - dc.updateHost() - dc.updateRealname() - dc.updateAccount() + dc.updateNick(ctx) + dc.updateHost(ctx) + dc.updateRealname(ctx) + dc.updateAccount(ctx) dc.updateCasemapping() }) u.notifyBouncerNetworkState(uc.network.ID, irc.Tags{ @@ -729,7 +730,7 @@ func (u *user) run() { } if !u.Enabled { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Command: "ERROR", Params: []string{"This bouncer account is disabled"}, }) @@ -741,9 +742,9 @@ func (u *user) run() { if ircErr, ok := err.(ircError); ok { msg := ircErr.Message.Copy() msg.Prefix = dc.srv.prefix() - dc.SendMessage(msg) + dc.SendMessage(ctx, msg) } else { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Command: "ERROR", Params: []string{"Internal server error"}, }) @@ -799,7 +800,7 @@ func (u *user) run() { err := dc.handleMessage(context.TODO(), msg) if ircErr, ok := err.(ircError); ok { ircErr.Message.Prefix = dc.srv.prefix() - dc.SendMessage(ircErr.Message) + dc.SendMessage(context.TODO(), ircErr.Message) } else if err != nil { dc.logger.Printf("failed to handle message %q: %v", msg, err) dc.Close() @@ -807,7 +808,7 @@ func (u *user) run() { case eventBroadcast: msg := e.msg for _, dc := range u.downstreamConns { - dc.SendMessage(msg) + dc.SendMessage(context.TODO(), msg) } case eventUserUpdate: e.done <- u.updateUser(context.TODO(), func(record *database.User) error { @@ -882,7 +883,7 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { }) uc.forEachDownstream(func(dc *downstreamConn) { - dc.updateSupportedCaps() + dc.updateSupportedCaps(context.TODO()) }) // If the network has been removed, don't send a state change notification @@ -912,7 +913,7 @@ func (u *user) notifyBouncerNetworkState(netID int64, attrs irc.Tags) { netIDStr := fmt.Sprintf("%v", netID) for _, dc := range u.downstreamConns { if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") { - dc.SendMessage(&irc.Message{ + dc.SendMessage(context.TODO(), &irc.Message{ Prefix: dc.srv.prefix(), Command: "BOUNCER", Params: []string{"NETWORK", netIDStr, attrs.String()}, @@ -1116,7 +1117,7 @@ func (u *user) deleteNetwork(ctx context.Context, id int64) error { idStr := fmt.Sprintf("%v", network.ID) for _, dc := range u.downstreamConns { if dc.caps.IsEnabled("soju.im/bouncer-networks-notify") { - dc.SendMessage(&irc.Message{ + dc.SendMessage(ctx, &irc.Message{ Prefix: dc.srv.prefix(), Command: "BOUNCER", Params: []string{"NETWORK", idStr, "*"},