1use std::{
2 collections::HashMap,
3 future::Future,
4 marker::PhantomData,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use bytes::Buf;
10use http::{uri::PathAndQuery, Uri};
11use prost::Message;
12use prost_reflect::{DynamicMessage, MessageDescriptor, MethodDescriptor, ServiceDescriptor};
13use serde::{Deserializer, Serialize, Serializer};
14use tonic::{
15 body::Body as BoxBody,
16 client::GrpcService,
17 codec::{Codec, Decoder, Encoder},
18 transport::{Body, Channel},
19 Status,
20};
21use tower::Service;
22
23use crate::error::GrpcClientError;
24
25#[derive(Debug, Clone, PartialEq)]
26pub struct GrpcMethodRequest<D, S> {
27 pub destination: http::Uri,
28 pub service: ServiceDescriptor,
29 pub method: MethodDescriptor,
30 pub codec: MethodCodec<D, S>,
31 pub message: D,
32}
33impl<D, S> GrpcMethodRequest<D, S> {
34 pub fn format_method_path(&self) -> PathAndQuery {
35 format!("/{}/{}", self.service.full_name(), self.method.name())
37 .parse()
38 .unwrap_or_else(|e| unreachable!("{}", e))
39 }
40}
41
42#[derive(Debug, Clone)]
43pub struct GrpcClient<S> {
44 inner: HashMap<Uri, tonic::client::Grpc<S>>,
45}
46
47impl GrpcClient<tonic::transport::Channel> {
48 pub async fn new(all_destinations: &[Uri]) -> Result<Self, GrpcClientError> {
49 let mut clients = HashMap::new();
50 for d in all_destinations {
51 let channel = Channel::builder(d.clone()).connect().await.unwrap_or_else(|e| todo!("{}", e));
52 clients.insert(d.clone(), tonic::client::Grpc::new(channel));
53 }
54 Ok(Self { inner: clients })
55 }
56}
57impl<S> GrpcClient<S>
58where
59 S: Clone,
60{
61 pub async fn from_services(services: &HashMap<Uri, S>) -> Result<Self, GrpcClientError> {
62 let clients = services.iter().map(|(d, s)| (d.clone(), tonic::client::Grpc::new(s.clone()))).collect();
63 Ok(Self { inner: clients })
64 }
65}
66
67impl<S, De, Se> Service<GrpcMethodRequest<De, Se>> for GrpcClient<S>
68where
69 S: GrpcService<BoxBody> + Clone + Send + 'static,
70 S::ResponseBody: Send,
71 <S::ResponseBody as Body>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
72 S::Future: Send + 'static,
73 De: for<'a> Deserializer<'a> + Send + Sync + 'static,
74 for<'a> <De as Deserializer<'a>>::Error: std::error::Error + Send + Sync + 'static,
75 Se: Serializer + Clone + Send + Sync + 'static,
76 Se::Ok: Send + Sync + 'static,
77 Se::Error: std::error::Error + Send + Sync + 'static,
78{
79 type Response = tonic::Response<Se::Ok>;
80 type Error = GrpcClientError;
81 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
82
83 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84 Poll::Ready(Ok(())) }
86
87 fn call(&mut self, req: GrpcMethodRequest<De, Se>) -> Self::Future {
88 let mut inner = self.inner[&req.destination].clone();
89 Box::pin(async move {
90 let path = req.format_method_path();
91 let GrpcMethodRequest { codec, message, .. } = req;
92 inner.ready().await.map_err(|_| GrpcClientError::Todo)?;
93 inner.unary(tonic::Request::new(message), path, codec).await.map_err(|_| GrpcClientError::Todo)
94 })
95 }
96}
97
98#[derive(Debug, PartialEq, Eq)]
99pub struct MethodCodec<D, S> {
100 method: MethodDescriptor,
101 serializer: S,
102 phantom: PhantomData<(D, S)>,
103}
104impl<D, S: Clone> Clone for MethodCodec<D, S> {
105 fn clone(&self) -> Self {
106 Self { method: self.method.clone(), serializer: self.serializer.clone(), phantom: PhantomData }
107 }
108}
109impl<D, S> MethodCodec<D, S> {
110 pub fn new(method: MethodDescriptor, serializer: S) -> Self {
111 Self { method, serializer, phantom: PhantomData }
112 }
113}
114
115impl<D, S> Codec for MethodCodec<D, S>
116where
117 D: for<'a> Deserializer<'a> + Send + 'static,
118 for<'a> <D as Deserializer<'a>>::Error: std::error::Error + Send + Sync + 'static,
119 S: Serializer + Clone + Send + 'static,
120 S::Ok: Send + 'static,
121 S::Error: std::error::Error + Send + Sync + 'static,
122{
123 type Encode = D;
124 type Decode = S::Ok;
125 type Encoder = MethodEncoder<D>;
126 type Decoder = MethodDecoder<S>;
127
128 fn encoder(&mut self) -> Self::Encoder {
129 MethodEncoder(self.method.input(), PhantomData)
130 }
131
132 fn decoder(&mut self) -> Self::Decoder {
133 MethodDecoder(self.method.output(), self.serializer.clone())
134 }
135}
136
137#[derive(Debug)]
138pub struct MethodEncoder<D>(MessageDescriptor, PhantomData<D>);
139impl<D> Encoder for MethodEncoder<D>
140where
141 D: for<'a> Deserializer<'a>,
142 for<'a> <D as Deserializer<'a>>::Error: std::error::Error + Send + Sync + 'static,
143{
144 type Item = D;
145 type Error = Status;
146
147 fn encode(&mut self, item: Self::Item, dst: &mut tonic::codec::EncodeBuf<'_>) -> Result<(), Self::Error> {
148 let Self(descriptor, _phantom) = self;
149 DynamicMessage::deserialize(descriptor.clone(), item)
150 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?
151 .encode(dst)
152 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
153 Ok(())
154 }
155}
156
157#[derive(Debug)]
158pub struct MethodDecoder<S>(MessageDescriptor, S);
159impl<S> Decoder for MethodDecoder<S>
160where
161 S: Serializer + Clone + Send + 'static,
162 S::Ok: Send + 'static,
163 S::Error: std::error::Error + Send + Sync + 'static,
164{
165 type Item = S::Ok;
166 type Error = Status;
167
168 fn decode(&mut self, src: &mut tonic::codec::DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
169 if !src.has_remaining() {
170 return Ok(None);
171 }
172 let Self(descriptor, serializer) = self;
173 let dynamic_message = DynamicMessage::decode(descriptor.clone(), src) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
175 Ok(Some(
176 dynamic_message
177 .serialize(serializer.clone())
178 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?,
179 ))
180 }
181}