diff --git a/fiber/adapter.go b/fiber/adapter.go index 6da955f..592674e 100644 --- a/fiber/adapter.go +++ b/fiber/adapter.go @@ -5,9 +5,12 @@ package fiberadapter import ( "context" - "io/ioutil" + "fmt" + "io" + "log" "net" "net/http" + "strings" "github.com/aws/aws-lambda-go/events" "github.com/gofiber/fiber/v2" @@ -103,7 +106,7 @@ func (f *FiberLambda) adaptor(w http.ResponseWriter, r *http.Request) { defer fasthttp.ReleaseRequest(req) // Convert net/http -> fasthttp request - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError) return @@ -129,8 +132,16 @@ func (f *FiberLambda) adaptor(w http.ResponseWriter, r *http.Request) { } } - remoteAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr) + // We need to make sure the net.ResolveTCPAddr call works as it expects a port + addrWithPort := r.RemoteAddr + if !strings.Contains(r.RemoteAddr, ":") { + addrWithPort = r.RemoteAddr + ":80" // assuming a default port + } + + remoteAddr, err := net.ResolveTCPAddr("tcp", addrWithPort) if err != nil { + fmt.Printf("could not resolve TCP address for addr %s\n", r.RemoteAddr) + log.Println(err) http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError) return } diff --git a/fiber/fiberlambda_test.go b/fiber/fiberlambda_test.go index 75846e5..7813911 100644 --- a/fiber/fiberlambda_test.go +++ b/fiber/fiberlambda_test.go @@ -39,6 +39,39 @@ var _ = Describe("FiberLambda tests", func() { }) }) + Context("RemoteAddr handling", func() { + It("Properly parses the IP address", func() { + app := fiber.New() + app.Get("/ping", func(c *fiber.Ctx) error { + // make sure the ip address is actually set properly + Expect(c.Context().RemoteAddr().String()).To(Equal("8.8.8.8:80")) + return c.SendString("pong") + }) + + adapter := fiberadaptor.New(app) + + req := events.APIGatewayProxyRequest{ + Path: "/ping", + HTTPMethod: "GET", + RequestContext: events.APIGatewayProxyRequestContext{ + Identity: events.APIGatewayRequestIdentity{ + SourceIP: "8.8.8.8", + }, + }, + } + + resp, err := adapter.ProxyWithContext(context.Background(), req) + + Expect(err).To(BeNil()) + Expect(resp.StatusCode).To(Equal(200)) + + resp, err = adapter.Proxy(req) + + Expect(err).To(BeNil()) + Expect(resp.StatusCode).To(Equal(200)) + }) + }) + Context("Request header", func() { It("Check pass canonical header to fiber", func() { app := fiber.New()