Skip to content

Commit

Permalink
[FEAT] Compute pool for native executor (#2986)
Browse files Browse the repository at this point in the history
Create a multithreaded compute runtime for swordfish compute tasks.
Switch query runtime to be single threaded, and use IO pool for scan
task streams.

Additionally, adds in a `tokio_select` together with the
`tokio::signal::ctrlc` and main async execution loop so that queries can
be cancelled.

```
import os
import daft
import numpy
import time
import psutil

current_process = psutil.Process(os.getpid())

daft.set_execution_config(enable_native_executor=True, default_morsel_size=1)
dfs = [
    iter(
        daft.from_pydict({"a": numpy.random.rand(10)}).with_column(
            "plus_one", daft.col("a") + 1
        )
    )
    for _ in range(10)
]
while True:
    for i, df in enumerate(dfs):
        time.sleep(0.1)
        try:
            print("threads: ", current_process.num_threads())
            print(next(df))
        except StopIteration:
            dfs.pop(i)
    if not dfs:
        break
```
If you run this script you can see that the number of threads increases
by only 1 per dataframe.

TODO:
- replace rayon with this ->
#3076

---------

Co-authored-by: Colin Ho <[email protected]>
Co-authored-by: Colin Ho <[email protected]>
Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
4 people authored Oct 23, 2024
1 parent 4ec76ce commit 6569cb6
Show file tree
Hide file tree
Showing 44 changed files with 604 additions and 356 deletions.
22 changes: 21 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ common-display = {path = "src/common/display", default-features = false}
common-file-formats = {path = "src/common/file-formats", default-features = false}
common-hashable-float-wrapper = {path = "src/common/hashable-float-wrapper", default-features = false}
common-resource-request = {path = "src/common/resource-request", default-features = false}
common-runtime = {path = "src/common/runtime", default-features = false}
common-system-info = {path = "src/common/system-info", default-features = false}
common-tracing = {path = "src/common/tracing", default-features = false}
common-version = {path = "src/common/version", default-features = false}
Expand Down
15 changes: 15 additions & 0 deletions src/common/runtime/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[dependencies]
common-error = {path = "../error", default-features = false}
futures = {workspace = true}
lazy_static = {workspace = true}
log = {workspace = true}
oneshot = "0.1.8"
tokio = {workspace = true}

[lints]
workspace = true

[package]
edition = {workspace = true}
name = "common-runtime"
version = {workspace = true}
185 changes: 185 additions & 0 deletions src/common/runtime/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
use std::{
future::Future,
panic::AssertUnwindSafe,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, OnceLock,
},
};

use common_error::{DaftError, DaftResult};
use futures::FutureExt;
use lazy_static::lazy_static;
use tokio::{runtime::RuntimeFlavor, task::JoinHandle};

lazy_static! {
static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get();
static ref THREADED_IO_RUNTIME_NUM_WORKER_THREADS: usize = 8.min(*NUM_CPUS);
static ref COMPUTE_RUNTIME_NUM_WORKER_THREADS: usize = *NUM_CPUS;
static ref COMPUTE_RUNTIME_MAX_BLOCKING_THREADS: usize = 1; // Compute thread should not use blocking threads, limit this to the minimum, i.e. 1
}

static THREADED_IO_RUNTIME: OnceLock<RuntimeRef> = OnceLock::new();
static SINGLE_THREADED_IO_RUNTIME: OnceLock<RuntimeRef> = OnceLock::new();
static COMPUTE_RUNTIME: OnceLock<RuntimeRef> = OnceLock::new();

pub type RuntimeRef = Arc<Runtime>;

#[derive(Debug, Clone, Copy)]
enum PoolType {
Compute,
IO,
}

pub struct Runtime {
runtime: tokio::runtime::Runtime,
pool_type: PoolType,
}

impl Runtime {
pub(crate) fn new(runtime: tokio::runtime::Runtime, pool_type: PoolType) -> RuntimeRef {
Arc::new(Self { runtime, pool_type })
}

// TODO: figure out a way to cancel the Future if this output is dropped.
async fn execute_task<F>(future: F, pool_type: PoolType) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
AssertUnwindSafe(future).catch_unwind().await.map_err(|e| {
let s = if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = e.downcast_ref::<&str>() {
(*s).to_string()
} else {
"unknown internal error".to_string()
};
DaftError::ComputeError(format!(
"Caught panic when spawning blocking task in the {:?} runtime: {})",
pool_type, s
))
})
}

/// Spawns a task on the runtime and blocks the current thread until the task is completed.
/// Similar to tokio's Runtime::block_on but requires static lifetime + Send
/// You should use this when you are spawning IO tasks from an Expression Evaluator or in the Executor
pub fn block_on<F>(&self, future: F) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = oneshot::channel();
let pool_type = self.pool_type;
let _join_handle = self.spawn(async move {
let task_output = Self::execute_task(future, pool_type).await;
if tx.send(task_output).is_err() {
log::warn!("Spawned task output ignored: receiver dropped");
}
});
rx.recv().expect("Spawned task transmitter dropped")
}

/// Spawn a task on the runtime and await on it.
/// You should use this when you are spawning compute or IO tasks from the Executor.
pub async fn await_on<F>(&self, future: F) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = oneshot::channel();
let pool_type = self.pool_type;
let _join_handle = self.spawn(async move {
let task_output = Self::execute_task(future, pool_type).await;
if tx.send(task_output).is_err() {
log::warn!("Spawned task output ignored: receiver dropped");
}
});
rx.await.expect("Spawned task transmitter dropped")
}

/// Blocks current thread to compute future. Can not be called in tokio runtime context
///
pub fn block_on_current_thread<F: Future>(&self, future: F) -> F::Output {
self.runtime.block_on(future)
}

pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.runtime.spawn(future)
}
}

fn init_compute_runtime() -> RuntimeRef {
std::thread::spawn(move || {
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder
.worker_threads(*COMPUTE_RUNTIME_NUM_WORKER_THREADS)
.enable_all()
.thread_name_fn(move || {
static COMPUTE_THREAD_ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
let id = COMPUTE_THREAD_ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
format!("Compute-Thread-{}", id)
})
.max_blocking_threads(*COMPUTE_RUNTIME_MAX_BLOCKING_THREADS);
Runtime::new(builder.build().unwrap(), PoolType::Compute)
})
.join()
.unwrap()
}

fn init_io_runtime(multi_thread: bool) -> RuntimeRef {
std::thread::spawn(move || {
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder
.worker_threads(if multi_thread {
*THREADED_IO_RUNTIME_NUM_WORKER_THREADS
} else {
1
})
.enable_all()
.thread_name_fn(move || {
static COMPUTE_THREAD_ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
let id = COMPUTE_THREAD_ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
format!("IO-Thread-{}", id)
});
Runtime::new(builder.build().unwrap(), PoolType::IO)
})
.join()
.unwrap()
}

pub fn get_compute_runtime() -> RuntimeRef {
COMPUTE_RUNTIME.get_or_init(init_compute_runtime).clone()
}

pub fn get_io_runtime(multi_thread: bool) -> RuntimeRef {
if !multi_thread {
SINGLE_THREADED_IO_RUNTIME
.get_or_init(|| init_io_runtime(false))
.clone()
} else {
THREADED_IO_RUNTIME
.get_or_init(|| init_io_runtime(true))
.clone()
}
}

#[must_use]
pub fn get_io_pool_num_threads() -> Option<usize> {
match tokio::runtime::Handle::try_current() {
Ok(handle) => {
match handle.runtime_flavor() {
RuntimeFlavor::CurrentThread => Some(1),
RuntimeFlavor::MultiThread => Some(*THREADED_IO_RUNTIME_NUM_WORKER_THREADS),
// RuntimeFlavor is #non_exhaustive, so we default to 1 here to be conservative
_ => Some(1),
}
}
Err(_) => None,
}
}
1 change: 1 addition & 0 deletions src/daft-csv/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ async-compat = {workspace = true}
async-stream = {workspace = true}
common-error = {path = "../common/error", default-features = false}
common-py-serde = {path = "../common/py-serde", default-features = false}
common-runtime = {path = "../common/runtime", default-features = false}
csv-async = "1.3.0"
daft-compression = {path = "../daft-compression", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
Expand Down
7 changes: 4 additions & 3 deletions src/daft-csv/src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use std::{collections::HashSet, sync::Arc};
use arrow2::io::csv::read_async::{AsyncReader, AsyncReaderBuilder};
use async_compat::CompatExt;
use common_error::DaftResult;
use common_runtime::get_io_runtime;
use csv_async::ByteRecord;
use daft_compression::CompressionCodec;
use daft_core::prelude::Schema;
use daft_decoding::inference::infer;
use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef};
use daft_io::{GetResult, IOClient, IOStatsRef};
use futures::{StreamExt, TryStreamExt};
use snafu::ResultExt;
use tokio::{
Expand Down Expand Up @@ -58,7 +59,7 @@ pub fn read_csv_schema(
io_client: Arc<IOClient>,
io_stats: Option<IOStatsRef>,
) -> DaftResult<(Schema, CsvReadStats)> {
let runtime_handle = get_runtime(true)?;
let runtime_handle = get_io_runtime(true);
runtime_handle.block_on_current_thread(async {
read_csv_schema_single(
uri,
Expand All @@ -80,7 +81,7 @@ pub async fn read_csv_schema_bulk(
io_stats: Option<IOStatsRef>,
num_parallel_tasks: usize,
) -> DaftResult<Vec<(Schema, CsvReadStats)>> {
let runtime_handle = get_runtime(true)?;
let runtime_handle = get_io_runtime(true);
let result = runtime_handle
.block_on_current_thread(async {
let task_stream = futures::stream::iter(uris.iter().map(|uri| {
Expand Down
7 changes: 4 additions & 3 deletions src/daft-csv/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use arrow2::{
};
use async_compat::{Compat, CompatExt};
use common_error::{DaftError, DaftResult};
use common_runtime::get_io_runtime;
use csv_async::AsyncReader;
use daft_compression::CompressionCodec;
use daft_core::{prelude::*, utils::arrow::cast_array_for_daft_if_needed};
use daft_decoding::deserialize::deserialize_column;
use daft_dsl::optimization::get_required_columns;
use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef};
use daft_io::{GetResult, IOClient, IOStatsRef};
use daft_table::Table;
use futures::{stream::BoxStream, Stream, StreamExt, TryStreamExt};
use rayon::{
Expand Down Expand Up @@ -53,7 +54,7 @@ pub fn read_csv(
multithreaded_io: bool,
max_chunks_in_flight: Option<usize>,
) -> DaftResult<Table> {
let runtime_handle = get_runtime(multithreaded_io)?;
let runtime_handle = get_io_runtime(multithreaded_io);
runtime_handle.block_on_current_thread(async {
read_csv_single_into_table(
uri,
Expand All @@ -80,7 +81,7 @@ pub fn read_csv_bulk(
max_chunks_in_flight: Option<usize>,
num_parallel_tasks: usize,
) -> DaftResult<Vec<Table>> {
let runtime_handle = get_runtime(multithreaded_io)?;
let runtime_handle = get_io_runtime(multithreaded_io);
let tables = runtime_handle.block_on_current_thread(async move {
// Launch a read task per URI, throttling the number of concurrent file reads to num_parallel tasks.
let task_stream = futures::stream::iter(uris.iter().map(|uri| {
Expand Down
1 change: 1 addition & 0 deletions src/daft-functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ base64 = {workspace = true}
common-error = {path = "../common/error", default-features = false}
common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"}
common-io-config = {path = "../common/io-config", default-features = false}
common-runtime = {path = "../common/runtime", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
daft-dsl = {path = "../daft-dsl", default-features = false}
daft-image = {path = "../daft-image", default-features = false}
Expand Down
Loading

0 comments on commit 6569cb6

Please sign in to comment.