Skip to content

Commit

Permalink
Add ruff pre-commit (mosaicml#2414)
Browse files Browse the repository at this point in the history
* Add ruff pre-commit hook and apply some sane autofixes

* Reduce number of autofixes

* Fix pyproject

* Update ruff hook

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
Skylion007 and dakinggg committed Aug 4, 2023
1 parent 8d4dfaf commit cc35953
Show file tree
Hide file tree
Showing 24 changed files with 68 additions and 42 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
default_language_version:
python: python3
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.282
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/google/yapf
rev: v0.32.0
hooks:
Expand Down
4 changes: 2 additions & 2 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,7 @@ def load_state_dict(
algorithm_passes=algorithm_passes,
)

for attribute_name in sorted(list(state.keys())): # Sort so all ranks load in the same order
for attribute_name in sorted(state.keys()): # Sort so all ranks load in the same order
serialized_value = state[attribute_name]
# Skip removed attributes as well as algorithms and model, which was already loaded
if attribute_name not in self.serialized_attributes or attribute_name == 'model':
Expand All @@ -1235,7 +1235,7 @@ def load_state_dict(

# Restructure algorithms serialized_value from list to dict
if attribute_name == 'algorithms' and isinstance(serialized_value, list):
serialized_value = {algo_name: algo_serialized for algo_name, algo_serialized in serialized_value}
serialized_value = dict(serialized_value)

if attribute_name == 'dataset_state':
self._load_dataset_state(serialized_value)
Expand Down
2 changes: 1 addition & 1 deletion composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def prep_examples(self, num_fewshot: int, prompt_string: str, example_delimiter:

max_answer_length = max(
max_answer_length,
max(map(lambda x: len(self.tokenizer(x)['input_ids']), self.samples[sample_idx]['aliases'])))
max((len(self.tokenizer(x)['input_ids']) for x in self.samples[sample_idx]['aliases'])))

self.max_answer_length = max_answer_length
return examples
Expand Down
2 changes: 1 addition & 1 deletion composer/metrics/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def replace_underscore(text: str) -> str:
def update(self, outputs: List[str], labels: List[List[str]]):
for sample_output, sample_labels in zip(outputs, labels):
cleaned_sample_output = self.normalize_answer(sample_output)
cleaned_sample_labels = set(self.normalize_answer(label) for label in sample_labels)
cleaned_sample_labels = {self.normalize_answer(label) for label in sample_labels}
if any(cleaned_sample_output.startswith(label) for label in cleaned_sample_labels):
self.correct += torch.tensor(1.0)
self.total += torch.tensor(1.0)
Expand Down
2 changes: 1 addition & 1 deletion composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def _is_registered_causal_lm(model: transformers.PreTrainedModel) -> bool:
conda_package='transformers',
conda_channel='conda-forge') from e
causal_lm_classes = list(MODEL_FOR_CAUSAL_LM_MAPPING.values())
return any([isinstance(model, causal_lm_class) for causal_lm_class in causal_lm_classes])
return any(isinstance(model, causal_lm_class) for causal_lm_class in causal_lm_classes)


def get_hf_config_from_composer_state_dict(state_dict: Dict[str, Any],
Expand Down
12 changes: 6 additions & 6 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,8 +1249,8 @@ def __init__(
self._scheduler_step_frequency = TimeUnit.BATCH if step_schedulers_every_batch else TimeUnit.EPOCH

# Some algorithms require specific settings
self._backwards_create_graph = any(map(lambda x: x.backwards_create_graph, self.state.algorithms))
self._find_unused_parameters = any(map(lambda x: x.find_unused_parameters, self.state.algorithms))
self._backwards_create_graph = any((x.backwards_create_graph for x in self.state.algorithms))
self._find_unused_parameters = any((x.find_unused_parameters for x in self.state.algorithms))
self._ddp_sync_strategy = _get_ddp_sync_strategy(ddp_sync_strategy, self._find_unused_parameters)

# Suppressing GradScaler warnings as they are always created
Expand Down Expand Up @@ -2159,7 +2159,7 @@ def _eval_train_metrics(self, device_batch):
model_eval_mode(self.state.model),\
_get_precision_context(self.state.precision, self.state.precision_config, self.state.deepspeed_enabled):
eval_outputs = self._original_model.eval_forward(device_batch, self.state.outputs)
for _, metric in self.state.train_metrics.items():
for metric in self.state.train_metrics.values():
self._original_model.update_metric(
device_batch,
eval_outputs,
Expand Down Expand Up @@ -2209,7 +2209,7 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]:
# Reset train_metrics on every batch
# Placing reset here ensures that if auto grad accum catches an OOM, incomplete metric state is cleared
if self.state.train_metrics is not None:
for _, metric in self.state.train_metrics.items():
for metric in self.state.train_metrics.values():
metric.reset()

total_loss_dict = {'loss/train/total': self.state.device.tensor_to_device(torch.zeros(size=(1,)))}
Expand Down Expand Up @@ -2801,7 +2801,7 @@ def _eval_loop(

metrics = self._ensure_metrics_device_and_dtype(metrics)

for _, metric in metrics.items():
for metric in metrics.values():
metric.reset()

dataloader = self.state.dataloader
Expand Down Expand Up @@ -2899,7 +2899,7 @@ def _eval_loop(
else:
outputs = self.state.outputs

for _, metric in metrics.items():
for metric in metrics.values():
self._original_model.update_metric(
self.state.batch,
outputs,
Expand Down
2 changes: 1 addition & 1 deletion composer/utils/auto_log_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _get_obj_repr(obj: Any):
obj if obj is None or it is a int, float, str, bool type. Otherwise
returns obj.__class__.__name__.
"""
if any([isinstance(obj, type_) for type_ in [int, float, str, bool]]) or obj is None:
if any(isinstance(obj, type_) for type_ in [int, float, str, bool]) or obj is None:
return obj
else:
return obj.__class__.__name__
Expand Down
2 changes: 1 addition & 1 deletion composer/utils/batch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _batch_set_tuple(batch: Any, key: Union[int, str], value: Any) -> Any:


def _is_key_get_and_set_fn_pair(key):
if all([callable(key_element) for key_element in key]):
if all(callable(key_element) for key_element in key):
if len(key) == 2:
return True
else:
Expand Down
4 changes: 2 additions & 2 deletions docker/generate_build_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _main():
entry['PYTORCH_VERSION'], # Pytorch version
cuda_version, # Cuda version
entry['PYTHON_VERSION'], # Python version,
', '.join(reversed(list(f'`{x}`' for x in entry['TAGS']))), # Docker tags
', '.join(reversed([f'`{x}`' for x in entry['TAGS']])), # Docker tags
])
table.sort(
key=lambda x: x[3].replace('Infiniband', '1').replace('EFA', '2')) # cuda version, put infiniband ahead of EFA
Expand All @@ -272,7 +272,7 @@ def _main():
table.append([
entry['TAGS'][0].split(':')[1].replace('_cpu', ''), # Composer version, or 'latest'
'No' if entry['BASE_IMAGE'].startswith('ubuntu:') else 'Yes', # Whether there is Cuda support
', '.join(reversed(list(f'`{x}`' for x in entry['TAGS']))), # Docker tags
', '.join(reversed([f'`{x}`' for x in entry['TAGS']])), # Docker tags
])
table.sort(key=lambda x: x[1], reverse=True) # cuda support
table.sort(key=lambda x: packaging.version.parse('9999999999999'
Expand Down
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,25 @@ line_length = 120
skip = [ "env", "wandb", "runs", "build", "node_modules" ]
known_third_party = ["wandb"]

[tool.ruff]
select = [
"C4",
# TODO port pydocstyle
# "D", # pydocstyle
"PERF",
]

ignore = [
"C408",
"PERF2",
"PERF4",
]
exclude = [
"build/**",
"docs/**",
"node_modules/**",
]

# Pyright
[tool.pyright]
include = [
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def package_files(prefix: str, directory: str, extension: str):
'mlflow>=2.0.1,<3.0',
]

extra_deps['all'] = set(dep for deps in extra_deps.values() for dep in deps)
extra_deps['all'] = {dep for deps in extra_deps.values() for dep in deps}

composer_data_files = ['py.typed']
composer_data_files += package_files('composer', 'yamls', '.yaml')
Expand Down
4 changes: 2 additions & 2 deletions tests/algorithms/test_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def validate_model(model1, model2):
model1_params, model1_buffers = dict(model1.named_parameters()), dict(model1.named_buffers())
model2_params, model2_buffers = dict(model2.named_parameters()), dict(model2.named_buffers())

for name, _ in model1_params.items():
for name in model1_params.keys():
torch.testing.assert_close(model1_params[name].data, model2_params[name].data)

for name, _ in model1_buffers.items():
for name in model1_buffers.keys():
torch.testing.assert_close(model1_buffers[name].data, model2_buffers[name].data)


Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/test_required_on_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module, is_equal:
model_2_modules = list(model_2.module.modules())
assert len(model_1_modules) == len(model_2_modules)
for module_1, module_2 in zip(model_1_modules, model_2_modules):
assert sorted(list(module_1.__dict__.keys())) == sorted(list(module_2.__dict__.keys()))
assert sorted(module_1.__dict__.keys()) == sorted(module_2.__dict__.keys())
# Compare model parameters
for (name0, tensor0), (name1, tensor1) in zip(model_1.state_dict().items(), model_2.state_dict().items()):
assert name0 == name1
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def test_callbacks_map_to_events():
# exception for private methods
cb = Callback()
excluded_methods = ['state_dict', 'load_state_dict', 'run_event', 'close', 'post_close']
methods = set(m for m in dir(cb) if (m not in excluded_methods and not m.startswith('_')))
event_names = set(e.value for e in Event)
methods = {m for m in dir(cb) if (m not in excluded_methods and not m.startswith('_'))}
event_names = {e.value for e in Event}
assert methods == event_names


Expand Down
2 changes: 1 addition & 1 deletion tests/common/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def device(*args, precision=False):
also returns the parameter "precision".
"""
# convert cpu-fp32 and gpu-fp32 to cpu, gpu
if not precision and any(['-' in arg for arg in args]):
if not precision and any('-' in arg for arg in args):
raise ValueError('-fp32 and -amp tags must be removed if precision=False')
args = [arg.replace('-fp32', '') for arg in args]

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def pytest_collection_modifyitems(config: pytest.Config, items: List[pytest.Item
remaining = []
deselected = []
for item in items:
if all([condition(item) for condition in conditions]):
if all(condition(item) for condition in conditions):
remaining.append(item)
else:
deselected.append(item)
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_cifar10_shape_length(is_train, synthetic):
else:
dataspec = build_cifar10_dataloader(datadir='/tmp', global_batch_size=batch_size, is_train=is_train)

samples = [_ for _ in dataspec.dataloader]
samples = list(dataspec.dataloader)
if is_train:
assert len(samples) == 50000 // batch_size
else:
Expand Down
20 changes: 10 additions & 10 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def test_fewshot_sample_idxs():
rng = random.Random(1234)

fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=4, sample_idx=4, rng=rng)
assert fewshot_idxs == set([0, 1, 2, 3])
assert fewshot_idxs == {0, 1, 2, 3}

fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=5, sample_idx=4, rng=rng)
assert fewshot_idxs == set([0, 1, 2, 3])
assert fewshot_idxs == {0, 1, 2, 3}

fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=5, num_fewshot=500, sample_idx=4, rng=rng)
assert fewshot_idxs == set([0, 1, 2, 3])
assert fewshot_idxs == {0, 1, 2, 3}

fewshot_idxs = _get_fewshot_sample_idxs(dataset_size=10, num_fewshot=7, sample_idx=4, rng=rng)
assert len(fewshot_idxs) == 7 and 4 not in fewshot_idxs
Expand Down Expand Up @@ -549,11 +549,11 @@ def test_qa_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fews
assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids'])

decoded_batch = tokenizer.batch_decode(batch['input_ids'])
assert all([item.count('Q: ') == num_fewshot + 1 for item in decoded_batch])
assert all([item.count('\nA:') == num_fewshot + 1 for item in decoded_batch])
assert all(item.count('Q: ') == num_fewshot + 1 for item in decoded_batch)
assert all(item.count('\nA:') == num_fewshot + 1 for item in decoded_batch)

if len(prompt_string) > 0:
assert all([item.count('I am a prompt') == 1 for item in decoded_batch])
assert all(item.count('I am a prompt') == 1 for item in decoded_batch)

assert batch['labels'] == [['David Seville'], ['Scorpio', 'Skorpio']]

Expand Down Expand Up @@ -711,10 +711,10 @@ def test_code_eval_sentpiece_dataloader(dataset_uri, tmp_path, num_fewshot, prom
assert any(item[0] != tokenizer.eos_token_id for item in batch['input_ids']) # longest should be pushed left

decoded_batch = tokenizer.batch_decode(batch['input_ids'])
assert all([item.count('Code start: \n') == num_fewshot + 1 for item in decoded_batch])
assert all(item.count('Code start: \n') == num_fewshot + 1 for item in decoded_batch)

if len(prompt_string) > 0:
assert all([item.count('Please code:\n') == 1 for item in decoded_batch])
assert all(item.count('Please code:\n') == 1 for item in decoded_batch)

assert batch['labels'] == [
" result = []\n current_string = []\n current_depth = 0\n\n for c in paren_string:\n if c == '(':\n current_depth += 1\n current_string.append(c)\n elif c == ')':\n current_depth -= 1\n current_string.append(c)\n\n if current_depth == 0:\n result.append(''.join(current_string))\n current_string.clear()\n\n return result\n",
Expand Down Expand Up @@ -847,10 +847,10 @@ def test_code_eval_task_dataloader(dataset_uri, tmp_path, num_fewshot, prompt_st
assert any(item[0] != tokenizer.eos_token_id for item in batch['input_ids']) # longest should be pushed left

decoded_batch = tokenizer.batch_decode(batch['input_ids'])
assert all([item.count('Code start: \n') == num_fewshot + 1 for item in decoded_batch])
assert all(item.count('Code start: \n') == num_fewshot + 1 for item in decoded_batch)

if len(prompt_string) > 0:
assert all([item.count('Please code:\n') == 1 for item in decoded_batch])
assert all(item.count('Please code:\n') == 1 for item in decoded_batch)

assert batch['labels'] == [
" result = []\n current_string = []\n current_depth = 0\n\n for c in paren_string:\n if c == '(':\n current_depth += 1\n current_string.append(c)\n elif c == ')':\n current_depth -= 1\n current_string.append(c)\n\n if current_depth == 0:\n result.append(''.join(current_string))\n current_string.clear()\n\n return result\n",
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_mnist_shape_length(is_train, synthetic):
else:
loader = build_mnist_dataloader(datadir='/tmp', global_batch_size=batch_size, is_train=is_train)

samples = [_ for _ in loader]
samples = list(loader)
if is_train:
assert len(samples) == 60000 // batch_size
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_cometml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def test_comet_ml_log_metrics_and_hyperparameters(monkeypatch, tmp_path):
# those written to offline dump.
assert [msg['metricValue'] for msg in metric_msgs] == metric_values
assert [msg['step'] for msg in metric_msgs] == steps
assert all([msg['metricName'] == metric_name for msg in metric_msgs])
assert all(msg['metricName'] == metric_name for msg in metric_msgs)

# Assert dummy params input to log_hyperparameters are the same as
# those written to offline dump
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_mlflow_logging_works(tmp_path, device):
metric_file = run_file_path / Path('metrics') / Path(metric_name)
with open(metric_file) as f:
csv_reader = csv.reader(f, delimiter=' ')
lines = [line for line in csv_reader]
lines = list(csv_reader)

assert len(lines) == num_batches

Expand Down
4 changes: 2 additions & 2 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,14 @@ def test_engine_trace_all(self, event: Event, dummy_state: State, always_match_a
dummy_state.algorithms = always_match_algorithms
trace = run_event(event, dummy_state, dummy_logger)

assert all([tr.run for tr in trace.values()])
assert all(tr.run for tr in trace.values())

def test_engine_trace_never(self, event: Event, dummy_state: State, never_match_algorithms: List[Algorithm],
dummy_logger: Logger):
dummy_state.algorithms = never_match_algorithms
trace = run_event(event, dummy_state, dummy_logger)

assert all([tr.run is False for tr in trace.values()])
assert all(tr.run is False for tr in trace.values())


def test_engine_is_dead_after_close(dummy_state: State, dummy_logger: Logger):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_split_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_split_without_error(batch):
@pytest.mark.parametrize('batch', [dummy_tensor_batch(i) for i in [12, 13, 14, 15]])
def test_tensor_vs_list_chunking(batch):
tensor_microbatches = _split_tensor(batch, microbatch_size=4)
list_microbatches = _split_list([t for t in batch], microbatch_size=4)
list_microbatches = _split_list(list(batch), microbatch_size=4)

assert len(tensor_microbatches) == len(list_microbatches)
assert all(torch.equal(t1, torch.stack(t2, dim=0)) for t1, t2 in zip(tensor_microbatches, list_microbatches))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def test_composer_metadata_in_state_dict(tmp_path, request: pytest.FixtureReques
torch.save(state.state_dict(), _tmp_file)

loaded_state_dict = torch.load(save_path)
expected_env_info_keys = set([
expected_env_info_keys = {
'composer_version', 'composer_commit_hash', 'node_world_size', 'host_processor_model_name',
'host_processor_core_count', 'local_world_size', 'accelerator_model_name', 'cuda_device_count'
])
}
actual_env_info_keys = set(loaded_state_dict['metadata']['composer_env_info'].keys())
assert expected_env_info_keys == actual_env_info_keys
assert loaded_state_dict['metadata']['composer_env_info']['composer_version'] == composer.__version__
Expand Down

0 comments on commit cc35953

Please sign in to comment.