Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "gcp_auth"
version = "0.8.0"
version = "0.8.1"
repository = "https://github.com/hrvolapeter/gcp_auth"
description = "Google cloud platform (GCP) authentication using default and custom service accounts"
documentation = "https://docs.rs/gcp_auth/"
Expand All @@ -23,7 +23,7 @@ hyper-rustls = { version = "0.24", default-features = false, features = ["tokio-
ring = "0.16.20"
rustls = "0.21"
rustls-pemfile = "1.0.0"
serde = {version = "1.0", features = ["derive", "rc"]}
serde = { version = "1.0", features = ["derive", "rc"] }
serde_json = "1.0"
thiserror = "1.0"
time = { version = "0.3.5", features = ["serde"] }
Expand Down
3 changes: 2 additions & 1 deletion src/authentication_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ pub(crate) trait ServiceAccount: Send + Sync {
async fn refresh_token(&self, client: &HyperClient, scopes: &[&str]) -> Result<Token, Error>;
}

/// Authentication manager is responsible for caching and obtaing credentials for the required scope
/// Authentication manager is responsible for caching and obtaining credentials for the required
/// scope
///
/// Construct the authentication manager with [`AuthenticationManager::new()`] or by creating
/// a [`CustomServiceAccount`], then converting it into an `AuthenticationManager` using the `From`
Expand Down
39 changes: 27 additions & 12 deletions src/custom_service_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,30 @@ impl ServiceAccount for CustomServiceAccount {
.extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", jwt.as_str())])
.finish();

let request = hyper::Request::post(&self.credentials.token_uri)
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(hyper::Body::from(rqbody))
.unwrap();

tracing::debug!("requesting token from service account: {:?}", request);
let token = client
.request(request)
.await
.map_err(Error::OAuthConnectionError)?
.deserialize::<Token>()
.await?;
let mut retries = 0;
let response = loop {
let request = hyper::Request::post(&self.credentials.token_uri)
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(hyper::Body::from(rqbody.clone()))
.unwrap();

tracing::debug!("requesting token from service account: {request:?}");
let err = match client.request(request).await {
// Early return when the request succeeds
Ok(response) => break response,
Err(err) => err,
};

tracing::warn!(
"Failed to refresh token with GCP oauth2 token endpoint: {err}, trying again..."
);
retries += 1;
if retries >= RETRY_COUNT {
return Err(Error::OAuthConnectionError(err));
}
};

let token = response.deserialize::<Token>().await?;

let key = scopes.iter().map(|x| (*x).to_string()).collect();
self.tokens.write().unwrap().insert(key, token.clone());
Expand Down Expand Up @@ -154,3 +166,6 @@ impl fmt::Debug for ApplicationCredentials {
.finish()
}
}

/// How many times to attempt to fetch a token from the set credentials token endpoint.
const RETRY_COUNT: u8 = 5;
42 changes: 28 additions & 14 deletions src/default_authorized_user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,31 @@ impl DefaultAuthorizedUser {

#[tracing::instrument]
async fn get_token(cred: &UserCredentials, client: &HyperClient) -> Result<Token, Error> {
let req = Self::build_token_request(&RefreshRequest {
client_id: &cred.client_id,
client_secret: &cred.client_secret,
grant_type: "refresh_token",
refresh_token: &cred.refresh_token,
});

let token = client
.request(req)
.await
.map_err(Error::OAuthConnectionError)?
.deserialize()
.await?;
Ok(token)
let mut retries = 0;
let response = loop {
let req = Self::build_token_request(&RefreshRequest {
client_id: &cred.client_id,
client_secret: &cred.client_secret,
grant_type: "refresh_token",
refresh_token: &cred.refresh_token,
});

let err = match client.request(req).await {
// Early return when the request succeeds
Ok(response) => break response,
Err(err) => err,
};

tracing::warn!(
"Failed to get token from GCP oauth2 token endpoint: {err}, trying again..."
);
retries += 1;
if retries >= RETRY_COUNT {
return Err(Error::OAuthConnectionError(err));
}
};

response.deserialize().await.map_err(Into::into)
}
}

Expand Down Expand Up @@ -106,3 +117,6 @@ struct UserCredentials {
/// Type
pub(crate) r#type: String,
}

/// How many times to attempt to fetch a token from the GCP token endpoint.
const RETRY_COUNT: u8 = 5;
31 changes: 23 additions & 8 deletions src/default_service_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,27 @@ impl DefaultServiceAccount {

#[tracing::instrument]
async fn get_token(client: &HyperClient) -> Result<Token, Error> {
let mut retries = 0;
tracing::debug!("Getting token from GCP instance metadata server");
let req = Self::build_token_request(Self::DEFAULT_TOKEN_GCP_URI);
let token = client
.request(req)
.await
.map_err(Error::ConnectionError)?
.deserialize()
.await?;
Ok(token)
let response = loop {
let req = Self::build_token_request(Self::DEFAULT_TOKEN_GCP_URI);

let err = match client.request(req).await {
// Early return when the request succeeds
Ok(response) => break response,
Err(err) => err,
};

tracing::warn!(
"Failed to get token from GCP instance metadata server: {err}, trying again..."
);
retries += 1;
if retries >= RETRY_COUNT {
return Err(Error::ConnectionError(err));
}
};

response.deserialize().await.map_err(Into::into)
}
}

Expand Down Expand Up @@ -75,3 +87,6 @@ impl ServiceAccount for DefaultServiceAccount {
Ok(token)
}
}

/// How many times to attempt to fetch a token from the GCP metadata server.
const RETRY_COUNT: u8 = 5;
8 changes: 7 additions & 1 deletion src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,14 @@ impl Token {
///
/// This takes an additional 30s margin to ensure the token can still be reasonably used
/// instead of expiring right after having checked.
///
/// Note:
/// The official Python implementation uses 20s and states it should be no more than 30s.
/// The official Go implementation uses 10s (0s for the metadata server).
/// The docs state, the metadata server caches tokens until 5 minutes before expiry.
/// We use 20s to be on the safe side.
pub fn has_expired(&self) -> bool {
self.inner.expires_at - Duration::seconds(30) <= OffsetDateTime::now_utc()
self.inner.expires_at - Duration::seconds(20) <= OffsetDateTime::now_utc()
}

/// Get str representation of the token.
Expand Down