Skip to content

Commit

Permalink
feat(KAN): device type_as use
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed May 3, 2024
1 parent 63d1edf commit 507238a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
10 changes: 5 additions & 5 deletions kan_gpt/kan/KANLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def forward(self, x):
x,
torch.ones(
self.out_dim,
).to(x.device),
).type_as(x),
)
.reshape(batch, self.size)
.permute(1, 0)
Expand Down Expand Up @@ -276,7 +276,7 @@ def update_grid_from_samples(self, x):
x,
torch.ones(
self.out_dim,
).to(x.device),
).type_as(x),
)
.reshape(batch, self.size)
.permute(1, 0)
Expand Down Expand Up @@ -344,7 +344,7 @@ def initialize_grid_from_parent(self, parent, x):
x,
torch.ones(
self.out_dim,
).to(x.device),
).type_as(x),
)
.reshape(batch, self.size)
.permute(1, 0)
Expand All @@ -356,12 +356,12 @@ def initialize_grid_from_parent(self, parent, x):
k=1,
num=x_pos.shape[1] - 1,
scale_base=0.0,
).to(x.device)
).type_as(x)
sp2.coef.data = curve2coef(sp2.grid, x_pos, sp2.grid, k=1)
y_eval = coef2curve(
x_eval, parent.grid, parent.coef, parent.k, device=x.device
)
percentile = torch.linspace(-1, 1, self.num + 1).to(x.device)
percentile = torch.linspace(-1, 1, self.num + 1).type_as(x)
self.grid.data = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0)
self.coef.data = curve2coef(
x_eval, y_eval, self.grid, self.k, x.device
Expand Down
14 changes: 7 additions & 7 deletions kan_gpt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
from kan_gpt.mingpt.utils import set_seed
from kan_gpt.model import GPT as KAN_GPT

set_seed(3407)


def eval_split(
trainer, split, max_batches, batch_size, model, train_dataset, test_dataset
):
Expand All @@ -31,14 +28,12 @@ def eval_split(
if max_batches is not None and b + 1 >= max_batches:
break
rt = torch.tensor(results, dtype=torch.float)
print("%s loss: %.2f%%" % (split, rt.mean()))
print("%s loss: %.2f" % (split, rt.mean()))
return rt.mean()


def main(args):
wandb.init(project="KAN-GPT")

wandb.config = {
config = {
"model_type": args.model_type,
"batch_size": args.batch_size,
"dummy_dataset": args.dummy_dataset,
Expand All @@ -48,6 +43,11 @@ def main(args):
"architecture": args.architecture,
}

wandb.init(
project="KAN-GPT",
config=config
)

model_type = args.model_type

# print an example instance of the dataset
Expand Down

0 comments on commit 507238a

Please sign in to comment.