diff --git a/upstream.go b/upstream.go index 95bad3a..e3bdf4e 100644 --- a/upstream.go +++ b/upstream.go @@ -242,8 +242,6 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er ctx, cancel := context.WithTimeout(ctx, connectTimeout) defer cancel() - var dialer net.Dialer - u, err := network.URL() if err != nil { return nil, err @@ -259,13 +257,6 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er addr = u.Host + ":6697" } - dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host) - if err != nil { - return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) - } - - logger.Printf("connecting to TLS server at address %q", addr) - tlsConfig := &tls.Config{ServerName: host, NextProtos: []string{"irc"}} if network.SASL.Mechanism == "EXTERNAL" { if network.SASL.External.CertBlob == nil { @@ -321,9 +312,10 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er } } - netConn, err = dialer.DialContext(ctx, "tcp", addr) + logger.Printf("connecting to TLS server at address %q", addr) + netConn, err = dialTCP(ctx, network.user, addr) if err != nil { - return nil, fmt.Errorf("failed to dial %q: %v", addr, err) + return nil, err } // Don't do the TLS handshake immediately, because we need to register @@ -332,23 +324,17 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er netConn = tls.Client(netConn, tlsConfig) case "irc+insecure": addr := u.Host - host, _, err := net.SplitHostPort(addr) - if err != nil { - host = u.Host + if _, _, err := net.SplitHostPort(addr); err != nil { addr = u.Host + ":6667" } - dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host) - if err != nil { - return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) - } - logger.Printf("connecting to plain-text server at address %q", addr) - netConn, err = dialer.DialContext(ctx, "tcp", addr) + netConn, err = dialTCP(ctx, network.user, addr) if err != nil { - return nil, fmt.Errorf("failed to dial %q: %v", addr, err) + return nil, err } case "irc+unix", "unix": + var dialer net.Dialer logger.Printf("connecting to Unix socket at path %q", u.Path) netConn, err = dialer.DialContext(ctx, "unix", u.Path) if err != nil { @@ -386,6 +372,21 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er return uc, nil } +func dialTCP(ctx context.Context, user *user, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + localAddr, err := user.localTCPAddrForHost(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) + } + + dialer := net.Dialer{LocalAddr: localAddr} + return dialer.DialContext(ctx, "tcp", addr) +} + func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) { uc.network.forEachDownstream(f) }