diff --git a/downstream.go b/downstream.go index 8ed0826..2e71da0 100644 --- a/downstream.go +++ b/downstream.go @@ -1077,6 +1077,45 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { } sendTopic(dc, ch) } + case "LIST": + // TODO: support ELIST when supported by all upstreams + + dc.user.pendingLISTsLock.Lock() + defer dc.user.pendingLISTsLock.Unlock() + + pl := pendingLIST{ + downstreamID: dc.id, + pendingCommands: make(map[int64]*irc.Message), + } + var upstreamChannels map[int64][]string + if len(msg.Params) > 0 { + upstreamChannels = make(map[int64][]string) + channels := strings.Split(msg.Params[0], ",") + for _, channel := range channels { + uc, upstreamChannel, err := dc.unmarshalEntity(channel) + if err != nil { + return err + } + upstreamChannels[uc.network.ID] = append(upstreamChannels[uc.network.ID], upstreamChannel) + } + } + + dc.user.pendingLISTs = append(dc.user.pendingLISTs, pl) + dc.forEachUpstream(func(uc *upstreamConn) { + var params []string + if upstreamChannels != nil { + if channels, ok := upstreamChannels[uc.network.ID]; ok { + params = []string{strings.Join(channels, ",")} + } else { + return + } + } + pl.pendingCommands[uc.network.ID] = &irc.Message{ + Command: "LIST", + Params: params, + } + uc.trySendList(dc.id) + }) case "NAMES": if len(msg.Params) == 0 { dc.SendMessage(&irc.Message{ diff --git a/upstream.go b/upstream.go index fcc2be7..86f8a11 100644 --- a/upstream.go +++ b/upstream.go @@ -59,6 +59,10 @@ type upstreamConn struct { saslClient sasl.Client saslStarted bool + + // set of LIST commands in progress, per downstream + // access is synchronized with user.pendingLISTsLock + pendingLISTDownstreamSet map[uint64]struct{} } func connectToUpstream(network *network) (*upstreamConn, error) { @@ -79,19 +83,20 @@ func connectToUpstream(network *network) (*upstreamConn, error) { outgoing := make(chan *irc.Message, 64) uc := &upstreamConn{ - network: network, - logger: logger, - net: netConn, - irc: irc.NewConn(netConn), - srv: network.user.srv, - user: network.user, - outgoing: outgoing, - channels: make(map[string]*upstreamChannel), - caps: make(map[string]string), - batches: make(map[string]batch), - availableChannelTypes: stdChannelTypes, - availableChannelModes: stdChannelModes, - availableMemberships: stdMemberships, + network: network, + logger: logger, + net: netConn, + irc: irc.NewConn(netConn), + srv: network.user.srv, + user: network.user, + outgoing: outgoing, + channels: make(map[string]*upstreamChannel), + caps: make(map[string]string), + batches: make(map[string]batch), + availableChannelTypes: stdChannelTypes, + availableChannelModes: stdChannelModes, + availableMemberships: stdMemberships, + pendingLISTDownstreamSet: make(map[uint64]struct{}), } go func() { @@ -136,6 +141,8 @@ func (uc *upstreamConn) Close() error { return fmt.Errorf("upstream connection already closed") } close(uc.closed) + + uc.endPendingLists(true) return nil } @@ -172,6 +179,81 @@ func (uc *upstreamConn) isChannel(entity string) bool { return false } +func (uc *upstreamConn) getPendingList() *pendingLIST { + uc.user.pendingLISTsLock.Lock() + defer uc.user.pendingLISTsLock.Unlock() + for _, pl := range uc.user.pendingLISTs { + if _, ok := pl.pendingCommands[uc.network.ID]; !ok { + continue + } + return &pl + } + return nil +} + +func (uc *upstreamConn) endPendingLists(all bool) (found bool) { + found = false + uc.user.pendingLISTsLock.Lock() + defer uc.user.pendingLISTsLock.Unlock() + for i := 0; i < len(uc.user.pendingLISTs); i++ { + pl := uc.user.pendingLISTs[i] + if _, ok := pl.pendingCommands[uc.network.ID]; !ok { + continue + } + delete(pl.pendingCommands, uc.network.ID) + if len(pl.pendingCommands) == 0 { + uc.user.pendingLISTs = append(uc.user.pendingLISTs[:i], uc.user.pendingLISTs[i+1:]...) + i-- + uc.forEachDownstreamByID(pl.downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LISTEND, + Params: []string{dc.nick, "End of /LIST"}, + }) + }) + } + found = true + if !all { + delete(uc.pendingLISTDownstreamSet, pl.downstreamID) + uc.user.forEachUpstream(func(uc *upstreamConn) { + uc.trySendList(pl.downstreamID) + }) + return + } + } + return +} + +func (uc *upstreamConn) trySendList(downstreamID uint64) { + // must be called with a lock in uc.user.pendingLISTsLock + + if _, ok := uc.pendingLISTDownstreamSet[downstreamID]; ok { + // a LIST command is already pending + // we will try again when that command is completed + return + } + + for _, pl := range uc.user.pendingLISTs { + if pl.downstreamID != downstreamID { + continue + } + // this is the first pending LIST command list of the downstream + listCommand, ok := pl.pendingCommands[uc.network.ID] + if !ok { + // there is no command for this upstream in these LIST commands + // do not send anything + continue + } + // there is a command for this upstream in these LIST commands + // send it now + + uc.SendMessageLabeled(downstreamID, listCommand) + + uc.pendingLISTDownstreamSet[downstreamID] = struct{}{} + return + } +} + func (uc *upstreamConn) parseMembershipPrefix(s string) (membership *membership, nick string) { for _, m := range uc.availableMemberships { if m.Prefix == s[0] { @@ -833,6 +915,29 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return fmt.Errorf("failed to parse topic time: %v", err) } ch.TopicTime = time.Unix(sec, 0) + case irc.RPL_LIST: + var channel, clients, topic string + if err := parseMessageParams(msg, nil, &channel, &clients, &topic); err != nil { + return err + } + + pl := uc.getPendingList() + if pl == nil { + return fmt.Errorf("unexpected RPL_LIST: no matching pending LIST") + } + + uc.forEachDownstreamByID(pl.downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: irc.RPL_LIST, + Params: []string{dc.nick, dc.marshalChannel(uc, channel), clients, topic}, + }) + }) + case irc.RPL_LISTEND: + ok := uc.endPendingLists(false) + if !ok { + return fmt.Errorf("unexpected RPL_LISTEND: no matching pending LIST") + } case irc.RPL_NAMREPLY: var name, statusStr, members string if err := parseMessageParams(msg, nil, &statusStr, &name, &members); err != nil { @@ -1090,6 +1195,25 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { Params: []string{dc.nick, dc.marshalNick(uc, nick), dc.marshalChannel(uc, channel)}, }) }) + case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN: + var command, reason string + if err := parseMessageParams(msg, nil, &command, &reason); err != nil { + return err + } + + if command == "LIST" { + ok := uc.endPendingLists(false) + if !ok { + return fmt.Errorf("unexpected response for LIST: %q: no matching pending LIST", msg.Command) + } + uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { + dc.SendMessage(&irc.Message{ + Prefix: uc.srv.prefix(), + Command: msg.Command, + Params: []string{dc.nick, "LIST", reason}, + }) + }) + } case "TAGMSG": // TODO: relay to downstream connections that accept message-tags case "ACK": @@ -1100,6 +1224,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { // Ignore case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD: // Ignore + case irc.RPL_LISTSTART: + // Ignore case rpl_localusers, rpl_globalusers: // Ignore case irc.RPL_STATSVLINE, rpl_statsping, irc.RPL_STATSBLINE, irc.RPL_STATSDLINE: diff --git a/user.go b/user.go index 8c19079..d06df72 100644 --- a/user.go +++ b/user.go @@ -96,6 +96,16 @@ type user struct { networks []*network downstreamConns []*downstreamConn + + // LIST commands in progress + pendingLISTsLock sync.Mutex + pendingLISTs []pendingLIST +} + +type pendingLIST struct { + downstreamID uint64 + // list of per-upstream LIST commands not yet sent or completed + pendingCommands map[int64]*irc.Message } func newUser(srv *Server, record *User) *user {