diff --git a/playground/cubecl-test/Cargo.toml b/playground/cubecl-test/Cargo.toml new file mode 100644 index 0000000..ebf0ba9 --- /dev/null +++ b/playground/cubecl-test/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "cubecl-test" +edition.workspace = true +license.workspace = true +readme.workspace = true +version.workspace = true + +[features] +default = [] +wgpu = ["cubecl/wgpu"] + +[dependencies] +cubecl = { git = "https://github.com/tracel-ai/cubecl", features = ["wgpu"] } +half = { version = "2.4.1", features = [ + "alloc", + "num-traits", + "serde", +], default-features = false } diff --git a/playground/cubecl-test/src/main.rs b/playground/cubecl-test/src/main.rs new file mode 100644 index 0000000..664ca54 --- /dev/null +++ b/playground/cubecl-test/src/main.rs @@ -0,0 +1,45 @@ +use cubecl::prelude::*; + +#[cube(launch_unchecked)] +fn gelu_array(input: &Array>, output: &mut Array>) { + if ABSOLUTE_POS < input.len() { + output[ABSOLUTE_POS] = gelu_scalar::(input[ABSOLUTE_POS]); + } +} + +#[cube] +fn gelu_scalar(x: Line) -> Line { + // Execute the sqrt function at comptime. + let sqrt2 = F::new(comptime!(2.0f32.sqrt())); + let tmp = x / Line::new(sqrt2); + + x * (Line::erf(tmp) + 1.0) / 2.0 +} + +pub fn launch(device: &R::Device) { + let client = R::client(device); + let input = &[-1., 0., 1., 5.]; + let vectorization = 4; + let output_handle = client.empty(input.len() * core::mem::size_of::()); + let input_handle = client.create(f32::as_bytes(input)); + + unsafe { + gelu_array::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(input.len() as u32 / vectorization, 1, 1), + ArrayArg::from_raw_parts(&input_handle, input.len(), vectorization as u8), + ArrayArg::from_raw_parts(&output_handle, input.len(), vectorization as u8), + ) + }; + + let bytes = client.read(output_handle.binding()); + let output = f32::from_bytes(&bytes); + + // Should be [-0.1587, 0.0000, 0.8413, 5.0000] + println!("Executed gelu with runtime {:?} => {output:?}", R::name()); +} + +fn main() { + launch::(&Default::default()); +}