Skip to content

Commit

Permalink
2024-10-24 nightly release (7855789)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 24, 2024
1 parent de5f406 commit 047654d
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 36 deletions.
71 changes: 41 additions & 30 deletions .github/scripts/validate_binaries.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,40 +40,33 @@ else
export CUDA_VERSION="cpu"
fi

if [[ ${MATRIX_CHANNEL} = 'pypi_release' ]]; then
echo "checking pypi release"
pip install torch
pip install fbgemm-gpu
pip install torchrec
else
# figure out URL
if [[ ${MATRIX_CHANNEL} = 'nightly' ]]; then
export PYTORCH_URL="https://download.pytorch.org/whl/nightly/${CUDA_VERSION}"
elif [[ ${MATRIX_CHANNEL} = 'test' ]]; then
export PYTORCH_URL="https://download.pytorch.org/whl/test/${CUDA_VERSION}"
elif [[ ${MATRIX_CHANNEL} = 'release' ]]; then
export PYTORCH_URL="https://download.pytorch.org/whl/${CUDA_VERSION}"
fi
# figure out URL
if [[ ${MATRIX_CHANNEL} = 'nightly' ]]; then
export PYTORCH_URL="https://download.pytorch.org/whl/nightly/${CUDA_VERSION}"
elif [[ ${MATRIX_CHANNEL} = 'test' ]]; then
export PYTORCH_URL="https://download.pytorch.org/whl/test/${CUDA_VERSION}"
elif [[ ${MATRIX_CHANNEL} = 'release' ]]; then
export PYTORCH_URL="https://download.pytorch.org/whl/${CUDA_VERSION}"
fi

# install pytorch
# switch back to conda once torch nightly is fixed
# if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then
# export PYTORCH_CUDA_PKG="pytorch-cuda=${MATRIX_GPU_ARCH_VERSION}"
# fi
conda run -n build_binary pip install torch --index-url "$PYTORCH_URL"
# install pytorch
# switch back to conda once torch nightly is fixed
# if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then
# export PYTORCH_CUDA_PKG="pytorch-cuda=${MATRIX_GPU_ARCH_VERSION}"
# fi
conda run -n build_binary pip install torch --index-url "$PYTORCH_URL"

# install fbgemm
conda run -n build_binary pip install fbgemm-gpu --index-url "$PYTORCH_URL"
# install fbgemm
conda run -n build_binary pip install fbgemm-gpu --index-url "$PYTORCH_URL"

# install requirements from pypi
conda run -n build_binary pip install torchmetrics==1.0.3
# install requirements from pypi
conda run -n build_binary pip install torchmetrics==1.0.3

# install torchrec
conda run -n build_binary pip install torchrec --index-url "$PYTORCH_URL"
# install torchrec
conda run -n build_binary pip install torchrec --index-url "$PYTORCH_URL"

# Run small import test
conda run -n build_binary python -c "import torch; import fbgemm_gpu; import torchrec"
fi
# Run small import test
conda run -n build_binary python -c "import torch; import fbgemm_gpu; import torchrec"

# check directory
ls -R
Expand All @@ -98,13 +91,22 @@ fi

if [[ ${MATRIX_CHANNEL} != 'release' ]]; then
exit 0
else
# Check version matches only for release binaries
torchrec_version=$(conda run -n build_binary pip show torchrec | grep Version | cut -d' ' -f2)
fbgemm_version=$(conda run -n build_binary pip show fbgemm_gpu | grep Version | cut -d' ' -f2)

if [ "$torchrec_version" != "$fbgemm_version" ]; then
echo "Error: TorchRec package version does not match FBGEMM package version"
exit 1
fi
fi

conda create -y -n build_binary python="${MATRIX_PYTHON_VERSION}"

conda run -n build_binary python --version

if [[ ${MATRIX_GPU_ARCH_VERSION} != '12.1' ]]; then
if [[ ${MATRIX_GPU_ARCH_VERSION} != '12.4' ]]; then
exit 0
fi

Expand All @@ -113,6 +115,15 @@ conda run -n build_binary pip install torch
conda run -n build_binary pip install fbgemm-gpu
conda run -n build_binary pip install torchrec

# Check version matching again for PyPI
torchrec_version=$(conda run -n build_binary pip show torchrec | grep Version | cut -d' ' -f2)
fbgemm_version=$(conda run -n build_binary pip show fbgemm_gpu | grep Version | cut -d' ' -f2)

if [ "$torchrec_version" != "$fbgemm_version" ]; then
echo "Error: TorchRec package version does not match FBGEMM package version"
exit 1
fi

# check directory
ls -R

Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@
except ImportError:
logger.warning("torchrec_use_sync_collectives is not available")

torch.ops.import_module("fbgemm_gpu.sparse_ops")
if not torch._running_with_deploy():
torch.ops.import_module("fbgemm_gpu.sparse_ops")


class ModelDetachedException(Exception):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class GPUExecutor {
std::shared_ptr<IGPUExecutorObserver>
observer, // shared_ptr because used in completion executor callback
std::function<void()> warmupFn = {},
std::optional<size_t> numThreadsPerGPU = c10::nullopt,
std::optional<size_t> numThreadsPerGPU = std::nullopt,
std::unique_ptr<GCConfig> gcConfig = std::make_unique<GCConfig>());
GPUExecutor(GPUExecutor&& executor) noexcept = default;
GPUExecutor& operator=(GPUExecutor&& executor) noexcept = default;
Expand Down
2 changes: 1 addition & 1 deletion torchrec/inference/include/torchrec/inference/Validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace torchrec {
bool validateSparseFeatures(
at::Tensor& values,
at::Tensor& lengths,
std::optional<at::Tensor> maybeWeights = c10::nullopt);
std::optional<at::Tensor> maybeWeights = std::nullopt);

// Returns whether dense features are valid.
// Currently validates:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class GPUExecutor {
std::shared_ptr<IGPUExecutorObserver>
observer, // shared_ptr because used in completion executor callback
std::function<void()> warmupFn = {},
std::optional<size_t> numThreadsPerGPU = c10::nullopt,
std::optional<size_t> numThreadsPerGPU = std::nullopt,
std::unique_ptr<GCConfig> gcConfig = std::make_unique<GCConfig>());
GPUExecutor(GPUExecutor&& executor) noexcept = default;
GPUExecutor& operator=(GPUExecutor&& executor) noexcept = default;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace torchrec {
bool validateSparseFeatures(
at::Tensor& values,
at::Tensor& lengths,
std::optional<at::Tensor> maybeWeights = c10::nullopt);
std::optional<at::Tensor> maybeWeights = std::nullopt);

// Returns whether dense features are valid.
// Currently validates:
Expand Down
3 changes: 2 additions & 1 deletion torchrec/optim/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def step(self, closure: Any = None) -> None:
self._step_num += 1

@torch.no_grad()
def clip_grad_norm_(self) -> None:
def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
"""Clip the gradient norm of all parameters."""
max_norm = self._max_gradient
norm_type = float(self._norm_type)
Expand Down Expand Up @@ -224,6 +224,7 @@ def clip_grad_norm_(self) -> None:
clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6))
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
torch._foreach_mul_(all_grads, clip_coef_clamped)
return total_grad_norm


def _batch_cal_norm(
Expand Down

0 comments on commit 047654d

Please sign in to comment.