Skip to content

Commit

Permalink
Add adam with fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
zwzz2019 committed Sep 19, 2024
1 parent 3d35a26 commit f3c845f
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 32 deletions.
50 changes: 50 additions & 0 deletions bmtrain/optim/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,53 @@ def adam_bf16(
bias_correction2,
stream,
)

def adam_fp32(
param_fp32: torch.Tensor,
g_fp32: torch.Tensor,
m_fp32: torch.Tensor,
v_fp32: torch.Tensor,
beta1: float,
beta2: float,
eps: float,
lr: float,
scale: float,
weight_decay: float,
step: int,
) -> None:
assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda"
assert CHECK_INPUT(g_fp32), "g_fp16 must be contiguous and on cuda"
assert CHECK_INPUT(m_fp32), "m_fp32 must be contiguous and on cuda"
assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda"
assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
assert g_fp32.dtype == torch.float32, "g_fp16 must be float16 tensor"
assert m_fp32.dtype == torch.float32, "m_fp16 must be float16 tensor"
assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
assert (
param_fp32.numel() == g_fp32.numel()
), "param_fp32 and g_fp16 must have the same number of elements"
assert (
param_fp32.numel() == m_fp32.numel()
), "param_fp32 and m_fp32 must have the same number of elements"
assert (
param_fp32.numel() == v_fp32.numel()
), "param_fp32 and v_fp32 must have the same number of elements"
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
stream = torch.cuda.current_stream().cuda_stream
C.adam_fp32_launcher(
param_fp32.numel(),
param_fp32.data_ptr(),
g_fp32.data_ptr(),
m_fp32.data_ptr(),
v_fp32.data_ptr(),
beta1,
beta2,
eps,
lr,
scale,
weight_decay,
bias_correction1,
bias_correction2,
stream,
)
42 changes: 14 additions & 28 deletions bmtrain/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from itertools import chain
from collections import defaultdict


class AdamOptimizer(torch.optim.Optimizer):
"""
Adam optimizer support fp16 and bf16.
Expand Down Expand Up @@ -112,34 +111,21 @@ def step(self, closure=None, scale=1):
grad = p.grad

if p.dtype == torch.float32:
other_kwargs = {}
if (
"maximize"
in inspect.signature(
torch.optim._functional.adam
).parameters
):
other_kwargs["maximize"] = False
torch.optim._functional.adam(
[p],
[grad / scale],
[state["exp_avg"]],
[state["exp_avg_sq"]],
[],
(
[state["step"]]
if check_torch_version("1.12.0") < 0
else [torch.tensor(state["step"])]
),
amsgrad=False,
beta1=group["betas"][0],
beta2=group["betas"][1],
lr=0.0 if state["step"] < self._hold_steps else group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
**other_kwargs
f = F.adam_fp32
state["step"] += 1
f(
p, # fp32
grad, # fp32
state["exp_avg"], # fp32: m
state["exp_avg_sq"], # fp32: v
group["betas"][0],
group["betas"][1],
group["eps"],
0.0 if state["step"] < self._hold_steps else group["lr"],
scale,
group["weight_decay"],
state["step"],
)
state["step"] += 1
else:
f = F.adam_fp16 if p.dtype == torch.float16 else F.adam_bf16
state["step"] += 1
Expand Down
5 changes: 3 additions & 2 deletions csrc/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ PYBIND11_MODULE(C, m) {
m.def("is_bf16_supported", &is_bf16_supported, "whether bf16 supported");
m.def("has_nan_inf_fp16_launcher", &has_nan_inf_fp16_launcher, "has nan inf");
m.def("has_nan_inf_bf16_launcher", &has_nan_inf_bf16_launcher, "has nan inf bf16");
m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function cpu");
m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function cpu");
m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function");
m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function");
m.def("adam_fp32_launcher", &adam_fp32_launcher, "adam function");
m.def("adam_cpu_fp16_launcher", &adam_cpu_fp16_launcher, "adam function cpu");
m.def("adam_cpu_bf16_launcher", &adam_cpu_bf16_launcher, "adam function cpu");
m.def("cross_entropy_forward_fp16_launcher", &cross_entropy_forward_fp16_launcher, "cross entropy forward");
Expand Down
55 changes: 55 additions & 0 deletions csrc/cuda/adam_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include "bfloat16.cuh"
#include <stdio.h>

namespace {
// blocks <n // 1024>, threads<min(n, 1024)>
Expand Down Expand Up @@ -71,6 +72,35 @@ __global__ void adam_fp32_accum_bf16(
#endif
}

__global__ void adam_fp32_accum_fp32(
int32_t n,
const float *g, // (n)
float *m, // (n)
float *v, // (n)
float *param, // (n)
float beta1,
float beta2,
float eps,
float lr,
float scale,
float weight_decay,
float bias_correction1,
float bias_correction2
) {
int32_t col = blockIdx.x * blockDim.x + threadIdx.x;
if (col < n) {
float local_g = g[col] / scale;
float local_m = beta1 * m[col] + (1 - beta1) * local_g;
float local_v = beta2 * v[col] + (1 - beta2) * local_g * local_g;
float local_p = param[col];
local_p = local_p - lr * local_m / bias_correction1 / (sqrtf(local_v / bias_correction2) + eps) - lr * weight_decay * local_p;

param[col] = local_p;
v[col] = local_v;
m[col] = local_m;
}
}

}

void adam_fp16_launcher(
Expand Down Expand Up @@ -124,3 +154,28 @@ void adam_bf16_launcher(
dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
adam_fp32_accum_bf16<<<grid_size, block_size, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, g_bf16, m_ptr, v_fp32_ptr, param_fp32_ptr, param_bf16, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
}

void adam_fp32_launcher(
int n,
std::uintptr_t param_fp32,
std::uintptr_t g_fp32,
std::uintptr_t m_fp32,
std::uintptr_t v_fp32,
float beta1, float beta2,
float eps, float lr,
float scale,
float weight_decay,
float bias_correction1,
float bias_correction2,
uintptr_t stream
) {
if (n <= 0) return;
auto g_ptr = reinterpret_cast<float*>(g_fp32);
auto m_ptr = reinterpret_cast<float*>(m_fp32);
auto param_fp32_ptr = reinterpret_cast<float*>(param_fp32);
auto v_fp32_ptr = reinterpret_cast<float*>(v_fp32);
int32_t threads = 1024;
dim3 block_size = dim3(threads, 1, 1);
dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
adam_fp32_accum_fp32<<<grid_size, block_size, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, g_ptr, m_ptr, v_fp32_ptr, param_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
}
14 changes: 14 additions & 0 deletions csrc/include/bind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,17 @@ void adam_bf16_launcher(
float bias_correction2,
uintptr_t stream
);
void adam_fp32_launcher(
int n,
std::uintptr_t param_fp32,
std::uintptr_t g_fp32,
std::uintptr_t m_fp32,
std::uintptr_t v_fp32,
float beta1, float beta2,
float eps, float lr,
float scale,
float weight_decay,
float bias_correction1,
float bias_correction2,
uintptr_t stream
);
9 changes: 7 additions & 2 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,14 @@ def main(dtype):
optim_manager.add_optimizer(opt4)
optim_manager.add_optimizer(opt5)

# fp16 bmt.adam
# fp16 bmt.Offload
# fp32 torch.adam
# fp32 bmt.adam
# fp32 bmt.Offload

for _ in range(100):
optim_manager.zero_grad()

for p1, p2, p3, p4, p5 in zip(model1.parameters(), model2.parameters(), model3.parameters(), model4.parameters(), model5.parameters()):
grad = torch.randn_like(p1)
p1.grad = grad.to(dtype)
Expand All @@ -81,7 +86,7 @@ def main(dtype):
assert_lt(diff1, 1)
assert_lt(diff2, 1)
assert_lt(diff3, 1)
assert_eq(diff4, 0)
assert_lt(diff4, 0.001)
assert_lt(diff5, 0.00001)

if __name__ == "__main__":
Expand Down

0 comments on commit f3c845f

Please sign in to comment.