Skip to content

Commit

Permalink
Merge branch 'dev' into checkpoint_saver
Browse files Browse the repository at this point in the history
  • Loading branch information
bigning authored Jun 21, 2024
2 parents c87f36c + 62c5b1f commit 1ebf5a7
Show file tree
Hide file tree
Showing 26 changed files with 966 additions and 244 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-3.11-2.3
container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04
container: mosaicml/pytorch:2.3.1_cpu-python3.11-ubuntu20.04
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-doctest
Expand Down
4 changes: 2 additions & 2 deletions composer/algorithms/augmix/augmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def _augmix_pil_image(
aug = np.random.choice(augmentation_set)
augmented_image = aug(augmented_image, severity)
augmented_combination += chain_weights[chain_i] * np.asarray(augmented_image)
mixed = (1 - mixing_weight) * np.asarray(img_pil) + mixing_weight * augmented_combination
mixed = Image.fromarray(np.uint8(mixed))
mixed = (1 - mixing_weight) * np.asarray(img_pil, dtype=np.float32) + mixing_weight * augmented_combination
mixed = Image.fromarray(np.uint8(mixed)) # type: ignore
return mixed

f_pil = functools.partial(
Expand Down
31 changes: 26 additions & 5 deletions composer/algorithms/utils/augmentation_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import numpy as np
from PIL import Image, ImageEnhance, ImageOps
from PIL.Image import Resampling, Transform

AugmentationFn = Callable[[Image.Image, float], Image.Image]

Expand Down Expand Up @@ -155,7 +156,7 @@ def rotate(pil_img: Image.Image, level: float):
degrees = _int_parameter(_sample_level(level), 30)
if np.random.uniform() > 0.5:
degrees = -degrees
return pil_img.rotate(degrees, resample=Image.BILINEAR)
return pil_img.rotate(degrees, resample=Resampling.BILINEAR)


def solarize(pil_img: Image.Image, level: float):
Expand Down Expand Up @@ -183,7 +184,12 @@ def shear_x(pil_img: Image.Image, level: float):
level = _float_parameter(_sample_level(level), 0.3)
if np.random.uniform() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, level, 0, 0, 1, 0), resample=Image.BILINEAR)
return pil_img.transform(
pil_img.size,
Transform.AFFINE,
(1, level, 0, 0, 1, 0),
resample=Resampling.BILINEAR,
)


def shear_y(pil_img: Image.Image, level: float):
Expand All @@ -197,7 +203,12 @@ def shear_y(pil_img: Image.Image, level: float):
level = _float_parameter(_sample_level(level), 0.3)
if np.random.uniform() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, 0, level, 1, 0), resample=Image.BILINEAR)
return pil_img.transform(
pil_img.size,
Transform.AFFINE,
(1, 0, 0, level, 1, 0),
resample=Resampling.BILINEAR,
)


def translate_x(pil_img: Image.Image, level: float):
Expand All @@ -211,7 +222,12 @@ def translate_x(pil_img: Image.Image, level: float):
level = _int_parameter(_sample_level(level), pil_img.size[0] / 3)
if np.random.random() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, level, 0, 1, 0), resample=Image.BILINEAR)
return pil_img.transform(
pil_img.size,
Transform.AFFINE,
(1, 0, level, 0, 1, 0),
resample=Resampling.BILINEAR,
)


def translate_y(pil_img: Image.Image, level: float):
Expand All @@ -225,7 +241,12 @@ def translate_y(pil_img: Image.Image, level: float):
level = _int_parameter(_sample_level(level), pil_img.size[1] / 3)
if np.random.random() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, 0, 0, 1, level), resample=Image.BILINEAR)
return pil_img.transform(
pil_img.size,
Transform.AFFINE,
(1, 0, 0, 0, 1, level),
resample=Resampling.BILINEAR,
)


# The following augmentations overlap with corruptions in the ImageNet-C/CIFAR10-C test
Expand Down
3 changes: 1 addition & 2 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@
partial_format,
retry,
)
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME
from composer.utils.compression import get_compressor, is_compressed_pt
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY

log = logging.getLogger(__name__)

__all__ = ['CheckpointSaver']

_TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME = '.metadata'


def _upload_symlink_file(
remote_backend_name: str,
Expand Down
4 changes: 4 additions & 0 deletions composer/callbacks/eval_output_logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def eval_batch_end(self, state: State, logger: Logger) -> None:
self.rows.extend(rows)

def eval_end(self, state: State, logger: Logger) -> None:
# eval_batch_end will have set these if there is anything to log
if self.name is None or self.columns is None:
return

list_of_rows = dist.all_gather_object(self.rows)
rows = [row for rows in list_of_rows for row in rows]
for dest_logger in logger.destinations:
Expand Down
Loading

0 comments on commit 1ebf5a7

Please sign in to comment.