Skip to content

Instantly share code, notes, and snippets.

@darrenmothersele
Last active January 27, 2026 21:31
Show Gist options
  • Select an option

  • Save darrenmothersele/f99af74fb7c2bfea3ccdb8a7d6124c8d to your computer and use it in GitHub Desktop.

Select an option

Save darrenmothersele/f99af74fb7c2bfea3ccdb8a7d6124c8d to your computer and use it in GitHub Desktop.
Create a HTTP client from a Restate service.
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, ItemTrait, ReturnType, TraitItem};
/// Generates a Restate HTTP client from a service trait.
///
/// Apply this macro to a trait that's also annotated with `#[restate_sdk::service]`.
/// It will generate a `{TraitName}HttpClient` struct with methods that call the
/// Restate ingress HTTP API.
///
/// # Example
///
/// ```ignore
/// #[restate_client]
/// #[restate_sdk::service]
/// pub trait Greeter {
/// async fn greet(name: String) -> HandlerResult<String>;
/// }
///
/// // Generated:
/// // pub struct GreeterHttpClient { ... }
/// // impl GreeterHttpClient {
/// // pub fn new(base_url: impl Into<String>) -> Self { ... }
/// // pub async fn greet(&self, name: String) -> Result<String, GreeterHttpClientError> { ... }
/// // }
/// ```
#[proc_macro_attribute]
pub fn restate_client(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemTrait);
let trait_name = &input.ident;
let service_name = trait_name.to_string();
let client_name = format_ident!("{}HttpClient", trait_name);
let error_name = format_ident!("{}HttpClientError", trait_name);
let mut client_methods = Vec::new();
for item in &input.items {
if let TraitItem::Fn(method) = item {
let method_name = &method.sig.ident;
let method_name_str = method_name.to_string();
// Extract the single parameter (Restate handlers accept 0 or 1 param)
let param = method.sig.inputs.iter().find_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
Some(pat_type)
} else {
None
}
});
// Extract return type from HandlerResult<T>
let return_type = match &method.sig.output {
ReturnType::Type(_, ty) => extract_inner_type(ty),
ReturnType::Default => quote!(()),
};
// Unwrap Json<T> from return type for deserialization
let return_type_unwrapped = unwrap_json_type(&return_type);
let (method_impl, param_signature) = if let Some(p) = param {
let param_name = &p.pat;
let param_type = &p.ty;
// Unwrap Json<T> to just T for the client signature
let param_type_unwrapped = unwrap_json_type_from_syn(param_type);
let impl_code = quote! {
let url = format!("{}/{}/{}", self.base_url, #service_name, #method_name_str);
let body = serde_json::to_string(&#param_name).map_err(#error_name::Serialization)?;
let response = self.client
.post(&url)
.header("Content-Type", "application/json")
.body(body)
.send()
.await
.map_err(#error_name::Request)?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(#error_name::Status(status.as_u16(), text));
}
let text = response.text().await.map_err(#error_name::Request)?;
serde_json::from_str(&text).map_err(#error_name::Deserialization)
};
(impl_code, quote!(#param_name: #param_type_unwrapped))
} else {
// No parameters
let impl_code = quote! {
let url = format!("{}/{}/{}", self.base_url, #service_name, #method_name_str);
let response = self.client
.post(&url)
.header("Content-Type", "application/json")
.send()
.await
.map_err(#error_name::Request)?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(#error_name::Status(status.as_u16(), text));
}
let text = response.text().await.map_err(#error_name::Request)?;
serde_json::from_str(&text).map_err(#error_name::Deserialization)
};
(impl_code, quote!())
};
client_methods.push(quote! {
pub async fn #method_name(&self, #param_signature) -> Result<#return_type_unwrapped, #error_name> {
#method_impl
}
});
}
}
let expanded = quote! {
#input
#[derive(Debug)]
pub enum #error_name {
Request(reqwest::Error),
Serialization(serde_json::Error),
Deserialization(serde_json::Error),
Status(u16, String),
}
impl std::fmt::Display for #error_name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Request(e) => write!(f, "request error: {}", e),
Self::Serialization(e) => write!(f, "serialization error: {}", e),
Self::Deserialization(e) => write!(f, "deserialization error: {}", e),
Self::Status(code, msg) => write!(f, "HTTP {}: {}", code, msg),
}
}
}
impl std::error::Error for #error_name {}
#[derive(Clone)]
pub struct #client_name {
base_url: String,
client: reqwest::Client,
}
impl #client_name {
pub fn new(base_url: impl Into<String>) -> Self {
Self {
base_url: base_url.into(),
client: reqwest::Client::new(),
}
}
pub fn with_client(base_url: impl Into<String>, client: reqwest::Client) -> Self {
Self {
base_url: base_url.into(),
client,
}
}
#(#client_methods)*
}
};
TokenStream::from(expanded)
}
/// Extract the inner type T from HandlerResult<T> or Result<T, E>
fn extract_inner_type(ty: &syn::Type) -> proc_macro2::TokenStream {
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "HandlerResult" || segment.ident == "Result" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
return quote!(#inner);
}
}
}
}
}
quote!(#ty)
}
/// Unwrap Json<T> to T from a token stream
fn unwrap_json_type(tokens: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let s = tokens.to_string();
if s.starts_with("Json <") || s.starts_with("Json<") {
// Parse and extract inner type
if let Ok(ty) = syn::parse2::<syn::Type>(tokens.clone()) {
return unwrap_json_type_from_syn(&ty);
}
}
tokens.clone()
}
/// Unwrap Json<T> to T from a syn::Type
fn unwrap_json_type_from_syn(ty: &syn::Type) -> proc_macro2::TokenStream {
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Json" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
return quote!(#inner);
}
}
}
}
}
quote!(#ty)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment