Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion tgb/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ type Webhook struct {
securitySubnets []netip.Prefix
securityToken string

ipFromRequestFunc func(r *http.Request) string

isSetup bool
}

Expand Down Expand Up @@ -66,6 +68,18 @@ func WithWebhookIP(ip string) WebhookOption {
}
}

// DefaultWebhookRequestIP is the default function to get the IP address from the request.
// By default the IP address is resolved through the X-Real-Ip and X-Forwarded-For headers.
var DefaultWebhookRequestIP = realip.FromRequest

// WithWebhookRequestIP sets function to get the IP address from the request.
// By default the IP address is resolved through the X-Real-Ip and X-Forwarded-For headers.
func WithWebhookRequestIP(ip func(r *http.Request) string) WebhookOption {
return func(webhook *Webhook) {
webhook.ipFromRequestFunc = ip
}
}

// WithWebhookSecuritySubnets sets list of subnets which are allowed to send webhook requests.
func WithWebhookSecuritySubnets(subnets ...netip.Prefix) WebhookOption {
return func(webhook *Webhook) {
Expand Down Expand Up @@ -117,6 +131,8 @@ func NewWebhook(handler Handler, client *tg.Client, url string, options ...Webho
allowedUpdates: []tg.UpdateType{},
securitySubnets: defaultSubnets,
securityToken: token,

ipFromRequestFunc: DefaultWebhookRequestIP,
}

for _, option := range options {
Expand Down Expand Up @@ -322,7 +338,7 @@ func (webhook *Webhook) ServeRequest(ctx context.Context, r *WebhookRequest) *We
// ServeHTTP is the HTTP handler for webhook requests.
// Implementation of http.Handler.
func (webhook *Webhook) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ip, err := netip.ParseAddr(realip.FromRequest(r))
ip, err := netip.ParseAddr(webhook.ipFromRequestFunc(r))
if err != nil {
webhook.log("failed to parse ip: %s", err)
http.Error(w, "failed to parse ip", http.StatusBadRequest)
Expand Down
6 changes: 6 additions & 0 deletions tgb/webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ func TestNewWebhook(t *testing.T) {

assert.Equal(t, "https://example.com/webhook", webhook.url)
assert.NotNil(t, webhook.handler)
assert.NotNil(t, webhook.ipFromRequestFunc)

assert.NotNil(t, webhook.securityToken)
assert.Len(t, webhook.securitySubnets, 2)
})
Expand All @@ -39,10 +41,14 @@ func TestNewWebhook(t *testing.T) {
WithWebhookSecuritySubnets(netip.MustParsePrefix("1.1.1.1/24")),
WithWebhookSecurityToken("12345"),
WithWebhookMaxConnections(10),
WithWebhookRequestIP(func(r *http.Request) string {
return ""
}),
)

assert.Equal(t, "https://example.com/webhook", webhook.url)
assert.NotNil(t, webhook.handler)
assert.NotNil(t, webhook.ipFromRequestFunc)
assert.Equal(t, "12345", webhook.securityToken)
assert.Len(t, webhook.securitySubnets, 1)
assert.Equal(t, 10, webhook.maxConnections)
Expand Down