relentless_grpc/
client.rs

1use std::{
2    collections::HashMap,
3    future::Future,
4    marker::PhantomData,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use bytes::Buf;
10use http::{uri::PathAndQuery, Uri};
11use prost::Message;
12use prost_reflect::{DynamicMessage, MessageDescriptor, MethodDescriptor, ServiceDescriptor};
13use serde::{Deserializer, Serialize, Serializer};
14use tonic::{
15    body::Body as BoxBody,
16    client::GrpcService,
17    codec::{Codec, Decoder, Encoder},
18    transport::{Body, Channel},
19    Status,
20};
21use tower::Service;
22
23use crate::error::GrpcClientError;
24
25#[derive(Debug, Clone, PartialEq)]
26pub struct GrpcMethodRequest<D, S> {
27    pub destination: http::Uri,
28    pub service: ServiceDescriptor,
29    pub method: MethodDescriptor,
30    pub codec: MethodCodec<D, S>,
31    pub message: D,
32}
33impl<D, S> GrpcMethodRequest<D, S> {
34    pub fn format_method_path(&self) -> PathAndQuery {
35        // https://github.com/hyperium/tonic/blob/master/tonic-build/src/lib.rs#L212-L218
36        format!("/{}/{}", self.service.full_name(), self.method.name())
37            .parse()
38            .unwrap_or_else(|e| unreachable!("{}", e))
39    }
40}
41
42#[derive(Debug, Clone)]
43pub struct GrpcClient<S> {
44    inner: HashMap<Uri, tonic::client::Grpc<S>>,
45}
46
47impl GrpcClient<tonic::transport::Channel> {
48    pub async fn new(all_destinations: &[Uri]) -> Result<Self, GrpcClientError> {
49        let mut clients = HashMap::new();
50        for d in all_destinations {
51            let channel = Channel::builder(d.clone()).connect().await.unwrap_or_else(|e| todo!("{}", e));
52            clients.insert(d.clone(), tonic::client::Grpc::new(channel));
53        }
54        Ok(Self { inner: clients })
55    }
56}
57impl<S> GrpcClient<S>
58where
59    S: Clone,
60{
61    pub async fn from_services(services: &HashMap<Uri, S>) -> Result<Self, GrpcClientError> {
62        let clients = services.iter().map(|(d, s)| (d.clone(), tonic::client::Grpc::new(s.clone()))).collect();
63        Ok(Self { inner: clients })
64    }
65}
66
67impl<S, De, Se> Service<GrpcMethodRequest<De, Se>> for GrpcClient<S>
68where
69    S: GrpcService<BoxBody> + Clone + Send + 'static,
70    S::ResponseBody: Send,
71    <S::ResponseBody as Body>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
72    S::Future: Send + 'static,
73    De: for<'a> Deserializer<'a> + Send + Sync + 'static,
74    for<'a> <De as Deserializer<'a>>::Error: std::error::Error + Send + Sync + 'static,
75    Se: Serializer + Clone + Send + Sync + 'static,
76    Se::Ok: Send + Sync + 'static,
77    Se::Error: std::error::Error + Send + Sync + 'static,
78{
79    type Response = tonic::Response<Se::Ok>;
80    type Error = GrpcClientError;
81    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
82
83    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84        Poll::Ready(Ok(())) // TODO
85    }
86
87    fn call(&mut self, req: GrpcMethodRequest<De, Se>) -> Self::Future {
88        let mut inner = self.inner[&req.destination].clone();
89        Box::pin(async move {
90            let path = req.format_method_path();
91            let GrpcMethodRequest { codec, message, .. } = req;
92            inner.ready().await.map_err(|_| GrpcClientError::Todo)?;
93            inner.unary(tonic::Request::new(message), path, codec).await.map_err(|_| GrpcClientError::Todo)
94        })
95    }
96}
97
98#[derive(Debug, PartialEq, Eq)]
99pub struct MethodCodec<D, S> {
100    method: MethodDescriptor,
101    serializer: S,
102    phantom: PhantomData<(D, S)>,
103}
104impl<D, S: Clone> Clone for MethodCodec<D, S> {
105    fn clone(&self) -> Self {
106        Self { method: self.method.clone(), serializer: self.serializer.clone(), phantom: PhantomData }
107    }
108}
109impl<D, S> MethodCodec<D, S> {
110    pub fn new(method: MethodDescriptor, serializer: S) -> Self {
111        Self { method, serializer, phantom: PhantomData }
112    }
113}
114
115impl<D, S> Codec for MethodCodec<D, S>
116where
117    D: for<'a> Deserializer<'a> + Send + 'static,
118    for<'a> <D as Deserializer<'a>>::Error: std::error::Error + Send + Sync + 'static,
119    S: Serializer + Clone + Send + 'static,
120    S::Ok: Send + 'static,
121    S::Error: std::error::Error + Send + Sync + 'static,
122{
123    type Encode = D;
124    type Decode = S::Ok;
125    type Encoder = MethodEncoder<D>;
126    type Decoder = MethodDecoder<S>;
127
128    fn encoder(&mut self) -> Self::Encoder {
129        MethodEncoder(self.method.input(), PhantomData)
130    }
131
132    fn decoder(&mut self) -> Self::Decoder {
133        MethodDecoder(self.method.output(), self.serializer.clone())
134    }
135}
136
137#[derive(Debug)]
138pub struct MethodEncoder<D>(MessageDescriptor, PhantomData<D>);
139impl<D> Encoder for MethodEncoder<D>
140where
141    D: for<'a> Deserializer<'a>,
142    for<'a> <D as Deserializer<'a>>::Error: std::error::Error + Send + Sync + 'static,
143{
144    type Item = D;
145    type Error = Status;
146
147    fn encode(&mut self, item: Self::Item, dst: &mut tonic::codec::EncodeBuf<'_>) -> Result<(), Self::Error> {
148        let Self(descriptor, _phantom) = self;
149        DynamicMessage::deserialize(descriptor.clone(), item)
150            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?
151            .encode(dst)
152            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
153        Ok(())
154    }
155}
156
157#[derive(Debug)]
158pub struct MethodDecoder<S>(MessageDescriptor, S);
159impl<S> Decoder for MethodDecoder<S>
160where
161    S: Serializer + Clone + Send + 'static,
162    S::Ok: Send + 'static,
163    S::Error: std::error::Error + Send + Sync + 'static,
164{
165    type Item = S::Ok;
166    type Error = Status;
167
168    fn decode(&mut self, src: &mut tonic::codec::DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
169        if !src.has_remaining() {
170            return Ok(None);
171        }
172        let Self(descriptor, serializer) = self;
173        let dynamic_message = DynamicMessage::decode(descriptor.clone(), src) // TODO `decode` requires ownership of MethodDescriptor
174            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
175        Ok(Some(
176            dynamic_message
177                .serialize(serializer.clone())
178                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?,
179        ))
180    }
181}