upstream: consoldate TCP dial into function
This commit is contained in:
parent
d423a1ca24
commit
e184c30cef
43
upstream.go
43
upstream.go
@ -242,8 +242,6 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
|
||||
ctx, cancel := context.WithTimeout(ctx, connectTimeout)
|
||||
defer cancel()
|
||||
|
||||
var dialer net.Dialer
|
||||
|
||||
u, err := network.URL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -259,13 +257,6 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
|
||||
addr = u.Host + ":6697"
|
||||
}
|
||||
|
||||
dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
|
||||
}
|
||||
|
||||
logger.Printf("connecting to TLS server at address %q", addr)
|
||||
|
||||
tlsConfig := &tls.Config{ServerName: host, NextProtos: []string{"irc"}}
|
||||
if network.SASL.Mechanism == "EXTERNAL" {
|
||||
if network.SASL.External.CertBlob == nil {
|
||||
@ -321,9 +312,10 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
|
||||
}
|
||||
}
|
||||
|
||||
netConn, err = dialer.DialContext(ctx, "tcp", addr)
|
||||
logger.Printf("connecting to TLS server at address %q", addr)
|
||||
netConn, err = dialTCP(ctx, network.user, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Don't do the TLS handshake immediately, because we need to register
|
||||
@ -332,23 +324,17 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
|
||||
netConn = tls.Client(netConn, tlsConfig)
|
||||
case "irc+insecure":
|
||||
addr := u.Host
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = u.Host
|
||||
if _, _, err := net.SplitHostPort(addr); err != nil {
|
||||
addr = u.Host + ":6667"
|
||||
}
|
||||
|
||||
dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
|
||||
}
|
||||
|
||||
logger.Printf("connecting to plain-text server at address %q", addr)
|
||||
netConn, err = dialer.DialContext(ctx, "tcp", addr)
|
||||
netConn, err = dialTCP(ctx, network.user, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
|
||||
return nil, err
|
||||
}
|
||||
case "irc+unix", "unix":
|
||||
var dialer net.Dialer
|
||||
logger.Printf("connecting to Unix socket at path %q", u.Path)
|
||||
netConn, err = dialer.DialContext(ctx, "unix", u.Path)
|
||||
if err != nil {
|
||||
@ -386,6 +372,21 @@ func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, er
|
||||
return uc, nil
|
||||
}
|
||||
|
||||
func dialTCP(ctx context.Context, user *user, addr string) (net.Conn, error) {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
localAddr, err := user.localTCPAddrForHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
|
||||
}
|
||||
|
||||
dialer := net.Dialer{LocalAddr: localAddr}
|
||||
return dialer.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
|
||||
func (uc *upstreamConn) forEachDownstream(f func(*downstreamConn)) {
|
||||
uc.network.forEachDownstream(f)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user