FuckHTTP3/main.go
2025-04-26 14:24:23 -04:00

470 lines
13 KiB
Go

package main
import (
"context"
"crypto/tls"
"flag"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
)
var (
verbose = flag.Bool("verbose", false, "verbose logging")
addr = flag.String("addr", "localhost:8443", "Address to listen on")
certFile = flag.String("cert", "cert.pem", "Certificate file")
keyFile = flag.String("key", "key.pem", "Private key file")
targetAddr = flag.String("target", "", "Target address to proxy to (if empty, acts as forward proxy)")
)
// List of hop-by-hop headers to be removed when proxying
var hopByHopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
func main() {
flag.Parse()
// Setup logger
logger := log.New(os.Stdout, "[FuckHTTP3] ", log.LstdFlags)
// Check if cert and key files exist
if _, err := os.Stat(*certFile); os.IsNotExist(err) {
logger.Fatalf("Certificate file %s does not exist", *certFile)
}
if _, err := os.Stat(*keyFile); os.IsNotExist(err) {
logger.Fatalf("Key file %s does not exist", *keyFile)
}
// Normalize target address if specified
if *targetAddr != "" {
// Strip trailing slash for consistency
*targetAddr = strings.TrimSuffix(*targetAddr, "/")
// Log the actual target we're using
if *verbose {
logger.Printf("Using normalized target: %s", *targetAddr)
}
// Check if target is a valid URL
if !strings.HasPrefix(*targetAddr, "http://") && !strings.HasPrefix(*targetAddr, "https://") {
// Add https:// by default
*targetAddr = "https://" + *targetAddr
}
// Validate the URL
_, err := url.Parse(*targetAddr)
if err != nil {
logger.Fatalf("Invalid target URL: %v", err)
}
}
// Create a new proxy server
proxyServer := &ProxyServer{
logger: logger,
targetAddr: *targetAddr,
clients: make(map[string]*http3.RoundTripper),
clientsMu: &sync.Mutex{},
}
// Extract port for Alt-Svc headers
portStr := strings.Split(*addr, ":")[1]
// Configure TLS
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13, // HTTP/3 requires TLS 1.3
NextProtos: []string{"h3"}, // Specify HTTP/3 as the next protocol
InsecureSkipVerify: false, // Set to true for development only
}
// Configure the HTTP/3 server
server := &http3.Server{
Addr: *addr,
TLSConfig: http3.ConfigureTLSConfig(tlsConfig),
Handler: proxyServer,
QuicConfig: &quic.Config{
MaxIdleTimeout: 30 * time.Second,
KeepAlivePeriod: 10 * time.Second,
},
}
logger.Printf("Starting HTTP/3 proxy server on %s", *addr)
logger.Printf("Using certificate: %s", *certFile)
logger.Printf("Using key: %s", *keyFile)
if *targetAddr != "" {
logger.Printf("Proxying to target: %s", *targetAddr)
} else {
logger.Printf("Running as forward proxy")
}
// Create certificate pair from files
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
logger.Fatalf("Failed to load certificates: %v", err)
}
// Create a standard HTTP server for HTTP/1.1 and HTTP/2
standardServer := &http.Server{
Addr: *addr,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add Alt-Svc header to advertise HTTP/3 capability
w.Header().Set("Alt-Svc", fmt.Sprintf(`h3=":%s"; ma=2592000`, portStr))
// Handle the request with the proxy
proxyServer.ServeHTTP(w, r)
}),
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{"h2", "http/1.1"},
},
}
// Setup context for graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Handle shutdown signals
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigCh
logger.Printf("Received signal %v, shutting down...", sig)
// Cancel context
cancel()
// Create shutdown context with timeout
shutdownCtx, shutdownCancel := context.WithTimeout(ctx, 10*time.Second)
defer shutdownCancel()
// Shutdown both servers
if err := standardServer.Shutdown(shutdownCtx); err != nil {
logger.Printf("Error shutting down HTTP/1.1 server: %v", err)
}
if err := server.Close(); err != nil {
logger.Printf("Error shutting down HTTP/3 server: %v", err)
}
// Close all active roundtrippers
proxyServer.closeAllClients()
}()
// Start the HTTP/1.1 + HTTP/2 server in a goroutine
go func() {
logger.Printf("Starting HTTP/1.1 and HTTP/2 server on %s", *addr)
if err := standardServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
logger.Printf("HTTP server error: %v", err)
}
}()
// Start the HTTP/3 server
err = server.ListenAndServeTLS(*certFile, *keyFile)
if err != nil && err != http.ErrServerClosed {
logger.Fatalf("Failed to start HTTP/3 server: %v", err)
}
}
// ProxyServer implements the http.Handler interface
type ProxyServer struct {
logger *log.Logger
targetAddr string
clients map[string]*http3.RoundTripper
clientsMu *sync.Mutex
}
// ServeHTTP handles the HTTP requests
func (p *ProxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
// Log the incoming request
if *verbose {
p.logger.Printf("Received request: %s %s %s", r.Method, r.URL, r.Proto)
}
// Check if we're operating as a reverse proxy (with fixed target) or forward proxy
if p.targetAddr != "" {
// Reverse proxy mode
p.handleReverseProxy(w, r)
} else {
// Forward proxy mode
p.handleForwardProxy(w, r)
}
// Log completion time
if *verbose {
p.logger.Printf("Request completed in %v", time.Since(startTime))
}
}
// getRoundTripper gets or creates an HTTP/3 RoundTripper for the given host
func (p *ProxyServer) getRoundTripper(host string) *http3.RoundTripper {
p.clientsMu.Lock()
defer p.clientsMu.Unlock()
if rt, ok := p.clients[host]; ok {
return rt
}
rt := &http3.RoundTripper{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, // Allow insecure connections for testing
NextProtos: []string{"h3"},
},
}
p.clients[host] = rt
return rt
}
// closeAllClients closes all RoundTripper instances
func (p *ProxyServer) closeAllClients() {
p.clientsMu.Lock()
defer p.clientsMu.Unlock()
for host, rt := range p.clients {
if err := rt.Close(); err != nil && *verbose {
p.logger.Printf("Error closing roundtripper for %s: %v", host, err)
}
}
p.clients = make(map[string]*http3.RoundTripper)
}
// handleReverseProxy handles requests in reverse proxy mode
func (p *ProxyServer) handleReverseProxy(w http.ResponseWriter, r *http.Request) {
// Get or create a RoundTripper for the target
roundTripper := p.getRoundTripper(p.targetAddr)
// Create a new client
client := &http.Client{
Transport: roundTripper,
Timeout: 30 * time.Second,
}
// Create a new request to the target
var targetURL string
if strings.HasPrefix(p.targetAddr, "http://") || strings.HasPrefix(p.targetAddr, "https://") {
// Target already has a scheme
if strings.HasSuffix(p.targetAddr, "/") && strings.HasPrefix(r.URL.Path, "/") {
// Avoid double slashes when both target and path have slashes
targetURL = p.targetAddr + strings.TrimPrefix(r.URL.Path, "/")
} else if !strings.HasSuffix(p.targetAddr, "/") && !strings.HasPrefix(r.URL.Path, "/") && r.URL.Path != "" {
// Add slash when neither has one
targetURL = p.targetAddr + "/" + r.URL.Path
} else {
// Normal case
targetURL = p.targetAddr + r.URL.Path
}
} else {
// No scheme in target, add https:// (original behavior)
targetURL = fmt.Sprintf("https://%s%s", p.targetAddr, r.URL.Path)
}
if r.URL.RawQuery != "" {
targetURL += "?" + r.URL.RawQuery
}
if *verbose {
p.logger.Printf("Proxying to: %s", targetURL)
}
// Create a new request
outReq, err := http.NewRequestWithContext(r.Context(), r.Method, targetURL, r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
p.logger.Printf("Error creating request: %v", err)
return
}
// Copy headers
p.copyHeaders(outReq.Header, r.Header)
// Send the request
resp, err := client.Do(outReq)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
p.logger.Printf("Error sending request: %v", err)
return
}
defer resp.Body.Close()
// Copy the response headers
p.copyHeaders(w.Header(), resp.Header)
// Add Alt-Svc header with the correct port
portStr := strings.Split(*addr, ":")[1]
w.Header().Add("Alt-Svc", fmt.Sprintf(`h3=":%s"; ma=2592000`, portStr))
w.WriteHeader(resp.StatusCode)
// Copy the response body using io.Copy for efficiency
if _, err := io.Copy(w, resp.Body); err != nil {
p.logger.Printf("Error copying response body: %v", err)
}
}
// handleForwardProxy handles requests in forward proxy mode
func (p *ProxyServer) handleForwardProxy(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
// Handle CONNECT method (for HTTPS)
p.handleConnect(w, r)
return
}
// Ensure absolute URL for forward proxy
if !r.URL.IsAbs() {
http.Error(w, "Request URL must be absolute in forward proxy mode", http.StatusBadRequest)
return
}
// Get or create a RoundTripper for the target host
roundTripper := p.getRoundTripper(r.URL.Host)
// Create a new client
client := &http.Client{
Transport: roundTripper,
Timeout: 30 * time.Second,
// Don't follow redirects, let the client handle them
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
if *verbose {
p.logger.Printf("Proxying to: %s", r.URL)
}
// Create a new request
outReq, err := http.NewRequestWithContext(r.Context(), r.Method, r.URL.String(), r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
p.logger.Printf("Error creating request: %v", err)
return
}
// Copy headers
p.copyHeaders(outReq.Header, r.Header)
// Send the request
resp, err := client.Do(outReq)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
p.logger.Printf("Error sending request: %v", err)
return
}
defer resp.Body.Close()
// Copy the response headers
p.copyHeaders(w.Header(), resp.Header)
// Add Alt-Svc header with the correct port
portStr := strings.Split(*addr, ":")[1]
w.Header().Add("Alt-Svc", fmt.Sprintf(`h3=":%s"; ma=2592000`, portStr))
w.WriteHeader(resp.StatusCode)
// Copy the response body
if _, err := io.Copy(w, resp.Body); err != nil {
p.logger.Printf("Error copying response body: %v", err)
}
}
// handleConnect handles the CONNECT method for HTTPS tunneling
func (p *ProxyServer) handleConnect(w http.ResponseWriter, r *http.Request) {
// For HTTP/3 CONNECT, we need to establish a QUIC connection
// to the target and then setup a bidirectional stream
targetHost := r.Host
// Get roundTripper but don't use it yet - we'll access it via client in future implementation
_ = p.getRoundTripper(targetHost)
// Notify the client that tunnel has been established
w.WriteHeader(http.StatusOK)
// Check if we're dealing with a hijackable connection
hijacker, ok := w.(http.Hijacker)
if !ok {
p.logger.Printf("Connection doesn't support hijacking, can't establish tunnel")
http.Error(w, "CONNECT not supported over HTTP/3", http.StatusInternalServerError)
return
}
// Hijack the connection
clientConn, _, err := hijacker.Hijack()
if err != nil {
p.logger.Printf("Failed to hijack connection: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer clientConn.Close()
// For QUIC connections, this is more complex as we can't directly hijack streams
// This is a basic implementation that won't work for all cases
p.logger.Printf("CONNECT tunneling is limited in HTTP/3 due to protocol differences")
p.logger.Printf("Simple pass-through enabled for %s", targetHost)
// Note: In a full implementation, we would need to:
// 1. Open a QUIC connection to the target
// 2. Create a bidirectional stream
// 3. Set up forwarding between the client stream and target stream
// This requires direct access to the QUIC connection which http3.RoundTripper doesn't expose easily
}
// copyHeaders copies HTTP headers from src to dst, removing hop-by-hop headers
func (p *ProxyServer) copyHeaders(dst, src http.Header) {
for k, vv := range src {
// Skip hop-by-hop headers
if p.isHopByHopHeader(k) {
continue
}
for _, v := range vv {
dst.Add(k, v)
}
}
}
// isHopByHopHeader checks if a header is hop-by-hop
func (p *ProxyServer) isHopByHopHeader(header string) bool {
header = strings.ToLower(header)
for _, h := range hopByHopHeaders {
if strings.ToLower(h) == header {
return true
}
}
// Check for Connection header values
for _, h := range hopByHopHeaders {
if h == "Connection" {
values := strings.Split(header, ",")
for _, v := range values {
if strings.TrimSpace(strings.ToLower(v)) == header {
return true
}
}
}
}
return false
}