Skip to content

Commit

Permalink
Repair json files
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Sep 26, 2024
1 parent fe49d3d commit abf4e6f
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 21 deletions.
11 changes: 7 additions & 4 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):
# bnb.optim.AdamW8bit
# grokadamw.GrokAdamW
# torch.optim.RMSprop

if isinstance(optimizer, str):
if "." in optimizer:
class_module, class_name = optimizer.rsplit(".", 1)
Expand All @@ -583,7 +583,7 @@ def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):

valid_params = set(inspect.signature(optimizer_cls).parameters)
kwargs = {key: value for key, value in dict(kwargs).items() if key in valid_params}

optimizer["init_args"].update(kwargs)
optimizer = instantiate_class(model_parameters, optimizer)
else:
Expand Down Expand Up @@ -681,6 +681,9 @@ def check_nvlink_connectivity(fabric=None):


def fix_and_load_json(json_string):
""" Remove trailing comma before closing curly brace """
fixed_json_string = re.sub(r',\s*}', "}", json_string)
# Fix missing commas between JSON properties by finding consecutive closing and opening braces/quotes
fixed_json_string = re.sub(r'(["}\d])\s*["{]', r'\1,\n"', json_string)

# Remove trailing comma before closing curly brace
fixed_json_string = re.sub(r',\s*}', '}', fixed_json_string)
return json.loads(fixed_json_string)
58 changes: 41 additions & 17 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,16 +548,16 @@ def nvlink_all_gpu_connected_but_other_connected_output():
GPU5 NV12 NV12 NV12 NV12 NV12 X NV12 NV12 SYS SYS SYS SYS SYS SYS SYS SYS PXB PXB 64-127,192-254 1 N/A
GPU6 NV12 NV12 NV12 NV12 NV12 NV12 X NV12 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS 64-127,192-254 1 N/A
GPU7 NV12 NV12 NV12 NV12 NV12 NV12 NV12 X SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS 64-127,192-254 1 N/A
NIC0 SYS SYS PXB PXB SYS SYS SYS SYS X PIX SYS SYS SYS SYS SYS SYS SYS SYS
NIC1 SYS SYS PXB PXB SYS SYS SYS SYS PIX X SYS SYS SYS SYS SYS SYS SYS SYS
NIC2 PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS X PXB SYS SYS SYS SYS SYS SYS
NIC3 PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS PXB X SYS SYS SYS SYS SYS SYS
NIC4 SYS SYS SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS X PXB SYS SYS SYS SYS
NIC5 SYS SYS SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS PXB X SYS SYS SYS SYS
NIC6 SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS X PIX SYS SYS
NIC7 SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS PIX X SYS SYS
NIC8 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS X PXB
NIC9 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS PXB X
NIC0 SYS SYS PXB PXB SYS SYS SYS SYS X PIX SYS SYS SYS SYS SYS SYS SYS SYS
NIC1 SYS SYS PXB PXB SYS SYS SYS SYS PIX X SYS SYS SYS SYS SYS SYS SYS SYS
NIC2 PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS X PXB SYS SYS SYS SYS SYS SYS
NIC3 PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS PXB X SYS SYS SYS SYS SYS SYS
NIC4 SYS SYS SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS X PXB SYS SYS SYS SYS
NIC5 SYS SYS SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS PXB X SYS SYS SYS SYS
NIC6 SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS X PIX SYS SYS
NIC7 SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS PIX X SYS SYS
NIC8 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS X PXB
NIC9 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS PXB X
Legend:
Expand Down Expand Up @@ -594,8 +594,8 @@ def test_nvlink_all_gpu_connected_but_other_connected_output(mock_run, nvlink_al


def test_fix_and_load_json():
# Invalid JSON string with a trailing comma
invalid_json = '''
# Test 1: Invalid JSON string with a trailing comma
invalid_json_trailing_comma = '''
{
"_from_model_config": true,
"bos_token_id": 128000,
Expand All @@ -607,8 +607,7 @@ def test_fix_and_load_json():
}
'''

# Expected valid Python dictionary after fixing the JSON
expected_output = {
expected_output_trailing_comma = {
"_from_model_config": True,
"bos_token_id": 128000,
"eos_token_id": 128001,
Expand All @@ -618,6 +617,31 @@ def test_fix_and_load_json():
"top_p": 0.9
}

# Run the function and compare the result
result = fix_and_load_json(invalid_json)
assert result == expected_output
result_trailing_comma = fix_and_load_json(invalid_json_trailing_comma)
assert result_trailing_comma == expected_output_trailing_comma

# Test 2: Invalid JSON string with missing commas between properties
invalid_json_missing_commas = '''
{
"_from_model_config": true,
"bos_token_id": 128000,
"eos_token_id": 128001,
"transformers_version": "4.45.0.dev0"
"do_sample": true,
"temperature": 0.6,
"top_p": 0.9
}
'''

expected_output_missing_commas = {
"_from_model_config": True,
"bos_token_id": 128000,
"eos_token_id": 128001,
"transformers_version": "4.45.0.dev0",
"do_sample": True,
"temperature": 0.6,
"top_p": 0.9
}

result_missing_commas = fix_and_load_json(invalid_json_missing_commas)
assert result_missing_commas == expected_output_missing_commas

0 comments on commit abf4e6f

Please sign in to comment.