Skip to content

Commit

Permalink
Make turbo an optional dependency (#964)
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Feb 14, 2024
1 parent bce5374 commit a62ee55
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
14 changes: 12 additions & 2 deletions llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,12 @@ def __init__(self, data: Optional[torch.Tensor], try_quantize: bool = True):
self._f_encode = None
self._f_decode = None
if self._try_quantize:
from turbo import dequantize_signed, quantize_signed
try:
from turbo import dequantize_signed, quantize_signed
except ModuleNotFoundError:
raise NotImplementedError(
'The Lion 8b optimizer requires installing mosaicml-turbo. ',
'Please `pip install llm-foundry[turbo]` to install it.')
self._f_encode = quantize_signed
self._f_decode = dequantize_signed

Expand Down Expand Up @@ -396,7 +401,12 @@ def lion8b_step_fused(grads: torch.Tensor,
f'Weights must be f32 or match grad dtype {grads.dtype}')

# ------------------------------------------------ actual function call
from turbo import lion8b_step_cuda
try:
from turbo import lion8b_step_cuda
except ModuleNotFoundError:
raise NotImplementedError(
'The Lion 8b optimizer requires installing mosaicml-turbo. ',
'Please `pip install llm-foundry[turbo]` to install it.')
return lion8b_step_cuda(grads=grads,
weights=weights,
momentums=momentums,
Expand Down
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,16 @@

extra_deps['gpu'] = [
'flash-attn==1.0.9',
'mosaicml-turbo==0.0.8',
# PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
]

extra_deps['turbo'] = [
'mosaicml-turbo==0.0.8',
]

extra_deps['gpu-flash2'] = [
'flash-attn==2.5.0',
'mosaicml-turbo==0.0.8',
]

extra_deps['peft'] = [
Expand Down

0 comments on commit a62ee55

Please sign in to comment.