Nuke user.lock
Split user.register into two functions, one to make sure the user is authenticated, the other to send our current state. This allows to get rid of data races by doing the second part in the user goroutine. Closes: https://todo.sr.ht/~emersion/soju/22
This commit is contained in:
parent
c0f5850e5b
commit
08bb06c164
104
downstream.go
104
downstream.go
@ -71,6 +71,7 @@ type downstreamConn struct {
|
||||
nick string
|
||||
username string
|
||||
rawUsername string
|
||||
networkName string
|
||||
realname string
|
||||
hostname string
|
||||
password string // empty after authentication
|
||||
@ -582,42 +583,6 @@ func unmarshalUsername(rawUsername string) (username, network string) {
|
||||
return username, network
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) setNetwork(networkName string) error {
|
||||
if networkName == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
network := dc.user.getNetwork(networkName)
|
||||
if network == nil {
|
||||
addr := networkName
|
||||
if !strings.ContainsRune(addr, ':') {
|
||||
addr = addr + ":6697"
|
||||
}
|
||||
|
||||
dc.logger.Printf("trying to connect to new network %q", addr)
|
||||
if err := sanityCheckServer(addr); err != nil {
|
||||
dc.logger.Printf("failed to connect to %q: %v", addr, err)
|
||||
return ircError{&irc.Message{
|
||||
Command: irc.ERR_PASSWDMISMATCH,
|
||||
Params: []string{"*", fmt.Sprintf("Failed to connect to %q", networkName)},
|
||||
}}
|
||||
}
|
||||
|
||||
dc.logger.Printf("auto-saving network %q", networkName)
|
||||
var err error
|
||||
network, err = dc.user.createNetwork(&Network{
|
||||
Addr: networkName,
|
||||
Nick: dc.nick,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
dc.network = network
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) authenticate(username, password string) error {
|
||||
username, networkName := unmarshalUsername(username)
|
||||
|
||||
@ -634,31 +599,82 @@ func (dc *downstreamConn) authenticate(username, password string) error {
|
||||
}
|
||||
|
||||
dc.user = u
|
||||
|
||||
return dc.setNetwork(networkName)
|
||||
dc.networkName = networkName
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) register() error {
|
||||
if dc.registered {
|
||||
return fmt.Errorf("tried to register twice")
|
||||
}
|
||||
|
||||
password := dc.password
|
||||
dc.password = ""
|
||||
if dc.user == nil {
|
||||
if err := dc.authenticate(dc.rawUsername, password); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if dc.network == nil {
|
||||
_, networkName := unmarshalUsername(dc.rawUsername)
|
||||
if err := dc.setNetwork(networkName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if dc.networkName == "" {
|
||||
_, dc.networkName = unmarshalUsername(dc.rawUsername)
|
||||
}
|
||||
|
||||
dc.registered = true
|
||||
dc.username = dc.user.Username
|
||||
dc.logger.Printf("registration complete for user %q", dc.username)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) loadNetwork() error {
|
||||
if dc.networkName == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
network := dc.user.getNetwork(dc.networkName)
|
||||
if network == nil {
|
||||
addr := dc.networkName
|
||||
if !strings.ContainsRune(addr, ':') {
|
||||
addr = addr + ":6697"
|
||||
}
|
||||
|
||||
dc.logger.Printf("trying to connect to new network %q", addr)
|
||||
if err := sanityCheckServer(addr); err != nil {
|
||||
dc.logger.Printf("failed to connect to %q: %v", addr, err)
|
||||
return ircError{&irc.Message{
|
||||
Command: irc.ERR_PASSWDMISMATCH,
|
||||
Params: []string{"*", fmt.Sprintf("Failed to connect to %q", dc.networkName)},
|
||||
}}
|
||||
}
|
||||
|
||||
dc.logger.Printf("auto-saving network %q", dc.networkName)
|
||||
var err error
|
||||
network, err = dc.user.createNetwork(&Network{
|
||||
Addr: dc.networkName,
|
||||
Nick: dc.nick,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
dc.network = network
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dc *downstreamConn) welcome() error {
|
||||
if dc.user == nil || !dc.registered {
|
||||
panic("tried to welcome an unregistered connection")
|
||||
}
|
||||
|
||||
// TODO: doing this might take some time. We should do it in dc.register
|
||||
// instead, but we'll potentially be adding a new network and this must be
|
||||
// done in the user goroutine.
|
||||
if err := dc.loadNetwork(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dc.user.lock.Lock()
|
||||
firstDownstream := len(dc.user.downstreamConns) == 0
|
||||
dc.user.lock.Unlock()
|
||||
|
||||
dc.SendMessage(&irc.Message{
|
||||
Prefix: dc.srv.prefix(),
|
||||
|
21
user.go
21
user.go
@ -91,7 +91,6 @@ type user struct {
|
||||
|
||||
events chan event
|
||||
|
||||
lock sync.Mutex
|
||||
networks []*network
|
||||
downstreamConns []*downstreamConn
|
||||
}
|
||||
@ -105,15 +104,12 @@ func newUser(srv *Server, record *User) *user {
|
||||
}
|
||||
|
||||
func (u *user) forEachNetwork(f func(*network)) {
|
||||
u.lock.Lock()
|
||||
for _, network := range u.networks {
|
||||
f(network)
|
||||
}
|
||||
u.lock.Unlock()
|
||||
}
|
||||
|
||||
func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
|
||||
u.lock.Lock()
|
||||
for _, network := range u.networks {
|
||||
uc := network.upstream()
|
||||
if uc == nil || !uc.registered || uc.closed {
|
||||
@ -121,15 +117,12 @@ func (u *user) forEachUpstream(f func(uc *upstreamConn)) {
|
||||
}
|
||||
f(uc)
|
||||
}
|
||||
u.lock.Unlock()
|
||||
}
|
||||
|
||||
func (u *user) forEachDownstream(f func(dc *downstreamConn)) {
|
||||
u.lock.Lock()
|
||||
for _, dc := range u.downstreamConns {
|
||||
f(dc)
|
||||
}
|
||||
u.lock.Unlock()
|
||||
}
|
||||
|
||||
func (u *user) getNetwork(name string) *network {
|
||||
@ -148,14 +141,12 @@ func (u *user) run() {
|
||||
return
|
||||
}
|
||||
|
||||
u.lock.Lock()
|
||||
for _, record := range networks {
|
||||
network := newNetwork(u, &record)
|
||||
u.networks = append(u.networks, network)
|
||||
|
||||
go network.run()
|
||||
}
|
||||
u.lock.Unlock()
|
||||
|
||||
for e := range u.events {
|
||||
switch e := e.(type) {
|
||||
@ -170,19 +161,21 @@ func (u *user) run() {
|
||||
}
|
||||
case eventDownstreamConnected:
|
||||
dc := e.dc
|
||||
u.lock.Lock()
|
||||
|
||||
if err := dc.welcome(); err != nil {
|
||||
dc.logger.Printf("failed to handle new registered connection: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
u.downstreamConns = append(u.downstreamConns, dc)
|
||||
u.lock.Unlock()
|
||||
case eventDownstreamDisconnected:
|
||||
dc := e.dc
|
||||
u.lock.Lock()
|
||||
for i := range u.downstreamConns {
|
||||
if u.downstreamConns[i] == dc {
|
||||
u.downstreamConns = append(u.downstreamConns[:i], u.downstreamConns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
u.lock.Unlock()
|
||||
case eventDownstreamMessage:
|
||||
msg, dc := e.msg, e.dc
|
||||
if dc.isClosed() {
|
||||
@ -220,9 +213,7 @@ func (u *user) createNetwork(net *Network) (*network, error) {
|
||||
}
|
||||
})
|
||||
|
||||
u.lock.Lock()
|
||||
u.networks = append(u.networks, network)
|
||||
u.lock.Unlock()
|
||||
|
||||
go network.run()
|
||||
return network, nil
|
||||
|
Loading…
Reference in New Issue
Block a user