Add user.forEachDownstream

This commit is contained in:
Simon Ser 2020-02-07 11:56:36 +01:00
parent 059a799d16
commit 636ede13da
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
3 changed files with 37 additions and 25 deletions

View File

@ -115,12 +115,21 @@ func (c *downstreamConn) Close() error {
if c.closed { if c.closed {
return fmt.Errorf("downstream connection already closed") return fmt.Errorf("downstream connection already closed")
} }
if err := c.net.Close(); err != nil {
return err if u := c.user; u != nil {
u.lock.Lock()
for i := range u.downstreamConns {
if u.downstreamConns[i] == c {
u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
} }
}
u.lock.Unlock()
}
close(c.messages) close(c.messages)
c.closed = true c.closed = true
return nil
return c.net.Close()
} }
func (c *downstreamConn) handleMessage(msg *irc.Message) error { func (c *downstreamConn) handleMessage(msg *irc.Message) error {
@ -182,6 +191,10 @@ func (c *downstreamConn) register() error {
c.registered = true c.registered = true
c.user = u c.user = u
u.lock.Lock()
u.downstreamConns = append(u.downstreamConns, c)
u.lock.Unlock()
c.messages <- &irc.Message{ c.messages <- &irc.Message{
Prefix: c.srv.prefix(), Prefix: c.srv.prefix(),
Command: irc.RPL_WELCOME, Command: irc.RPL_WELCOME,

View File

@ -37,6 +37,7 @@ type user struct {
lock sync.Mutex lock sync.Mutex
upstreamConns []*upstreamConn upstreamConns []*upstreamConn
downstreamConns []*downstreamConn
} }
func (u *user) forEachUpstream(f func(uc *upstreamConn)) { func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
@ -50,6 +51,14 @@ func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
u.lock.Unlock() u.lock.Unlock()
} }
func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
u.lock.Lock()
for _, dc := range u.downstreamConns {
f(dc)
}
u.lock.Unlock()
}
type Upstream struct { type Upstream struct {
Addr string Addr string
Nick string Nick string

View File

@ -125,11 +125,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
return err return err
} }
c.srv.lock.Lock() c.user.forEachDownstream(func(dc *downstreamConn) {
for _, dc := range c.srv.downstreamConns {
dc.messages <- msg dc.messages <- msg
} })
c.srv.lock.Unlock()
} }
case "NOTICE": case "NOTICE":
c.logger.Print(msg) c.logger.Print(msg)
@ -174,11 +172,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
} }
} }
c.srv.lock.Lock() c.user.forEachDownstream(func(dc *downstreamConn) {
for _, dc := range c.srv.downstreamConns {
dc.messages <- msg dc.messages <- msg
} })
c.srv.lock.Unlock()
case "PART": case "PART":
if len(msg.Params) < 1 { if len(msg.Params) < 1 {
return newNeedMoreParamsError(msg.Command) return newNeedMoreParamsError(msg.Command)
@ -197,11 +193,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
} }
} }
c.srv.lock.Lock() c.user.forEachDownstream(func(dc *downstreamConn) {
for _, dc := range c.srv.downstreamConns {
dc.messages <- msg dc.messages <- msg
} })
c.srv.lock.Unlock()
case irc.RPL_TOPIC, irc.RPL_NOTOPIC: case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
if len(msg.Params) < 3 { if len(msg.Params) < 3 {
return newNeedMoreParamsError(msg.Command) return newNeedMoreParamsError(msg.Command)
@ -275,17 +269,13 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
} }
ch.complete = true ch.complete = true
c.srv.lock.Lock() c.user.forEachDownstream(func(dc *downstreamConn) {
for _, dc := range c.srv.downstreamConns {
forwardChannel(dc, ch) forwardChannel(dc, ch)
} })
c.srv.lock.Unlock()
case "PRIVMSG": case "PRIVMSG":
c.srv.lock.Lock() c.user.forEachDownstream(func(dc *downstreamConn) {
for _, dc := range c.srv.downstreamConns {
dc.messages <- msg dc.messages <- msg
} })
c.srv.lock.Unlock()
case irc.RPL_YOURHOST, irc.RPL_CREATED: case irc.RPL_YOURHOST, irc.RPL_CREATED:
// Ignore // Ignore
case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: