From f3c845fee1f93827a96a1b6715b7b73cdc9f41ab Mon Sep 17 00:00:00 2001 From: zwzz2019 <17629213401@163.com> Date: Fri, 20 Sep 2024 00:59:58 +0800 Subject: [PATCH] Add adam with fp32 --- bmtrain/optim/_function.py | 50 ++++++++++++++++++++++++++++++++++ bmtrain/optim/adam.py | 42 ++++++++++------------------- csrc/bind.cpp | 5 ++-- csrc/cuda/adam_cuda.cu | 55 ++++++++++++++++++++++++++++++++++++++ csrc/include/bind.hpp | 14 ++++++++++ tests/test_optim.py | 9 +++++-- 6 files changed, 143 insertions(+), 32 deletions(-) diff --git a/bmtrain/optim/_function.py b/bmtrain/optim/_function.py index f9e0ce9d..999b021c 100644 --- a/bmtrain/optim/_function.py +++ b/bmtrain/optim/_function.py @@ -216,3 +216,53 @@ def adam_bf16( bias_correction2, stream, ) + +def adam_fp32( + param_fp32: torch.Tensor, + g_fp32: torch.Tensor, + m_fp32: torch.Tensor, + v_fp32: torch.Tensor, + beta1: float, + beta2: float, + eps: float, + lr: float, + scale: float, + weight_decay: float, + step: int, +) -> None: + assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda" + assert CHECK_INPUT(g_fp32), "g_fp16 must be contiguous and on cuda" + assert CHECK_INPUT(m_fp32), "m_fp32 must be contiguous and on cuda" + assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda" + assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor" + assert g_fp32.dtype == torch.float32, "g_fp16 must be float16 tensor" + assert m_fp32.dtype == torch.float32, "m_fp16 must be float16 tensor" + assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor" + assert ( + param_fp32.numel() == g_fp32.numel() + ), "param_fp32 and g_fp16 must have the same number of elements" + assert ( + param_fp32.numel() == m_fp32.numel() + ), "param_fp32 and m_fp32 must have the same number of elements" + assert ( + param_fp32.numel() == v_fp32.numel() + ), "param_fp32 and v_fp32 must have the same number of elements" + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + stream = torch.cuda.current_stream().cuda_stream + C.adam_fp32_launcher( + param_fp32.numel(), + param_fp32.data_ptr(), + g_fp32.data_ptr(), + m_fp32.data_ptr(), + v_fp32.data_ptr(), + beta1, + beta2, + eps, + lr, + scale, + weight_decay, + bias_correction1, + bias_correction2, + stream, + ) \ No newline at end of file diff --git a/bmtrain/optim/adam.py b/bmtrain/optim/adam.py index f99c483c..1d12e999 100644 --- a/bmtrain/optim/adam.py +++ b/bmtrain/optim/adam.py @@ -10,7 +10,6 @@ from itertools import chain from collections import defaultdict - class AdamOptimizer(torch.optim.Optimizer): """ Adam optimizer support fp16 and bf16. @@ -112,34 +111,21 @@ def step(self, closure=None, scale=1): grad = p.grad if p.dtype == torch.float32: - other_kwargs = {} - if ( - "maximize" - in inspect.signature( - torch.optim._functional.adam - ).parameters - ): - other_kwargs["maximize"] = False - torch.optim._functional.adam( - [p], - [grad / scale], - [state["exp_avg"]], - [state["exp_avg_sq"]], - [], - ( - [state["step"]] - if check_torch_version("1.12.0") < 0 - else [torch.tensor(state["step"])] - ), - amsgrad=False, - beta1=group["betas"][0], - beta2=group["betas"][1], - lr=0.0 if state["step"] < self._hold_steps else group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - **other_kwargs + f = F.adam_fp32 + state["step"] += 1 + f( + p, # fp32 + grad, # fp32 + state["exp_avg"], # fp32: m + state["exp_avg_sq"], # fp32: v + group["betas"][0], + group["betas"][1], + group["eps"], + 0.0 if state["step"] < self._hold_steps else group["lr"], + scale, + group["weight_decay"], + state["step"], ) - state["step"] += 1 else: f = F.adam_fp16 if p.dtype == torch.float16 else F.adam_bf16 state["step"] += 1 diff --git a/csrc/bind.cpp b/csrc/bind.cpp index b8f6fa85..7b8e3084 100644 --- a/csrc/bind.cpp +++ b/csrc/bind.cpp @@ -6,8 +6,9 @@ PYBIND11_MODULE(C, m) { m.def("is_bf16_supported", &is_bf16_supported, "whether bf16 supported"); m.def("has_nan_inf_fp16_launcher", &has_nan_inf_fp16_launcher, "has nan inf"); m.def("has_nan_inf_bf16_launcher", &has_nan_inf_bf16_launcher, "has nan inf bf16"); - m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function cpu"); - m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function cpu"); + m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function"); + m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function"); + m.def("adam_fp32_launcher", &adam_fp32_launcher, "adam function"); m.def("adam_cpu_fp16_launcher", &adam_cpu_fp16_launcher, "adam function cpu"); m.def("adam_cpu_bf16_launcher", &adam_cpu_bf16_launcher, "adam function cpu"); m.def("cross_entropy_forward_fp16_launcher", &cross_entropy_forward_fp16_launcher, "cross entropy forward"); diff --git a/csrc/cuda/adam_cuda.cu b/csrc/cuda/adam_cuda.cu index 0510ac12..f807c0b8 100644 --- a/csrc/cuda/adam_cuda.cu +++ b/csrc/cuda/adam_cuda.cu @@ -2,6 +2,7 @@ #include #include #include "bfloat16.cuh" +#include namespace { // blocks , threads @@ -71,6 +72,35 @@ __global__ void adam_fp32_accum_bf16( #endif } +__global__ void adam_fp32_accum_fp32( + int32_t n, + const float *g, // (n) + float *m, // (n) + float *v, // (n) + float *param, // (n) + float beta1, + float beta2, + float eps, + float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +) { + int32_t col = blockIdx.x * blockDim.x + threadIdx.x; + if (col < n) { + float local_g = g[col] / scale; + float local_m = beta1 * m[col] + (1 - beta1) * local_g; + float local_v = beta2 * v[col] + (1 - beta2) * local_g * local_g; + float local_p = param[col]; + local_p = local_p - lr * local_m / bias_correction1 / (sqrtf(local_v / bias_correction2) + eps) - lr * weight_decay * local_p; + + param[col] = local_p; + v[col] = local_v; + m[col] = local_m; + } +} + } void adam_fp16_launcher( @@ -124,3 +154,28 @@ void adam_bf16_launcher( dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); adam_fp32_accum_bf16<<(stream)>>>(n, g_bf16, m_ptr, v_fp32_ptr, param_fp32_ptr, param_bf16, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); } + +void adam_fp32_launcher( + int n, + std::uintptr_t param_fp32, + std::uintptr_t g_fp32, + std::uintptr_t m_fp32, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2, + uintptr_t stream +) { + if (n <= 0) return; + auto g_ptr = reinterpret_cast(g_fp32); + auto m_ptr = reinterpret_cast(m_fp32); + auto param_fp32_ptr = reinterpret_cast(param_fp32); + auto v_fp32_ptr = reinterpret_cast(v_fp32); + int32_t threads = 1024; + dim3 block_size = dim3(threads, 1, 1); + dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); + adam_fp32_accum_fp32<<(stream)>>>(n, g_ptr, m_ptr, v_fp32_ptr, param_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); +} \ No newline at end of file diff --git a/csrc/include/bind.hpp b/csrc/include/bind.hpp index 3ff967fd..6a799b79 100644 --- a/csrc/include/bind.hpp +++ b/csrc/include/bind.hpp @@ -109,3 +109,17 @@ void adam_bf16_launcher( float bias_correction2, uintptr_t stream ); +void adam_fp32_launcher( + int n, + std::uintptr_t param_fp32, + std::uintptr_t g_fp32, + std::uintptr_t m_fp32, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2, + uintptr_t stream +); diff --git a/tests/test_optim.py b/tests/test_optim.py index 0aca8c31..1568e1a5 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -57,9 +57,14 @@ def main(dtype): optim_manager.add_optimizer(opt4) optim_manager.add_optimizer(opt5) + # fp16 bmt.adam + # fp16 bmt.Offload + # fp32 torch.adam + # fp32 bmt.adam + # fp32 bmt.Offload + for _ in range(100): optim_manager.zero_grad() - for p1, p2, p3, p4, p5 in zip(model1.parameters(), model2.parameters(), model3.parameters(), model4.parameters(), model5.parameters()): grad = torch.randn_like(p1) p1.grad = grad.to(dtype) @@ -81,7 +86,7 @@ def main(dtype): assert_lt(diff1, 1) assert_lt(diff2, 1) assert_lt(diff3, 1) - assert_eq(diff4, 0) + assert_lt(diff4, 0.001) assert_lt(diff5, 0.00001) if __name__ == "__main__":