Add user.forEachDownstream
This commit is contained in:
parent
059a799d16
commit
636ede13da
@ -115,12 +115,21 @@ func (c *downstreamConn) Close() error {
|
|||||||
if c.closed {
|
if c.closed {
|
||||||
return fmt.Errorf("downstream connection already closed")
|
return fmt.Errorf("downstream connection already closed")
|
||||||
}
|
}
|
||||||
if err := c.net.Close(); err != nil {
|
|
||||||
return err
|
if u := c.user; u != nil {
|
||||||
|
u.lock.Lock()
|
||||||
|
for i := range u.downstreamConns {
|
||||||
|
if u.downstreamConns[i] == c {
|
||||||
|
u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
u.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
close(c.messages)
|
close(c.messages)
|
||||||
c.closed = true
|
c.closed = true
|
||||||
return nil
|
|
||||||
|
return c.net.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *downstreamConn) handleMessage(msg *irc.Message) error {
|
func (c *downstreamConn) handleMessage(msg *irc.Message) error {
|
||||||
@ -182,6 +191,10 @@ func (c *downstreamConn) register() error {
|
|||||||
c.registered = true
|
c.registered = true
|
||||||
c.user = u
|
c.user = u
|
||||||
|
|
||||||
|
u.lock.Lock()
|
||||||
|
u.downstreamConns = append(u.downstreamConns, c)
|
||||||
|
u.lock.Unlock()
|
||||||
|
|
||||||
c.messages <- &irc.Message{
|
c.messages <- &irc.Message{
|
||||||
Prefix: c.srv.prefix(),
|
Prefix: c.srv.prefix(),
|
||||||
Command: irc.RPL_WELCOME,
|
Command: irc.RPL_WELCOME,
|
||||||
|
13
server.go
13
server.go
@ -35,8 +35,9 @@ type user struct {
|
|||||||
username string
|
username string
|
||||||
srv *Server
|
srv *Server
|
||||||
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
upstreamConns []*upstreamConn
|
upstreamConns []*upstreamConn
|
||||||
|
downstreamConns []*downstreamConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
|
func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
|
||||||
@ -50,6 +51,14 @@ func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
|
|||||||
u.lock.Unlock()
|
u.lock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
|
||||||
|
u.lock.Lock()
|
||||||
|
for _, dc := range u.downstreamConns {
|
||||||
|
f(dc)
|
||||||
|
}
|
||||||
|
u.lock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
type Upstream struct {
|
type Upstream struct {
|
||||||
Addr string
|
Addr string
|
||||||
Nick string
|
Nick string
|
||||||
|
30
upstream.go
30
upstream.go
@ -125,11 +125,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.srv.lock.Lock()
|
c.user.forEachDownstream(func(dc *downstreamConn) {
|
||||||
for _, dc := range c.srv.downstreamConns {
|
|
||||||
dc.messages <- msg
|
dc.messages <- msg
|
||||||
}
|
})
|
||||||
c.srv.lock.Unlock()
|
|
||||||
}
|
}
|
||||||
case "NOTICE":
|
case "NOTICE":
|
||||||
c.logger.Print(msg)
|
c.logger.Print(msg)
|
||||||
@ -174,11 +172,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.srv.lock.Lock()
|
c.user.forEachDownstream(func(dc *downstreamConn) {
|
||||||
for _, dc := range c.srv.downstreamConns {
|
|
||||||
dc.messages <- msg
|
dc.messages <- msg
|
||||||
}
|
})
|
||||||
c.srv.lock.Unlock()
|
|
||||||
case "PART":
|
case "PART":
|
||||||
if len(msg.Params) < 1 {
|
if len(msg.Params) < 1 {
|
||||||
return newNeedMoreParamsError(msg.Command)
|
return newNeedMoreParamsError(msg.Command)
|
||||||
@ -197,11 +193,9 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.srv.lock.Lock()
|
c.user.forEachDownstream(func(dc *downstreamConn) {
|
||||||
for _, dc := range c.srv.downstreamConns {
|
|
||||||
dc.messages <- msg
|
dc.messages <- msg
|
||||||
}
|
})
|
||||||
c.srv.lock.Unlock()
|
|
||||||
case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
|
case irc.RPL_TOPIC, irc.RPL_NOTOPIC:
|
||||||
if len(msg.Params) < 3 {
|
if len(msg.Params) < 3 {
|
||||||
return newNeedMoreParamsError(msg.Command)
|
return newNeedMoreParamsError(msg.Command)
|
||||||
@ -275,17 +269,13 @@ func (c *upstreamConn) handleMessage(msg *irc.Message) error {
|
|||||||
}
|
}
|
||||||
ch.complete = true
|
ch.complete = true
|
||||||
|
|
||||||
c.srv.lock.Lock()
|
c.user.forEachDownstream(func(dc *downstreamConn) {
|
||||||
for _, dc := range c.srv.downstreamConns {
|
|
||||||
forwardChannel(dc, ch)
|
forwardChannel(dc, ch)
|
||||||
}
|
})
|
||||||
c.srv.lock.Unlock()
|
|
||||||
case "PRIVMSG":
|
case "PRIVMSG":
|
||||||
c.srv.lock.Lock()
|
c.user.forEachDownstream(func(dc *downstreamConn) {
|
||||||
for _, dc := range c.srv.downstreamConns {
|
|
||||||
dc.messages <- msg
|
dc.messages <- msg
|
||||||
}
|
})
|
||||||
c.srv.lock.Unlock()
|
|
||||||
case irc.RPL_YOURHOST, irc.RPL_CREATED:
|
case irc.RPL_YOURHOST, irc.RPL_CREATED:
|
||||||
// Ignore
|
// Ignore
|
||||||
case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
|
case irc.RPL_LUSERCLIENT, irc.RPL_LUSEROP, irc.RPL_LUSERUNKNOWN, irc.RPL_LUSERCHANNELS, irc.RPL_LUSERME:
|
||||||
|
Loading…
Reference in New Issue
Block a user