Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 8 additions & 94 deletions v2/api_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,22 @@ package lark

import (
"context"
"sync/atomic"
"time"
)

// URLs for auth
const (
appAccessTokenInternalURL = "/open-apis/auth/v3/app_access_token/internal"
tenantAppAccessTokenInternalURL = "/open-apis/auth/v3/tenant_access_token/internal/"
tenantAccessTokenInternalURL = "/open-apis/auth/v3/tenant_access_token/internal"
)

// AuthTokenInternalResponse .
type AuthTokenInternalResponse struct {
// TenantAccessTokenInternalResponse .
type TenantAccessTokenInternalResponse struct {
BaseResponse
AppAccessToken string `json:"app_access_token"`
Expire int `json:"expire"`
}

// TenantAuthTokenInternalResponse .
type TenantAuthTokenInternalResponse struct {
BaseResponse
TenantAppAccessToken string `json:"tenant_access_token"`
Expire int `json:"expire"`
}

// GetAccessTokenInternal gets AppAccessToken for internal use
func (bot *Bot) GetAccessTokenInternal(ctx context.Context, updateToken bool) (*AuthTokenInternalResponse, error) {
if !bot.requireType(ChatBot) {
return nil, ErrBotTypeError
}

params := map[string]interface{}{
"app_id": bot.appID,
"app_secret": bot.appSecret,
}
var respData AuthTokenInternalResponse
err := bot.PostAPIRequest(ctx, "GetAccessTokenInternal", appAccessTokenInternalURL, false, params, &respData)
if err == nil && updateToken {
bot.accessToken.Store(respData.AppAccessToken)
}
return &respData, err
TenantAccessToken string `json:"tenant_access_token"`
Expire int `json:"expire"`
}

// GetTenantAccessTokenInternal gets AppAccessToken for internal use
func (bot *Bot) GetTenantAccessTokenInternal(ctx context.Context, updateToken bool) (*TenantAuthTokenInternalResponse, error) {
func (bot *Bot) GetTenantAccessTokenInternal(ctx context.Context) (*TenantAccessTokenInternalResponse, error) {
if !bot.requireType(ChatBot) {
return nil, ErrBotTypeError
}
Expand All @@ -54,65 +26,7 @@ func (bot *Bot) GetTenantAccessTokenInternal(ctx context.Context, updateToken bo
"app_id": bot.appID,
"app_secret": bot.appSecret,
}
var respData TenantAuthTokenInternalResponse
err := bot.PostAPIRequest(ctx, "GetTenantAccessTokenInternal", tenantAppAccessTokenInternalURL, false, params, &respData)
if err == nil && updateToken {
bot.tenantAccessToken.Store(respData.TenantAppAccessToken)
}
var respData TenantAccessTokenInternalResponse
err := bot.PostAPIRequest(ctx, "GetTenantAccessTokenInternal", tenantAccessTokenInternalURL, false, params, &respData)
return &respData, err
}

// StopHeartbeat stops auto-renew
func (bot *Bot) StopHeartbeat() {
bot.heartbeat <- true
}

// StartHeartbeat renews auth token periodically
func (bot *Bot) StartHeartbeat() error {
return bot.startHeartbeat(10 * time.Second)
}

// SetHeartbeatCtx assigns a context for heartbeat
func (bot *Bot) SetHeartbeatCtx(ctx context.Context) {
bot.heartbeatCtx = ctx
}

func (bot *Bot) startHeartbeat(defaultInterval time.Duration) error {
if !bot.requireType(ChatBot) {
return ErrBotTypeError
}

// First initialize the token in blocking mode
if bot.heartbeatCtx == nil {
return ErrHeartbeatContextNotSet
}
_, err := bot.GetTenantAccessTokenInternal(bot.heartbeatCtx, true)
if err != nil {
bot.httpErrorLog(bot.heartbeatCtx, "Heartbeat", "failed to get tenant access token", err)
return err
}
atomic.AddInt64(&bot.heartbeatCounter, 1)

interval := defaultInterval
bot.heartbeat = make(chan bool)
go func() {
for {
t := time.NewTimer(interval)
select {
case <-bot.heartbeat:
return
case <-t.C:
interval = defaultInterval
resp, err := bot.GetTenantAccessTokenInternal(bot.heartbeatCtx, true)
if err != nil {
bot.httpErrorLog(bot.heartbeatCtx, "Heartbeat", "failed to get tenant access token", err)
}
atomic.AddInt64(&bot.heartbeatCounter, 1)
if resp != nil && resp.Expire-20 > 0 {
interval = time.Duration(resp.Expire-20) * time.Second
}
}
}
}()
return nil
}
52 changes: 4 additions & 48 deletions v2/api_auth_test.go
Original file line number Diff line number Diff line change
@@ -1,62 +1,18 @@
package lark

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestAuthAccessTokenInternal(t *testing.T) {
func TestGetTenantAccessTokenInternal(t *testing.T) {
bot := newTestBot()
resp, err := bot.GetAccessTokenInternal(t.Context(), true)
resp, err := bot.GetTenantAccessTokenInternal(t.Context())
if assert.NoError(t, err) {
assert.Equal(t, 0, resp.Code)
assert.NotEmpty(t, resp.AppAccessToken)
t.Log(resp.AppAccessToken)
assert.NotEmpty(t, resp.TenantAccessToken)
assert.NotEmpty(t, resp.Expire)
t.Log(resp)
}
}

func TestAuthTenantAccessTokenInternal(t *testing.T) {
bot := newTestBot()
resp, err := bot.GetTenantAccessTokenInternal(t.Context(), true)
if assert.NoError(t, err) {
assert.Equal(t, 0, resp.Code)
assert.NotEmpty(t, resp.TenantAppAccessToken)
t.Log(resp.TenantAppAccessToken)
assert.NotEmpty(t, resp.Expire)
}
}

func TestHeartbeat(t *testing.T) {
bot := newTestBot()
assert.Nil(t, bot.heartbeat)
assert.Nil(t, bot.startHeartbeat(time.Second*1))
assert.NotEmpty(t, bot.tenantAccessToken)
assert.Equal(t, int64(1), bot.heartbeatCounter)
time.Sleep(2 * time.Second)
assert.Equal(t, int64(2), bot.heartbeatCounter)
bot.StopHeartbeat()
time.Sleep(2 * time.Second)
assert.Equal(t, int64(2), bot.heartbeatCounter)
// restart heartbeat
assert.Nil(t, bot.startHeartbeat(time.Second*1))
time.Sleep(2 * time.Second)
assert.Equal(t, int64(4), bot.heartbeatCounter)
}

func TestInvalidHeartbeat(t *testing.T) {
bot := NewNotificationBot("")
err := bot.StartHeartbeat()
assert.Error(t, err, ErrBotTypeError)
}

func TestSetHeartbeatContext(t *testing.T) {
bot := newTestBot()
assert.Equal(t, "context.Background", fmt.Sprintf("%s", bot.heartbeatCtx))
bot.SetHeartbeatCtx(context.TODO())
assert.Equal(t, "context.TODO", fmt.Sprintf("%s", bot.heartbeatCtx))
}
2 changes: 0 additions & 2 deletions v2/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@ var (
ErrParamExceedInputLimit = errors.New("Param error: Exceed input limit")
ErrMessageTypeNotSuppored = errors.New("Message type not supported")
ErrEncryptionNotEnabled = errors.New("Encryption is not enabled")
ErrCustomHTTPClientNotSet = errors.New("Custom HTTP client not set")
ErrMessageNotBuild = errors.New("Message not build")
ErrUnsupportedUIDType = errors.New("Unsupported UID type")
ErrInvalidReceiveID = errors.New("Invalid receive ID")
ErrEventTypeNotMatch = errors.New("Event type not match")
ErrMessageType = errors.New("Message type error")
ErrHeartbeatContextNotSet = errors.New("Heartbeat context not set")
)

// APIError constructs an error with given response
Expand Down
77 changes: 52 additions & 25 deletions v2/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
"io"
"net/http"
"net/http/httputil"
"time"
)

// HTTPClient is an interface handling http requests
type HTTPClient interface {
Do(ctx context.Context, req *http.Request) (*http.Response, error)
}

// ExpandURL expands url path to full url
func (bot Bot) ExpandURL(urlPath string) string {
url := fmt.Sprintf("%s%s", bot.domain, urlPath)
Expand All @@ -20,6 +26,34 @@
bot.logger.Log(ctx, LogLevelError, fmt.Sprintf("[%s] %s: %+v\n", prefix, text, err))
}

func (bot *Bot) loadAndRenewToken(ctx context.Context) (string, error) {
now := time.Now()
// check token
token, ok := bot.tenantAccessToken.Load().(TenantAccessToken)
tenantAccessToken := token.TenantAccessToken
if !ok || token.TenantAccessToken == "" || (token.EstimatedExpireAt != nil && now.After(*token.EstimatedExpireAt)) {
// renew token
if bot.autoRenew {
tacResp, err := bot.GetTenantAccessTokenInternal(ctx)
if err != nil {
return "", err

Check warning on line 39 in v2/http.go

View check run for this annotation

Codecov / codecov/patch

v2/http.go#L39

Added line #L39 was not covered by tests
}
now := time.Now()
expire := time.Duration(tacResp.Expire - 10)
eta := now.Add(expire)
token := TenantAccessToken{
TenantAccessToken: tacResp.TenantAccessToken,
Expire: expire,
LastUpdatedAt: &now,
EstimatedExpireAt: &eta,
}
bot.tenantAccessToken.Store(token)
tenantAccessToken = tacResp.TenantAccessToken
}
}
return tenantAccessToken, nil
}

// PerformAPIRequest performs API request
func (bot Bot) PerformAPIRequest(
ctx context.Context,
Expand All @@ -38,35 +72,28 @@
header = make(http.Header)
}
if auth {
header.Add("Authorization", fmt.Sprintf("Bearer %s", bot.TenantAccessToken()))
}
if bot.useCustomClient {
if bot.customClient == nil {
return ErrCustomHTTPClientNotSet
}
respBody, err = bot.customClient.Do(ctx, method, url, header, body)
if err != nil {
bot.httpErrorLog(ctx, prefix, "call failed", err)
return err
}
} else {
req, err := http.NewRequestWithContext(ctx, method, url, body)
tenantAccessToken, err := bot.loadAndRenewToken(ctx)
if err != nil {
bot.httpErrorLog(ctx, prefix, "init request failed", err)
return err
}
req.Header = header
resp, err := bot.client.Do(req)
if err != nil {
bot.httpErrorLog(ctx, prefix, "call failed", err)
return err
}
if bot.debug {
b, _ := httputil.DumpResponse(resp, true)
bot.logger.Log(ctx, LogLevelDebug, string(b))
}
respBody = resp.Body
header.Add("Authorization", fmt.Sprintf("Bearer %s", tenantAccessToken))
}
req, err := http.NewRequest(method, url, body)
if err != nil {
bot.httpErrorLog(ctx, prefix, "init request failed", err)
return err

Check warning on line 84 in v2/http.go

View check run for this annotation

Codecov / codecov/patch

v2/http.go#L83-L84

Added lines #L83 - L84 were not covered by tests
}
req.Header = header
resp, err := bot.client.Do(ctx, req)
if err != nil {
bot.httpErrorLog(ctx, prefix, "call failed", err)
return err

Check warning on line 90 in v2/http.go

View check run for this annotation

Codecov / codecov/patch

v2/http.go#L89-L90

Added lines #L89 - L90 were not covered by tests
}
if bot.debug {
b, _ := httputil.DumpResponse(resp, true)
bot.logger.Log(ctx, LogLevelDebug, string(b))

Check warning on line 94 in v2/http.go

View check run for this annotation

Codecov / codecov/patch

v2/http.go#L93-L94

Added lines #L93 - L94 were not covered by tests
}
respBody = resp.Body
defer respBody.Close()
buffer, err := io.ReadAll(respBody)
if err != nil {
Expand Down
27 changes: 27 additions & 0 deletions v2/http_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package lark

import (
"context"
"net/http"
"time"
)

// defaultClient .
type defaultClient struct {
c *http.Client
}

// newDefaultClient .
func newDefaultClient() *defaultClient {
return &defaultClient{
c: &http.Client{
Timeout: 5 * time.Second,
},
}
}

// Do .
func (dc defaultClient) Do(ctx context.Context, req *http.Request) (*http.Response, error) {
req.WithContext(ctx)
return dc.c.Do(req)
}
17 changes: 0 additions & 17 deletions v2/http_wrapper.go

This file was deleted.

Loading
Loading