diff --git a/http.go b/http.go index d28c66f..4ad386c 100644 --- a/http.go +++ b/http.go @@ -38,16 +38,34 @@ func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) { // // The ipPort is any valid net.Listen TCP address. func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) { - p.addRoute(ipPort, httpHostMatch{match, dest}) + p.addRoute(ipPort, httpHostMatch{matcher: match, target: dest}) +} + +// HTTPHostTargetFunc is the func callback used by Proxy.AddHTTPHostRouteFunc. +type HTTPHostTargetFunc func(ctx context.Context, httpHost string) (t Target, ok bool) + +// AddHTTPHostRouteFunc adds a route to ipPort that matches an HTTP request and calls +// fn to map it to a target. +func (p *Proxy) AddHTTPHostRouteFunc(ipPort string, fn HTTPHostTargetFunc) { + p.addRoute(ipPort, httpHostMatch{targetFunc: fn}) } type httpHostMatch struct { matcher Matcher target Target + + // Alternatively, if targetFunc is non-nil, it's used instead: + targetFunc HTTPHostTargetFunc } func (m httpHostMatch) match(br *bufio.Reader) (Target, string) { hh := httpHostHeader(br) + if m.targetFunc != nil { + if t, ok := m.targetFunc(context.TODO(), hh); ok { + return t, hh + } + return nil, "" + } if m.matcher(context.TODO(), hh) { return m.target, hh } diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 38feb06..525bc86 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -72,7 +72,7 @@ func TestMatchHTTPHost(t *testing.T) { } t.Run(name, func(t *testing.T) { br := bufio.NewReader(tt.r) - r := httpHostMatch{equals(tt.host), noopTarget{}} + r := httpHostMatch{matcher: equals(tt.host), target: noopTarget{}} m, name := r.match(br) got := m != nil if got != tt.want { @@ -247,6 +247,50 @@ func TestProxyHTTP(t *testing.T) { } } +func TestProxyHTTPFunc(t *testing.T) { + front := newLocalListener(t) + defer front.Close() + + backFoo := newLocalListener(t) + defer backFoo.Close() + backBar := newLocalListener(t) + defer backBar.Close() + + p := testProxy(t, front) + p.AddHTTPHostRouteFunc(testFrontAddr, func(ctx context.Context, httpHost string) (_ Target, ok bool) { + if httpHost == "bar.com" { + return To(backBar.Addr().String()), true + } + t.Fatalf("failed to match %q", httpHost) + return nil, false + }) + if err := p.Start(); err != nil { + t.Fatal(err) + } + + toFront, err := net.Dial("tcp", front.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer toFront.Close() + + const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n" + io.WriteString(toFront, msg) + + fromProxy, err := backBar.Accept() + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(fromProxy, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("got %q; want %q", buf, msg) + } +} + func TestProxySNI(t *testing.T) { front := newLocalListener(t) defer front.Close()