diff --git a/Cargo.toml b/Cargo.toml index 646974ca0d..1560c1ae60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,7 +82,7 @@ full = [ # HTTP versions http1 = ["dep:atomic-waker", "dep:futures-channel", "dep:futures-core", "dep:httparse", "dep:itoa"] -http2 = ["dep:futures-channel", "dep:futures-core", "dep:h2"] +http2 = ["dep:atomic-waker", "dep:futures-channel", "dep:futures-core", "dep:h2"] # Client/Server client = ["dep:want", "dep:pin-project-lite", "dep:smallvec"] diff --git a/src/proto/h2/upgrade.rs b/src/proto/h2/upgrade.rs index 80a110b5c2..f4e65466a2 100644 --- a/src/proto/h2/upgrade.rs +++ b/src/proto/h2/upgrade.rs @@ -1,8 +1,11 @@ use std::future::Future; use std::io::Cursor; use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use std::task::{Context, Poll}; +use atomic_waker::AtomicWaker; use bytes::{Buf, Bytes}; use futures_channel::{mpsc, oneshot}; use futures_core::{ready, Stream}; @@ -20,10 +23,15 @@ pub(super) fn pair( ) -> (H2Upgraded, UpgradedSendStreamTask) { let (tx, rx) = mpsc::channel(1); let (error_tx, error_rx) = oneshot::channel(); + let close_notify = Arc::new(UpgradedCloseNotify::new()); ( H2Upgraded { - send_stream: UpgradedSendStreamBridge { tx, error_rx }, + send_stream: UpgradedSendStreamBridge { + tx, + error_rx, + close_notify: close_notify.clone(), + }, recv_stream, ping, buf: Bytes::new(), @@ -31,6 +39,7 @@ pub(super) fn pair( UpgradedSendStreamTask { h2_tx: send_stream, rx, + close_notify, error_tx: Some(error_tx), }, ) @@ -46,6 +55,46 @@ pub(super) struct H2Upgraded { struct UpgradedSendStreamBridge { tx: mpsc::Sender>>, error_rx: oneshot::Receiver, + close_notify: Arc, +} + +impl Drop for UpgradedSendStreamBridge { + fn drop(&mut self) { + self.close_notify.close(); + } +} + +struct UpgradedCloseNotify { + closed: AtomicBool, + task: AtomicWaker, +} + +impl UpgradedCloseNotify { + fn new() -> Self { + Self { + closed: AtomicBool::new(false), + task: AtomicWaker::new(), + } + } + + fn close(&self) { + self.closed.store(true, Ordering::Release); + self.task.wake(); + } + + fn poll_closed(&self, cx: &mut Context<'_>) -> Poll<()> { + if self.closed.load(Ordering::Acquire) { + return Poll::Ready(()); + } + + self.task.register(cx.waker()); + + if self.closed.load(Ordering::Acquire) { + Poll::Ready(()) + } else { + Poll::Pending + } + } } pin_project! { @@ -55,6 +104,7 @@ pin_project! { h2_tx: SendStream>, #[pin] rx: mpsc::Receiver>>, + close_notify: Arc, error_tx: Option>, } } @@ -78,12 +128,12 @@ where // for the actual body chunk. me.h2_tx.reserve_capacity(1); - if me.h2_tx.capacity() == 0 { + let h2_has_capacity = if me.h2_tx.capacity() == 0 { // poll_capacity oddly needs a loop - 'capacity: loop { + loop { match me.h2_tx.poll_capacity(cx) { Poll::Ready(Some(Ok(0))) => {} - Poll::Ready(Some(Ok(_))) => break, + Poll::Ready(Some(Ok(_))) => break true, Poll::Ready(Some(Err(e))) => { return Poll::Ready(Err(crate::Error::new_body_write(e))) } @@ -95,10 +145,12 @@ where "send stream capacity unexpectedly closed", ))); } - Poll::Pending => break 'capacity, + Poll::Pending => break false, } } - } + } else { + true + }; match me.h2_tx.poll_reset(cx) { Poll::Ready(Ok(reason)) => { @@ -113,6 +165,25 @@ where Poll::Pending => (), } + // If h2 has no capacity, don't pull another item from the mpsc + // receiver. That would free a channel slot and let the writer + // enqueue more data without h2 backpressure. + // + // Still allow the task to finish once the upgraded write side is + // gone and the mpsc queue is empty. + if !h2_has_capacity { + // `size_hint` reads the queued message count without popping, + // so an accepted write stays queued until h2 capacity returns. + if me.rx.size_hint().0 == 0 && me.close_notify.poll_closed(cx).is_ready() { + me.h2_tx + .send_data(SendBuf::None, true) + .map_err(crate::Error::new_body_write)?; + return Poll::Ready(Ok(())); + } + + return Poll::Pending; + } + match me.rx.as_mut().poll_next(cx) { Poll::Ready(Some(cursor)) => { me.h2_tx @@ -259,6 +330,7 @@ impl Write for H2Upgraded { cx: &mut Context<'_>, ) -> Poll> { self.send_stream.tx.close_channel(); + self.send_stream.close_notify.close(); match Pin::new(&mut self.send_stream.error_rx).poll(cx) { Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))), Poll::Ready(Err(_task_dropped)) => Poll::Ready(Ok(())), diff --git a/tests/server.rs b/tests/server.rs index b4c66022a6..b22f50b11a 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -2300,6 +2300,371 @@ async fn h2_connect_empty_frames() { .unwrap(); } +#[tokio::test] +async fn h2_connect_backpressure_respected() { + let (listener, addr) = setup_tcp_listener(); + let conn = connect_async(addr).await; + + let mut builder = h2::client::Builder::new(); + builder.initial_window_size(1024); + builder.initial_connection_window_size(1024); + let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + const CHUNK: &[u8] = b"backpressure test data chunk!\n"; + const TOTAL_LEN: usize = CHUNK.len() * 2000; + + let client_handle = tokio::spawn(async move { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, _send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let mut received = 0usize; + + while let Some(chunk) = body.data().await { + let chunk = chunk.unwrap(); + if chunk.is_empty() { + break; + } + let len = chunk.len(); + received += len; + let _ = body.flow_control().release_capacity(len); + } + + assert_eq!(received, TOTAL_LEN); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); + + for _ in 0..2000 { + upgraded.write_all(CHUNK).await.unwrap(); + } + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(Empty::::new()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); + http2::Builder::new(TokioExecutor) + .serve_connection(socket, svc) + .await + .unwrap(); + + client_handle.await.unwrap(); +} + +#[tokio::test] +async fn h2_connect_zero_window_then_release() { + let (listener, addr) = setup_tcp_listener(); + let conn = connect_async(addr).await; + + let mut builder = h2::client::Builder::new(); + builder.initial_window_size(65535); + let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + const DATA: &[u8] = b"Hello from upgraded stream"; + + let client_handle = tokio::spawn(async move { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, _send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let mut received = Vec::new(); + + while let Some(chunk) = body.data().await { + let chunk = chunk.unwrap(); + if chunk.is_empty() { + break; + } + let len = chunk.len(); + received.extend_from_slice(&chunk); + let _ = body.flow_control().release_capacity(len); + } + + assert_eq!(&received[..], DATA); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); + upgraded.write_all(DATA).await.unwrap(); + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(Empty::::new()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); + http2::Builder::new(TokioExecutor) + .serve_connection(socket, svc) + .await + .unwrap(); + + client_handle.await.unwrap(); +} + +#[tokio::test] +async fn h2_connect_shutdown_while_send_backpressured() { + let (listener, addr) = setup_tcp_listener(); + let conn = connect_async(addr).await; + + let mut builder = h2::client::Builder::new(); + builder.initial_window_size(1024); + builder.initial_connection_window_size(1024); + let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap(); + tokio::spawn(async move { + let _ = connection.await; + }); + let mut h2 = h2.ready().await.unwrap(); + + let (shutdown_tx, shutdown_rx) = oneshot::channel::(); + let shutdown_tx = Arc::new(Mutex::new(Some(shutdown_tx))); + + let client_handle = tokio::spawn(async move { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, _send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(bytes.len(), 1024); + + // Do not release capacity. The server-side upgraded writer should + // still observe shutdown of its mpsc sender instead of waiting for + // more h2 send capacity. + let shutdown_completed = shutdown_rx.await.unwrap_or(false); + assert!( + shutdown_completed, + "upgraded shutdown should not wait for h2 capacity after the writer closes" + ); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + let shutdown_tx = shutdown_tx.clone(); + + tokio::spawn(async move { + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); + upgraded.write_all(&[b'x'; 1024]).await.unwrap(); + + // Regression trigger: shutdown closes the mpsc sender while the + // send task is already parked waiting for h2 capacity. + let shutdown_completed = + tokio::time::timeout(Duration::from_secs(1), upgraded.shutdown()) + .await + .is_ok(); + + if let Some(tx) = shutdown_tx.lock().unwrap().take() { + let _ = tx.send(shutdown_completed); + } + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(Empty::::new()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); + let _ = http2::Builder::new(TokioExecutor) + .serve_connection(socket, svc) + .await; + + client_handle.await.unwrap(); +} + +#[tokio::test] +async fn h2_connect_reset_during_backpressure() { + let (listener, addr) = setup_tcp_listener(); + let conn = connect_async(addr).await; + + let mut builder = h2::client::Builder::new(); + builder.initial_window_size(1024); + builder.initial_connection_window_size(1024); + let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap(); + tokio::spawn(async move { + let _ = connection.await; + }); + let mut h2 = h2.ready().await.unwrap(); + + let (write_err_tx, write_err_rx) = oneshot::channel::(); + let write_err_tx = Arc::new(Mutex::new(Some(write_err_tx))); + let (reset_tx, reset_rx) = oneshot::channel::<()>(); + let reset_rx = Arc::new(Mutex::new(Some(reset_rx))); + + let client_handle = tokio::spawn(async move { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, mut send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let mut received = 0; + while received < 1024 { + let bytes = body.data().await.unwrap().unwrap(); + received += bytes.len(); + } + assert_eq!(received, 1024); + + send_stream.send_reset(h2::Reason::CANCEL); + let _ = reset_tx.send(()); + drop(body); + drop(send_stream); + + let got_err = write_err_rx.await.unwrap_or(false); + assert!(got_err, "server write side should have observed RST_STREAM"); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + let write_err_tx = write_err_tx.clone(); + let reset_rx = reset_rx.clone(); + + tokio::spawn(async move { + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); + upgraded.write_all(&[b'x'; 1024]).await.unwrap(); + + let reset_rx = reset_rx.lock().unwrap().take().unwrap(); + reset_rx.await.unwrap(); + + let large_data = vec![b'x'; 1024 * 1024]; + let write = upgraded.write_all(&large_data).await; + let shutdown = upgraded.shutdown().await; + + if let Some(tx) = write_err_tx.lock().unwrap().take() { + let _ = tx.send(write.is_err() || shutdown.is_err()); + } + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(Empty::::new()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); + let _ = http2::Builder::new(TokioExecutor) + .serve_connection(socket, svc) + .await; + + client_handle.await.unwrap(); +} + +#[tokio::test] +async fn h2_connect_backpressure_bidirectional() { + let (listener, addr) = setup_tcp_listener(); + let conn = connect_async(addr).await; + + let mut builder = h2::client::Builder::new(); + builder.initial_window_size(2048); + builder.initial_connection_window_size(4096); + let (h2, connection) = builder.handshake::<_, Bytes>(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + const PATTERN: &[u8] = b"All work and no bread makes nox a dull boy.\n"; + const REPEAT: usize = 500; + let expected_len = PATTERN.len() * REPEAT; + + let client_handle = tokio::spawn(async move { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, mut send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let mut received = 0usize; + + while let Some(chunk) = body.data().await { + let chunk = chunk.unwrap(); + if chunk.is_empty() { + break; + } + let len = chunk.len(); + received += len; + let _ = body.flow_control().release_capacity(len); + } + + assert_eq!(received, expected_len); + + send_stream.send_data("client done".into(), true).unwrap(); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = TokioIo::new(on_upgrade.await.expect("on_upgrade")); + + for _ in 0..REPEAT { + upgraded.write_all(PATTERN).await.unwrap(); + } + + upgraded.shutdown().await.unwrap(); + + let mut response_buf = vec![0u8; 64]; + let n = upgraded.read(&mut response_buf).await.unwrap(); + assert_eq!(&response_buf[..n], b"client done"); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(Empty::::new()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + let socket = TokioIo::new(socket); + http2::Builder::new(TokioExecutor) + .serve_connection(socket, svc) + .await + .unwrap(); + + client_handle.await.unwrap(); +} + #[tokio::test] async fn parse_errors_send_4xx_response() { let (listener, addr) = setup_tcp_listener();