diff --git a/upstream.go b/upstream.go index e2dd8d6..2d6b4eb 100644 --- a/upstream.go +++ b/upstream.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "strings" "gopkg.in/irc.v3" ) @@ -14,12 +15,52 @@ const ( rpl_globalusers = "266" ) +type modeSet string + +func (ms modeSet) Has(c byte) bool { + return strings.IndexByte(string(ms), c) >= 0 +} + +func (ms *modeSet) Add(c byte) { + if !ms.Has(c) { + *ms += modeSet(c) + } +} + +func (ms *modeSet) Del(c byte) { + i := strings.IndexByte(string(*ms), c) + if i >= 0 { + *ms = (*ms)[:i] + (*ms)[i+1:] + } +} + +func (ms *modeSet) Apply(s string) error { + var plusMinus byte + for i := 0; i < len(s); i++ { + switch c := s[i]; c { + case '+', '-': + plusMinus = c + default: + switch plusMinus { + case '+': + ms.Add(c) + case '-': + ms.Del(c) + default: + return fmt.Errorf("malformed modestring %q: missing plus/minus", s) + } + } + } + return nil +} + type upstreamConn struct { upstream *Upstream net net.Conn irc *irc.Conn srv *Server registered bool + modes modeSet serverName string availableUserModes string @@ -35,6 +76,14 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error { Command: "PONG", Params: []string{c.srv.Hostname}, }) + case "MODE": + if len(msg.Params) < 2 { + return newNeedMoreParamsError(msg.Command) + } + if nick := msg.Params[0]; nick != c.upstream.Nick { + return fmt.Errorf("received MODE message for unknow nick %q", nick) + } + return c.modes.Apply(msg.Params[1]) case irc.RPL_WELCOME: c.registered = true c.srv.Logger.Printf("Connection to %q registered", c.upstream.Addr)