From 5b0eca6203a0eceee231ed0e1b575bfc7188d71e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan-Luis=20de=20Sousa-Valadas=20Casta=C3=B1o?= Date: Fri, 30 Aug 2024 16:18:59 +0200 Subject: [PATCH] Implement Proxy.SetRoutes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This allows to replace all routes at once. Signed-off-by: Juan-Luis de Sousa-Valadas CastaƱo --- tcpproxy.go | 27 ++++++++++++++++++++++----- tcpproxy_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/tcpproxy.go b/tcpproxy.go index 1f03e32..d2c4077 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -144,6 +144,23 @@ func (p *Proxy) AddRoute(ipPort string, dest Target) { p.addRoute(ipPort, fixedTarget{dest}) } +func (p *Proxy) setRoutes(ipPort string, targets []Target) { + var routes []route + for _, target := range targets { + routes = append(routes, fixedTarget{target}) + } + cfg := p.configFor(ipPort) + cfg.routes = routes +} + +// SetRoutes replaces routes for the ipPort. +// +// It's possible that the old routes are still used once after this +// function is called. +func (p *Proxy) SetRoutes(ipPort string, targets []Target) { + p.setRoutes(ipPort, targets) +} + type fixedTarget struct { t Target } @@ -198,7 +215,7 @@ func (p *Proxy) Start() error { return err } p.lns = append(p.lns, ln) - go p.serveListener(errc, ln, config.routes) + go p.serveListener(errc, ln, config) } go p.awaitFirstError(errc) return nil @@ -209,22 +226,22 @@ func (p *Proxy) awaitFirstError(errc <-chan error) { close(p.donec) } -func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route) { +func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, cfg *config) { for { c, err := ln.Accept() if err != nil { ret <- err return } - go p.serveConn(c, routes) + go p.serveConn(c, cfg) } } // serveConn runs in its own goroutine and matches c against routes. // It returns whether it matched purely for testing. -func (p *Proxy) serveConn(c net.Conn, routes []route) bool { +func (p *Proxy) serveConn(c net.Conn, cfg *config) bool { br := bufio.NewReader(c) - for _, route := range routes { + for _, route := range cfg.routes { if target, hostName := route.match(br); target != nil { if n := br.Buffered(); n > 0 { peeked, _ := br.Peek(br.Buffered()) diff --git a/tcpproxy_test.go b/tcpproxy_test.go index 0346a7a..95cd8eb 100644 --- a/tcpproxy_test.go +++ b/tcpproxy_test.go @@ -408,6 +408,50 @@ func TestProxyPROXYOut(t *testing.T) { } } +func TestSetRoutes(t *testing.T) { + + var p Proxy + ipPort := ":8080" + p.AddRoute(ipPort, To("127.0.0.2:8080")) + cfg := p.configFor(ipPort) + + expectedAddrsList := [][]string{ + {"127.0.0.1:80"}, + {"127.0.0.1:80", "127.0.0.1:443"}, + {}, + {"127.0.0.1:80"}, + } + + for _, expectedAddrs := range expectedAddrsList { + p.setRoutes(ipPort, stringsToTargets(expectedAddrs)) + if !equalRoutes(cfg.routes, expectedAddrs) { + t.Fatalf("got %v; want %v", cfg.routes, expectedAddrs) + } + } +} + +func stringsToTargets(s []string) []Target { + targets := make([]Target, len(s)) + for i, v := range s { + targets[i] = To(v) + } + + return targets +} +func equalRoutes(routes []route, expectedAddrs []string) bool { + if len(routes) != len(expectedAddrs) { + return false + } + + for i, _ := range routes { + addr := routes[i].(fixedTarget).t.(*DialProxy).Addr + if addr != expectedAddrs[i] { + return false + } + } + return true +} + type tlsServer struct { Listener net.Listener Domain string