Skip to content

Commit 3a757bb

Browse files
committed
Added the test helper to assert no message reached backend
- Add the test case for existing backends - Add the test case for add/remove route after server starts
1 parent cfba6a1 commit 3a757bb

File tree

1 file changed

+93
-51
lines changed

1 file changed

+93
-51
lines changed

tcpproxy_test.go

+93-51
Original file line numberDiff line numberDiff line change
@@ -169,24 +169,75 @@ func testProxy(t *testing.T, front net.Listener) *Proxy {
169169
}
170170
}
171171

172-
func testRouteToBackendWithExpected(t *testing.T, front net.Conn, back net.Listener, msg string, expected string) {
173-
io.WriteString(front, msg)
174-
fromProxy, err := back.Accept()
175-
if err != nil {
176-
t.Fatal(err)
177-
}
178-
179-
buf := make([]byte, len(msg))
180-
if _, err := io.ReadFull(fromProxy, buf); err != nil {
181-
t.Fatal(err)
182-
}
183-
if string(buf) != expected {
184-
t.Fatalf("got %q; want %q", buf, expected)
185-
}
172+
func testRouteToBackendWithExpected(t *testing.T, toFront net.Conn, back net.Listener, msg string, expected string) {
173+
io.WriteString(toFront, msg)
174+
fromProxy, err := back.Accept()
175+
if err != nil {
176+
t.Fatal(err)
177+
}
178+
179+
buf := make([]byte, len(expected))
180+
if _, err := io.ReadFull(fromProxy, buf); err != nil {
181+
t.Fatal(err)
182+
}
183+
if string(buf) != expected {
184+
t.Fatalf("got %q; want %q", buf, expected)
185+
}
186186
}
187187

188-
func testRouteToBackend(t *testing.T, front net.Conn, back net.Listener, msg string) {
189-
testRouteToBackendWithExpected(t, front, back, msg, msg)
188+
func testRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) {
189+
toFront, err := net.Dial("tcp", front.Addr().String())
190+
if err != nil {
191+
t.Fatal(err)
192+
}
193+
defer toFront.Close()
194+
195+
testRouteToBackendWithExpected(t, toFront, back, msg, msg)
196+
}
197+
198+
// test the backend is not receiving traffic
199+
func testNotRouteToBackend(t *testing.T, front net.Listener, back net.Listener, msg string) <-chan bool {
200+
done := make(chan bool)
201+
toFront, err := net.Dial("tcp", front.Addr().String())
202+
if err != nil {
203+
t.Fatal(err)
204+
}
205+
defer toFront.Close()
206+
207+
timeC := time.NewTimer(10 * time.Millisecond).C
208+
acceptC := make(chan struct{})
209+
go func() {
210+
io.WriteString(toFront, msg)
211+
fromProxy, err := back.Accept()
212+
acceptC <- struct{}{}
213+
{
214+
if err == nil {
215+
buf := make([]byte, len(msg))
216+
if _, err := io.ReadFull(fromProxy, buf); err != nil {
217+
t.Fatal(err)
218+
}
219+
t.Fatalf("Expect backend to not receive message, but found %s", string(buf))
220+
}
221+
err, ok := err.(net.Error)
222+
if !ok || !err.Timeout() {
223+
t.Fatalf("Expect backend to timeout, but found err: %v", err)
224+
}
225+
}
226+
}()
227+
go func() {
228+
select {
229+
case <-timeC:
230+
{
231+
done <- true
232+
}
233+
case <-acceptC:
234+
{
235+
t.Fatal("Expect backend to not receive message")
236+
done <- true
237+
}
238+
}
239+
}()
240+
return done
190241
}
191242

192243
func TestProxyAlwaysMatch(t *testing.T) {
@@ -201,13 +252,7 @@ func TestProxyAlwaysMatch(t *testing.T) {
201252
t.Fatal(err)
202253
}
203254

204-
toFront, err := net.Dial("tcp", front.Addr().String())
205-
if err != nil {
206-
t.Fatal(err)
207-
}
208-
defer toFront.Close()
209-
210-
testRouteToBackend(t, toFront, back, "message")
255+
testRouteToBackend(t, front, back, "message")
211256
}
212257

213258
func TestProxyHTTP(t *testing.T) {
@@ -226,14 +271,9 @@ func TestProxyHTTP(t *testing.T) {
226271
t.Fatal(err)
227272
}
228273

229-
toFront, err := net.Dial("tcp", front.Addr().String())
230-
if err != nil {
231-
t.Fatal(err)
232-
}
233-
defer toFront.Close()
234-
235-
const msg = "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n"
236-
testRouteToBackend(t, toFront, backBar, msg)
274+
testRouteToBackend(t, front, backBar, "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n")
275+
<-testNotRouteToBackend(t, front, backFoo, "GET / HTTP/1.1\r\nHost: bar.com\r\n\r\n")
276+
testRouteToBackend(t, front, backFoo, "GET / HTTP/1.1\r\nHost: foo.com\r\n\r\n")
237277
}
238278

239279
func TestProxySNI(t *testing.T) {
@@ -252,30 +292,32 @@ func TestProxySNI(t *testing.T) {
252292
t.Fatal(err)
253293
}
254294

255-
toFront, err := net.Dial("tcp", front.Addr().String())
256-
if err != nil {
257-
t.Fatal(err)
258-
}
259-
defer toFront.Close()
260-
261-
msg := clientHelloRecord(t, "bar.com")
262-
testRouteToBackend(t, toFront, backBar, msg)
295+
testRouteToBackend(t, front, backBar, clientHelloRecord(t, "bar.com"))
296+
<-testNotRouteToBackend(t, front, backBar, clientHelloRecord(t, "foo.com"))
297+
testRouteToBackend(t, front, backFoo, clientHelloRecord(t, "foo.com"))
263298
}
264-
msg := clientHelloRecord(t, "bar.com")
265-
io.WriteString(toFront, msg)
266299

267-
fromProxy, err := backBar.Accept()
268-
if err != nil {
269-
t.Fatal(err)
270-
}
300+
func TestProxyRemoveRoute(t *testing.T) {
301+
front := newLocalListener(t)
302+
defer front.Close()
303+
p := testProxy(t, front)
271304

272-
buf := make([]byte, len(msg))
273-
if _, err := io.ReadFull(fromProxy, buf); err != nil {
305+
// NOTE: Needs to register testFrontAddr before server starts
306+
p.AddSNIRoute(testFrontAddr, "unused.com", noopTarget{})
307+
308+
if err := p.Start(); err != nil {
274309
t.Fatal(err)
275310
}
276-
if string(buf) != msg {
277-
t.Fatalf("got %q; want %q", buf, msg)
278-
}
311+
312+
backBar := newLocalListener(t)
313+
defer backBar.Close()
314+
routeId := p.AddSNIRoute(testFrontAddr, "bar.com", To(backBar.Addr().String()))
315+
316+
msg := clientHelloRecord(t, "bar.com")
317+
testRouteToBackend(t, front, backBar, msg)
318+
319+
p.RemoveRoute(testFrontAddr, routeId)
320+
<-testNotRouteToBackend(t, front, backBar, msg)
279321
}
280322

281323
func TestProxyPROXYOut(t *testing.T) {
@@ -299,7 +341,7 @@ func TestProxyPROXYOut(t *testing.T) {
299341
}
300342

301343
want := fmt.Sprintf("PROXY TCP4 %s %d %s %d\r\nfoo", toFront.LocalAddr().(*net.TCPAddr).IP, toFront.LocalAddr().(*net.TCPAddr).Port, toFront.RemoteAddr().(*net.TCPAddr).IP, toFront.RemoteAddr().(*net.TCPAddr).Port)
302-
testRouteToBackendWithExpected(t, toFront, back, "foo", want)
344+
testRouteToBackendWithExpected(t, toFront, back, "foo", want)
303345
}
304346

305347
type tlsServer struct {

0 commit comments

Comments
 (0)