Skip to content

Instantly share code, notes, and snippets.

@smallfish
Last active October 19, 2021 13:07
Show Gist options
  • Select an option

  • Save smallfish/615783794808a01bf30851168b910937 to your computer and use it in GitHub Desktop.

Select an option

Save smallfish/615783794808a01bf30851168b910937 to your computer and use it in GitHub Desktop.
axum route with default layer and tower::ServiceBuilder layer
[dependencies]
tokio = { version = "1.0", features = ["full"] }
futures = "0.3.1"
axum = { version = "0.2", features = ["headers"] }
tower = { version = "0.4", features = ["full"] }
log = "0.4.0"
env_logger = "0.8.4"
async-trait = "0.1.51"
use async_trait::async_trait;
use axum::extract::{FromRequest, RequestParts};
use axum::http::StatusCode;
use log::info;
fn main() {
println!("hello");
}
struct ExtractAaa;
#[async_trait]
impl<B> FromRequest<B> for ExtractAaa
where
B: Send,
{
type Rejection = StatusCode;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
info!("extract::aaa headers {:?}", req.headers());
req.headers_mut()
.unwrap()
.insert("x-extra-aaa", "aaa".parse().unwrap());
Ok(Self)
}
}
struct ExtractBbb;
#[async_trait]
impl<B> FromRequest<B> for ExtractBbb
where
B: Send,
{
type Rejection = StatusCode;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
info!("extract::bbb headers {:?}", req.headers());
req.headers_mut()
.unwrap()
.insert("x-extra-bbb", "bbb".parse().unwrap());
Ok(Self)
}
}
struct ExtractCcc;
#[async_trait]
impl<B> FromRequest<B> for ExtractCcc
where
B: Send,
{
type Rejection = StatusCode;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
info!("extract::ccc headers {:?}", req.headers());
req.headers_mut()
.unwrap()
.insert("x-extra-ccc", "ccc".parse().unwrap());
Ok(Self)
}
}
#[cfg(test)]
mod tests {
use crate::{ExtractAaa, ExtractBbb, ExtractCcc};
use axum::body::Body;
use axum::extract::extractor_middleware;
use axum::handler::get;
use axum::http::{Method, Request};
use axum::Router;
use env_logger;
use tower::{ServiceBuilder, ServiceExt};
async fn hello() -> &'static str {
"helloworld"
}
fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}
#[tokio::test]
async fn test_axum_default_route() {
init();
let app = Router::new()
.route("/", get(hello))
.layer(extractor_middleware::<ExtractAaa>())
.layer(extractor_middleware::<ExtractBbb>())
.layer(extractor_middleware::<ExtractCcc>());
let req = Request::builder()
.uri("/")
.method(Method::GET)
.body(Body::empty());
let _res = app.oneshot(req.unwrap()).await.unwrap();
}
#[tokio::test]
async fn test_axum_tower_service() {
init();
let service_layer = ServiceBuilder::new()
.layer(extractor_middleware::<ExtractAaa>())
.layer(extractor_middleware::<ExtractBbb>())
.layer(extractor_middleware::<ExtractCcc>())
.into_inner();
let app = Router::new().route("/", get(hello)).layer(service_layer);
let req = Request::builder()
.uri("/")
.method(Method::GET)
.body(Body::empty());
let _res = app.oneshot(req.unwrap()).await.unwrap();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment