1use bytes::Bytes;
2use http_body::Body;
3use http_body_util::{BodyExt, Collected};
4use serde::{Deserialize, Serialize};
5
6#[cfg(feature = "json")]
7use relentless::assault::evaluator::json::JsonEvaluator;
8use relentless::{
9 assault::{
10 destinations::{AllOr, Destinations},
11 evaluate::{Acceptable, Evaluate},
12 evaluator::plaintext::PlaintextEvaluator,
13 messages::Messages,
14 result::RequestResult,
15 },
16 interface::helper::{coalesce::Coalesce, http_serde_priv, is_default::IsDefault},
17};
18
19use super::error::HttpEvaluateError;
20
21#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
22#[serde(deny_unknown_fields, rename_all = "kebab-case")]
23pub struct HttpResponse {
24 #[serde(default, skip_serializing_if = "IsDefault::is_default")]
25 #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
26 pub status: HttpStatus,
27 #[serde(default, skip_serializing_if = "IsDefault::is_default")]
28 #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
29 pub header: HttpHeaders,
30 #[serde(default, skip_serializing_if = "IsDefault::is_default")]
31 #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
32 pub body: HttpBody,
33}
34#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
35#[serde(deny_unknown_fields, rename_all = "kebab-case")]
36pub enum HttpStatus {
37 #[default]
38 OkOrEqual,
39 Expect(AllOr<http_serde_priv::StatusCode>),
40 Ignore,
41}
42#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
43#[serde(deny_unknown_fields, rename_all = "kebab-case")]
44pub enum HttpHeaders {
45 #[default]
46 AnyOrEqual,
47 Expect(AllOr<http_serde_priv::HeaderMap>),
48 Ignore,
49}
50#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
51#[serde(deny_unknown_fields, rename_all = "kebab-case")]
52pub enum HttpBody {
53 #[default]
54 AnyOrEqual,
55 Plaintext(PlaintextEvaluator),
56 #[cfg(feature = "json")]
57 Json(JsonEvaluator),
58}
59
60impl Coalesce for HttpResponse {
61 fn coalesce(self, other: &Self) -> Self {
62 Self {
63 status: self.status.coalesce(&other.status),
64 header: self.header.coalesce(&other.header),
65 body: self.body.coalesce(&other.body),
66 }
67 }
68}
69impl Coalesce for HttpStatus {
70 fn coalesce(self, other: &Self) -> Self {
71 if self.is_default() {
72 other.clone()
73 } else {
74 self
75 }
76 }
77}
78impl Coalesce for HttpHeaders {
79 fn coalesce(self, other: &Self) -> Self {
80 if self.is_default() {
81 other.clone()
82 } else {
83 self
84 }
85 }
86}
87impl Coalesce for HttpBody {
88 fn coalesce(self, other: &Self) -> Self {
89 if self.is_default() {
90 other.clone()
91 } else {
92 self
93 }
94 }
95}
96
97impl<B> Evaluate<http::Response<B>> for HttpResponse
98where
99 B: Body,
100 B::Error: std::error::Error + Sync + Send + 'static,
101{
102 type Message = HttpEvaluateError;
103 async fn evaluate(
104 &self,
105 res: Destinations<RequestResult<http::Response<B>>>,
106 msg: &mut Messages<Self::Message>,
107 ) -> bool {
108 let Some(responses) = msg.response_destinations_with(res, HttpEvaluateError::RequestError) else {
109 return false;
110 };
111 let Some(parts) = msg.push_if_err(HttpResponse::unzip_parts(responses).await) else {
112 return false;
113 };
114
115 self.accept(&parts, msg)
116 }
117}
118impl Acceptable<(http::StatusCode, http::HeaderMap, Bytes)> for HttpResponse {
119 type Message = HttpEvaluateError;
120 fn accept(
121 &self,
122 parts: &Destinations<(http::StatusCode, http::HeaderMap, Bytes)>,
123 msg: &mut Messages<Self::Message>,
124 ) -> bool {
125 let (mut status, mut headers, mut body) = (Destinations::new(), Destinations::new(), Destinations::new());
126 for (name, (s, h, b)) in parts {
127 status.insert(name.clone(), s);
128 headers.insert(name.clone(), h);
129 body.insert(name.clone(), b);
130 }
131 self.status.accept(&status, msg) && self.header.accept(&headers, msg) && self.body.accept(&body, msg)
132 }
133}
134impl HttpResponse {
135 pub async fn unzip_parts<B>(
136 responses: Destinations<http::Response<B>>,
137 ) -> Result<Destinations<(http::StatusCode, http::HeaderMap, Bytes)>, HttpEvaluateError>
138 where
139 B: Body,
140 B::Error: std::error::Error + Sync + Send + 'static,
141 {
142 let mut parts = Destinations::new();
143 for (name, response) in responses {
144 let (http::response::Parts { status, headers, .. }, body) = response.into_parts();
145 let bytes = BodyExt::collect(body)
146 .await
147 .map(Collected::to_bytes)
148 .map_err(|e| HttpEvaluateError::FailToCollectBody(e.into()))?;
149 parts.insert(name, (status, headers, bytes));
150 }
151 Ok(parts)
152 }
153}
154
155impl Acceptable<&http::StatusCode> for HttpStatus {
156 type Message = HttpEvaluateError;
157 fn accept(&self, status: &Destinations<&http::StatusCode>, msg: &mut Messages<Self::Message>) -> bool {
158 let acceptable = match &self {
159 HttpStatus::OkOrEqual => Self::assault_or_compare(status, |(_, s)| s.is_success()),
160 HttpStatus::Expect(AllOr::All(code)) => Self::validate_all(status, |(_, s)| s == &&**code),
161 HttpStatus::Expect(AllOr::Destinations(code)) => {
162 status == &code.iter().map(|(d, c)| (d.to_string(), &**c)).collect()
164 }
165 HttpStatus::Ignore => true,
166 };
167 if !acceptable {
168 msg.push_err(HttpEvaluateError::UnacceptableStatus);
169 }
170 acceptable
171 }
172}
173
174impl Acceptable<&http::HeaderMap> for HttpHeaders {
175 type Message = HttpEvaluateError;
176 fn accept(&self, headers: &Destinations<&http::HeaderMap>, msg: &mut Messages<Self::Message>) -> bool {
177 let acceptable = match &self {
178 HttpHeaders::AnyOrEqual => Self::assault_or_compare(headers, |_| true),
179 HttpHeaders::Expect(AllOr::All(header)) => Self::validate_all(headers, |(_, h)| h == &&**header),
180 HttpHeaders::Expect(AllOr::Destinations(header)) => {
181 headers == &header.iter().map(|(d, h)| (d.to_string(), &**h)).collect()
183 }
184 HttpHeaders::Ignore => true,
185 };
186 if !acceptable {
187 msg.push_err(HttpEvaluateError::UnacceptableHeaderMap);
188 }
189 acceptable
190 }
191}
192
193impl Acceptable<&Bytes> for HttpBody {
194 type Message = HttpEvaluateError;
195 fn accept(&self, body: &Destinations<&Bytes>, msg: &mut Messages<Self::Message>) -> bool {
196 match &self {
197 HttpBody::AnyOrEqual => Self::assault_or_compare(body, |_| true),
198 HttpBody::Plaintext(p) => Self::sub_accept(p, body, msg, HttpEvaluateError::PlaintextEvaluateError),
199 #[cfg(feature = "json")]
200 HttpBody::Json(e) => Self::sub_accept(e, body, msg, HttpEvaluateError::JsonEvaluateError),
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use std::time::{Instant, SystemTime};
208
209 use relentless::assault::measure::metrics::MeasuredResponse;
210
211 use super::*;
212
213 #[tokio::test]
214 async fn test_default_assault_evaluate() {
215 let evaluator = HttpResponse::default();
216
217 let ok =
218 http::Response::builder().status(http::StatusCode::OK).body(http_body_util::Empty::<Bytes>::new()).unwrap();
219 let responses = Destinations::from_iter(vec![(
220 "test".to_string(),
221 Ok(MeasuredResponse::new(ok, SystemTime::now(), (Instant::now(), Instant::now()))),
222 )]);
223 let mut msg = Messages::new();
224 let result = evaluator.evaluate(responses, &mut msg).await;
225 assert!(result);
226 assert!(msg.is_empty());
227
228 let unavailable = http::Response::builder()
229 .status(http::StatusCode::SERVICE_UNAVAILABLE)
230 .body(http_body_util::Empty::<Bytes>::new())
231 .unwrap();
232 let responses = Destinations::from_iter(vec![(
233 "test".to_string(),
234 Ok(MeasuredResponse::new(unavailable, SystemTime::now(), (Instant::now(), Instant::now()))),
235 )]);
236 let mut msg = Messages::new();
237 let result = evaluator.evaluate(responses, &mut msg).await;
238 assert!(!result);
239 assert!(matches!(msg.as_slice(), [HttpEvaluateError::UnacceptableStatus]));
240 }
241
242 }