Skip to content

Instantly share code, notes, and snippets.

@Shikugawa
Created October 26, 2025 03:55
Show Gist options
  • Select an option

  • Save Shikugawa/9e5cba6c056ca734400d4591d670982d to your computer and use it in GitHub Desktop.

Select an option

Save Shikugawa/9e5cba6c056ca734400d4591d670982d to your computer and use it in GitHub Desktop.
use reqwest::blocking::Client;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fs;
use std::path::Path;
/// Athenz ZTS OAuth2 Access Token Response
#[derive(Debug, Deserialize)]
pub struct AccessTokenResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub id_token: Option<String>,
}
/// Athenz Access Token Request Parameters
#[derive(Debug, Serialize)]
pub struct AccessTokenRequest {
grant_type: String,
scope: String,
#[serde(skip_serializing_if = "Option::is_none")]
expires_in: Option<u64>,
}
impl AccessTokenRequest {
/// Create a new access token request
///
/// # Arguments
/// * `domain` - Athenz domain name
/// * `roles` - Optional list of specific roles to request (comma-separated)
/// * `expires_in` - Optional expiration time in seconds (default: 7200)
pub fn new(domain: &str, roles: Option<Vec<&str>>, expires_in: Option<u64>) -> Self {
let scope = if let Some(role_list) = roles {
// Format: openid domain:role.role1 domain:role.role2
let role_scopes: Vec<String> = role_list
.iter()
.map(|role| format!("{}:role.{}", domain, role))
.collect();
format!("openid {}", role_scopes.join(" "))
} else {
// Request all roles in the domain
format!("openid {}:domain", domain)
};
Self {
grant_type: "client_credentials".to_string(),
scope,
expires_in,
}
}
}
/// Athenz Access Token Client
pub struct AthenzClient {
zts_url: String,
cert_path: String,
key_path: String,
ca_cert_path: Option<String>,
}
impl AthenzClient {
/// Create a new Athenz client
///
/// # Arguments
/// * `zts_url` - ZTS server URL (e.g., "https://zts.athenz.io/zts/v1")
/// * `cert_path` - Path to the service identity certificate (PEM format)
/// * `key_path` - Path to the private key (PEM format)
/// * `ca_cert_path` - Optional path to CA certificate bundle
pub fn new(
zts_url: impl Into<String>,
cert_path: impl Into<String>,
key_path: impl Into<String>,
ca_cert_path: Option<String>,
) -> Self {
Self {
zts_url: zts_url.into(),
cert_path: cert_path.into(),
key_path: key_path.into(),
ca_cert_path,
}
}
/// Request an access token from ZTS
///
/// # Arguments
/// * `domain` - Athenz domain name
/// * `roles` - Optional list of specific roles to request
/// * `expires_in` - Optional expiration time in seconds (default: 7200)
pub fn get_access_token(
&self,
domain: &str,
roles: Option<Vec<&str>>,
expires_in: Option<u64>,
) -> Result<AccessTokenResponse, Box<dyn Error>> {
// Load the certificate and key
let cert_pem = fs::read(&self.cert_path)?;
let key_pem = fs::read(&self.key_path)?;
// Create identity from certificate and key
let identity = reqwest::Identity::from_pem(&[cert_pem, key_pem].concat())?;
// Build the client with mTLS
let mut client_builder = Client::builder().identity(identity);
// Add CA certificate if provided
if let Some(ca_path) = &self.ca_cert_path {
if Path::new(ca_path).exists() {
let ca_cert = fs::read(ca_path)?;
let cert = reqwest::Certificate::from_pem(&ca_cert)?;
client_builder = client_builder.add_root_certificate(cert);
}
}
let client = client_builder.build()?;
// Create the access token request
let token_request = AccessTokenRequest::new(domain, roles, expires_in);
// Make the POST request to ZTS OAuth2 token endpoint
let endpoint = format!("{}/oauth2/token", self.zts_url);
let response = client
.post(&endpoint)
.form(&token_request)
.send()?;
// Check if the request was successful
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
return Err(format!("ZTS request failed with status {}: {}", status, error_text).into());
}
// Parse the response
let token_response: AccessTokenResponse = response.json()?;
Ok(token_response)
}
/// Convenience method to get an access token for all roles in a domain
pub fn get_domain_token(&self, domain: &str) -> Result<AccessTokenResponse, Box<dyn Error>> {
self.get_access_token(domain, None, None)
}
/// Convenience method to get an access token for specific roles
pub fn get_role_token(
&self,
domain: &str,
roles: Vec<&str>,
) -> Result<AccessTokenResponse, Box<dyn Error>> {
self.get_access_token(domain, Some(roles), None)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_access_token_request_all_roles() {
let request = AccessTokenRequest::new("my-domain", None, None);
assert_eq!(request.grant_type, "client_credentials");
assert_eq!(request.scope, "openid my-domain:domain");
}
#[test]
fn test_access_token_request_specific_roles() {
let request = AccessTokenRequest::new(
"my-domain",
Some(vec!["reader", "writer"]),
None,
);
assert_eq!(request.grant_type, "client_credentials");
assert_eq!(request.scope, "openid my-domain:role.reader my-domain:role.writer");
}
}
// Example usage
fn main() -> Result<(), Box<dyn Error>> {
// Initialize the Athenz client with your service identity certificate
let client = AthenzClient::new(
"https://your-zts-server.com/zts/v1",
"/var/lib/sia/certs/my-domain.my-service.cert.pem",
"/var/lib/sia/keys/my-domain.my-service.key.pem",
Some("/path/to/ca-bundle.pem".to_string()),
);
// Example 1: Get access token for all roles in a domain
println!("=== Getting access token for all roles ===");
match client.get_domain_token("my-domain") {
Ok(response) => {
println!("Access Token: {}", response.access_token);
println!("Token Type: {}", response.token_type);
println!("Expires In: {} seconds", response.expires_in);
if let Some(id_token) = response.id_token {
println!("ID Token: {}", id_token);
}
}
Err(e) => eprintln!("Error: {}", e),
}
// Example 2: Get access token for specific roles
println!("\n=== Getting access token for specific roles ===");
match client.get_role_token("my-domain", vec!["reader", "writer"]) {
Ok(response) => {
println!("Access Token: {}", response.access_token);
println!("Expires In: {} seconds", response.expires_in);
}
Err(e) => eprintln!("Error: {}", e),
}
// Example 3: Get access token with custom expiration
println!("\n=== Getting access token with custom expiration ===");
match client.get_access_token("my-domain", None, Some(3600)) {
Ok(response) => {
println!("Access Token: {}", response.access_token);
println!("Expires In: {} seconds", response.expires_in);
}
Err(e) => eprintln!("Error: {}", e),
}
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment