diff --git a/downstream.go b/downstream.go index d6d287a..95ee7d0 100644 --- a/downstream.go +++ b/downstream.go @@ -1864,7 +1864,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { return nil } - uc.enqueueLIST(dc, msg) + uc.enqueueCommand(dc, msg) case "NAMES": if len(msg.Params) == 0 { dc.SendMessage(&irc.Message{ @@ -1986,7 +1986,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { params = append(params, options) } - uc.SendMessageLabeled(dc.id, &irc.Message{ + uc.enqueueCommand(dc, &irc.Message{ Command: "WHO", Params: params, }) diff --git a/upstream.go b/upstream.go index 19e3da5..ae9843f 100644 --- a/upstream.go +++ b/upstream.go @@ -74,7 +74,7 @@ func (uc *upstreamChannel) updateAutoDetach(dur time.Duration) { type pendingUpstreamCommand struct { downstreamID uint64 - cmd *irc.Message + msg *irc.Message } type upstreamConn struct { @@ -109,10 +109,10 @@ type upstreamConn struct { casemapIsSet bool - // Queue of LIST commands in progress. The first entry has been sent to the - // server and is awaiting reply. The following entries have not been sent - // yet. - pendingLIST []pendingUpstreamCommand + // Queue of commands in progress, indexed by type. The first entry has been + // sent to the server and is awaiting reply. The following entries have not + // been sent yet. + pendingCmds map[string][]pendingUpstreamCommand gotMotd bool } @@ -208,6 +208,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) { availableChannelModes: stdChannelModes, availableMemberships: stdMemberships, isupport: make(map[string]*string), + pendingCmds: make(map[string][]pendingUpstreamCommand), } return uc, nil } @@ -225,6 +226,15 @@ func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn) }) } +func (uc *upstreamConn) downstreamByID(id uint64) *downstreamConn { + for _, dc := range uc.user.downstreamConns { + if dc.id == id { + return dc + } + } + return nil +} + func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { ch := uc.channels.Value(name) if ch == nil { @@ -241,63 +251,85 @@ func (uc *upstreamConn) isOurNick(nick string) bool { return uc.nickCM == uc.network.casemap(nick) } -func (uc *upstreamConn) endPendingLISTs() { - for _, pendingCmd := range uc.pendingLIST { - uc.forEachDownstreamByID(pendingCmd.downstreamID, func(dc *downstreamConn) { - dc.SendMessage(&irc.Message{ - Prefix: dc.srv.prefix(), - Command: irc.RPL_LISTEND, - Params: []string{dc.nick, "End of /LIST"}, - }) - }) - } - uc.pendingLIST = nil -} +func (uc *upstreamConn) endPendingCommands() { + for _, l := range uc.pendingCmds { + for _, pendingCmd := range l { + dc := uc.downstreamByID(pendingCmd.downstreamID) + if dc == nil { + continue + } -func (uc *upstreamConn) sendNextPendingLIST() { - if len(uc.pendingLIST) == 0 { - return - } - uc.SendMessage(uc.pendingLIST[0].cmd) -} - -func (uc *upstreamConn) enqueueLIST(dc *downstreamConn, cmd *irc.Message) { - uc.pendingLIST = append(uc.pendingLIST, pendingUpstreamCommand{ - downstreamID: dc.id, - cmd: cmd, - }) - - if len(uc.pendingLIST) == 1 { - uc.sendNextPendingLIST() - } -} - -func (uc *upstreamConn) currentPendingLIST() (*downstreamConn, *irc.Message) { - if len(uc.pendingLIST) == 0 { - return nil, nil - } - - pendingCmd := uc.pendingLIST[0] - for _, dc := range uc.user.downstreamConns { - if dc.id == pendingCmd.downstreamID { - return dc, pendingCmd.cmd + switch pendingCmd.msg.Command { + case "LIST": + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "End of /LIST"}, + }) + case "WHO": + mask := "*" + if len(pendingCmd.msg.Params) > 0 { + mask = pendingCmd.msg.Params[0] + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, mask, "End of /WHO"}, + }) + default: + panic(fmt.Errorf("Unsupported pending command %q", pendingCmd.msg.Command)) + } } } - return nil, pendingCmd.cmd + uc.pendingCmds = make(map[string][]pendingUpstreamCommand) } -func (uc *upstreamConn) dequeueLIST() (*downstreamConn, *irc.Message) { - dc, cmd := uc.currentPendingLIST() +func (uc *upstreamConn) sendNextPendingCommand(cmd string) { + if len(uc.pendingCmds[cmd]) == 0 { + return + } + uc.SendMessage(uc.pendingCmds[cmd][0].msg) +} - if len(uc.pendingLIST) > 0 { - copy(uc.pendingLIST, uc.pendingLIST[1:]) - uc.pendingLIST = uc.pendingLIST[:len(uc.pendingLIST)-1] +func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) { + switch msg.Command { + case "LIST", "WHO": + // Supported + default: + panic(fmt.Errorf("Unsupported pending command %q", msg.Command)) } - uc.sendNextPendingLIST() + uc.pendingCmds[msg.Command] = append(uc.pendingCmds[msg.Command], pendingUpstreamCommand{ + downstreamID: dc.id, + msg: msg, + }) - return dc, cmd + if len(uc.pendingCmds[msg.Command]) == 1 { + uc.sendNextPendingCommand(msg.Command) + } +} + +func (uc *upstreamConn) currentPendingCommand(cmd string) (*downstreamConn, *irc.Message) { + if len(uc.pendingCmds[cmd]) == 0 { + return nil, nil + } + + pendingCmd := uc.pendingCmds[cmd][0] + return uc.downstreamByID(pendingCmd.downstreamID), pendingCmd.msg +} + +func (uc *upstreamConn) dequeueCommand(cmd string) (*downstreamConn, *irc.Message) { + dc, msg := uc.currentPendingCommand(cmd) + + if len(uc.pendingCmds[cmd]) > 0 { + copy(uc.pendingCmds[cmd], uc.pendingCmds[cmd][1:]) + uc.pendingCmds[cmd] = uc.pendingCmds[cmd][:len(uc.pendingCmds[cmd])-1] + } + + uc.sendNextPendingCommand(cmd) + + return dc, msg } func (uc *upstreamConn) parseMembershipPrefix(s string) (ms *memberships, nick string) { @@ -1095,7 +1127,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return err } - dc, cmd := uc.currentPendingLIST() + dc, cmd := uc.currentPendingCommand("LIST") if cmd == nil { return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST") } else if dc == nil { @@ -1108,7 +1140,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { Params: []string{dc.nick, dc.marshalEntity(uc.network, channel), clients, topic}, }) case irc.RPL_LISTEND: - dc, cmd := uc.dequeueLIST() + dc, cmd := uc.dequeueCommand("LIST") if cmd == nil { return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST") } else if dc == nil { @@ -1195,6 +1227,13 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return err } + dc, cmd := uc.currentPendingCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_WHOREPLY: no matching pending WHO") + } else if dc == nil { + return nil + } + parts := strings.SplitN(trailing, " ", 2) if len(parts) != 2 { return fmt.Errorf("received malformed RPL_WHOREPLY: wrong trailing parameter: %s", trailing) @@ -1208,35 +1247,46 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { trailing = strconv.Itoa(hops) + " " + realname - uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { - channel := channel - if channel != "*" { - channel = dc.marshalEntity(uc.network, channel) - } - nick := dc.marshalEntity(uc.network, nick) - dc.SendMessage(&irc.Message{ - Prefix: dc.srv.prefix(), - Command: irc.RPL_WHOREPLY, - Params: []string{dc.nick, channel, username, host, server, nick, mode, trailing}, - }) + if channel != "*" { + channel = dc.marshalEntity(uc.network, channel) + } + nick = dc.marshalEntity(uc.network, nick) + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_WHOREPLY, + Params: []string{dc.nick, channel, username, host, server, nick, mode, trailing}, }) + case rpl_whospcrpl: + dc, cmd := uc.currentPendingCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_WHOSPCRPL: no matching pending WHO") + } else if dc == nil { + return nil + } + + // Only supported in single-upstream mode, so forward as-is + dc.SendMessage(msg) case irc.RPL_ENDOFWHO: var name string if err := parseMessageParams(msg, nil, &name); err != nil { return err } - uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { - name := name - if name != "*" { - // TODO: support WHO masks - name = dc.marshalEntity(uc.network, name) - } - dc.SendMessage(&irc.Message{ - Prefix: dc.srv.prefix(), - Command: irc.RPL_ENDOFWHO, - Params: []string{dc.nick, name, "End of /WHO list"}, - }) + dc, cmd := uc.dequeueCommand("WHO") + if cmd == nil { + return fmt.Errorf("unexpected RPL_ENDOFWHO: no matching pending WHO") + } else if dc == nil { + return nil + } + + mask := "*" + if len(cmd.Params) > 0 { + mask = cmd.Params[0] + } + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_ENDOFWHO, + Params: []string{dc.nick, mask, "End of /WHO list"}, }) case irc.RPL_WHOISUSER: var nick, username, host, realname string @@ -1436,8 +1486,11 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return err } - if command == "LIST" { - uc.endPendingLISTs() + if command == "LIST" || command == "WHO" { + dc, _ := uc.dequeueCommand(command) + if dc != nil && downstreamID == 0 { + downstreamID = dc.id + } } uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { @@ -1453,11 +1506,6 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { // Ignore case irc.RPL_YOURHOST, irc.RPL_CREATED: // Ignore - case rpl_whospcrpl: - // Not supported in multi-upstream mode, forward as-is - uc.forEachDownstream(func(dc *downstreamConn) { - dc.SendMessage(msg) - }) case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: fallthrough case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE: diff --git a/user.go b/user.go index 6f24e9c..0de7d6d 100644 --- a/user.go +++ b/user.go @@ -681,7 +681,7 @@ func (u *user) run() { func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { uc.network.conn = nil - uc.endPendingLISTs() + uc.endPendingCommands() for _, entry := range uc.channels.innerMap { uch := entry.value.(*upstreamChannel)