470 lines
13 KiB
Go
470 lines
13 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/quic-go/quic-go"
|
|
"github.com/quic-go/quic-go/http3"
|
|
)
|
|
|
|
var (
|
|
verbose = flag.Bool("verbose", false, "verbose logging")
|
|
addr = flag.String("addr", "localhost:8443", "Address to listen on")
|
|
certFile = flag.String("cert", "cert.pem", "Certificate file")
|
|
keyFile = flag.String("key", "key.pem", "Private key file")
|
|
targetAddr = flag.String("target", "", "Target address to proxy to (if empty, acts as forward proxy)")
|
|
)
|
|
|
|
// List of hop-by-hop headers to be removed when proxying
|
|
var hopByHopHeaders = []string{
|
|
"Connection",
|
|
"Keep-Alive",
|
|
"Proxy-Authenticate",
|
|
"Proxy-Authorization",
|
|
"Te",
|
|
"Trailers",
|
|
"Transfer-Encoding",
|
|
"Upgrade",
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
// Setup logger
|
|
logger := log.New(os.Stdout, "[FuckHTTP3] ", log.LstdFlags)
|
|
|
|
// Check if cert and key files exist
|
|
if _, err := os.Stat(*certFile); os.IsNotExist(err) {
|
|
logger.Fatalf("Certificate file %s does not exist", *certFile)
|
|
}
|
|
if _, err := os.Stat(*keyFile); os.IsNotExist(err) {
|
|
logger.Fatalf("Key file %s does not exist", *keyFile)
|
|
}
|
|
|
|
// Normalize target address if specified
|
|
if *targetAddr != "" {
|
|
// Strip trailing slash for consistency
|
|
*targetAddr = strings.TrimSuffix(*targetAddr, "/")
|
|
|
|
// Log the actual target we're using
|
|
if *verbose {
|
|
logger.Printf("Using normalized target: %s", *targetAddr)
|
|
}
|
|
|
|
// Check if target is a valid URL
|
|
if !strings.HasPrefix(*targetAddr, "http://") && !strings.HasPrefix(*targetAddr, "https://") {
|
|
// Add https:// by default
|
|
*targetAddr = "https://" + *targetAddr
|
|
}
|
|
|
|
// Validate the URL
|
|
_, err := url.Parse(*targetAddr)
|
|
if err != nil {
|
|
logger.Fatalf("Invalid target URL: %v", err)
|
|
}
|
|
}
|
|
|
|
// Create a new proxy server
|
|
proxyServer := &ProxyServer{
|
|
logger: logger,
|
|
targetAddr: *targetAddr,
|
|
clients: make(map[string]*http3.RoundTripper),
|
|
clientsMu: &sync.Mutex{},
|
|
}
|
|
|
|
// Extract port for Alt-Svc headers
|
|
portStr := strings.Split(*addr, ":")[1]
|
|
|
|
// Configure TLS
|
|
tlsConfig := &tls.Config{
|
|
MinVersion: tls.VersionTLS13, // HTTP/3 requires TLS 1.3
|
|
NextProtos: []string{"h3"}, // Specify HTTP/3 as the next protocol
|
|
InsecureSkipVerify: false, // Set to true for development only
|
|
}
|
|
|
|
// Configure the HTTP/3 server
|
|
server := &http3.Server{
|
|
Addr: *addr,
|
|
TLSConfig: http3.ConfigureTLSConfig(tlsConfig),
|
|
Handler: proxyServer,
|
|
QuicConfig: &quic.Config{
|
|
MaxIdleTimeout: 30 * time.Second,
|
|
KeepAlivePeriod: 10 * time.Second,
|
|
},
|
|
}
|
|
|
|
logger.Printf("Starting HTTP/3 proxy server on %s", *addr)
|
|
logger.Printf("Using certificate: %s", *certFile)
|
|
logger.Printf("Using key: %s", *keyFile)
|
|
|
|
if *targetAddr != "" {
|
|
logger.Printf("Proxying to target: %s", *targetAddr)
|
|
} else {
|
|
logger.Printf("Running as forward proxy")
|
|
}
|
|
|
|
// Create certificate pair from files
|
|
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to load certificates: %v", err)
|
|
}
|
|
|
|
// Create a standard HTTP server for HTTP/1.1 and HTTP/2
|
|
standardServer := &http.Server{
|
|
Addr: *addr,
|
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Add Alt-Svc header to advertise HTTP/3 capability
|
|
w.Header().Set("Alt-Svc", fmt.Sprintf(`h3=":%s"; ma=2592000`, portStr))
|
|
|
|
// Handle the request with the proxy
|
|
proxyServer.ServeHTTP(w, r)
|
|
}),
|
|
TLSConfig: &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
NextProtos: []string{"h2", "http/1.1"},
|
|
},
|
|
}
|
|
|
|
// Setup context for graceful shutdown
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Handle shutdown signals
|
|
go func() {
|
|
sigCh := make(chan os.Signal, 1)
|
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
sig := <-sigCh
|
|
logger.Printf("Received signal %v, shutting down...", sig)
|
|
|
|
// Cancel context
|
|
cancel()
|
|
|
|
// Create shutdown context with timeout
|
|
shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 10*time.Second)
|
|
defer shutdownCancel()
|
|
|
|
// Shutdown both servers
|
|
if err := standardServer.Shutdown(shutdownCtx); err != nil {
|
|
logger.Printf("Error shutting down HTTP/1.1 server: %v", err)
|
|
}
|
|
|
|
if err := server.Close(); err != nil {
|
|
logger.Printf("Error shutting down HTTP/3 server: %v", err)
|
|
}
|
|
|
|
// Close all active roundtrippers
|
|
proxyServer.closeAllClients()
|
|
}()
|
|
|
|
// Start the HTTP/1.1 + HTTP/2 server in a goroutine
|
|
go func() {
|
|
logger.Printf("Starting HTTP/1.1 and HTTP/2 server on %s", *addr)
|
|
if err := standardServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
|
logger.Printf("HTTP server error: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Start the HTTP/3 server
|
|
err = server.ListenAndServeTLS(*certFile, *keyFile)
|
|
if err != nil && err != http.ErrServerClosed {
|
|
logger.Fatalf("Failed to start HTTP/3 server: %v", err)
|
|
}
|
|
}
|
|
|
|
// ProxyServer implements the http.Handler interface
|
|
type ProxyServer struct {
|
|
logger *log.Logger
|
|
targetAddr string
|
|
clients map[string]*http3.RoundTripper
|
|
clientsMu *sync.Mutex
|
|
}
|
|
|
|
// ServeHTTP handles the HTTP requests
|
|
func (p *ProxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
startTime := time.Now()
|
|
|
|
// Log the incoming request
|
|
if *verbose {
|
|
p.logger.Printf("Received request: %s %s %s", r.Method, r.URL, r.Proto)
|
|
}
|
|
|
|
// Check if we're operating as a reverse proxy (with fixed target) or forward proxy
|
|
if p.targetAddr != "" {
|
|
// Reverse proxy mode
|
|
p.handleReverseProxy(w, r)
|
|
} else {
|
|
// Forward proxy mode
|
|
p.handleForwardProxy(w, r)
|
|
}
|
|
|
|
// Log completion time
|
|
if *verbose {
|
|
p.logger.Printf("Request completed in %v", time.Since(startTime))
|
|
}
|
|
}
|
|
|
|
// getRoundTripper gets or creates an HTTP/3 RoundTripper for the given host
|
|
func (p *ProxyServer) getRoundTripper(host string) *http3.RoundTripper {
|
|
p.clientsMu.Lock()
|
|
defer p.clientsMu.Unlock()
|
|
|
|
if rt, ok := p.clients[host]; ok {
|
|
return rt
|
|
}
|
|
|
|
rt := &http3.RoundTripper{
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: true, // Allow insecure connections for testing
|
|
NextProtos: []string{"h3"},
|
|
},
|
|
}
|
|
|
|
p.clients[host] = rt
|
|
return rt
|
|
}
|
|
|
|
// closeAllClients closes all RoundTripper instances
|
|
func (p *ProxyServer) closeAllClients() {
|
|
p.clientsMu.Lock()
|
|
defer p.clientsMu.Unlock()
|
|
|
|
for host, rt := range p.clients {
|
|
if err := rt.Close(); err != nil && *verbose {
|
|
p.logger.Printf("Error closing roundtripper for %s: %v", host, err)
|
|
}
|
|
}
|
|
|
|
p.clients = make(map[string]*http3.RoundTripper)
|
|
}
|
|
|
|
// handleReverseProxy handles requests in reverse proxy mode
|
|
func (p *ProxyServer) handleReverseProxy(w http.ResponseWriter, r *http.Request) {
|
|
// Get or create a RoundTripper for the target
|
|
roundTripper := p.getRoundTripper(p.targetAddr)
|
|
|
|
// Create a new client
|
|
client := &http.Client{
|
|
Transport: roundTripper,
|
|
Timeout: 30 * time.Second,
|
|
}
|
|
|
|
// Create a new request to the target
|
|
var targetURL string
|
|
if strings.HasPrefix(p.targetAddr, "http://") || strings.HasPrefix(p.targetAddr, "https://") {
|
|
// Target already has a scheme
|
|
if strings.HasSuffix(p.targetAddr, "/") && strings.HasPrefix(r.URL.Path, "/") {
|
|
// Avoid double slashes when both target and path have slashes
|
|
targetURL = p.targetAddr + strings.TrimPrefix(r.URL.Path, "/")
|
|
} else if !strings.HasSuffix(p.targetAddr, "/") && !strings.HasPrefix(r.URL.Path, "/") && r.URL.Path != "" {
|
|
// Add slash when neither has one
|
|
targetURL = p.targetAddr + "/" + r.URL.Path
|
|
} else {
|
|
// Normal case
|
|
targetURL = p.targetAddr + r.URL.Path
|
|
}
|
|
} else {
|
|
// No scheme in target, add https:// (original behavior)
|
|
targetURL = fmt.Sprintf("https://%s%s", p.targetAddr, r.URL.Path)
|
|
}
|
|
|
|
if r.URL.RawQuery != "" {
|
|
targetURL += "?" + r.URL.RawQuery
|
|
}
|
|
|
|
if *verbose {
|
|
p.logger.Printf("Proxying to: %s", targetURL)
|
|
}
|
|
|
|
// Create a new request
|
|
outReq, err := http.NewRequestWithContext(r.Context(), r.Method, targetURL, r.Body)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
p.logger.Printf("Error creating request: %v", err)
|
|
return
|
|
}
|
|
|
|
// Copy headers
|
|
p.copyHeaders(outReq.Header, r.Header)
|
|
|
|
// Send the request
|
|
resp, err := client.Do(outReq)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
p.logger.Printf("Error sending request: %v", err)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Copy the response headers
|
|
p.copyHeaders(w.Header(), resp.Header)
|
|
|
|
// Add Alt-Svc header with the correct port
|
|
portStr := strings.Split(*addr, ":")[1]
|
|
w.Header().Add("Alt-Svc", fmt.Sprintf(`h3=":%s"; ma=2592000`, portStr))
|
|
|
|
w.WriteHeader(resp.StatusCode)
|
|
|
|
// Copy the response body using io.Copy for efficiency
|
|
if _, err := io.Copy(w, resp.Body); err != nil {
|
|
p.logger.Printf("Error copying response body: %v", err)
|
|
}
|
|
}
|
|
|
|
// handleForwardProxy handles requests in forward proxy mode
|
|
func (p *ProxyServer) handleForwardProxy(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodConnect {
|
|
// Handle CONNECT method (for HTTPS)
|
|
p.handleConnect(w, r)
|
|
return
|
|
}
|
|
|
|
// Ensure absolute URL for forward proxy
|
|
if !r.URL.IsAbs() {
|
|
http.Error(w, "Request URL must be absolute in forward proxy mode", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Get or create a RoundTripper for the target host
|
|
roundTripper := p.getRoundTripper(r.URL.Host)
|
|
|
|
// Create a new client
|
|
client := &http.Client{
|
|
Transport: roundTripper,
|
|
Timeout: 30 * time.Second,
|
|
// Don't follow redirects, let the client handle them
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
},
|
|
}
|
|
|
|
if *verbose {
|
|
p.logger.Printf("Proxying to: %s", r.URL)
|
|
}
|
|
|
|
// Create a new request
|
|
outReq, err := http.NewRequestWithContext(r.Context(), r.Method, r.URL.String(), r.Body)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
p.logger.Printf("Error creating request: %v", err)
|
|
return
|
|
}
|
|
|
|
// Copy headers
|
|
p.copyHeaders(outReq.Header, r.Header)
|
|
|
|
// Send the request
|
|
resp, err := client.Do(outReq)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
p.logger.Printf("Error sending request: %v", err)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Copy the response headers
|
|
p.copyHeaders(w.Header(), resp.Header)
|
|
|
|
// Add Alt-Svc header with the correct port
|
|
portStr := strings.Split(*addr, ":")[1]
|
|
w.Header().Add("Alt-Svc", fmt.Sprintf(`h3=":%s"; ma=2592000`, portStr))
|
|
|
|
w.WriteHeader(resp.StatusCode)
|
|
|
|
// Copy the response body
|
|
if _, err := io.Copy(w, resp.Body); err != nil {
|
|
p.logger.Printf("Error copying response body: %v", err)
|
|
}
|
|
}
|
|
|
|
// handleConnect handles the CONNECT method for HTTPS tunneling
|
|
func (p *ProxyServer) handleConnect(w http.ResponseWriter, r *http.Request) {
|
|
// For HTTP/3 CONNECT, we need to establish a QUIC connection
|
|
// to the target and then setup a bidirectional stream
|
|
targetHost := r.Host
|
|
|
|
// Get roundTripper but don't use it yet - we'll access it via client in future implementation
|
|
_ = p.getRoundTripper(targetHost)
|
|
|
|
// Notify the client that tunnel has been established
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
// Check if we're dealing with a hijackable connection
|
|
hijacker, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
p.logger.Printf("Connection doesn't support hijacking, can't establish tunnel")
|
|
http.Error(w, "CONNECT not supported over HTTP/3", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Hijack the connection
|
|
clientConn, _, err := hijacker.Hijack()
|
|
if err != nil {
|
|
p.logger.Printf("Failed to hijack connection: %v", err)
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer clientConn.Close()
|
|
|
|
// For QUIC connections, this is more complex as we can't directly hijack streams
|
|
// This is a basic implementation that won't work for all cases
|
|
p.logger.Printf("CONNECT tunneling is limited in HTTP/3 due to protocol differences")
|
|
p.logger.Printf("Simple pass-through enabled for %s", targetHost)
|
|
|
|
// Note: In a full implementation, we would need to:
|
|
// 1. Open a QUIC connection to the target
|
|
// 2. Create a bidirectional stream
|
|
// 3. Set up forwarding between the client stream and target stream
|
|
// This requires direct access to the QUIC connection which http3.RoundTripper doesn't expose easily
|
|
}
|
|
|
|
// copyHeaders copies HTTP headers from src to dst, removing hop-by-hop headers
|
|
func (p *ProxyServer) copyHeaders(dst, src http.Header) {
|
|
for k, vv := range src {
|
|
// Skip hop-by-hop headers
|
|
if p.isHopByHopHeader(k) {
|
|
continue
|
|
}
|
|
|
|
for _, v := range vv {
|
|
dst.Add(k, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
// isHopByHopHeader checks if a header is hop-by-hop
|
|
func (p *ProxyServer) isHopByHopHeader(header string) bool {
|
|
header = strings.ToLower(header)
|
|
for _, h := range hopByHopHeaders {
|
|
if strings.ToLower(h) == header {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Check for Connection header values
|
|
for _, h := range hopByHopHeaders {
|
|
if h == "Connection" {
|
|
values := strings.Split(header, ",")
|
|
for _, v := range values {
|
|
if strings.TrimSpace(strings.ToLower(v)) == header {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|