Parse upstream URLs with net/url

This allows us to ignore the path part of the URL. This is preliminary
work for unix URLs.
This commit is contained in:
Simon Ser 2020-07-06 16:59:14 +02:00
parent b46a2554e1
commit 7af21d9d81
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
1 changed files with 25 additions and 22 deletions

View File

@ -10,6 +10,7 @@ import (
"fmt"
"io"
"net"
"net/url"
"strconv"
"strings"
"time"
@ -80,31 +81,30 @@ type upstreamConn struct {
func connectToUpstream(network *network) (*upstreamConn, error) {
logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
var scheme string
var addr string
addrParts := strings.SplitN(network.Addr, "://", 2)
if len(addrParts) == 2 {
scheme = addrParts[0]
addr = addrParts[1]
} else {
scheme = "ircs"
addr = addrParts[0]
}
dialer := net.Dialer{Timeout: connectTimeout}
s := network.Addr
if !strings.Contains(s, "://") {
// This is a raw domain name, make it an URL with the default scheme
s = "ircs://" + s
}
u, err := url.Parse(s)
if err != nil {
return nil, fmt.Errorf("failed to parse upstream server URL: %v", err)
}
var netConn net.Conn
var err error
switch scheme {
switch u.Scheme {
case "ircs":
addr := u.Host
if _, _, err := net.SplitHostPort(addr); err != nil {
addr = addr + ":6697"
}
logger.Printf("connecting to TLS server at address %q", addr)
var cfg *tls.Config
var tlsConfig *tls.Config
if network.SASL.Mechanism == "EXTERNAL" {
if network.SASL.External.CertBlob == nil {
return nil, fmt.Errorf("missing certificate for authentication")
@ -116,7 +116,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %v", err)
}
cfg = &tls.Config{
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{network.SASL.External.CertBlob},
@ -127,19 +127,23 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
}
netConn, err = tls.DialWithDialer(&dialer, "tcp", addr, cfg)
netConn, err = tls.DialWithDialer(&dialer, "tcp", addr, tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
}
case "irc+insecure":
addr := u.Host
if _, _, err := net.SplitHostPort(addr); err != nil {
addr = addr + ":6667"
}
logger.Printf("connecting to plain-text server at address %q", addr)
netConn, err = dialer.Dial("tcp", addr)
if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
}
default:
return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", addr, scheme)
}
if err != nil {
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme)
}
uc := &upstreamConn{
@ -156,7 +160,6 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
pendingLISTDownstreamSet: make(map[uint64]struct{}),
messageLoggers: make(map[string]*messageLogger),
}
return uc, nil
}