relentless_http/
record.rs1use bytes::Bytes;
2use http::header::CONTENT_TYPE;
3use http_body::Body;
4use http_body_util::{BodyExt, Collected};
5use relentless::assault::service::record::{CollectClone, IoRecord, RequestIoRecord};
6
7#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
8pub struct HttpIoRecorder;
9
10impl<B> IoRecord<http::Request<B>> for HttpIoRecorder
11where
12 B: Body + From<Bytes> + Send,
13 B::Data: Send,
14{
15 type Error = std::io::Error;
16 fn extension(&self, r: &http::Request<B>) -> &'static str {
17 if let Some(content_type) = r.headers().get(CONTENT_TYPE) {
18 if content_type == mime::APPLICATION_JSON.as_ref() {
19 "json"
20 } else {
21 "txt"
22 }
23 } else {
24 "txt"
25 }
26 }
27 async fn record<W: std::io::Write + Send>(&self, w: &mut W, r: http::Request<B>) -> Result<(), Self::Error> {
28 let body = BodyExt::collect(r.into_body()).await.map(Collected::to_bytes).unwrap_or_default();
29 write!(w, "{}", String::from_utf8_lossy(&body))
30 }
31 async fn record_raw<W: std::io::Write + Send>(&self, w: &mut W, r: http::Request<B>) -> Result<(), Self::Error> {
32 let (http::request::Parts { method, uri, version, headers, .. }, body) = r.into_parts();
33
34 writeln!(w, "{method} {uri} {version:?}")?;
35 for (header, value) in headers.iter() {
36 writeln!(w, "{header}: {value:?}")?;
37 }
38 writeln!(w)?;
39 if let Ok(b) = BodyExt::collect(body).await.map(Collected::to_bytes) {
40 write!(w, "{}", String::from_utf8_lossy(&b))?;
41 }
42
43 Ok(())
44 }
45}
46
47impl<B> CollectClone<http::Request<B>> for HttpIoRecorder
48where
49 B: Body + From<Bytes> + Send,
50 B::Data: Send,
51{
52 type Error = B::Error;
53 async fn collect_clone(&self, r: http::Request<B>) -> Result<(http::Request<B>, http::Request<B>), Self::Error> {
54 let (req_parts, req_body) = r.into_parts();
56 let req_bytes = BodyExt::collect(req_body).await.map(Collected::to_bytes)?;
57 let req1 = http::Request::from_parts(req_parts.clone(), B::from(req_bytes.clone()));
58 let req2 = http::Request::from_parts(req_parts, B::from(req_bytes));
59 Ok((req1, req2))
60 }
61}
62impl<B> RequestIoRecord<http::Request<B>> for HttpIoRecorder {
63 fn record_dir(&self, r: &http::Request<B>) -> std::path::PathBuf {
64 r.uri().to_string().into()
65 }
66}
67
68impl<B> IoRecord<http::Response<B>> for HttpIoRecorder
69where
70 B: Body + From<Bytes> + Send,
71 B::Data: Send,
72{
73 type Error = std::io::Error;
74 fn extension(&self, r: &http::Response<B>) -> &'static str {
75 if let Some(content_type) = r.headers().get(CONTENT_TYPE) {
76 if content_type == mime::APPLICATION_JSON.as_ref() {
77 "json"
78 } else {
79 "txt"
80 }
81 } else {
82 "txt"
83 }
84 }
85 async fn record<W: std::io::Write>(&self, w: &mut W, r: http::Response<B>) -> Result<(), Self::Error> {
86 let body = BodyExt::collect(r.into_body()).await.map(Collected::to_bytes).unwrap_or_default();
87 write!(w, "{}", String::from_utf8_lossy(&body))
88 }
89
90 async fn record_raw<W: std::io::Write>(&self, w: &mut W, r: http::Response<B>) -> Result<(), Self::Error> {
91 let (http::response::Parts { version, status, headers, .. }, body) = r.into_parts();
92
93 writeln!(w, "{version:?} {status}")?;
94 for (header, value) in headers.iter() {
95 writeln!(w, "{header}: {value:?}")?;
96 }
97 writeln!(w)?;
98 if let Ok(b) = BodyExt::collect(body).await.map(Collected::to_bytes) {
99 write!(w, "{}", String::from_utf8_lossy(&b))?;
100 }
101
102 Ok(())
103 }
104}
105impl<B> CollectClone<http::Response<B>> for HttpIoRecorder
106where
107 B: Body + From<Bytes> + Send,
108 B::Data: Send,
109{
110 type Error = B::Error;
111 async fn collect_clone(&self, r: http::Response<B>) -> Result<(http::Response<B>, http::Response<B>), Self::Error> {
112 let (res_parts, res_body) = r.into_parts();
114 let res_bytes = BodyExt::collect(res_body).await.map(Collected::to_bytes)?;
115 let res1 = http::Response::from_parts(res_parts.clone(), B::from(res_bytes.clone()));
116 let res2 = http::Response::from_parts(res_parts, B::from(res_bytes));
117 Ok((res1, res2))
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use bytes::Bytes;
124 use http::Method;
125
126 use super::*;
127
128 #[tokio::test]
129 async fn test_empty_body_request() {
130 let request = http::Request::builder()
131 .method(Method::GET)
132 .uri("http://localhost:3000")
133 .body(http_body_util::Full::<Bytes>::new(Default::default()))
134 .unwrap();
135
136 let mut buf = Vec::new();
137 HttpIoRecorder.record_raw(&mut buf, request).await.unwrap();
138 assert_eq!(buf, b"GET http://localhost:3000/ HTTP/1.1\n\n");
139 }
140
141 #[tokio::test]
142 async fn test_empty_body_response() {
143 let response = http::Response::builder()
144 .status(http::StatusCode::OK)
145 .body(http_body_util::Full::<Bytes>::new(Default::default()))
146 .unwrap();
147
148 let mut buf = Vec::new();
149 HttpIoRecorder.record_raw(&mut buf, response).await.unwrap();
150 assert_eq!(buf, b"HTTP/1.1 200 OK\n\n");
151 }
152}