From c3ab11de4e86216143cb9b7cbdd5a40e2c2f9280 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 8 Aug 2022 11:30:10 +0200 Subject: [PATCH] downstream: drop downstreamConn.unmarshalEntity{,Network} --- downstream.go | 354 ++++++++++++++++---------------------------------- service.go | 34 +++-- 2 files changed, 139 insertions(+), 249 deletions(-) diff --git a/downstream.go b/downstream.go index 5c05290..12dcaff 100644 --- a/downstream.go +++ b/downstream.go @@ -42,6 +42,13 @@ func newUnknownCommandError(cmd string) ircError { }} } +func newUnknownIRCError(cmd, text string) ircError { + return ircError{&irc.Message{ + Command: xirc.ERR_UNKNOWNERROR, + Params: []string{"*", cmd, text}, + }} +} + func newNeedMoreParamsError(cmd string) ircError { return ircError{&irc.Message{ Command: irc.ERR_NEEDMOREPARAMS, @@ -402,8 +409,8 @@ func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) { }) } -// upstream returns the upstream connection, if any. If there are zero or if -// there are multiple upstream connections, it returns nil. +// upstream returns the upstream connection, if any. If there are zero upstream +// connections, it returns nil. func (dc *downstreamConn) upstream() *upstreamConn { if dc.network == nil { return nil @@ -411,6 +418,16 @@ func (dc *downstreamConn) upstream() *upstreamConn { return dc.network.conn } +func (dc *downstreamConn) upstreamForCommand(cmd string) (*upstreamConn, error) { + if dc.network == nil { + return nil, newUnknownIRCError(cmd, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?") + } + if dc.network.conn == nil { + return nil, newUnknownIRCError(cmd, "Disconnected from upstream network") + } + return dc.network.conn, nil +} + func isOurNick(net *network, nick string) bool { // TODO: this doesn't account for nick changes if net.conn != nil { @@ -423,38 +440,6 @@ func isOurNick(net *network, nick string) bool { return net.casemap(nick) == net.casemap(database.GetNick(&net.user.User, &net.Network)) } -// unmarshalEntityNetwork converts a downstream entity name (ie. channel or -// nick) into an upstream entity name. -// -// This involves removing the "/" suffix. -func (dc *downstreamConn) unmarshalEntityNetwork(name string) (*network, string, error) { - if dc.network != nil { - return dc.network, name, nil - } - return nil, "", ircError{&irc.Message{ - Command: irc.ERR_NOSUCHCHANNEL, - Params: []string{dc.nick, name, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?"}, - }} -} - -// unmarshalEntity is the same as unmarshalEntityNetwork, but returns the -// upstream connection and fails if the upstream is disconnected. -func (dc *downstreamConn) unmarshalEntity(name string) (*upstreamConn, string, error) { - net, name, err := dc.unmarshalEntityNetwork(name) - if err != nil { - return nil, "", err - } - - if net.conn == nil { - return nil, "", ircError{&irc.Message{ - Command: irc.ERR_NOSUCHCHANNEL, - Params: []string{dc.nick, name, "Disconnected from upstream network"}, - }} - } - - return net.conn, name, nil -} - func (dc *downstreamConn) ReadMessage() (*irc.Message, error) { msg, err := dc.conn.ReadMessage() if err != nil { @@ -1802,6 +1787,11 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } } case "JOIN": + uc, err := dc.upstreamForCommand(msg.Command) + if err != nil { + return err + } + var namesStr string if err := parseMessageParams(msg, &namesStr); err != nil { return err @@ -1813,17 +1803,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } for i, name := range strings.Split(namesStr, ",") { - uc, upstreamName, err := dc.unmarshalEntity(name) - if err != nil { - return err - } - var key string if len(keys) > i { key = keys[i] } - if !uc.isChannel(upstreamName) { + if !uc.isChannel(name) { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_NOSUCHCHANNEL, @@ -1836,8 +1821,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. // because some clients automatically send JOIN messages in bulk // when reconnecting to the bouncer. We don't want to flood the // upstream connection with these. - if !uc.channels.Has(upstreamName) { - params := []string{upstreamName} + if !uc.channels.Has(name) { + params := []string{name} if key != "" { params = append(params, key) } @@ -1847,7 +1832,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }) } - ch := uc.network.channels.Get(upstreamName) + ch := uc.network.channels.Get(name) if ch != nil { // Don't clear the channel key if there's one set // TODO: add a way to unset the channel key @@ -1857,16 +1842,21 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. uc.network.attach(ctx, ch) } else { ch = &database.Channel{ - Name: upstreamName, + Name: name, Key: key, } uc.network.channels.Set(ch) } if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { - dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) + dc.logger.Printf("failed to create or update channel %q: %v", name, err) } } case "PART": + uc, err := dc.upstreamForCommand(msg.Command) + if err != nil { + return err + } + var namesStr string if err := parseMessageParams(msg, &namesStr); err != nil { return err @@ -1878,27 +1868,22 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } for _, name := range strings.Split(namesStr, ",") { - uc, upstreamName, err := dc.unmarshalEntity(name) - if err != nil { - return err - } - if strings.EqualFold(reason, "detach") { - ch := uc.network.channels.Get(upstreamName) + ch := uc.network.channels.Get(name) if ch != nil { uc.network.detach(ch) } else { ch = &database.Channel{ - Name: upstreamName, + Name: name, Detached: true, } uc.network.channels.Set(ch) } if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { - dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) + dc.logger.Printf("failed to create or update channel %q: %v", name, err) } } else { - params := []string{upstreamName} + params := []string{name} if reason != "" { params = append(params, reason) } @@ -1907,69 +1892,20 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Params: params, }) - if err := uc.network.deleteChannel(ctx, upstreamName); err != nil { - dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err) + if err := uc.network.deleteChannel(ctx, name); err != nil { + dc.logger.Printf("failed to delete channel %q: %v", name, err) } - uc.network.pushTargets.Del(upstreamName) + uc.network.pushTargets.Del(name) } } case "KICK": - var channelStr, userStr string - if err := parseMessageParams(msg, &channelStr, &userStr); err != nil { + uc, err := dc.upstreamForCommand(msg.Command) + if err != nil { return err } - channels := strings.Split(channelStr, ",") - users := strings.Split(userStr, ",") - - var reason string - if len(msg.Params) > 2 { - reason = msg.Params[2] - } - - if len(channels) != 1 && len(channels) != len(users) { - return ircError{&irc.Message{ - Command: irc.ERR_BADCHANMASK, - Params: []string{dc.nick, channelStr, "Bad channel mask"}, - }} - } - - for i, user := range users { - var channel string - if len(channels) == 1 { - channel = channels[0] - } else { - channel = channels[i] - } - - ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel) - if err != nil { - return err - } - - ucUser, upstreamUser, err := dc.unmarshalEntity(user) - if err != nil { - return err - } - - if ucChannel != ucUser { - return ircError{&irc.Message{ - Command: irc.ERR_USERNOTINCHANNEL, - Params: []string{dc.nick, user, channel, "They are on another network"}, - }} - } - uc := ucChannel - - params := []string{upstreamChannel, upstreamUser} - if reason != "" { - params = append(params, reason) - } - uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ - Command: "KICK", - Params: params, - }) - } + uc.SendMessageLabeled(ctx, dc.id, msg) case "MODE": var name string if err := parseMessageParams(msg, &name); err != nil { @@ -1992,7 +1928,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.ERR_UMODEUNKNOWNFLAG, - Params: []string{dc.nick, "Cannot change user mode in multi-upstream mode"}, + Params: []string{dc.nick, "Cannot change user mode on bouncer connection"}, }) } } else { @@ -2010,12 +1946,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return nil } - uc, upstreamName, err := dc.unmarshalEntity(name) + uc, err := dc.upstreamForCommand(msg.Command) if err != nil { return err } - if !uc.isChannel(upstreamName) { + if !uc.isChannel(name) { return ircError{&irc.Message{ Command: irc.ERR_USERSDONTMATCH, Params: []string{dc.nick, "Cannot change mode for other users"}, @@ -2023,14 +1959,14 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } if modeStr != "" { - params := []string{upstreamName, modeStr} + params := []string{name, modeStr} params = append(params, msg.Params[2:]...) uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "MODE", Params: params, }) } else { - ch := uc.channels.Get(upstreamName) + ch := uc.channels.Get(name) if ch == nil { return ircError{&irc.Message{ Command: irc.ERR_NOSUCHCHANNEL, @@ -2062,12 +1998,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } } case "TOPIC": - var channel string - if err := parseMessageParams(msg, &channel); err != nil { + var name string + if err := parseMessageParams(msg, &name); err != nil { return err } - uc, upstreamName, err := dc.unmarshalEntity(channel) + uc, err := dc.upstreamForCommand(msg.Command) if err != nil { return err } @@ -2076,48 +2012,31 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. topic := msg.Params[1] uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "TOPIC", - Params: []string{upstreamName, topic}, + Params: []string{name, topic}, }) } else { // getting topic - ch := uc.channels.Get(upstreamName) + ch := uc.channels.Get(name) if ch == nil { return ircError{&irc.Message{ Command: irc.ERR_NOSUCHCHANNEL, - Params: []string{dc.nick, upstreamName, "No such channel"}, + Params: []string{dc.nick, name, "No such channel"}, }} } sendTopic(dc, ch) } case "LIST": - network := dc.network - if network == nil && len(msg.Params) > 0 { - var err error - network, msg.Params[0], err = dc.unmarshalEntityNetwork(msg.Params[0]) - if err != nil { - return err - } - } - if network == nil { - dc.SendMessage(&irc.Message{ - Prefix: dc.srv.prefix(), - Command: irc.RPL_LISTEND, - Params: []string{dc.nick, "LIST without a network suffix is not supported in multi-upstream mode"}, - }) - return nil - } - - uc := network.conn - if uc == nil { - dc.SendMessage(&irc.Message{ - Prefix: dc.srv.prefix(), - Command: irc.RPL_LISTEND, - Params: []string{dc.nick, "Disconnected from upstream server"}, - }) - return nil + uc, err := dc.upstreamForCommand(msg.Command) + if err != nil { + return err } uc.enqueueCommand(dc, msg) case "NAMES": + uc, err := dc.upstreamForCommand(msg.Command) + if err != nil { + return err + } + if len(msg.Params) == 0 { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), @@ -2128,20 +2047,15 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } channels := strings.Split(msg.Params[0], ",") - for _, channel := range channels { - uc, upstreamName, err := dc.unmarshalEntity(channel) - if err != nil { - return err - } - - ch := uc.channels.Get(upstreamName) + for _, name := range channels { + ch := uc.channels.Get(name) if ch != nil { sendNames(dc, ch) } else { // NAMES on a channel we have not joined, ask upstream uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ Command: "NAMES", - Params: []string{upstreamName}, + Params: []string{name}, }) } } @@ -2229,12 +2143,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }) return nil } - - // TODO: properly support WHO masks - uc, upstreamMask, err := dc.unmarshalEntity(mask) - if err != nil { - // Ignore the error here, because clients don't know how to deal - // with anything other than RPL_ENDOFWHO + if dc.network == nil { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_ENDOFWHO, @@ -2243,15 +2152,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return nil } - params := []string{upstreamMask} - if options != "" { - params = append(params, options) + uc, err := dc.upstreamForCommand(msg.Command) + if err != nil { + return err } - uc.enqueueCommand(dc, &irc.Message{ - Command: "WHO", - Params: params, - }) + uc.enqueueCommand(dc, msg) case "WHOIS": if len(msg.Params) == 0 { return ircError{&irc.Message{ @@ -2260,12 +2166,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }} } - var target, mask string + var mask string if len(msg.Params) == 1 { - target = "" mask = msg.Params[0] } else { - target = msg.Params[0] mask = msg.Params[1] } // TODO: support multiple WHOIS users @@ -2337,27 +2241,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return nil } - // TODO: support WHOIS masks - uc, upstreamNick, err := dc.unmarshalEntity(mask) + uc, err := dc.upstreamForCommand(msg.Command) if err != nil { return err } - var params []string - if target != "" { - if target == mask { // WHOIS nick nick - params = []string{upstreamNick, upstreamNick} - } else { - params = []string{target, upstreamNick} - } - } else { - params = []string{upstreamNick} - } - - uc.enqueueCommand(dc, &irc.Message{ - Command: "WHOIS", - Params: params, - }) + uc.enqueueCommand(dc, msg) case "PRIVMSG", "NOTICE", "TAGMSG": var targetsStr, text string if msg.Command != "TAGMSG" { @@ -2432,16 +2321,16 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. continue } - uc, upstreamName, err := dc.unmarshalEntity(name) + uc, err := dc.upstreamForCommand(msg.Command) if err != nil { return err } - if msg.Command == "PRIVMSG" && uc.network.casemap(upstreamName) == "nickserv" { + if msg.Command == "PRIVMSG" && uc.network.casemap(name) == "nickserv" { dc.handleNickServPRIVMSG(ctx, uc, text) } - upstreamParams := []string{upstreamName} + upstreamParams := []string{name} if msg.Command != "TAGMSG" { upstreamParams = append(upstreamParams, text) } @@ -2456,7 +2345,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. // when it is echoed from the upstream. // Otherwise, produce/log it here because it's the last time we'll see it. if !uc.caps.IsEnabled("echo-message") { - echoParams := []string{upstreamName} + echoParams := []string{name} if msg.Command != "TAGMSG" { echoParams = append(echoParams, text) } @@ -2476,39 +2365,18 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. Command: msg.Command, Params: echoParams, } - uc.produce(upstreamName, echoMsg, dc.id) + uc.produce(name, echoMsg, dc.id) } - uc.updateChannelAutoDetach(upstreamName) + uc.updateChannelAutoDetach(name) } case "INVITE": - var user, channel string - if err := parseMessageParams(msg, &user, &channel); err != nil { - return err - } - - ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel) + uc, err := dc.upstreamForCommand(msg.Command) if err != nil { return err } - ucUser, upstreamUser, err := dc.unmarshalEntity(user) - if err != nil { - return err - } - - if ucChannel != ucUser { - return ircError{&irc.Message{ - Command: irc.ERR_USERNOTINCHANNEL, - Params: []string{dc.nick, user, channel, "They are on another network"}, - }} - } - uc := ucChannel - - uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ - Command: "INVITE", - Params: []string{upstreamUser, upstreamChannel}, - }) + uc.SendMessageLabeled(ctx, dc.id, msg) case "AUTHENTICATE": // Post-connection-registration AUTHENTICATE is unsupported in // multi-upstream mode, or if the upstream doesn't support SASL @@ -2747,15 +2615,16 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if !ok { return ircError{&irc.Message{ Command: irc.ERR_UNKNOWNCOMMAND, - Params: []string{dc.nick, "CHATHISTORY", "Unknown command"}, + Params: []string{dc.nick, "CHATHISTORY", "Chat history disabled"}, }} } - network, entity, err := dc.unmarshalEntityNetwork(target) - if err != nil { - return err + network := dc.network + if network == nil { + return newChatHistoryError(subcommand, "Cannot fetch chat history on bouncer connection") } - entity = network.casemap(entity) + + target = network.casemap(target) // TODO: support msgid criteria var bounds [2]time.Time @@ -2791,7 +2660,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. options := msgstore.LoadMessageOptions{ Network: &network.Network, - Entity: entity, + Entity: target, Limit: limit, Events: eventPlayback, } @@ -2871,22 +2740,26 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. return nil } - network, entity, err := dc.unmarshalEntityNetwork(target) - if err != nil { - return err + network := dc.network + if network == nil { + return ircError{&irc.Message{ + Command: "FAIL", + Params: []string{msg.Command, "INTERNAL_ERROR", target, "Cannot set read markers on bouncer connection"}, + }} } - entityCM := network.casemap(entity) - r, err := dc.srv.db.GetReadReceipt(ctx, network.ID, entityCM) + targetCM := network.casemap(target) + + r, err := dc.srv.db.GetReadReceipt(ctx, network.ID, targetCM) if err != nil { - dc.logger.Printf("failed to get the read receipt for %q: %v", entity, err) + dc.logger.Printf("failed to get the read receipt for %q: %v", target, err) return ircError{&irc.Message{ Command: "FAIL", Params: []string{msg.Command, "INTERNAL_ERROR", target, "Internal error"}, }} } else if r == nil { r = &database.ReadReceipt{ - Target: entityCM, + Target: targetCM, } } @@ -2915,7 +2788,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. if r.Timestamp.Before(timestamp) { r.Timestamp = timestamp if err := dc.srv.db.StoreReadReceipt(ctx, network.ID, r); err != nil { - dc.logger.Printf("failed to store receipt for %q: %v", entity, err) + dc.logger.Printf("failed to store receipt for %q: %v", target, err) return ircError{&irc.Message{ Command: "FAIL", Params: []string{msg.Command, "INTERNAL_ERROR", target, "Internal error"}, @@ -2938,17 +2811,17 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. d.SendMessage(&irc.Message{ Prefix: d.prefix(), Command: cmd, - Params: []string{entity, timestampStr}, + Params: []string{target, timestampStr}, }) } }) - if broadcast && network.pushTargets.Has(entity) { + if broadcast && network.pushTargets.Has(target) { // TODO: only broadcast if draft/read-marker has been negotiated - network.pushTargets.Del(entity) + network.pushTargets.Del(target) go network.broadcastWebPush(&irc.Message{ Command: "MARKREAD", - Params: []string{entity, timestampStr}, + Params: []string{target, timestampStr}, }) } case "SEARCH": @@ -2965,7 +2838,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. } attrs := irc.ParseTags(attrsStr) - var uc *upstreamConn + var network *network const searchMaxLimit = 100 opts := msgstore.SearchMessageOptions{ Limit: searchMaxLimit, @@ -2990,15 +2863,14 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. case "from": opts.From = value case "in": - u, upstreamName, err := dc.unmarshalEntity(value) - if err != nil { + network = dc.network + if network == nil { return ircError{&irc.Message{ Command: "FAIL", - Params: []string{"SEARCH", "INVALID_PARAMS", name, "Invalid criteria"}, + Params: []string{"SEARCH", "INVALID_PARAMS", name, "Cannot search on bouncer connection"}, }} } - uc = u - opts.In = u.network.casemap(upstreamName) + opts.In = network.casemap(value) case "text": opts.Text = value case "limit": @@ -3012,7 +2884,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. opts.Limit = limit } } - if uc == nil { + if network == nil { return ircError{&irc.Message{ Command: "FAIL", Params: []string{"SEARCH", "INVALID_PARAMS", "in", "The in parameter is mandatory"}, @@ -3022,7 +2894,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. opts.Limit = searchMaxLimit } - messages, err := store.Search(ctx, &uc.network.Network, &opts) + messages, err := store.Search(ctx, &network.Network, &opts) if err != nil { dc.logger.Printf("failed fetching messages for search: %v", err) return ircError{&irc.Message{ diff --git a/service.go b/service.go index 7cc110b..97de0f3 100644 --- a/service.go +++ b/service.go @@ -1139,12 +1139,28 @@ func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params return err } - uc, upstreamName, err := dc.unmarshalEntity(name) - if err != nil { - return fmt.Errorf("unknown channel %q", name) + network := dc.network + if network == nil { + l := strings.SplitN(name, "/", 2) + if len(l) != 2 { + return fmt.Errorf("missing network name") + } + name = l[0] + netName := l[1] + + for _, n := range dc.user.networks { + if netName == n.GetName() { + network = n + break + } + } + + if network == nil { + return fmt.Errorf("unknown network %q", netName) + } } - ch := uc.network.channels.Get(upstreamName) + ch := network.channels.Get(name) if ch == nil { return fmt.Errorf("unknown channel %q", name) } @@ -1155,15 +1171,17 @@ func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params if fs.Detached != nil && *fs.Detached != ch.Detached { if *fs.Detached { - uc.network.detach(ch) + network.detach(ch) } else { - uc.network.attach(ctx, ch) + network.attach(ctx, ch) } } - uc.updateChannelAutoDetach(upstreamName) + if network.conn != nil { + network.conn.updateChannelAutoDetach(name) + } - if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { + if err := dc.srv.db.StoreChannel(ctx, network.ID, ch); err != nil { return fmt.Errorf("failed to update channel: %v", err) }