Skip to content

Commit 84d2efa

Browse files
authored
feat(sdk): Refactoring using keyflow.init() (#2889)
1 parent a25fa57 commit 84d2efa

File tree

2 files changed

+253
-147
lines changed

2 files changed

+253
-147
lines changed

core/clients/key_flow_continuous_refresh_test.go

Lines changed: 174 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"net/http"
10+
"net/url"
1011
"testing"
1112
"time"
1213

@@ -21,27 +22,28 @@ func TestContinuousRefreshToken(t *testing.T) {
2122
jwt.TimePrecision = time.Millisecond
2223

2324
// Refresher settings
24-
timeStartBeforeTokenExpiration := 100 * time.Millisecond
25-
timeBetweenContextCheck := 5 * time.Millisecond
26-
timeBetweenTries := 40 * time.Millisecond
25+
timeStartBeforeTokenExpiration := 500 * time.Millisecond
26+
timeBetweenContextCheck := 10 * time.Millisecond
27+
timeBetweenTries := 100 * time.Millisecond
2728

2829
// All generated acess tokens will have this time to live
29-
accessTokensTimeToLive := 200 * time.Millisecond
30+
accessTokensTimeToLive := 1 * time.Second
3031

3132
tests := []struct {
3233
desc string
3334
contextClosesIn time.Duration
3435
doError error
3536
expectedNumberDoCalls int
37+
expectedCallRange []int // Optional: for tests that can have variable call counts
3638
}{
3739
{
3840
desc: "update access token once",
39-
contextClosesIn: 150 * time.Millisecond,
41+
contextClosesIn: 700 * time.Millisecond, // Should allow one refresh
4042
expectedNumberDoCalls: 1,
4143
},
4244
{
4345
desc: "update access token twice",
44-
contextClosesIn: 250 * time.Millisecond,
46+
contextClosesIn: 1300 * time.Millisecond, // Should allow two refreshes
4547
expectedNumberDoCalls: 2,
4648
},
4749
{
@@ -61,25 +63,26 @@ func TestContinuousRefreshToken(t *testing.T) {
6163
},
6264
{
6365
desc: "refresh token fails - non-API error",
64-
contextClosesIn: 250 * time.Millisecond,
66+
contextClosesIn: 700 * time.Millisecond,
6567
doError: fmt.Errorf("something went wrong"),
6668
expectedNumberDoCalls: 1,
6769
},
6870
{
6971
desc: "refresh token fails - API non-5xx error",
70-
contextClosesIn: 250 * time.Millisecond,
72+
contextClosesIn: 700 * time.Millisecond,
7173
doError: &oapierror.GenericOpenAPIError{
7274
StatusCode: http.StatusBadRequest,
7375
},
7476
expectedNumberDoCalls: 1,
7577
},
7678
{
7779
desc: "refresh token fails - API 5xx error",
78-
contextClosesIn: 200 * time.Millisecond,
80+
contextClosesIn: 800 * time.Millisecond,
7981
doError: &oapierror.GenericOpenAPIError{
8082
StatusCode: http.StatusInternalServerError,
8183
},
8284
expectedNumberDoCalls: 3,
85+
expectedCallRange: []int{3, 4}, // Allow 3 or 4 calls due to timing race condition
8386
},
8487
}
8588

@@ -101,19 +104,16 @@ func TestContinuousRefreshToken(t *testing.T) {
101104

102105
numberDoCalls := 0
103106
mockDo := func(_ *http.Request) (resp *http.Response, err error) {
104-
numberDoCalls++
105-
107+
numberDoCalls++ // count refresh attempts
106108
if tt.doError != nil {
107109
return nil, tt.doError
108110
}
109-
110111
newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
111112
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)),
112113
}).SignedString([]byte("test"))
113114
if err != nil {
114115
t.Fatalf("Do call: failed to create access token: %v", err)
115116
}
116-
117117
responseBodyStruct := TokenResponseBody{
118118
AccessToken: newAccessToken,
119119
RefreshToken: refreshToken,
@@ -133,19 +133,34 @@ func TestContinuousRefreshToken(t *testing.T) {
133133
ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn)
134134
defer cancel()
135135

136-
keyFlow := &KeyFlow{
137-
config: &KeyFlowConfig{
138-
BackgroundTokenRefreshContext: ctx,
139-
},
140-
authClient: &http.Client{
136+
keyFlow := &KeyFlow{}
137+
privateKeyBytes, err := generatePrivateKey()
138+
if err != nil {
139+
t.Fatalf("Error generating private key: %s", err)
140+
}
141+
keyFlowConfig := &KeyFlowConfig{
142+
ServiceAccountKey: fixtureServiceAccountKey(),
143+
PrivateKey: string(privateKeyBytes),
144+
AuthHTTPClient: &http.Client{
141145
Transport: mockTransportFn{mockDo},
142146
},
143-
token: &TokenResponseBody{
144-
AccessToken: accessToken,
145-
RefreshToken: refreshToken,
146-
},
147+
HTTPTransport: mockTransportFn{mockDo},
148+
BackgroundTokenRefreshContext: nil,
149+
}
150+
err = keyFlow.Init(keyFlowConfig)
151+
if err != nil {
152+
t.Fatalf("failed to initialize key flow: %v", err)
147153
}
148154

155+
// Set the token after initialization
156+
err = keyFlow.SetToken(accessToken, refreshToken)
157+
if err != nil {
158+
t.Fatalf("failed to set token: %v", err)
159+
}
160+
161+
// Set the context for continuous refresh
162+
keyFlow.config.BackgroundTokenRefreshContext = ctx
163+
149164
refresher := &continuousTokenRefresher{
150165
keyFlow: keyFlow,
151166
timeStartBeforeTokenExpiration: timeStartBeforeTokenExpiration,
@@ -157,7 +172,13 @@ func TestContinuousRefreshToken(t *testing.T) {
157172
if err == nil {
158173
t.Fatalf("routine finished with non-nil error")
159174
}
160-
if numberDoCalls != tt.expectedNumberDoCalls {
175+
176+
// Check if we have a range of expected calls (for timing-sensitive tests)
177+
if tt.expectedCallRange != nil {
178+
if !contains(tt.expectedCallRange, numberDoCalls) {
179+
t.Fatalf("expected %v calls to API to refresh token, got %d", tt.expectedCallRange, numberDoCalls)
180+
}
181+
} else if numberDoCalls != tt.expectedNumberDoCalls {
161182
t.Fatalf("expected %d calls to API to refresh token, got %d", tt.expectedNumberDoCalls, numberDoCalls)
162183
}
163184
})
@@ -194,7 +215,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
194215

195216
// The access token at the start
196217
accessTokenFirst, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
197-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(100 * time.Millisecond)),
218+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Second)),
198219
}).SignedString([]byte("token-first"))
199220
if err != nil {
200221
t.Fatalf("failed to create first access token: %v", err)
@@ -225,60 +246,98 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
225246
ctx, cancel := context.WithCancel(ctx)
226247
defer cancel() // This cancels the refresher goroutine
227248

249+
// Extract host from tokenAPI constant for consistency
250+
tokenURL, _ := url.Parse(tokenAPI)
251+
tokenHost := tokenURL.Host
252+
228253
// The Do() routine, that both the keyFlow and continuousRefreshToken() use to make their requests
229254
// The bools are used to make sure only one request goes through on each test phase
230255
doTestPhase1RequestDone := false
231256
doTestPhase2RequestDone := false
232257
doTestPhase4RequestDone := false
233258
mockDo := func(req *http.Request) (resp *http.Response, err error) {
234-
switch currentTestPhase {
235-
default:
236-
t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase)
237-
return nil, nil
238-
case 1: // Call by continuousRefreshToken()
239-
if doTestPhase1RequestDone {
240-
t.Fatalf("Do call: multiple requests during test phase 1")
241-
}
242-
doTestPhase1RequestDone = true
259+
// Handle auth requests (token refresh)
260+
if req.URL.Host == tokenHost {
261+
switch currentTestPhase {
262+
default:
263+
// After phase 1, allow additional auth requests but don't fail the test
264+
// This handles the continuous nature of the refresh routine
265+
if currentTestPhase > 1 {
266+
// Return a valid response for any additional auth requests
267+
newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
268+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
269+
}).SignedString([]byte("additional-token"))
270+
if err != nil {
271+
t.Fatalf("Do call: failed to create additional access token: %v", err)
272+
}
273+
responseBodyStruct := TokenResponseBody{
274+
AccessToken: newAccessToken,
275+
RefreshToken: refreshToken,
276+
}
277+
responseBody, err := json.Marshal(responseBodyStruct)
278+
if err != nil {
279+
t.Fatalf("Do call: failed to marshal additional response: %v", err)
280+
}
281+
response := &http.Response{
282+
StatusCode: http.StatusOK,
283+
Body: io.NopCloser(bytes.NewReader(responseBody)),
284+
}
285+
return response, nil
286+
}
287+
t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase)
288+
return nil, nil
289+
case 1: // Call by continuousRefreshToken()
290+
if doTestPhase1RequestDone {
291+
t.Fatalf("Do call: multiple requests during test phase 1")
292+
}
293+
doTestPhase1RequestDone = true
243294

244-
currentTestPhase = 2
245-
chanBlockContinuousRefreshToken <- true
295+
currentTestPhase = 2
296+
chanBlockContinuousRefreshToken <- true
246297

247-
// Wait until continuousRefreshToken() is to be unblocked
248-
<-chanUnblockContinuousRefreshToken
298+
// Wait until continuousRefreshToken() is to be unblocked
299+
<-chanUnblockContinuousRefreshToken
249300

250-
if currentTestPhase != 3 {
251-
t.Fatalf("Do call: after unlocking refreshToken(), expected test phase to be 3, got %d", currentTestPhase)
252-
}
301+
if currentTestPhase != 3 {
302+
t.Fatalf("Do call: after unlocking refreshToken(), expected test phase to be 3, got %d", currentTestPhase)
303+
}
253304

254-
// Check required fields are passed
255-
err = req.ParseForm()
256-
if err != nil {
257-
t.Fatalf("Do call: failed to parse body form: %v", err)
258-
}
259-
reqGrantType := req.Form.Get("grant_type")
260-
if reqGrantType != "refresh_token" {
261-
t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "refresh_token", reqGrantType)
262-
}
263-
reqRefreshToken := req.Form.Get("refresh_token")
264-
if reqRefreshToken != refreshToken {
265-
t.Fatalf("Do call: failed request to refresh token: call to refresh access token did not have the expected refresh token set")
266-
}
305+
// Check required fields are passed
306+
err = req.ParseForm()
307+
if err != nil {
308+
t.Fatalf("Do call: failed to parse body form: %v", err)
309+
}
310+
reqGrantType := req.Form.Get("grant_type")
311+
if reqGrantType != "refresh_token" {
312+
t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "refresh_token", reqGrantType)
313+
}
314+
reqRefreshToken := req.Form.Get("refresh_token")
315+
if reqRefreshToken != refreshToken {
316+
t.Fatalf("Do call: failed request to refresh token: call to refresh access token did not have the expected refresh token set")
317+
}
267318

268-
// Return response with accessTokenSecond
269-
responseBodyStruct := TokenResponseBody{
270-
AccessToken: accessTokenSecond,
271-
RefreshToken: refreshToken,
272-
}
273-
responseBody, err := json.Marshal(responseBodyStruct)
274-
if err != nil {
275-
t.Fatalf("Do call: failed request to refresh token: marshal access token response: %v", err)
276-
}
277-
response := &http.Response{
278-
StatusCode: http.StatusOK,
279-
Body: io.NopCloser(bytes.NewReader(responseBody)),
319+
// Return response with accessTokenSecond
320+
responseBodyStruct := TokenResponseBody{
321+
AccessToken: accessTokenSecond,
322+
RefreshToken: refreshToken,
323+
}
324+
responseBody, err := json.Marshal(responseBodyStruct)
325+
if err != nil {
326+
t.Fatalf("Do call: failed request to refresh token: marshal access token response: %v", err)
327+
}
328+
response := &http.Response{
329+
StatusCode: http.StatusOK,
330+
Body: io.NopCloser(bytes.NewReader(responseBody)),
331+
}
332+
return response, nil
280333
}
281-
return response, nil
334+
}
335+
336+
// Handle regular HTTP requests
337+
switch currentTestPhase {
338+
default:
339+
t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase)
340+
return nil, nil
282341
case 2: // Call by tokenFlow, first request
283342
if doTestPhase2RequestDone {
284343
t.Fatalf("Do call: multiple requests during test phase 2")
@@ -292,8 +351,9 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
292351
t.Fatalf("Do call: first request expected to have host %q, found %q", expectedHost, host)
293352
}
294353
authHeader := req.Header.Get("Authorization")
295-
if authHeader != fmt.Sprintf("Bearer %s", accessTokenFirst) {
296-
t.Fatalf("Do call: first request didn't carry first access token")
354+
expectedAuthHeader := fmt.Sprintf("Bearer %s", accessTokenFirst)
355+
if authHeader != expectedAuthHeader {
356+
t.Fatalf("Do call: first request didn't carry first access token. Expected: %s, Got: %s", expectedAuthHeader, authHeader)
297357
}
298358

299359
// Return empty response
@@ -328,23 +388,49 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
328388
}
329389
}
330390

331-
keyFlow := &KeyFlow{
332-
config: &KeyFlowConfig{
333-
BackgroundTokenRefreshContext: ctx,
334-
},
335-
authClient: &http.Client{
391+
keyFlow := &KeyFlow{}
392+
privateKeyBytes, err := generatePrivateKey()
393+
if err != nil {
394+
t.Fatalf("Error generating private key: %s", err)
395+
}
396+
keyFlowConfig := &KeyFlowConfig{
397+
ServiceAccountKey: fixtureServiceAccountKey(),
398+
PrivateKey: string(privateKeyBytes),
399+
AuthHTTPClient: &http.Client{
336400
Transport: mockTransportFn{mockDo},
337401
},
338-
rt: mockTransportFn{mockDo},
339-
token: &TokenResponseBody{
340-
AccessToken: accessTokenFirst,
341-
RefreshToken: refreshToken,
342-
},
402+
HTTPTransport: mockTransportFn{mockDo}, // Use same mock for regular requests
403+
// Don't start continuous refresh automatically
404+
BackgroundTokenRefreshContext: nil,
405+
}
406+
err = keyFlow.Init(keyFlowConfig)
407+
if err != nil {
408+
t.Fatalf("failed to initialize key flow: %v", err)
409+
}
410+
411+
// Set the token after initialization
412+
err = keyFlow.SetToken(accessTokenFirst, refreshToken)
413+
if err != nil {
414+
t.Fatalf("failed to set token: %v", err)
415+
}
416+
417+
// Set the context for continuous refresh
418+
keyFlow.config.BackgroundTokenRefreshContext = ctx
419+
420+
// Create a custom refresher with shorter timing for the test
421+
refresher := &continuousTokenRefresher{
422+
keyFlow: keyFlow,
423+
timeStartBeforeTokenExpiration: 9 * time.Second, // Start 9 seconds before expiration
424+
timeBetweenContextCheck: 5 * time.Millisecond,
425+
timeBetweenTries: 40 * time.Millisecond,
343426
}
344427

345428
// TEST START
346429
currentTestPhase = 1
347-
go continuousRefreshToken(keyFlow)
430+
// Ignore returned error as expected in test
431+
go func() {
432+
_ = refresher.continuousRefreshToken()
433+
}()
348434

349435
// Wait until continuousRefreshToken() is blocked
350436
<-chanBlockContinuousRefreshToken
@@ -389,3 +475,12 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
389475
t.Fatalf("Second request body failed to close: %v", err)
390476
}
391477
}
478+
479+
func contains(arr []int, val int) bool {
480+
for _, v := range arr {
481+
if v == val {
482+
return true
483+
}
484+
}
485+
return false
486+
}

0 commit comments

Comments
 (0)