From 9a93c56cdf775c1067a395965e0f4b05dfe1a7b0 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 17 Feb 2020 15:46:29 +0100 Subject: [PATCH] Fix issues related to Ring - RingConsumer is now used directly in the goroutine responsible for writing downstream messages. This allows the ring buffer not to be consumed on write error. - RingConsumer now has a channel attached. This allows PRIVMSG messages to always use RingConsumer, instead of also directly pushing messages to all downstream connections. - Multiple clients with the same history name are now supported. - Ring is now protected by a mutex --- downstream.go | 116 ++++++++++++++++++++++++++++++++++++-------------- ring.go | 85 +++++++++++++++++++++++++++--------- server.go | 9 +++- upstream.go | 5 +-- 4 files changed, 159 insertions(+), 56 deletions(-) diff --git a/downstream.go b/downstream.go index 419370d..0dd6848 100644 --- a/downstream.go +++ b/downstream.go @@ -40,15 +40,16 @@ func (err ircError) Error() string { } type downstreamConn struct { - net net.Conn - irc *irc.Conn - srv *Server - logger Logger - messages chan *irc.Message + net net.Conn + irc *irc.Conn + srv *Server + logger Logger + messages chan *irc.Message + consumers chan *RingConsumer + closed chan struct{} registered bool user *user - closed bool nick string username string realname string @@ -56,11 +57,13 @@ type downstreamConn struct { func newDownstreamConn(srv *Server, netConn net.Conn) *downstreamConn { dc := &downstreamConn{ - net: netConn, - irc: irc.NewConn(netConn), - srv: srv, - logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}, - messages: make(chan *irc.Message, 64), + net: netConn, + irc: irc.NewConn(netConn), + srv: srv, + logger: &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}, + messages: make(chan *irc.Message, 64), + consumers: make(chan *RingConsumer), + closed: make(chan struct{}), } go func() { @@ -85,6 +88,15 @@ func (dc *downstreamConn) prefix() *irc.Prefix { } } +func (dc *downstreamConn) isClosed() bool { + select { + case <-dc.closed: + return true + default: + return false + } +} + func (dc *downstreamConn) readMessages() error { dc.logger.Printf("new connection") @@ -104,7 +116,7 @@ func (dc *downstreamConn) readMessages() error { return fmt.Errorf("failed to handle IRC command %q: %v", msg.Command, err) } - if dc.closed { + if dc.isClosed() { return nil } } @@ -113,16 +125,39 @@ func (dc *downstreamConn) readMessages() error { } func (dc *downstreamConn) writeMessages() error { - for msg := range dc.messages { - if err := dc.irc.WriteMessage(msg); err != nil { + for { + var err error + var closed bool + select { + case msg := <-dc.messages: + err = dc.irc.WriteMessage(msg) + case consumer := <-dc.consumers: + for { + msg := consumer.Peek() + if msg == nil { + break + } + err = dc.irc.WriteMessage(msg) + if err != nil { + break + } + consumer.Consume() + } + case <-dc.closed: + closed = true + } + if err != nil { return err } + if closed { + break + } } return nil } func (dc *downstreamConn) Close() error { - if dc.closed { + if dc.isClosed() { return fmt.Errorf("downstream connection already closed") } @@ -134,17 +169,9 @@ func (dc *downstreamConn) Close() error { } } u.lock.Unlock() - - // TODO: figure out a better way to advance the ring buffer consumer cursor - u.forEachUpstream(func(uc *upstreamConn) { - // TODO: let clients specify the ring buffer name in their username - uc.ring.Consumer("").Reset() - }) } - close(dc.messages) - dc.closed = true - + close(dc.closed) return nil } @@ -211,6 +238,7 @@ func (dc *downstreamConn) register() error { dc.user = u u.lock.Lock() + firstDownstream := len(u.downstreamConns) == 0 u.downstreamConns = append(u.downstreamConns, dc) u.lock.Unlock() @@ -249,15 +277,41 @@ func (dc *downstreamConn) register() error { } // TODO: let clients specify the ring buffer name in their username - consumer := uc.ring.Consumer("") - for { - // TODO: these messages will get lost if the connection is closed - msg := consumer.Consume() - if msg == nil { - break + historyName := "" + + var seqPtr *uint64 + if firstDownstream { + seq, ok := uc.history[historyName] + if ok { + seqPtr = &seq } - dc.SendMessage(msg) } + + consumer, ch := uc.ring.Consumer(seqPtr) + go func() { + for { + var closed bool + select { + case <-ch: + dc.consumers <- consumer + case <-dc.closed: + closed = true + } + if closed { + break + } + } + + seq := consumer.Close() + + dc.user.lock.Lock() + lastDownstream := len(dc.user.downstreamConns) == 0 + dc.user.lock.Unlock() + + if lastDownstream { + uc.history[historyName] = seq + } + }() }) return nil diff --git a/ring.go b/ring.go index 5d7086c..f7d3d0d 100644 --- a/ring.go +++ b/ring.go @@ -1,52 +1,77 @@ package jounce import ( + "sync" + "gopkg.in/irc.v3" ) // Ring implements a single producer, multiple consumer ring buffer. The ring // buffer size is fixed. The ring buffer is stored in memory. type Ring struct { - buffer []*irc.Message - cap, cur uint64 + buffer []*irc.Message + cap uint64 - consumers map[string]*RingConsumer + lock sync.Mutex + cur uint64 + consumers []*RingConsumer } func NewRing(capacity int) *Ring { return &Ring{ - buffer: make([]*irc.Message, capacity), - cap: uint64(capacity), - consumers: make(map[string]*RingConsumer), + buffer: make([]*irc.Message, capacity), + cap: uint64(capacity), } } func (r *Ring) Produce(msg *irc.Message) { + r.lock.Lock() + defer r.lock.Unlock() + i := int(r.cur % r.cap) r.buffer[i] = msg r.cur++ + + for _, consumer := range r.consumers { + select { + case consumer.ch <- struct{}{}: + // This space is intentionally left blank + default: + // The channel already has a pending item + } + } } -func (r *Ring) Consumer(name string) *RingConsumer { - consumer, ok := r.consumers[name] - if ok { - return consumer +func (r *Ring) Consumer(seq *uint64) (*RingConsumer, <-chan struct{}) { + consumer := &RingConsumer{ + ring: r, + ch: make(chan struct{}, 1), } - consumer = &RingConsumer{ - ring: r, - cur: r.cur, + r.lock.Lock() + if seq != nil { + consumer.cur = *seq + } else { + consumer.cur = r.cur } - r.consumers[name] = consumer - return consumer + if consumer.diff() > 0 { + consumer.ch <- struct{}{} + } + r.consumers = append(r.consumers, consumer) + r.lock.Unlock() + + return consumer, consumer.ch } type RingConsumer struct { - ring *Ring - cur uint64 + ring *Ring + cur uint64 + ch chan struct{} + closed bool } -func (rc *RingConsumer) Diff() uint64 { +// diff returns the number of pending messages. It assumes the Ring is locked. +func (rc *RingConsumer) diff() uint64 { if rc.cur > rc.ring.cur { panic("jounce: consumer cursor greater than producer cursor") } @@ -54,7 +79,14 @@ func (rc *RingConsumer) Diff() uint64 { } func (rc *RingConsumer) Peek() *irc.Message { - diff := rc.Diff() + if rc.closed { + panic("jounce: RingConsumer.Peek called after Close") + } + + rc.ring.lock.Lock() + defer rc.ring.lock.Unlock() + + diff := rc.diff() if diff == 0 { return nil } @@ -78,6 +110,17 @@ func (rc *RingConsumer) Consume() *irc.Message { return msg } -func (rc *RingConsumer) Reset() { - rc.cur = rc.ring.cur +func (rc *RingConsumer) Close() uint64 { + rc.ring.lock.Lock() + for i := range rc.ring.consumers { + if rc.ring.consumers[i] == rc { + rc.ring.consumers = append(rc.ring.consumers[:i], rc.ring.consumers[i+1:]...) + break + } + } + rc.ring.lock.Unlock() + + close(rc.ch) + rc.closed = true + return rc.cur } diff --git a/server.go b/server.go index e88a5e9..e4d96e3 100644 --- a/server.go +++ b/server.go @@ -40,6 +40,13 @@ type user struct { downstreamConns []*downstreamConn } +func newUser(srv *Server, username string) *user { + return &user{ + username: username, + srv: srv, + } +} + func (u *user) forEachUpstream(f func(uc *upstreamConn)) { u.lock.Lock() for _, uc := range u.upstreamConns { @@ -116,7 +123,7 @@ func (s *Server) prefix() *irc.Prefix { func (s *Server) Run() { // TODO: multi-user - u := &user{username: "jounce", srv: s} + u := newUser(s, "jounce") s.lock.Lock() s.users[u.username] = u diff --git a/upstream.go b/upstream.go index 225500a..342b06f 100644 --- a/upstream.go +++ b/upstream.go @@ -44,6 +44,7 @@ type upstreamConn struct { closed bool modes modeSet channels map[string]*upstreamChannel + history map[string]uint64 } func connectToUpstream(u *user, upstream *Upstream) (*upstreamConn, error) { @@ -66,6 +67,7 @@ func connectToUpstream(u *user, upstream *Upstream) (*upstreamConn, error) { messages: msgs, ring: NewRing(u.srv.RingCap), channels: make(map[string]*upstreamChannel), + history: make(map[string]uint64), } go func() { @@ -305,9 +307,6 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error { }) case "PRIVMSG": uc.ring.Produce(msg) - uc.user.forEachDownstream(func(dc *downstreamConn) { - dc.SendMessage(msg) - }) case irc.RPL_YOURHOST, irc.RPL_CREATED: // Ignore case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: