relentless_grpc/
record.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
use std::path::PathBuf;

use bytes::Bytes;
use relentless::assault::service::record::{CollectClone, IoRecord, RequestIoRecord};
use serde::{de::DeserializeOwned, Deserialize};

use crate::client::GrpcMethodRequest;

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
pub struct GrpcIoRecorder;

impl<De, Se> IoRecord<GrpcMethodRequest<De, Se>> for GrpcIoRecorder
where
    De: for<'a> serde::Deserializer<'a> + Send + Sync + 'static,
    for<'a> <De as serde::Deserializer<'a>>::Error: std::error::Error + Send + Sync + 'static,
    Se: Send,
{
    type Error = std::io::Error;
    fn extension(&self, _r: &GrpcMethodRequest<De, Se>) -> &'static str {
        "json"
    }
    async fn record<W: std::io::Write>(&self, w: &mut W, r: GrpcMethodRequest<De, Se>) -> Result<(), Self::Error> {
        let value = serde_json::Value::deserialize(r.message)
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
        write!(w, "{}", serde_json::to_string_pretty(&value).unwrap())
    }
    async fn record_raw<W: std::io::Write + Send>(
        &self,
        w: &mut W,
        r: GrpcMethodRequest<De, Se>,
    ) -> Result<(), Self::Error> {
        let uri = r.destination;
        let (metadata, extension, message) = tonic::Request::new(r.message).into_parts();
        let mut http_request_builder =
            http::Request::builder().method(http::Method::POST).uri(uri).extension(extension);
        if let Some(headers) = http_request_builder.headers_mut() {
            *headers = metadata.into_headers();
        }
        let body = Bytes::from(
            serde_json::to_vec(
                &serde_json::Value::deserialize(message)
                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?,
            )
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?,
        );
        let http_request = http_request_builder
            .body(http_body_util::Full::new(body))
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;

        relentless_http::record::HttpIoRecorder.record_raw(w, http_request).await
    }
}
impl<De, Se> CollectClone<GrpcMethodRequest<De, Se>> for GrpcIoRecorder
where
    De: for<'a> serde::Deserializer<'a> + DeserializeOwned + Send + Sync + 'static,
    for<'a> <De as serde::Deserializer<'a>>::Error: std::error::Error + Send + Sync + 'static,
    Se: Clone + Send,
{
    type Error = std::io::Error;
    async fn collect_clone(
        &self,
        r: GrpcMethodRequest<De, Se>,
    ) -> Result<(GrpcMethodRequest<De, Se>, GrpcMethodRequest<De, Se>), Self::Error> {
        let GrpcMethodRequest { destination, service, method, codec, message } = r;
        let value = serde_json::Value::deserialize(message)
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
        let m1 = serde_json::from_value(value.clone())
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
        let m2 = serde_json::from_value(value).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
        Ok((
            GrpcMethodRequest {
                destination: destination.clone(),
                service: service.clone(),
                method: method.clone(),
                codec: codec.clone(),
                message: m1,
            },
            GrpcMethodRequest { destination, service, method, codec, message: m2 },
        ))
    }
}
impl<De, Se> RequestIoRecord<GrpcMethodRequest<De, Se>> for GrpcIoRecorder {
    fn record_dir(&self, r: &GrpcMethodRequest<De, Se>) -> PathBuf {
        http::uri::Builder::from(r.destination.clone())
            .path_and_query(r.format_method_path())
            .build()
            .unwrap_or_else(|e| unreachable!("{}", e))
            .to_string()
            .into()
    }
}

impl IoRecord<tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>> for GrpcIoRecorder {
    type Error = std::io::Error;

    fn extension(
        &self,
        _r: &tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>,
    ) -> &'static str {
        "json"
    }
    async fn record<W: std::io::Write + Send>(
        &self,
        w: &mut W,
        r: tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>,
    ) -> Result<(), Self::Error> {
        let value = serde_json::Value::deserialize(r.into_inner())
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
        write!(w, "{}", serde_json::to_string_pretty(&value).unwrap())
    }
    async fn record_raw<W: std::io::Write + Send>(
        &self,
        w: &mut W,
        r: tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>,
    ) -> Result<(), Self::Error> {
        let (metadata, message, extension) = r.into_parts();
        let mut http_response_builder = http::Response::builder().extension(extension);
        if let Some(headers) = http_response_builder.headers_mut() {
            *headers = metadata.into_headers();
        }
        let body = Bytes::from(
            serde_json::to_vec(
                &serde_json::Value::deserialize(message)
                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?,
            )
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?,
        );
        let http_response = http_response_builder
            .body(http_body_util::Full::new(body))
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;

        relentless_http::record::HttpIoRecorder.record_raw(w, http_response).await
    }
}
impl CollectClone<tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>> for GrpcIoRecorder {
    type Error = std::io::Error;
    async fn collect_clone(
        &self,
        r: tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>,
    ) -> Result<
        (
            tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>,
            tonic::Response<<serde_json::value::Serializer as serde::Serializer>::Ok>,
        ),
        Self::Error,
    > {
        let (metadata, message, extension) = r.into_parts();
        let value = serde_json::Value::deserialize(message)
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
        let m1 = serde_json::from_value(value.clone())
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
        let m2 = serde_json::from_value(value).map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
        Ok((
            tonic::Response::from_parts(metadata.clone(), m1, extension.clone()),
            tonic::Response::from_parts(metadata, m2, extension),
        ))
    }
}