diff --git a/downstream.go b/downstream.go index 98668dd..7a0b248 100644 --- a/downstream.go +++ b/downstream.go @@ -71,12 +71,12 @@ type downstreamConn struct { negociatingCaps bool capVersion int - caps map[string]bool saslServer sasl.Server lock sync.Mutex ourMessages map[*irc.Message]struct{} + caps map[string]bool } func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn { @@ -85,8 +85,8 @@ func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn conn: *newConn(srv, netConn, logger), id: id, ringConsumers: make(map[*network]*RingConsumer), - caps: make(map[string]bool), ourMessages: make(map[*irc.Message]struct{}), + caps: make(map[string]bool), } dc.hostname = netConn.RemoteAddr().String() if host, _, err := net.SplitHostPort(dc.hostname); err == nil { @@ -209,8 +209,27 @@ func (dc *downstreamConn) readMessages(ch chan<- event) error { return nil } +func (dc *downstreamConn) getCap(name string) bool { + dc.lock.Lock() + defer dc.lock.Unlock() + return dc.caps[name] +} + func (dc *downstreamConn) SendMessage(msg *irc.Message) { - // TODO: strip tags if the client doesn't support them (see runNetwork) + if !dc.getCap("message-tags") { + msg = msg.Copy() + for name := range msg.Tags { + supported := false + switch name { + case "time": + supported = dc.getCap("server-time") + } + if !supported { + delete(msg.Tags, name) + } + } + } + dc.conn.SendMessage(msg) } @@ -258,7 +277,7 @@ func (dc *downstreamConn) handleMessageUnregistered(msg *irc.Message) error { return err } case "AUTHENTICATE": - if !dc.caps["sasl"] { + if !dc.getCap("sasl") { return ircError{&irc.Message{ Command: irc.ERR_SASLFAIL, Params: []string{"*", "AUTHENTICATE requires the \"sasl\" capability to be enabled"}, @@ -399,9 +418,11 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { } case "LIST": var caps []string + dc.lock.Lock() for name := range dc.caps { caps = append(caps, name) } + dc.lock.Unlock() // TODO: multi-line replies dc.SendMessage(&irc.Message{ @@ -419,6 +440,7 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { caps := strings.Fields(args[0]) ack := true + dc.lock.Lock() for _, name := range caps { name = strings.ToLower(name) enable := !strings.HasPrefix(name, "-") @@ -438,6 +460,7 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error { ack = false } } + dc.lock.Unlock() reply := "NAK" if ack { @@ -663,11 +686,6 @@ func (dc *downstreamConn) runNetwork(net *network, loadHistory bool) { } } - // TODO: can't be enabled/disabled on-the-fly - msgTagsEnabled := dc.caps["message-tags"] - serverTimeEnabled := dc.caps["server-time"] - echoMessageEnabled := dc.caps["echo-message"] - consumer, ch := net.ring.NewConsumer(seqPtr) if _, ok := dc.ringConsumers[net]; ok { @@ -693,7 +711,7 @@ func (dc *downstreamConn) runNetwork(net *network, loadHistory bool) { _, ours := dc.ourMessages[msg] delete(dc.ourMessages, msg) dc.lock.Unlock() - if ours && !echoMessageEnabled { + if ours && !dc.getCap("echo-message") { // The message comes from our connection, don't echo it // back consumer.Consume() @@ -709,19 +727,6 @@ func (dc *downstreamConn) runNetwork(net *network, loadHistory bool) { panic("expected to consume a PRIVMSG message") } - if !msgTagsEnabled { - for name := range msg.Tags { - supported := false - switch name { - case "time": - supported = serverTimeEnabled - } - if !supported { - delete(msg.Tags, name) - } - } - } - dc.SendMessage(msg) consumer.Consume() } diff --git a/logger.go b/logger.go index ab8f15e..43b3eda 100644 --- a/logger.go +++ b/logger.go @@ -31,6 +31,8 @@ func (ml *messageLogger) Append(msg *irc.Message) error { return nil } + // TODO: parse time from msg.Tags["time"], if available + // TODO: enforce maximum open file handles (LRU cache of file handles) // TODO: handle non-monotonic clock behaviour now := time.Now() diff --git a/upstream.go b/upstream.go index f21740d..d9f358b 100644 --- a/upstream.go +++ b/upstream.go @@ -245,6 +245,10 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { } } + if _, ok := msg.Tags["time"]; !ok { + msg.Tags["time"] = irc.TagValue(time.Now().Format(serverTimeLayout)) + } + switch msg.Command { case "PING": uc.SendMessage(&irc.Message{ @@ -1149,10 +1153,6 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { break } - if _, ok := msg.Tags["time"]; !ok { - msg.Tags["time"] = irc.TagValue(time.Now().Format(serverTimeLayout)) - } - target := nick if nick == uc.nick { target = msg.Prefix.Name