From 507238aa4ca72e4ef71a5f93dc61b3e170dd3bd8 Mon Sep 17 00:00:00 2001 From: Aditya NG Date: Fri, 3 May 2024 08:32:29 +0530 Subject: [PATCH] feat(KAN): device type_as use --- kan_gpt/kan/KANLayer.py | 10 +++++----- kan_gpt/train.py | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/kan_gpt/kan/KANLayer.py b/kan_gpt/kan/KANLayer.py index df99079..2ca7487 100644 --- a/kan_gpt/kan/KANLayer.py +++ b/kan_gpt/kan/KANLayer.py @@ -214,7 +214,7 @@ def forward(self, x): x, torch.ones( self.out_dim, - ).to(x.device), + ).type_as(x), ) .reshape(batch, self.size) .permute(1, 0) @@ -276,7 +276,7 @@ def update_grid_from_samples(self, x): x, torch.ones( self.out_dim, - ).to(x.device), + ).type_as(x), ) .reshape(batch, self.size) .permute(1, 0) @@ -344,7 +344,7 @@ def initialize_grid_from_parent(self, parent, x): x, torch.ones( self.out_dim, - ).to(x.device), + ).type_as(x), ) .reshape(batch, self.size) .permute(1, 0) @@ -356,12 +356,12 @@ def initialize_grid_from_parent(self, parent, x): k=1, num=x_pos.shape[1] - 1, scale_base=0.0, - ).to(x.device) + ).type_as(x) sp2.coef.data = curve2coef(sp2.grid, x_pos, sp2.grid, k=1) y_eval = coef2curve( x_eval, parent.grid, parent.coef, parent.k, device=x.device ) - percentile = torch.linspace(-1, 1, self.num + 1).to(x.device) + percentile = torch.linspace(-1, 1, self.num + 1).type_as(x) self.grid.data = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0) self.coef.data = curve2coef( x_eval, y_eval, self.grid, self.k, x.device diff --git a/kan_gpt/train.py b/kan_gpt/train.py index 6cdf5a7..150d792 100644 --- a/kan_gpt/train.py +++ b/kan_gpt/train.py @@ -8,9 +8,6 @@ from kan_gpt.mingpt.utils import set_seed from kan_gpt.model import GPT as KAN_GPT -set_seed(3407) - - def eval_split( trainer, split, max_batches, batch_size, model, train_dataset, test_dataset ): @@ -31,14 +28,12 @@ def eval_split( if max_batches is not None and b + 1 >= max_batches: break rt = torch.tensor(results, dtype=torch.float) - print("%s loss: %.2f%%" % (split, rt.mean())) + print("%s loss: %.2f" % (split, rt.mean())) return rt.mean() def main(args): - wandb.init(project="KAN-GPT") - - wandb.config = { + config = { "model_type": args.model_type, "batch_size": args.batch_size, "dummy_dataset": args.dummy_dataset, @@ -48,6 +43,11 @@ def main(args): "architecture": args.architecture, } + wandb.init( + project="KAN-GPT", + config=config + ) + model_type = args.model_type # print an example instance of the dataset