Skip to content

Remove the Send bound from block_on #195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from Sep 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions src/task/block_on.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ use std::sync::Arc;
use std::task::{RawWaker, RawWakerVTable};
use std::thread::{self, Thread};

use super::log_utils;
use super::pool;
use super::Builder;
use super::task;
use crate::future::Future;
use crate::task::{Context, Poll, Waker};
use crate::utils::abort_on_panic;

/// Spawns a task and blocks the current thread on its result.
///
Expand All @@ -32,8 +34,7 @@ use crate::task::{Context, Poll, Waker};
/// ```
pub fn block_on<F, T>(future: F) -> T
where
F: Future<Output = T> + Send,
T: Send,
F: Future<Output = T>,
{
unsafe {
// A place on the stack where the result will be stored.
Expand All @@ -51,17 +52,48 @@ where
}
};

// Create a tag for the task.
let tag = task::Tag::new(None);

// Log this `block_on` operation.
let child_id = tag.task_id().as_u64();
let parent_id = pool::get_task(|t| t.id().as_u64()).unwrap_or(0);
log_utils::print(
format_args!("block_on"),
log_utils::LogData {
parent_id,
child_id,
},
);

// Wrap the future into one that drops task-local variables on exit.
let future = async move {
let res = future.await;

// Abort on panic because thread-local variables behave the same way.
abort_on_panic(|| pool::get_task(|task| task.metadata().local_map.clear()));

log_utils::print(
format_args!("block_on completed"),
log_utils::LogData {
parent_id,
child_id,
},
);
res
};

// Pin the future onto the stack.
pin_utils::pin_mut!(future);

// Transmute the future into one that is static and sendable.
// Transmute the future into one that is static.
let future = mem::transmute::<
Pin<&mut dyn Future<Output = ()>>,
Pin<&'static mut (dyn Future<Output = ()> + Send)>,
Pin<&'_ mut dyn Future<Output = ()>>,
Pin<&'static mut dyn Future<Output = ()>>,
>(future);

// Spawn the future and wait for it to complete.
block(pool::spawn_with_builder(Builder::new(), future, "block_on"));
// Block on the future and and wait for it to complete.
pool::set_tag(&tag, || block(future));

// Take out the result.
match (*out.get()).take().unwrap() {
Expand All @@ -87,7 +119,10 @@ impl<F: Future + UnwindSafe> Future for CatchUnwindFuture<F> {
}
}

fn block<F: Future>(f: F) -> F::Output {
fn block<F, T>(f: F) -> T
where
F: Future<Output = T>,
{
thread_local! {
static ARC_THREAD: Arc<Thread> = Arc::new(thread::current());
}
Expand Down
32 changes: 32 additions & 0 deletions src/task/log_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use std::fmt::Arguments;

/// This struct only exists because kv logging isn't supported from the macros right now.
pub(crate) struct LogData {
pub parent_id: u64,
pub child_id: u64,
}

impl<'a> log::kv::Source for LogData {
fn visit<'kvs>(
&'kvs self,
visitor: &mut dyn log::kv::Visitor<'kvs>,
) -> Result<(), log::kv::Error> {
visitor.visit_pair("parent_id".into(), self.parent_id.into())?;
visitor.visit_pair("child_id".into(), self.child_id.into())?;
Ok(())
}
}

pub fn print(msg: Arguments<'_>, key_values: impl log::kv::Source) {
log::logger().log(
&log::Record::builder()
.args(msg)
.key_values(&key_values)
.level(log::Level::Trace)
.target(module_path!())
.module_path(Some(module_path!()))
.file(Some(file!()))
.line(Some(line!()))
.build(),
);
}
1 change: 1 addition & 0 deletions src/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub use task::{JoinHandle, Task, TaskId};

mod block_on;
mod local;
mod log_utils;
mod pool;
mod sleep;
mod task;
Expand Down
106 changes: 36 additions & 70 deletions src/task/pool.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::cell::Cell;
use std::fmt::Arguments;
use std::mem;
use std::ptr;
use std::thread;

use crossbeam_channel::{unbounded, Sender};
use lazy_static::lazy_static;

use super::log_utils;
use super::task;
use super::{JoinHandle, Task};
use crate::future::Future;
use crate::io;
use crate::utils::abort_on_panic;

/// Returns a handle to the current task.
///
Expand Down Expand Up @@ -64,7 +64,7 @@ where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
spawn_with_builder(Builder::new(), future, "spawn")
spawn_with_builder(Builder::new(), future)
}

/// Task builder that configures the settings of a new task.
Expand All @@ -91,15 +91,11 @@ impl Builder {
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
Ok(spawn_with_builder(self, future, "spawn"))
Ok(spawn_with_builder(self, future))
}
}

pub(crate) fn spawn_with_builder<F, T>(
builder: Builder,
future: F,
fn_name: &'static str,
) -> JoinHandle<T>
pub(crate) fn spawn_with_builder<F, T>(builder: Builder, future: F) -> JoinHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
Expand All @@ -117,13 +113,9 @@ where
thread::Builder::new()
.name("async-task-driver".to_string())
.spawn(|| {
TAG.with(|tag| {
for job in receiver {
tag.set(job.tag());
abort_on_panic(|| job.run());
tag.set(ptr::null());
}
});
for job in receiver {
set_tag(job.tag(), || abort_on_panic(|| job.run()))
}
})
.expect("cannot start a thread driving tasks");
}
Expand All @@ -135,11 +127,12 @@ where
let tag = task::Tag::new(name);
let schedule = |job| QUEUE.send(job).unwrap();

// Log this `spawn` operation.
let child_id = tag.task_id().as_u64();
let parent_id = get_task(|t| t.id().as_u64()).unwrap_or(0);
print(
format_args!("{}", fn_name),
LogData {
log_utils::print(
format_args!("spawn"),
log_utils::LogData {
parent_id,
child_id,
},
Expand All @@ -152,9 +145,9 @@ where
// Abort on panic because thread-local variables behave the same way.
abort_on_panic(|| get_task(|task| task.metadata().local_map.clear()));

print(
format_args!("{} completed", fn_name),
LogData {
log_utils::print(
format_args!("spawn completed"),
log_utils::LogData {
parent_id,
child_id,
},
Expand All @@ -171,61 +164,34 @@ thread_local! {
static TAG: Cell<*const task::Tag> = Cell::new(ptr::null_mut());
}

pub(crate) fn get_task<F: FnOnce(&Task) -> R, R>(f: F) -> Option<R> {
let res = TAG.try_with(|tag| unsafe { tag.get().as_ref().map(task::Tag::task).map(f) });

match res {
Ok(Some(val)) => Some(val),
Ok(None) | Err(_) => None,
}
}

/// Calls a function and aborts if it panics.
///
/// This is useful in unsafe code where we can't recover from panics.
#[inline]
fn abort_on_panic<T>(f: impl FnOnce() -> T) -> T {
struct Bomb;
pub(crate) fn set_tag<F, R>(tag: *const task::Tag, f: F) -> R
where
F: FnOnce() -> R,
{
struct ResetTag<'a>(&'a Cell<*const task::Tag>);

impl Drop for Bomb {
impl Drop for ResetTag<'_> {
fn drop(&mut self) {
std::process::abort();
self.0.set(ptr::null());
}
}

let bomb = Bomb;
let t = f();
mem::forget(bomb);
t
}
TAG.with(|t| {
t.set(tag);
let _guard = ResetTag(t);

/// This struct only exists because kv logging isn't supported from the macros right now.
struct LogData {
parent_id: u64,
child_id: u64,
f()
})
}

impl<'a> log::kv::Source for LogData {
fn visit<'kvs>(
&'kvs self,
visitor: &mut dyn log::kv::Visitor<'kvs>,
) -> Result<(), log::kv::Error> {
visitor.visit_pair("parent_id".into(), self.parent_id.into())?;
visitor.visit_pair("child_id".into(), self.child_id.into())?;
Ok(())
}
}
pub(crate) fn get_task<F, R>(f: F) -> Option<R>
where
F: FnOnce(&Task) -> R,
{
let res = TAG.try_with(|tag| unsafe { tag.get().as_ref().map(task::Tag::task).map(f) });

fn print(msg: Arguments<'_>, key_values: impl log::kv::Source) {
log::logger().log(
&log::Record::builder()
.args(msg)
.key_values(&key_values)
.level(log::Level::Trace)
.target(module_path!())
.module_path(Some(module_path!()))
.file(Some(file!()))
.line(Some(line!()))
.build(),
);
match res {
Ok(Some(val)) => Some(val),
Ok(None) | Err(_) => None,
}
}