Created
October 26, 2025 03:55
-
-
Save Shikugawa/9e5cba6c056ca734400d4591d670982d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use reqwest::blocking::Client; | |
| use serde::{Deserialize, Serialize}; | |
| use std::error::Error; | |
| use std::fs; | |
| use std::path::Path; | |
| /// Athenz ZTS OAuth2 Access Token Response | |
| #[derive(Debug, Deserialize)] | |
| pub struct AccessTokenResponse { | |
| pub access_token: String, | |
| pub token_type: String, | |
| pub expires_in: u64, | |
| #[serde(skip_serializing_if = "Option::is_none")] | |
| pub id_token: Option<String>, | |
| } | |
| /// Athenz Access Token Request Parameters | |
| #[derive(Debug, Serialize)] | |
| pub struct AccessTokenRequest { | |
| grant_type: String, | |
| scope: String, | |
| #[serde(skip_serializing_if = "Option::is_none")] | |
| expires_in: Option<u64>, | |
| } | |
| impl AccessTokenRequest { | |
| /// Create a new access token request | |
| /// | |
| /// # Arguments | |
| /// * `domain` - Athenz domain name | |
| /// * `roles` - Optional list of specific roles to request (comma-separated) | |
| /// * `expires_in` - Optional expiration time in seconds (default: 7200) | |
| pub fn new(domain: &str, roles: Option<Vec<&str>>, expires_in: Option<u64>) -> Self { | |
| let scope = if let Some(role_list) = roles { | |
| // Format: openid domain:role.role1 domain:role.role2 | |
| let role_scopes: Vec<String> = role_list | |
| .iter() | |
| .map(|role| format!("{}:role.{}", domain, role)) | |
| .collect(); | |
| format!("openid {}", role_scopes.join(" ")) | |
| } else { | |
| // Request all roles in the domain | |
| format!("openid {}:domain", domain) | |
| }; | |
| Self { | |
| grant_type: "client_credentials".to_string(), | |
| scope, | |
| expires_in, | |
| } | |
| } | |
| } | |
| /// Athenz Access Token Client | |
| pub struct AthenzClient { | |
| zts_url: String, | |
| cert_path: String, | |
| key_path: String, | |
| ca_cert_path: Option<String>, | |
| } | |
| impl AthenzClient { | |
| /// Create a new Athenz client | |
| /// | |
| /// # Arguments | |
| /// * `zts_url` - ZTS server URL (e.g., "https://zts.athenz.io/zts/v1") | |
| /// * `cert_path` - Path to the service identity certificate (PEM format) | |
| /// * `key_path` - Path to the private key (PEM format) | |
| /// * `ca_cert_path` - Optional path to CA certificate bundle | |
| pub fn new( | |
| zts_url: impl Into<String>, | |
| cert_path: impl Into<String>, | |
| key_path: impl Into<String>, | |
| ca_cert_path: Option<String>, | |
| ) -> Self { | |
| Self { | |
| zts_url: zts_url.into(), | |
| cert_path: cert_path.into(), | |
| key_path: key_path.into(), | |
| ca_cert_path, | |
| } | |
| } | |
| /// Request an access token from ZTS | |
| /// | |
| /// # Arguments | |
| /// * `domain` - Athenz domain name | |
| /// * `roles` - Optional list of specific roles to request | |
| /// * `expires_in` - Optional expiration time in seconds (default: 7200) | |
| pub fn get_access_token( | |
| &self, | |
| domain: &str, | |
| roles: Option<Vec<&str>>, | |
| expires_in: Option<u64>, | |
| ) -> Result<AccessTokenResponse, Box<dyn Error>> { | |
| // Load the certificate and key | |
| let cert_pem = fs::read(&self.cert_path)?; | |
| let key_pem = fs::read(&self.key_path)?; | |
| // Create identity from certificate and key | |
| let identity = reqwest::Identity::from_pem(&[cert_pem, key_pem].concat())?; | |
| // Build the client with mTLS | |
| let mut client_builder = Client::builder().identity(identity); | |
| // Add CA certificate if provided | |
| if let Some(ca_path) = &self.ca_cert_path { | |
| if Path::new(ca_path).exists() { | |
| let ca_cert = fs::read(ca_path)?; | |
| let cert = reqwest::Certificate::from_pem(&ca_cert)?; | |
| client_builder = client_builder.add_root_certificate(cert); | |
| } | |
| } | |
| let client = client_builder.build()?; | |
| // Create the access token request | |
| let token_request = AccessTokenRequest::new(domain, roles, expires_in); | |
| // Make the POST request to ZTS OAuth2 token endpoint | |
| let endpoint = format!("{}/oauth2/token", self.zts_url); | |
| let response = client | |
| .post(&endpoint) | |
| .form(&token_request) | |
| .send()?; | |
| // Check if the request was successful | |
| if !response.status().is_success() { | |
| let status = response.status(); | |
| let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string()); | |
| return Err(format!("ZTS request failed with status {}: {}", status, error_text).into()); | |
| } | |
| // Parse the response | |
| let token_response: AccessTokenResponse = response.json()?; | |
| Ok(token_response) | |
| } | |
| /// Convenience method to get an access token for all roles in a domain | |
| pub fn get_domain_token(&self, domain: &str) -> Result<AccessTokenResponse, Box<dyn Error>> { | |
| self.get_access_token(domain, None, None) | |
| } | |
| /// Convenience method to get an access token for specific roles | |
| pub fn get_role_token( | |
| &self, | |
| domain: &str, | |
| roles: Vec<&str>, | |
| ) -> Result<AccessTokenResponse, Box<dyn Error>> { | |
| self.get_access_token(domain, Some(roles), None) | |
| } | |
| } | |
| #[cfg(test)] | |
| mod tests { | |
| use super::*; | |
| #[test] | |
| fn test_access_token_request_all_roles() { | |
| let request = AccessTokenRequest::new("my-domain", None, None); | |
| assert_eq!(request.grant_type, "client_credentials"); | |
| assert_eq!(request.scope, "openid my-domain:domain"); | |
| } | |
| #[test] | |
| fn test_access_token_request_specific_roles() { | |
| let request = AccessTokenRequest::new( | |
| "my-domain", | |
| Some(vec!["reader", "writer"]), | |
| None, | |
| ); | |
| assert_eq!(request.grant_type, "client_credentials"); | |
| assert_eq!(request.scope, "openid my-domain:role.reader my-domain:role.writer"); | |
| } | |
| } | |
| // Example usage | |
| fn main() -> Result<(), Box<dyn Error>> { | |
| // Initialize the Athenz client with your service identity certificate | |
| let client = AthenzClient::new( | |
| "https://your-zts-server.com/zts/v1", | |
| "/var/lib/sia/certs/my-domain.my-service.cert.pem", | |
| "/var/lib/sia/keys/my-domain.my-service.key.pem", | |
| Some("/path/to/ca-bundle.pem".to_string()), | |
| ); | |
| // Example 1: Get access token for all roles in a domain | |
| println!("=== Getting access token for all roles ==="); | |
| match client.get_domain_token("my-domain") { | |
| Ok(response) => { | |
| println!("Access Token: {}", response.access_token); | |
| println!("Token Type: {}", response.token_type); | |
| println!("Expires In: {} seconds", response.expires_in); | |
| if let Some(id_token) = response.id_token { | |
| println!("ID Token: {}", id_token); | |
| } | |
| } | |
| Err(e) => eprintln!("Error: {}", e), | |
| } | |
| // Example 2: Get access token for specific roles | |
| println!("\n=== Getting access token for specific roles ==="); | |
| match client.get_role_token("my-domain", vec!["reader", "writer"]) { | |
| Ok(response) => { | |
| println!("Access Token: {}", response.access_token); | |
| println!("Expires In: {} seconds", response.expires_in); | |
| } | |
| Err(e) => eprintln!("Error: {}", e), | |
| } | |
| // Example 3: Get access token with custom expiration | |
| println!("\n=== Getting access token with custom expiration ==="); | |
| match client.get_access_token("my-domain", None, Some(3600)) { | |
| Ok(response) => { | |
| println!("Access Token: {}", response.access_token); | |
| println!("Expires In: {} seconds", response.expires_in); | |
| } | |
| Err(e) => eprintln!("Error: {}", e), | |
| } | |
| Ok(()) | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment