downstream: drop downstreamConn.unmarshalEntity{,Network}

This commit is contained in:
Simon Ser 2022-08-08 11:30:10 +02:00
parent dde4ee9518
commit c3ab11de4e
2 changed files with 139 additions and 249 deletions

View File

@ -42,6 +42,13 @@ func newUnknownCommandError(cmd string) ircError {
}}
}
func newUnknownIRCError(cmd, text string) ircError {
return ircError{&irc.Message{
Command: xirc.ERR_UNKNOWNERROR,
Params: []string{"*", cmd, text},
}}
}
func newNeedMoreParamsError(cmd string) ircError {
return ircError{&irc.Message{
Command: irc.ERR_NEEDMOREPARAMS,
@ -402,8 +409,8 @@ func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
})
}
// upstream returns the upstream connection, if any. If there are zero or if
// there are multiple upstream connections, it returns nil.
// upstream returns the upstream connection, if any. If there are zero upstream
// connections, it returns nil.
func (dc *downstreamConn) upstream() *upstreamConn {
if dc.network == nil {
return nil
@ -411,6 +418,16 @@ func (dc *downstreamConn) upstream() *upstreamConn {
return dc.network.conn
}
func (dc *downstreamConn) upstreamForCommand(cmd string) (*upstreamConn, error) {
if dc.network == nil {
return nil, newUnknownIRCError(cmd, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?")
}
if dc.network.conn == nil {
return nil, newUnknownIRCError(cmd, "Disconnected from upstream network")
}
return dc.network.conn, nil
}
func isOurNick(net *network, nick string) bool {
// TODO: this doesn't account for nick changes
if net.conn != nil {
@ -423,38 +440,6 @@ func isOurNick(net *network, nick string) bool {
return net.casemap(nick) == net.casemap(database.GetNick(&net.user.User, &net.Network))
}
// unmarshalEntityNetwork converts a downstream entity name (ie. channel or
// nick) into an upstream entity name.
//
// This involves removing the "/<network>" suffix.
func (dc *downstreamConn) unmarshalEntityNetwork(name string) (*network, string, error) {
if dc.network != nil {
return dc.network, name, nil
}
return nil, "", ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{dc.nick, name, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?"},
}}
}
// unmarshalEntity is the same as unmarshalEntityNetwork, but returns the
// upstream connection and fails if the upstream is disconnected.
func (dc *downstreamConn) unmarshalEntity(name string) (*upstreamConn, string, error) {
net, name, err := dc.unmarshalEntityNetwork(name)
if err != nil {
return nil, "", err
}
if net.conn == nil {
return nil, "", ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{dc.nick, name, "Disconnected from upstream network"},
}}
}
return net.conn, name, nil
}
func (dc *downstreamConn) ReadMessage() (*irc.Message, error) {
msg, err := dc.conn.ReadMessage()
if err != nil {
@ -1802,6 +1787,11 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
}
case "JOIN":
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
var namesStr string
if err := parseMessageParams(msg, &namesStr); err != nil {
return err
@ -1813,17 +1803,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
for i, name := range strings.Split(namesStr, ",") {
uc, upstreamName, err := dc.unmarshalEntity(name)
if err != nil {
return err
}
var key string
if len(keys) > i {
key = keys[i]
}
if !uc.isChannel(upstreamName) {
if !uc.isChannel(name) {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.ERR_NOSUCHCHANNEL,
@ -1836,8 +1821,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
// because some clients automatically send JOIN messages in bulk
// when reconnecting to the bouncer. We don't want to flood the
// upstream connection with these.
if !uc.channels.Has(upstreamName) {
params := []string{upstreamName}
if !uc.channels.Has(name) {
params := []string{name}
if key != "" {
params = append(params, key)
}
@ -1847,7 +1832,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
})
}
ch := uc.network.channels.Get(upstreamName)
ch := uc.network.channels.Get(name)
if ch != nil {
// Don't clear the channel key if there's one set
// TODO: add a way to unset the channel key
@ -1857,16 +1842,21 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
uc.network.attach(ctx, ch)
} else {
ch = &database.Channel{
Name: upstreamName,
Name: name,
Key: key,
}
uc.network.channels.Set(ch)
}
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
dc.logger.Printf("failed to create or update channel %q: %v", name, err)
}
}
case "PART":
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
var namesStr string
if err := parseMessageParams(msg, &namesStr); err != nil {
return err
@ -1878,27 +1868,22 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
for _, name := range strings.Split(namesStr, ",") {
uc, upstreamName, err := dc.unmarshalEntity(name)
if err != nil {
return err
}
if strings.EqualFold(reason, "detach") {
ch := uc.network.channels.Get(upstreamName)
ch := uc.network.channels.Get(name)
if ch != nil {
uc.network.detach(ch)
} else {
ch = &database.Channel{
Name: upstreamName,
Name: name,
Detached: true,
}
uc.network.channels.Set(ch)
}
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
dc.logger.Printf("failed to create or update channel %q: %v", name, err)
}
} else {
params := []string{upstreamName}
params := []string{name}
if reason != "" {
params = append(params, reason)
}
@ -1907,69 +1892,20 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Params: params,
})
if err := uc.network.deleteChannel(ctx, upstreamName); err != nil {
dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err)
if err := uc.network.deleteChannel(ctx, name); err != nil {
dc.logger.Printf("failed to delete channel %q: %v", name, err)
}
uc.network.pushTargets.Del(upstreamName)
uc.network.pushTargets.Del(name)
}
}
case "KICK":
var channelStr, userStr string
if err := parseMessageParams(msg, &channelStr, &userStr); err != nil {
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
channels := strings.Split(channelStr, ",")
users := strings.Split(userStr, ",")
var reason string
if len(msg.Params) > 2 {
reason = msg.Params[2]
}
if len(channels) != 1 && len(channels) != len(users) {
return ircError{&irc.Message{
Command: irc.ERR_BADCHANMASK,
Params: []string{dc.nick, channelStr, "Bad channel mask"},
}}
}
for i, user := range users {
var channel string
if len(channels) == 1 {
channel = channels[0]
} else {
channel = channels[i]
}
ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
if err != nil {
return err
}
ucUser, upstreamUser, err := dc.unmarshalEntity(user)
if err != nil {
return err
}
if ucChannel != ucUser {
return ircError{&irc.Message{
Command: irc.ERR_USERNOTINCHANNEL,
Params: []string{dc.nick, user, channel, "They are on another network"},
}}
}
uc := ucChannel
params := []string{upstreamChannel, upstreamUser}
if reason != "" {
params = append(params, reason)
}
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "KICK",
Params: params,
})
}
uc.SendMessageLabeled(ctx, dc.id, msg)
case "MODE":
var name string
if err := parseMessageParams(msg, &name); err != nil {
@ -1992,7 +1928,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.ERR_UMODEUNKNOWNFLAG,
Params: []string{dc.nick, "Cannot change user mode in multi-upstream mode"},
Params: []string{dc.nick, "Cannot change user mode on bouncer connection"},
})
}
} else {
@ -2010,12 +1946,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return nil
}
uc, upstreamName, err := dc.unmarshalEntity(name)
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
if !uc.isChannel(upstreamName) {
if !uc.isChannel(name) {
return ircError{&irc.Message{
Command: irc.ERR_USERSDONTMATCH,
Params: []string{dc.nick, "Cannot change mode for other users"},
@ -2023,14 +1959,14 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
if modeStr != "" {
params := []string{upstreamName, modeStr}
params := []string{name, modeStr}
params = append(params, msg.Params[2:]...)
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "MODE",
Params: params,
})
} else {
ch := uc.channels.Get(upstreamName)
ch := uc.channels.Get(name)
if ch == nil {
return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
@ -2062,12 +1998,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
}
case "TOPIC":
var channel string
if err := parseMessageParams(msg, &channel); err != nil {
var name string
if err := parseMessageParams(msg, &name); err != nil {
return err
}
uc, upstreamName, err := dc.unmarshalEntity(channel)
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
@ -2076,48 +2012,31 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
topic := msg.Params[1]
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "TOPIC",
Params: []string{upstreamName, topic},
Params: []string{name, topic},
})
} else { // getting topic
ch := uc.channels.Get(upstreamName)
ch := uc.channels.Get(name)
if ch == nil {
return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{dc.nick, upstreamName, "No such channel"},
Params: []string{dc.nick, name, "No such channel"},
}}
}
sendTopic(dc, ch)
}
case "LIST":
network := dc.network
if network == nil && len(msg.Params) > 0 {
var err error
network, msg.Params[0], err = dc.unmarshalEntityNetwork(msg.Params[0])
if err != nil {
return err
}
}
if network == nil {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.RPL_LISTEND,
Params: []string{dc.nick, "LIST without a network suffix is not supported in multi-upstream mode"},
})
return nil
}
uc := network.conn
if uc == nil {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.RPL_LISTEND,
Params: []string{dc.nick, "Disconnected from upstream server"},
})
return nil
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
uc.enqueueCommand(dc, msg)
case "NAMES":
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
if len(msg.Params) == 0 {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
@ -2128,20 +2047,15 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
channels := strings.Split(msg.Params[0], ",")
for _, channel := range channels {
uc, upstreamName, err := dc.unmarshalEntity(channel)
if err != nil {
return err
}
ch := uc.channels.Get(upstreamName)
for _, name := range channels {
ch := uc.channels.Get(name)
if ch != nil {
sendNames(dc, ch)
} else {
// NAMES on a channel we have not joined, ask upstream
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "NAMES",
Params: []string{upstreamName},
Params: []string{name},
})
}
}
@ -2229,12 +2143,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
})
return nil
}
// TODO: properly support WHO masks
uc, upstreamMask, err := dc.unmarshalEntity(mask)
if err != nil {
// Ignore the error here, because clients don't know how to deal
// with anything other than RPL_ENDOFWHO
if dc.network == nil {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.RPL_ENDOFWHO,
@ -2243,15 +2152,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return nil
}
params := []string{upstreamMask}
if options != "" {
params = append(params, options)
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
uc.enqueueCommand(dc, &irc.Message{
Command: "WHO",
Params: params,
})
uc.enqueueCommand(dc, msg)
case "WHOIS":
if len(msg.Params) == 0 {
return ircError{&irc.Message{
@ -2260,12 +2166,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}}
}
var target, mask string
var mask string
if len(msg.Params) == 1 {
target = ""
mask = msg.Params[0]
} else {
target = msg.Params[0]
mask = msg.Params[1]
}
// TODO: support multiple WHOIS users
@ -2337,27 +2241,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return nil
}
// TODO: support WHOIS masks
uc, upstreamNick, err := dc.unmarshalEntity(mask)
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
var params []string
if target != "" {
if target == mask { // WHOIS nick nick
params = []string{upstreamNick, upstreamNick}
} else {
params = []string{target, upstreamNick}
}
} else {
params = []string{upstreamNick}
}
uc.enqueueCommand(dc, &irc.Message{
Command: "WHOIS",
Params: params,
})
uc.enqueueCommand(dc, msg)
case "PRIVMSG", "NOTICE", "TAGMSG":
var targetsStr, text string
if msg.Command != "TAGMSG" {
@ -2432,16 +2321,16 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
continue
}
uc, upstreamName, err := dc.unmarshalEntity(name)
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
if msg.Command == "PRIVMSG" && uc.network.casemap(upstreamName) == "nickserv" {
if msg.Command == "PRIVMSG" && uc.network.casemap(name) == "nickserv" {
dc.handleNickServPRIVMSG(ctx, uc, text)
}
upstreamParams := []string{upstreamName}
upstreamParams := []string{name}
if msg.Command != "TAGMSG" {
upstreamParams = append(upstreamParams, text)
}
@ -2456,7 +2345,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
// when it is echoed from the upstream.
// Otherwise, produce/log it here because it's the last time we'll see it.
if !uc.caps.IsEnabled("echo-message") {
echoParams := []string{upstreamName}
echoParams := []string{name}
if msg.Command != "TAGMSG" {
echoParams = append(echoParams, text)
}
@ -2476,39 +2365,18 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Command: msg.Command,
Params: echoParams,
}
uc.produce(upstreamName, echoMsg, dc.id)
uc.produce(name, echoMsg, dc.id)
}
uc.updateChannelAutoDetach(upstreamName)
uc.updateChannelAutoDetach(name)
}
case "INVITE":
var user, channel string
if err := parseMessageParams(msg, &user, &channel); err != nil {
return err
}
ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
ucUser, upstreamUser, err := dc.unmarshalEntity(user)
if err != nil {
return err
}
if ucChannel != ucUser {
return ircError{&irc.Message{
Command: irc.ERR_USERNOTINCHANNEL,
Params: []string{dc.nick, user, channel, "They are on another network"},
}}
}
uc := ucChannel
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "INVITE",
Params: []string{upstreamUser, upstreamChannel},
})
uc.SendMessageLabeled(ctx, dc.id, msg)
case "AUTHENTICATE":
// Post-connection-registration AUTHENTICATE is unsupported in
// multi-upstream mode, or if the upstream doesn't support SASL
@ -2747,15 +2615,16 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if !ok {
return ircError{&irc.Message{
Command: irc.ERR_UNKNOWNCOMMAND,
Params: []string{dc.nick, "CHATHISTORY", "Unknown command"},
Params: []string{dc.nick, "CHATHISTORY", "Chat history disabled"},
}}
}
network, entity, err := dc.unmarshalEntityNetwork(target)
if err != nil {
return err
network := dc.network
if network == nil {
return newChatHistoryError(subcommand, "Cannot fetch chat history on bouncer connection")
}
entity = network.casemap(entity)
target = network.casemap(target)
// TODO: support msgid criteria
var bounds [2]time.Time
@ -2791,7 +2660,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
options := msgstore.LoadMessageOptions{
Network: &network.Network,
Entity: entity,
Entity: target,
Limit: limit,
Events: eventPlayback,
}
@ -2871,22 +2740,26 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return nil
}
network, entity, err := dc.unmarshalEntityNetwork(target)
if err != nil {
return err
network := dc.network
if network == nil {
return ircError{&irc.Message{
Command: "FAIL",
Params: []string{msg.Command, "INTERNAL_ERROR", target, "Cannot set read markers on bouncer connection"},
}}
}
entityCM := network.casemap(entity)
r, err := dc.srv.db.GetReadReceipt(ctx, network.ID, entityCM)
targetCM := network.casemap(target)
r, err := dc.srv.db.GetReadReceipt(ctx, network.ID, targetCM)
if err != nil {
dc.logger.Printf("failed to get the read receipt for %q: %v", entity, err)
dc.logger.Printf("failed to get the read receipt for %q: %v", target, err)
return ircError{&irc.Message{
Command: "FAIL",
Params: []string{msg.Command, "INTERNAL_ERROR", target, "Internal error"},
}}
} else if r == nil {
r = &database.ReadReceipt{
Target: entityCM,
Target: targetCM,
}
}
@ -2915,7 +2788,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if r.Timestamp.Before(timestamp) {
r.Timestamp = timestamp
if err := dc.srv.db.StoreReadReceipt(ctx, network.ID, r); err != nil {
dc.logger.Printf("failed to store receipt for %q: %v", entity, err)
dc.logger.Printf("failed to store receipt for %q: %v", target, err)
return ircError{&irc.Message{
Command: "FAIL",
Params: []string{msg.Command, "INTERNAL_ERROR", target, "Internal error"},
@ -2938,17 +2811,17 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
d.SendMessage(&irc.Message{
Prefix: d.prefix(),
Command: cmd,
Params: []string{entity, timestampStr},
Params: []string{target, timestampStr},
})
}
})
if broadcast && network.pushTargets.Has(entity) {
if broadcast && network.pushTargets.Has(target) {
// TODO: only broadcast if draft/read-marker has been negotiated
network.pushTargets.Del(entity)
network.pushTargets.Del(target)
go network.broadcastWebPush(&irc.Message{
Command: "MARKREAD",
Params: []string{entity, timestampStr},
Params: []string{target, timestampStr},
})
}
case "SEARCH":
@ -2965,7 +2838,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}
attrs := irc.ParseTags(attrsStr)
var uc *upstreamConn
var network *network
const searchMaxLimit = 100
opts := msgstore.SearchMessageOptions{
Limit: searchMaxLimit,
@ -2990,15 +2863,14 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
case "from":
opts.From = value
case "in":
u, upstreamName, err := dc.unmarshalEntity(value)
if err != nil {
network = dc.network
if network == nil {
return ircError{&irc.Message{
Command: "FAIL",
Params: []string{"SEARCH", "INVALID_PARAMS", name, "Invalid criteria"},
Params: []string{"SEARCH", "INVALID_PARAMS", name, "Cannot search on bouncer connection"},
}}
}
uc = u
opts.In = u.network.casemap(upstreamName)
opts.In = network.casemap(value)
case "text":
opts.Text = value
case "limit":
@ -3012,7 +2884,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
opts.Limit = limit
}
}
if uc == nil {
if network == nil {
return ircError{&irc.Message{
Command: "FAIL",
Params: []string{"SEARCH", "INVALID_PARAMS", "in", "The in parameter is mandatory"},
@ -3022,7 +2894,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
opts.Limit = searchMaxLimit
}
messages, err := store.Search(ctx, &uc.network.Network, &opts)
messages, err := store.Search(ctx, &network.Network, &opts)
if err != nil {
dc.logger.Printf("failed fetching messages for search: %v", err)
return ircError{&irc.Message{

View File

@ -1139,12 +1139,28 @@ func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params
return err
}
uc, upstreamName, err := dc.unmarshalEntity(name)
if err != nil {
return fmt.Errorf("unknown channel %q", name)
network := dc.network
if network == nil {
l := strings.SplitN(name, "/", 2)
if len(l) != 2 {
return fmt.Errorf("missing network name")
}
name = l[0]
netName := l[1]
for _, n := range dc.user.networks {
if netName == n.GetName() {
network = n
break
}
}
if network == nil {
return fmt.Errorf("unknown network %q", netName)
}
}
ch := uc.network.channels.Get(upstreamName)
ch := network.channels.Get(name)
if ch == nil {
return fmt.Errorf("unknown channel %q", name)
}
@ -1155,15 +1171,17 @@ func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params
if fs.Detached != nil && *fs.Detached != ch.Detached {
if *fs.Detached {
uc.network.detach(ch)
network.detach(ch)
} else {
uc.network.attach(ctx, ch)
network.attach(ctx, ch)
}
}
uc.updateChannelAutoDetach(upstreamName)
if network.conn != nil {
network.conn.updateChannelAutoDetach(name)
}
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
if err := dc.srv.db.StoreChannel(ctx, network.ID, ch); err != nil {
return fmt.Errorf("failed to update channel: %v", err)
}