relentless_http/
factory.rs

1use bytes::Bytes;
2use http::{
3    header::{CONTENT_LENGTH, CONTENT_TYPE},
4    HeaderMap,
5};
6use http_body::Body;
7use mime::{Mime, APPLICATION_JSON, TEXT_PLAIN};
8use serde::{Deserialize, Serialize};
9#[cfg(feature = "json")]
10use serde_json::Value;
11
12use relentless::{
13    assault::factory::RequestFactory,
14    error::IntoResult,
15    interface::{
16        helper::{coalesce::Coalesce, http_serde_priv, is_default::IsDefault},
17        template::Template,
18    },
19};
20
21#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
22#[serde(deny_unknown_fields, rename_all = "kebab-case")]
23pub struct HttpRequest {
24    #[serde(default, skip_serializing_if = "IsDefault::is_default")]
25    #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
26    pub no_additional_headers: bool,
27    #[serde(default, skip_serializing_if = "IsDefault::is_default")]
28    #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
29    pub method: Option<http_serde_priv::Method>,
30    #[serde(default, skip_serializing_if = "IsDefault::is_default")]
31    #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
32    pub headers: Option<http_serde_priv::HeaderMap>,
33    #[serde(default, skip_serializing_if = "IsDefault::is_default")]
34    #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
35    pub body: HttpBody,
36}
37#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
38#[serde(deny_unknown_fields, rename_all = "kebab-case")]
39pub enum HttpBody {
40    #[default]
41    Empty,
42    Plaintext(String),
43    #[cfg(feature = "json")]
44    Json(Value),
45}
46impl HttpBody {
47    pub fn body_with_headers<ReqB>(&self, template: &Template) -> relentless::Result<(ReqB, HeaderMap)>
48    where
49        ReqB: Body,
50        Self: BodyFactory<ReqB>,
51        <Self as BodyFactory<ReqB>>::Error: std::error::Error + Send + Sync + 'static,
52    {
53        let mut headers = HeaderMap::new();
54        self.content_type()
55            .map(|t| headers.insert(CONTENT_TYPE, t.as_ref().parse().unwrap_or_else(|_| unreachable!())));
56        let body = self.produce(template).box_err()?;
57        body.size_hint().exact().filter(|size| *size > 0).map(|size| headers.insert(CONTENT_LENGTH, size.into())); // TODO remove ?
58        Ok((body, headers))
59    }
60    pub fn content_type(&self) -> Option<Mime> {
61        match self {
62            HttpBody::Empty => None,
63            HttpBody::Plaintext(_) => Some(TEXT_PLAIN),
64            #[cfg(feature = "json")]
65            HttpBody::Json(_) => Some(APPLICATION_JSON),
66        }
67    }
68}
69
70impl Coalesce for HttpRequest {
71    fn coalesce(self, other: &Self) -> Self {
72        Self {
73            no_additional_headers: self.no_additional_headers || other.no_additional_headers,
74            method: self.method.or(other.method.clone()),
75            headers: self.headers.or(other.headers.clone()),
76            body: self.body.coalesce(&other.body),
77        }
78    }
79}
80impl Coalesce for HttpBody {
81    fn coalesce(self, other: &Self) -> Self {
82        match self {
83            HttpBody::Empty => other.clone(),
84            _ => self,
85        }
86    }
87}
88
89impl<B, S> RequestFactory<http::Request<B>, S> for HttpRequest
90where
91    B: Body,
92    HttpBody: BodyFactory<B>,
93    <HttpBody as BodyFactory<B>>::Error: std::error::Error + Send + Sync + 'static,
94{
95    type Error = relentless::Error;
96
97    async fn produce(
98        &self,
99        _service: S,
100        destination: &http::Uri,
101        target: &str,
102        template: &Template,
103    ) -> Result<http::Request<B>, Self::Error> {
104        let HttpRequest { no_additional_headers, method, headers, body } = self;
105        let uri =
106            http::uri::Builder::from(destination.clone()).path_and_query(template.render(target)?).build().box_err()?;
107        let unwrapped_method = method.as_ref().map(|m| (**m).clone()).unwrap_or_default();
108        let unwrapped_headers: HeaderMap = headers
109            .as_ref()
110            .map(|h| {
111                (**h)
112                    .clone()
113                    .into_iter()
114                    .fold((None, HeaderMap::default()), |(prev, mut map), (k, v)| {
115                        // duplicate key will cause None https://docs.rs/http/latest/http/header/struct.HeaderMap.html#impl-IntoIterator-for-HeaderMap%3CT%3E
116                        let curr = k.or(prev);
117                        map.insert(
118                            curr.as_ref().unwrap_or_else(|| unreachable!()),
119                            template.render_as_string(v.clone()).unwrap_or(v),
120                        );
121                        (curr, map)
122                    })
123                    .1
124            })
125            .unwrap_or_default();
126        let (actual_body, additional_headers) = body.clone().body_with_headers(template)?;
127
128        let mut request = http::Request::builder().uri(uri).method(unwrapped_method).body(actual_body).box_err()?;
129        let header_map = request.headers_mut();
130        header_map.extend(unwrapped_headers);
131        if !no_additional_headers {
132            header_map.extend(additional_headers);
133        }
134        Ok(request)
135    }
136}
137
138pub trait BodyFactory<B: Body> {
139    type Error;
140    fn produce(&self, template: &Template) -> Result<B, Self::Error>;
141}
142impl<B> BodyFactory<B> for HttpBody
143where
144    B: Body + From<Bytes> + Default,
145{
146    type Error = relentless::Error;
147    fn produce(&self, template: &Template) -> Result<B, Self::Error> {
148        match self {
149            HttpBody::Empty => Ok(Default::default()),
150            HttpBody::Plaintext(s) => Ok(Bytes::from(template.render(s).unwrap_or(s.to_string())).into()),
151            #[cfg(feature = "json")]
152            HttpBody::Json(v) => {
153                Ok(Bytes::from(serde_json::to_vec(&template.render_json_recursive(v).as_ref().unwrap_or(v)).box_err()?)
154                    .into())
155            }
156        }
157    }
158}