1use std::{
2 fs::File,
3 io::Read,
4 path::{Path, PathBuf},
5};
6
7use bytes::Bytes;
8use futures::{StreamExt, TryStreamExt};
9use prost::Message;
10use prost_reflect::{DescriptorPool, MethodDescriptor, ServiceDescriptor};
11use prost_types::FileDescriptorProto;
12use serde::{Deserialize, Serialize};
13use tonic::transport::Channel;
14use tonic_reflection::pb::v1::{
15 server_reflection_client::ServerReflectionClient, server_reflection_request::MessageRequest,
16 server_reflection_response::MessageResponse, ServerReflectionRequest,
17};
18
19use relentless::{
20 assault::factory::RequestFactory,
21 error::IntoResult,
22 interface::{
23 helper::{coalesce::Coalesce, is_default::IsDefault},
24 template::Template,
25 },
26};
27use tower::Service;
28
29use crate::helper::JsonSerializer;
30
31use super::{
32 client::{GrpcMethodRequest, MethodCodec},
33 error::GrpcRequestError,
34};
35
36#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
37#[serde(deny_unknown_fields, rename_all = "kebab-case")]
38pub struct GrpcRequest {
39 #[serde(default, skip_serializing_if = "IsDefault::is_default")]
40 #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
41 descriptor: DescriptorFrom,
42 #[serde(default, skip_serializing_if = "IsDefault::is_default")]
43 #[cfg_attr(feature = "yaml", serde(with = "serde_yaml::with::singleton_map_recursive"))]
44 message: GrpcMessage,
45}
46#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
47#[serde(deny_unknown_fields, rename_all = "kebab-case", untagged)]
48pub enum DescriptorFrom {
49 Protos {
50 #[serde(default, skip_serializing_if = "IsDefault::is_default")]
51 protos: Vec<PathBuf>,
52 #[serde(default, skip_serializing_if = "IsDefault::is_default")]
53 import_path: Vec<PathBuf>,
54 },
55 Bin(PathBuf),
56 #[default]
57 Reflection,
58}
59#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
60#[serde(deny_unknown_fields, rename_all = "kebab-case")]
61pub enum GrpcMessage {
62 #[default]
63 Empty,
64 Plaintext(String),
65 Json(serde_json::Value),
66}
67impl Coalesce for GrpcRequest {
68 fn coalesce(self, other: &Self) -> Self {
69 Self { descriptor: self.descriptor.coalesce(&other.descriptor), message: self.message.coalesce(&other.message) }
70 }
71}
72impl Coalesce for DescriptorFrom {
73 fn coalesce(self, other: &Self) -> Self {
74 if self.is_default() {
75 other.clone()
76 } else {
77 self
78 }
79 }
80}
81impl Coalesce for GrpcMessage {
82 fn coalesce(self, other: &Self) -> Self {
83 if self.is_default() {
84 other.clone()
85 } else {
86 self
87 }
88 }
89}
90
91impl<S> RequestFactory<GrpcMethodRequest<serde_json::Value, JsonSerializer>, S> for GrpcRequest
92where
93 S: Service<GrpcMethodRequest<serde_json::Value, JsonSerializer>>,
94{
95 type Error = relentless::Error;
96 async fn produce(
97 &self,
98 service: S,
99 destination: &http::Uri,
100 target: &str,
101 template: &Template,
102 ) -> Result<GrpcMethodRequest<serde_json::Value, JsonSerializer>, Self::Error> {
103 let (svc, mth) = target.split_once('/').ok_or_else(|| GrpcRequestError::FailToParse(target.to_string()))?; let pool = self.descriptor_pool(service, destination, (svc, mth)).await?;
105 let destination = destination.clone();
106 let (service, method) = Self::service_method(&pool, (svc, mth))?;
107 let message = template.render_json_recursive(&self.message.produce())?;
108 let codec = MethodCodec::new(method.clone(), JsonSerializer::default()); Ok(GrpcMethodRequest { destination, service, method, codec, message })
111 }
112}
113impl GrpcRequest {
114 pub fn service_method(
115 pool: &DescriptorPool,
116 (service, method): (&str, &str),
117 ) -> relentless::Result<(ServiceDescriptor, MethodDescriptor)> {
118 let svc = pool.get_service_by_name(service).ok_or_else(|| GrpcRequestError::NoService(service.to_string()))?;
119 let mth =
120 svc.methods().find(|m| m.name() == method).ok_or_else(|| GrpcRequestError::NoMethod(method.to_string()))?;
121 Ok((svc, mth))
122 }
123 pub async fn descriptor_pool<S>(
124 &self,
125 service: S,
126 destination: &http::Uri,
127 (svc, _mth): (&str, &str),
128 ) -> relentless::Result<DescriptorPool> {
129 match &self.descriptor {
131 DescriptorFrom::Protos { protos, import_path } => Self::descriptor_from_protos(protos, import_path).await,
132 DescriptorFrom::Bin(path) => Self::descriptor_from_file(path).await,
133 DescriptorFrom::Reflection => Self::descriptor_from_reflection(service, destination, svc).await,
134 }
135 }
136
137 pub async fn descriptor_from_protos<A: AsRef<Path>>(
138 protos: &[A],
139 import_path: &[A],
140 ) -> relentless::Result<DescriptorPool> {
141 let builder = &mut prost_build::Config::new();
142 let fds = builder.load_fds(protos, import_path).box_err()?;
143 DescriptorPool::from_file_descriptor_set(fds).box_err()
144 }
145
146 pub async fn descriptor_from_file(path: &PathBuf) -> relentless::Result<DescriptorPool> {
147 let mut descriptor_bytes = Vec::new();
148 File::open(path).box_err()?.read_to_end(&mut descriptor_bytes).box_err()?;
149 DescriptorPool::decode(Bytes::from(descriptor_bytes)).box_err()
150 }
151
152 pub async fn descriptor_from_reflection<S>(
153 _service: S,
154 destination: &http::Uri,
155 svc: &str,
156 ) -> relentless::Result<DescriptorPool> {
157 let mut client = ServerReflectionClient::new(Channel::builder(destination.clone()).connect().await.box_err()?);
159 let (host, service) = (
160 destination.host().ok_or_else(|| GrpcRequestError::NoHost(destination.clone()))?.to_string(),
161 svc.to_string(),
162 );
163 let request_stream = futures::stream::once({
164 let host = host.clone();
165 async move {
166 ServerReflectionRequest { host, message_request: Some(MessageRequest::FileContainingSymbol(service)) }
167 }
168 });
169 let streaming = client.server_reflection_info(request_stream).await.box_err()?.into_inner();
170 let descriptors = streaming
171 .map(|recv| async { recv.box_err() })
172 .buffer_unordered(1)
173 .try_fold(DescriptorPool::new(), move |mut pool, recv| {
174 let host = host.to_string();
175 async move {
176 let MessageResponse::FileDescriptorResponse(descriptor) =
177 recv.message_response.ok_or_else(|| GrpcRequestError::EmptyResponse)?
178 else {
179 return Err(GrpcRequestError::UnexpectedReflectionResponse.into());
180 };
181 futures::stream::iter(descriptor.file_descriptor_proto.into_iter())
182 .map(|d| async { Ok(d) })
183 .buffer_unordered(16)
184 .try_fold(&mut pool, move |p, d| {
185 let host = host.clone();
186 async move {
187 let fd = FileDescriptorProto::decode(&*d).box_err()?;
188 Self::fetch_all_descriptors(destination, &host, p, fd).await.map(|()| p)
189 }
190 })
191 .await?;
192 Ok(pool)
193 }
194 })
195 .await?;
196 Ok(descriptors)
197 }
198
199 pub async fn fetch_all_descriptors(
200 destination: &http::Uri,
201 host: &str,
202 pool: &mut DescriptorPool,
203 fd: FileDescriptorProto,
204 ) -> relentless::Result<()> {
205 let mut stack = vec![fd]; let mut client = ServerReflectionClient::new(Channel::builder(destination.clone()).connect().await.box_err()?);
207 while let Some(proto) = stack.pop() {
208 if pool.add_file_descriptor_proto(proto.clone()).is_err() {
209 stack.push(proto.clone());
210 let host = host.to_string();
211 let dep_streaming = client
212 .server_reflection_info(futures::stream::iter(proto.dependency.into_iter().map(move |dep| {
213 let host = host.clone();
214 ServerReflectionRequest { host, message_request: Some(MessageRequest::FileByFilename(dep)) }
215 })))
216 .await
217 .box_err()?
218 .into_inner();
219 dep_streaming
220 .map(|recv| async { recv.box_err() })
221 .buffer_unordered(16)
222 .try_fold(&mut stack, |dfs, recv| async move {
223 let MessageResponse::FileDescriptorResponse(descriptor) =
224 recv.message_response.ok_or_else(|| GrpcRequestError::EmptyResponse)?
225 else {
226 return Err(GrpcRequestError::UnexpectedReflectionResponse.into());
227 };
228 let dep_protos: relentless::Result<Vec<_>> = descriptor
229 .file_descriptor_proto
230 .into_iter()
231 .map(|d| FileDescriptorProto::decode(&*d).box_err())
232 .collect();
233 dfs.extend(dep_protos?); Ok(dfs)
235 })
236 .await?;
237 }
238 }
239 Ok(())
240 }
241}
242
243impl GrpcMessage {
244 pub fn produce(&self) -> serde_json::Value {
246 match self {
247 Self::Empty => serde_json::Value::Object(serde_json::Map::new()),
248 Self::Plaintext(_) => unimplemented!(),
249 Self::Json(v) => v.clone(),
250 }
251 }
252}