Sort and split JOIN messages

Sort channels so that channels with a key appear first. Split JOIN
messages so that we don't reach the message size limit.
This commit is contained in:
Simon Ser 2020-07-06 11:06:20 +02:00
parent 4c8b01fb51
commit c490705fee
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
2 changed files with 82 additions and 12 deletions

78
irc.go
View File

@ -2,6 +2,7 @@ package soju
import ( import (
"fmt" "fmt"
"sort"
"strings" "strings"
"gopkg.in/irc.v3" "gopkg.in/irc.v3"
@ -18,6 +19,9 @@ const (
const maxMessageLength = 512 const maxMessageLength = 512
// The server-time layout, as defined in the IRCv3 spec.
const serverTimeLayout = "2006-01-02T15:04:05.000Z"
type userModes string type userModes string
func (ms userModes) Has(c byte) bool { func (ms userModes) Has(c byte) bool {
@ -293,5 +297,75 @@ type batch struct {
Label string Label string
} }
// The server-time layout, as defined in the IRCv3 spec. func join(channels, keys []string) []*irc.Message {
const serverTimeLayout = "2006-01-02T15:04:05.000Z" // Put channels with a key first
js := joinSorter{channels, keys}
sort.Sort(&js)
// Two spaces because there are three words (JOIN, channels and keys)
maxLength := maxMessageLength - (len("JOIN") + 2)
var msgs []*irc.Message
var channelsBuf, keysBuf strings.Builder
for i, channel := range channels {
key := keys[i]
n := channelsBuf.Len() + keysBuf.Len() + 1 + len(channel)
if key != "" {
n += 1 + len(key)
}
if channelsBuf.Len() > 0 && n > maxLength {
// No room for the new channel in this message
params := []string{channelsBuf.String()}
if keysBuf.Len() > 0 {
params = append(params, keysBuf.String())
}
msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
channelsBuf.Reset()
keysBuf.Reset()
}
if channelsBuf.Len() > 0 {
channelsBuf.WriteByte(',')
}
channelsBuf.WriteString(channel)
if key != "" {
if keysBuf.Len() > 0 {
keysBuf.WriteByte(',')
}
keysBuf.WriteString(key)
}
}
if channelsBuf.Len() > 0 {
params := []string{channelsBuf.String()}
if keysBuf.Len() > 0 {
params = append(params, keysBuf.String())
}
msgs = append(msgs, &irc.Message{Command: "JOIN", Params: params})
}
return msgs
}
type joinSorter struct {
channels []string
keys []string
}
func (js *joinSorter) Len() int {
return len(js.channels)
}
func (js *joinSorter) Less(i, j int) bool {
if (js.keys[i] != "") != (js.keys[j] != "") {
// Only one of the channels has a key
return js.keys[i] != ""
}
return js.channels[i] < js.channels[j]
}
func (js *joinSorter) Swap(i, j int) {
js.channels[i], js.channels[j] = js.channels[j], js.channels[i]
js.keys[i], js.keys[j] = js.keys[j], js.keys[i]
}

View File

@ -553,19 +553,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
}) })
if len(uc.network.channels) > 0 { if len(uc.network.channels) > 0 {
// TODO: split this into multiple messages if need be var channels, keys []string
var names, keys []string
for _, ch := range uc.network.channels { for _, ch := range uc.network.channels {
names = append(names, ch.Name) channels = append(channels, ch.Name)
keys = append(keys, ch.Key) keys = append(keys, ch.Key)
} }
uc.SendMessage(&irc.Message{
Command: "JOIN", for _, msg := range join(channels, keys) {
Params: []string{ uc.SendMessage(msg)
strings.Join(names, ","), }
strings.Join(keys, ","),
},
})
} }
case irc.RPL_MYINFO: case irc.RPL_MYINFO:
if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil { if err := parseMessageParams(msg, nil, &uc.serverName, nil, &uc.availableUserModes, nil); err != nil {