Add context to connectToUpstream
This commit is contained in:
parent
33a639ecf0
commit
73287f242e
12
upstream.go
12
upstream.go
@ -123,7 +123,7 @@ type upstreamConn struct {
|
|||||||
gotMotd bool
|
gotMotd bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func connectToUpstream(network *network) (*upstreamConn, error) {
|
func connectToUpstream(ctx context.Context, network *network) (*upstreamConn, error) {
|
||||||
logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())}
|
logger := &prefixLogger{network.user.logger, fmt.Sprintf("upstream %q: ", network.GetName())}
|
||||||
|
|
||||||
dialer := net.Dialer{Timeout: connectTimeout}
|
dialer := net.Dialer{Timeout: connectTimeout}
|
||||||
@ -143,7 +143,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
|
|||||||
addr = u.Host + ":6697"
|
addr = u.Host + ":6697"
|
||||||
}
|
}
|
||||||
|
|
||||||
dialer.LocalAddr, err = network.user.localTCPAddrForHost(host)
|
dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
|
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
|
||||||
}
|
}
|
||||||
@ -171,7 +171,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
|
|||||||
logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
|
logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
|
||||||
}
|
}
|
||||||
|
|
||||||
netConn, err = dialer.Dial("tcp", addr)
|
netConn, err = dialer.DialContext(ctx, "tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
|
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
|
||||||
}
|
}
|
||||||
@ -188,19 +188,19 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
|
|||||||
addr = u.Host + ":6667"
|
addr = u.Host + ":6667"
|
||||||
}
|
}
|
||||||
|
|
||||||
dialer.LocalAddr, err = network.user.localTCPAddrForHost(host)
|
dialer.LocalAddr, err = network.user.localTCPAddrForHost(ctx, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
|
return nil, fmt.Errorf("failed to pick local IP for remote host %q: %v", host, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Printf("connecting to plain-text server at address %q", addr)
|
logger.Printf("connecting to plain-text server at address %q", addr)
|
||||||
netConn, err = dialer.Dial("tcp", addr)
|
netConn, err = dialer.DialContext(ctx, "tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
|
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
|
||||||
}
|
}
|
||||||
case "irc+unix", "unix":
|
case "irc+unix", "unix":
|
||||||
logger.Printf("connecting to Unix socket at path %q", u.Path)
|
logger.Printf("connecting to Unix socket at path %q", u.Path)
|
||||||
netConn, err = dialer.Dial("unix", u.Path)
|
netConn, err = dialer.DialContext(ctx, "unix", u.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err)
|
return nil, fmt.Errorf("failed to connect to Unix socket %q: %v", u.Path, err)
|
||||||
}
|
}
|
||||||
|
6
user.go
6
user.go
@ -202,7 +202,7 @@ func (net *network) run() {
|
|||||||
}
|
}
|
||||||
lastTry = time.Now()
|
lastTry = time.Now()
|
||||||
|
|
||||||
uc, err := connectToUpstream(net)
|
uc, err := connectToUpstream(context.TODO(), net)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
|
net.logger.Printf("failed to connect to upstream server %q: %v", net.Addr, err)
|
||||||
net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
|
net.user.events <- eventUpstreamConnectionError{net, fmt.Errorf("failed to connect: %v", err)}
|
||||||
@ -1015,13 +1015,13 @@ func (u *user) hasPersistentMsgStore() bool {
|
|||||||
|
|
||||||
// localAddrForHost returns the local address to use when connecting to host.
|
// localAddrForHost returns the local address to use when connecting to host.
|
||||||
// A nil address is returned when the OS should automatically pick one.
|
// A nil address is returned when the OS should automatically pick one.
|
||||||
func (u *user) localTCPAddrForHost(host string) (*net.TCPAddr, error) {
|
func (u *user) localTCPAddrForHost(ctx context.Context, host string) (*net.TCPAddr, error) {
|
||||||
upstreamUserIPs := u.srv.Config().UpstreamUserIPs
|
upstreamUserIPs := u.srv.Config().UpstreamUserIPs
|
||||||
if len(upstreamUserIPs) == 0 {
|
if len(upstreamUserIPs) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ips, err := net.LookupIP(host)
|
ips, err := net.DefaultResolver.LookupIP(ctx, "ip", host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user