From ce4ebe0aaa67ea059fc23915b51ec564e66f4e2b Mon Sep 17 00:00:00 2001 From: Oleksandr Savchuk Date: Tue, 14 Nov 2023 20:08:04 +0200 Subject: [PATCH 1/3] allow to use custom func for get ip from request --- tgb/webhook.go | 14 +++++++++++++- tgb/webhook_test.go | 5 +++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tgb/webhook.go b/tgb/webhook.go index 613e1eb..3ae59de 100644 --- a/tgb/webhook.go +++ b/tgb/webhook.go @@ -33,6 +33,8 @@ type Webhook struct { securitySubnets []netip.Prefix securityToken string + ipFromRequestFunc func(r *http.Request) string + isSetup bool } @@ -66,6 +68,14 @@ func WithWebhookIP(ip string) WebhookOption { } } +// WithWebhookRequestIP sets function to get the IP address from the request. +// By default the IP address is resolved through the X-Real-Ip and X-Forwarded-For headers. +func WithWebhookRequestIP(ip func(r *http.Request) string) WebhookOption { + return func(webhook *Webhook) { + webhook.ipFromRequestFunc = ip + } +} + // WithWebhookSecuritySubnets sets list of subnets which are allowed to send webhook requests. func WithWebhookSecuritySubnets(subnets ...netip.Prefix) WebhookOption { return func(webhook *Webhook) { @@ -117,6 +127,8 @@ func NewWebhook(handler Handler, client *tg.Client, url string, options ...Webho allowedUpdates: []tg.UpdateType{}, securitySubnets: defaultSubnets, securityToken: token, + + ipFromRequestFunc: realip.FromRequest, } for _, option := range options { @@ -322,7 +334,7 @@ func (webhook *Webhook) ServeRequest(ctx context.Context, r *WebhookRequest) *We // ServeHTTP is the HTTP handler for webhook requests. // Implementation of http.Handler. func (webhook *Webhook) ServeHTTP(w http.ResponseWriter, r *http.Request) { - ip, err := netip.ParseAddr(realip.FromRequest(r)) + ip, err := netip.ParseAddr(webhook.ipFromRequestFunc(r)) if err != nil { webhook.log("failed to parse ip: %s", err) http.Error(w, "failed to parse ip", http.StatusBadRequest) diff --git a/tgb/webhook_test.go b/tgb/webhook_test.go index 0b4b046..45d6c46 100644 --- a/tgb/webhook_test.go +++ b/tgb/webhook_test.go @@ -27,6 +27,7 @@ func TestNewWebhook(t *testing.T) { assert.Equal(t, "https://example.com/webhook", webhook.url) assert.NotNil(t, webhook.handler) + assert.NotNil(t, webhook.ipFromRequestFunc) assert.NotNil(t, webhook.securityToken) assert.Len(t, webhook.securitySubnets, 2) }) @@ -39,10 +40,14 @@ func TestNewWebhook(t *testing.T) { WithWebhookSecuritySubnets(netip.MustParsePrefix("1.1.1.1/24")), WithWebhookSecurityToken("12345"), WithWebhookMaxConnections(10), + WithWebhookRequestIP(func(r *http.Request) string { + return "" + }), ) assert.Equal(t, "https://example.com/webhook", webhook.url) assert.NotNil(t, webhook.handler) + assert.NotNil(t, webhook.ipFromRequestFunc) assert.Equal(t, "12345", webhook.securityToken) assert.Len(t, webhook.securitySubnets, 1) assert.Equal(t, 10, webhook.maxConnections) From 776b9c69aae1f212b8d83127eb716acc807b80f6 Mon Sep 17 00:00:00 2001 From: Oleksandr Savchuk Date: Tue, 14 Nov 2023 20:16:34 +0200 Subject: [PATCH 2/3] add `DefaultWebhookRequestIP` to public vars --- tgb/webhook.go | 6 +++++- tgb/webhook_test.go | 23 +++++++++++++++-------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/tgb/webhook.go b/tgb/webhook.go index 3ae59de..d547eea 100644 --- a/tgb/webhook.go +++ b/tgb/webhook.go @@ -68,6 +68,10 @@ func WithWebhookIP(ip string) WebhookOption { } } +// DefaultWebhookRequestIP is the default function to get the IP address from the request. +// By default the IP address is resolved through the X-Real-Ip and X-Forwarded-For headers. +var DefaultWebhookRequestIP = realip.FromRequest + // WithWebhookRequestIP sets function to get the IP address from the request. // By default the IP address is resolved through the X-Real-Ip and X-Forwarded-For headers. func WithWebhookRequestIP(ip func(r *http.Request) string) WebhookOption { @@ -128,7 +132,7 @@ func NewWebhook(handler Handler, client *tg.Client, url string, options ...Webho securitySubnets: defaultSubnets, securityToken: token, - ipFromRequestFunc: realip.FromRequest, + ipFromRequestFunc: DefaultWebhookRequestIP, } for _, option := range options { diff --git a/tgb/webhook_test.go b/tgb/webhook_test.go index 45d6c46..759bf12 100644 --- a/tgb/webhook_test.go +++ b/tgb/webhook_test.go @@ -6,13 +6,13 @@ import ( "io" "net/http" "net/http/httptest" - "net/netip" "net/url" "strings" "testing" "time" "github.com/mr-linch/go-tg" + "github.com/mr-linch/go-tg/tgb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -32,16 +32,23 @@ func TestNewWebhook(t *testing.T) { assert.Len(t, webhook.securitySubnets, 2) }) t.Run("Custom", func(t *testing.T) { + var ( + handler tgb.Handler + client *tg.Client + ) + webhook := NewWebhook( - HandlerFunc(func(ctx context.Context, update *Update) error { return nil }), - &tg.Client{}, + handler, + client, "https://example.com/webhook", - WithWebhookIP("1.1.1.1"), - WithWebhookSecuritySubnets(netip.MustParsePrefix("1.1.1.1/24")), - WithWebhookSecurityToken("12345"), - WithWebhookMaxConnections(10), WithWebhookRequestIP(func(r *http.Request) string { - return "" + ip := r.Header.Get("Cf-Connecting-Ip") + + if ip != "" { + return ip + } + + return tgb.DefaultWebhookRequestIP(r) }), ) From 6f8233935b730d1f0eb8a76e675a78e328bbff58 Mon Sep 17 00:00:00 2001 From: Oleksandr Savchuk Date: Tue, 14 Nov 2023 20:18:05 +0200 Subject: [PATCH 3/3] fix tests --- tgb/webhook_test.go | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/tgb/webhook_test.go b/tgb/webhook_test.go index 759bf12..60455ca 100644 --- a/tgb/webhook_test.go +++ b/tgb/webhook_test.go @@ -6,13 +6,13 @@ import ( "io" "net/http" "net/http/httptest" + "net/netip" "net/url" "strings" "testing" "time" "github.com/mr-linch/go-tg" - "github.com/mr-linch/go-tg/tgb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -28,27 +28,21 @@ func TestNewWebhook(t *testing.T) { assert.Equal(t, "https://example.com/webhook", webhook.url) assert.NotNil(t, webhook.handler) assert.NotNil(t, webhook.ipFromRequestFunc) + assert.NotNil(t, webhook.securityToken) assert.Len(t, webhook.securitySubnets, 2) }) t.Run("Custom", func(t *testing.T) { - var ( - handler tgb.Handler - client *tg.Client - ) - webhook := NewWebhook( - handler, - client, + HandlerFunc(func(ctx context.Context, update *Update) error { return nil }), + &tg.Client{}, "https://example.com/webhook", + WithWebhookIP("1.1.1.1"), + WithWebhookSecuritySubnets(netip.MustParsePrefix("1.1.1.1/24")), + WithWebhookSecurityToken("12345"), + WithWebhookMaxConnections(10), WithWebhookRequestIP(func(r *http.Request) string { - ip := r.Header.Get("Cf-Connecting-Ip") - - if ip != "" { - return ip - } - - return tgb.DefaultWebhookRequestIP(r) + return "" }), )