diff --git a/downstream.go b/downstream.go index 6030256..5782250 100644 --- a/downstream.go +++ b/downstream.go @@ -268,7 +268,7 @@ func (dc *downstreamConn) SendMessage(msg *irc.Message) { // marshalMessage re-formats a message coming from an upstream connection so // that it's suitable for being sent on this downstream connection. Only -// messages that may appear in logs are supported. +// messages that may appear in logs are supported, except MODE. func (dc *downstreamConn) marshalMessage(msg *irc.Message, net *network) *irc.Message { msg = msg.Copy() msg.Prefix = dc.marshalUserPrefix(net, msg.Prefix) @@ -286,8 +286,6 @@ func (dc *downstreamConn) marshalMessage(msg *irc.Message, net *network) *irc.Me msg.Params[1] = dc.marshalEntity(net, msg.Params[1]) case "TOPIC": msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) - case "MODE": - msg.Params[0] = dc.marshalEntity(net, msg.Params[0]) case "QUIT": // This space is intentionally left blank default: diff --git a/irc.go b/irc.go index ea3cc5f..50a22c1 100644 --- a/irc.go +++ b/irc.go @@ -84,9 +84,18 @@ var stdChannelModes = map[byte]channelModeType{ type channelModes map[byte]string -func (cm channelModes) Apply(modeTypes map[byte]channelModeType, modeStr string, arguments ...string) error { +// applyChannelModes parses a mode string and mode arguments from a MODE message, +// and applies the corresponding channel mode and user membership changes on that channel. +// +// If ch.modes is nil, channel modes are not updated. +// +// needMarshaling is a list of indexes of mode arguments that represent entities +// that must be marshaled when sent downstream. +func applyChannelModes(ch *upstreamChannel, modeStr string, arguments []string) (needMarshaling map[int]struct{}, err error) { + needMarshaling = make(map[int]struct{}, len(arguments)) nextArgument := 0 var plusMinus byte +outer: for i := 0; i < len(modeStr); i++ { mode := modeStr[i] if mode == '+' || mode == '-' { @@ -94,10 +103,30 @@ func (cm channelModes) Apply(modeTypes map[byte]channelModeType, modeStr string, continue } if plusMinus != '+' && plusMinus != '-' { - return fmt.Errorf("malformed modestring %q: missing plus/minus", modeStr) + return nil, fmt.Errorf("malformed modestring %q: missing plus/minus", modeStr) } - mt, ok := modeTypes[mode] + for _, membership := range ch.conn.availableMemberships { + if membership.Mode == mode { + if nextArgument >= len(arguments) { + return nil, fmt.Errorf("malformed modestring %q: missing mode argument for %c%c", modeStr, plusMinus, mode) + } + member := arguments[nextArgument] + if _, ok := ch.Members[member]; ok { + if plusMinus == '+' { + ch.Members[member].Add(ch.conn.availableMemberships, membership) + } else { + // TODO: for upstreams without multi-prefix, query the user modes again + ch.Members[member].Remove(membership) + } + } + needMarshaling[nextArgument] = struct{}{} + nextArgument++ + continue outer + } + } + + mt, ok := ch.conn.availableChannelModes[mode] if !ok { continue } @@ -109,20 +138,24 @@ func (cm channelModes) Apply(modeTypes map[byte]channelModeType, modeStr string, if nextArgument < len(arguments) { argument = arguments[nextArgument] } - cm[mode] = argument + if ch.modes != nil { + ch.modes[mode] = argument + } } else { - delete(cm, mode) + delete(ch.modes, mode) } nextArgument++ } else if mt == modeTypeC || mt == modeTypeD { if plusMinus == '+' { - cm[mode] = "" + if ch.modes != nil { + ch.modes[mode] = "" + } } else { - delete(cm, mode) + delete(ch.modes, mode) } } } - return nil + return needMarshaling, nil } func (cm channelModes) Format() (modeString string, parameters []string) { diff --git a/upstream.go b/upstream.go index 1f281fe..c16ca47 100644 --- a/upstream.go +++ b/upstream.go @@ -817,13 +817,30 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { return err } - if ch.modes != nil { - if err := ch.modes.Apply(uc.availableChannelModes, modeStr, msg.Params[2:]...); err != nil { - return err - } + needMarshaling, err := applyChannelModes(ch, modeStr, msg.Params[2:]) + if err != nil { + return err } - uc.produce(ch.Name, msg, nil) + uc.appendLog(ch.Name, msg) + uc.forEachDownstream(func(dc *downstreamConn) { + params := make([]string, len(msg.Params)) + params[0] = dc.marshalEntity(uc.network, name) + params[1] = modeStr + + copy(params[2:], msg.Params[2:]) + for i, modeParam := range params[2:] { + if _, ok := needMarshaling[i]; ok { + params[2+i] = dc.marshalEntity(uc.network, modeParam) + } + } + + dc.SendMessage(&irc.Message{ + Prefix: dc.marshalUserPrefix(uc.network, msg.Prefix), + Command: "MODE", + Params: params, + }) + }) } case irc.RPL_UMODEIS: if err := parseMessageParams(msg, nil); err != nil { @@ -856,7 +873,7 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { firstMode := ch.modes == nil ch.modes = make(map[byte]string) - if err := ch.modes.Apply(uc.availableChannelModes, modeStr, msg.Params[3:]...); err != nil { + if _, err := applyChannelModes(ch, modeStr, msg.Params[3:]); err != nil { return err } if firstMode {