Fix race condition in upstreamConn.Close

upstreamConn.closed was a bool accessed from different goroutines. Use
the same pattern as downstreamConn instead.
This commit is contained in:
Simon Ser 2020-03-27 23:08:35 +01:00
parent f08063c943
commit b33e5f29ab
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
2 changed files with 31 additions and 11 deletions

View File

@ -36,6 +36,7 @@ type upstreamConn struct {
srv *Server srv *Server
user *user user *user
outgoing chan<- *irc.Message outgoing chan<- *irc.Message
closed chan struct{}
serverName string serverName string
availableUserModes string availableUserModes string
@ -47,7 +48,6 @@ type upstreamConn struct {
nick string nick string
username string username string
realname string realname string
closed bool
modes userModes modes userModes
channels map[string]*upstreamChannel channels map[string]*upstreamChannel
caps map[string]string caps map[string]string
@ -95,12 +95,21 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
} }
go func() { go func() {
for msg := range outgoing { for {
if uc.srv.Debug { var closed bool
uc.logger.Printf("sent: %v", msg) select {
case msg := <-outgoing:
if uc.srv.Debug {
uc.logger.Printf("sent: %v", msg)
}
if err := uc.irc.WriteMessage(msg); err != nil {
uc.logger.Printf("failed to write message: %v", err)
}
case <-uc.closed:
closed = true
} }
if err := uc.irc.WriteMessage(msg); err != nil { if closed {
uc.logger.Printf("failed to write message: %v", err) break
} }
} }
if err := uc.net.Close(); err != nil { if err := uc.net.Close(); err != nil {
@ -113,12 +122,20 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
return uc, nil return uc, nil
} }
func (uc *upstreamConn) isClosed() bool {
select {
case <-uc.closed:
return true
default:
return false
}
}
func (uc *upstreamConn) Close() error { func (uc *upstreamConn) Close() error {
if uc.closed { if uc.isClosed() {
return fmt.Errorf("upstream connection already closed") return fmt.Errorf("upstream connection already closed")
} }
close(uc.outgoing) close(uc.closed)
uc.closed = true
return nil return nil
} }

View File

@ -64,6 +64,9 @@ func (net *network) run() {
uc.register() uc.register()
// TODO: wait for the connection to be registered before adding it to
// net, otherwise messages might be sent to it while still being
// unauthenticated
net.lock.Lock() net.lock.Lock()
net.conn = uc net.conn = uc
net.lock.Unlock() net.lock.Unlock()
@ -112,7 +115,7 @@ func (u *user) forEachNetwork(f func(*network)) {
func (u *user) forEachUpstream(f func(uc *upstreamConn)) { func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
for _, network := range u.networks { for _, network := range u.networks {
uc := network.upstream() uc := network.upstream()
if uc == nil || !uc.registered || uc.closed { if uc == nil || !uc.registered {
continue continue
} }
f(uc) f(uc)
@ -152,7 +155,7 @@ func (u *user) run() {
switch e := e.(type) { switch e := e.(type) {
case eventUpstreamMessage: case eventUpstreamMessage:
msg, uc := e.msg, e.uc msg, uc := e.msg, e.uc
if uc.closed { if uc.isClosed() {
uc.logger.Printf("ignoring message on closed connection: %v", msg) uc.logger.Printf("ignoring message on closed connection: %v", msg)
break break
} }