diff --git a/client/src/h1/proto/context.rs b/client/src/h1/proto/context.rs index 1b4affce8..1bbc79d21 100644 --- a/client/src/h1/proto/context.rs +++ b/client/src/h1/proto/context.rs @@ -29,7 +29,7 @@ impl DerefMut for Context<'_, '_, HEADER_LIMIT> { } impl<'c, 'd, const HEADER_LIMIT: usize> Context<'c, 'd, HEADER_LIMIT> { - pub(crate) fn new(date: &'c DateTimeHandle<'d>) -> Self { - Self(context::Context::new(date)) + pub(crate) fn new(date: &'c DateTimeHandle<'d>, is_tls: bool) -> Self { + Self(context::Context::new(date, is_tls)) } } diff --git a/client/src/h1/proto/dispatcher.rs b/client/src/h1/proto/dispatcher.rs index d49b890ff..66e3f9825 100644 --- a/client/src/h1/proto/dispatcher.rs +++ b/client/src/h1/proto/dispatcher.rs @@ -68,8 +68,12 @@ where } } + let is_tls = req + .uri() + .scheme() + .is_some_and(|scheme| scheme == "https" || scheme == "wss"); // TODO: make const generic params configurable. - let mut ctx = Context::<128>::new(&date); + let mut ctx = Context::<128>::new(&date, is_tls); // encode request head and return transfer encoding for request body let encoder = ctx.encode_head(&mut buf, req)?; diff --git a/http/benches/h1_decode.rs b/http/benches/h1_decode.rs index e0d9ce8c0..0154f79e6 100644 --- a/http/benches/h1_decode.rs +++ b/http/benches/h1_decode.rs @@ -31,7 +31,7 @@ impl DateTime for DT { fn decode(c: &mut Criterion) { let dt = DT::dummy_date_time(); - let mut ctx = Context::<_, 8>::new(&dt); + let mut ctx = Context::<_, 8>::new(&dt, false); let req = b"\ GET /HFQR/xitca-web HTTP/1.1\r\n\ diff --git a/http/src/h1/dispatcher.rs b/http/src/h1/dispatcher.rs index b2a13d447..2672b8d03 100644 --- a/http/src/h1/dispatcher.rs +++ b/http/src/h1/dispatcher.rs @@ -63,6 +63,7 @@ pub(crate) async fn run< config: HttpServiceConfig, service: &'a S, date: &'a D, + is_tls: bool, ) -> Result<(), Error> where S: Service, Response = Response>, @@ -77,7 +78,7 @@ where EitherBuf::Right(WriteBuf::::default()) }; - Dispatcher::new(io, addr, timer, config, service, date, write_buf) + Dispatcher::new(io, addr, timer, config, service, date, write_buf, is_tls) .run() .await } @@ -166,6 +167,7 @@ where W: H1BufWrite, D: DateTime, { + #[allow(clippy::too_many_arguments)] fn new( io: &'a mut St, addr: SocketAddr, @@ -174,11 +176,12 @@ where service: &'a S, date: &'a D, write_buf: W, + is_tls: bool, ) -> Self { Self { io: BufferedIo::new(io, write_buf), timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout), - ctx: Context::with_addr(addr, date), + ctx: Context::with_addr(addr, date, is_tls), service, _phantom: PhantomData, } diff --git a/http/src/h1/dispatcher_compio.rs b/http/src/h1/dispatcher_compio.rs index a40eeb77f..ae9f89057 100644 --- a/http/src/h1/dispatcher_compio.rs +++ b/http/src/h1/dispatcher_compio.rs @@ -75,10 +75,16 @@ where ResB: Stream>, D: DateTime, { - pub async fn run(io: TcpStream, addr: SocketAddr, service: &'a S, date: &'a D) -> Result<(), Error> { + pub async fn run( + io: TcpStream, + addr: SocketAddr, + service: &'a S, + date: &'a D, + is_tls: bool, + ) -> Result<(), Error> { let mut dispatcher = Dispatcher::<_, _, _, H_LIMIT, R_LIMIT, W_LIMIT> { io: SharedIo::new(io), - ctx: Context::with_addr(addr, date), + ctx: Context::with_addr(addr, date, is_tls), service, _phantom: PhantomData, }; diff --git a/http/src/h1/dispatcher_uring.rs b/http/src/h1/dispatcher_uring.rs index 1d6f2d3d6..1ecf917c5 100644 --- a/http/src/h1/dispatcher_uring.rs +++ b/http/src/h1/dispatcher_uring.rs @@ -85,12 +85,13 @@ where config: HttpServiceConfig, service: &'a S, date: &'a D, + is_tls: bool, ) -> Result<(), Error> { let mut dispatcher = Dispatcher::<_, _, _, _, H_LIMIT, R_LIMIT, W_LIMIT> { io, notify: Notify::default(), timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout), - ctx: Context::with_addr(addr, date), + ctx: Context::with_addr(addr, date, is_tls), service, _phantom: PhantomData, }; diff --git a/http/src/h1/proto/context.rs b/http/src/h1/proto/context.rs index 1d0035f6f..22324af0b 100644 --- a/http/src/h1/proto/context.rs +++ b/http/src/h1/proto/context.rs @@ -11,6 +11,7 @@ pub struct Context<'a, D, const HEADER_LIMIT: usize> { // http extensions reused by next request. exts: Extensions, date: &'a D, + pub(crate) is_tls: bool, } // A set of state for current request that are used after request's ownership is passed @@ -49,21 +50,22 @@ impl<'a, D, const HEADER_LIMIT: usize> Context<'a, D, HEADER_LIMIT> { /// /// [DateTime]: crate::date::DateTime #[inline] - pub fn new(date: &'a D) -> Self { - Self::with_addr(crate::unspecified_socket_addr(), date) + pub fn new(date: &'a D, is_tls: bool) -> Self { + Self::with_addr(crate::unspecified_socket_addr(), date, is_tls) } /// Context is constructed with [SocketAddr] and reference of certain type that impl [DateTime] trait. /// /// [DateTime]: crate::date::DateTime #[inline] - pub fn with_addr(addr: SocketAddr, date: &'a D) -> Self { + pub fn with_addr(addr: SocketAddr, date: &'a D, is_tls: bool) -> Self { Self { addr, state: ContextState::new(), header: None, exts: Extensions::new(), date, + is_tls, } } diff --git a/http/src/h1/proto/decode.rs b/http/src/h1/proto/decode.rs index cd03ad75d..39a2bc382 100644 --- a/http/src/h1/proto/decode.rs +++ b/http/src/h1/proto/decode.rs @@ -1,5 +1,6 @@ use core::mem::MaybeUninit; +use http::uri::{Authority, Scheme}; use httparse::Status; use crate::{ @@ -71,7 +72,7 @@ impl Context<'_, D, MAX_HEADERS> { // split the headers from buffer. let slice = buf.split_to(len).freeze(); - let uri = Uri::from_maybe_shared(slice.slice(path_head..path_head + path_len))?; + let mut uri = Uri::from_maybe_shared(slice.slice(path_head..path_head + path_len))?.into_parts(); // pop a cached headermap or construct a new one. let mut headers = self.take_headers(); @@ -87,6 +88,25 @@ impl Context<'_, D, MAX_HEADERS> { let extensions = self.take_extensions(); + // Try to set authority from host header if not present in request path + if uri.authority.is_none() { + // @TODO if it's a tls connection we could set the sni server name as authority instead + if let Some(host) = headers.get(http::header::HOST) { + uri.authority = Some(Authority::try_from(host.as_bytes())?); + } + } + + // If authority is set, this will set the correct scheme depending on the tls acceptor used in the service. + if uri.authority.is_some() && uri.scheme.is_none() { + uri.scheme = if self.is_tls { + Some(Scheme::HTTPS) + } else { + Some(Scheme::HTTP) + }; + } + + let uri = Uri::from_parts(uri)?; + *req.method_mut() = method; *req.version_mut() = version; *req.uri_mut() = uri; @@ -173,7 +193,7 @@ mod test { #[test] fn connection_multiple_value() { - let mut ctx = Context::<_, 4>::new(&()); + let mut ctx = Context::<_, 4>::new(&(), false); let head = b"\ GET / HTTP/1.1\r\n\ @@ -211,7 +231,7 @@ mod test { #[test] fn transfer_encoding() { - let mut ctx = Context::<_, 4>::new(&()); + let mut ctx = Context::<_, 4>::new(&(), false); let head = b"\ GET / HTTP/1.1\r\n\ @@ -311,4 +331,33 @@ mod test { "transfer coding is not decoded to chunked" ); } + + #[test] + fn test_host_with_scheme() { + let mut ctx = Context::<_, 4>::new(&(), true); + + let head = b"\ + GET / HTTP/1.1\r\n\ + Host: example.com\r\n\ + \r\n\ + "; + let mut buf = BytesMut::from(&head[..]); + + let (req, _) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap(); + + assert_eq!(req.uri().scheme(), Some(&Scheme::HTTPS)); + assert_eq!(req.uri().authority(), Some(&Authority::from_static("example.com"))); + assert_eq!(req.headers().get(http::header::HOST).unwrap(), "example.com"); + + let head = b"\ + GET / HTTP/1.1\r\n\ + \r\n\ + "; + let mut buf = BytesMut::from(&head[..]); + + let (req, _) = ctx.decode_head::<128>(&mut buf).unwrap().unwrap(); + + assert_eq!(req.uri().scheme(), None); + assert_eq!(req.uri().authority(), None); + } } diff --git a/http/src/h1/proto/encode.rs b/http/src/h1/proto/encode.rs index 263a750d3..2d3c8717a 100644 --- a/http/src/h1/proto/encode.rs +++ b/http/src/h1/proto/encode.rs @@ -257,7 +257,7 @@ mod test { #[test] fn append_header() { - let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler); + let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler, false); let mut res = Response::new(BoxBody::new(Once::new(Bytes::new()))); @@ -287,7 +287,7 @@ mod test { #[test] fn multi_set_cookie() { - let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler); + let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler, false); let mut res = Response::new(BoxBody::new(Once::new(Bytes::new()))); diff --git a/http/src/h1/proto/error.rs b/http/src/h1/proto/error.rs index 8c8d505fb..8e2a81ab0 100644 --- a/http/src/h1/proto/error.rs +++ b/http/src/h1/proto/error.rs @@ -40,6 +40,12 @@ impl From for ProtoError { } } +impl From for ProtoError { + fn from(_: http::uri::InvalidUriParts) -> Self { + Self::Uri + } +} + impl From for ProtoError { fn from(_: http::status::InvalidStatusCode) -> Self { Self::Status diff --git a/http/src/h1/service.rs b/http/src/h1/service.rs index d1dbcc0e1..ab801d4c9 100644 --- a/http/src/h1/service.rs +++ b/http/src/h1/service.rs @@ -9,6 +9,7 @@ use crate::{ error::{HttpServiceError, TimeoutError}, http::{Request, RequestExt, Response}, service::HttpService, + tls::IsTls, util::timer::Timeout, }; @@ -21,7 +22,7 @@ impl for H1Service where S: Service>, Response = Response>, - A: Service, + A: Service + IsTls, St: AsyncIo, A::Response: AsyncIo, B: Stream>, @@ -41,9 +42,17 @@ where .await .map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??; - super::dispatcher::run(&mut io, addr, timer, self.config, &self.service, self.date.get()) - .await - .map_err(Into::into) + super::dispatcher::run( + &mut io, + addr, + timer, + self.config, + &self.service, + self.date.get(), + self.tls_acceptor.is_tls(), + ) + .await + .map_err(Into::into) } } @@ -94,7 +103,7 @@ impl for H1UringService where S: Service>, Response = Response>, - A: Service, + A: Service + IsTls, A::Response: AsyncBufRead + AsyncBufWrite + Clone + 'static, B: Stream>, HttpServiceError: From, @@ -113,9 +122,17 @@ where .await .map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??; - super::dispatcher_uring::Dispatcher::run(io, addr, timer, self.config, &self.service, self.date.get()) - .await - .map_err(Into::into) + super::dispatcher_uring::Dispatcher::run( + io, + addr, + timer, + self.config, + &self.service, + self.date.get(), + self.tls_acceptor.is_tls(), + ) + .await + .map_err(Into::into) } } diff --git a/http/src/service.rs b/http/src/service.rs index 703a814bf..7c94ece2e 100644 --- a/http/src/service.rs +++ b/http/src/service.rs @@ -14,6 +14,7 @@ use super::{ date::{DateTime, DateTimeService}, error::{HttpServiceError, TimeoutError}, http::{Request, RequestExt, Response}, + tls::IsTls, util::timer::{KeepAlive, Timeout}, version::AsVersion, }; @@ -73,7 +74,7 @@ impl where S: Service>, Response = Response>, - A: Service, + A: Service + IsTls, A::Response: AsyncIo + AsVersion, HttpServiceError: From, S::Error: fmt::Debug, @@ -120,6 +121,7 @@ where self.config, &self.service, self.date.get(), + self.tls_acceptor.is_tls(), ) .await .map_err(From::from), @@ -168,6 +170,7 @@ where self.config, &self.service, self.date.get(), + self.tls_acceptor.is_tls(), ) .await .map_err(From::from) diff --git a/http/src/tls/mod.rs b/http/src/tls/mod.rs index bb8e3da63..625b40ca8 100644 --- a/http/src/tls/mod.rs +++ b/http/src/tls/mod.rs @@ -19,6 +19,13 @@ pub use error::TlsError; use xitca_service::Service; +/// A trait to check if an acceptor will create a Tls stream. +pub trait IsTls { + fn is_tls(&self) -> bool { + true + } +} + /// A NoOp Tls Acceptor pass through input Stream type. #[derive(Copy, Clone)] pub struct NoOpTlsAcceptorBuilder; @@ -42,3 +49,9 @@ impl Service for NoOpTlsAcceptorService { Ok(io) } } + +impl IsTls for NoOpTlsAcceptorService { + fn is_tls(&self) -> bool { + false + } +} diff --git a/http/src/tls/native_tls.rs b/http/src/tls/native_tls.rs index 177dc74d6..0204af6a0 100644 --- a/http/src/tls/native_tls.rs +++ b/http/src/tls/native_tls.rs @@ -15,7 +15,7 @@ use xitca_service::Service; use crate::{http::Version, version::AsVersion}; -use super::error::TlsError; +use super::{IsTls, error::TlsError}; /// A wrapper type for [TlsStream](native_tls::TlsStream). /// @@ -92,6 +92,8 @@ impl Service for TlsAcceptorService { } } +impl IsTls for TlsAcceptorService {} + impl AsyncIo for TlsStream { #[inline] fn ready(&mut self, interest: Interest) -> impl Future> + Send { diff --git a/http/src/tls/openssl.rs b/http/src/tls/openssl.rs index 1822254f4..422b46294 100644 --- a/http/src/tls/openssl.rs +++ b/http/src/tls/openssl.rs @@ -8,7 +8,7 @@ use xitca_tls::openssl::ssl; use crate::{http::Version, version::AsVersion}; -use super::error::TlsError; +use super::{IsTls, error::TlsError}; pub type TlsStream = xitca_tls::openssl::TlsStream; @@ -70,6 +70,8 @@ impl Service for TlsAcceptorService { } } +impl IsTls for TlsAcceptorService {} + /// Collection of 'openssl' error types. pub type OpensslError = xitca_tls::openssl::Error; diff --git a/http/src/tls/rustls.rs b/http/src/tls/rustls.rs index 320ed05d8..9d7e68fce 100644 --- a/http/src/tls/rustls.rs +++ b/http/src/tls/rustls.rs @@ -8,7 +8,7 @@ use xitca_tls::rustls::{Error, ServerConfig, ServerConnection, TlsStream as _Tls use crate::{http::Version, version::AsVersion}; -use super::error::TlsError; +use super::{IsTls, error::TlsError}; pub(crate) type RustlsConfig = Arc; @@ -65,6 +65,8 @@ impl Service for TlsAcceptorService { } } +impl IsTls for TlsAcceptorService {} + /// Collection of 'rustls' error types. pub enum RustlsError { Io(io::Error), diff --git a/http/src/tls/rustls_uring.rs b/http/src/tls/rustls_uring.rs index d99dd65e8..a62c36702 100644 --- a/http/src/tls/rustls_uring.rs +++ b/http/src/tls/rustls_uring.rs @@ -11,7 +11,7 @@ use xitca_tls::{ use crate::{http::Version, version::AsVersion}; -use super::rustls::RustlsError; +use super::{IsTls, rustls::RustlsError}; /// A stream managed by rustls for tls read/write. pub struct TlsStream { @@ -66,6 +66,8 @@ where } } +impl IsTls for TlsAcceptorService {} + impl AsyncBufRead for TlsStream where Io: AsyncBufRead,