Skip to content

Commit

Permalink
Fixes some typing issues (#3418)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jun 21, 2024
1 parent 16e2862 commit 62c5b1f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
4 changes: 4 additions & 0 deletions composer/callbacks/eval_output_logging_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def eval_batch_end(self, state: State, logger: Logger) -> None:
self.rows.extend(rows)

def eval_end(self, state: State, logger: Logger) -> None:
# eval_batch_end will have set these if there is anything to log
if self.name is None or self.columns is None:
return

list_of_rows = dist.all_gather_object(self.rows)
rows = [row for rows in list_of_rows for row in rows]
for dest_logger in logger.destinations:
Expand Down
10 changes: 5 additions & 5 deletions composer/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Evaluator:
When specifying ``eval_interval``, the evaluator(s) are also run at the ``Event.FIT_END`` if it doesn't
evenly divide the training duration.
device_eval_microbatch_size (int, optional): The number of samples to use for each microbatch when evaluating.
device_eval_microbatch_size (str | int | float, optional): The number of samples to use for each microbatch when evaluating.
If set to ``auto``, dynamically decreases device_eval_microbatch_size if microbatch is too large for GPU.
If None, sets `device_eval_microbatch_size` to per rank batch size. (default: ``None``)
"""
Expand All @@ -80,7 +80,7 @@ def __init__(
metric_names: Optional[list[str]] = None,
subset_num_batches: Optional[int] = None,
eval_interval: Optional[Union[int, str, Time, Callable[[State, Event], bool]]] = None,
device_eval_microbatch_size: Optional[Union[int, str]] = None,
device_eval_microbatch_size: Optional[Union[int, str, float]] = None,
):
self.label = label
self.dataloader = ensure_data_spec(dataloader)
Expand Down Expand Up @@ -142,7 +142,7 @@ def ensure_evaluator(evaluator: Union[Evaluator, DataSpec, Iterable, dict[str, A
)


def _is_auto_microbatching(device_eval_microbatch_size: Optional[Union[int, str]]):
def _is_auto_microbatching(device_eval_microbatch_size: Optional[Union[int, str, float]]):
if device_eval_microbatch_size == 'auto':
warnings.warn((
"Setting `device_eval_microbatch_size='auto'` is an experimental feature which may cause "
Expand All @@ -155,10 +155,10 @@ def _is_auto_microbatching(device_eval_microbatch_size: Optional[Union[int, str]


def _get_initial_device_eval_microbatch_size(
device_eval_microbatch_size: Optional[Union[int, str]],
device_eval_microbatch_size: Optional[Union[int, str, float]],
auto_microbatching: bool,
dataloader: Iterable,
) -> int:
) -> Union[int, float]:
"""Sets initial value of device_eval_microbatch_size.
If auto_microbatching, sets initial `device_eval_microbatch_size` to per rank batch size.
Expand Down
4 changes: 3 additions & 1 deletion composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def __init__(
def _start_mlflow_run(self, state):
import mlflow

# This function is only called if self._enabled is True, and therefore self._experiment_id is not None.
assert self._experiment_id is not None

env_run_id = os.getenv(
mlflow.environment_variables.MLFLOW_RUN_ID.name, # pyright: ignore[reportGeneralTypeIssues]
None,
Expand All @@ -193,7 +196,6 @@ def _start_mlflow_run(self, state):
self._run_id = env_run_id
elif self.resume:
# Search for an existing run tagged with this Composer run if `self.resume=True`.
assert self._experiment_id is not None
run_name = self.tags['run_name']
existing_runs = mlflow.search_runs(
experiment_ids=[self._experiment_id],
Expand Down

0 comments on commit 62c5b1f

Please sign in to comment.