1 #![cfg(feature = "connect")]
2
3 use std::{
4 io,
5 net::{Ipv4Addr, SocketAddr},
6 };
7
8 use actix_rt::net::TcpStream;
9 use actix_server::TestServer;
10 use actix_service::{fn_service, Service, ServiceFactory};
11 use futures_core::future::LocalBoxFuture;
12
13 use actix_tls::connect::{new_connector_factory, Connect, Resolve, Resolver};
14
15 #[actix_rt::test]
custom_resolver()16 async fn custom_resolver() {
17 /// Always resolves to localhost with the given port.
18 struct LocalOnlyResolver;
19
20 impl Resolve for LocalOnlyResolver {
21 fn lookup<'a>(
22 &'a self,
23 _host: &'a str,
24 port: u16,
25 ) -> LocalBoxFuture<'a, Result<Vec<SocketAddr>, Box<dyn std::error::Error>>> {
26 Box::pin(async move {
27 let local = format!("127.0.0.1:{}", port).parse().unwrap();
28 Ok(vec![local])
29 })
30 }
31 }
32
33 let addr = LocalOnlyResolver.lookup("example.com", 8080).await.unwrap()[0];
34 assert_eq!(addr, SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8080))
35 }
36
37 #[actix_rt::test]
custom_resolver_connect()38 async fn custom_resolver_connect() {
39 use trust_dns_resolver::TokioAsyncResolver;
40
41 let srv =
42 TestServer::with(|| fn_service(|_io: TcpStream| async { Ok::<_, io::Error>(()) }));
43
44 struct MyResolver {
45 trust_dns: TokioAsyncResolver,
46 }
47
48 impl Resolve for MyResolver {
49 fn lookup<'a>(
50 &'a self,
51 host: &'a str,
52 port: u16,
53 ) -> LocalBoxFuture<'a, Result<Vec<SocketAddr>, Box<dyn std::error::Error>>> {
54 Box::pin(async move {
55 let res = self
56 .trust_dns
57 .lookup_ip(host)
58 .await?
59 .iter()
60 .map(|ip| SocketAddr::new(ip, port))
61 .collect();
62 Ok(res)
63 })
64 }
65 }
66
67 let resolver = MyResolver {
68 trust_dns: TokioAsyncResolver::tokio_from_system_conf().unwrap(),
69 };
70
71 let resolver = Resolver::new_custom(resolver);
72 let factory = new_connector_factory(resolver);
73
74 let conn = factory.new_service(()).await.unwrap();
75 let con = conn
76 .call(Connect::with_addr("example.com", srv.addr()))
77 .await
78 .unwrap();
79 assert_eq!(con.peer_addr().unwrap(), srv.addr());
80 }
81