diff --git a/bridge.go b/bridge.go index f707e33..509f1fd 100644 --- a/bridge.go +++ b/bridge.go @@ -54,7 +54,9 @@ func sendNames(dc *downstreamConn, ch *upstreamChannel) { maxLength := maxMessageLength - len(emptyNameReply.String()) var buf strings.Builder - for nick, memberships := range ch.Members { + for _, entry := range ch.Members.innerMap { + nick := entry.originalKey + memberships := entry.value.(*memberships) s := memberships.Format(dc) + dc.marshalEntity(ch.conn.network, nick) n := buf.Len() + 1 + len(s) diff --git a/downstream.go b/downstream.go index a6fe74c..cf9da70 100644 --- a/downstream.go +++ b/downstream.go @@ -116,6 +116,7 @@ type downstreamConn struct { registered bool user *user nick string + nickCM string rawUsername string networkName string clientName string @@ -192,13 +193,13 @@ func (dc *downstreamConn) upstream() *upstreamConn { func isOurNick(net *network, nick string) bool { // TODO: this doesn't account for nick changes if net.conn != nil { - return nick == net.conn.nick + return net.casemap(nick) == net.conn.nickCM } // We're not currently connected to the upstream connection, so we don't // know whether this name is our nickname. Best-effort: use the network's // configured nickname and hope it was the one being used when we were // connected. - return nick == net.Nick + return net.casemap(nick) == net.casemap(net.Nick) } // marshalEntity converts an upstream entity name (ie. channel or nick) into a @@ -210,6 +211,7 @@ func (dc *downstreamConn) marshalEntity(net *network, name string) string { if isOurNick(net, name) { return dc.nick } + name = partialCasemap(net.casemap, name) if dc.network != nil { if dc.network != net { panic("soju: tried to marshal an entity for another network") @@ -223,6 +225,7 @@ func (dc *downstreamConn) marshalUserPrefix(net *network, prefix *irc.Prefix) *i if isOurNick(net, prefix.Name) { return dc.prefix() } + prefix.Name = partialCasemap(net.casemap, prefix.Name) if dc.network != nil { if dc.network != net { panic("soju: tried to marshal a user prefix for another network") @@ -358,8 +361,8 @@ func (dc *downstreamConn) ackMsgID(id string) { return } - delivered, ok := network.delivered[entity] - if !ok { + delivered := network.delivered.Value(entity) + if delivered == nil { return } @@ -445,13 +448,15 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error { Params: []string{dc.nick, nick, "contains illegal characters"}, }} } - if nick == serviceNick { + nickCM := casemapASCII(nick) + if nickCM == serviceNickCM { return ircError{&irc.Message{ Command: irc.ERR_NICKNAMEINUSE, Params: []string{dc.nick, nick, "Nickname reserved for bouncer service"}, }} } dc.nick = nick + dc.nickCM = nickCM case "USER": if err := parseMessageParams(msg, &dc.rawUsername, nil, nil, &dc.realname); err != nil { return err @@ -766,6 +771,7 @@ func (dc *downstreamConn) updateNick() { Params: []string{uc.nick}, }) dc.nick = uc.nick + dc.nickCM = casemapASCII(dc.nick) } } @@ -911,6 +917,7 @@ func (dc *downstreamConn) welcome() error { isupport := []string{ fmt.Sprintf("CHATHISTORY=%v", dc.srv.HistoryLimit), + "CASEMAPPING=ascii", } if uc := dc.upstream(); uc != nil { @@ -960,11 +967,13 @@ func (dc *downstreamConn) welcome() error { dc.updateSupportedCaps() dc.forEachUpstream(func(uc *upstreamConn) { - for _, ch := range uc.channels { + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) if !ch.complete { continue } - if record, ok := uc.network.channels[ch.Name]; ok && record.Detached { + record := uc.network.channels.Value(ch.Name) + if record != nil && record.Detached { continue } @@ -987,8 +996,10 @@ func (dc *downstreamConn) welcome() error { } // Fast-forward history to last message - for target, delivered := range net.delivered { - if ch, ok := net.channels[target]; ok && ch.Detached { + for target, entry := range net.delivered.innerMap { + delivered := entry.value.(map[string]string) + ch := net.channels.Value(target) + if ch != nil && ch.Detached { continue } @@ -1019,7 +1030,8 @@ func (dc *downstreamConn) messageSupportsHistory(msg *irc.Message) bool { } func (dc *downstreamConn) sendNetworkBacklog(net *network) { - for target := range net.delivered { + for _, entry := range net.delivered.innerMap { + target := entry.originalKey dc.sendTargetBacklog(net, target) } } @@ -1028,11 +1040,11 @@ func (dc *downstreamConn) sendTargetBacklog(net *network, target string) { if dc.caps["draft/chathistory"] || dc.user.msgStore == nil { return } - if ch, ok := net.channels[target]; ok && ch.Detached { + if ch := net.channels.Value(target); ch != nil && ch.Detached { return } - delivered, ok := net.delivered[target] - if !ok { + delivered := net.delivered.Value(target) + if delivered == nil { return } lastDelivered, ok := delivered[dc.clientName] @@ -1158,7 +1170,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Params: []string{dc.nick, rawNick, "contains illegal characters"}, }} } - if nick == serviceNick { + if casemapASCII(nick) == serviceNickCM { return ircError{&irc.Message{ Command: irc.ERR_NICKNAMEINUSE, Params: []string{dc.nick, rawNick, "Nickname reserved for bouncer service"}, @@ -1194,6 +1206,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Params: []string{nick}, }) dc.nick = nick + dc.nickCM = casemapASCII(dc.nick) } case "JOIN": var namesStr string @@ -1226,9 +1239,8 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Params: params, }) - var ch *Channel - var ok bool - if ch, ok = uc.network.channels[upstreamName]; ok { + ch := uc.network.channels.Value(upstreamName) + if ch != nil { // Don't clear the channel key if there's one set // TODO: add a way to unset the channel key if key != "" { @@ -1240,7 +1252,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Name: upstreamName, Key: key, } - uc.network.channels[upstreamName] = ch + uc.network.channels.SetValue(upstreamName, ch) } if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil { dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) @@ -1264,16 +1276,15 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { } if strings.EqualFold(reason, "detach") { - var ch *Channel - var ok bool - if ch, ok = uc.network.channels[upstreamName]; ok { + ch := uc.network.channels.Value(upstreamName) + if ch != nil { uc.network.detach(ch) } else { ch = &Channel{ Name: name, Detached: true, } - uc.network.channels[upstreamName] = ch + uc.network.channels.SetValue(upstreamName, ch) } if err := dc.srv.db.StoreChannel(uc.network.ID, ch); err != nil { dc.logger.Printf("failed to create or update channel %q: %v", upstreamName, err) @@ -1360,7 +1371,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { modeStr = msg.Params[1] } - if name == dc.nick { + if casemapASCII(name) == dc.nickCM { if modeStr != "" { dc.forEachUpstream(func(uc *upstreamConn) { uc.SendMessageLabeled(dc.id, &irc.Message{ @@ -1398,8 +1409,8 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { Params: params, }) } else { - ch, ok := uc.channels[upstreamName] - if !ok { + ch := uc.channels.Value(upstreamName) + if ch == nil { return ircError{&irc.Message{ Command: irc.ERR_NOSUCHCHANNEL, Params: []string{dc.nick, name, "No such channel"}, @@ -1435,7 +1446,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { return err } - uc, upstreamChannel, err := dc.unmarshalEntity(channel) + uc, upstreamName, err := dc.unmarshalEntity(channel) if err != nil { return err } @@ -1444,14 +1455,14 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { topic := msg.Params[1] uc.SendMessageLabeled(dc.id, &irc.Message{ Command: "TOPIC", - Params: []string{upstreamChannel, topic}, + Params: []string{upstreamName, topic}, }) } else { // getting topic - ch, ok := uc.channels[upstreamChannel] - if !ok { + ch := uc.channels.Value(upstreamName) + if ch == nil { return ircError{&irc.Message{ Command: irc.ERR_NOSUCHCHANNEL, - Params: []string{dc.nick, upstreamChannel, "No such channel"}, + Params: []string{dc.nick, upstreamName, "No such channel"}, }} } sendTopic(dc, ch) @@ -1513,19 +1524,19 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { channels := strings.Split(msg.Params[0], ",") for _, channel := range channels { - uc, upstreamChannel, err := dc.unmarshalEntity(channel) + uc, upstreamName, err := dc.unmarshalEntity(channel) if err != nil { return err } - ch, ok := uc.channels[upstreamChannel] - if ok { + ch := uc.channels.Value(upstreamName) + if ch != nil { sendNames(dc, ch) } else { // NAMES on a channel we have not joined, ask upstream uc.SendMessageLabeled(dc.id, &irc.Message{ Command: "NAMES", - Params: []string{upstreamChannel}, + Params: []string{upstreamName}, }) } } @@ -1542,8 +1553,9 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { // TODO: support WHO masks entity := msg.Params[0] + entityCM := casemapASCII(entity) - if entity == dc.nick { + if entityCM == dc.nickCM { // TODO: support AWAY (H/G) in self WHO reply dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), @@ -1557,7 +1569,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { }) return nil } - if entity == serviceNick { + if entityCM == serviceNickCM { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_WHOREPLY, @@ -1608,7 +1620,7 @@ func (dc *downstreamConn) handleMessageRegistered(msg *irc.Message) error { mask = mask[:i] } - if mask == dc.nick { + if casemapASCII(mask) == dc.nickCM { dc.SendMessage(&irc.Message{ Prefix: dc.srv.prefix(), Command: irc.RPL_WHOISUSER, diff --git a/irc.go b/irc.go index 94fdc7f..ae0ae43 100644 --- a/irc.go +++ b/irc.go @@ -121,12 +121,13 @@ outer: return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode) } member := arguments[nextArgument] - if _, ok := ch.Members[member]; ok { + m := ch.Members.Value(member) + if m != nil { if plusMinus == '+' { - ch.Members[member].Add(ch.conn.availableMemberships, membership) + m.Add(ch.conn.availableMemberships, membership) } else { // TODO: for upstreams without multi-prefix, query the user modes again - ch.Members[member].Remove(membership) + m.Remove(membership) } } needMarshaling[nextArgument] = struct{}{} @@ -418,3 +419,193 @@ func parseCTCPMessage(msg *irc.Message) (cmd string, params string, ok bool) { return cmd, params, true } + +type casemapping func(string) string + +func casemapNone(name string) string { + return name +} + +// CasemapASCII of name is the canonical representation of name according to the +// ascii casemapping. +func casemapASCII(name string) string { + var sb strings.Builder + sb.Grow(len(name)) + for _, r := range name { + if 'A' <= r && r <= 'Z' { + r += 'a' - 'A' + } + sb.WriteRune(r) + } + return sb.String() +} + +// casemapRFC1459 of name is the canonical representation of name according to the +// rfc1459 casemapping. +func casemapRFC1459(name string) string { + var sb strings.Builder + sb.Grow(len(name)) + for _, r := range name { + if 'A' <= r && r <= 'Z' { + r += 'a' - 'A' + } else if r == '{' { + r = '[' + } else if r == '}' { + r = ']' + } else if r == '\\' { + r = '|' + } else if r == '~' { + r = '^' + } + sb.WriteRune(r) + } + return sb.String() +} + +// casemapRFC1459Strict of name is the canonical representation of name +// according to the rfc1459-strict casemapping. +func casemapRFC1459Strict(name string) string { + var sb strings.Builder + sb.Grow(len(name)) + for _, r := range name { + if 'A' <= r && r <= 'Z' { + r += 'a' - 'A' + } else if r == '{' { + r = '[' + } else if r == '}' { + r = ']' + } else if r == '\\' { + r = '|' + } + sb.WriteRune(r) + } + return sb.String() +} + +func parseCasemappingToken(tokenValue string) (casemap casemapping, ok bool) { + switch tokenValue { + case "ascii": + casemap = casemapASCII + case "rfc1459": + casemap = casemapRFC1459 + case "rfc1459-strict": + casemap = casemapRFC1459Strict + default: + return nil, false + } + return casemap, true +} + +func partialCasemap(higher casemapping, name string) string { + nameFullyCM := higher(name) + var sb strings.Builder + sb.Grow(len(name)) + for i, r := range nameFullyCM { + if 'a' <= r && r <= 'z' { + r = rune(name[i]) + } + sb.WriteRune(r) + } + return sb.String() +} + +type casemapMap struct { + innerMap map[string]casemapEntry + casemap casemapping +} + +type casemapEntry struct { + originalKey string + value interface{} +} + +func newCasemapMap(size int) casemapMap { + return casemapMap{ + innerMap: make(map[string]casemapEntry, size), + casemap: casemapNone, + } +} + +func (cm *casemapMap) OriginalKey(name string) (key string, ok bool) { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return "", false + } + return entry.originalKey, true +} + +func (cm *casemapMap) Has(name string) bool { + _, ok := cm.innerMap[cm.casemap(name)] + return ok +} + +func (cm *casemapMap) Len() int { + return len(cm.innerMap) +} + +func (cm *casemapMap) SetValue(name string, value interface{}) { + nameCM := cm.casemap(name) + entry, ok := cm.innerMap[nameCM] + if !ok { + cm.innerMap[nameCM] = casemapEntry{ + originalKey: name, + value: value, + } + return + } + entry.value = value + cm.innerMap[nameCM] = entry +} + +func (cm *casemapMap) Delete(name string) { + delete(cm.innerMap, cm.casemap(name)) +} + +func (cm *casemapMap) SetCasemapping(newCasemap casemapping) { + cm.casemap = newCasemap + newInnerMap := make(map[string]casemapEntry, len(cm.innerMap)) + for _, entry := range cm.innerMap { + newInnerMap[cm.casemap(entry.originalKey)] = entry + } + cm.innerMap = newInnerMap +} + +type upstreamChannelCasemapMap struct{ casemapMap } + +func (cm *upstreamChannelCasemapMap) Value(name string) *upstreamChannel { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*upstreamChannel) +} + +type channelCasemapMap struct{ casemapMap } + +func (cm *channelCasemapMap) Value(name string) *Channel { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*Channel) +} + +type membershipsCasemapMap struct{ casemapMap } + +func (cm *membershipsCasemapMap) Value(name string) *memberships { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(*memberships) +} + +type mapStringStringCasemapMap struct{ casemapMap } + +func (cm *mapStringStringCasemapMap) Value(name string) map[string]string { + entry, ok := cm.innerMap[cm.casemap(name)] + if !ok { + return nil + } + return entry.value.(map[string]string) +} diff --git a/service.go b/service.go index 1c239e1..01e27d9 100644 --- a/service.go +++ b/service.go @@ -27,6 +27,7 @@ import ( ) const serviceNick = "BouncerServ" +const serviceNickCM = "bouncerserv" const serviceRealname = "soju bouncer service" var servicePrefix = &irc.Prefix{ @@ -408,7 +409,7 @@ func handleServiceNetworkStatus(dc *downstreamConn, params []string) error { } else { statuses = append(statuses, "connected") } - details = fmt.Sprintf("%v channels", len(uc.channels)) + details = fmt.Sprintf("%v channels", uc.channels.Len()) } else { statuses = append(statuses, "disconnected") if net.lastError != nil { @@ -768,8 +769,8 @@ func handleServiceChannelUpdate(dc *downstreamConn, params []string) error { return fmt.Errorf("unknown channel %q", name) } - ch, ok := uc.network.channels[upstreamName] - if !ok { + ch := uc.network.channels.Value(upstreamName) + if ch == nil { return fmt.Errorf("unknown channel %q", name) } diff --git a/upstream.go b/upstream.go index d5ba66b..42143e2 100644 --- a/upstream.go +++ b/upstream.go @@ -48,7 +48,7 @@ type upstreamChannel struct { Status channelStatus modes channelModes creationTime string - Members map[string]*memberships + Members membershipsCasemapMap complete bool detachTimer *time.Timer } @@ -86,10 +86,11 @@ type upstreamConn struct { registered bool nick string + nickCM string username string realname string modes userModes - channels map[string]*upstreamChannel + channels upstreamChannelCasemapMap supportedCaps map[string]string caps map[string]bool batches map[string]batch @@ -99,6 +100,8 @@ type upstreamConn struct { saslClient sasl.Client saslStarted bool + casemapIsSet bool + // set of LIST commands in progress, per downstream pendingLISTDownstreamSet map[uint64]struct{} } @@ -186,7 +189,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) { conn: *newConn(network.user.srv, newNetIRCConn(netConn), &options), network: network, user: network.user, - channels: make(map[string]*upstreamChannel), + channels: upstreamChannelCasemapMap{newCasemapMap(0)}, supportedCaps: make(map[string]string), caps: make(map[string]bool), batches: make(map[string]batch), @@ -213,8 +216,8 @@ func (uc *upstreamConn) forEachDownstreamByID(id uint64, f func(*downstreamConn) } func (uc *upstreamConn) getChannel(name string) (*upstreamChannel, error) { - ch, ok := uc.channels[name] - if !ok { + ch := uc.channels.Value(name) + if ch == nil { return nil, fmt.Errorf("unknown channel %q", name) } return ch, nil @@ -224,6 +227,10 @@ func (uc *upstreamConn) isChannel(entity string) bool { return strings.ContainsRune(uc.availableChannelTypes, rune(entity[0])) } +func (uc *upstreamConn) isOurNick(nick string) bool { + return uc.nickCM == uc.network.casemap(nick) +} + func (uc *upstreamConn) getPendingLIST() *pendingLIST { for _, pl := range uc.user.pendingLISTs { if _, ok := pl.pendingCommands[uc.network.ID]; !ok { @@ -413,11 +420,12 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { uc.produce("", msg, nil) } else { // regular user message target := entity - if target == uc.nick { + if uc.isOurNick(target) { target = msg.Prefix.Name } - if ch, ok := uc.network.channels[target]; ok { + ch := uc.network.channels.Value(target) + if ch != nil { if ch.Detached { uc.handleDetachedMessage(msg.Prefix.Name, text, ch) } @@ -590,9 +598,10 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { uc.registered = true uc.logger.Printf("connection registered") - if len(uc.network.channels) > 0 { + if uc.network.channels.Len() > 0 { var channels, keys []string - for _, ch := range uc.network.channels { + for _, entry := range uc.network.channels.innerMap { + ch := entry.value.(*Channel) channels = append(channels, ch.Name) keys = append(keys, ch.Key) } @@ -634,6 +643,14 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { var err error switch parameter { + case "CASEMAPPING": + casemap, ok := parseCasemappingToken(value) + if !ok { + casemap = casemapRFC1459 + } + uc.network.updateCasemapping(casemap) + uc.nickCM = uc.network.casemap(uc.nick) + uc.casemapIsSet = true case "CHANMODES": if !negate { err = uc.handleChanModes(value) @@ -671,6 +688,14 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { dc.SendMessage(msg) } }) + case irc.ERR_NOMOTD, irc.RPL_ENDOFMOTD: + if !uc.casemapIsSet { + // upstream did not send any CASEMAPPING token, thus + // we assume it implements the old RFCs with rfc1459. + uc.casemapIsSet = true + uc.network.updateCasemapping(casemapRFC1459) + uc.nickCM = uc.network.casemap(uc.nick) + } case "BATCH": var tag string if err := parseMessageParams(msg, &tag); err != nil { @@ -716,16 +741,19 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { } me := false - if msg.Prefix.Name == uc.nick { + if uc.isOurNick(msg.Prefix.Name) { uc.logger.Printf("changed nick from %q to %q", uc.nick, newNick) me = true uc.nick = newNick + uc.nickCM = uc.network.casemap(uc.nick) } - for _, ch := range uc.channels { - if memberships, ok := ch.Members[msg.Prefix.Name]; ok { - delete(ch.Members, msg.Prefix.Name) - ch.Members[newNick] = memberships + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + memberships := ch.Members.Value(msg.Prefix.Name) + if memberships != nil { + ch.Members.Delete(msg.Prefix.Name) + ch.Members.SetValue(newNick, memberships) uc.appendLog(ch.Name, msg) } } @@ -750,13 +778,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { } for _, ch := range strings.Split(channels, ",") { - if msg.Prefix.Name == uc.nick { + if uc.isOurNick(msg.Prefix.Name) { uc.logger.Printf("joined channel %q", ch) - uc.channels[ch] = &upstreamChannel{ + members := membershipsCasemapMap{newCasemapMap(0)} + members.casemap = uc.network.casemap + uc.channels.SetValue(ch, &upstreamChannel{ Name: ch, conn: uc, - Members: make(map[string]*memberships), - } + Members: members, + }) uc.updateChannelAutoDetach(ch) uc.SendMessage(&irc.Message{ @@ -768,7 +798,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { if err != nil { return err } - ch.Members[msg.Prefix.Name] = &memberships{} + ch.Members.SetValue(msg.Prefix.Name, &memberships{}) } chMsg := msg.Copy() @@ -786,10 +816,11 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { } for _, ch := range strings.Split(channels, ",") { - if msg.Prefix.Name == uc.nick { + if uc.isOurNick(msg.Prefix.Name) { uc.logger.Printf("parted channel %q", ch) - if uch, ok := uc.channels[ch]; ok { - delete(uc.channels, ch) + uch := uc.channels.Value(ch) + if uch != nil { + uc.channels.Delete(ch) uch.updateAutoDetach(0) } } else { @@ -797,7 +828,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { if err != nil { return err } - delete(ch.Members, msg.Prefix.Name) + ch.Members.Delete(msg.Prefix.Name) } chMsg := msg.Copy() @@ -814,15 +845,15 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return err } - if user == uc.nick { + if uc.isOurNick(user) { uc.logger.Printf("kicked from channel %q by %s", channel, msg.Prefix.Name) - delete(uc.channels, channel) + uc.channels.Delete(channel) } else { ch, err := uc.getChannel(channel) if err != nil { return err } - delete(ch.Members, user) + ch.Members.Delete(user) } uc.produce(channel, msg, nil) @@ -831,13 +862,14 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return fmt.Errorf("expected a prefix") } - if msg.Prefix.Name == uc.nick { + if uc.isOurNick(msg.Prefix.Name) { uc.logger.Printf("quit") } - for _, ch := range uc.channels { - if _, ok := ch.Members[msg.Prefix.Name]; ok { - delete(ch.Members, msg.Prefix.Name) + for _, entry := range uc.channels.innerMap { + ch := entry.value.(*upstreamChannel) + if ch.Members.Has(msg.Prefix.Name) { + ch.Members.Delete(msg.Prefix.Name) uc.appendLog(ch.Name, msg) } @@ -908,7 +940,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { uc.appendLog(ch.Name, msg) - if ch, ok := uc.network.channels[name]; !ok || !ch.Detached { + c := uc.network.channels.Value(name) + if c == nil || !c.Detached { uc.forEachDownstream(func(dc *downstreamConn) { params := make([]string, len(msg.Params)) params[0] = dc.marshalEntity(uc.network, name) @@ -964,7 +997,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return err } if firstMode { - if c, ok := uc.network.channels[channel]; !ok || !c.Detached { + c := uc.network.channels.Value(channel) + if c == nil || !c.Detached { modeStr, modeParams := ch.modes.Format() uc.forEachDownstream(func(dc *downstreamConn) { @@ -1061,8 +1095,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return err } - ch, ok := uc.channels[name] - if !ok { + ch := uc.channels.Value(name) + if ch == nil { // NAMES on a channel we have not joined, forward to downstream uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { channel := dc.marshalEntity(uc.network, name) @@ -1090,7 +1124,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { for _, s := range splitSpace(members) { memberships, nick := uc.parseMembershipPrefix(s) - ch.Members[nick] = memberships + ch.Members.SetValue(nick, memberships) } case irc.RPL_ENDOFNAMES: var name string @@ -1098,8 +1132,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return err } - ch, ok := uc.channels[name] - if !ok { + ch := uc.channels.Value(name) + if ch == nil { // NAMES on a channel we have not joined, forward to downstream uc.forEachDownstreamByID(downstreamID, func(dc *downstreamConn) { channel := dc.marshalEntity(uc.network, name) @@ -1118,7 +1152,8 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { } ch.complete = true - if c, ok := uc.network.channels[name]; !ok || !c.Detached { + c := uc.network.channels.Value(name) + if c == nil || !c.Detached { uc.forEachDownstream(func(dc *downstreamConn) { forwardChannel(dc, ch) }) @@ -1272,7 +1307,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return err } - weAreInvited := nick == uc.nick + weAreInvited := uc.isOurNick(nick) uc.forEachDownstream(func(dc *downstreamConn) { if !weAreInvited && !dc.caps["invite-notify"] { @@ -1395,7 +1430,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { // Ignore case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: // Ignore - case irc.RPL_MOTDSTART, irc.RPL_MOTD, irc.RPL_ENDOFMOTD: + case irc.RPL_MOTDSTART, irc.RPL_MOTD: // Ignore case irc.RPL_LISTSTART: // Ignore @@ -1601,6 +1636,7 @@ func splitSpace(s string) []string { func (uc *upstreamConn) register() { uc.nick = uc.network.Nick + uc.nickCM = uc.network.casemap(uc.nick) uc.username = uc.network.GetUsername() uc.realname = uc.network.GetRealname() @@ -1705,20 +1741,21 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string } detached := false - if ch, ok := uc.network.channels[entity]; ok { + if ch := uc.network.channels.Value(entity); ch != nil { detached = ch.Detached } - delivered, ok := uc.network.delivered[entity] - if !ok { - lastID, err := uc.user.msgStore.LastMsgID(uc.network, entity, time.Now()) + delivered := uc.network.delivered.Value(entity) + entityCM := uc.network.casemap(entity) + if delivered == nil { + lastID, err := uc.user.msgStore.LastMsgID(uc.network, entityCM, time.Now()) if err != nil { uc.logger.Printf("failed to log message: failed to get last message ID: %v", err) return "" } delivered = make(map[string]string) - uc.network.delivered[entity] = delivered + uc.network.delivered.SetValue(entity, delivered) for clientName, _ := range uc.network.offlineClients { delivered[clientName] = lastID @@ -1733,7 +1770,7 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string } } - msgID, err := uc.user.msgStore.Append(uc.network, entity, msg) + msgID, err := uc.user.msgStore.Append(uc.network, entityCM, msg) if err != nil { uc.logger.Printf("failed to log message: %v", err) return "" @@ -1754,7 +1791,8 @@ func (uc *upstreamConn) produce(target string, msg *irc.Message, origin *downstr } // Don't forward messages if it's a detached channel - if ch, ok := uc.network.channels[target]; ok && ch.Detached { + ch := uc.network.channels.Value(target) + if ch != nil && ch.Detached { return } @@ -1789,9 +1827,13 @@ func (uc *upstreamConn) updateAway() { } func (uc *upstreamConn) updateChannelAutoDetach(name string) { - if uch, ok := uc.channels[name]; ok { - if ch, ok := uc.network.channels[name]; ok && !ch.Detached { - uch.updateAutoDetach(ch.DetachAfter) - } + uch := uc.channels.Value(name) + if uch == nil { + return } + ch := uc.network.channels.Value(name) + if ch == nil || ch.Detached { + return + } + uch.updateAutoDetach(ch.DetachAfter) } diff --git a/user.go b/user.go index 73a9058..c7a0aa9 100644 --- a/user.go +++ b/user.go @@ -61,17 +61,18 @@ type network struct { stopped chan struct{} conn *upstreamConn - channels map[string]*Channel - delivered map[string]map[string]string // entity -> client name -> msg ID - offlineClients map[string]struct{} // indexed by client name + channels channelCasemapMap + delivered mapStringStringCasemapMap // entity -> client name -> msg ID + offlineClients map[string]struct{} // indexed by client name lastError error + casemap casemapping } func newNetwork(user *user, record *Network, channels []Channel) *network { - m := make(map[string]*Channel, len(channels)) + m := channelCasemapMap{newCasemapMap(0)} for _, ch := range channels { ch := ch - m[ch.Name] = &ch + m.SetValue(ch.Name, &ch) } return &network{ @@ -79,8 +80,9 @@ func newNetwork(user *user, record *Network, channels []Channel) *network { user: user, stopped: make(chan struct{}), channels: m, - delivered: make(map[string]map[string]string), + delivered: mapStringStringCasemapMap{newCasemapMap(0)}, offlineClients: make(map[string]struct{}), + casemap: casemapRFC1459, } } @@ -185,7 +187,8 @@ func (net *network) detach(ch *Channel) { net.user.srv.Logger.Printf("network %q: detaching channel %q", net.GetName(), ch.Name) if net.conn != nil { - if uch, ok := net.conn.channels[ch.Name]; ok { + uch := net.conn.channels.Value(ch.Name) + if uch != nil { uch.updateAutoDetach(0) } } @@ -210,7 +213,7 @@ func (net *network) attach(ch *Channel) { var uch *upstreamChannel if net.conn != nil { - uch = net.conn.channels[ch.Name] + uch = net.conn.channels.Value(ch.Name) net.conn.updateChannelAutoDetach(ch.Name) } @@ -231,12 +234,13 @@ func (net *network) attach(ch *Channel) { } func (net *network) deleteChannel(name string) error { - ch, ok := net.channels[name] - if !ok { + ch := net.channels.Value(name) + if ch == nil { return fmt.Errorf("unknown channel %q", name) } if net.conn != nil { - if uch, ok := net.conn.channels[ch.Name]; ok { + uch := net.conn.channels.Value(ch.Name) + if uch != nil { uch.updateAutoDetach(0) } } @@ -244,10 +248,23 @@ func (net *network) deleteChannel(name string) error { if err := net.user.srv.db.DeleteChannel(ch.ID); err != nil { return err } - delete(net.channels, name) + net.channels.Delete(name) return nil } +func (net *network) updateCasemapping(newCasemap casemapping) { + net.casemap = newCasemap + net.channels.SetCasemapping(newCasemap) + net.delivered.SetCasemapping(newCasemap) + if net.conn != nil { + net.conn.channels.SetCasemapping(newCasemap) + for _, entry := range net.conn.channels.innerMap { + uch := entry.value.(*upstreamChannel) + uch.Members.SetCasemapping(newCasemap) + } + } +} + type user struct { User srv *Server @@ -410,8 +427,8 @@ func (u *user) run() { } case eventChannelDetach: uc, name := e.uc, e.name - c, ok := uc.network.channels[name] - if !ok || c.Detached { + c := uc.network.channels.Value(name) + if c == nil || c.Detached { continue } uc.network.detach(c) @@ -499,7 +516,8 @@ func (u *user) handleUpstreamDisconnected(uc *upstreamConn) { uc.endPendingLISTs(true) - for _, uch := range uc.channels { + for _, entry := range uc.channels.innerMap { + uch := entry.value.(*upstreamChannel) uch.updateAutoDetach(0) } @@ -570,8 +588,9 @@ func (u *user) updateNetwork(record *Network) (*network, error) { // Most network changes require us to re-connect to the upstream server - channels := make([]Channel, 0, len(network.channels)) - for _, ch := range network.channels { + channels := make([]Channel, 0, network.channels.Len()) + for _, entry := range network.channels.innerMap { + ch := entry.value.(*Channel) channels = append(channels, *ch) }