Implement casemapping

TL;DR: supports for casemapping, now logs are saved in
casemapped/canonical/tolower form
(eg. in the #channel directory instead of #Channel... or something)

== What is casemapping? ==

see <https://modern.ircdocs.horse/#casemapping-parameter>

== Casemapping and multi-upstream ==

Since each upstream does not necessarily use the same casemapping, and
since casemappings cannot coexist [0],

1. soju must also update the database accordingly to upstreams'
   casemapping, otherwise it will end up inconsistent,
2. soju must "normalize" entity names and expose only one casemapping
   that is a subset of all supported casemappings (here, ascii).

[0] On some upstreams, "emersion[m]" and "emersion{m}" refer to the same
user (upstreams that advertise rfc1459 for example), while on others
(upstreams that advertise ascii) they don't.

Once upstream's casemapping is known (default to rfc1459), entity names
in map keys are made into casemapped form, for upstreamConn,
upstreamChannel and network.

downstreamConn advertises "CASEMAPPING=ascii", and always casemap map
keys with ascii.

Some functions require the caller to casemap their argument (to avoid
needless calls to casemapping functions).

== Message forwarding and casemapping ==

downstream message handling (joins and parts basically):
When relaying entity names from downstreams to upstreams, soju uses the
upstream casemapping, in order to not get in the way of the user.  This
does not brings any issue, as long as soju replies with the ascii
casemapping in mind (solves point 1.).

marshalEntity/marshalUserPrefix:
When relaying entity names from upstreams with non-ascii casemappings,
soju *partially* casemap them: it only change the case of characters
which are not ascii letters.  ASCII case is thus kept intact, while
special symbols like []{} are the same every time soju sends them to
downstreams (solves point 2.).

== Casemapping changes ==

Casemapping changes are not fully supported by this patch and will
result in loss of history.  This is a limitation of the protocol and
should be solved by the RENAME spec.
This commit is contained in:
Hubert Hirtz 2021-03-16 10:00:34 +01:00 committed by Simon Ser
parent 56bf73716d
commit bdd0c7bc06
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
6 changed files with 379 additions and 112 deletions

View File

@ -54,7 +54,9 @@ func sendNames(dc *downstreamConn, ch *upstreamChannel) {
maxLength := maxMessageLength - len(emptyNameReply.String())
var buf strings.Builder
for nick, memberships := range ch.Members {
for _, entry := range ch.Members.innerMap {
nick := entry.originalKey
memberships := entry.value.(*memberships)
s := memberships.Format(dc) + dc.marshalEntity(ch.conn.network, nick)
n := buf.Len() + 1 + len(s)

View File

@ -116,6 +116,7 @@ type downstreamConn struct {
registered bool
user *user
nick string
nickCM string
rawUsername string
networkName string
clientName string
@ -192,13 +193,13 @@ func (dc *downstreamConn) upstream() *upstreamConn {
func isOurNick(net *network, nick string) bool {
// TODO: this doesn't account for nick changes
if net.conn != nil {
return nick == net.conn.nick
return net.casemap(nick) == net.conn.nickCM
}
// We're not currently connected to the upstream connection, so we don't
// know whether this name is our nickname. Best-effort: use the network's
// configured nickname and hope it was the one being used when we were
// connected.
return nick == net.Nick
return net.casemap(nick) == net.casemap(net.Nick)
}
// marshalEntity converts an upstream entity name (ie. channel or nick) into a
@ -210,6 +211,7 @@ func (dc *downstreamConn) marshalEntity(net *network, name string) string {
if isOurNick(net, name) {
return dc.nick
}
name = partialCasemap(net.casemap, name)
if dc.network != nil {
if dc.network != net {
panic("soju: tried to marshal an entity for another network")
@ -223,6 +225,7 @@ func (dc *downstreamConn) marshalUserPrefix(net *network, prefix *irc.Prefix) *i
if isOurNick(net, prefix.Name) {
return dc.prefix()
}
prefix.Name = partialCasemap(net.casemap, prefix.Name)
if dc.network != nil {
if dc.network != net {
panic("soju: tried to marshal a user prefix for another network")
@ -358,8 +361,8 @@ func (dc *downstreamConn) ackMsgID(id string) {
return
}
delivered, ok := network.delivered[entity]
if !ok {
delivered := network.delivered.Value(entity)
if delivered == nil {
return
}
@ -445,13 +448,15 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error {
Params: []string{dc.nick, nick, "contains illegal characters"},
}}
}
if nick == serviceNick {
nickCM := casemapASCII(nick)
if nickCM == serviceNickCM {
return ircError{&irc.Message{
Command: irc.ERR_NICKNAMEINUSE,
Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"},
}}
}
dc.nick = nick
dc.nickCM = nickCM
case "USER":
if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil {
return err
@ -766,6 +771,7 @@ func (dc *downstreamConn) updateNick() {
Params: []string{uc.nick},
})
dc.nick = uc.nick
dc.nickCM = casemapASCII(dc.nick)
}
}
@ -911,6 +917,7 @@ func (dc *downstreamConn) welcome() error {
isupport := []string{
fmt.Sprintf("CHATHISTORY=%v", dc.srv.HistoryLimit),
"CASEMAPPING=ascii",
}
if uc := dc.upstream(); uc != nil {
@ -960,11 +967,13 @@ func (dc *downstreamConn) welcome() error {
dc.updateSupportedCaps()
dc.forEachUpstream(func(uc *upstreamConn) {
for _, ch := range uc.channels {
for _, entry := range uc.channels.innerMap {
ch := entry.value.(*upstreamChannel)
if !ch.complete {
continue
}
if record, ok := uc.network.channels[ch.Name]; ok && record.Detached {
record := uc.network.channels.Value(ch.Name)
if record != nil && record.Detached {
continue
}
@ -987,8 +996,10 @@ func (dc *downstreamConn) welcome() error {
}
// Fast-forward history to last message
for target, delivered := range net.delivered {
if ch, ok := net.channels[target]; ok && ch.Detached {
for target, entry := range net.delivered.innerMap {
delivered := entry.value.(map[string]string)
ch := net.channels.Value(target)
if ch != nil && ch.Detached {
continue
}
@ -1019,7 +1030,8 @@ func (dc *downstreamConn) messageSupportsHistory(msg *irc.Message) bool {
}
func (dc *downstreamConn) sendNetworkBacklog(net *network) {
for target := range net.delivered {
for _, entry := range net.delivered.innerMap {
target := entry.originalKey
dc.sendTargetBacklog(net, target)
}
}
@ -1028,11 +1040,11 @@ func (dc *downstreamConn) sendTargetBacklog(net *network, target string) {
if dc.caps["draft/chathistory"] || dc.user.msgStore == nil {
return
}
if ch, ok := net.channels[target]; ok && ch.Detached {
if ch := net.channels.Value(target); ch != nil && ch.Detached {
return
}
delivered, ok := net.delivered[target]
if !ok {
delivered := net.delivered.Value(target)
if delivered == nil {
return
}
lastDelivered, ok := delivered[dc.clientName]
@ -1158,7 +1170,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Params: []string{dc.nick, rawNick, "contains illegal characters"},
}}
}
if nick == serviceNick {
if casemapASCII(nick) == serviceNickCM {
return ircError{&irc.Message{
Command: irc.ERR_NICKNAMEINUSE,
Params: []string{dc.nick, rawNick, "Nickname reserved for bouncer service"},
@ -1194,6 +1206,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Params: []string{nick},
})
dc.nick = nick
dc.nickCM = casemapASCII(dc.nick)
}
case "JOIN":
var namesStr string
@ -1226,9 +1239,8 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Params: params,
})
var ch *Channel
var ok bool
if ch, ok = uc.network.channels[upstreamName]; ok {
ch := uc.network.channels.Value(upstreamName)
if ch != nil {
// Don't clear the channel key if there's one set
// TODO: add a way to unset the channel key
if key != "" {
@ -1240,7 +1252,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Name: upstreamName,
Key: key,
}
uc.network.channels[upstreamName] = ch
uc.network.channels.SetValue(upstreamName, ch)
}
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
@ -1264,16 +1276,15 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
}
if strings.EqualFold(reason, "detach") {
var ch *Channel
var ok bool
if ch, ok = uc.network.channels[upstreamName]; ok {
ch := uc.network.channels.Value(upstreamName)
if ch != nil {
uc.network.detach(ch)
} else {
ch = &Channel{
Name: name,
Detached: true,
}
uc.network.channels[upstreamName] = ch
uc.network.channels.SetValue(upstreamName, ch)
}
if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil {
dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err)
@ -1360,7 +1371,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
modeStr = msg.Params[1]
}
if name == dc.nick {
if casemapASCII(name) == dc.nickCM {
if modeStr != "" {
dc.forEachUpstream(func(uc *upstreamConn) {
uc.SendMessageLabeled(dc.id, &irc.Message{
@ -1398,8 +1409,8 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
Params: params,
})
} else {
ch, ok := uc.channels[upstreamName]
if !ok {
ch := uc.channels.Value(upstreamName)
if ch == nil {
return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{dc.nick, name, "No such channel"},
@ -1435,7 +1446,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
return err
}
uc, upstreamChannel, err := dc.unmarshalEntity(channel)
uc, upstreamName, err := dc.unmarshalEntity(channel)
if err != nil {
return err
}
@ -1444,14 +1455,14 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
topic := msg.Params[1]
uc.SendMessageLabeled(dc.id, &irc.Message{
Command: "TOPIC",
Params: []string{upstreamChannel, topic},
Params: []string{upstreamName, topic},
})
} else { // getting topic
ch, ok := uc.channels[upstreamChannel]
if !ok {
ch := uc.channels.Value(upstreamName)
if ch == nil {
return ircError{&irc.Message{
Command: irc.ERR_NOSUCHCHANNEL,
Params: []string{dc.nick, upstreamChannel, "No such channel"},
Params: []string{dc.nick, upstreamName, "No such channel"},
}}
}
sendTopic(dc, ch)
@ -1513,19 +1524,19 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
channels := strings.Split(msg.Params[0], ",")
for _, channel := range channels {
uc, upstreamChannel, err := dc.unmarshalEntity(channel)
uc, upstreamName, err := dc.unmarshalEntity(channel)
if err != nil {
return err
}
ch, ok := uc.channels[upstreamChannel]
if ok {
ch := uc.channels.Value(upstreamName)
if ch != nil {
sendNames(dc, ch)
} else {
// NAMES on a channel we have not joined, ask upstream
uc.SendMessageLabeled(dc.id, &irc.Message{
Command: "NAMES",
Params: []string{upstreamChannel},
Params: []string{upstreamName},
})
}
}
@ -1542,8 +1553,9 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
// TODO: support WHO masks
entity := msg.Params[0]
entityCM := casemapASCII(entity)
if entity == dc.nick {
if entityCM == dc.nickCM {
// TODO: support AWAY (H/G) in self WHO reply
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
@ -1557,7 +1569,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
})
return nil
}
if entity == serviceNick {
if entityCM == serviceNickCM {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.RPL_WHOREPLY,
@ -1608,7 +1620,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error {
mask = mask[:i]
}
if mask == dc.nick {
if casemapASCII(mask) == dc.nickCM {
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),
Command: irc.RPL_WHOISUSER,

197
irc.go
View File

@ -121,12 +121,13 @@ outer:
return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode)
}
member := arguments[nextArgument]
if _, ok := ch.Members[member]; ok {
m := ch.Members.Value(member)
if m != nil {
if plusMinus == '+' {
ch.Members[member].Add(ch.conn.availableMemberships, membership)
m.Add(ch.conn.availableMemberships, membership)
} else {
// TODO: for upstreams without multi-prefix, query the user modes again
ch.Members[member].Remove(membership)
m.Remove(membership)
}
}
needMarshaling[nextArgument] = struct{}{}
@ -418,3 +419,193 @@ func parseCTCPMessage(msg *irc.Message) (cmd string, params string, ok bool) {
return cmd, params, true
}
type casemapping func(string) string
func casemapNone(name string) string {
return name
}
// CasemapASCII of name is the canonical representation of name according to the
// ascii casemapping.
func casemapASCII(name string) string {
var sb strings.Builder
sb.Grow(len(name))
for _, r := range name {
if 'A' <= r && r <= 'Z' {
r += 'a' - 'A'
}
sb.WriteRune(r)
}
return sb.String()
}
// casemapRFC1459 of name is the canonical representation of name according to the
// rfc1459 casemapping.
func casemapRFC1459(name string) string {
var sb strings.Builder
sb.Grow(len(name))
for _, r := range name {
if 'A' <= r && r <= 'Z' {
r += 'a' - 'A'
} else if r == '{' {
r = '['
} else if r == '}' {
r = ']'
} else if r == '\\' {
r = '|'
} else if r == '~' {
r = '^'
}
sb.WriteRune(r)
}
return sb.String()
}
// casemapRFC1459Strict of name is the canonical representation of name
// according to the rfc1459-strict casemapping.
func casemapRFC1459Strict(name string) string {
var sb strings.Builder
sb.Grow(len(name))
for _, r := range name {
if 'A' <= r && r <= 'Z' {
r += 'a' - 'A'
} else if r == '{' {
r = '['
} else if r == '}' {
r = ']'
} else if r == '\\' {
r = '|'
}
sb.WriteRune(r)
}
return sb.String()
}
func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) {
switch tokenValue {
case "ascii":
casemap = casemapASCII
case "rfc1459":
casemap = casemapRFC1459
case "rfc1459-strict":
casemap = casemapRFC1459Strict
default:
return nil, false
}
return casemap, true
}
func partialCasemap(higher casemapping, name string) string {
nameFullyCM := higher(name)
var sb strings.Builder
sb.Grow(len(name))
for i, r := range nameFullyCM {
if 'a' <= r && r <= 'z' {
r = rune(name[i])
}
sb.WriteRune(r)
}
return sb.String()
}
type casemapMap struct {
innerMap map[string]casemapEntry
casemap casemapping
}
type casemapEntry struct {
originalKey string
value interface{}
}
func newCasemapMap(size int) casemapMap {
return casemapMap{
innerMap: make(map[string]casemapEntry, size),
casemap: casemapNone,
}
}
func (cm *casemapMap) OriginalKey(name string) (key string, ok bool) {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
return "", false
}
return entry.originalKey, true
}
func (cm *casemapMap) Has(name string) bool {
_, ok := cm.innerMap[cm.casemap(name)]
return ok
}
func (cm *casemapMap) Len() int {
return len(cm.innerMap)
}
func (cm *casemapMap) SetValue(name string, value interface{}) {
nameCM := cm.casemap(name)
entry, ok := cm.innerMap[nameCM]
if !ok {
cm.innerMap[nameCM] = casemapEntry{
originalKey: name,
value: value,
}
return
}
entry.value = value
cm.innerMap[nameCM] = entry
}
func (cm *casemapMap) Delete(name string) {
delete(cm.innerMap, cm.casemap(name))
}
func (cm *casemapMap) SetCasemapping(newCasemap casemapping) {
cm.casemap = newCasemap
newInnerMap := make(map[string]casemapEntry, len(cm.innerMap))
for _, entry := range cm.innerMap {
newInnerMap[cm.casemap(entry.originalKey)] = entry
}
cm.innerMap = newInnerMap
}
type upstreamChannelCasemapMap struct{ casemapMap }
func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
return nil
}
return entry.value.(*upstreamChannel)
}
type channelCasemapMap struct{ casemapMap }
func (cm *channelCasemapMap) Value(name string) *Channel {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
return nil
}
return entry.value.(*Channel)
}
type membershipsCasemapMap struct{ casemapMap }
func (cm *membershipsCasemapMap) Value(name string) *memberships {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
return nil
}
return entry.value.(*memberships)
}
type mapStringStringCasemapMap struct{ casemapMap }
func (cm *mapStringStringCasemapMap) Value(name string) map[string]string {
entry, ok := cm.innerMap[cm.casemap(name)]
if !ok {
return nil
}
return entry.value.(map[string]string)
}

View File

@ -27,6 +27,7 @@ import (
)
const serviceNick = "BouncerServ"
const serviceNickCM = "bouncerserv"
const serviceRealname = "soju bouncer service"
var servicePrefix = &irc.Prefix{
@ -408,7 +409,7 @@ func handleServiceNetworkStatus(dc *downstreamConn, params []string) error {
} else {
statuses = append(statuses, "connected")
}
details = fmt.Sprintf("%v channels", len(uc.channels))
details = fmt.Sprintf("%v channels", uc.channels.Len())
} else {
statuses = append(statuses, "disconnected")
if net.lastError != nil {
@ -768,8 +769,8 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error {
return fmt.Errorf("unknown channel %q", name)
}
ch, ok := uc.network.channels[upstreamName]
if !ok {
ch := uc.network.channels.Value(upstreamName)
if ch == nil {
return fmt.Errorf("unknown channel %q", name)
}

View File

@ -48,7 +48,7 @@ type upstreamChannel struct {
Status channelStatus
modes channelModes
creationTime string
Members map[string]*memberships
Members membershipsCasemapMap
complete bool
detachTimer *time.Timer
}
@ -86,10 +86,11 @@ type upstreamConn struct {
registered bool
nick string
nickCM string
username string
realname string
modes userModes
channels map[string]*upstreamChannel
channels upstreamChannelCasemapMap
supportedCaps map[string]string
caps map[string]bool
batches map[string]batch
@ -99,6 +100,8 @@ type upstreamConn struct {
saslClient sasl.Client
saslStarted bool
casemapIsSet bool
// set of LIST commands in progress, per downstream
pendingLISTDownstreamSet map[uint64]struct{}
}
@ -186,7 +189,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options),
network: network,
user: network.user,
channels: make(map[string]*upstreamChannel),
channels: upstreamChannelCasemapMap{newCasemapMap(0)},
supportedCaps: make(map[string]string),
caps: make(map[string]bool),
batches: make(map[string]batch),
@ -213,8 +216,8 @@ func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn)
}
func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) {
ch, ok := uc.channels[name]
if !ok {
ch := uc.channels.Value(name)
if ch == nil {
return nil, fmt.Errorf("unknown channel %q", name)
}
return ch, nil
@ -224,6 +227,10 @@ func (uc *upstreamConn) isChannel(entity string) bool {
return strings.ContainsRune(uc.availableChannelTypes, rune(entity[0]))
}
func (uc *upstreamConn) isOurNick(nick string) bool {
return uc.nickCM == uc.network.casemap(nick)
}
func (uc *upstreamConn) getPendingLIST() *pendingLIST {
for _, pl := range uc.user.pendingLISTs {
if _, ok := pl.pendingCommands[uc.network.ID]; !ok {
@ -413,11 +420,12 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
uc.produce("", msg, nil)
} else { // regular user message
target := entity
if target == uc.nick {
if uc.isOurNick(target) {
target = msg.Prefix.Name
}
if ch, ok := uc.network.channels[target]; ok {
ch := uc.network.channels.Value(target)
if ch != nil {
if ch.Detached {
uc.handleDetachedMessage(msg.Prefix.Name, text, ch)
}
@ -590,9 +598,10 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
uc.registered = true
uc.logger.Printf("connection registered")
if len(uc.network.channels) > 0 {
if uc.network.channels.Len() > 0 {
var channels, keys []string
for _, ch := range uc.network.channels {
for _, entry := range uc.network.channels.innerMap {
ch := entry.value.(*Channel)
channels = append(channels, ch.Name)
keys = append(keys, ch.Key)
}
@ -634,6 +643,14 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
var err error
switch parameter {
case "CASEMAPPING":
casemap, ok := parseCasemappingToken(value)
if !ok {
casemap = casemapRFC1459
}
uc.network.updateCasemapping(casemap)
uc.nickCM = uc.network.casemap(uc.nick)
uc.casemapIsSet = true
case "CHANMODES":
if !negate {
err = uc.handleChanModes(value)
@ -671,6 +688,14 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
dc.SendMessage(msg)
}
})
case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD:
if !uc.casemapIsSet {
// upstream did not send any CASEMAPPING token, thus
// we assume it implements the old RFCs with rfc1459.
uc.casemapIsSet = true
uc.network.updateCasemapping(casemapRFC1459)
uc.nickCM = uc.network.casemap(uc.nick)
}
case "BATCH":
var tag string
if err := parseMessageParams(msg, &tag); err != nil {
@ -716,16 +741,19 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
}
me := false
if msg.Prefix.Name == uc.nick {
if uc.isOurNick(msg.Prefix.Name) {
uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick)
me = true
uc.nick = newNick
uc.nickCM = uc.network.casemap(uc.nick)
}
for _, ch := range uc.channels {
if memberships, ok := ch.Members[msg.Prefix.Name]; ok {
delete(ch.Members, msg.Prefix.Name)
ch.Members[newNick] = memberships
for _, entry := range uc.channels.innerMap {
ch := entry.value.(*upstreamChannel)
memberships := ch.Members.Value(msg.Prefix.Name)
if memberships != nil {
ch.Members.Delete(msg.Prefix.Name)
ch.Members.SetValue(newNick, memberships)
uc.appendLog(ch.Name, msg)
}
}
@ -750,13 +778,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
}
for _, ch := range strings.Split(channels, ",") {
if msg.Prefix.Name == uc.nick {
if uc.isOurNick(msg.Prefix.Name) {
uc.logger.Printf("joined channel %q", ch)
uc.channels[ch] = &upstreamChannel{
members := membershipsCasemapMap{newCasemapMap(0)}
members.casemap = uc.network.casemap
uc.channels.SetValue(ch, &upstreamChannel{
Name: ch,
conn: uc,
Members: make(map[string]*memberships),
}
Members: members,
})
uc.updateChannelAutoDetach(ch)
uc.SendMessage(&irc.Message{
@ -768,7 +798,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
if err != nil {
return err
}
ch.Members[msg.Prefix.Name] = &memberships{}
ch.Members.SetValue(msg.Prefix.Name, &memberships{})
}
chMsg := msg.Copy()
@ -786,10 +816,11 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
}
for _, ch := range strings.Split(channels, ",") {
if msg.Prefix.Name == uc.nick {
if uc.isOurNick(msg.Prefix.Name) {
uc.logger.Printf("parted channel %q", ch)
if uch, ok := uc.channels[ch]; ok {
delete(uc.channels, ch)
uch := uc.channels.Value(ch)
if uch != nil {
uc.channels.Delete(ch)
uch.updateAutoDetach(0)
}
} else {
@ -797,7 +828,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
if err != nil {
return err
}
delete(ch.Members, msg.Prefix.Name)
ch.Members.Delete(msg.Prefix.Name)
}
chMsg := msg.Copy()
@ -814,15 +845,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err
}
if user == uc.nick {
if uc.isOurNick(user) {
uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name)
delete(uc.channels, channel)
uc.channels.Delete(channel)
} else {
ch, err := uc.getChannel(channel)
if err != nil {
return err
}
delete(ch.Members, user)
ch.Members.Delete(user)
}
uc.produce(channel, msg, nil)
@ -831,13 +862,14 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return fmt.Errorf("expected a prefix")
}
if msg.Prefix.Name == uc.nick {
if uc.isOurNick(msg.Prefix.Name) {
uc.logger.Printf("quit")
}
for _, ch := range uc.channels {
if _, ok := ch.Members[msg.Prefix.Name]; ok {
delete(ch.Members, msg.Prefix.Name)
for _, entry := range uc.channels.innerMap {
ch := entry.value.(*upstreamChannel)
if ch.Members.Has(msg.Prefix.Name) {
ch.Members.Delete(msg.Prefix.Name)
uc.appendLog(ch.Name, msg)
}
@ -908,7 +940,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
uc.appendLog(ch.Name, msg)
if ch, ok := uc.network.channels[name]; !ok || !ch.Detached {
c := uc.network.channels.Value(name)
if c == nil || !c.Detached {
uc.forEachDownstream(func(dc *downstreamConn) {
params := make([]string, len(msg.Params))
params[0] = dc.marshalEntity(uc.network, name)
@ -964,7 +997,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err
}
if firstMode {
if c, ok := uc.network.channels[channel]; !ok || !c.Detached {
c := uc.network.channels.Value(channel)
if c == nil || !c.Detached {
modeStr, modeParams := ch.modes.Format()
uc.forEachDownstream(func(dc *downstreamConn) {
@ -1061,8 +1095,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err
}
ch, ok := uc.channels[name]
if !ok {
ch := uc.channels.Value(name)
if ch == nil {
// NAMES on a channel we have not joined, forward to downstream
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
channel := dc.marshalEntity(uc.network, name)
@ -1090,7 +1124,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
for _, s := range splitSpace(members) {
memberships, nick := uc.parseMembershipPrefix(s)
ch.Members[nick] = memberships
ch.Members.SetValue(nick, memberships)
}
case irc.RPL_ENDOFNAMES:
var name string
@ -1098,8 +1132,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err
}
ch, ok := uc.channels[name]
if !ok {
ch := uc.channels.Value(name)
if ch == nil {
// NAMES on a channel we have not joined, forward to downstream
uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) {
channel := dc.marshalEntity(uc.network, name)
@ -1118,7 +1152,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
}
ch.complete = true
if c, ok := uc.network.channels[name]; !ok || !c.Detached {
c := uc.network.channels.Value(name)
if c == nil || !c.Detached {
uc.forEachDownstream(func(dc *downstreamConn) {
forwardChannel(dc, ch)
})
@ -1272,7 +1307,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
return err
}
weAreInvited := nick == uc.nick
weAreInvited := uc.isOurNick(nick)
uc.forEachDownstream(func(dc *downstreamConn) {
if !weAreInvited && !dc.caps["invite-notify"] {
@ -1395,7 +1430,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
// Ignore
case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
// Ignore
case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD:
case irc.RPL_MOTDSTART, irc.RPL_MOTD:
// Ignore
case irc.RPL_LISTSTART:
// Ignore
@ -1601,6 +1636,7 @@ func splitSpace(s string) []string {
func (uc *upstreamConn) register() {
uc.nick = uc.network.Nick
uc.nickCM = uc.network.casemap(uc.nick)
uc.username = uc.network.GetUsername()
uc.realname = uc.network.GetRealname()
@ -1705,20 +1741,21 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string
}
detached := false
if ch, ok := uc.network.channels[entity]; ok {
if ch := uc.network.channels.Value(entity); ch != nil {
detached = ch.Detached
}
delivered, ok := uc.network.delivered[entity]
if !ok {
lastID, err := uc.user.msgStore.LastMsgID(uc.network, entity, time.Now())
delivered := uc.network.delivered.Value(entity)
entityCM := uc.network.casemap(entity)
if delivered == nil {
lastID, err := uc.user.msgStore.LastMsgID(uc.network, entityCM, time.Now())
if err != nil {
uc.logger.Printf("failed to log message: failed to get last message ID: %v", err)
return ""
}
delivered = make(map[string]string)
uc.network.delivered[entity] = delivered
uc.network.delivered.SetValue(entity, delivered)
for clientName, _ := range uc.network.offlineClients {
delivered[clientName] = lastID
@ -1733,7 +1770,7 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string
}
}
msgID, err := uc.user.msgStore.Append(uc.network, entity, msg)
msgID, err := uc.user.msgStore.Append(uc.network, entityCM, msg)
if err != nil {
uc.logger.Printf("failed to log message: %v", err)
return ""
@ -1754,7 +1791,8 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstr
}
// Don't forward messages if it's a detached channel
if ch, ok := uc.network.channels[target]; ok && ch.Detached {
ch := uc.network.channels.Value(target)
if ch != nil && ch.Detached {
return
}
@ -1789,9 +1827,13 @@ func (uc *upstreamConn) updateAway() {
}
func (uc *upstreamConn) updateChannelAutoDetach(name string) {
if uch, ok := uc.channels[name]; ok {
if ch, ok := uc.network.channels[name]; ok && !ch.Detached {
uch.updateAutoDetach(ch.DetachAfter)
}
uch := uc.channels.Value(name)
if uch == nil {
return
}
ch := uc.network.channels.Value(name)
if ch == nil || ch.Detached {
return
}
uch.updateAutoDetach(ch.DetachAfter)
}

53
user.go
View File

@ -61,17 +61,18 @@ type network struct {
stopped chan struct{}
conn *upstreamConn
channels map[string]*Channel
delivered map[string]map[string]string // entity -> client name -> msg ID
offlineClients map[string]struct{} // indexed by client name
channels channelCasemapMap
delivered mapStringStringCasemapMap // entity -> client name -> msg ID
offlineClients map[string]struct{} // indexed by client name
lastError error
casemap casemapping
}
func newNetwork(user *user, record *Network, channels []Channel) *network {
m := make(map[string]*Channel, len(channels))
m := channelCasemapMap{newCasemapMap(0)}
for _, ch := range channels {
ch := ch
m[ch.Name] = &ch
m.SetValue(ch.Name, &ch)
}
return &network{
@ -79,8 +80,9 @@ func newNetwork(user *user, record *Network, channels []Channel) *network {
user: user,
stopped: make(chan struct{}),
channels: m,
delivered: make(map[string]map[string]string),
delivered: mapStringStringCasemapMap{newCasemapMap(0)},
offlineClients: make(map[string]struct{}),
casemap: casemapRFC1459,
}
}
@ -185,7 +187,8 @@ func (net *network) detach(ch *Channel) {
net.user.srv.Logger.Printf("network %q: detaching channel %q", net.GetName(), ch.Name)
if net.conn != nil {
if uch, ok := net.conn.channels[ch.Name]; ok {
uch := net.conn.channels.Value(ch.Name)
if uch != nil {
uch.updateAutoDetach(0)
}
}
@ -210,7 +213,7 @@ func (net *network) attach(ch *Channel) {
var uch *upstreamChannel
if net.conn != nil {
uch = net.conn.channels[ch.Name]
uch = net.conn.channels.Value(ch.Name)
net.conn.updateChannelAutoDetach(ch.Name)
}
@ -231,12 +234,13 @@ func (net *network) attach(ch *Channel) {
}
func (net *network) deleteChannel(name string) error {
ch, ok := net.channels[name]
if !ok {
ch := net.channels.Value(name)
if ch == nil {
return fmt.Errorf("unknown channel %q", name)
}
if net.conn != nil {
if uch, ok := net.conn.channels[ch.Name]; ok {
uch := net.conn.channels.Value(ch.Name)
if uch != nil {
uch.updateAutoDetach(0)
}
}
@ -244,10 +248,23 @@ func (net *network) deleteChannel(name string) error {
if err := net.user.srv.db.DeleteChannel(ch.ID); err != nil {
return err
}
delete(net.channels, name)
net.channels.Delete(name)
return nil
}
func (net *network) updateCasemapping(newCasemap casemapping) {
net.casemap = newCasemap
net.channels.SetCasemapping(newCasemap)
net.delivered.SetCasemapping(newCasemap)
if net.conn != nil {
net.conn.channels.SetCasemapping(newCasemap)
for _, entry := range net.conn.channels.innerMap {
uch := entry.value.(*upstreamChannel)
uch.Members.SetCasemapping(newCasemap)
}
}
}
type user struct {
User
srv *Server
@ -410,8 +427,8 @@ func (u *user) run() {
}
case eventChannelDetach:
uc, name := e.uc, e.name
c, ok := uc.network.channels[name]
if !ok || c.Detached {
c := uc.network.channels.Value(name)
if c == nil || c.Detached {
continue
}
uc.network.detach(c)
@ -499,7 +516,8 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) {
uc.endPendingLISTs(true)
for _, uch := range uc.channels {
for _, entry := range uc.channels.innerMap {
uch := entry.value.(*upstreamChannel)
uch.updateAutoDetach(0)
}
@ -570,8 +588,9 @@ func (u *user) updateNetwork(record *Network) (*network, error) {
// Most network changes require us to re-connect to the upstream server
channels := make([]Channel, 0, len(network.channels))
for _, ch := range network.channels {
channels := make([]Channel, 0, network.channels.Len())
for _, entry := range network.channels.innerMap {
ch := entry.value.(*Channel)
channels = append(channels, *ch)
}