downstream: drop downstreamConn.unmarshalEntity{,Network}

This commit is contained in:
Simon Ser 2022-08-08 11:30:10 +02:00
parent dde4ee9518
commit c3ab11de4e
2 changed files with 139 additions and 249 deletions

View File

@ -42,6 +42,13 @@ func newUnknownCommandError(cmd string) ircError {
}} }}
} }
func newUnknownIRCError(cmd, text string) ircError {
return ircError{&irc.Message{
Command: xirc.ERR_UNKNOWNERROR,
Params: []string{"*", cmd, text},
}}
}
func newNeedMoreParamsError(cmd string) ircError { func newNeedMoreParamsError(cmd string) ircError {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_NEEDMOREPARAMS, Command: irc.ERR_NEEDMOREPARAMS,
@ -402,8 +409,8 @@ func (dc *downstreamConn) forEachUpstream(f func(*upstreamConn)) {
}) })
} }
// upstream returns the upstream connection, if any. If there are zero or if // upstream returns the upstream connection, if any. If there are zero upstream
// there are multiple upstream connections, it returns nil. // connections, it returns nil.
func (dc *downstreamConn) upstream() *upstreamConn { func (dc *downstreamConn) upstream() *upstreamConn {
if dc.network == nil { if dc.network == nil {
return nil return nil
@ -411,6 +418,16 @@ func (dc *downstreamConn) upstream() *upstreamConn {
return dc.network.conn return dc.network.conn
} }
func (dc *downstreamConn) upstreamForCommand(cmd string) (*upstreamConn, error) {
if dc.network == nil {
return nil, newUnknownIRCError(cmd, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?")
}
if dc.network.conn == nil {
return nil, newUnknownIRCError(cmd, "Disconnected from upstream network")
}
return dc.network.conn, nil
}
func isOurNick(net *network, nick string) bool { func isOurNick(net *network, nick string) bool {
// TODO: this doesn't account for nick changes // TODO: this doesn't account for nick changes
if net.conn != nil { if net.conn != nil {
@ -423,38 +440,6 @@ func isOurNick(net *network, nick string) bool {
return net.casemap(nick) == net.casemap(database.GetNick(&net.user.User, &net.Network)) return net.casemap(nick) == net.casemap(database.GetNick(&net.user.User, &net.Network))
} }
// unmarshalEntityNetwork converts a downstream entity name (ie. channel or
// nick) into an upstream entity name.
//
// This involves removing the "/<network>" suffix.
func (dc *downstreamConn) unmarshalEntityNetwork(name string) (*network, string, error) {
if dc.network != nil {
return dc.network, name, nil
}
return nil, "", ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{dc.nick, name, "Cannot interact with channels and users on the bouncer connection. Did you mean to use a specific network?"},
}}
}
// unmarshalEntity is the same as unmarshalEntityNetwork, but returns the
// upstream connection and fails if the upstream is disconnected.
func (dc *downstreamConn) unmarshalEntity(name string) (*upstreamConn, string, error) {
net, name, err := dc.unmarshalEntityNetwork(name)
if err != nil {
return nil, "", err
}
if net.conn == nil {
return nil, "", ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{dc.nick, name, "Disconnected from upstream network"},
}}
}
return net.conn, name, nil
}
func (dc *downstreamConn) ReadMessage() (*irc.Message, error) { func (dc *downstreamConn) ReadMessage() (*irc.Message, error) {
msg, err := dc.conn.ReadMessage() msg, err := dc.conn.ReadMessage()
if err != nil { if err != nil {
@ -1802,6 +1787,11 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
} }
case "JOIN": case "JOIN":
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
var namesStr string var namesStr string
if err := parseMessageParams(msg, &namesStr); err != nil { if err := parseMessageParams(msg, &namesStr); err != nil {
return err return err
@ -1813,17 +1803,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
for i, name := range strings.Split(namesStr, ",") { for i, name := range strings.Split(namesStr, ",") {
uc, upstreamName, err := dc.unmarshalEntity(name)
if err != nil {
return err
}
var key string var key string
if len(keys) > i { if len(keys) > i {
key = keys[i] key = keys[i]
} }
if !uc.isChannel(upstreamName) { if !uc.isChannel(name) {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.ERR_NOSUCHCHANNEL, Command: irc.ERR_NOSUCHCHANNEL,
@ -1836,8 +1821,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
// because some clients automatically send JOIN messages in bulk // because some clients automatically send JOIN messages in bulk
// when reconnecting to the bouncer. We don't want to flood the // when reconnecting to the bouncer. We don't want to flood the
// upstream connection with these. // upstream connection with these.
if !uc.channels.Has(upstreamName) { if !uc.channels.Has(name) {
params := []string{upstreamName} params := []string{name}
if key != "" { if key != "" {
params = append(params, key) params = append(params, key)
} }
@ -1847,7 +1832,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}) })
} }
ch := uc.network.channels.Get(upstreamName) ch := uc.network.channels.Get(name)
if ch != nil { if ch != nil {
// Don't clear the channel key if there's one set // Don't clear the channel key if there's one set
// TODO: add a way to unset the channel key // TODO: add a way to unset the channel key
@ -1857,16 +1842,21 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
uc.network.attach(ctx, ch) uc.network.attach(ctx, ch)
} else { } else {
ch = &database.Channel{ ch = &database.Channel{
Name: upstreamName, Name: name,
Key: key, Key: key,
} }
uc.network.channels.Set(ch) uc.network.channels.Set(ch)
} }
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) dc.logger.Printf("failed to create or update channel %q: %v", name, err)
} }
} }
case "PART": case "PART":
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
var namesStr string var namesStr string
if err := parseMessageParams(msg, &namesStr); err != nil { if err := parseMessageParams(msg, &namesStr); err != nil {
return err return err
@ -1878,27 +1868,22 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
for _, name := range strings.Split(namesStr, ",") { for _, name := range strings.Split(namesStr, ",") {
uc, upstreamName, err := dc.unmarshalEntity(name)
if err != nil {
return err
}
if strings.EqualFold(reason, "detach") { if strings.EqualFold(reason, "detach") {
ch := uc.network.channels.Get(upstreamName) ch := uc.network.channels.Get(name)
if ch != nil { if ch != nil {
uc.network.detach(ch) uc.network.detach(ch)
} else { } else {
ch = &database.Channel{ ch = &database.Channel{
Name: upstreamName, Name: name,
Detached: true, Detached: true,
} }
uc.network.channels.Set(ch) uc.network.channels.Set(ch)
} }
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) dc.logger.Printf("failed to create or update channel %q: %v", name, err)
} }
} else { } else {
params := []string{upstreamName} params := []string{name}
if reason != "" { if reason != "" {
params = append(params, reason) params = append(params, reason)
} }
@ -1907,69 +1892,20 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Params: params, Params: params,
}) })
if err := uc.network.deleteChannel(ctx, upstreamName); err != nil { if err := uc.network.deleteChannel(ctx, name); err != nil {
dc.logger.Printf("failed to delete channel %q: %v", upstreamName, err) dc.logger.Printf("failed to delete channel %q: %v", name, err)
} }
uc.network.pushTargets.Del(upstreamName) uc.network.pushTargets.Del(name)
} }
} }
case "KICK": case "KICK":
var channelStr, userStr string uc, err := dc.upstreamForCommand(msg.Command)
if err := parseMessageParams(msg, &channelStr, &userStr); err != nil {
return err
}
channels := strings.Split(channelStr, ",")
users := strings.Split(userStr, ",")
var reason string
if len(msg.Params) > 2 {
reason = msg.Params[2]
}
if len(channels) != 1 && len(channels) != len(users) {
return ircError{&irc.Message{
Command: irc.ERR_BADCHANMASK,
Params: []string{dc.nick, channelStr, "Bad channel mask"},
}}
}
for i, user := range users {
var channel string
if len(channels) == 1 {
channel = channels[0]
} else {
channel = channels[i]
}
ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
if err != nil { if err != nil {
return err return err
} }
ucUser, upstreamUser, err := dc.unmarshalEntity(user) uc.SendMessageLabeled(ctx, dc.id, msg)
if err != nil {
return err
}
if ucChannel != ucUser {
return ircError{&irc.Message{
Command: irc.ERR_USERNOTINCHANNEL,
Params: []string{dc.nick, user, channel, "They are on another network"},
}}
}
uc := ucChannel
params := []string{upstreamChannel, upstreamUser}
if reason != "" {
params = append(params, reason)
}
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "KICK",
Params: params,
})
}
case "MODE": case "MODE":
var name string var name string
if err := parseMessageParams(msg, &name); err != nil { if err := parseMessageParams(msg, &name); err != nil {
@ -1992,7 +1928,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.ERR_UMODEUNKNOWNFLAG, Command: irc.ERR_UMODEUNKNOWNFLAG,
Params: []string{dc.nick, "Cannot change user mode in multi-upstream mode"}, Params: []string{dc.nick, "Cannot change user mode on bouncer connection"},
}) })
} }
} else { } else {
@ -2010,12 +1946,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return nil return nil
} }
uc, upstreamName, err := dc.unmarshalEntity(name) uc, err := dc.upstreamForCommand(msg.Command)
if err != nil { if err != nil {
return err return err
} }
if !uc.isChannel(upstreamName) { if !uc.isChannel(name) {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_USERSDONTMATCH, Command: irc.ERR_USERSDONTMATCH,
Params: []string{dc.nick, "Cannot change mode for other users"}, Params: []string{dc.nick, "Cannot change mode for other users"},
@ -2023,14 +1959,14 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
if modeStr != "" { if modeStr != "" {
params := []string{upstreamName, modeStr} params := []string{name, modeStr}
params = append(params, msg.Params[2:]...) params = append(params, msg.Params[2:]...)
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "MODE", Command: "MODE",
Params: params, Params: params,
}) })
} else { } else {
ch := uc.channels.Get(upstreamName) ch := uc.channels.Get(name)
if ch == nil { if ch == nil {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL, Command: irc.ERR_NOSUCHCHANNEL,
@ -2062,12 +1998,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
} }
case "TOPIC": case "TOPIC":
var channel string var name string
if err := parseMessageParams(msg, &channel); err != nil { if err := parseMessageParams(msg, &name); err != nil {
return err return err
} }
uc, upstreamName, err := dc.unmarshalEntity(channel) uc, err := dc.upstreamForCommand(msg.Command)
if err != nil { if err != nil {
return err return err
} }
@ -2076,48 +2012,31 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
topic := msg.Params[1] topic := msg.Params[1]
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "TOPIC", Command: "TOPIC",
Params: []string{upstreamName, topic}, Params: []string{name, topic},
}) })
} else { // getting topic } else { // getting topic
ch := uc.channels.Get(upstreamName) ch := uc.channels.Get(name)
if ch == nil { if ch == nil {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL, Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{dc.nick, upstreamName, "No such channel"}, Params: []string{dc.nick, name, "No such channel"},
}} }}
} }
sendTopic(dc, ch) sendTopic(dc, ch)
} }
case "LIST": case "LIST":
network := dc.network uc, err := dc.upstreamForCommand(msg.Command)
if network == nil && len(msg.Params) > 0 {
var err error
network, msg.Params[0], err = dc.unmarshalEntityNetwork(msg.Params[0])
if err != nil { if err != nil {
return err return err
} }
}
if network == nil {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.RPL_LISTEND,
Params: []string{dc.nick, "LIST without a network suffix is not supported in multi-upstream mode"},
})
return nil
}
uc := network.conn
if uc == nil {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.RPL_LISTEND,
Params: []string{dc.nick, "Disconnected from upstream server"},
})
return nil
}
uc.enqueueCommand(dc, msg) uc.enqueueCommand(dc, msg)
case "NAMES": case "NAMES":
uc, err := dc.upstreamForCommand(msg.Command)
if err != nil {
return err
}
if len(msg.Params) == 0 { if len(msg.Params) == 0 {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
@ -2128,20 +2047,15 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
channels := strings.Split(msg.Params[0], ",") channels := strings.Split(msg.Params[0], ",")
for _, channel := range channels { for _, name := range channels {
uc, upstreamName, err := dc.unmarshalEntity(channel) ch := uc.channels.Get(name)
if err != nil {
return err
}
ch := uc.channels.Get(upstreamName)
if ch != nil { if ch != nil {
sendNames(dc, ch) sendNames(dc, ch)
} else { } else {
// NAMES on a channel we have not joined, ask upstream // NAMES on a channel we have not joined, ask upstream
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{ uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "NAMES", Command: "NAMES",
Params: []string{upstreamName}, Params: []string{name},
}) })
} }
} }
@ -2229,12 +2143,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}) })
return nil return nil
} }
if dc.network == nil {
// TODO: properly support WHO masks
uc, upstreamMask, err := dc.unmarshalEntity(mask)
if err != nil {
// Ignore the error here, because clients don't know how to deal
// with anything other than RPL_ENDOFWHO
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_ENDOFWHO, Command: irc.RPL_ENDOFWHO,
@ -2243,15 +2152,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return nil return nil
} }
params := []string{upstreamMask} uc, err := dc.upstreamForCommand(msg.Command)
if options != "" { if err != nil {
params = append(params, options) return err
} }
uc.enqueueCommand(dc, &irc.Message{ uc.enqueueCommand(dc, msg)
Command: "WHO",
Params: params,
})
case "WHOIS": case "WHOIS":
if len(msg.Params) == 0 { if len(msg.Params) == 0 {
return ircError{&irc.Message{ return ircError{&irc.Message{
@ -2260,12 +2166,10 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
}} }}
} }
var target, mask string var mask string
if len(msg.Params) == 1 { if len(msg.Params) == 1 {
target = ""
mask = msg.Params[0] mask = msg.Params[0]
} else { } else {
target = msg.Params[0]
mask = msg.Params[1] mask = msg.Params[1]
} }
// TODO: support multiple WHOIS users // TODO: support multiple WHOIS users
@ -2337,27 +2241,12 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return nil return nil
} }
// TODO: support WHOIS masks uc, err := dc.upstreamForCommand(msg.Command)
uc, upstreamNick, err := dc.unmarshalEntity(mask)
if err != nil { if err != nil {
return err return err
} }
var params []string uc.enqueueCommand(dc, msg)
if target != "" {
if target == mask { // WHOIS nick nick
params = []string{upstreamNick, upstreamNick}
} else {
params = []string{target, upstreamNick}
}
} else {
params = []string{upstreamNick}
}
uc.enqueueCommand(dc, &irc.Message{
Command: "WHOIS",
Params: params,
})
case "PRIVMSG", "NOTICE", "TAGMSG": case "PRIVMSG", "NOTICE", "TAGMSG":
var targetsStr, text string var targetsStr, text string
if msg.Command != "TAGMSG" { if msg.Command != "TAGMSG" {
@ -2432,16 +2321,16 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
continue continue
} }
uc, upstreamName, err := dc.unmarshalEntity(name) uc, err := dc.upstreamForCommand(msg.Command)
if err != nil { if err != nil {
return err return err
} }
if msg.Command == "PRIVMSG" && uc.network.casemap(upstreamName) == "nickserv" { if msg.Command == "PRIVMSG" && uc.network.casemap(name) == "nickserv" {
dc.handleNickServPRIVMSG(ctx, uc, text) dc.handleNickServPRIVMSG(ctx, uc, text)
} }
upstreamParams := []string{upstreamName} upstreamParams := []string{name}
if msg.Command != "TAGMSG" { if msg.Command != "TAGMSG" {
upstreamParams = append(upstreamParams, text) upstreamParams = append(upstreamParams, text)
} }
@ -2456,7 +2345,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
// when it is echoed from the upstream. // when it is echoed from the upstream.
// Otherwise, produce/log it here because it's the last time we'll see it. // Otherwise, produce/log it here because it's the last time we'll see it.
if !uc.caps.IsEnabled("echo-message") { if !uc.caps.IsEnabled("echo-message") {
echoParams := []string{upstreamName} echoParams := []string{name}
if msg.Command != "TAGMSG" { if msg.Command != "TAGMSG" {
echoParams = append(echoParams, text) echoParams = append(echoParams, text)
} }
@ -2476,39 +2365,18 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
Command: msg.Command, Command: msg.Command,
Params: echoParams, Params: echoParams,
} }
uc.produce(upstreamName, echoMsg, dc.id) uc.produce(name, echoMsg, dc.id)
} }
uc.updateChannelAutoDetach(upstreamName) uc.updateChannelAutoDetach(name)
} }
case "INVITE": case "INVITE":
var user, channel string uc, err := dc.upstreamForCommand(msg.Command)
if err := parseMessageParams(msg, &user, &channel); err != nil {
return err
}
ucChannel, upstreamChannel, err := dc.unmarshalEntity(channel)
if err != nil { if err != nil {
return err return err
} }
ucUser, upstreamUser, err := dc.unmarshalEntity(user) uc.SendMessageLabeled(ctx, dc.id, msg)
if err != nil {
return err
}
if ucChannel != ucUser {
return ircError{&irc.Message{
Command: irc.ERR_USERNOTINCHANNEL,
Params: []string{dc.nick, user, channel, "They are on another network"},
}}
}
uc := ucChannel
uc.SendMessageLabeled(ctx, dc.id, &irc.Message{
Command: "INVITE",
Params: []string{upstreamUser, upstreamChannel},
})
case "AUTHENTICATE": case "AUTHENTICATE":
// Post-connection-registration AUTHENTICATE is unsupported in // Post-connection-registration AUTHENTICATE is unsupported in
// multi-upstream mode, or if the upstream doesn't support SASL // multi-upstream mode, or if the upstream doesn't support SASL
@ -2747,15 +2615,16 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if !ok { if !ok {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_UNKNOWNCOMMAND, Command: irc.ERR_UNKNOWNCOMMAND,
Params: []string{dc.nick, "CHATHISTORY", "Unknown command"}, Params: []string{dc.nick, "CHATHISTORY", "Chat history disabled"},
}} }}
} }
network, entity, err := dc.unmarshalEntityNetwork(target) network := dc.network
if err != nil { if network == nil {
return err return newChatHistoryError(subcommand, "Cannot fetch chat history on bouncer connection")
} }
entity = network.casemap(entity)
target = network.casemap(target)
// TODO: support msgid criteria // TODO: support msgid criteria
var bounds [2]time.Time var bounds [2]time.Time
@ -2791,7 +2660,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
options := msgstore.LoadMessageOptions{ options := msgstore.LoadMessageOptions{
Network: &network.Network, Network: &network.Network,
Entity: entity, Entity: target,
Limit: limit, Limit: limit,
Events: eventPlayback, Events: eventPlayback,
} }
@ -2871,22 +2740,26 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
return nil return nil
} }
network, entity, err := dc.unmarshalEntityNetwork(target) network := dc.network
if err != nil { if network == nil {
return err return ircError{&irc.Message{
Command: "FAIL",
Params: []string{msg.Command, "INTERNAL_ERROR", target, "Cannot set read markers on bouncer connection"},
}}
} }
entityCM := network.casemap(entity)
r, err := dc.srv.db.GetReadReceipt(ctx, network.ID, entityCM) targetCM := network.casemap(target)
r, err := dc.srv.db.GetReadReceipt(ctx, network.ID, targetCM)
if err != nil { if err != nil {
dc.logger.Printf("failed to get the read receipt for %q: %v", entity, err) dc.logger.Printf("failed to get the read receipt for %q: %v", target, err)
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: "FAIL", Command: "FAIL",
Params: []string{msg.Command, "INTERNAL_ERROR", target, "Internal error"}, Params: []string{msg.Command, "INTERNAL_ERROR", target, "Internal error"},
}} }}
} else if r == nil { } else if r == nil {
r = &database.ReadReceipt{ r = &database.ReadReceipt{
Target: entityCM, Target: targetCM,
} }
} }
@ -2915,7 +2788,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
if r.Timestamp.Before(timestamp) { if r.Timestamp.Before(timestamp) {
r.Timestamp = timestamp r.Timestamp = timestamp
if err := dc.srv.db.StoreReadReceipt(ctx, network.ID, r); err != nil { if err := dc.srv.db.StoreReadReceipt(ctx, network.ID, r); err != nil {
dc.logger.Printf("failed to store receipt for %q: %v", entity, err) dc.logger.Printf("failed to store receipt for %q: %v", target, err)
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: "FAIL", Command: "FAIL",
Params: []string{msg.Command, "INTERNAL_ERROR", target, "Internal error"}, Params: []string{msg.Command, "INTERNAL_ERROR", target, "Internal error"},
@ -2938,17 +2811,17 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
d.SendMessage(&irc.Message{ d.SendMessage(&irc.Message{
Prefix: d.prefix(), Prefix: d.prefix(),
Command: cmd, Command: cmd,
Params: []string{entity, timestampStr}, Params: []string{target, timestampStr},
}) })
} }
}) })
if broadcast && network.pushTargets.Has(entity) { if broadcast && network.pushTargets.Has(target) {
// TODO: only broadcast if draft/read-marker has been negotiated // TODO: only broadcast if draft/read-marker has been negotiated
network.pushTargets.Del(entity) network.pushTargets.Del(target)
go network.broadcastWebPush(&irc.Message{ go network.broadcastWebPush(&irc.Message{
Command: "MARKREAD", Command: "MARKREAD",
Params: []string{entity, timestampStr}, Params: []string{target, timestampStr},
}) })
} }
case "SEARCH": case "SEARCH":
@ -2965,7 +2838,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
} }
attrs := irc.ParseTags(attrsStr) attrs := irc.ParseTags(attrsStr)
var uc *upstreamConn var network *network
const searchMaxLimit = 100 const searchMaxLimit = 100
opts := msgstore.SearchMessageOptions{ opts := msgstore.SearchMessageOptions{
Limit: searchMaxLimit, Limit: searchMaxLimit,
@ -2990,15 +2863,14 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
case "from": case "from":
opts.From = value opts.From = value
case "in": case "in":
u, upstreamName, err := dc.unmarshalEntity(value) network = dc.network
if err != nil { if network == nil {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: "FAIL", Command: "FAIL",
Params: []string{"SEARCH", "INVALID_PARAMS", name, "Invalid criteria"}, Params: []string{"SEARCH", "INVALID_PARAMS", name, "Cannot search on bouncer connection"},
}} }}
} }
uc = u opts.In = network.casemap(value)
opts.In = u.network.casemap(upstreamName)
case "text": case "text":
opts.Text = value opts.Text = value
case "limit": case "limit":
@ -3012,7 +2884,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
opts.Limit = limit opts.Limit = limit
} }
} }
if uc == nil { if network == nil {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: "FAIL", Command: "FAIL",
Params: []string{"SEARCH", "INVALID_PARAMS", "in", "The in parameter is mandatory"}, Params: []string{"SEARCH", "INVALID_PARAMS", "in", "The in parameter is mandatory"},
@ -3022,7 +2894,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
opts.Limit = searchMaxLimit opts.Limit = searchMaxLimit
} }
messages, err := store.Search(ctx, &uc.network.Network, &opts) messages, err := store.Search(ctx, &network.Network, &opts)
if err != nil { if err != nil {
dc.logger.Printf("failed fetching messages for search: %v", err) dc.logger.Printf("failed fetching messages for search: %v", err)
return ircError{&irc.Message{ return ircError{&irc.Message{

View File

@ -1139,12 +1139,28 @@ func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params
return err return err
} }
uc, upstreamName, err := dc.unmarshalEntity(name) network := dc.network
if err != nil { if network == nil {
return fmt.Errorf("unknown channel %q", name) l := strings.SplitN(name, "/", 2)
if len(l) != 2 {
return fmt.Errorf("missing network name")
}
name = l[0]
netName := l[1]
for _, n := range dc.user.networks {
if netName == n.GetName() {
network = n
break
}
} }
ch := uc.network.channels.Get(upstreamName) if network == nil {
return fmt.Errorf("unknown network %q", netName)
}
}
ch := network.channels.Get(name)
if ch == nil { if ch == nil {
return fmt.Errorf("unknown channel %q", name) return fmt.Errorf("unknown channel %q", name)
} }
@ -1155,15 +1171,17 @@ func handleServiceChannelUpdate(ctx context.Context, dc *downstreamConn, params
if fs.Detached != nil && *fs.Detached != ch.Detached { if fs.Detached != nil && *fs.Detached != ch.Detached {
if *fs.Detached { if *fs.Detached {
uc.network.detach(ch) network.detach(ch)
} else { } else {
uc.network.attach(ctx, ch) network.attach(ctx, ch)
} }
} }
uc.updateChannelAutoDetach(upstreamName) if network.conn != nil {
network.conn.updateChannelAutoDetach(name)
}
if err := dc.srv.db.StoreChannel(ctx, uc.network.ID, ch); err != nil { if err := dc.srv.db.StoreChannel(ctx, network.ID, ch); err != nil {
return fmt.Errorf("failed to update channel: %v", err) return fmt.Errorf("failed to update channel: %v", err)
} }