diff --git a/user.go b/user.go index e2fac8f..e2fe9aa 100644 --- a/user.go +++ b/user.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "encoding/binary" "encoding/hex" + "errors" "fmt" "math/big" "net" @@ -186,6 +187,41 @@ func userIdent(u *User) string { return hex.EncodeToString(h[:16]) } +func (net *network) runConn(ctx context.Context) error { + net.user.srv.metrics.upstreams.Add(1) + defer net.user.srv.metrics.upstreams.Add(-1) + + uc, err := connectToUpstream(ctx, net) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + defer uc.Close() + + if net.user.srv.Identd != nil { + net.user.srv.Identd.Store(uc.RemoteAddr().String(), uc.LocalAddr().String(), userIdent(&net.user.User)) + defer net.user.srv.Identd.Delete(uc.RemoteAddr().String(), uc.LocalAddr().String()) + } + + uc.register(ctx) + if err := uc.runUntilRegistered(ctx); err != nil { + return fmt.Errorf("failed to register: %w", err) + } + + // TODO: this is racy with net.stopped. If the network is stopped + // before the user goroutine receives eventUpstreamConnected, the + // connection won't be closed. + net.user.events <- eventUpstreamConnected{uc} + defer func() { + net.user.events <- eventUpstreamDisconnected{uc} + }() + + if err := uc.readMessages(net.user.events); err != nil { + return fmt.Errorf("failed to handle messages: %w", err) + } + + return nil +} + func (net *network) run() { if !net.Enabled { return @@ -205,57 +241,25 @@ func (net *network) run() { } lastTry = time.Now() - net.user.srv.metrics.upstreams.Add(1) - - uc, err := connectToUpstream(context.TODO(), net) - if err != nil { - net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err) - net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)} - net.user.srv.metrics.upstreams.Add(-1) - net.user.srv.metrics.upstreamConnectErrorsTotal.Inc() - continue - } - - if net.user.srv.Identd != nil { - net.user.srv.Identd.Store(uc.RemoteAddr().String(), uc.LocalAddr().String(), userIdent(&net.user.User)) - } - - uc.register(context.TODO()) - if err := uc.runUntilRegistered(context.TODO()); err != nil { + if err := net.runConn(context.TODO()); err != nil { text := err.Error() temp := true - if regErr, ok := err.(registrationError); ok { - text = regErr.Reason() + var regErr registrationError + if errors.As(err, ®Err) { + text = "failed to register: " + regErr.Reason() temp = regErr.Temporary() } - uc.logger.Printf("failed to register: %v", text) - net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to register: %v", text)} - uc.Close() - net.user.srv.metrics.upstreams.Add(-1) + + net.logger.Printf("connection error to %q: %v", net.Addr, text) + net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("connection error: %v", err)} net.user.srv.metrics.upstreamConnectErrorsTotal.Inc() + if !temp { return } - continue + } else { + backoff.Reset() } - - // TODO: this is racy with net.stopped. If the network is stopped - // before the user goroutine receives eventUpstreamConnected, the - // connection won't be closed. - net.user.events <- eventUpstreamConnected{uc} - if err := uc.readMessages(net.user.events); err != nil { - uc.logger.Printf("failed to handle messages: %v", err) - net.user.events <- eventUpstreamError{uc, fmt.Errorf("failed to handle messages: %v", err)} - } - uc.Close() - net.user.events <- eventUpstreamDisconnected{uc} - - if net.user.srv.Identd != nil { - net.user.srv.Identd.Delete(uc.RemoteAddr().String(), uc.LocalAddr().String()) - } - - net.user.srv.metrics.upstreams.Add(-1) - backoff.Reset() } }