1use std::path::PathBuf;
2
3use bytes::Bytes;
4use relentless::assault::service::record::{CollectClone, IoRecord, RequestIoRecord};
5use serde::{de::DeserializeOwned, Deserialize};
6
7use crate::client::GrpcMethodRequest;
8
9#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
10pub struct GrpcIoRecorder;
11
12impl<De, Se> IoRecord<GrpcMethodRequest<De, Se>> for GrpcIoRecorder
13where
14 De: for<'a> serde::Deserializer<'a> + Send + Sync + 'static,
15 for<'a> <De as serde::Deserializer<'a>>::Error: std::error::Error + Send + Sync + 'static,
16 Se: Send,
17{
18 type Error = std::io::Error;
19 fn extension(&self, _r: &GrpcMethodRequest<De, Se>) -> &'static str {
20 "json"
21 }
22 async fn record<W: std::io::Write>(&self, w: &mut W, r: GrpcMethodRequest<De, Se>) -> Result<(), Self::Error> {
23 let value = serde_json::Value::deserialize(r.message)
24 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
25 write!(w, "{}", serde_json::to_string_pretty(&value).unwrap())
26 }
27 async fn record_raw<W: std::io::Write + Send>(
28 &self,
29 w: &mut W,
30 r: GrpcMethodRequest<De, Se>,
31 ) -> Result<(), Self::Error> {
32 let uri = r.destination;
33 let (metadata, extension, message) = tonic::Request::new(r.message).into_parts();
34 let mut http_request_builder =
35 http::Request::builder().method(http::Method::POST).uri(uri).extension(extension);
36 if let Some(headers) = http_request_builder.headers_mut() {
37 *headers = metadata.into_headers();
38 }
39 let body = Bytes::from(
40 serde_json::to_vec(
41 &serde_json::Value::deserialize(message)
42 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?,
43 )
44 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?,
45 );
46 let http_request = http_request_builder
47 .body(http_body_util::Full::new(body))
48 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
49
50 relentless_http::record::HttpIoRecorder.record_raw(w, http_request).await
51 }
52}
53impl<De, Se> CollectClone<GrpcMethodRequest<De, Se>> for GrpcIoRecorder
54where
55 De: for<'a> serde::Deserializer<'a> + DeserializeOwned + Send + Sync + 'static,
56 for<'a> <De as serde::Deserializer<'a>>::Error: std::error::Error + Send + Sync + 'static,
57 Se: Clone + Send,
58{
59 type Error = std::io::Error;
60 async fn collect_clone(
61 &self,
62 r: GrpcMethodRequest<De, Se>,
63 ) -> Result<(GrpcMethodRequest<De, Se>, GrpcMethodRequest<De, Se>), Self::Error> {
64 let GrpcMethodRequest { destination, service, method, codec, message } = r;
65 let value = serde_json::Value::deserialize(message)
66 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
67 let m1 = serde_json::from_value(value.clone())
68 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
69 let m2 = serde_json::from_value(value).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
70 Ok((
71 GrpcMethodRequest {
72 destination: destination.clone(),
73 service: service.clone(),
74 method: method.clone(),
75 codec: codec.clone(),
76 message: m1,
77 },
78 GrpcMethodRequest { destination, service, method, codec, message: m2 },
79 ))
80 }
81}
82impl<De, Se> RequestIoRecord<GrpcMethodRequest<De, Se>> for GrpcIoRecorder {
83 fn record_dir(&self, r: &GrpcMethodRequest<De, Se>) -> PathBuf {
84 http::uri::Builder::from(r.destination.clone())
85 .path_and_query(r.format_method_path())
86 .build()
87 .unwrap_or_else(|e| unreachable!("{}", e))
88 .to_string()
89 .into()
90 }
91}
92
93impl IoRecord<tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>> for GrpcIoRecorder {
94 type Error = std::io::Error;
95
96 fn extension(
97 &self,
98 _r: &tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>,
99 ) -> &'static str {
100 "json"
101 }
102 async fn record<W: std::io::Write + Send>(
103 &self,
104 w: &mut W,
105 r: tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>,
106 ) -> Result<(), Self::Error> {
107 let value = serde_json::Value::deserialize(r.into_inner())
108 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
109 write!(w, "{}", serde_json::to_string_pretty(&value).unwrap())
110 }
111 async fn record_raw<W: std::io::Write + Send>(
112 &self,
113 w: &mut W,
114 r: tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>,
115 ) -> Result<(), Self::Error> {
116 let (metadata, message, extension) = r.into_parts();
117 let mut http_response_builder = http::Response::builder().extension(extension);
118 if let Some(headers) = http_response_builder.headers_mut() {
119 *headers = metadata.into_headers();
120 }
121 let body = Bytes::from(
122 serde_json::to_vec(
123 &serde_json::Value::deserialize(message)
124 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?,
125 )
126 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?,
127 );
128 let http_response = http_response_builder
129 .body(http_body_util::Full::new(body))
130 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
131
132 relentless_http::record::HttpIoRecorder.record_raw(w, http_response).await
133 }
134}
135impl CollectClone<tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>> for GrpcIoRecorder {
136 type Error = std::io::Error;
137 async fn collect_clone(
138 &self,
139 r: tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>,
140 ) -> Result<
141 (
142 tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>,
143 tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>,
144 ),
145 Self::Error,
146 > {
147 let (metadata, message, extension) = r.into_parts();
148 let value = serde_json::Value::deserialize(message)
149 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
150 let m1 = serde_json::from_value(value.clone())
151 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
152 let m2 = serde_json::from_value(value).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
153 Ok((
154 tonic::Response::from_parts(metadata.clone(), m1, extension.clone()),
155 tonic::Response::from_parts(metadata, m2, extension),
156 ))
157 }
158}