Nuke user.lock

Split user.register into two functions, one to make sure the user is
authenticated, the other to send our current state. This allows to get
rid of data races by doing the second part in the user goroutine.

Closes: https://todo.sr.ht/~emersion/soju/22
This commit is contained in:
Simon Ser 2020-03-27 19:17:58 +01:00
parent c0f5850e5b
commit 08bb06c164
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
2 changed files with 67 additions and 60 deletions

View File

@ -71,6 +71,7 @@ type downstreamConn struct {
nick string nick string
username string username string
rawUsername string rawUsername string
networkName string
realname string realname string
hostname string hostname string
password string // empty after authentication password string // empty after authentication
@ -582,42 +583,6 @@ func unmarshalUsername(rawUsername string) (username, network string) {
return username, network return username, network
} }
func (dc *downstreamConn) setNetwork(networkName string) error {
if networkName == "" {
return nil
}
network := dc.user.getNetwork(networkName)
if network == nil {
addr := networkName
if !strings.ContainsRune(addr, ':') {
addr = addr + ":6697"
}
dc.logger.Printf("trying to connect to new network %q", addr)
if err := sanityCheckServer(addr); err != nil {
dc.logger.Printf("failed to connect to %q: %v", addr, err)
return ircError{&irc.Message{
Command: irc.ERR_PASSWDMISMATCH,
Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
}}
}
dc.logger.Printf("auto-saving network %q", networkName)
var err error
network, err = dc.user.createNetwork(&Network{
Addr: networkName,
Nick: dc.nick,
})
if err != nil {
return err
}
}
dc.network = network
return nil
}
func (dc *downstreamConn) authenticate(username, password string) error { func (dc *downstreamConn) authenticate(username, password string) error {
username, networkName := unmarshalUsername(username) username, networkName := unmarshalUsername(username)
@ -634,31 +599,82 @@ func (dc *downstreamConn) authenticate(username, password string) error {
} }
dc.user = u dc.user = u
dc.networkName = networkName
return dc.setNetwork(networkName) return nil
} }
func (dc *downstreamConn) register() error { func (dc *downstreamConn) register() error {
if dc.registered {
return fmt.Errorf("tried to register twice")
}
password := dc.password password := dc.password
dc.password = "" dc.password = ""
if dc.user == nil { if dc.user == nil {
if err := dc.authenticate(dc.rawUsername, password); err != nil { if err := dc.authenticate(dc.rawUsername, password); err != nil {
return err return err
} }
} else if dc.network == nil {
_, networkName := unmarshalUsername(dc.rawUsername)
if err := dc.setNetwork(networkName); err != nil {
return err
} }
if dc.networkName == "" {
_, dc.networkName = unmarshalUsername(dc.rawUsername)
} }
dc.registered = true dc.registered = true
dc.username = dc.user.Username dc.username = dc.user.Username
dc.logger.Printf("registration complete for user %q", dc.username) dc.logger.Printf("registration complete for user %q", dc.username)
return nil
}
func (dc *downstreamConn) loadNetwork() error {
if dc.networkName == "" {
return nil
}
network := dc.user.getNetwork(dc.networkName)
if network == nil {
addr := dc.networkName
if !strings.ContainsRune(addr, ':') {
addr = addr + ":6697"
}
dc.logger.Printf("trying to connect to new network %q", addr)
if err := sanityCheckServer(addr); err != nil {
dc.logger.Printf("failed to connect to %q: %v", addr, err)
return ircError{&irc.Message{
Command: irc.ERR_PASSWDMISMATCH,
Params: []string{"*", fmt.Sprintf("Failed to connect to %q", dc.networkName)},
}}
}
dc.logger.Printf("auto-saving network %q", dc.networkName)
var err error
network, err = dc.user.createNetwork(&Network{
Addr: dc.networkName,
Nick: dc.nick,
})
if err != nil {
return err
}
}
dc.network = network
return nil
}
func (dc *downstreamConn) welcome() error {
if dc.user == nil || !dc.registered {
panic("tried to welcome an unregistered connection")
}
// TODO: doing this might take some time. We should do it in dc.register
// instead, but we'll potentially be adding a new network and this must be
// done in the user goroutine.
if err := dc.loadNetwork(); err != nil {
return err
}
dc.user.lock.Lock()
firstDownstream := len(dc.user.downstreamConns) == 0 firstDownstream := len(dc.user.downstreamConns) == 0
dc.user.lock.Unlock()
dc.SendMessage(&irc.Message{ dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(), Prefix: dc.srv.prefix(),

21
user.go
View File

@ -91,7 +91,6 @@ type user struct {
events chan event events chan event
lock sync.Mutex
networks []*network networks []*network
downstreamConns []*downstreamConn downstreamConns []*downstreamConn
} }
@ -105,15 +104,12 @@ func newUser(srv *Server, record *User) *user {
} }
func (u *user) forEachNetwork(f func(*network)) { func (u *user) forEachNetwork(f func(*network)) {
u.lock.Lock()
for _, network := range u.networks { for _, network := range u.networks {
f(network) f(network)
} }
u.lock.Unlock()
} }
func (u *user) forEachUpstream(f func(uc *upstreamConn)) { func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
u.lock.Lock()
for _, network := range u.networks { for _, network := range u.networks {
uc := network.upstream() uc := network.upstream()
if uc == nil || !uc.registered || uc.closed { if uc == nil || !uc.registered || uc.closed {
@ -121,15 +117,12 @@ func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
} }
f(uc) f(uc)
} }
u.lock.Unlock()
} }
func (u *user) forEachDownstream(f func(dc *downstreamConn)) { func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
u.lock.Lock()
for _, dc := range u.downstreamConns { for _, dc := range u.downstreamConns {
f(dc) f(dc)
} }
u.lock.Unlock()
} }
func (u *user) getNetwork(name string) *network { func (u *user) getNetwork(name string) *network {
@ -148,14 +141,12 @@ func (u *user) run() {
return return
} }
u.lock.Lock()
for _, record := range networks { for _, record := range networks {
network := newNetwork(u, &record) network := newNetwork(u, &record)
u.networks = append(u.networks, network) u.networks = append(u.networks, network)
go network.run() go network.run()
} }
u.lock.Unlock()
for e := range u.events { for e := range u.events {
switch e := e.(type) { switch e := e.(type) {
@ -170,19 +161,21 @@ func (u *user) run() {
} }
case eventDownstreamConnected: case eventDownstreamConnected:
dc := e.dc dc := e.dc
u.lock.Lock()
if err := dc.welcome(); err != nil {
dc.logger.Printf("failed to handle new registered connection: %v", err)
break
}
u.downstreamConns = append(u.downstreamConns, dc) u.downstreamConns = append(u.downstreamConns, dc)
u.lock.Unlock()
case eventDownstreamDisconnected: case eventDownstreamDisconnected:
dc := e.dc dc := e.dc
u.lock.Lock()
for i := range u.downstreamConns { for i := range u.downstreamConns {
if u.downstreamConns[i] == dc { if u.downstreamConns[i] == dc {
u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...) u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
break break
} }
} }
u.lock.Unlock()
case eventDownstreamMessage: case eventDownstreamMessage:
msg, dc := e.msg, e.dc msg, dc := e.msg, e.dc
if dc.isClosed() { if dc.isClosed() {
@ -220,9 +213,7 @@ func (u *user) createNetwork(net *Network) (*network, error) {
} }
}) })
u.lock.Lock()
u.networks = append(u.networks, network) u.networks = append(u.networks, network)
u.lock.Unlock()
go network.run() go network.run()
return network, nil return network, nil