Parse upstream URLs with net/url
This allows us to ignore the path part of the URL. This is preliminary work for unix URLs.
This commit is contained in:
parent
b46a2554e1
commit
7af21d9d81
45
upstream.go
45
upstream.go
@ -10,6 +10,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -80,31 +81,30 @@ type upstreamConn struct {
|
|||||||
func connectToUpstream(network *network) (*upstreamConn, error) {
|
func connectToUpstream(network *network) (*upstreamConn, error) {
|
||||||
logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
|
logger := &prefixLogger{network.user.srv.Logger, fmt.Sprintf("upstream %q: ", network.Addr)}
|
||||||
|
|
||||||
var scheme string
|
|
||||||
var addr string
|
|
||||||
|
|
||||||
addrParts := strings.SplitN(network.Addr, "://", 2)
|
|
||||||
if len(addrParts) == 2 {
|
|
||||||
scheme = addrParts[0]
|
|
||||||
addr = addrParts[1]
|
|
||||||
} else {
|
|
||||||
scheme = "ircs"
|
|
||||||
addr = addrParts[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
dialer := net.Dialer{Timeout: connectTimeout}
|
dialer := net.Dialer{Timeout: connectTimeout}
|
||||||
|
|
||||||
|
s := network.Addr
|
||||||
|
if !strings.Contains(s, "://") {
|
||||||
|
// This is a raw domain name, make it an URL with the default scheme
|
||||||
|
s = "ircs://" + s
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse upstream server URL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
var netConn net.Conn
|
var netConn net.Conn
|
||||||
var err error
|
switch u.Scheme {
|
||||||
switch scheme {
|
|
||||||
case "ircs":
|
case "ircs":
|
||||||
|
addr := u.Host
|
||||||
if _, _, err := net.SplitHostPort(addr); err != nil {
|
if _, _, err := net.SplitHostPort(addr); err != nil {
|
||||||
addr = addr + ":6697"
|
addr = addr + ":6697"
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Printf("connecting to TLS server at address %q", addr)
|
logger.Printf("connecting to TLS server at address %q", addr)
|
||||||
|
|
||||||
var cfg *tls.Config
|
var tlsConfig *tls.Config
|
||||||
if network.SASL.Mechanism == "EXTERNAL" {
|
if network.SASL.Mechanism == "EXTERNAL" {
|
||||||
if network.SASL.External.CertBlob == nil {
|
if network.SASL.External.CertBlob == nil {
|
||||||
return nil, fmt.Errorf("missing certificate for authentication")
|
return nil, fmt.Errorf("missing certificate for authentication")
|
||||||
@ -116,7 +116,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse private key: %v", err)
|
return nil, fmt.Errorf("failed to parse private key: %v", err)
|
||||||
}
|
}
|
||||||
cfg = &tls.Config{
|
tlsConfig = &tls.Config{
|
||||||
Certificates: []tls.Certificate{
|
Certificates: []tls.Certificate{
|
||||||
{
|
{
|
||||||
Certificate: [][]byte{network.SASL.External.CertBlob},
|
Certificate: [][]byte{network.SASL.External.CertBlob},
|
||||||
@ -127,20 +127,24 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
|
|||||||
logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
|
logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
|
||||||
}
|
}
|
||||||
|
|
||||||
netConn, err = tls.DialWithDialer(&dialer, "tcp", addr, cfg)
|
netConn, err = tls.DialWithDialer(&dialer, "tcp", addr, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
|
||||||
|
}
|
||||||
case "irc+insecure":
|
case "irc+insecure":
|
||||||
|
addr := u.Host
|
||||||
if _, _, err := net.SplitHostPort(addr); err != nil {
|
if _, _, err := net.SplitHostPort(addr); err != nil {
|
||||||
addr = addr + ":6667"
|
addr = addr + ":6667"
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Printf("connecting to plain-text server at address %q", addr)
|
logger.Printf("connecting to plain-text server at address %q", addr)
|
||||||
netConn, err = dialer.Dial("tcp", addr)
|
netConn, err = dialer.Dial("tcp", addr)
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", addr, scheme)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
|
return nil, fmt.Errorf("failed to dial %q: %v", addr, err)
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("failed to dial %q: unknown scheme: %v", network.Addr, u.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
uc := &upstreamConn{
|
uc := &upstreamConn{
|
||||||
conn: *newConn(network.user.srv, newNetIRCConn(netConn), logger),
|
conn: *newConn(network.user.srv, newNetIRCConn(netConn), logger),
|
||||||
@ -156,7 +160,6 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
|
|||||||
pendingLISTDownstreamSet: make(map[uint64]struct{}),
|
pendingLISTDownstreamSet: make(map[uint64]struct{}),
|
||||||
messageLoggers: make(map[string]*messageLogger),
|
messageLoggers: make(map[string]*messageLogger),
|
||||||
}
|
}
|
||||||
|
|
||||||
return uc, nil
|
return uc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user