Skip to content

Commit

Permalink
Implement custom get/setstate for CheckpointPath
Browse files Browse the repository at this point in the history
Reviewed By: JKSenthil

Differential Revision: D56654810

fbshipit-source-id: e1bf8573d007e11b0eee444f54e82e62381dbe60
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Apr 27, 2024
1 parent 698d4d0 commit e1135d6
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tests/utils/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict
import os
import pickle
import shutil
import tempfile
import unittest
Expand Down Expand Up @@ -173,6 +174,21 @@ def test_compare_by_optimality(self) -> None:
self.assertTrue(smaller.more_optimal_than(larger, mode="min"))
self.assertFalse(larger.more_optimal_than(smaller, mode="min"))

def test_pickling(self) -> None:
for path in (
"foo/epoch_0_step_1",
"file://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61/epoch_2_step_1_acc=0.98",
):
ckpt = CheckpointPath.from_str(path)

pickled = pickle.dumps(ckpt)

# Don't test equality because of custom protocol
self.assertTrue(path in str(pickled))

unpickled = pickle.loads(pickled)
self.assertEqual(unpickled, ckpt)


class CheckpointUtilsTest(unittest.TestCase):
@staticmethod
Expand Down
76 changes: 76 additions & 0 deletions tests/utils/test_checkpoint_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import os
import shutil
import tempfile
import unittest

import torch.distributed as dist
from torchtnt.utils import init_from_env
from torchtnt.utils.checkpoint import get_checkpoint_dirpaths
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu


class TestCheckpointUtilsGPU(unittest.TestCase):

@skip_if_not_distributed
@skip_if_not_gpu
def test_get_checkpoint_dirpaths_distributed(self) -> None:
spawn_multi_process(
2,
"nccl",
self._test_get_checkpoint_dirpaths,
)

@staticmethod
def _test_get_checkpoint_dirpaths() -> None:
"""
Tests retrieving checkpoint directories from a given root directory
using NCCL on GPUs with custom state for pickling.
"""
init_from_env()
paths = [
"epoch_0_step_10",
"epoch_1_step_10_val_loss=10.5",
"epoch_2_step_10",
"epoch_0_step_5",
"epoch_0_step_6_acc=0.03",
"epoch_0_step_3",
]

if get_global_rank() == 0:
temp_dir = tempfile.mkdtemp()
for path in paths:
os.mkdir(os.path.join(temp_dir, path))
else:
temp_dir = None

tc = unittest.TestCase()
# Only rank 0 will know about temp_dir
if get_global_rank() != 0:
tc.assertIsNone(temp_dir)

ckpt_dirpaths = get_checkpoint_dirpaths(
temp_dir, process_group=dist.group.WORLD
)

# Broadcast temp_dir to verify successful execution
temp_dir = [temp_dir] if get_global_rank() == 0 else [None]
dist.broadcast_object_list(temp_dir, src=0, group=dist.group.WORLD)
temp_dir = temp_dir[0]
tc.assertIsNotNone(temp_dir)

tc.assertEqual(
{str(x) for x in ckpt_dirpaths},
{os.path.join(temp_dir, path) for path in paths},
)

if get_global_rank() == 0:
shutil.rmtree(temp_dir)
19 changes: 19 additions & 0 deletions torchtnt/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,25 @@ def __eq__(self, other: "CheckpointPath") -> bool:
def __gt__(self, other: "CheckpointPath") -> bool:
return self.newer_than(other)

def __getstate__(self) -> str:
# Lightweight pickling to avoid broadcast errors
return self.path

def __setstate__(self, state: str) -> None:
# Match regex directly to avoid creating a new instance with `from_str`
path_match = self.PATH_REGEX.match(state)
assert path_match, f"Malformed checkpoint found when unpickling: {state}"

dirpath, epoch, step, metric_name, metric_value = path_match.groups()
self.dirpath = dirpath.rstrip("/")
self.epoch = int(epoch)
self.step = int(step)
self.metric_data = (
MetricData(name=metric_name, value=float(metric_value))
if metric_name and metric_value
else None
)


@rank_zero_read_and_broadcast
def get_latest_checkpoint_path(
Expand Down

0 comments on commit e1135d6

Please sign in to comment.