Skip to content

Commit bc6743a

Browse files
committed
feat: add support for ALBTargetGroupRequest/Response
1 parent 5f61940 commit bc6743a

File tree

3 files changed

+173
-2
lines changed

3 files changed

+173
-2
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Simple HTTP adapter for AWS Lambda
88
- AWS Lambda Function URL (both normal and streaming)
99
- API Gateway (v1)
1010
- API Gateway (v2)
11+
- Application Load Balancer
1112

1213
## Builtin support for these HTTP frameworks:
1314
- `net/http`
@@ -252,6 +253,7 @@ Once this build-tag is present, the following build-tags are available:
252253
- `lambdahttpadapter.apigwv1` (enables API Gateway V1 handler)
253254
- `lambdahttpadapter.apigwv2` (enables API Gateway V2 handler)
254255
- `lambdahttpadapter.functionurl` (enables Lambda Function URL handler)
256+
- `lambdahttpadapter.alb` (enables Application Load Balancer handler)
255257

256258
Also note that Lambda Function URL in Streaming-Mode requires the following build-tag to be set:
257259
- `lambda.norpc`

handler/alb.go

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
//go:build !lambdahttpadapter.partial || (lambdahttpadapter.partial && lambdahttpadapter.alb)
2+
3+
package handler
4+
5+
import (
6+
"bytes"
7+
"context"
8+
"encoding/base64"
9+
"github.com/aws/aws-lambda-go/events"
10+
"io"
11+
"net/http"
12+
"net/url"
13+
"strconv"
14+
"strings"
15+
"unicode/utf8"
16+
)
17+
18+
func convertALBRequest(ctx context.Context, event events.ALBTargetGroupRequest) (*http.Request, error) {
19+
q := make(url.Values)
20+
21+
if len(event.MultiValueQueryStringParameters) > 0 {
22+
for k, values := range event.MultiValueQueryStringParameters {
23+
for _, v := range values {
24+
q.Add(k, v)
25+
}
26+
}
27+
} else if len(event.QueryStringParameters) > 0 {
28+
for k, v := range event.QueryStringParameters {
29+
q.Add(k, v)
30+
}
31+
}
32+
33+
headers := make(http.Header)
34+
if event.Headers != nil {
35+
for k, v := range event.Headers {
36+
headers.Add(k, v)
37+
}
38+
}
39+
40+
if event.MultiValueHeaders != nil {
41+
for k, values := range event.MultiValueHeaders {
42+
for _, v := range values {
43+
headers.Add(k, v)
44+
}
45+
}
46+
}
47+
48+
host := headers.Get("X-Forwarded-Host")
49+
if host == "" {
50+
host = headers.Get("Host")
51+
if host == "" {
52+
host = "127.0.0.1"
53+
}
54+
}
55+
56+
sourceIp := headers.Get("X-Forwarded-For")
57+
if sourceIp == "" {
58+
sourceIp = "127.0.0.1"
59+
}
60+
61+
proto := headers.Get("X-Forwarded-Proto")
62+
if proto == "" {
63+
proto = "http"
64+
}
65+
66+
rUrl := buildFullRequestURLWithProto(proto, host, event.Path, "", q.Encode())
67+
req, err := http.NewRequestWithContext(ctx, event.HTTPMethod, rUrl, getBody(event.Body, event.IsBase64Encoded))
68+
if err != nil {
69+
return nil, err
70+
}
71+
72+
req.Header = headers
73+
req.RemoteAddr = buildRemoteAddr(sourceIp)
74+
req.RequestURI = req.URL.RequestURI()
75+
76+
return req, nil
77+
}
78+
79+
type albResponseWriter struct {
80+
multiValueHeaders bool
81+
headersWritten bool
82+
contentTypeSet bool
83+
contentLengthSet bool
84+
headers http.Header
85+
body bytes.Buffer
86+
res events.ALBTargetGroupResponse
87+
}
88+
89+
func (w *albResponseWriter) Header() http.Header {
90+
return w.headers
91+
}
92+
93+
func (w *albResponseWriter) Write(p []byte) (int, error) {
94+
w.WriteHeader(http.StatusOK)
95+
return w.body.Write(p)
96+
}
97+
98+
func (w *albResponseWriter) WriteHeader(statusCode int) {
99+
if !w.headersWritten {
100+
w.headersWritten = true
101+
w.res.StatusCode = statusCode
102+
103+
for k, values := range w.headers {
104+
if w.multiValueHeaders {
105+
w.res.MultiValueHeaders[k] = values
106+
} else {
107+
w.res.Headers[k] = strings.Join(values, ",")
108+
}
109+
}
110+
}
111+
}
112+
113+
func handleALB(multiValueHeaders bool) func(ctx context.Context, event events.ALBTargetGroupRequest, adapter AdapterFunc) (events.ALBTargetGroupResponse, error) {
114+
return func(ctx context.Context, event events.ALBTargetGroupRequest, adapter AdapterFunc) (events.ALBTargetGroupResponse, error) {
115+
req, err := convertALBRequest(ctx, event)
116+
if err != nil {
117+
var def events.ALBTargetGroupResponse
118+
return def, err
119+
}
120+
121+
w := albResponseWriter{
122+
multiValueHeaders: multiValueHeaders,
123+
headers: make(http.Header),
124+
res: events.ALBTargetGroupResponse{},
125+
}
126+
127+
if multiValueHeaders {
128+
w.res.MultiValueHeaders = make(map[string][]string)
129+
} else {
130+
w.res.Headers = make(map[string]string)
131+
}
132+
133+
if err = adapter(ctx, req, &w); err != nil {
134+
var def events.ALBTargetGroupResponse
135+
return def, err
136+
}
137+
138+
b, err := io.ReadAll(&w.body)
139+
if err != nil {
140+
var def events.ALBTargetGroupResponse
141+
return def, err
142+
}
143+
144+
if !w.contentTypeSet {
145+
w.res.Headers["Content-Type"] = http.DetectContentType(b)
146+
}
147+
148+
if !w.contentLengthSet {
149+
w.res.Headers["Content-Length"] = strconv.Itoa(len(b))
150+
}
151+
152+
if utf8.Valid(b) {
153+
w.res.Body = string(b)
154+
} else {
155+
w.res.IsBase64Encoded = true
156+
w.res.Body = base64.StdEncoding.EncodeToString(b)
157+
}
158+
159+
return w.res, nil
160+
}
161+
}
162+
163+
func NewALBHandler(adapter AdapterFunc, multiValueHeaders bool) func(context.Context, events.ALBTargetGroupRequest) (events.ALBTargetGroupResponse, error) {
164+
return NewHandler(handleALB(multiValueHeaders), adapter)
165+
}

handler/common.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ func buildQuery(rawQuery string, queryParams map[string]string) string {
2424
return ""
2525
}
2626

27-
func buildFullRequestURL(host string, path string, altPath string, query string) string {
27+
func buildFullRequestURL(host, path, altPath, query string) string {
28+
return buildFullRequestURLWithProto("https", host, path, altPath, query)
29+
}
30+
31+
func buildFullRequestURLWithProto(proto, host, path, altPath, query string) string {
2832
rUrl := path
2933

3034
if rUrl == "" {
@@ -35,7 +39,7 @@ func buildFullRequestURL(host string, path string, altPath string, query string)
3539
rUrl = "/" + rUrl
3640
}
3741

38-
rUrl = "https://" + host + rUrl
42+
rUrl = proto + "://" + host + rUrl
3943

4044
if query != "" {
4145
rUrl += "?" + query

0 commit comments

Comments
 (0)