Skip to content

Commit

Permalink
gpt2-nano training
Browse files Browse the repository at this point in the history
  • Loading branch information
xl0 committed Dec 20, 2023
1 parent f56092b commit d742103
Show file tree
Hide file tree
Showing 14 changed files with 821 additions and 95 deletions.
50 changes: 43 additions & 7 deletions nbs/01_tensor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
"from lovely_numpy import lovely\n",
"\n",
"import tidygrad.ops as ops\n",
"import tidygrad.tensor_helpers as helpers"
"import tidygrad.tensor_helpers as helpers\n",
"import traceback"
]
},
{
Expand Down Expand Up @@ -67,21 +68,54 @@
"class Tensor:\n",
" pass\n",
"\n",
"def simplify_trace(trace):\n",
" return ' -> '.join(f'{frame.name} at {frame.filename}:{frame.lineno}' for frame in trace if '/python' not in frame.filename)\n",
"\n",
"alloc_log = {}\n",
"\n",
"class Tensor:\n",
" # op = \"L\"\n",
" name: str = \"\"\n",
"\n",
" def __init__(self, data, name=None, op=None, eps=1e-8, requires_grad=False):\n",
" global _num_tensors\n",
" _num_tensors += 1\n",
" self.data = np.asarray(data, dtype=np.float64) # , dtype=np.float32\n",
"\n",
" self.grad = (np.zeros_like(self.data, dtype=np.float64) if requires_grad else None)\n",
" trace = traceback.extract_stack()\n",
" simplified_trace = simplify_trace(trace)\n",
" alloc_log[id(self)] = simplified_trace\n",
" \n",
" # Increment allocation count\n",
"\n",
" # if _num_tensors > 620:\n",
" # raise Exception(\"Too many tensors\")\n",
"\n",
" self.data = np.asarray(data) # , dtype=np.float32\n",
" if self.data.dtype == np.float64:\n",
" self.data = self.data.astype(np.float32)\n",
"\n",
" self.grad = (np.zeros_like(self.data, dtype=np.float32) if requires_grad else None)\n",
" self.eps = eps\n",
" self.op = op or ops.Load(name=name)\n",
" self.name = name or self.op.name\n",
" self.requires_grad = requires_grad\n",
" self._requires_grad = requires_grad\n",
"\n",
" def __del__(self):\n",
" # print(f\"Tensor {self.name} deleted\")\n",
" del alloc_log[id(self)]\n",
" global _num_tensors\n",
" _num_tensors -= 1\n",
"\n",
" @property\n",
" def requires_grad(self):\n",
" return self._requires_grad\n",
"\n",
" @requires_grad.setter\n",
" def requires_grad(self, requires_grad):\n",
" if requires_grad and self.grad is None:\n",
" self.grad = np.zeros_like(self.data)\n",
" \n",
" self._requires_grad = requires_grad\n",
" \n",
" def __repr__(self):\n",
" value_str = f\"v={lovely(self.data)}\"\n",
" grad_str = f\"∇={lovely(self.grad)}\" if self.grad is not None else \"\"\n",
Expand All @@ -90,7 +124,7 @@
" return f'Tensor{list(self.data.shape)}(name=\"{self.name}\" op={type(self.op).__name__}{parents}):\\n {value_str}\\n {grad_str}'\n",
"\n",
" def accum_grad(self, grad):\n",
" if not self.requires_grad:\n",
" if not self._requires_grad:\n",
" return\n",
"\n",
" if self.grad is None:\n",
Expand Down Expand Up @@ -230,9 +264,11 @@
" for n in nodes[::-1]:\n",
" if hasattr(n.op, \"backward\"):\n",
" n.op.backward()\n",
" n.op = None\n",
"\n",
"\n",
" def zero_grad(self):\n",
" assert self.requires_grad, \"Cannot zero grad on non-differentiable tensor\"\n",
" assert self._requires_grad, \"Cannot zero grad on non-differentiable tensor\"\n",
" self.grad.fill(0)"
]
}
Expand Down
1 change: 1 addition & 0 deletions nbs/02_func.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
" target = Tensor(target)\n",
" sm = softmax(logits)\n",
" loss = -target * sm.log()\n",
"\n",
" if reduction == \"mean\":\n",
" return loss.mean(axis=-1, keepdims=True)\n",
" if reduction == \"sum\":\n",
Expand Down
9 changes: 9 additions & 0 deletions nbs/02_ops.conv.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {},
"source": [
"---\n",
"skip_exec: true\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
60 changes: 25 additions & 35 deletions nbs/06_training.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion nbs/10_utils.grad_check.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@
"\n",
"loss.backward()\n",
"\n",
"grad_check(NN, (x, y), (w1, b1, w2))"
"# grad_check(NN, (x, y), (w1, b1, w2))"
]
}
],
Expand Down
Loading

0 comments on commit d742103

Please sign in to comment.