diff --git a/upstream.go b/upstream.go index b24d78b..3eb08e2 100644 --- a/upstream.go +++ b/upstream.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "net" + "net/url" "strconv" "strings" "time" @@ -80,31 +81,30 @@ type upstreamConn struct { func connectToUpstream(network *network) (*upstreamConn, error) { logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)} - var scheme string - var addr string - - addrParts := strings.SplitN(network.Addr, "://", 2) - if len(addrParts) == 2 { - scheme = addrParts[0] - addr = addrParts[1] - } else { - scheme = "ircs" - addr = addrParts[0] - } - dialer := net.Dialer{Timeout: connectTimeout} + s := network.Addr + if !strings.Contains(s, "://") { + // This is a raw domain name, make it an URL with the default scheme + s = "ircs://" + s + } + + u, err := url.Parse(s) + if err != nil { + return nil, fmt.Errorf("failed to parse upstream server URL: %v", err) + } + var netConn net.Conn - var err error - switch scheme { + switch u.Scheme { case "ircs": + addr := u.Host if _, _, err := net.SplitHostPort(addr); err != nil { addr = addr + ":6697" } logger.Printf("connecting to TLS server at address %q", addr) - var cfg *tls.Config + var tlsConfig *tls.Config if network.SASL.Mechanism == "EXTERNAL" { if network.SASL.External.CertBlob == nil { return nil, fmt.Errorf("missing certificate for authentication") @@ -116,7 +116,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) { if err != nil { return nil, fmt.Errorf("failed to parse private key: %v", err) } - cfg = &tls.Config{ + tlsConfig = &tls.Config{ Certificates: []tls.Certificate{ { Certificate: [][]byte{network.SASL.External.CertBlob}, @@ -127,19 +127,23 @@ func connectToUpstream(network *network) (*upstreamConn, error) { logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob)) } - netConn, err = tls.DialWithDialer(&dialer, "tcp", addr, cfg) + netConn, err = tls.DialWithDialer(&dialer, "tcp", addr, tlsConfig) + if err != nil { + return nil, fmt.Errorf("failed to dial %q: %v", addr, err) + } case "irc+insecure": + addr := u.Host if _, _, err := net.SplitHostPort(addr); err != nil { addr = addr + ":6667" } logger.Printf("connecting to plain-text server at address %q", addr) netConn, err = dialer.Dial("tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to dial %q: %v", addr, err) + } default: - return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", addr, scheme) - } - if err != nil { - return nil, fmt.Errorf("failed to dial %q: %v", addr, err) + return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme) } uc := &upstreamConn{ @@ -156,7 +160,6 @@ func connectToUpstream(network *network) (*upstreamConn, error) { pendingLISTDownstreamSet: make(map[uint64]struct{}), messageLoggers: make(map[string]*messageLogger), } - return uc, nil }