Skip to content

Commit 64adcfe

Browse files
committed
switch to multitask
Signed-off-by: Marc-Antoine Perennou <Marc-Antoine@Perennou.com>
1 parent a6f6d04 commit 64adcfe

File tree

11 files changed

+130
-35
lines changed

11 files changed

+130
-35
lines changed

Cargo.toml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ default = [
2929
"blocking",
3030
"kv-log-macro",
3131
"log",
32+
"multitask",
3233
"num_cpus",
3334
"pin-project-lite",
34-
"smol",
3535
]
3636
docs = ["attributes", "unstable", "default"]
3737
unstable = [
@@ -56,7 +56,7 @@ alloc = [
5656
"futures-core/alloc",
5757
"pin-project-lite",
5858
]
59-
tokio02 = ["smol/tokio02"]
59+
tokio02 = ["tokio"]
6060

6161
[dependencies]
6262
async-attributes = { version = "1.1.1", optional = true }
@@ -81,7 +81,7 @@ surf = { version = "1.0.3", optional = true }
8181
[target.'cfg(not(target_os = "unknown"))'.dependencies]
8282
async-io = { version = "0.1.2", optional = true }
8383
blocking = { version = "0.4.6", optional = true }
84-
smol = { version = "0.1.17", optional = true }
84+
multitask = { version = "0.2.0", optional = true }
8585

8686
[target.'cfg(target_arch = "wasm32")'.dependencies]
8787
futures-timer = { version = "3.0.2", optional = true, features = ["wasm-bindgen"] }
@@ -91,6 +91,12 @@ futures-channel = { version = "0.3.4", optional = true }
9191
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
9292
wasm-bindgen-test = "0.3.10"
9393

94+
[dependencies.tokio]
95+
version = "0.2"
96+
default-features = false
97+
features = ["rt-threaded"]
98+
optional = true
99+
94100
[dev-dependencies]
95101
femme = "1.3.0"
96102
rand = "0.7.3"

examples/line-count.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use async_std::task;
1010
fn main() -> io::Result<()> {
1111
let path = args().nth(1).expect("missing path argument");
1212

13-
task::block_on(async {
13+
task::block_on(async move {
1414
let file = File::open(&path).await?;
1515
let mut lines = BufReader::new(file).lines();
1616
let mut count = 0u64;

examples/list-dir.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use async_std::task;
1010
fn main() -> io::Result<()> {
1111
let path = args().nth(1).expect("missing path argument");
1212

13-
task::block_on(async {
13+
task::block_on(async move {
1414
let mut dir = fs::read_dir(&path).await?;
1515

1616
while let Some(res) = dir.next().await {

examples/print-file.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ const LEN: usize = 16 * 1024; // 16 Kb
1212
fn main() -> io::Result<()> {
1313
let path = args().nth(1).expect("missing path argument");
1414

15-
task::block_on(async {
15+
task::block_on(async move {
1616
let mut file = File::open(&path).await?;
1717
let mut stdout = io::stdout();
1818
let mut buf = vec![0u8; LEN];

src/task/block_on.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ use crate::task::Builder;
2828
#[cfg(not(target_os = "unknown"))]
2929
pub fn block_on<F, T>(future: F) -> T
3030
where
31-
F: Future<Output = T>,
31+
F: Future<Output = T> + 'static,
32+
T: 'static,
3233
{
3334
Builder::new().blocking(future)
3435
}

src/task/builder.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ use std::task::{Context, Poll};
77
use pin_project_lite::pin_project;
88

99
use crate::io;
10-
use crate::task::{JoinHandle, Task, TaskLocalsWrapper};
10+
use crate::task::{self, JoinHandle, Task, TaskLocalsWrapper};
1111

1212
/// Task builder that configures the settings of a new task.
13-
#[derive(Debug, Default)]
13+
#[derive(Default, Debug)]
1414
pub struct Builder {
1515
pub(crate) name: Option<String>,
1616
}
@@ -61,9 +61,9 @@ impl Builder {
6161
});
6262

6363
let task = wrapped.tag.task().clone();
64-
let smol_task = smol::Task::spawn(wrapped).into();
64+
let handle = task::executor::spawn(wrapped);
6565

66-
Ok(JoinHandle::new(smol_task, task))
66+
Ok(JoinHandle::new(handle, task))
6767
}
6868

6969
/// Spawns a task locally with the configured settings.
@@ -81,9 +81,9 @@ impl Builder {
8181
});
8282

8383
let task = wrapped.tag.task().clone();
84-
let smol_task = smol::Task::local(wrapped).into();
84+
let handle = task::executor::local(wrapped);
8585

86-
Ok(JoinHandle::new(smol_task, task))
86+
Ok(JoinHandle::new(handle, task))
8787
}
8888

8989
/// Spawns a task locally with the configured settings.
@@ -141,7 +141,8 @@ impl Builder {
141141
#[cfg(not(target_os = "unknown"))]
142142
pub fn blocking<F, T>(self, future: F) -> T
143143
where
144-
F: Future<Output = T>,
144+
F: Future<Output = T> + 'static,
145+
T: 'static,
145146
{
146147
let wrapped = self.build(future);
147148

@@ -166,8 +167,8 @@ impl Builder {
166167
unsafe {
167168
TaskLocalsWrapper::set_current(&wrapped.tag, || {
168169
let res = if should_run {
169-
// The first call should use run.
170-
smol::run(wrapped)
170+
// The first call should run the executor
171+
task::executor::run(wrapped)
171172
} else {
172173
blocking::block_on(wrapped)
173174
};

src/task/executor.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
use std::cell::RefCell;
2+
use std::future::Future;
3+
4+
static GLOBAL_EXECUTOR: once_cell::sync::Lazy<multitask::Executor> = once_cell::sync::Lazy::new(multitask::Executor::new);
5+
6+
struct Executor {
7+
local_executor: multitask::LocalExecutor,
8+
parker: async_io::parking::Parker,
9+
}
10+
11+
thread_local! {
12+
static EXECUTOR: RefCell<Executor> = RefCell::new({
13+
let (parker, unparker) = async_io::parking::pair();
14+
let local_executor = multitask::LocalExecutor::new(move || unparker.unpark());
15+
Executor { local_executor, parker }
16+
});
17+
}
18+
19+
pub(crate) fn spawn<F, T>(future: F) -> multitask::Task<T>
20+
where
21+
F: Future<Output = T> + Send + 'static,
22+
T: Send + 'static,
23+
{
24+
GLOBAL_EXECUTOR.spawn(future)
25+
}
26+
27+
#[cfg(feature = "unstable")]
28+
pub(crate) fn local<F, T>(future: F) -> multitask::Task<T>
29+
where
30+
F: Future<Output = T> + 'static,
31+
T: 'static,
32+
{
33+
EXECUTOR.with(|executor| executor.borrow().local_executor.spawn(future))
34+
}
35+
36+
pub(crate) fn run<F, T>(future: F) -> T
37+
where
38+
F: Future<Output = T> + 'static,
39+
T: 'static,
40+
{
41+
enter(|| EXECUTOR.with(|executor| {
42+
let executor = executor.borrow();
43+
let (sender, receiver) = std::sync::mpsc::channel();
44+
executor.local_executor.spawn(async move {
45+
sender.send(future.await).unwrap()
46+
}).detach();
47+
let unparker = executor.parker.unparker();
48+
let global_ticker = GLOBAL_EXECUTOR.ticker(move || unparker.unpark());
49+
loop {
50+
let advanced = executor.local_executor.tick() || std::panic::catch_unwind(|| global_ticker.tick()).unwrap_or(true);
51+
if let Ok(res) = receiver.try_recv() {
52+
return res;
53+
}
54+
if !advanced {
55+
executor.parker.park();
56+
}
57+
}
58+
}))
59+
}
60+
61+
/// Enters the tokio context if the `tokio` feature is enabled.
62+
fn enter<T>(f: impl FnOnce() -> T) -> T {
63+
#[cfg(not(feature = "tokio02"))]
64+
return f();
65+
66+
#[cfg(feature = "tokio02")]
67+
{
68+
use std::cell::Cell;
69+
use tokio::runtime::Runtime;
70+
71+
thread_local! {
72+
/// The level of nested `enter` calls we are in, to ensure that the outermost always
73+
/// has a runtime spawned.
74+
static NESTING: Cell<usize> = Cell::new(0);
75+
}
76+
77+
/// The global tokio runtime.
78+
static RT: once_cell::sync::Lazy<Runtime> = once_cell::sync::Lazy::new(|| Runtime::new().expect("cannot initialize tokio"));
79+
80+
NESTING.with(|nesting| {
81+
let res = if nesting.get() == 0 {
82+
nesting.replace(1);
83+
RT.enter(f)
84+
} else {
85+
nesting.replace(nesting.get() + 1);
86+
f()
87+
};
88+
nesting.replace(nesting.get() - 1);
89+
res
90+
})
91+
}
92+
}

src/task/join_handle.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub struct JoinHandle<T> {
1818
}
1919

2020
#[cfg(not(target_os = "unknown"))]
21-
type InnerHandle<T> = async_task::JoinHandle<T, ()>;
21+
type InnerHandle<T> = multitask::Task<T>;
2222
#[cfg(target_arch = "wasm32")]
2323
type InnerHandle<T> = futures_channel::oneshot::Receiver<T>;
2424

@@ -54,8 +54,7 @@ impl<T> JoinHandle<T> {
5454
#[cfg(not(target_os = "unknown"))]
5555
pub async fn cancel(mut self) -> Option<T> {
5656
let handle = self.handle.take().unwrap();
57-
handle.cancel();
58-
handle.await
57+
handle.cancel().await
5958
}
6059

6160
/// Cancel this task.
@@ -71,11 +70,6 @@ impl<T> Future for JoinHandle<T> {
7170
type Output = T;
7271

7372
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
74-
match Pin::new(&mut self.handle.as_mut().unwrap()).poll(cx) {
75-
Poll::Pending => Poll::Pending,
76-
Poll::Ready(output) => {
77-
Poll::Ready(output.expect("cannot await the result of a panicked task"))
78-
}
79-
}
73+
Pin::new(&mut self.handle.as_mut().unwrap()).poll(cx)
8074
}
8175
}

src/task/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ cfg_default! {
148148
mod block_on;
149149
mod builder;
150150
mod current;
151+
#[cfg(not(target_os = "unknown"))]
152+
mod executor;
151153
mod join_handle;
152154
mod sleep;
153155
#[cfg(not(target_os = "unknown"))]

tests/addr.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@ use async_std::task;
77

88
fn blocking_resolve<A>(a: A) -> Result<Vec<SocketAddr>, String>
99
where
10-
A: ToSocketAddrs,
10+
A: ToSocketAddrs + 'static,
1111
A::Iter: Send,
1212
{
13-
let socket_addrs = task::block_on(a.to_socket_addrs());
14-
match socket_addrs {
15-
Ok(a) => Ok(a.collect()),
16-
Err(e) => Err(e.to_string()),
17-
}
13+
task::block_on(async move {
14+
Ok(a.to_socket_addrs().await.map_err(|e| e.to_string())?.collect())
15+
})
1816
}
1917

2018
#[test]
@@ -71,7 +69,7 @@ fn to_socket_addr_string() {
7169
let s: &str = "77.88.21.11:24352";
7270
assert_eq!(Ok(vec![a]), blocking_resolve(s));
7371

74-
let s: &String = &"77.88.21.11:24352".to_string();
72+
let s: String = "77.88.21.11:24352".to_string();
7573
assert_eq!(Ok(vec![a]), blocking_resolve(s));
7674

7775
let s: String = "77.88.21.11:24352".to_string();

tests/uds.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,15 @@ fn socket_ping_pong() {
5656
let iter_cnt = 16;
5757

5858
let listener =
59-
task::block_on(async { UnixListener::bind(&sock_path).await.expect("Socket bind") });
59+
task::block_on(async move { UnixListener::bind(&sock_path).await.expect("Socket bind") });
6060

6161
let server_handle = std::thread::spawn(move || {
62-
task::block_on(async { ping_pong_server(listener, iter_cnt).await }).unwrap()
62+
task::block_on(async move { ping_pong_server(listener, iter_cnt).await }).unwrap()
6363
});
6464

65+
let sock_path = tmp_dir.as_ref().join("sock");
6566
let client_handle = std::thread::spawn(move || {
66-
task::block_on(async { ping_pong_client(&sock_path, iter_cnt).await }).unwrap()
67+
task::block_on(async move { ping_pong_client(&sock_path, iter_cnt).await }).unwrap()
6768
});
6869

6970
client_handle.join().unwrap();

0 commit comments

Comments
 (0)