Skip to content

Commit

Permalink
Add CubeCL test, work through dependency issues
Browse files Browse the repository at this point in the history
  • Loading branch information
utensil committed Oct 8, 2024
1 parent aa62c78 commit 434d488
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
18 changes: 18 additions & 0 deletions playground/cubecl-test/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "cubecl-test"
edition.workspace = true
license.workspace = true
readme.workspace = true
version.workspace = true

[features]
default = []
wgpu = ["cubecl/wgpu"]

[dependencies]
cubecl = { git = "https://github.com/tracel-ai/cubecl", features = ["wgpu"] }
half = { version = "2.4.1", features = [
"alloc",
"num-traits",
"serde",
], default-features = false }
45 changes: 45 additions & 0 deletions playground/cubecl-test/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use cubecl::prelude::*;

#[cube(launch_unchecked)]
fn gelu_array<F: Float>(input: &Array<Line<F>>, output: &mut Array<Line<F>>) {
if ABSOLUTE_POS < input.len() {
output[ABSOLUTE_POS] = gelu_scalar::<F>(input[ABSOLUTE_POS]);
}
}

#[cube]
fn gelu_scalar<F: Float>(x: Line<F>) -> Line<F> {
// Execute the sqrt function at comptime.
let sqrt2 = F::new(comptime!(2.0f32.sqrt()));
let tmp = x / Line::new(sqrt2);

x * (Line::erf(tmp) + 1.0) / 2.0
}

pub fn launch<R: Runtime>(device: &R::Device) {
let client = R::client(device);
let input = &[-1., 0., 1., 5.];
let vectorization = 4;
let output_handle = client.empty(input.len() * core::mem::size_of::<f32>());
let input_handle = client.create(f32::as_bytes(input));

unsafe {
gelu_array::launch_unchecked::<f32, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new(input.len() as u32 / vectorization, 1, 1),
ArrayArg::from_raw_parts(&input_handle, input.len(), vectorization as u8),
ArrayArg::from_raw_parts(&output_handle, input.len(), vectorization as u8),
)
};

let bytes = client.read(output_handle.binding());
let output = f32::from_bytes(&bytes);

// Should be [-0.1587, 0.0000, 0.8413, 5.0000]
println!("Executed gelu with runtime {:?} => {output:?}", R::name());
}

fn main() {
launch::<cubecl::wgpu::WgpuRuntime>(&Default::default());
}

0 comments on commit 434d488

Please sign in to comment.