From 73287f242e8cd127f6b08e4675619374a4a5e4e3 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Thu, 2 Dec 2021 10:53:43 +0100 Subject: [PATCH] Add context to connectToUpstream --- upstream.go | 12 ++++++------ user.go | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/upstream.go b/upstream.go index 6cc02fc..f4789ec 100644 --- a/upstream.go +++ b/upstream.go @@ -123,7 +123,7 @@ type upstreamConn struct { gotMotd bool } -func connectToUpstream(network *network) (*upstreamConn, error) { +func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, error) { logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())} dialer := net.Dialer{Timeout: connectTimeout} @@ -143,7 +143,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) { addr = u.Host + ":6697" } - dialer.LocalAddr, err = network.user.localTCPAddrForHost(host) + dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host) if err != nil { return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) } @@ -171,7 +171,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) { logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob)) } - netConn, err = dialer.Dial("tcp", addr) + netConn, err = dialer.DialContext(ctx, "tcp", addr) if err != nil { return nil, fmt.Errorf("failed to dial %q: %v", addr, err) } @@ -188,19 +188,19 @@ func connectToUpstream(network *network) (*upstreamConn, error) { addr = u.Host + ":6667" } - dialer.LocalAddr, err = network.user.localTCPAddrForHost(host) + dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host) if err != nil { return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err) } logger.Printf("connecting to plain-text server at address %q", addr) - netConn, err = dialer.Dial("tcp", addr) + netConn, err = dialer.DialContext(ctx, "tcp", addr) if err != nil { return nil, fmt.Errorf("failed to dial %q: %v", addr, err) } case "irc+unix", "unix": logger.Printf("connecting to Unix socket at path %q", u.Path) - netConn, err = dialer.Dial("unix", u.Path) + netConn, err = dialer.DialContext(ctx, "unix", u.Path) if err != nil { return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err) } diff --git a/user.go b/user.go index 4e43f31..ef3df38 100644 --- a/user.go +++ b/user.go @@ -202,7 +202,7 @@ func (net *network) run() { } lastTry = time.Now() - uc, err := connectToUpstream(net) + uc, err := connectToUpstream(context.TODO(), net) if err != nil { net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err) net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)} @@ -1015,13 +1015,13 @@ func (u *user) hasPersistentMsgStore() bool { // localAddrForHost returns the local address to use when connecting to host. // A nil address is returned when the OS should automatically pick one. -func (u *user) localTCPAddrForHost(host string) (*net.TCPAddr, error) { +func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) { upstreamUserIPs := u.srv.Config().UpstreamUserIPs if len(upstreamUserIPs) == 0 { return nil, nil } - ips, err := net.LookupIP(host) + ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host) if err != nil { return nil, err }