diff --git a/upstream.go b/upstream.go index 5614f7d..efbed31 100644 --- a/upstream.go +++ b/upstream.go @@ -1326,6 +1326,25 @@ func (uc *upstreamConn) register() { }) } +func (uc *upstreamConn) runUntilRegistered() error { + for !uc.registered { + msg, err := uc.irc.ReadMessage() + if err != nil { + return fmt.Errorf("failed to read message: %v", err) + } + + if uc.srv.Debug { + uc.logger.Printf("received: %v", msg) + } + + if err := uc.handleMessage(msg); err != nil { + return fmt.Errorf("failed to handle message %q: %v", msg, err) + } + } + + return nil +} + func (uc *upstreamConn) requestSASL() bool { if uc.network.SASL.Mechanism == "" { return false diff --git a/user.go b/user.go index 45613e7..dc84bcc 100644 --- a/user.go +++ b/user.go @@ -71,10 +71,12 @@ func (net *network) run() { } uc.register() + if err := uc.runUntilRegistered(); err != nil { + uc.logger.Printf("failed to register: %v", err) + uc.Close() + continue + } - // TODO: wait for the connection to be registered before adding it to - // net, otherwise messages might be sent to it while still being - // unauthenticated net.lock.Lock() net.conn = uc net.lock.Unlock() @@ -134,7 +136,7 @@ func (u *user) forEachNetwork(f func(*network)) { func (u *user) forEachUpstream(f func(uc *upstreamConn)) { for _, network := range u.networks { uc := network.upstream() - if uc == nil || !uc.registered { + if uc == nil { continue } f(uc)