Skip to content

Commit c69bfa7

Browse files
committed
refactor(sinktools): use pin_project in LazySinkSource
1 parent 68920b6 commit c69bfa7

1 file changed

Lines changed: 97 additions & 185 deletions

File tree

sinktools/src/lazy_sink_source.rs

Lines changed: 97 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! [`LazySinkSource`], and related items.
22
3+
use core::future::Future;
34
use core::marker::PhantomData;
45
use core::pin::Pin;
56
use core::task::{Context, Poll, Waker};
@@ -8,6 +9,7 @@ use std::task::Wake;
89

910
use futures_util::task::AtomicWaker;
1011
use futures_util::{Sink, Stream, ready};
12+
use pin_project_lite::pin_project;
1113

1214
#[derive(Default)]
1315
struct DualWaker {
@@ -34,80 +36,84 @@ impl Wake for DualWaker {
3436
}
3537
}
3638

37-
enum SharedState<Fut, St, Si, Item> {
38-
Uninit {
39-
future: Pin<Box<Fut>>,
40-
},
41-
Thunkulating {
42-
future: Pin<Box<Fut>>,
43-
item: Option<Item>,
44-
dual_waker_state: Arc<DualWaker>,
45-
dual_waker_waker: Waker,
46-
},
47-
Done {
48-
stream: Pin<Box<St>>,
49-
sink: Pin<Box<Si>>,
50-
buf: Option<Item>,
51-
},
52-
Taken,
39+
pin_project! {
40+
#[project = SharedStateProj]
41+
enum SharedState<Fut, St, Si, Item> {
42+
Uninit {
43+
// The future, always `Some` in this state.
44+
future: Option<Fut>,
45+
},
46+
Thunkulating {
47+
#[pin]
48+
future: Fut,
49+
item: Option<Item>,
50+
dual_waker_state: Arc<DualWaker>,
51+
dual_waker_waker: Waker,
52+
},
53+
Done {
54+
#[pin]
55+
stream: St,
56+
#[pin]
57+
sink: Si,
58+
buf: Option<Item>,
59+
},
60+
}
5361
}
5462

55-
/// A lazy sink-source, where the internal state is initialized when the first item is attempted to be pulled from the
56-
/// source, or when the first item is sent to the sink. To split into separate source and sink halves, use
57-
/// [`futures_util::StreamExt::split`].
58-
pub struct LazySinkSource<Fut, St, Si, Item, Error> {
59-
state: SharedState<Fut, St, Si, Item>,
60-
_phantom: PhantomData<Error>,
63+
pin_project! {
64+
/// A lazy sink-source, where the internal state is initialized when the first item is attempted to be pulled from the
65+
/// source, or when the first item is sent to the sink. To split into separate source and sink halves, use
66+
/// [`futures_util::StreamExt::split`].
67+
pub struct LazySinkSource<Fut, St, Si, Item, Error> {
68+
#[pin]
69+
state: SharedState<Fut, St, Si, Item>,
70+
_phantom: PhantomData<Error>,
71+
}
6172
}
6273

6374
impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error> {
6475
/// Creates a new `LazySinkSource` with the given initialization future.
6576
pub fn new(future: Fut) -> Self {
6677
Self {
67-
state: SharedState::Uninit {
68-
future: Box::pin(future),
69-
},
78+
state: SharedState::Uninit { future: Some(future) },
7079
_phantom: PhantomData,
7180
}
7281
}
7382
}
7483

75-
impl<Fut, St, Si, Item, Error> Sink<Item> for LazySinkSource<Fut, St, Si, Item, Error>
84+
impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error>
7685
where
77-
Self: Unpin,
7886
Fut: Future<Output = Result<(St, Si), Error>>,
7987
St: Stream,
8088
Si: Sink<Item>,
8189
Error: From<Si::Error>,
8290
{
83-
type Error = Error;
84-
85-
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86-
let state = &mut self.get_mut().state;
87-
88-
if let SharedState::Uninit { .. } = &*state {
91+
fn poll_sink_op(
92+
self: Pin<&mut Self>,
93+
cx: &mut Context<'_>,
94+
sink_op: impl FnOnce(Pin<&mut Si>, &mut Context<'_>) -> Poll<Result<(), Si::Error>>,
95+
) -> Poll<Result<(), Error>> {
96+
let mut this = self.project();
97+
98+
if let SharedStateProj::Uninit { .. } = this.state.as_mut().project() {
8999
return Poll::Ready(Ok(()));
90100
}
91101

92-
if let SharedState::Thunkulating {
102+
if let SharedStateProj::Thunkulating {
93103
future,
94104
item,
95105
dual_waker_state,
96106
dual_waker_waker,
97-
} = &mut *state
107+
} = this.state.as_mut().project()
98108
{
99109
dual_waker_state.sink.register(cx.waker());
100110

101111
let mut dual_context = Context::from_waker(dual_waker_waker);
102112

103-
match future.as_mut().poll(&mut dual_context) {
113+
match future.poll(&mut dual_context) {
104114
Poll::Ready(Ok((stream, sink))) => {
105115
let buf = item.take();
106-
*state = SharedState::Done {
107-
stream: Box::pin(stream),
108-
sink: Box::pin(sink),
109-
buf,
110-
};
116+
this.state.as_mut().set(SharedState::Done { stream, sink, buf });
111117
}
112118
Poll::Ready(Err(e)) => {
113119
return Poll::Ready(Err(e));
@@ -118,192 +124,104 @@ where
118124
}
119125
}
120126

121-
if let SharedState::Done { sink, buf, .. } = &mut *state {
127+
if let SharedStateProj::Done { mut sink, buf, .. } = this.state.as_mut().project() {
122128
if buf.is_some() {
123129
ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
124130
sink.as_mut().start_send(buf.take().unwrap())?;
125131
}
126-
let result = sink.as_mut().poll_ready(cx).map_err(From::from);
127-
return result;
132+
return (sink_op)(sink, cx).map_err(From::from);
128133
}
129134

130135
panic!("LazySinkSource in invalid state.");
131136
}
137+
}
138+
139+
impl<Fut, St, Si, Item, Error> Sink<Item> for LazySinkSource<Fut, St, Si, Item, Error>
140+
where
141+
Fut: Future<Output = Result<(St, Si), Error>>,
142+
St: Stream,
143+
Si: Sink<Item>,
144+
Error: From<Si::Error>,
145+
{
146+
type Error = Error;
147+
148+
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
149+
self.poll_sink_op(cx, Sink::poll_ready)
150+
}
132151

133152
fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
134-
let state = &mut self.get_mut().state;
135-
136-
if let SharedState::Uninit { .. } = &*state {
137-
let old_state = std::mem::replace(&mut *state, SharedState::Taken);
138-
if let SharedState::Uninit { future } = old_state {
139-
let (dual_waker_state, dual_waker_waker) = DualWaker::new();
140-
*state = SharedState::Thunkulating {
141-
future,
142-
item: Some(item),
143-
dual_waker_state,
144-
dual_waker_waker,
145-
};
146-
147-
return Ok(());
148-
}
153+
let mut this = self.project();
154+
155+
if let SharedStateProj::Uninit { future } = this.state.as_mut().project() {
156+
let future = future.take().unwrap();
157+
let (dual_waker_state, dual_waker_waker) = DualWaker::new();
158+
this.state.as_mut().set(SharedState::Thunkulating {
159+
future,
160+
item: Some(item),
161+
dual_waker_state,
162+
dual_waker_waker,
163+
});
164+
return Ok(());
149165
}
150166

151-
if let SharedState::Thunkulating { .. } = &mut *state {
167+
if let SharedStateProj::Thunkulating { .. } = this.state.as_mut().project() {
152168
panic!("LazySinkSource not ready.");
153169
}
154170

155-
if let SharedState::Done { sink, buf, .. } = &mut *state {
171+
if let SharedStateProj::Done { sink, buf, .. } = this.state.as_mut().project() {
156172
debug_assert!(buf.is_none());
157-
let result = sink.as_mut().start_send(item).map_err(From::from);
158-
return result;
173+
return sink.start_send(item).map_err(From::from);
159174
}
160175

161176
panic!("LazySinkSource not ready.");
162177
}
163178

164179
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165-
let state = &mut self.get_mut().state;
166-
167-
if let SharedState::Uninit { .. } = &*state {
168-
return Poll::Ready(Ok(()));
169-
}
170-
171-
if let SharedState::Thunkulating {
172-
future,
173-
item,
174-
dual_waker_state,
175-
dual_waker_waker,
176-
} = &mut *state
177-
{
178-
dual_waker_state.sink.register(cx.waker());
179-
180-
let mut new_context = Context::from_waker(dual_waker_waker);
181-
182-
match future.as_mut().poll(&mut new_context) {
183-
Poll::Ready(Ok((stream, sink))) => {
184-
let buf = item.take();
185-
*state = SharedState::Done {
186-
stream: Box::pin(stream),
187-
sink: Box::pin(sink),
188-
buf,
189-
};
190-
}
191-
Poll::Ready(Err(e)) => {
192-
return Poll::Ready(Err(e));
193-
}
194-
Poll::Pending => {
195-
return Poll::Pending;
196-
}
197-
}
198-
}
199-
200-
if let SharedState::Done { sink, buf, .. } = &mut *state {
201-
if buf.is_some() {
202-
ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
203-
sink.as_mut().start_send(buf.take().unwrap())?;
204-
}
205-
let result = sink.as_mut().poll_flush(cx).map_err(From::from);
206-
return result;
207-
}
208-
209-
panic!("LazySinkHalf in invalid state.");
180+
self.poll_sink_op(cx, Sink::poll_flush)
210181
}
211182

212183
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
213-
let state = &mut self.get_mut().state;
214-
215-
if let SharedState::Uninit { .. } = &*state {
216-
return Poll::Ready(Ok(()));
217-
}
218-
219-
if let SharedState::Thunkulating {
220-
future,
221-
item,
222-
dual_waker_state,
223-
dual_waker_waker,
224-
} = &mut *state
225-
{
226-
dual_waker_state.sink.register(cx.waker());
227-
228-
let mut new_context = Context::from_waker(dual_waker_waker);
229-
230-
match future.as_mut().poll(&mut new_context) {
231-
Poll::Ready(Ok((stream, sink))) => {
232-
let buf = item.take();
233-
*state = SharedState::Done {
234-
stream: Box::pin(stream),
235-
sink: Box::pin(sink),
236-
buf,
237-
};
238-
}
239-
Poll::Ready(Err(e)) => {
240-
return Poll::Ready(Err(e));
241-
}
242-
Poll::Pending => {
243-
return Poll::Pending;
244-
}
245-
}
246-
}
247-
248-
if let SharedState::Done { sink, buf, .. } = &mut *state {
249-
if buf.is_some() {
250-
ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
251-
sink.as_mut().start_send(buf.take().unwrap())?;
252-
}
253-
let result = sink.as_mut().poll_close(cx).map_err(From::from);
254-
return result;
255-
}
256-
257-
panic!("LazySinkHalf in invalid state.");
184+
self.poll_sink_op(cx, Sink::poll_close)
258185
}
259186
}
260187

261188
impl<Fut, St, Si, Item, Error> Stream for LazySinkSource<Fut, St, Si, Item, Error>
262189
where
263-
Self: Unpin,
264190
Fut: Future<Output = Result<(St, Si), Error>>,
265191
St: Stream,
266192
Si: Sink<Item>,
267193
{
268194
type Item = St::Item;
269195

270196
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
271-
let state = &mut self.get_mut().state;
272-
273-
if let SharedState::Uninit { .. } = &*state {
274-
let old_state = std::mem::replace(&mut *state, SharedState::Taken);
275-
if let SharedState::Uninit { future } = old_state {
276-
let (dual_waker_state, dual_waker_waker) = DualWaker::new();
277-
*state = SharedState::Thunkulating {
278-
future,
279-
item: None,
280-
dual_waker_state,
281-
dual_waker_waker,
282-
};
283-
} else {
284-
unreachable!();
285-
}
197+
let mut this = self.project();
198+
199+
if let SharedStateProj::Uninit { future } = this.state.as_mut().project() {
200+
let future = future.take().unwrap();
201+
let (dual_waker_state, dual_waker_waker) = DualWaker::new();
202+
this.state.as_mut().set(SharedState::Thunkulating {
203+
future,
204+
item: None,
205+
dual_waker_state,
206+
dual_waker_waker,
207+
});
286208
}
287209

288-
if let SharedState::Thunkulating {
210+
if let SharedStateProj::Thunkulating {
289211
future,
290212
item,
291213
dual_waker_state,
292214
dual_waker_waker,
293-
} = &mut *state
215+
} = this.state.as_mut().project()
294216
{
295217
dual_waker_state.stream.register(cx.waker());
296218

297219
let mut new_context = Context::from_waker(dual_waker_waker);
298220

299-
match future.as_mut().poll(&mut new_context) {
221+
match future.poll(&mut new_context) {
300222
Poll::Ready(Ok((stream, sink))) => {
301223
let buf = item.take();
302-
*state = SharedState::Done {
303-
stream: Box::pin(stream),
304-
sink: Box::pin(sink),
305-
buf,
306-
};
224+
this.state.as_mut().set(SharedState::Done { stream, sink, buf });
307225
}
308226

309227
Poll::Ready(Err(_)) => {
@@ -316,14 +234,8 @@ where
316234
}
317235
}
318236

319-
if let SharedState::Done { stream, .. } = &mut *state {
320-
let result = stream.as_mut().poll_next(cx);
321-
match &result {
322-
Poll::Ready(Some(_)) => {}
323-
Poll::Ready(None) => {}
324-
Poll::Pending => {}
325-
}
326-
return result;
237+
if let SharedStateProj::Done { stream, .. } = this.state.as_mut().project() {
238+
return stream.poll_next(cx);
327239
}
328240

329241
panic!("LazySinkSource in invalid state.");

0 commit comments

Comments
 (0)