Add eventDownstreamConnected

In a later commit, we'll be able to move part of downstreamConn.register
into the user goroutine to prevent races.

References: https://todo.sr.ht/~emersion/soju/22
This commit is contained in:
Simon Ser 2020-03-27 17:21:05 +01:00
parent 474f2889d9
commit 36ab6ece09
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
3 changed files with 13 additions and 9 deletions

View File

@ -660,7 +660,9 @@ func (dc *downstreamConn) register() error {
dc.username = dc.user.Username dc.username = dc.user.Username
dc.logger.Printf("registration complete for user %q", dc.username) dc.logger.Printf("registration complete for user %q", dc.username)
firstDownstream := dc.user.addDownstream(dc) dc.user.lock.Lock()
firstDownstream := len(dc.user.downstreamConns) == 0
dc.user.lock.Unlock()
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),

View File

@ -119,6 +119,7 @@ func (s *Server) Serve(ln net.Listener) error {
if err := dc.runUntilRegistered(); err != nil { if err := dc.runUntilRegistered(); err != nil {
dc.logger.Print(err) dc.logger.Print(err)
} else { } else {
dc.user.events <- eventDownstreamConnected{dc}
if err := dc.readMessages(dc.user.events); err != nil { if err := dc.readMessages(dc.user.events); err != nil {
dc.logger.Print(err) dc.logger.Print(err)
} }

17
user.go
View File

@ -19,6 +19,10 @@ type eventDownstreamMessage struct {
dc *downstreamConn dc *downstreamConn
} }
type eventDownstreamConnected struct {
dc *downstreamConn
}
type network struct { type network struct {
Network Network
user *user user *user
@ -160,6 +164,11 @@ func (u *user) run() {
if err := uc.handleMessage(msg); err != nil { if err := uc.handleMessage(msg); err != nil {
uc.logger.Printf("failed to handle message %q: %v", msg, err) uc.logger.Printf("failed to handle message %q: %v", msg, err)
} }
case eventDownstreamConnected:
dc := e.dc
u.lock.Lock()
u.downstreamConns = append(u.downstreamConns, dc)
u.lock.Unlock()
case eventDownstreamMessage: case eventDownstreamMessage:
msg, dc := e.msg, e.dc msg, dc := e.msg, e.dc
if dc.isClosed() { if dc.isClosed() {
@ -180,14 +189,6 @@ func (u *user) run() {
} }
} }
func (u *user) addDownstream(dc *downstreamConn) (first bool) {
u.lock.Lock()
first = len(dc.user.downstreamConns) == 0
u.downstreamConns = append(u.downstreamConns, dc)
u.lock.Unlock()
return first
}
func (u *user) removeDownstream(dc *downstreamConn) { func (u *user) removeDownstream(dc *downstreamConn) {
u.lock.Lock() u.lock.Lock()
for i := range u.downstreamConns { for i := range u.downstreamConns {