Skip to content

Commit 026c5e9

Browse files
committed
implement retries for fetching tokens from the various endpoints
The official SDKs implement a retry mechanism for fetching the tokens from the metadata server in the case of I/O errors, etc. This adds a similar mechanism to the provided ServiceAccount implementations
1 parent cdfcfdd commit 026c5e9

File tree

3 files changed

+79
-34
lines changed

3 files changed

+79
-34
lines changed

src/custom_service_account.rs

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,27 @@ impl ServiceAccount for CustomServiceAccount {
104104
.extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", jwt.as_str())])
105105
.finish();
106106

107-
let request = hyper::Request::post(&self.credentials.token_uri)
108-
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
109-
.body(hyper::Body::from(rqbody))
110-
.unwrap();
111-
112-
tracing::debug!("requesting token from service account: {:?}", request);
113-
let token = client
114-
.request(request)
115-
.await
116-
.map_err(Error::OAuthConnectionError)?
117-
.deserialize::<Token>()
118-
.await?;
107+
let mut retries = 0;
108+
let res = loop {
109+
let request = hyper::Request::post(&self.credentials.token_uri)
110+
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
111+
.body(hyper::Body::from(rqbody.clone()))
112+
.unwrap();
113+
114+
tracing::debug!("requesting token from service account: {:?}", request);
115+
let res = client.request(request).await;
116+
match res {
117+
Ok(res) => break Ok(res),
118+
Err(e) => {
119+
retries += 1;
120+
if retries >= RETRY_COUNT {
121+
break Err(Error::OAuthConnectionError(e));
122+
}
123+
}
124+
}
125+
}?;
126+
127+
let token = res.deserialize::<Token>().await?;
119128

120129
let key = scopes.iter().map(|x| (*x).to_string()).collect();
121130
self.tokens.write().unwrap().insert(key, token.clone());
@@ -154,3 +163,6 @@ impl fmt::Debug for ApplicationCredentials {
154163
.finish()
155164
}
156165
}
166+
167+
/// How many times to attempt to fetch a token from the set credentials token endpoint.
168+
const RETRY_COUNT: u8 = 5;

src/default_authorized_user.rs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,33 @@ impl DefaultAuthorizedUser {
4848

4949
#[tracing::instrument]
5050
async fn get_token(cred: &UserCredentials, client: &HyperClient) -> Result<Token, Error> {
51-
let req = Self::build_token_request(&RefreshRequest {
52-
client_id: &cred.client_id,
53-
client_secret: &cred.client_secret,
54-
grant_type: "refresh_token",
55-
refresh_token: &cred.refresh_token,
56-
});
57-
58-
let token = client
59-
.request(req)
60-
.await
61-
.map_err(Error::OAuthConnectionError)?
62-
.deserialize()
63-
.await?;
64-
Ok(token)
51+
let mut retries = 0;
52+
let res = loop {
53+
let req = Self::build_token_request(&RefreshRequest {
54+
client_id: &cred.client_id,
55+
client_secret: &cred.client_secret,
56+
grant_type: "refresh_token",
57+
refresh_token: &cred.refresh_token,
58+
});
59+
let res = client.request(req).await;
60+
61+
match res {
62+
Ok(res) => break Ok(res),
63+
Err(e) => {
64+
tracing::warn!(
65+
"Failed to get token from GCP instance metadata server: {}, trying again...",
66+
e
67+
);
68+
69+
retries += 1;
70+
if retries >= RETRY_COUNT {
71+
break Err(Error::OAuthConnectionError(e));
72+
}
73+
}
74+
}
75+
}?;
76+
77+
res.deserialize().await.map_err(Into::into)
6578
}
6679
}
6780

@@ -106,3 +119,6 @@ struct UserCredentials {
106119
/// Type
107120
pub(crate) r#type: String,
108121
}
122+
123+
/// How many times to attempt to fetch a token from the GCP token endpoint.
124+
const RETRY_COUNT: u8 = 5;

src/default_service_account.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,29 @@ impl DefaultServiceAccount {
3636

3737
#[tracing::instrument]
3838
async fn get_token(client: &HyperClient) -> Result<Token, Error> {
39+
let mut retries = 0;
3940
tracing::debug!("Getting token from GCP instance metadata server");
40-
let req = Self::build_token_request(Self::DEFAULT_TOKEN_GCP_URI);
41-
let token = client
42-
.request(req)
43-
.await
44-
.map_err(Error::ConnectionError)?
45-
.deserialize()
46-
.await?;
47-
Ok(token)
41+
let res = loop {
42+
let req = Self::build_token_request(Self::DEFAULT_TOKEN_GCP_URI);
43+
let res = client.request(req).await;
44+
45+
match res {
46+
Ok(res) => break Ok(res),
47+
Err(e) => {
48+
tracing::warn!(
49+
"Failed to get token from GCP instance metadata server: {}, trying again...",
50+
e
51+
);
52+
53+
retries += 1;
54+
if retries >= RETRY_COUNT {
55+
break Err(Error::ConnectionError(e));
56+
}
57+
}
58+
}
59+
}?;
60+
61+
res.deserialize().await.map_err(Into::into)
4862
}
4963
}
5064

@@ -75,3 +89,6 @@ impl ServiceAccount for DefaultServiceAccount {
7589
Ok(token)
7690
}
7791
}
92+
93+
/// How many times to attempt to fetch a token from the GCP metadata server.
94+
const RETRY_COUNT: u8 = 5;

0 commit comments

Comments
 (0)