Skip to content

Commit

Permalink
[GraphBolt][PyG] Refine examples. (#7806)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Sep 24, 2024
1 parent 5ae6400 commit 55d66fe
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
6 changes: 5 additions & 1 deletion examples/graphbolt/pyg/hetero/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def create_dataloader(
node_feature_keys["institute"] = ["feat"]
node_feature_keys["fos"] = ["feat"]
# Fetch node features for the sampled subgraph.
datapipe = datapipe.fetch_feature(features, node_feature_keys)
datapipe = datapipe.fetch_feature(
features,
node_feature_keys,
overlap_fetch=args.overlap_feature_fetch,
)

# Copy the data to the specified device.
if need_copy:
Expand Down
11 changes: 4 additions & 7 deletions examples/graphbolt/pyg/multigpu/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def weighted_reduce(tensor, weight, dst=0):


@torch.compile
def train_step(minibatch, optimizer, model, loss_fn, cooperative):
def train_step(minibatch, optimizer, model, loss_fn):
node_features = minibatch.node_features["feat"]
labels = minibatch.labels
optimizer.zero_grad()
Expand All @@ -211,9 +211,7 @@ def train_step(minibatch, optimizer, model, loss_fn, cooperative):
return loss.detach(), num_correct, labels.size(0)


def train_helper(
rank, dataloader, model, optimizer, loss_fn, device, cooperative
):
def train_helper(rank, dataloader, model, optimizer, loss_fn, device):
model.train() # Set the model to training mode
total_loss = torch.zeros(1, device=device) # Accumulator for the total loss
# Accumulator for the total number of correct predictions
Expand All @@ -223,7 +221,7 @@ def train_helper(
start = time.time()
for minibatch in tqdm(dataloader, "Training") if rank == 0 else dataloader:
loss, num_correct, num_samples = train_step(
minibatch, optimizer, model, loss_fn, cooperative
minibatch, optimizer, model, loss_fn
)
total_loss += loss
total_correct += num_correct
Expand Down Expand Up @@ -263,7 +261,6 @@ def train(args, rank, train_dataloader, valid_dataloader, model, device):
optimizer,
loss_fn,
device,
args.cooperative,
)
val_acc = evaluate(rank, model, valid_dataloader, device)
if rank == 0:
Expand Down Expand Up @@ -381,7 +378,7 @@ def parse_args():
default=1,
help="The number of accesses after which a vertex neighborhood will be cached.",
)
parser.add_argument("--precision", type=str, default="high")
parser.add_argument("--precision", type=str, default="medium")
parser.add_argument(
"--cooperative",
action="store_true",
Expand Down

0 comments on commit 55d66fe

Please sign in to comment.