Implement casemapping

TL;DR: supports for casemapping, now logs are saved in
casemapped/canonical/tolower form
(eg. in the #channel directory instead of #Channel... or something)

== What is casemapping? ==

see <https://modern.ircdocs.horse/#casemapping-parameter>

== Casemapping and multi-upstream ==

Since each upstream does not necessarily use the same casemapping, and
since casemappings cannot coexist [0],

1. soju must also update the database accordingly to upstreams'
   casemapping, otherwise it will end up inconsistent,
2. soju must "normalize" entity names and expose only one casemapping
   that is a subset of all supported casemappings (here, ascii).

[0] On some upstreams, "emersion[m]" and "emersion{m}" refer to the same
user (upstreams that advertise rfc1459 for example), while on others
(upstreams that advertise ascii) they don't.

Once upstream's casemapping is known (default to rfc1459), entity names
in map keys are made into casemapped form, for upstreamConn,
upstreamChannel and network.

downstreamConn advertises "CASEMAPPING=ascii", and always casemap map
keys with ascii.

Some functions require the caller to casemap their argument (to avoid
needless calls to casemapping functions).

== Message forwarding and casemapping ==

downstream message handling (joins and parts basically):
When relaying entity names from downstreams to upstreams, soju uses the
upstream casemapping, in order to not get in the way of the user.  This
does not brings any issue, as long as soju replies with the ascii
casemapping in mind (solves point 1.).

marshalEntity/marshalUserPrefix:
When relaying entity names from upstreams with non-ascii casemappings,
soju *partially* casemap them: it only change the case of characters
which are not ascii letters.  ASCII case is thus kept intact, while
special symbols like []{} are the same every time soju sends them to
downstreams (solves point 2.).

== Casemapping changes ==

Casemapping changes are not fully supported by this patch and will
result in loss of history.  This is a limitation of the protocol and
should be solved by the RENAME spec.
This commit is contained in:
Hubert Hirtz 2021-03-16 10:00:34 +01:00 committed by Simon Ser
parent 56bf73716d
commit bdd0c7bc06
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
6 changed files with 379 additions and 112 deletions

View File

@ -54,7 +54,9 @@ func sendNames(dc *downstreamConn, ch *upstreamChannel) {
maxLength := maxMessageLength - len(emptyNameReply.String()) maxLength := maxMessageLength - len(emptyNameReply.String())
var buf strings.Builder var buf strings.Builder
for nick, memberships := range ch.Members { for _, entry := range ch.Members.innerMap {
nick := entry.originalKey
memberships := entry.value.(*memberships)
s := memberships.Format(dc) + dc.marshalEntity(ch.conn.network, nick) s := memberships.Format(dc) + dc.marshalEntity(ch.conn.network, nick)
n := buf.Len() + 1 + len(s) n := buf.Len() + 1 + len(s)

View File

@ -116,6 +116,7 @@ type downstreamConn struct {
registered bool registered bool
user *user user *user
nick string nick string
nickCM string
rawUsername string rawUsername string
networkName string networkName string
clientName string clientName string
@ -192,13 +193,13 @@ func (dc *downstreamConn) upstream() *upstreamConn {
func isOurNick(net *network, nick string) bool { func isOurNick(net *network, nick string) bool {
// TODO: this doesn't account for nick changes // TODO: this doesn't account for nick changes
if net.conn != nil { if net.conn != nil {
return nick == net.conn.nick return net.casemap(nick) == net.conn.nickCM
} }
// We're not currently connected to the upstream connection, so we don't // We're not currently connected to the upstream connection, so we don't
// know whether this name is our nickname. Best-effort: use the network's // know whether this name is our nickname. Best-effort: use the network's
// configured nickname and hope it was the one being used when we were // configured nickname and hope it was the one being used when we were
// connected. // connected.
return nick == net.Nick return net.casemap(nick) == net.casemap(net.Nick)
} }
// marshalEntity converts an upstream entity name (ie. channel or nick) into a // marshalEntity converts an upstream entity name (ie. channel or nick) into a
@ -210,6 +211,7 @@ func (dc *downstreamConn) marshalEntity(net *network, name string) string {
if isOurNick(net, name) { if isOurNick(net, name) {
return dc.nick return dc.nick
} }
name = partialCasemap(net.casemap, name)
if dc.network != nil { if dc.network != nil {
if dc.network != net { if dc.network != net {
panic("soju: tried to marshal an entity for another network") panic("soju: tried to marshal an entity for another network")
@ -223,6 +225,7 @@ func (dc *downstreamConn) marshalUserPrefix(net *network, prefix *irc.Prefix) *i
if isOurNick(net, prefix.Name) { if isOurNick(net, prefix.Name) {
return dc.prefix() return dc.prefix()
} }
prefix.Name = partialCasemap(net.casemap, prefix.Name)
if dc.network != nil { if dc.network != nil {
if dc.network != net { if dc.network != net {
panic("soju: tried to marshal a user prefix for another network") panic("soju: tried to marshal a user prefix for another network")
@ -358,8 +361,8 @@ func (dc *downstreamConn) ackMsgID(id string) {
return return
} }
delivered, ok := network.delivered[entity] delivered := network.delivered.Value(entity)
if !ok { if delivered == nil {
return return
} }
@ -445,13 +448,15 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
Params: []string{dc.nick, nick, "contains illegal characters"}, Params: []string{dc.nick, nick, "contains illegal characters"},
}} }}
} }
if nick == serviceNick { nickCM := casemapASCII(nick)
if nickCM == serviceNickCM {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_NICKNAMEINUSE, Command: irc.ERR_NICKNAMEINUSE,
Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"}, Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"},
}} }}
} }
dc.nick = nick dc.nick = nick
dc.nickCM = nickCM
case "USER": case "USER":
if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil { if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil {
return err return err
@ -766,6 +771,7 @@ func (dc *downstreamConn) updateNick() {
Params: []string{uc.nick}, Params: []string{uc.nick},
}) })
dc.nick = uc.nick dc.nick = uc.nick
dc.nickCM = casemapASCII(dc.nick)
} }
} }
@ -911,6 +917,7 @@ func (dc *downstreamConn) welcome() error {
isupport := []string{ isupport := []string{
fmt.Sprintf("CHATHISTORY=%v", dc.srv.HistoryLimit), fmt.Sprintf("CHATHISTORY=%v", dc.srv.HistoryLimit),
"CASEMAPPING=ascii",
} }
if uc := dc.upstream(); uc != nil { if uc := dc.upstream(); uc != nil {
@ -960,11 +967,13 @@ func (dc *downstreamConn) welcome() error {
dc.updateSupportedCaps() dc.updateSupportedCaps()
dc.forEachUpstream(func(uc *upstreamConn) { dc.forEachUpstream(func(uc *upstreamConn) {
for _, ch := range uc.channels { for _, entry := range uc.channels.innerMap {
ch := entry.value.(*upstreamChannel)
if !ch.complete { if !ch.complete {
continue continue
} }
if record, ok := uc.network.channels[ch.Name]; ok && record.Detached { record := uc.network.channels.Value(ch.Name)
if record != nil && record.Detached {
continue continue
} }
@ -987,8 +996,10 @@ func (dc *downstreamConn) welcome() error {
} }
// Fast-forward history to last message // Fast-forward history to last message
for target, delivered := range net.delivered { for target, entry := range net.delivered.innerMap {
if ch, ok := net.channels[target]; ok && ch.Detached { delivered := entry.value.(map[string]string)
ch := net.channels.Value(target)
if ch != nil && ch.Detached {
continue continue
} }
@ -1019,7 +1030,8 @@ func (dc *downstreamConn) messageSupportsHistory(msg *irc.Message) bool {
} }
func (dc *downstreamConn) sendNetworkBacklog(net *network) { func (dc *downstreamConn) sendNetworkBacklog(net *network) {
for target := range net.delivered { for _, entry := range net.delivered.innerMap {
target := entry.originalKey
dc.sendTargetBacklog(net, target) dc.sendTargetBacklog(net, target)
} }
} }
@ -1028,11 +1040,11 @@ func (dc *downstreamConn) sendTargetBacklog(net *network, target string) {
if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
return return
} }
if ch, ok := net.channels[target]; ok && ch.Detached { if ch := net.channels.Value(target); ch != nil && ch.Detached {
return return
} }
delivered, ok := net.delivered[target] delivered := net.delivered.Value(target)
if !ok { if delivered == nil {
return return
} }
lastDelivered, ok := delivered[dc.clientName] lastDelivered, ok := delivered[dc.clientName]
@ -1158,7 +1170,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Params: []string{dc.nick, rawNick, "contains illegal characters"}, Params: []string{dc.nick, rawNick, "contains illegal characters"},
}} }}
} }
if nick == serviceNick { if casemapASCII(nick) == serviceNickCM {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_NICKNAMEINUSE, Command: irc.ERR_NICKNAMEINUSE,
Params: []string{dc.nick, rawNick, "Nickname reserved for bouncer service"}, Params: []string{dc.nick, rawNick, "Nickname reserved for bouncer service"},
@ -1194,6 +1206,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Params: []string{nick}, Params: []string{nick},
}) })
dc.nick = nick dc.nick = nick
dc.nickCM = casemapASCII(dc.nick)
} }
case "JOIN": case "JOIN":
var namesStr string var namesStr string
@ -1226,9 +1239,8 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Params: params, Params: params,
}) })
var ch *Channel ch := uc.network.channels.Value(upstreamName)
var ok bool if ch != nil {
if ch, ok = uc.network.channels[upstreamName]; ok {
// Don't clear the channel key if there's one set // Don't clear the channel key if there's one set
// TODO: add a way to unset the channel key // TODO: add a way to unset the channel key
if key != "" { if key != "" {
@ -1240,7 +1252,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Name: upstreamName, Name: upstreamName,
Key: key, Key: key,
} }
uc.network.channels[upstreamName] = ch uc.network.channels.SetValue(upstreamName, ch)
} }
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil { if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
@ -1264,16 +1276,15 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
} }
if strings.EqualFold(reason, "detach") { if strings.EqualFold(reason, "detach") {
var ch *Channel ch := uc.network.channels.Value(upstreamName)
var ok bool if ch != nil {
if ch, ok = uc.network.channels[upstreamName]; ok {
uc.network.detach(ch) uc.network.detach(ch)
} else { } else {
ch = &Channel{ ch = &Channel{
Name: name, Name: name,
Detached: true, Detached: true,
} }
uc.network.channels[upstreamName] = ch uc.network.channels.SetValue(upstreamName, ch)
} }
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil { if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
@ -1360,7 +1371,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
modeStr = msg.Params[1] modeStr = msg.Params[1]
} }
if name == dc.nick { if casemapASCII(name) == dc.nickCM {
if modeStr != "" { if modeStr != "" {
dc.forEachUpstream(func(uc *upstreamConn) { dc.forEachUpstream(func(uc *upstreamConn) {
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(dc.id, &irc.Message{
@ -1398,8 +1409,8 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Params: params, Params: params,
}) })
} else { } else {
ch, ok := uc.channels[upstreamName] ch := uc.channels.Value(upstreamName)
if !ok { if ch == nil {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL, Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{dc.nick, name, "No such channel"}, Params: []string{dc.nick, name, "No such channel"},
@ -1435,7 +1446,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
return err return err
} }
uc, upstreamChannel, err := dc.unmarshalEntity(channel) uc, upstreamName, err := dc.unmarshalEntity(channel)
if err != nil { if err != nil {
return err return err
} }
@ -1444,14 +1455,14 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
topic := msg.Params[1] topic := msg.Params[1]
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(dc.id, &irc.Message{
Command: "TOPIC", Command: "TOPIC",
Params: []string{upstreamChannel, topic}, Params: []string{upstreamName, topic},
}) })
} else { // getting topic } else { // getting topic
ch, ok := uc.channels[upstreamChannel] ch := uc.channels.Value(upstreamName)
if !ok { if ch == nil {
return ircError{&irc.Message{ return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL, Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{dc.nick, upstreamChannel, "No such channel"}, Params: []string{dc.nick, upstreamName, "No such channel"},
}} }}
} }
sendTopic(dc, ch) sendTopic(dc, ch)
@ -1513,19 +1524,19 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
channels := strings.Split(msg.Params[0], ",") channels := strings.Split(msg.Params[0], ",")
for _, channel := range channels { for _, channel := range channels {
uc, upstreamChannel, err := dc.unmarshalEntity(channel) uc, upstreamName, err := dc.unmarshalEntity(channel)
if err != nil { if err != nil {
return err return err
} }
ch, ok := uc.channels[upstreamChannel] ch := uc.channels.Value(upstreamName)
if ok { if ch != nil {
sendNames(dc, ch) sendNames(dc, ch)
} else { } else {
// NAMES on a channel we have not joined, ask upstream // NAMES on a channel we have not joined, ask upstream
uc.SendMessageLabeled(dc.id, &irc.Message{ uc.SendMessageLabeled(dc.id, &irc.Message{
Command: "NAMES", Command: "NAMES",
Params: []string{upstreamChannel}, Params: []string{upstreamName},
}) })
} }
} }
@ -1542,8 +1553,9 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
// TODO: support WHO masks // TODO: support WHO masks
entity := msg.Params[0] entity := msg.Params[0]
entityCM := casemapASCII(entity)
if entity == dc.nick { if entityCM == dc.nickCM {
// TODO: support AWAY (H/G) in self WHO reply // TODO: support AWAY (H/G) in self WHO reply
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
@ -1557,7 +1569,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}) })
return nil return nil
} }
if entity == serviceNick { if entityCM == serviceNickCM {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_WHOREPLY, Command: irc.RPL_WHOREPLY,
@ -1608,7 +1620,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
mask = mask[:i] mask = mask[:i]
} }
if mask == dc.nick { if casemapASCII(mask) == dc.nickCM {
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),
Command: irc.RPL_WHOISUSER, Command: irc.RPL_WHOISUSER,

197
irc.go
View File

@ -121,12 +121,13 @@ outer:
return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode) return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode)
} }
member := arguments[nextArgument] member := arguments[nextArgument]
if _, ok := ch.Members[member]; ok { m := ch.Members.Value(member)
if m != nil {
if plusMinus == '+' { if plusMinus == '+' {
ch.Members[member].Add(ch.conn.availableMemberships, membership) m.Add(ch.conn.availableMemberships, membership)
} else { } else {
// TODO: for upstreams without multi-prefix, query the user modes again // TODO: for upstreams without multi-prefix, query the user modes again
ch.Members[member].Remove(membership) m.Remove(membership)
} }
} }
needMarshaling[nextArgument] = struct{}{} needMarshaling[nextArgument] = struct{}{}
@ -418,3 +419,193 @@ func parseCTCPMessage(msg *irc.Message) (cmd string, params string, ok bool) {
return cmd, params, true return cmd, params, true
} }
type casemapping func(string) string
func casemapNone(name string) string {
return name
}
// CasemapASCII of name is the canonical representation of name according to the
// ascii casemapping.
func casemapASCII(name string) string {
var sb strings.Builder
sb.Grow(len(name))
for _, r := range name {
if 'A' <= r && r <= 'Z' {
r += 'a' - 'A'
}
sb.WriteRune(r)
}
return sb.String()
}
// casemapRFC1459 of name is the canonical representation of name according to the
// rfc1459 casemapping.
func casemapRFC1459(name string) string {
var sb strings.Builder
sb.Grow(len(name))
for _, r := range name {
if 'A' <= r && r <= 'Z' {
r += 'a' - 'A'
} else if r == '{' {
r = '['
} else if r == '}' {
r = ']'
} else if r == '\\' {
r = '|'
} else if r == '~' {
r = '^'
}
sb.WriteRune(r)
}
return sb.String()
}
// casemapRFC1459Strict of name is the canonical representation of name
// according to the rfc1459-strict casemapping.
func casemapRFC1459Strict(name string) string {
var sb strings.Builder
sb.Grow(len(name))
for _, r := range name {
if 'A' <= r && r <= 'Z' {
r += 'a' - 'A'
} else if r == '{' {
r = '['
} else if r == '}' {
r = ']'
} else if r == '\\' {
r = '|'
}
sb.WriteRune(r)
}
return sb.String()
}
func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) {
switch tokenValue {
case "ascii":
casemap = casemapASCII
case "rfc1459":
casemap = casemapRFC1459
case "rfc1459-strict":
casemap = casemapRFC1459Strict
default:
return nil, false
}
return casemap, true
}
func partialCasemap(higher casemapping, name string) string {
nameFullyCM := higher(name)
var sb strings.Builder
sb.Grow(len(name))
for i, r := range nameFullyCM {
if 'a' <= r && r <= 'z' {
r = rune(name[i])
}
sb.WriteRune(r)
}
return sb.String()
}
type casemapMap struct {
innerMap map[string]casemapEntry
casemap casemapping
}
type casemapEntry struct {
originalKey string
value interface{}
}
func newCasemapMap(size int) casemapMap {
return casemapMap{
innerMap: make(map[string]casemapEntry, size),
casemap: casemapNone,
}
}
func (cm *casemapMap) OriginalKey(name string) (key string, ok bool) {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
return "", false
}
return entry.originalKey, true
}
func (cm *casemapMap) Has(name string) bool {
_, ok := cm.innerMap[cm.casemap(name)]
return ok
}
func (cm *casemapMap) Len() int {
return len(cm.innerMap)
}
func (cm *casemapMap) SetValue(name string, value interface{}) {
nameCM := cm.casemap(name)
entry, ok := cm.innerMap[nameCM]
if !ok {
cm.innerMap[nameCM] = casemapEntry{
originalKey: name,
value: value,
}
return
}
entry.value = value
cm.innerMap[nameCM] = entry
}
func (cm *casemapMap) Delete(name string) {
delete(cm.innerMap, cm.casemap(name))
}
func (cm *casemapMap) SetCasemapping(newCasemap casemapping) {
cm.casemap = newCasemap
newInnerMap := make(map[string]casemapEntry, len(cm.innerMap))
for _, entry := range cm.innerMap {
newInnerMap[cm.casemap(entry.originalKey)] = entry
}
cm.innerMap = newInnerMap
}
type upstreamChannelCasemapMap struct{ casemapMap }
func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
return nil
}
return entry.value.(*upstreamChannel)
}
type channelCasemapMap struct{ casemapMap }
func (cm *channelCasemapMap) Value(name string) *Channel {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
return nil
}
return entry.value.(*Channel)
}
type membershipsCasemapMap struct{ casemapMap }
func (cm *membershipsCasemapMap) Value(name string) *memberships {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
return nil
}
return entry.value.(*memberships)
}
type mapStringStringCasemapMap struct{ casemapMap }
func (cm *mapStringStringCasemapMap) Value(name string) map[string]string {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
return nil
}
return entry.value.(map[string]string)
}

View File

@ -27,6 +27,7 @@ import (
) )
const serviceNick = "BouncerServ" const serviceNick = "BouncerServ"
const serviceNickCM = "bouncerserv"
const serviceRealname = "soju bouncer service" const serviceRealname = "soju bouncer service"
var servicePrefix = &irc.Prefix{ var servicePrefix = &irc.Prefix{
@ -408,7 +409,7 @@ func handleServiceNetworkStatus(dc *downstreamConn, params []string) error {
} else { } else {
statuses = append(statuses, "connected") statuses = append(statuses, "connected")
} }
details = fmt.Sprintf("%v channels", len(uc.channels)) details = fmt.Sprintf("%v channels", uc.channels.Len())
} else { } else {
statuses = append(statuses, "disconnected") statuses = append(statuses, "disconnected")
if net.lastError != nil { if net.lastError != nil {
@ -768,8 +769,8 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
return fmt.Errorf("unknown channel %q", name) return fmt.Errorf("unknown channel %q", name)
} }
ch, ok := uc.network.channels[upstreamName] ch := uc.network.channels.Value(upstreamName)
if !ok { if ch == nil {
return fmt.Errorf("unknown channel %q", name) return fmt.Errorf("unknown channel %q", name)
} }

View File

@ -48,7 +48,7 @@ type upstreamChannel struct {
Status channelStatus Status channelStatus
modes channelModes modes channelModes
creationTime string creationTime string
Members map[string]*memberships Members membershipsCasemapMap
complete bool complete bool
detachTimer *time.Timer detachTimer *time.Timer
} }
@ -86,10 +86,11 @@ type upstreamConn struct {
registered bool registered bool
nick string nick string
nickCM string
username string username string
realname string realname string
modes userModes modes userModes
channels map[string]*upstreamChannel channels upstreamChannelCasemapMap
supportedCaps map[string]string supportedCaps map[string]string
caps map[string]bool caps map[string]bool
batches map[string]batch batches map[string]batch
@ -99,6 +100,8 @@ type upstreamConn struct {
saslClient sasl.Client saslClient sasl.Client
saslStarted bool saslStarted bool
casemapIsSet bool
// set of LIST commands in progress, per downstream // set of LIST commands in progress, per downstream
pendingLISTDownstreamSet map[uint64]struct{} pendingLISTDownstreamSet map[uint64]struct{}
} }
@ -186,7 +189,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options), conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options),
network: network, network: network,
user: network.user, user: network.user,
channels: make(map[string]*upstreamChannel), channels: upstreamChannelCasemapMap{newCasemapMap(0)},
supportedCaps: make(map[string]string), supportedCaps: make(map[string]string),
caps: make(map[string]bool), caps: make(map[string]bool),
batches: make(map[string]batch), batches: make(map[string]batch),
@ -213,8 +216,8 @@ func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)
} }
func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
ch, ok := uc.channels[name] ch := uc.channels.Value(name)
if !ok { if ch == nil {
return nil, fmt.Errorf("unknown channel %q", name) return nil, fmt.Errorf("unknown channel %q", name)
} }
return ch, nil return ch, nil
@ -224,6 +227,10 @@ func (uc *upstreamConn) isChannel(entity string) bool {
return strings.ContainsRune(uc.availableChannelTypes, rune(entity[0])) return strings.ContainsRune(uc.availableChannelTypes, rune(entity[0]))
} }
func (uc *upstreamConn) isOurNick(nick string) bool {
return uc.nickCM == uc.network.casemap(nick)
}
func (uc *upstreamConn) getPendingLIST() *pendingLIST { func (uc *upstreamConn) getPendingLIST() *pendingLIST {
for _, pl := range uc.user.pendingLISTs { for _, pl := range uc.user.pendingLISTs {
if _, ok := pl.pendingCommands[uc.network.ID]; !ok { if _, ok := pl.pendingCommands[uc.network.ID]; !ok {
@ -413,11 +420,12 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
uc.produce("", msg, nil) uc.produce("", msg, nil)
} else { // regular user message } else { // regular user message
target := entity target := entity
if target == uc.nick { if uc.isOurNick(target) {
target = msg.Prefix.Name target = msg.Prefix.Name
} }
if ch, ok := uc.network.channels[target]; ok { ch := uc.network.channels.Value(target)
if ch != nil {
if ch.Detached { if ch.Detached {
uc.handleDetachedMessage(msg.Prefix.Name, text, ch) uc.handleDetachedMessage(msg.Prefix.Name, text, ch)
} }
@ -590,9 +598,10 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
uc.registered = true uc.registered = true
uc.logger.Printf("connection registered") uc.logger.Printf("connection registered")
if len(uc.network.channels) > 0 { if uc.network.channels.Len() > 0 {
var channels, keys []string var channels, keys []string
for _, ch := range uc.network.channels { for _, entry := range uc.network.channels.innerMap {
ch := entry.value.(*Channel)
channels = append(channels, ch.Name) channels = append(channels, ch.Name)
keys = append(keys, ch.Key) keys = append(keys, ch.Key)
} }
@ -634,6 +643,14 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
var err error var err error
switch parameter { switch parameter {
case "CASEMAPPING":
casemap, ok := parseCasemappingToken(value)
if !ok {
casemap = casemapRFC1459
}
uc.network.updateCasemapping(casemap)
uc.nickCM = uc.network.casemap(uc.nick)
uc.casemapIsSet = true
case "CHANMODES": case "CHANMODES":
if !negate { if !negate {
err = uc.handleChanModes(value) err = uc.handleChanModes(value)
@ -671,6 +688,14 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
dc.SendMessage(msg) dc.SendMessage(msg)
} }
}) })
case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD:
if !uc.casemapIsSet {
// upstream did not send any CASEMAPPING token, thus
// we assume it implements the old RFCs with rfc1459.
uc.casemapIsSet = true
uc.network.updateCasemapping(casemapRFC1459)
uc.nickCM = uc.network.casemap(uc.nick)
}
case "BATCH": case "BATCH":
var tag string var tag string
if err := parseMessageParams(msg, &tag); err != nil { if err := parseMessageParams(msg, &tag); err != nil {
@ -716,16 +741,19 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
me := false me := false
if msg.Prefix.Name == uc.nick { if uc.isOurNick(msg.Prefix.Name) {
uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick) uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
me = true me = true
uc.nick = newNick uc.nick = newNick
uc.nickCM = uc.network.casemap(uc.nick)
} }
for _, ch := range uc.channels { for _, entry := range uc.channels.innerMap {
if memberships, ok := ch.Members[msg.Prefix.Name]; ok { ch := entry.value.(*upstreamChannel)
delete(ch.Members, msg.Prefix.Name) memberships := ch.Members.Value(msg.Prefix.Name)
ch.Members[newNick] = memberships if memberships != nil {
ch.Members.Delete(msg.Prefix.Name)
ch.Members.SetValue(newNick, memberships)
uc.appendLog(ch.Name, msg) uc.appendLog(ch.Name, msg)
} }
} }
@ -750,13 +778,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
for _, ch := range strings.Split(channels, ",") { for _, ch := range strings.Split(channels, ",") {
if msg.Prefix.Name == uc.nick { if uc.isOurNick(msg.Prefix.Name) {
uc.logger.Printf("joined channel %q", ch) uc.logger.Printf("joined channel %q", ch)
uc.channels[ch] = &upstreamChannel{ members := membershipsCasemapMap{newCasemapMap(0)}
members.casemap = uc.network.casemap
uc.channels.SetValue(ch, &upstreamChannel{
Name: ch, Name: ch,
conn: uc, conn: uc,
Members: make(map[string]*memberships), Members: members,
} })
uc.updateChannelAutoDetach(ch) uc.updateChannelAutoDetach(ch)
uc.SendMessage(&irc.Message{ uc.SendMessage(&irc.Message{
@ -768,7 +798,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
if err != nil { if err != nil {
return err return err
} }
ch.Members[msg.Prefix.Name] = &memberships{} ch.Members.SetValue(msg.Prefix.Name, &memberships{})
} }
chMsg := msg.Copy() chMsg := msg.Copy()
@ -786,10 +816,11 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
for _, ch := range strings.Split(channels, ",") { for _, ch := range strings.Split(channels, ",") {
if msg.Prefix.Name == uc.nick { if uc.isOurNick(msg.Prefix.Name) {
uc.logger.Printf("parted channel %q", ch) uc.logger.Printf("parted channel %q", ch)
if uch, ok := uc.channels[ch]; ok { uch := uc.channels.Value(ch)
delete(uc.channels, ch) if uch != nil {
uc.channels.Delete(ch)
uch.updateAutoDetach(0) uch.updateAutoDetach(0)
} }
} else { } else {
@ -797,7 +828,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
if err != nil { if err != nil {
return err return err
} }
delete(ch.Members, msg.Prefix.Name) ch.Members.Delete(msg.Prefix.Name)
} }
chMsg := msg.Copy() chMsg := msg.Copy()
@ -814,15 +845,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err return err
} }
if user == uc.nick { if uc.isOurNick(user) {
uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name) uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name)
delete(uc.channels, channel) uc.channels.Delete(channel)
} else { } else {
ch, err := uc.getChannel(channel) ch, err := uc.getChannel(channel)
if err != nil { if err != nil {
return err return err
} }
delete(ch.Members, user) ch.Members.Delete(user)
} }
uc.produce(channel, msg, nil) uc.produce(channel, msg, nil)
@ -831,13 +862,14 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return fmt.Errorf("expected a prefix") return fmt.Errorf("expected a prefix")
} }
if msg.Prefix.Name == uc.nick { if uc.isOurNick(msg.Prefix.Name) {
uc.logger.Printf("quit") uc.logger.Printf("quit")
} }
for _, ch := range uc.channels { for _, entry := range uc.channels.innerMap {
if _, ok := ch.Members[msg.Prefix.Name]; ok { ch := entry.value.(*upstreamChannel)
delete(ch.Members, msg.Prefix.Name) if ch.Members.Has(msg.Prefix.Name) {
ch.Members.Delete(msg.Prefix.Name)
uc.appendLog(ch.Name, msg) uc.appendLog(ch.Name, msg)
} }
@ -908,7 +940,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
uc.appendLog(ch.Name, msg) uc.appendLog(ch.Name, msg)
if ch, ok := uc.network.channels[name]; !ok || !ch.Detached { c := uc.network.channels.Value(name)
if c == nil || !c.Detached {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
params := make([]string, len(msg.Params)) params := make([]string, len(msg.Params))
params[0] = dc.marshalEntity(uc.network, name) params[0] = dc.marshalEntity(uc.network, name)
@ -964,7 +997,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err return err
} }
if firstMode { if firstMode {
if c, ok := uc.network.channels[channel]; !ok || !c.Detached { c := uc.network.channels.Value(channel)
if c == nil || !c.Detached {
modeStr, modeParams := ch.modes.Format() modeStr, modeParams := ch.modes.Format()
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
@ -1061,8 +1095,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err return err
} }
ch, ok := uc.channels[name] ch := uc.channels.Value(name)
if !ok { if ch == nil {
// NAMES on a channel we have not joined, forward to downstream // NAMES on a channel we have not joined, forward to downstream
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
channel := dc.marshalEntity(uc.network, name) channel := dc.marshalEntity(uc.network, name)
@ -1090,7 +1124,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
for _, s := range splitSpace(members) { for _, s := range splitSpace(members) {
memberships, nick := uc.parseMembershipPrefix(s) memberships, nick := uc.parseMembershipPrefix(s)
ch.Members[nick] = memberships ch.Members.SetValue(nick, memberships)
} }
case irc.RPL_ENDOFNAMES: case irc.RPL_ENDOFNAMES:
var name string var name string
@ -1098,8 +1132,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err return err
} }
ch, ok := uc.channels[name] ch := uc.channels.Value(name)
if !ok { if ch == nil {
// NAMES on a channel we have not joined, forward to downstream // NAMES on a channel we have not joined, forward to downstream
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
channel := dc.marshalEntity(uc.network, name) channel := dc.marshalEntity(uc.network, name)
@ -1118,7 +1152,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
} }
ch.complete = true ch.complete = true
if c, ok := uc.network.channels[name]; !ok || !c.Detached { c := uc.network.channels.Value(name)
if c == nil || !c.Detached {
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
forwardChannel(dc, ch) forwardChannel(dc, ch)
}) })
@ -1272,7 +1307,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err return err
} }
weAreInvited := nick == uc.nick weAreInvited := uc.isOurNick(nick)
uc.forEachDownstream(func(dc *downstreamConn) { uc.forEachDownstream(func(dc *downstreamConn) {
if !weAreInvited && !dc.caps["invite-notify"] { if !weAreInvited && !dc.caps["invite-notify"] {
@ -1395,7 +1430,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
// Ignore // Ignore
case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
// Ignore // Ignore
case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD: case irc.RPL_MOTDSTART, irc.RPL_MOTD:
// Ignore // Ignore
case irc.RPL_LISTSTART: case irc.RPL_LISTSTART:
// Ignore // Ignore
@ -1601,6 +1636,7 @@ func splitSpace(s string) []string {
func (uc *upstreamConn) register() { func (uc *upstreamConn) register() {
uc.nick = uc.network.Nick uc.nick = uc.network.Nick
uc.nickCM = uc.network.casemap(uc.nick)
uc.username = uc.network.GetUsername() uc.username = uc.network.GetUsername()
uc.realname = uc.network.GetRealname() uc.realname = uc.network.GetRealname()
@ -1705,20 +1741,21 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string
} }
detached := false detached := false
if ch, ok := uc.network.channels[entity]; ok { if ch := uc.network.channels.Value(entity); ch != nil {
detached = ch.Detached detached = ch.Detached
} }
delivered, ok := uc.network.delivered[entity] delivered := uc.network.delivered.Value(entity)
if !ok { entityCM := uc.network.casemap(entity)
lastID, err := uc.user.msgStore.LastMsgID(uc.network, entity, time.Now()) if delivered == nil {
lastID, err := uc.user.msgStore.LastMsgID(uc.network, entityCM, time.Now())
if err != nil { if err != nil {
uc.logger.Printf("failed to log message: failed to get last message ID: %v", err) uc.logger.Printf("failed to log message: failed to get last message ID: %v", err)
return "" return ""
} }
delivered = make(map[string]string) delivered = make(map[string]string)
uc.network.delivered[entity] = delivered uc.network.delivered.SetValue(entity, delivered)
for clientName, _ := range uc.network.offlineClients { for clientName, _ := range uc.network.offlineClients {
delivered[clientName] = lastID delivered[clientName] = lastID
@ -1733,7 +1770,7 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string
} }
} }
msgID, err := uc.user.msgStore.Append(uc.network, entity, msg) msgID, err := uc.user.msgStore.Append(uc.network, entityCM, msg)
if err != nil { if err != nil {
uc.logger.Printf("failed to log message: %v", err) uc.logger.Printf("failed to log message: %v", err)
return "" return ""
@ -1754,7 +1791,8 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstr
} }
// Don't forward messages if it's a detached channel // Don't forward messages if it's a detached channel
if ch, ok := uc.network.channels[target]; ok && ch.Detached { ch := uc.network.channels.Value(target)
if ch != nil && ch.Detached {
return return
} }
@ -1789,9 +1827,13 @@ func (uc *upstreamConn) updateAway() {
} }
func (uc *upstreamConn) updateChannelAutoDetach(name string) { func (uc *upstreamConn) updateChannelAutoDetach(name string) {
if uch, ok := uc.channels[name]; ok { uch := uc.channels.Value(name)
if ch, ok := uc.network.channels[name]; ok && !ch.Detached { if uch == nil {
uch.updateAutoDetach(ch.DetachAfter) return
}
} }
ch := uc.network.channels.Value(name)
if ch == nil || ch.Detached {
return
}
uch.updateAutoDetach(ch.DetachAfter)
} }

53
user.go
View File

@ -61,17 +61,18 @@ type network struct {
stopped chan struct{} stopped chan struct{}
conn *upstreamConn conn *upstreamConn
channels map[string]*Channel channels channelCasemapMap
delivered map[string]map[string]string // entity -> client name -> msg ID delivered mapStringStringCasemapMap // entity -> client name -> msg ID
offlineClients map[string]struct{} // indexed by client name offlineClients map[string]struct{} // indexed by client name
lastError error lastError error
casemap casemapping
} }
func newNetwork(user *user, record *Network, channels []Channel) *network { func newNetwork(user *user, record *Network, channels []Channel) *network {
m := make(map[string]*Channel, len(channels)) m := channelCasemapMap{newCasemapMap(0)}
for _, ch := range channels { for _, ch := range channels {
ch := ch ch := ch
m[ch.Name] = &ch m.SetValue(ch.Name, &ch)
} }
return &network{ return &network{
@ -79,8 +80,9 @@ func newNetwork(user *user, record *Network, channels []Channel) *network {
user: user, user: user,
stopped: make(chan struct{}), stopped: make(chan struct{}),
channels: m, channels: m,
delivered: make(map[string]map[string]string), delivered: mapStringStringCasemapMap{newCasemapMap(0)},
offlineClients: make(map[string]struct{}), offlineClients: make(map[string]struct{}),
casemap: casemapRFC1459,
} }
} }
@ -185,7 +187,8 @@ func (net *network) detach(ch *Channel) {
net.user.srv.Logger.Printf("network %q: detaching channel %q", net.GetName(), ch.Name) net.user.srv.Logger.Printf("network %q: detaching channel %q", net.GetName(), ch.Name)
if net.conn != nil { if net.conn != nil {
if uch, ok := net.conn.channels[ch.Name]; ok { uch := net.conn.channels.Value(ch.Name)
if uch != nil {
uch.updateAutoDetach(0) uch.updateAutoDetach(0)
} }
} }
@ -210,7 +213,7 @@ func (net *network) attach(ch *Channel) {
var uch *upstreamChannel var uch *upstreamChannel
if net.conn != nil { if net.conn != nil {
uch = net.conn.channels[ch.Name] uch = net.conn.channels.Value(ch.Name)
net.conn.updateChannelAutoDetach(ch.Name) net.conn.updateChannelAutoDetach(ch.Name)
} }
@ -231,12 +234,13 @@ func (net *network) attach(ch *Channel) {
} }
func (net *network) deleteChannel(name string) error { func (net *network) deleteChannel(name string) error {
ch, ok := net.channels[name] ch := net.channels.Value(name)
if !ok { if ch == nil {
return fmt.Errorf("unknown channel %q", name) return fmt.Errorf("unknown channel %q", name)
} }
if net.conn != nil { if net.conn != nil {
if uch, ok := net.conn.channels[ch.Name]; ok { uch := net.conn.channels.Value(ch.Name)
if uch != nil {
uch.updateAutoDetach(0) uch.updateAutoDetach(0)
} }
} }
@ -244,10 +248,23 @@ func (net *network) deleteChannel(name string) error {
if err := net.user.srv.db.DeleteChannel(ch.ID); err != nil { if err := net.user.srv.db.DeleteChannel(ch.ID); err != nil {
return err return err
} }
delete(net.channels, name) net.channels.Delete(name)
return nil return nil
} }
func (net *network) updateCasemapping(newCasemap casemapping) {
net.casemap = newCasemap
net.channels.SetCasemapping(newCasemap)
net.delivered.SetCasemapping(newCasemap)
if net.conn != nil {
net.conn.channels.SetCasemapping(newCasemap)
for _, entry := range net.conn.channels.innerMap {
uch := entry.value.(*upstreamChannel)
uch.Members.SetCasemapping(newCasemap)
}
}
}
type user struct { type user struct {
User User
srv *Server srv *Server
@ -410,8 +427,8 @@ func (u *user) run() {
} }
case eventChannelDetach: case eventChannelDetach:
uc, name := e.uc, e.name uc, name := e.uc, e.name
c, ok := uc.network.channels[name] c := uc.network.channels.Value(name)
if !ok || c.Detached { if c == nil || c.Detached {
continue continue
} }
uc.network.detach(c) uc.network.detach(c)
@ -499,7 +516,8 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
uc.endPendingLISTs(true) uc.endPendingLISTs(true)
for _, uch := range uc.channels { for _, entry := range uc.channels.innerMap {
uch := entry.value.(*upstreamChannel)
uch.updateAutoDetach(0) uch.updateAutoDetach(0)
} }
@ -570,8 +588,9 @@ func (u *user) updateNetwork(record *Network) (*network, error) {
// Most network changes require us to re-connect to the upstream server // Most network changes require us to re-connect to the upstream server
channels := make([]Channel, 0, len(network.channels)) channels := make([]Channel, 0, network.channels.Len())
for _, ch := range network.channels { for _, entry := range network.channels.innerMap {
ch := entry.value.(*Channel)
channels = append(channels, *ch) channels = append(channels, *ch)
} }