Add user.forEachDownstream

This commit is contained in:
Simon Ser 2020-02-07 11:56:36 +01:00
parent 059a799d16
commit 636ede13da
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
3 changed files with 37 additions and 25 deletions

View File

@ -115,12 +115,21 @@ func (c *downstreamConn) Close() error {
if c.closed {
return fmt.Errorf("downstream connection already closed")
}
if err := c.net.Close(); err != nil {
return err
if u := c.user; u != nil {
u.lock.Lock()
for i := range u.downstreamConns {
if u.downstreamConns[i] == c {
u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
}
}
u.lock.Unlock()
}
close(c.messages)
c.closed = true
return nil
return c.net.Close()
}
func (c *downstreamConn) handleMessage(msg *irc.Message) error {
@ -182,6 +191,10 @@ func (c *downstreamConn) register() error {
c.registered = true
c.user = u
u.lock.Lock()
u.downstreamConns = append(u.downstreamConns, c)
u.lock.Unlock()
c.messages <- &irc.Message{
Prefix: c.srv.prefix(),
Command: irc.RPL_WELCOME,

View File

@ -35,8 +35,9 @@ type user struct {
username string
srv *Server
lock sync.Mutex
upstreamConns []*upstreamConn
lock sync.Mutex
upstreamConns []*upstreamConn
downstreamConns []*downstreamConn
}
func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
@ -50,6 +51,14 @@ func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
u.lock.Unlock()
}
func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
u.lock.Lock()
for _, dc := range u.downstreamConns {
f(dc)
}
u.lock.Unlock()
}
type Upstream struct {
Addr string
Nick string

View File

@ -125,11 +125,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
return err
}
c.srv.lock.Lock()
for _, dc := range c.srv.downstreamConns {
c.user.forEachDownstream(func(dc *downstreamConn) {
dc.messages <- msg
}
c.srv.lock.Unlock()
})
}
case "NOTICE":
c.logger.Print(msg)
@ -174,11 +172,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
}
}
c.srv.lock.Lock()
for _, dc := range c.srv.downstreamConns {
c.user.forEachDownstream(func(dc *downstreamConn) {
dc.messages <- msg
}
c.srv.lock.Unlock()
})
case "PART":
if len(msg.Params) < 1 {
return newNeedMoreParamsError(msg.Command)
@ -197,11 +193,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
}
}
c.srv.lock.Lock()
for _, dc := range c.srv.downstreamConns {
c.user.forEachDownstream(func(dc *downstreamConn) {
dc.messages <- msg
}
c.srv.lock.Unlock()
})
case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
if len(msg.Params) < 3 {
return newNeedMoreParamsError(msg.Command)
@ -275,17 +269,13 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
}
ch.complete = true
c.srv.lock.Lock()
for _, dc := range c.srv.downstreamConns {
c.user.forEachDownstream(func(dc *downstreamConn) {
forwardChannel(dc, ch)
}
c.srv.lock.Unlock()
})
case "PRIVMSG":
c.srv.lock.Lock()
for _, dc := range c.srv.downstreamConns {
c.user.forEachDownstream(func(dc *downstreamConn) {
dc.messages <- msg
}
c.srv.lock.Unlock()
})
case irc.RPL_YOURHOST, irc.RPL_CREATED:
// Ignore
case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME: