Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ full = [

# HTTP versions
http1 = ["dep:atomic-waker", "dep:futures-channel", "dep:futures-core", "dep:httparse", "dep:itoa"]
http2 = ["dep:futures-channel", "dep:futures-core", "dep:h2"]
http2 = ["dep:atomic-waker", "dep:futures-channel", "dep:futures-core", "dep:h2"]

# Client/Server
client = ["dep:want", "dep:pin-project-lite", "dep:smallvec"]
Expand Down
84 changes: 78 additions & 6 deletions src/proto/h2/upgrade.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::future::Future;
use std::io::Cursor;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};

use atomic_waker::AtomicWaker;
use bytes::{Buf, Bytes};
use futures_channel::{mpsc, oneshot};
use futures_core::{ready, Stream};
Expand All @@ -20,17 +23,23 @@ pub(super) fn pair<B>(
) -> (H2Upgraded, UpgradedSendStreamTask<B>) {
let (tx, rx) = mpsc::channel(1);
let (error_tx, error_rx) = oneshot::channel();
let close_notify = Arc::new(UpgradedCloseNotify::new());

(
H2Upgraded {
send_stream: UpgradedSendStreamBridge { tx, error_rx },
send_stream: UpgradedSendStreamBridge {
tx,
error_rx,
close_notify: close_notify.clone(),
},
recv_stream,
ping,
buf: Bytes::new(),
},
UpgradedSendStreamTask {
h2_tx: send_stream,
rx,
close_notify,
error_tx: Some(error_tx),
},
)
Expand All @@ -46,6 +55,46 @@ pub(super) struct H2Upgraded {
struct UpgradedSendStreamBridge {
tx: mpsc::Sender<Cursor<Box<[u8]>>>,
error_rx: oneshot::Receiver<crate::Error>,
close_notify: Arc<UpgradedCloseNotify>,
}

impl Drop for UpgradedSendStreamBridge {
fn drop(&mut self) {
self.close_notify.close();
}
}

struct UpgradedCloseNotify {
closed: AtomicBool,
task: AtomicWaker,
}

impl UpgradedCloseNotify {
fn new() -> Self {
Self {
closed: AtomicBool::new(false),
task: AtomicWaker::new(),
}
}

fn close(&self) {
self.closed.store(true, Ordering::Release);
self.task.wake();
}

fn poll_closed(&self, cx: &mut Context<'_>) -> Poll<()> {
if self.closed.load(Ordering::Acquire) {
return Poll::Ready(());
}

self.task.register(cx.waker());

if self.closed.load(Ordering::Acquire) {
Poll::Ready(())
} else {
Poll::Pending
}
}
}

pin_project! {
Expand All @@ -55,6 +104,7 @@ pin_project! {
h2_tx: SendStream<SendBuf<B>>,
#[pin]
rx: mpsc::Receiver<Cursor<Box<[u8]>>>,
close_notify: Arc<UpgradedCloseNotify>,
error_tx: Option<oneshot::Sender<crate::Error>>,
}
}
Expand All @@ -78,12 +128,12 @@ where
// for the actual body chunk.
me.h2_tx.reserve_capacity(1);

if me.h2_tx.capacity() == 0 {
let h2_has_capacity = if me.h2_tx.capacity() == 0 {
// poll_capacity oddly needs a loop
'capacity: loop {
loop {
match me.h2_tx.poll_capacity(cx) {
Poll::Ready(Some(Ok(0))) => {}
Poll::Ready(Some(Ok(_))) => break,
Poll::Ready(Some(Ok(_))) => break true,
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Err(crate::Error::new_body_write(e)))
}
Expand All @@ -95,10 +145,12 @@ where
"send stream capacity unexpectedly closed",
)));
}
Poll::Pending => break 'capacity,
Poll::Pending => break false,
}
}
}
} else {
true
};

match me.h2_tx.poll_reset(cx) {
Poll::Ready(Ok(reason)) => {
Expand All @@ -113,6 +165,25 @@ where
Poll::Pending => (),
}

// If h2 has no capacity, don't pull another item from the mpsc
// receiver. That would free a channel slot and let the writer
// enqueue more data without h2 backpressure.
//
// Still allow the task to finish once the upgraded write side is
// gone and the mpsc queue is empty.
if !h2_has_capacity {
// `size_hint` reads the queued message count without popping,
// so an accepted write stays queued until h2 capacity returns.
if me.rx.size_hint().0 == 0 && me.close_notify.poll_closed(cx).is_ready() {
me.h2_tx
.send_data(SendBuf::None, true)
.map_err(crate::Error::new_body_write)?;
return Poll::Ready(Ok(()));
}

return Poll::Pending;
}

match me.rx.as_mut().poll_next(cx) {
Poll::Ready(Some(cursor)) => {
me.h2_tx
Expand Down Expand Up @@ -259,6 +330,7 @@ impl Write for H2Upgraded {
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
self.send_stream.tx.close_channel();
self.send_stream.close_notify.close();
match Pin::new(&mut self.send_stream.error_rx).poll(cx) {
Poll::Ready(Ok(reason)) => Poll::Ready(Err(io_error(reason))),
Poll::Ready(Err(_task_dropped)) => Poll::Ready(Ok(())),
Expand Down
Loading
Loading