micromegas/servers/
connect_info_layer.rs1use futures::Stream;
4use std::{
5 io,
6 net::SocketAddr,
7 pin::Pin,
8 task::{Context, Poll},
9};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tonic::transport::server::Connected;
12
13pub struct ConnectedStream<T> {
16 inner: T,
17 remote_addr: SocketAddr,
18}
19
20impl<T> ConnectedStream<T> {
21 pub fn new(inner: T, remote_addr: SocketAddr) -> Self {
22 Self { inner, remote_addr }
23 }
24}
25
26impl<T: AsyncRead + Unpin> AsyncRead for ConnectedStream<T> {
27 fn poll_read(
28 mut self: Pin<&mut Self>,
29 cx: &mut Context<'_>,
30 buf: &mut ReadBuf<'_>,
31 ) -> Poll<io::Result<()>> {
32 Pin::new(&mut self.inner).poll_read(cx, buf)
33 }
34}
35
36impl<T: AsyncWrite + Unpin> AsyncWrite for ConnectedStream<T> {
37 fn poll_write(
38 mut self: Pin<&mut Self>,
39 cx: &mut Context<'_>,
40 buf: &[u8],
41 ) -> Poll<io::Result<usize>> {
42 Pin::new(&mut self.inner).poll_write(cx, buf)
43 }
44
45 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
46 Pin::new(&mut self.inner).poll_flush(cx)
47 }
48
49 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
50 Pin::new(&mut self.inner).poll_shutdown(cx)
51 }
52}
53
54impl<T> Connected for ConnectedStream<T> {
56 type ConnectInfo = SocketAddr;
57
58 fn connect_info(&self) -> Self::ConnectInfo {
59 self.remote_addr
60 }
61}
62
63pub struct ConnectedIncoming {
65 inner: Pin<Box<dyn Stream<Item = Result<tokio::net::TcpStream, io::Error>> + Send>>,
66}
67
68impl ConnectedIncoming {
69 pub fn from_std_listener(listener: std::net::TcpListener) -> io::Result<Self> {
70 listener.set_nonblocking(true)?;
71 let listener = tokio::net::TcpListener::from_std(listener)?;
72 Ok(Self::new(listener))
73 }
74
75 pub fn new(listener: tokio::net::TcpListener) -> Self {
76 let stream = async_stream::stream! {
77 loop {
78 match listener.accept().await {
79 Ok((stream, _addr)) => yield Ok(stream),
80 Err(e) => yield Err(e),
81 }
82 }
83 };
84
85 Self {
86 inner: Box::pin(stream),
87 }
88 }
89}
90
91impl Stream for ConnectedIncoming {
92 type Item = Result<ConnectedStream<tokio::net::TcpStream>, io::Error>;
93
94 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95 match self.inner.as_mut().poll_next(cx) {
96 Poll::Ready(Some(Ok(stream))) => {
97 let remote_addr = match stream.peer_addr() {
98 Ok(addr) => addr,
99 Err(e) => return Poll::Ready(Some(Err(e))),
100 };
101 Poll::Ready(Some(Ok(ConnectedStream::new(stream, remote_addr))))
102 }
103 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
104 Poll::Ready(None) => Poll::Ready(None),
105 Poll::Pending => Poll::Pending,
106 }
107 }
108}