diff --git a/composer/utils/parallelism.py b/composer/utils/parallelism.py index a22dd8f8db..d3c97689d9 100644 --- a/composer/utils/parallelism.py +++ b/composer/utils/parallelism.py @@ -4,7 +4,7 @@ """Parallelism configs.""" from dataclasses import dataclass, field -from typing import Any, Optional, Union +from typing import Any, Optional from torch.distributed._tensor.device_mesh import DeviceMesh