Nuke user.lock

Split user.register into two functions, one to make sure the user is
authenticated, the other to send our current state. This allows to get
rid of data races by doing the second part in the user goroutine.

Closes: https://todo.sr.ht/~emersion/soju/22
This commit is contained in:
Simon Ser 2020-03-27 19:17:58 +01:00
parent c0f5850e5b
commit 08bb06c164
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
2 changed files with 67 additions and 60 deletions

View File

@ -71,6 +71,7 @@ type downstreamConn struct {
nick string
username string
rawUsername string
networkName string
realname string
hostname string
password string // empty after authentication
@ -582,42 +583,6 @@ func unmarshalUsername(rawUsername string) (username, network string) {
return username, network
}
func (dc *downstreamConn) setNetwork(networkName string) error {
if networkName == "" {
return nil
}
network := dc.user.getNetwork(networkName)
if network == nil {
addr := networkName
if !strings.ContainsRune(addr, ':') {
addr = addr + ":6697"
}
dc.logger.Printf("trying to connect to new network %q", addr)
if err := sanityCheckServer(addr); err != nil {
dc.logger.Printf("failed to connect to %q: %v", addr, err)
return ircError{&irc.Message{
Command: irc.ERR_PASSWDMISMATCH,
Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
}}
}
dc.logger.Printf("auto-saving network %q", networkName)
var err error
network, err = dc.user.createNetwork(&Network{
Addr: networkName,
Nick: dc.nick,
})
if err != nil {
return err
}
}
dc.network = network
return nil
}
func (dc *downstreamConn) authenticate(username, password string) error {
username, networkName := unmarshalUsername(username)
@ -634,31 +599,82 @@ func (dc *downstreamConn) authenticate(username, password string) error {
}
dc.user = u
return dc.setNetwork(networkName)
dc.networkName = networkName
return nil
}
func (dc *downstreamConn) register() error {
if dc.registered {
return fmt.Errorf("tried to register twice")
}
password := dc.password
dc.password = ""
if dc.user == nil {
if err := dc.authenticate(dc.rawUsername, password); err != nil {
return err
}
} else if dc.network == nil {
_, networkName := unmarshalUsername(dc.rawUsername)
if err := dc.setNetwork(networkName); err != nil {
return err
}
}
if dc.networkName == "" {
_, dc.networkName = unmarshalUsername(dc.rawUsername)
}
dc.registered = true
dc.username = dc.user.Username
dc.logger.Printf("registration complete for user %q", dc.username)
return nil
}
func (dc *downstreamConn) loadNetwork() error {
if dc.networkName == "" {
return nil
}
network := dc.user.getNetwork(dc.networkName)
if network == nil {
addr := dc.networkName
if !strings.ContainsRune(addr, ':') {
addr = addr + ":6697"
}
dc.logger.Printf("trying to connect to new network %q", addr)
if err := sanityCheckServer(addr); err != nil {
dc.logger.Printf("failed to connect to %q: %v", addr, err)
return ircError{&irc.Message{
Command: irc.ERR_PASSWDMISMATCH,
Params: []string{"*", fmt.Sprintf("Failed to connect to %q", dc.networkName)},
}}
}
dc.logger.Printf("auto-saving network %q", dc.networkName)
var err error
network, err = dc.user.createNetwork(&Network{
Addr: dc.networkName,
Nick: dc.nick,
})
if err != nil {
return err
}
}
dc.network = network
return nil
}
func (dc *downstreamConn) welcome() error {
if dc.user == nil || !dc.registered {
panic("tried to welcome an unregistered connection")
}
// TODO: doing this might take some time. We should do it in dc.register
// instead, but we'll potentially be adding a new network and this must be
// done in the user goroutine.
if err := dc.loadNetwork(); err != nil {
return err
}
dc.user.lock.Lock()
firstDownstream := len(dc.user.downstreamConns) == 0
dc.user.lock.Unlock()
dc.SendMessage(&irc.Message{
Prefix: dc.srv.prefix(),

21
user.go
View File

@ -91,7 +91,6 @@ type user struct {
events chan event
lock sync.Mutex
networks []*network
downstreamConns []*downstreamConn
}
@ -105,15 +104,12 @@ func newUser(srv *Server, record *User) *user {
}
func (u *user) forEachNetwork(f func(*network)) {
u.lock.Lock()
for _, network := range u.networks {
f(network)
}
u.lock.Unlock()
}
func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
u.lock.Lock()
for _, network := range u.networks {
uc := network.upstream()
if uc == nil || !uc.registered || uc.closed {
@ -121,15 +117,12 @@ func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
}
f(uc)
}
u.lock.Unlock()
}
func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
u.lock.Lock()
for _, dc := range u.downstreamConns {
f(dc)
}
u.lock.Unlock()
}
func (u *user) getNetwork(name string) *network {
@ -148,14 +141,12 @@ func (u *user) run() {
return
}
u.lock.Lock()
for _, record := range networks {
network := newNetwork(u, &record)
u.networks = append(u.networks, network)
go network.run()
}
u.lock.Unlock()
for e := range u.events {
switch e := e.(type) {
@ -170,19 +161,21 @@ func (u *user) run() {
}
case eventDownstreamConnected:
dc := e.dc
u.lock.Lock()
if err := dc.welcome(); err != nil {
dc.logger.Printf("failed to handle new registered connection: %v", err)
break
}
u.downstreamConns = append(u.downstreamConns, dc)
u.lock.Unlock()
case eventDownstreamDisconnected:
dc := e.dc
u.lock.Lock()
for i := range u.downstreamConns {
if u.downstreamConns[i] == dc {
u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
break
}
}
u.lock.Unlock()
case eventDownstreamMessage:
msg, dc := e.msg, e.dc
if dc.isClosed() {
@ -220,9 +213,7 @@ func (u *user) createNetwork(net *Network) (*network, error) {
}
})
u.lock.Lock()
u.networks = append(u.networks, network)
u.lock.Unlock()
go network.run()
return network, nil