relentless_http/
evaluate.rs

1use bytes::Bytes;
2use http_body::Body;
3use http_body_util::{BodyExt, Collected};
4use serde::{Deserialize, Serialize};
5
6#[cfg(feature = "json")]
7use relentless::assault::evaluator::json::JsonEvaluator;
8use relentless::{
9    assault::{
10        destinations::{AllOr, Destinations},
11        evaluate::{Acceptable, Evaluate},
12        evaluator::plaintext::PlaintextEvaluator,
13        messages::Messages,
14        result::RequestResult,
15    },
16    interface::helper::{coalesce::Coalesce, http_serde_priv, is_default::IsDefault},
17};
18
19use super::error::HttpEvaluateError;
20
21#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
22#[serde(deny_unknown_fields, rename_all = "kebab-case")]
23pub struct HttpResponse {
24    #[serde(default, skip_serializing_if = "IsDefault::is_default")]
25    #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
26    pub status: HttpStatus,
27    #[serde(default, skip_serializing_if = "IsDefault::is_default")]
28    #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
29    pub header: HttpHeaders,
30    #[serde(default, skip_serializing_if = "IsDefault::is_default")]
31    #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
32    pub body: HttpBody,
33}
34#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
35#[serde(deny_unknown_fields, rename_all = "kebab-case")]
36pub enum HttpStatus {
37    #[default]
38    OkOrEqual,
39    Expect(AllOr<http_serde_priv::StatusCode>),
40    Ignore,
41}
42#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
43#[serde(deny_unknown_fields, rename_all = "kebab-case")]
44pub enum HttpHeaders {
45    #[default]
46    AnyOrEqual,
47    Expect(AllOr<http_serde_priv::HeaderMap>),
48    Ignore,
49}
50#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
51#[serde(deny_unknown_fields, rename_all = "kebab-case")]
52pub enum HttpBody {
53    #[default]
54    AnyOrEqual,
55    Plaintext(PlaintextEvaluator),
56    #[cfg(feature = "json")]
57    Json(JsonEvaluator),
58}
59
60impl Coalesce for HttpResponse {
61    fn coalesce(self, other: &Self) -> Self {
62        Self {
63            status: self.status.coalesce(&other.status),
64            header: self.header.coalesce(&other.header),
65            body: self.body.coalesce(&other.body),
66        }
67    }
68}
69impl Coalesce for HttpStatus {
70    fn coalesce(self, other: &Self) -> Self {
71        if self.is_default() {
72            other.clone()
73        } else {
74            self
75        }
76    }
77}
78impl Coalesce for HttpHeaders {
79    fn coalesce(self, other: &Self) -> Self {
80        if self.is_default() {
81            other.clone()
82        } else {
83            self
84        }
85    }
86}
87impl Coalesce for HttpBody {
88    fn coalesce(self, other: &Self) -> Self {
89        if self.is_default() {
90            other.clone()
91        } else {
92            self
93        }
94    }
95}
96
97impl<B> Evaluate<http::Response<B>> for HttpResponse
98where
99    B: Body,
100    B::Error: std::error::Error + Sync + Send + 'static,
101{
102    type Message = HttpEvaluateError;
103    async fn evaluate(
104        &self,
105        res: Destinations<RequestResult<http::Response<B>>>,
106        msg: &mut Messages<Self::Message>,
107    ) -> bool {
108        let Some(responses) = msg.response_destinations_with(res, HttpEvaluateError::RequestError) else {
109            return false;
110        };
111        let Some(parts) = msg.push_if_err(HttpResponse::unzip_parts(responses).await) else {
112            return false;
113        };
114
115        self.accept(&parts, msg)
116    }
117}
118impl Acceptable<(http::StatusCode, http::HeaderMap, Bytes)> for HttpResponse {
119    type Message = HttpEvaluateError;
120    fn accept(
121        &self,
122        parts: &Destinations<(http::StatusCode, http::HeaderMap, Bytes)>,
123        msg: &mut Messages<Self::Message>,
124    ) -> bool {
125        let (mut status, mut headers, mut body) = (Destinations::new(), Destinations::new(), Destinations::new());
126        for (name, (s, h, b)) in parts {
127            status.insert(name.clone(), s);
128            headers.insert(name.clone(), h);
129            body.insert(name.clone(), b);
130        }
131        self.status.accept(&status, msg) && self.header.accept(&headers, msg) && self.body.accept(&body, msg)
132    }
133}
134impl HttpResponse {
135    pub async fn unzip_parts<B>(
136        responses: Destinations<http::Response<B>>,
137    ) -> Result<Destinations<(http::StatusCode, http::HeaderMap, Bytes)>, HttpEvaluateError>
138    where
139        B: Body,
140        B::Error: std::error::Error + Sync + Send + 'static,
141    {
142        let mut parts = Destinations::new();
143        for (name, response) in responses {
144            let (http::response::Parts { status, headers, .. }, body) = response.into_parts();
145            let bytes = BodyExt::collect(body)
146                .await
147                .map(Collected::to_bytes)
148                .map_err(|e| HttpEvaluateError::FailToCollectBody(e.into()))?;
149            parts.insert(name, (status, headers, bytes));
150        }
151        Ok(parts)
152    }
153}
154
155impl Acceptable<&http::StatusCode> for HttpStatus {
156    type Message = HttpEvaluateError;
157    fn accept(&self, status: &Destinations<&http::StatusCode>, msg: &mut Messages<Self::Message>) -> bool {
158        let acceptable = match &self {
159            HttpStatus::OkOrEqual => Self::assault_or_compare(status, |(_, s)| s.is_success()),
160            HttpStatus::Expect(AllOr::All(code)) => Self::validate_all(status, |(_, s)| s == &&**code),
161            HttpStatus::Expect(AllOr::Destinations(code)) => {
162                // TODO subset ?
163                status == &code.iter().map(|(d, c)| (d.to_string(), &**c)).collect()
164            }
165            HttpStatus::Ignore => true,
166        };
167        if !acceptable {
168            msg.push_err(HttpEvaluateError::UnacceptableStatus);
169        }
170        acceptable
171    }
172}
173
174impl Acceptable<&http::HeaderMap> for HttpHeaders {
175    type Message = HttpEvaluateError;
176    fn accept(&self, headers: &Destinations<&http::HeaderMap>, msg: &mut Messages<Self::Message>) -> bool {
177        let acceptable = match &self {
178            HttpHeaders::AnyOrEqual => Self::assault_or_compare(headers, |_| true),
179            HttpHeaders::Expect(AllOr::All(header)) => Self::validate_all(headers, |(_, h)| h == &&**header),
180            HttpHeaders::Expect(AllOr::Destinations(header)) => {
181                // TODO subset ?
182                headers == &header.iter().map(|(d, h)| (d.to_string(), &**h)).collect()
183            }
184            HttpHeaders::Ignore => true,
185        };
186        if !acceptable {
187            msg.push_err(HttpEvaluateError::UnacceptableHeaderMap);
188        }
189        acceptable
190    }
191}
192
193impl Acceptable<&Bytes> for HttpBody {
194    type Message = HttpEvaluateError;
195    fn accept(&self, body: &Destinations<&Bytes>, msg: &mut Messages<Self::Message>) -> bool {
196        match &self {
197            HttpBody::AnyOrEqual => Self::assault_or_compare(body, |_| true),
198            HttpBody::Plaintext(p) => Self::sub_accept(p, body, msg, HttpEvaluateError::PlaintextEvaluateError),
199            #[cfg(feature = "json")]
200            HttpBody::Json(e) => Self::sub_accept(e, body, msg, HttpEvaluateError::JsonEvaluateError),
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use std::time::{Instant, SystemTime};
208
209    use relentless::assault::measure::metrics::MeasuredResponse;
210
211    use super::*;
212
213    #[tokio::test]
214    async fn test_default_assault_evaluate() {
215        let evaluator = HttpResponse::default();
216
217        let ok =
218            http::Response::builder().status(http::StatusCode::OK).body(http_body_util::Empty::<Bytes>::new()).unwrap();
219        let responses = Destinations::from_iter(vec![(
220            "test".to_string(),
221            Ok(MeasuredResponse::new(ok, SystemTime::now(), (Instant::now(), Instant::now()))),
222        )]);
223        let mut msg = Messages::new();
224        let result = evaluator.evaluate(responses, &mut msg).await;
225        assert!(result);
226        assert!(msg.is_empty());
227
228        let unavailable = http::Response::builder()
229            .status(http::StatusCode::SERVICE_UNAVAILABLE)
230            .body(http_body_util::Empty::<Bytes>::new())
231            .unwrap();
232        let responses = Destinations::from_iter(vec![(
233            "test".to_string(),
234            Ok(MeasuredResponse::new(unavailable, SystemTime::now(), (Instant::now(), Instant::now()))),
235        )]);
236        let mut msg = Messages::new();
237        let result = evaluator.evaluate(responses, &mut msg).await;
238        assert!(!result);
239        assert!(matches!(msg.as_slice(), [HttpEvaluateError::UnacceptableStatus]));
240    }
241
242    // TODO more tests
243}