Last active
January 27, 2026 21:31
-
-
Save darrenmothersele/f99af74fb7c2bfea3ccdb8a7d6124c8d to your computer and use it in GitHub Desktop.
Create a HTTP client from a Restate service.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use proc_macro::TokenStream; | |
| use quote::{format_ident, quote}; | |
| use syn::{parse_macro_input, ItemTrait, ReturnType, TraitItem}; | |
| /// Generates a Restate HTTP client from a service trait. | |
| /// | |
| /// Apply this macro to a trait that's also annotated with `#[restate_sdk::service]`. | |
| /// It will generate a `{TraitName}HttpClient` struct with methods that call the | |
| /// Restate ingress HTTP API. | |
| /// | |
| /// # Example | |
| /// | |
| /// ```ignore | |
| /// #[restate_client] | |
| /// #[restate_sdk::service] | |
| /// pub trait Greeter { | |
| /// async fn greet(name: String) -> HandlerResult<String>; | |
| /// } | |
| /// | |
| /// // Generated: | |
| /// // pub struct GreeterHttpClient { ... } | |
| /// // impl GreeterHttpClient { | |
| /// // pub fn new(base_url: impl Into<String>) -> Self { ... } | |
| /// // pub async fn greet(&self, name: String) -> Result<String, GreeterHttpClientError> { ... } | |
| /// // } | |
| /// ``` | |
| #[proc_macro_attribute] | |
| pub fn restate_client(_attr: TokenStream, item: TokenStream) -> TokenStream { | |
| let input = parse_macro_input!(item as ItemTrait); | |
| let trait_name = &input.ident; | |
| let service_name = trait_name.to_string(); | |
| let client_name = format_ident!("{}HttpClient", trait_name); | |
| let error_name = format_ident!("{}HttpClientError", trait_name); | |
| let mut client_methods = Vec::new(); | |
| for item in &input.items { | |
| if let TraitItem::Fn(method) = item { | |
| let method_name = &method.sig.ident; | |
| let method_name_str = method_name.to_string(); | |
| // Extract the single parameter (Restate handlers accept 0 or 1 param) | |
| let param = method.sig.inputs.iter().find_map(|arg| { | |
| if let syn::FnArg::Typed(pat_type) = arg { | |
| Some(pat_type) | |
| } else { | |
| None | |
| } | |
| }); | |
| // Extract return type from HandlerResult<T> | |
| let return_type = match &method.sig.output { | |
| ReturnType::Type(_, ty) => extract_inner_type(ty), | |
| ReturnType::Default => quote!(()), | |
| }; | |
| // Unwrap Json<T> from return type for deserialization | |
| let return_type_unwrapped = unwrap_json_type(&return_type); | |
| let (method_impl, param_signature) = if let Some(p) = param { | |
| let param_name = &p.pat; | |
| let param_type = &p.ty; | |
| // Unwrap Json<T> to just T for the client signature | |
| let param_type_unwrapped = unwrap_json_type_from_syn(param_type); | |
| let impl_code = quote! { | |
| let url = format!("{}/{}/{}", self.base_url, #service_name, #method_name_str); | |
| let body = serde_json::to_string(&#param_name).map_err(#error_name::Serialization)?; | |
| let response = self.client | |
| .post(&url) | |
| .header("Content-Type", "application/json") | |
| .body(body) | |
| .send() | |
| .await | |
| .map_err(#error_name::Request)?; | |
| if !response.status().is_success() { | |
| let status = response.status(); | |
| let text = response.text().await.unwrap_or_default(); | |
| return Err(#error_name::Status(status.as_u16(), text)); | |
| } | |
| let text = response.text().await.map_err(#error_name::Request)?; | |
| serde_json::from_str(&text).map_err(#error_name::Deserialization) | |
| }; | |
| (impl_code, quote!(#param_name: #param_type_unwrapped)) | |
| } else { | |
| // No parameters | |
| let impl_code = quote! { | |
| let url = format!("{}/{}/{}", self.base_url, #service_name, #method_name_str); | |
| let response = self.client | |
| .post(&url) | |
| .header("Content-Type", "application/json") | |
| .send() | |
| .await | |
| .map_err(#error_name::Request)?; | |
| if !response.status().is_success() { | |
| let status = response.status(); | |
| let text = response.text().await.unwrap_or_default(); | |
| return Err(#error_name::Status(status.as_u16(), text)); | |
| } | |
| let text = response.text().await.map_err(#error_name::Request)?; | |
| serde_json::from_str(&text).map_err(#error_name::Deserialization) | |
| }; | |
| (impl_code, quote!()) | |
| }; | |
| client_methods.push(quote! { | |
| pub async fn #method_name(&self, #param_signature) -> Result<#return_type_unwrapped, #error_name> { | |
| #method_impl | |
| } | |
| }); | |
| } | |
| } | |
| let expanded = quote! { | |
| #input | |
| #[derive(Debug)] | |
| pub enum #error_name { | |
| Request(reqwest::Error), | |
| Serialization(serde_json::Error), | |
| Deserialization(serde_json::Error), | |
| Status(u16, String), | |
| } | |
| impl std::fmt::Display for #error_name { | |
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | |
| match self { | |
| Self::Request(e) => write!(f, "request error: {}", e), | |
| Self::Serialization(e) => write!(f, "serialization error: {}", e), | |
| Self::Deserialization(e) => write!(f, "deserialization error: {}", e), | |
| Self::Status(code, msg) => write!(f, "HTTP {}: {}", code, msg), | |
| } | |
| } | |
| } | |
| impl std::error::Error for #error_name {} | |
| #[derive(Clone)] | |
| pub struct #client_name { | |
| base_url: String, | |
| client: reqwest::Client, | |
| } | |
| impl #client_name { | |
| pub fn new(base_url: impl Into<String>) -> Self { | |
| Self { | |
| base_url: base_url.into(), | |
| client: reqwest::Client::new(), | |
| } | |
| } | |
| pub fn with_client(base_url: impl Into<String>, client: reqwest::Client) -> Self { | |
| Self { | |
| base_url: base_url.into(), | |
| client, | |
| } | |
| } | |
| #(#client_methods)* | |
| } | |
| }; | |
| TokenStream::from(expanded) | |
| } | |
| /// Extract the inner type T from HandlerResult<T> or Result<T, E> | |
| fn extract_inner_type(ty: &syn::Type) -> proc_macro2::TokenStream { | |
| if let syn::Type::Path(type_path) = ty { | |
| if let Some(segment) = type_path.path.segments.last() { | |
| if segment.ident == "HandlerResult" || segment.ident == "Result" { | |
| if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { | |
| if let Some(syn::GenericArgument::Type(inner)) = args.args.first() { | |
| return quote!(#inner); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| quote!(#ty) | |
| } | |
| /// Unwrap Json<T> to T from a token stream | |
| fn unwrap_json_type(tokens: &proc_macro2::TokenStream) -> proc_macro2::TokenStream { | |
| let s = tokens.to_string(); | |
| if s.starts_with("Json <") || s.starts_with("Json<") { | |
| // Parse and extract inner type | |
| if let Ok(ty) = syn::parse2::<syn::Type>(tokens.clone()) { | |
| return unwrap_json_type_from_syn(&ty); | |
| } | |
| } | |
| tokens.clone() | |
| } | |
| /// Unwrap Json<T> to T from a syn::Type | |
| fn unwrap_json_type_from_syn(ty: &syn::Type) -> proc_macro2::TokenStream { | |
| if let syn::Type::Path(type_path) = ty { | |
| if let Some(segment) = type_path.path.segments.last() { | |
| if segment.ident == "Json" { | |
| if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { | |
| if let Some(syn::GenericArgument::Type(inner)) = args.args.first() { | |
| return quote!(#inner); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| quote!(#ty) | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment