micromegas/servers/
connect_info_layer.rs

1//! Connection info utilities for capturing client socket addresses in Tonic services.
2
3use futures::Stream;
4use std::{
5    io,
6    net::SocketAddr,
7    pin::Pin,
8    task::{Context, Poll},
9};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tonic::transport::server::Connected;
12
13/// A wrapper around a TCP stream that captures and provides the remote socket address.
14/// This is used with Tonic's `serve_with_incoming` to make client IPs available.
15pub struct ConnectedStream<T> {
16    inner: T,
17    remote_addr: SocketAddr,
18}
19
20impl<T> ConnectedStream<T> {
21    pub fn new(inner: T, remote_addr: SocketAddr) -> Self {
22        Self { inner, remote_addr }
23    }
24}
25
26impl<T: AsyncRead + Unpin> AsyncRead for ConnectedStream<T> {
27    fn poll_read(
28        mut self: Pin<&mut Self>,
29        cx: &mut Context<'_>,
30        buf: &mut ReadBuf<'_>,
31    ) -> Poll<io::Result<()>> {
32        Pin::new(&mut self.inner).poll_read(cx, buf)
33    }
34}
35
36impl<T: AsyncWrite + Unpin> AsyncWrite for ConnectedStream<T> {
37    fn poll_write(
38        mut self: Pin<&mut Self>,
39        cx: &mut Context<'_>,
40        buf: &[u8],
41    ) -> Poll<io::Result<usize>> {
42        Pin::new(&mut self.inner).poll_write(cx, buf)
43    }
44
45    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
46        Pin::new(&mut self.inner).poll_flush(cx)
47    }
48
49    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
50        Pin::new(&mut self.inner).poll_shutdown(cx)
51    }
52}
53
54/// Implement Connected trait to provide the remote address to Tonic
55impl<T> Connected for ConnectedStream<T> {
56    type ConnectInfo = SocketAddr;
57
58    fn connect_info(&self) -> Self::ConnectInfo {
59        self.remote_addr
60    }
61}
62
63/// An incoming stream adapter that wraps TCP connections with ConnectedStream
64pub struct ConnectedIncoming {
65    inner: Pin<Box<dyn Stream<Item = Result<tokio::net::TcpStream, io::Error>> + Send>>,
66}
67
68impl ConnectedIncoming {
69    pub fn from_std_listener(listener: std::net::TcpListener) -> io::Result<Self> {
70        listener.set_nonblocking(true)?;
71        let listener = tokio::net::TcpListener::from_std(listener)?;
72        Ok(Self::new(listener))
73    }
74
75    pub fn new(listener: tokio::net::TcpListener) -> Self {
76        let stream = async_stream::stream! {
77            loop {
78                match listener.accept().await {
79                    Ok((stream, _addr)) => yield Ok(stream),
80                    Err(e) => yield Err(e),
81                }
82            }
83        };
84
85        Self {
86            inner: Box::pin(stream),
87        }
88    }
89}
90
91impl Stream for ConnectedIncoming {
92    type Item = Result<ConnectedStream<tokio::net::TcpStream>, io::Error>;
93
94    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95        match self.inner.as_mut().poll_next(cx) {
96            Poll::Ready(Some(Ok(stream))) => {
97                let remote_addr = match stream.peer_addr() {
98                    Ok(addr) => addr,
99                    Err(e) => return Poll::Ready(Some(Err(e))),
100                };
101                Poll::Ready(Some(Ok(ConnectedStream::new(stream, remote_addr))))
102            }
103            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
104            Poll::Ready(None) => Poll::Ready(None),
105            Poll::Pending => Poll::Pending,
106        }
107    }
108}