diff --git a/litgpt/utils.py b/litgpt/utils.py index 18c4556691..65da258092 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -680,10 +680,18 @@ def check_nvlink_connectivity(fabric=None): custom_print(f"An error occurred: {e}") -def fix_and_load_json(json_string): - # Fix missing commas between JSON properties by finding consecutive closing and opening braces/quotes - fixed_json_string = re.sub(r'(["\d}])\s*(?=["{])', r'\1,\n', json_string) +def fix_and_load_json(s): + # Remove trailing commas before } or ] + s = re.sub(r',(\s*[}\]])', r'\1', s) - # Remove trailing comma before closing curly brace - fixed_json_string = re.sub(r',\s*}', '}', fixed_json_string) - return json.loads(fixed_json_string) + # Insert missing commas between properties + # Match positions where a value is followed by a newline and then a quote without a comma + pattern = r'(?<=[}\]0-9truefalsenull"])\s*(\n\s*)"' + replacement = r',\1"' + s = re.sub(pattern, replacement, s) + + # Now try to parse the JSON + try: + return json.loads(s) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse JSON after fixing: {e}")