relentless_grpc/
factory.rs

1use std::{
2    fs::File,
3    io::Read,
4    path::{Path, PathBuf},
5};
6
7use bytes::Bytes;
8use futures::{StreamExt, TryStreamExt};
9use prost::Message;
10use prost_reflect::{DescriptorPool, MethodDescriptor, ServiceDescriptor};
11use prost_types::FileDescriptorProto;
12use serde::{Deserialize, Serialize};
13use tonic::transport::Channel;
14use tonic_reflection::pb::v1::{
15    server_reflection_client::ServerReflectionClient, server_reflection_request::MessageRequest,
16    server_reflection_response::MessageResponse, ServerReflectionRequest,
17};
18
19use relentless::{
20    assault::factory::RequestFactory,
21    error::IntoResult,
22    interface::{
23        helper::{coalesce::Coalesce, is_default::IsDefault},
24        template::Template,
25    },
26};
27use tower::Service;
28
29use crate::helper::JsonSerializer;
30
31use super::{
32    client::{GrpcMethodRequest, MethodCodec},
33    error::GrpcRequestError,
34};
35
36#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
37#[serde(deny_unknown_fields, rename_all = "kebab-case")]
38pub struct GrpcRequest {
39    #[serde(default, skip_serializing_if = "IsDefault::is_default")]
40    #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
41    descriptor: DescriptorFrom,
42    #[serde(default, skip_serializing_if = "IsDefault::is_default")]
43    #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
44    message: GrpcMessage,
45}
46#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
47#[serde(deny_unknown_fields, rename_all = "kebab-case", untagged)]
48pub enum DescriptorFrom {
49    Protos {
50        #[serde(default, skip_serializing_if = "IsDefault::is_default")]
51        protos: Vec<PathBuf>,
52        #[serde(default, skip_serializing_if = "IsDefault::is_default")]
53        import_path: Vec<PathBuf>,
54    },
55    Bin(PathBuf),
56    #[default]
57    Reflection,
58}
59#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
60#[serde(deny_unknown_fields, rename_all = "kebab-case")]
61pub enum GrpcMessage {
62    #[default]
63    Empty,
64    Plaintext(String),
65    Json(serde_json::Value),
66}
67impl Coalesce for GrpcRequest {
68    fn coalesce(self, other: &Self) -> Self {
69        Self { descriptor: self.descriptor.coalesce(&other.descriptor), message: self.message.coalesce(&other.message) }
70    }
71}
72impl Coalesce for DescriptorFrom {
73    fn coalesce(self, other: &Self) -> Self {
74        if self.is_default() {
75            other.clone()
76        } else {
77            self
78        }
79    }
80}
81impl Coalesce for GrpcMessage {
82    fn coalesce(self, other: &Self) -> Self {
83        if self.is_default() {
84            other.clone()
85        } else {
86            self
87        }
88    }
89}
90
91impl<S> RequestFactory<GrpcMethodRequest<serde_json::Value, JsonSerializer>, S> for GrpcRequest
92where
93    S: Service<GrpcMethodRequest<serde_json::Value, JsonSerializer>>,
94{
95    type Error = relentless::Error;
96    async fn produce(
97        &self,
98        service: S,
99        destination: &http::Uri,
100        target: &str,
101        template: &Template,
102    ) -> Result<GrpcMethodRequest<serde_json::Value, JsonSerializer>, Self::Error> {
103        let (svc, mth) = target.split_once('/').ok_or_else(|| GrpcRequestError::FailToParse(target.to_string()))?; // TODO only one '/' ?
104        let pool = self.descriptor_pool(service, destination, (svc, mth)).await?;
105        let destination = destination.clone();
106        let (service, method) = Self::service_method(&pool, (svc, mth))?;
107        let message = template.render_json_recursive(&self.message.produce())?;
108        let codec = MethodCodec::new(method.clone(), JsonSerializer::default()); // TODO remove clone
109
110        Ok(GrpcMethodRequest { destination, service, method, codec, message })
111    }
112}
113impl GrpcRequest {
114    pub fn service_method(
115        pool: &DescriptorPool,
116        (service, method): (&str, &str),
117    ) -> relentless::Result<(ServiceDescriptor, MethodDescriptor)> {
118        let svc = pool.get_service_by_name(service).ok_or_else(|| GrpcRequestError::NoService(service.to_string()))?;
119        let mth =
120            svc.methods().find(|m| m.name() == method).ok_or_else(|| GrpcRequestError::NoMethod(method.to_string()))?;
121        Ok((svc, mth))
122    }
123    pub async fn descriptor_pool<S>(
124        &self,
125        service: S,
126        destination: &http::Uri,
127        (svc, _mth): (&str, &str),
128    ) -> relentless::Result<DescriptorPool> {
129        // TODO cache
130        match &self.descriptor {
131            DescriptorFrom::Protos { protos, import_path } => Self::descriptor_from_protos(protos, import_path).await,
132            DescriptorFrom::Bin(path) => Self::descriptor_from_file(path).await,
133            DescriptorFrom::Reflection => Self::descriptor_from_reflection(service, destination, svc).await,
134        }
135    }
136
137    pub async fn descriptor_from_protos<A: AsRef<Path>>(
138        protos: &[A],
139        import_path: &[A],
140    ) -> relentless::Result<DescriptorPool> {
141        let builder = &mut prost_build::Config::new();
142        let fds = builder.load_fds(protos, import_path).box_err()?;
143        DescriptorPool::from_file_descriptor_set(fds).box_err()
144    }
145
146    pub async fn descriptor_from_file(path: &PathBuf) -> relentless::Result<DescriptorPool> {
147        let mut descriptor_bytes = Vec::new();
148        File::open(path).box_err()?.read_to_end(&mut descriptor_bytes).box_err()?;
149        DescriptorPool::decode(Bytes::from(descriptor_bytes)).box_err()
150    }
151
152    pub async fn descriptor_from_reflection<S>(
153        _service: S,
154        destination: &http::Uri,
155        svc: &str,
156    ) -> relentless::Result<DescriptorPool> {
157        // TODO!!! do not use Channel directly, use Service
158        let mut client = ServerReflectionClient::new(Channel::builder(destination.clone()).connect().await.box_err()?);
159        let (host, service) = (
160            destination.host().ok_or_else(|| GrpcRequestError::NoHost(destination.clone()))?.to_string(),
161            svc.to_string(),
162        );
163        let request_stream = futures::stream::once({
164            let host = host.clone();
165            async move {
166                ServerReflectionRequest { host, message_request: Some(MessageRequest::FileContainingSymbol(service)) }
167            }
168        });
169        let streaming = client.server_reflection_info(request_stream).await.box_err()?.into_inner();
170        let descriptors = streaming
171            .map(|recv| async { recv.box_err() })
172            .buffer_unordered(1)
173            .try_fold(DescriptorPool::new(), move |mut pool, recv| {
174                let host = host.to_string();
175                async move {
176                    let MessageResponse::FileDescriptorResponse(descriptor) =
177                        recv.message_response.ok_or_else(|| GrpcRequestError::EmptyResponse)?
178                    else {
179                        return Err(GrpcRequestError::UnexpectedReflectionResponse.into());
180                    };
181                    futures::stream::iter(descriptor.file_descriptor_proto.into_iter())
182                        .map(|d| async { Ok(d) })
183                        .buffer_unordered(16)
184                        .try_fold(&mut pool, move |p, d| {
185                            let host = host.clone();
186                            async move {
187                                let fd = FileDescriptorProto::decode(&*d).box_err()?;
188                                Self::fetch_all_descriptors(destination, &host, p, fd).await.map(|()| p)
189                            }
190                        })
191                        .await?;
192                    Ok(pool)
193                }
194            })
195            .await?;
196        Ok(descriptors)
197    }
198
199    pub async fn fetch_all_descriptors(
200        destination: &http::Uri,
201        host: &str,
202        pool: &mut DescriptorPool,
203        fd: FileDescriptorProto,
204    ) -> relentless::Result<()> {
205        let mut stack = vec![fd]; // TODO use stream as stack ?
206        let mut client = ServerReflectionClient::new(Channel::builder(destination.clone()).connect().await.box_err()?);
207        while let Some(proto) = stack.pop() {
208            if pool.add_file_descriptor_proto(proto.clone()).is_err() {
209                stack.push(proto.clone());
210                let host = host.to_string();
211                let dep_streaming = client
212                    .server_reflection_info(futures::stream::iter(proto.dependency.into_iter().map(move |dep| {
213                        let host = host.clone();
214                        ServerReflectionRequest { host, message_request: Some(MessageRequest::FileByFilename(dep)) }
215                    })))
216                    .await
217                    .box_err()?
218                    .into_inner();
219                dep_streaming
220                    .map(|recv| async { recv.box_err() })
221                    .buffer_unordered(16)
222                    .try_fold(&mut stack, |dfs, recv| async move {
223                        let MessageResponse::FileDescriptorResponse(descriptor) =
224                            recv.message_response.ok_or_else(|| GrpcRequestError::EmptyResponse)?
225                        else {
226                            return Err(GrpcRequestError::UnexpectedReflectionResponse.into());
227                        };
228                        let dep_protos: relentless::Result<Vec<_>> = descriptor
229                            .file_descriptor_proto
230                            .into_iter()
231                            .map(|d| FileDescriptorProto::decode(&*d).box_err())
232                            .collect();
233                        dfs.extend(dep_protos?); // TODO dedup in advance?
234                        Ok(dfs)
235                    })
236                    .await?;
237            }
238        }
239        Ok(())
240    }
241}
242
243impl GrpcMessage {
244    // TODO type of grpc message
245    pub fn produce(&self) -> serde_json::Value {
246        match self {
247            Self::Empty => serde_json::Value::Object(serde_json::Map::new()),
248            Self::Plaintext(_) => unimplemented!(),
249            Self::Json(v) => v.clone(),
250        }
251    }
252}