Add context to connectToUpstream
This commit is contained in:
parent
33a639ecf0
commit
73287f242e
12
upstream.go
12
upstream.go
@ -123,7 +123,7 @@ type upstreamConn struct {
|
||||
gotMotd bool
|
||||
}
|
||||
|
||||
func connectToUpstream(network *network) (*upstreamConn, error) {
|
||||
func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, error) {
|
||||
logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())}
|
||||
|
||||
dialer := net.Dialer{Timeout: connectTimeout}
|
||||
@ -143,7 +143,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
|
||||
addr = u.Host + ":6697"
|
||||
}
|
||||
|
||||
dialer.LocalAddr, err = network.user.localTCPAddrForHost(host)
|
||||
dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
|
||||
}
|
||||
@ -171,7 +171,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
|
||||
logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
|
||||
}
|
||||
|
||||
netConn, err = dialer.Dial("tcp", addr)
|
||||
netConn, err = dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
|
||||
}
|
||||
@ -188,19 +188,19 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
|
||||
addr = u.Host + ":6667"
|
||||
}
|
||||
|
||||
dialer.LocalAddr, err = network.user.localTCPAddrForHost(host)
|
||||
dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
|
||||
}
|
||||
|
||||
logger.Printf("connecting to plain-text server at address %q", addr)
|
||||
netConn, err = dialer.Dial("tcp", addr)
|
||||
netConn, err = dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
|
||||
}
|
||||
case "irc+unix", "unix":
|
||||
logger.Printf("connecting to Unix socket at path %q", u.Path)
|
||||
netConn, err = dialer.Dial("unix", u.Path)
|
||||
netConn, err = dialer.DialContext(ctx, "unix", u.Path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err)
|
||||
}
|
||||
|
6
user.go
6
user.go
@ -202,7 +202,7 @@ func (net *network) run() {
|
||||
}
|
||||
lastTry = time.Now()
|
||||
|
||||
uc, err := connectToUpstream(net)
|
||||
uc, err := connectToUpstream(context.TODO(), net)
|
||||
if err != nil {
|
||||
net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
|
||||
net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
|
||||
@ -1015,13 +1015,13 @@ func (u *user) hasPersistentMsgStore() bool {
|
||||
|
||||
// localAddrForHost returns the local address to use when connecting to host.
|
||||
// A nil address is returned when the OS should automatically pick one.
|
||||
func (u *user) localTCPAddrForHost(host string) (*net.TCPAddr, error) {
|
||||
func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) {
|
||||
upstreamUserIPs := u.srv.Config().UpstreamUserIPs
|
||||
if len(upstreamUserIPs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ips, err := net.LookupIP(host)
|
||||
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user