Parse upstream URLs with net/url

This allows us to ignore the path part of the URL. This is preliminary
work for unix URLs.
This commit is contained in:
Simon Ser 2020-07-06 16:59:14 +02:00
parent b46a2554e1
commit 7af21d9d81
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48

View File

@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -80,31 +81,30 @@ type upstreamConn struct {
func connectToUpstream(network *network) (*upstreamConn, error) { func connectToUpstream(network *network) (*upstreamConn, error) {
logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)} logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
var scheme string
var addr string
addrParts := strings.SplitN(network.Addr, "://", 2)
if len(addrParts) == 2 {
scheme = addrParts[0]
addr = addrParts[1]
} else {
scheme = "ircs"
addr = addrParts[0]
}
dialer := net.Dialer{Timeout: connectTimeout} dialer := net.Dialer{Timeout: connectTimeout}
s := network.Addr
if !strings.Contains(s, "://") {
// This is a raw domain name, make it an URL with the default scheme
s = "ircs://" + s
}
u, err := url.Parse(s)
if err != nil {
return nil, fmt.Errorf("failed to parse upstream server URL: %v", err)
}
var netConn net.Conn var netConn net.Conn
var err error switch u.Scheme {
switch scheme {
case "ircs": case "ircs":
addr := u.Host
if _, _, err := net.SplitHostPort(addr); err != nil { if _, _, err := net.SplitHostPort(addr); err != nil {
addr = addr + ":6697" addr = addr + ":6697"
} }
logger.Printf("connecting to TLS server at address %q", addr) logger.Printf("connecting to TLS server at address %q", addr)
var cfg *tls.Config var tlsConfig *tls.Config
if network.SASL.Mechanism == "EXTERNAL" { if network.SASL.Mechanism == "EXTERNAL" {
if network.SASL.External.CertBlob == nil { if network.SASL.External.CertBlob == nil {
return nil, fmt.Errorf("missing certificate for authentication") return nil, fmt.Errorf("missing certificate for authentication")
@ -116,7 +116,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse private key: %v", err) return nil, fmt.Errorf("failed to parse private key: %v", err)
} }
cfg = &tls.Config{ tlsConfig = &tls.Config{
Certificates: []tls.Certificate{ Certificates: []tls.Certificate{
{ {
Certificate: [][]byte{network.SASL.External.CertBlob}, Certificate: [][]byte{network.SASL.External.CertBlob},
@ -127,19 +127,23 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob)) logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
} }
netConn, err = tls.DialWithDialer(&dialer, "tcp", addr, cfg) netConn, err = tls.DialWithDialer(&dialer, "tcp", addr, tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
}
case "irc+insecure": case "irc+insecure":
addr := u.Host
if _, _, err := net.SplitHostPort(addr); err != nil { if _, _, err := net.SplitHostPort(addr); err != nil {
addr = addr + ":6667" addr = addr + ":6667"
} }
logger.Printf("connecting to plain-text server at address %q", addr) logger.Printf("connecting to plain-text server at address %q", addr)
netConn, err = dialer.Dial("tcp", addr) netConn, err = dialer.Dial("tcp", addr)
if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
}
default: default:
return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", addr, scheme) return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme)
}
if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
} }
uc := &upstreamConn{ uc := &upstreamConn{
@ -156,7 +160,6 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
pendingLISTDownstreamSet: make(map[uint64]struct{}), pendingLISTDownstreamSet: make(map[uint64]struct{}),
messageLoggers: make(map[string]*messageLogger), messageLoggers: make(map[string]*messageLogger),
} }
return uc, nil return uc, nil
} }