upstream: consoldate TCP dial into function

This commit is contained in:
Simon Ser 2023-12-21 13:57:28 +01:00
parent d423a1ca24
commit e184c30cef

View File

@ -242,8 +242,6 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
ctx, cancel := context.WithTimeout(ctx, connectTimeout)
defer cancel()
var dialer net.Dialer
u, err := network.URL()
if err != nil {
return nil, err
@ -259,13 +257,6 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
addr = u.Host + ":6697"
}
dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
if err != nil {
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
}
logger.Printf("connecting to TLS server at address %q", addr)
tlsConfig := &tls.Config{ServerName: host, NextProtos: []string{"irc"}}
if network.SASL.Mechanism == "EXTERNAL" {
if network.SASL.External.CertBlob == nil {
@ -321,9 +312,10 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
}
}
netConn, err = dialer.DialContext(ctx, "tcp", addr)
logger.Printf("connecting to TLS server at address %q", addr)
netConn, err = dialTCP(ctx, network.user, addr)
if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
return nil, err
}
// Don't do the TLS handshake immediately, because we need to register
@ -332,23 +324,17 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
netConn = tls.Client(netConn, tlsConfig)
case "irc+insecure":
addr := u.Host
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = u.Host
if _, _, err := net.SplitHostPort(addr); err != nil {
addr = u.Host + ":6667"
}
dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
if err != nil {
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
}
logger.Printf("connecting to plain-text server at address %q", addr)
netConn, err = dialer.DialContext(ctx, "tcp", addr)
netConn, err = dialTCP(ctx, network.user, addr)
if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
return nil, err
}
case "irc+unix", "unix":
var dialer net.Dialer
logger.Printf("connecting to Unix socket at path %q", u.Path)
netConn, err = dialer.DialContext(ctx, "unix", u.Path)
if err != nil {
@ -386,6 +372,21 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
return uc, nil
}
func dialTCP(ctx context.Context, user *user, addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
localAddr, err := user.localTCPAddrForHost(ctx, host)
if err != nil {
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
}
dialer := net.Dialer{LocalAddr: localAddr}
return dialer.DialContext(ctx, "tcp", addr)
}
func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
uc.network.forEachDownstream(f)
}