Make downstreamConn.marshal{Entity,UserPrefix} take a network

This will be used when sending history while upstream is disconnected.
This commit is contained in:
Simon Ser 2020-04-16 17:19:00 +02:00
parent 5cf876cb89
commit 45e897c1c1
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
3 changed files with 51 additions and 39 deletions

View File

@ -16,7 +16,7 @@ func forwardChannel(dc *downstreamConn, ch *upstreamChannel) {
} }
func sendTopic(dc *downstreamConn, ch *upstreamChannel) { func sendTopic(dc *downstreamConn, ch *upstreamChannel) {
downstreamName := dc.marshalEntity(ch.conn, ch.Name) downstreamName := dc.marshalEntity(ch.conn.network, ch.Name)
if ch.Topic != "" { if ch.Topic != "" {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
@ -36,10 +36,10 @@ func sendTopic(dc *downstreamConn, ch *upstreamChannel) {
func sendNames(dc *downstreamConn, ch *upstreamChannel) { func sendNames(dc *downstreamConn, ch *upstreamChannel) {
// TODO: send multiple members in each message // TODO: send multiple members in each message
downstreamName := dc.marshalEntity(ch.conn, ch.Name) downstreamName := dc.marshalEntity(ch.conn.network, ch.Name)
for nick, membership := range ch.Members { for nick, membership := range ch.Members {
s := membership.String() + dc.marshalEntity(ch.conn, nick) s := membership.String() + dc.marshalEntity(ch.conn.network, nick)
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),

View File

@ -121,36 +121,48 @@ func (dc *downstreamConn) upstream() *upstreamConn {
return dc.network.upstream() return dc.network.upstream()
} }
func isOurNick(net *network, nick string) bool {
// TODO: this doesn't account for nick changes
if net.conn != nil {
return nick == net.conn.nick
}
// We're not currently connected to the upstream connection, so we don't
// know whether this name is our nickname. Best-effort: use the network's
// configured nickname and hope it was the one being used when we were
// connected.
return nick == net.Nick
}
// marshalEntity converts an upstream entity name (ie. channel or nick) into a // marshalEntity converts an upstream entity name (ie. channel or nick) into a
// downstream entity name. // downstream entity name.
// //
// This involves adding a "/<network>" suffix if the entity isn't the current // This involves adding a "/<network>" suffix if the entity isn't the current
// user. // user.
func (dc *downstreamConn) marshalEntity(uc *upstreamConn, name string) string { func (dc *downstreamConn) marshalEntity(net *network, name string) string {
if dc.network != nil { if dc.network != nil {
if dc.network != uc.network { if dc.network != net {
panic("soju: tried to marshal an entity for another network") panic("soju: tried to marshal an entity for another network")
} }
return name return name
} }
if name == uc.nick { if isOurNick(net, name) {
return dc.nick return dc.nick
} }
return name + "/" + uc.network.GetName() return name + "/" + net.GetName()
} }
func (dc *downstreamConn) marshalUserPrefix(uc *upstreamConn, prefix *irc.Prefix) *irc.Prefix { func (dc *downstreamConn) marshalUserPrefix(net *network, prefix *irc.Prefix) *irc.Prefix {
if prefix.Name == uc.nick { if isOurNick(net, prefix.Name) {
return dc.prefix() return dc.prefix()
} }
if dc.network != nil { if dc.network != nil {
if dc.network != uc.network { if dc.network != net {
panic("soju: tried to marshal a user prefix for another network") panic("soju: tried to marshal a user prefix for another network")
} }
return prefix return prefix
} }
return &irc.Prefix{ return &irc.Prefix{
Name: prefix.Name + "/" + uc.network.GetName(), Name: prefix.Name + "/" + net.GetName(),
User: prefix.User, User: prefix.User,
Host: prefix.Host, Host: prefix.Host,
} }
@ -228,23 +240,23 @@ func (dc *downstreamConn) SendMessage(msg *irc.Message) {
// messages that may appear in logs are supported. // messages that may appear in logs are supported.
func (dc *downstreamConn) marshalMessage(msg *irc.Message, uc *upstreamConn) *irc.Message { func (dc *downstreamConn) marshalMessage(msg *irc.Message, uc *upstreamConn) *irc.Message {
msg = msg.Copy() msg = msg.Copy()
msg.Prefix = dc.marshalUserPrefix(uc, msg.Prefix) msg.Prefix = dc.marshalUserPrefix(uc.network, msg.Prefix)
switch msg.Command { switch msg.Command {
case "PRIVMSG", "NOTICE": case "PRIVMSG", "NOTICE":
msg.Params[0] = dc.marshalEntity(uc, msg.Params[0]) msg.Params[0] = dc.marshalEntity(uc.network, msg.Params[0])
case "NICK": case "NICK":
// Nick change for another user // Nick change for another user
msg.Params[0] = dc.marshalEntity(uc, msg.Params[0]) msg.Params[0] = dc.marshalEntity(uc.network, msg.Params[0])
case "JOIN", "PART": case "JOIN", "PART":
msg.Params[0] = dc.marshalEntity(uc, msg.Params[0]) msg.Params[0] = dc.marshalEntity(uc.network, msg.Params[0])
case "KICK": case "KICK":
msg.Params[0] = dc.marshalEntity(uc, msg.Params[0]) msg.Params[0] = dc.marshalEntity(uc.network, msg.Params[0])
msg.Params[1] = dc.marshalEntity(uc, msg.Params[1]) msg.Params[1] = dc.marshalEntity(uc.network, msg.Params[1])
case "TOPIC": case "TOPIC":
msg.Params[0] = dc.marshalEntity(uc, msg.Params[0]) msg.Params[0] = dc.marshalEntity(uc.network, msg.Params[0])
case "MODE": case "MODE":
msg.Params[0] = dc.marshalEntity(uc, msg.Params[0]) msg.Params[0] = dc.marshalEntity(uc.network, msg.Params[0])
case "QUIT": case "QUIT":
// This space is intentinally left blank // This space is intentinally left blank
default: default:
@ -662,7 +674,7 @@ func (dc *downstreamConn) welcome() error {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.prefix(), Prefix: dc.prefix(),
Command: "JOIN", Command: "JOIN",
Params: []string{dc.marshalEntity(ch.conn, ch.Name)}, Params: []string{dc.marshalEntity(ch.conn.network, ch.Name)},
}) })
forwardChannel(dc, ch) forwardChannel(dc, ch)
@ -713,7 +725,7 @@ func (dc *downstreamConn) sendNetworkHistory(net *network) {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: "BATCH", Command: "BATCH",
Params: []string{"+" + batchRef, "chathistory", dc.marshalEntity(uc, target)}, Params: []string{"+" + batchRef, "chathistory", dc.marshalEntity(net, target)},
}) })
} }

View File

@ -764,7 +764,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
modeStr, modeParams := ch.modes.Format() modeStr, modeParams := ch.modes.Format()
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
params := []string{dc.nick, dc.marshalEntity(uc, channel), modeStr} params := []string{dc.nick, dc.marshalEntity(uc.network, channel), modeStr}
params = append(params, modeParams...) params = append(params, modeParams...)
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
@ -826,7 +826,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_LIST, Command: irc.RPL_LIST,
Params: []string{dc.nick, dc.marshalEntity(uc, channel), clients, topic}, Params: []string{dc.nick, dc.marshalEntity(uc.network, channel), clients, topic},
}) })
}) })
case irc.RPL_LISTEND: case irc.RPL_LISTEND:
@ -844,11 +844,11 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
if !ok { if !ok {
// NAMES on a channel we have not joined, forward to downstream // NAMES on a channel we have not joined, forward to downstream
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
channel := dc.marshalEntity(uc, name) channel := dc.marshalEntity(uc.network, name)
members := splitSpace(members) members := splitSpace(members)
for i, member := range members { for i, member := range members {
membership, nick := uc.parseMembershipPrefix(member) membership, nick := uc.parseMembershipPrefix(member)
members[i] = membership.String() + dc.marshalEntity(uc, nick) members[i] = membership.String() + dc.marshalEntity(uc.network, nick)
} }
memberStr := strings.Join(members, " ") memberStr := strings.Join(members, " ")
@ -881,7 +881,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
if !ok { if !ok {
// NAMES on a channel we have not joined, forward to downstream // NAMES on a channel we have not joined, forward to downstream
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
channel := dc.marshalEntity(uc, name) channel := dc.marshalEntity(uc.network, name)
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
@ -922,9 +922,9 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
channel := channel channel := channel
if channel != "*" { if channel != "*" {
channel = dc.marshalEntity(uc, channel) channel = dc.marshalEntity(uc.network, channel)
} }
nick := dc.marshalEntity(uc, nick) nick := dc.marshalEntity(uc.network, nick)
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_WHOREPLY, Command: irc.RPL_WHOREPLY,
@ -941,7 +941,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
name := name name := name
if name != "*" { if name != "*" {
// TODO: support WHO masks // TODO: support WHO masks
name = dc.marshalEntity(uc, name) name = dc.marshalEntity(uc.network, name)
} }
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
@ -956,7 +956,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
nick := dc.marshalEntity(uc, nick) nick := dc.marshalEntity(uc.network, nick)
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_WHOISUSER, Command: irc.RPL_WHOISUSER,
@ -970,7 +970,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
nick := dc.marshalEntity(uc, nick) nick := dc.marshalEntity(uc.network, nick)
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_WHOISSERVER, Command: irc.RPL_WHOISSERVER,
@ -984,7 +984,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
nick := dc.marshalEntity(uc, nick) nick := dc.marshalEntity(uc.network, nick)
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_WHOISOPERATOR, Command: irc.RPL_WHOISOPERATOR,
@ -998,7 +998,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
nick := dc.marshalEntity(uc, nick) nick := dc.marshalEntity(uc.network, nick)
params := []string{dc.nick, nick} params := []string{dc.nick, nick}
params = append(params, msg.Params[2:]...) params = append(params, msg.Params[2:]...)
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
@ -1015,11 +1015,11 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
channels := splitSpace(channelList) channels := splitSpace(channelList)
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
nick := dc.marshalEntity(uc, nick) nick := dc.marshalEntity(uc.network, nick)
channelList := make([]string, len(channels)) channelList := make([]string, len(channels))
for i, channel := range channels { for i, channel := range channels {
prefix, channel := uc.parseMembershipPrefix(channel) prefix, channel := uc.parseMembershipPrefix(channel)
channel = dc.marshalEntity(uc, channel) channel = dc.marshalEntity(uc.network, channel)
channelList[i] = prefix.String() + channel channelList[i] = prefix.String() + channel
} }
channels := strings.Join(channelList, " ") channels := strings.Join(channelList, " ")
@ -1036,7 +1036,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
nick := dc.marshalEntity(uc, nick) nick := dc.marshalEntity(uc.network, nick)
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_ENDOFWHOIS, Command: irc.RPL_ENDOFWHOIS,
@ -1076,9 +1076,9 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.marshalUserPrefix(uc, msg.Prefix), Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix),
Command: "INVITE", Command: "INVITE",
Params: []string{dc.marshalEntity(uc, nick), dc.marshalEntity(uc, channel)}, Params: []string{dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)},
}) })
}) })
case irc.RPL_INVITING: case irc.RPL_INVITING:
@ -1092,7 +1092,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_INVITING, Command: irc.RPL_INVITING,
Params: []string{dc.nick, dc.marshalEntity(uc, nick), dc.marshalEntity(uc, channel)}, Params: []string{dc.nick, dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)},
}) })
}) })
case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN: case irc.ERR_UNKNOWNCOMMAND, irc.RPL_TRYAGAIN: