1use bytes::Bytes;
2use http::{
3 header::{CONTENT_LENGTH, CONTENT_TYPE},
4 HeaderMap,
5};
6use http_body::Body;
7use mime::{Mime, APPLICATION_JSON, TEXT_PLAIN};
8use serde::{Deserialize, Serialize};
9#[cfg(feature = "json")]
10use serde_json::Value;
11
12use relentless::{
13 assault::factory::RequestFactory,
14 error::IntoResult,
15 interface::{
16 helper::{coalesce::Coalesce, http_serde_priv, is_default::IsDefault},
17 template::Template,
18 },
19};
20
21#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
22#[serde(deny_unknown_fields, rename_all = "kebab-case")]
23pub struct HttpRequest {
24 #[serde(default, skip_serializing_if = "IsDefault::is_default")]
25 #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
26 pub no_additional_headers: bool,
27 #[serde(default, skip_serializing_if = "IsDefault::is_default")]
28 #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
29 pub method: Option<http_serde_priv::Method>,
30 #[serde(default, skip_serializing_if = "IsDefault::is_default")]
31 #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
32 pub headers: Option<http_serde_priv::HeaderMap>,
33 #[serde(default, skip_serializing_if = "IsDefault::is_default")]
34 #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
35 pub body: HttpBody,
36}
37#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
38#[serde(deny_unknown_fields, rename_all = "kebab-case")]
39pub enum HttpBody {
40 #[default]
41 Empty,
42 Plaintext(String),
43 #[cfg(feature = "json")]
44 Json(Value),
45}
46impl HttpBody {
47 pub fn body_with_headers<ReqB>(&self, template: &Template) -> relentless::Result<(ReqB, HeaderMap)>
48 where
49 ReqB: Body,
50 Self: BodyFactory<ReqB>,
51 <Self as BodyFactory<ReqB>>::Error: std::error::Error + Send + Sync + 'static,
52 {
53 let mut headers = HeaderMap::new();
54 self.content_type()
55 .map(|t| headers.insert(CONTENT_TYPE, t.as_ref().parse().unwrap_or_else(|_| unreachable!())));
56 let body = self.produce(template).box_err()?;
57 body.size_hint().exact().filter(|size| *size > 0).map(|size| headers.insert(CONTENT_LENGTH, size.into())); Ok((body, headers))
59 }
60 pub fn content_type(&self) -> Option<Mime> {
61 match self {
62 HttpBody::Empty => None,
63 HttpBody::Plaintext(_) => Some(TEXT_PLAIN),
64 #[cfg(feature = "json")]
65 HttpBody::Json(_) => Some(APPLICATION_JSON),
66 }
67 }
68}
69
70impl Coalesce for HttpRequest {
71 fn coalesce(self, other: &Self) -> Self {
72 Self {
73 no_additional_headers: self.no_additional_headers || other.no_additional_headers,
74 method: self.method.or(other.method.clone()),
75 headers: self.headers.or(other.headers.clone()),
76 body: self.body.coalesce(&other.body),
77 }
78 }
79}
80impl Coalesce for HttpBody {
81 fn coalesce(self, other: &Self) -> Self {
82 match self {
83 HttpBody::Empty => other.clone(),
84 _ => self,
85 }
86 }
87}
88
89impl<B, S> RequestFactory<http::Request<B>, S> for HttpRequest
90where
91 B: Body,
92 HttpBody: BodyFactory<B>,
93 <HttpBody as BodyFactory<B>>::Error: std::error::Error + Send + Sync + 'static,
94{
95 type Error = relentless::Error;
96
97 async fn produce(
98 &self,
99 _service: S,
100 destination: &http::Uri,
101 target: &str,
102 template: &Template,
103 ) -> Result<http::Request<B>, Self::Error> {
104 let HttpRequest { no_additional_headers, method, headers, body } = self;
105 let uri =
106 http::uri::Builder::from(destination.clone()).path_and_query(template.render(target)?).build().box_err()?;
107 let unwrapped_method = method.as_ref().map(|m| (**m).clone()).unwrap_or_default();
108 let unwrapped_headers: HeaderMap = headers
109 .as_ref()
110 .map(|h| {
111 (**h)
112 .clone()
113 .into_iter()
114 .fold((None, HeaderMap::default()), |(prev, mut map), (k, v)| {
115 let curr = k.or(prev);
117 map.insert(
118 curr.as_ref().unwrap_or_else(|| unreachable!()),
119 template.render_as_string(v.clone()).unwrap_or(v),
120 );
121 (curr, map)
122 })
123 .1
124 })
125 .unwrap_or_default();
126 let (actual_body, additional_headers) = body.clone().body_with_headers(template)?;
127
128 let mut request = http::Request::builder().uri(uri).method(unwrapped_method).body(actual_body).box_err()?;
129 let header_map = request.headers_mut();
130 header_map.extend(unwrapped_headers);
131 if !no_additional_headers {
132 header_map.extend(additional_headers);
133 }
134 Ok(request)
135 }
136}
137
138pub trait BodyFactory<B: Body> {
139 type Error;
140 fn produce(&self, template: &Template) -> Result<B, Self::Error>;
141}
142impl<B> BodyFactory<B> for HttpBody
143where
144 B: Body + From<Bytes> + Default,
145{
146 type Error = relentless::Error;
147 fn produce(&self, template: &Template) -> Result<B, Self::Error> {
148 match self {
149 HttpBody::Empty => Ok(Default::default()),
150 HttpBody::Plaintext(s) => Ok(Bytes::from(template.render(s).unwrap_or(s.to_string())).into()),
151 #[cfg(feature = "json")]
152 HttpBody::Json(v) => {
153 Ok(Bytes::from(serde_json::to_vec(&template.render_json_recursive(v).as_ref().unwrap_or(v)).box_err()?)
154 .into())
155 }
156 }
157 }
158}