Skip to content

Commit

Permalink
feat: allow benchmarking fft algorithms for javascript wasm targets
Browse files Browse the repository at this point in the history
  • Loading branch information
IceTDrinker committed Sep 4, 2024
1 parent 8da0c7b commit 55336ec
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 8 deletions.
3 changes: 3 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Accessed by wasm-bindgen when testing for the wasm target
[target.wasm32-unknown-unknown]
runner = 'wasm-bindgen-test-runner'
10 changes: 10 additions & 0 deletions .github/workflows/cargo_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,13 @@ jobs:
run: |
make test_no_std_nightly
make test_no_std_nightly FFT128_SUPPORT=ON
cargo-tests-node-js:
runs-on: "ubuntu-latest"
steps:
- uses: actions/checkout@ac593985615ec2ede58e132d2e21d2b1cbd6127c

- name: Test node js
run: |
make install_node
make test_node_js
16 changes: 13 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ num-complex = { version = "0.4", features = ["bytemuck"] }
pulp = { version = "0.18.22", default-features = false }
serde = { version = "1.0", optional = true, default-features = false }

[target.'cfg(target_arch = "wasm32")'.dependencies]
js-sys = "0.3"

[features]
default = ["std"]
fft128 = []
Expand All @@ -26,17 +29,24 @@ std = ["pulp/std"]
serde = ["dep:serde", "num-complex/serde"]

[dev-dependencies]
criterion = "0.4"
rustfft = "6.0"
fftw-sys = { version = "0.6", default-features = false, features = ["system"] }
rand = "0.8"
bincode = "1.3"
more-asserts = "0.3.1"
serde_json = "1.0.96"

[target.'cfg(not(target_os = "windows"))'.dev-dependencies]
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
wasm-bindgen-test = "0.3"
wasm-bindgen = "0.2.86"
getrandom = { version = "0.2", features = ["js"] }

[target.'cfg(all(not(target_os = "windows"), not(target_arch = "wasm32")))'.dev-dependencies]
rug = "1.19.1"

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
criterion = "0.4"
fftw-sys = { version = "0.6", default-features = false, features = ["system"] }

[[bench]]
name = "fft"
harness = false
Expand Down
43 changes: 41 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ CARGO_RS_CHECK_TOOLCHAIN:=+$(RS_CHECK_TOOLCHAIN)
RS_BUILD_TOOLCHAIN:=stable
CARGO_RS_BUILD_TOOLCHAIN:=+$(RS_BUILD_TOOLCHAIN)
MIN_RUST_VERSION:=1.65
WASM_BINDGEN_VERSION:=$(shell grep '^wasm-bindgen[[:space:]]*=' Cargo.toml | cut -d '=' -f 2 | xargs)
NODE_VERSION=22.6
AVX512_SUPPORT?=OFF
FFT128_SUPPORT?=OFF
# This is done to avoid forgetting it, we still precise the RUSTFLAGS in the commands to be able to
Expand Down Expand Up @@ -47,6 +49,35 @@ install_rs_build_toolchain:
( echo "Unable to install $(RS_BUILD_TOOLCHAIN) toolchain, check your rustup installation. \
Rustup can be downloaded at https://rustup.rs/" && exit 1 )

.PHONY: install_build_wasm32_target # Install the wasm32 toolchain used for builds
install_build_wasm32_target: install_rs_build_toolchain
rustup +$(RS_BUILD_TOOLCHAIN) target add wasm32-unknown-unknown || \
( echo "Unable to install wasm32-unknown-unknown target toolchain, check your rustup installation. \
Rustup can be downloaded at https://rustup.rs/" && exit 1 )

# The installation uses the ^ symbol because we need the matching version of wasm-bindgen in the
# Cargo.toml, as we don't lock those dependencies, this allows to get the matching CLI
.PHONY: install_wasm_bindgen_cli # Install wasm-bindgen-cli to get access to the test runner
install_wasm_bindgen_cli: install_rs_build_toolchain
cargo +$(RS_BUILD_TOOLCHAIN) install --locked wasm-bindgen-cli --version ^$(WASM_BINDGEN_VERSION)

.PHONY: install_node # Install last version of NodeJS via nvm
install_node:
curl -o nvm_install.sh https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.3/install.sh
@echo "2ed5e94ba12434370f0358800deb69f514e8bce90f13beb0e1b241d42c6abafd nvm_install.sh" > nvm_checksum
@sha256sum -c nvm_checksum
@rm nvm_checksum
$(SHELL) nvm_install.sh
@rm nvm_install.sh
source ~/.bashrc
$(SHELL) -i -c 'nvm install $(NODE_VERSION)' || \
( echo "Unable to install node, unknown error." && exit 1 )

.PHONY: check_nvm_installed # Check if Node Version Manager is installed
check_nvm_installed:
@source ~/.nvm/nvm.sh && nvm --version > /dev/null 2>&1 || \
( echo "Unable to locate Node. Run 'make install_node'" && exit 1 )

.PHONY: fmt # Format rust code
fmt: install_rs_check_toolchain
cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" fmt
Expand Down Expand Up @@ -85,7 +116,7 @@ test: install_rs_build_toolchain

.PHONY: test_serde
test_serde: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release \
--features=serde

.PHONY: test_nightly
Expand All @@ -105,8 +136,16 @@ test_no_std_nightly: install_rs_check_toolchain
--no-default-features \
--features=nightly,$(FFT128_FEATURE)

.PHONY: test_node_js
test_node_js: install_rs_build_toolchain install_build_wasm32_target install_wasm_bindgen_cli check_nvm_installed
source ~/.nvm/nvm.sh && \
nvm install $(NODE_VERSION) && \
nvm use $(NODE_VERSION) && \
RUSTFLAGS="" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release \
--features=serde --target wasm32-unknown-unknown

.PHONY: test_all
test_all: test test_serde test_nightly test_no_std test_no_std_nightly
test_all: test test_serde test_nightly test_no_std test_no_std_nightly test_node_js

.PHONY: doc # Build rust doc
doc: install_rs_check_toolchain
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ macro_rules! izip {
mod fft_simd;
mod nat;

#[cfg(feature = "std")]
pub(crate) mod time;

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
mod x86;

Expand Down
12 changes: 10 additions & 2 deletions src/ordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ fn measure_n_runs(
let (scratch, _) = stack.make_aligned_raw::<c64>(n, CACHELINE_ALIGN);
let [fwd, _] = get_fn_ptr(algo, n);

use std::time::Instant;
// For wasm we have a dedicated implementation going through js-sys
use crate::time::Instant;
let now = Instant::now();

for _ in 0..n_runs {
Expand Down Expand Up @@ -101,7 +102,13 @@ pub(crate) fn measure_fastest(
stack: PodStack,
) -> (FftAlgo, Duration) {
const N_ALGOS: usize = 8;
const MIN_DURATION: Duration = Duration::from_millis(1);
const MIN_DURATION: Duration = if cfg!(target_arch = "wasm32") {
// This is to account for the fact the js-sys based time measurement has a resolution of 1ms
// on chrome, this will slow down the fft benchmarking somewhat, but it's barely noticeable
Duration::from_millis(10)
} else {
Duration::from_millis(1)
};

assert!(n.is_power_of_two());

Expand Down Expand Up @@ -443,6 +450,7 @@ mod tests {
}
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_fft() {
test_fft_simd(crate::fft_simd::Scalar);
Expand Down
11 changes: 11 additions & 0 deletions src/time/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//! The standard API for Instant is not available in Wasm runtimes.
//! This module replaces the Instant type from std to a custom implementation.

#[cfg(target_arch = "wasm32")]
mod wasm;

#[cfg(target_arch = "wasm32")]
pub(crate) use wasm::Instant;

#[cfg(not(target_arch = "wasm32"))]
pub(crate) use std::time::Instant;
18 changes: 18 additions & 0 deletions src/time/wasm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
pub(crate) struct Instant {
start: f64,
}

impl Instant {
/// This function only has a millisecond resolution on some platforms like the chrome browser
pub fn now() -> Self {
let now = js_sys::Date::new_0().get_time();
Self { start: now }
}

/// This function only has a millisecond resolution on some platforms like the chrome browser,
/// which means it can easily return 0 when called on quick code
pub fn elapsed(&self) -> core::time::Duration {
let now = js_sys::Date::new_0().get_time();
core::time::Duration::from_millis((now - self.start) as u64)
}
}
8 changes: 7 additions & 1 deletion src/unordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,8 @@ fn measure_fastest(

let n_runs = n_runs.ceil() as u32;

use std::time::Instant;
// For wasm we have a dedicated implementation going through js-sys
use crate::time::Instant;
let now = Instant::now();
for _ in 0..n_runs {
fwd_depth(
Expand Down Expand Up @@ -1067,6 +1068,7 @@ mod tests {

extern crate alloc;

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_fwd() {
for n in [128, 256, 512, 1024] {
Expand Down Expand Up @@ -1101,6 +1103,7 @@ mod tests {
}
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_fwd_monomial() {
for n in [256, 512, 1024] {
Expand Down Expand Up @@ -1133,6 +1136,7 @@ mod tests {
}
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_roundtrip() {
for n in [32, 64, 256, 512, 1024] {
Expand Down Expand Up @@ -1167,6 +1171,7 @@ mod tests {
}
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_equivalency() {
use num_complex::Complex;
Expand Down Expand Up @@ -9401,6 +9406,7 @@ mod tests_serde {

extern crate alloc;

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_serde() {
for n in [64, 128, 256, 512, 1024] {
Expand Down

0 comments on commit 55336ec

Please sign in to comment.