Skip to content

Commit

Permalink
[rust] Make cublaslt wrapper non static (#3434)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Aug 21, 2024
1 parent 5042fc4 commit 29dea2a
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 26 deletions.
38 changes: 16 additions & 22 deletions extensions/tokenizers/rust/src/layers/cublaslt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,23 @@ use std::sync::Once;
#[cfg(feature = "cuda")]
use candle_cublaslt::{fused_batch_matmul, fused_matmul, Activation, CublasLt};

static INIT: Once = Once::new();
static mut CUBLASLT: Option<CublasLtWrapper> = None;

pub fn get_cublas_lt_wrapper() -> Option<&'static CublasLtWrapper> {
pub fn get_cublas_lt_wrapper(device: &Device) -> Option<CublasLtWrapper> {
unsafe {
INIT.call_once(|| {
CUBLASLT = match Device::cuda_if_available(0) {
Ok(device) => {
#[cfg(feature = "cuda")]
{
Some(CublasLtWrapper {
cublaslt: CublasLt::new(&device).unwrap(),
})
}
#[cfg(not(feature = "cuda"))]
{
None
}
}
Err(_) => None,
};
});
CUBLASLT.as_ref()
let cublaslt = if device.is_cuda() {
#[cfg(feature = "cuda")]
{
Some(CublasLtWrapper {
cublaslt: CublasLt::new(&device).unwrap(),
})
}
#[cfg(not(feature = "cuda"))]
{
None
}
} else {
None
};
cublaslt.clone()
}
}

Expand Down
7 changes: 5 additions & 2 deletions extensions/tokenizers/rust/src/layers/linear.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::layers::cublaslt::get_cublas_lt_wrapper;
use crate::layers::cublaslt::{get_cublas_lt_wrapper, CublasLtWrapper};
use candle::{Device, Result, Tensor};
use candle_nn::VarBuilder;
use serde::Deserialize;
Expand All @@ -17,6 +17,7 @@ pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
act: Option<HiddenAct>,
cublaslt: Option<CublasLtWrapper>,
span: tracing::Span,
}

Expand All @@ -31,10 +32,12 @@ impl Linear {
Ok(w) => Some(w),
Err(_) => None,
};
let cublaslt = get_cublas_lt_wrapper(&vb.device());
Ok(Self {
weight: vb.get((out_dim, in_dim), "weight")?,
bias,
act,
cublaslt,
span: tracing::span!(tracing::Level::TRACE, "linear"),
})
}
Expand All @@ -43,7 +46,7 @@ impl Linear {
let _enter = self.span.enter();

#[allow(unused)]
if let (Device::Cuda(_), Some(cublaslt)) = (x.device(), get_cublas_lt_wrapper()) {
if let (Device::Cuda(_), Some(cublaslt)) = (x.device(), self.cublaslt.clone()) {
match x.dims() {
&[bsize, _, _] => cublaslt.batch_matmul(
&self.weight.broadcast_left(bsize)?,
Expand Down
2 changes: 0 additions & 2 deletions extensions/tokenizers/rust/src/layers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ mod linear;
#[allow(dead_code, unused)]
mod rms_norm;

#[allow(unused_imports)]
pub use cublaslt::get_cublas_lt_wrapper;
pub use layer_norm::LayerNorm;
pub use linear::{HiddenAct, Linear};
pub use rms_norm::RmsNorm;

0 comments on commit 29dea2a

Please sign in to comment.