76 lines
No EOL
2.4 KiB
Rust
76 lines
No EOL
2.4 KiB
Rust
use std::{future::Future, pin::Pin, task::{Context, Poll}};
|
|
use std::net::SocketAddr;
|
|
use std::ops::{Deref, DerefMut};
|
|
|
|
use rocket::{Request, State, http::Status, request::{FromRequest, Outcome}};
|
|
use rocket::outcome::try_outcome;
|
|
use tokio::{net::TcpStream as TokioTcpStream, task};
|
|
use trust_dns_client::{client::AsyncClient, error::ClientError, op::DnsResponse, tcp::TcpClientStream};
|
|
use trust_dns_proto::error::ProtoError;
|
|
use trust_dns_proto::iocompat::AsyncIoTokioAsStd;
|
|
|
|
use crate::config::Config;
|
|
|
|
|
|
pub struct DnsClient(AsyncClient);
|
|
|
|
impl Deref for DnsClient {
|
|
type Target = AsyncClient;
|
|
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
impl DerefMut for DnsClient {
|
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
|
&mut self.0
|
|
}
|
|
}
|
|
|
|
impl DnsClient {
|
|
pub async fn new(addr: SocketAddr) -> Result<Self, ProtoError> {
|
|
let (stream, handle) = TcpClientStream::<AsyncIoTokioAsStd<TokioTcpStream>>::new(addr);
|
|
let client = AsyncClient::with_timeout(
|
|
stream,
|
|
handle,
|
|
std::time::Duration::from_secs(5),
|
|
None);
|
|
let (client, bg) = client.await?;
|
|
task::spawn(bg);
|
|
return Ok(DnsClient(client))
|
|
}
|
|
}
|
|
|
|
|
|
#[rocket::async_trait]
|
|
impl<'r> FromRequest<'r> for DnsClient {
|
|
type Error = ();
|
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
|
let config = try_outcome!(request.guard::<&State<Config>>().await);
|
|
match DnsClient::new(config.dns.server).await {
|
|
Err(e) => {
|
|
println!("Failed to connect to DNS server: {}", e);
|
|
Outcome::Failure((Status::InternalServerError, ()))
|
|
},
|
|
Ok(c) => Outcome::Success(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Reimplement this type here as ClientReponse in trust-dns crate have private fields
|
|
pub struct ClientResponse<R>(pub(crate) R)
|
|
where
|
|
R: Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static;
|
|
|
|
impl<R> Future for ClientResponse<R>
|
|
where
|
|
R: Future<Output = Result<DnsResponse, ProtoError>> + Send + Unpin + 'static,
|
|
{
|
|
type Output = Result<DnsResponse, ClientError>;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
// This is from the future_utils crate, we simply reuse the reexport from Rocket
|
|
rocket::futures::FutureExt::poll_unpin(&mut self.0, cx).map_err(ClientError::from)
|
|
}
|
|
} |