diff --git a/samples/community/Sandeep-1507/template.json b/samples/community/Sandeep-1507/template.json index 217b905d..34bc45d1 100644 --- a/samples/community/Sandeep-1507/template.json +++ b/samples/community/Sandeep-1507/template.json @@ -1,7 +1,5 @@ { "Globals": { - "window_width": 1189, - "window_height": 1682, "display_width": 1189, "display_height": 1682 }, diff --git a/samples/sample5/config.json b/samples/sample5/config.json index adee33f8..2d5466db 100644 --- a/samples/sample5/config.json +++ b/samples/sample5/config.json @@ -1,5 +1,5 @@ { "threshold_params": { - "MIN_JUMP": 200 + "MIN_JUMP": 20 } } \ No newline at end of file diff --git a/samples/sample5/template.json b/samples/sample5/template.json index 7c50dc30..597bef99 100644 --- a/samples/sample5/template.json +++ b/samples/sample5/template.json @@ -1,7 +1,5 @@ { "Globals": { - "window_width": 1189, - "window_height": 1682, "display_width": 1189, "display_height": 1682 }, diff --git a/samples/sample6/template.json b/samples/sample6/template.json index 811b99c3..3c28eb48 100644 --- a/samples/sample6/template.json +++ b/samples/sample6/template.json @@ -1,7 +1,5 @@ { "Globals": { - "window_width": 2480, - "window_height": 3508, "display_width": 2480, "display_height": 3508 }, diff --git a/samples/sample6/template_fb_align.json b/samples/sample6/template_fb_align.json index f65ed7ed..3f0923a1 100644 --- a/samples/sample6/template_fb_align.json +++ b/samples/sample6/template_fb_align.json @@ -1,7 +1,5 @@ { "Globals": { - "window_width": 2480, - "window_height": 3508, "display_width": 2480 , "display_height": 3508 }, diff --git a/samples/sample6/template_no_fb_align.json b/samples/sample6/template_no_fb_align.json index 811b99c3..3c28eb48 100644 --- a/samples/sample6/template_no_fb_align.json +++ b/samples/sample6/template_no_fb_align.json @@ -1,7 +1,5 @@ { "Globals": { - "window_width": 2480, - "window_height": 3508, "display_width": 2480, "display_height": 3508 }, diff --git a/src/constants.py b/src/constants.py index 669db0b3..560a12ab 100644 --- a/src/constants.py +++ b/src/constants.py @@ -20,11 +20,11 @@ EVALUATION_FILENAME = "evaluation.json" CONFIG_FILENAME = "config.json" -SCHEMA_NAMES = { +SCHEMA_NAMES = DotMap({ "template": "template", "evaluation": "evaluation", "config": "config", -} +}) # ERROR_CODES = DotMap( @@ -57,9 +57,6 @@ GLOBAL_PAGE_THRESHOLD_WHITE = 200 GLOBAL_PAGE_THRESHOLD_BLACK = 100 -# Filepaths - object is better - - class Paths: def __init__(self, output_dir): self.output_dir = output_dir diff --git a/src/core.py b/src/core.py index 779060a3..fe7fc9c4 100644 --- a/src/core.py +++ b/src/core.py @@ -26,12 +26,7 @@ from src.template import Template # TODO: further break utils down and separate the imports -from src.utils.imgutils import ( - ImageUtils, - MainOperations, - draw_template_layout, - setup_dirs, -) +from src.utils.image import ImageUtils, MainOperations, draw_template_layout, setup_dirs from src.utils.parsing import ( evaluate_concatenated_response, get_concatenated_response, @@ -265,10 +260,10 @@ def process_files(omr_files, template, tuning_config, evaluation_config, args, o ) # TODO: Get rid of saveImgList - for i in range(ImageUtils.save_image_level): - ImageUtils.reset_save_img(i + 1) + for i in range(InstanceImageUtils.save_image_level): + InstanceImageUtils.reset_save_img(i + 1) - ImageUtils.append_save_img(1, in_omr) + InstanceImageUtils.append_save_img(1, in_omr) # resize to conform to template in_omr = ImageUtils.resize_util( diff --git a/src/defaults/__init__.py b/src/defaults/__init__.py index 2d45045c..24f5107d 100644 --- a/src/defaults/__init__.py +++ b/src/defaults/__init__.py @@ -1,4 +1,6 @@ # https://docs.python.org/3/tutorial/modules.html#:~:text=The%20__init__.py,on%20the%20module%20search%20path. -from .config import * # NOQA -from .evaluation import * # NOQA -from .template import * # NOQA +# Use all imports relative to root directory +# (https://chrisyeh96.github.io/2017/08/08/definitive-guide-python-imports.html) +from src.defaults.config import * # NOQA +from src.defaults.evaluation import * # NOQA +from src.defaults.template import * # NOQA diff --git a/src/defaults/config.py b/src/defaults/config.py index 35ac9898..5e66f6e8 100644 --- a/src/defaults/config.py +++ b/src/defaults/config.py @@ -7,8 +7,6 @@ "display_width": 640, "processing_height": 820, "processing_width": 666, - "window_width": 1280, - "window_height": 720, }, "threshold_params": { "GAMMA_LOW": 0.7, diff --git a/src/processors/CropOnMarkers.py b/src/processors/CropOnMarkers.py index 865f7766..f6e7f17b 100644 --- a/src/processors/CropOnMarkers.py +++ b/src/processors/CropOnMarkers.py @@ -4,15 +4,14 @@ import numpy as np from src.logger import logger -from src.utils.imgutils import ( +from src.processors.interfaces.ImagePreprocessor import ImagePreprocessor +from src.utils.image import ( ImageUtils, MainOperations, four_point_transform, normalize_util, ) -from .interfaces.ImagePreprocessor import ImagePreprocessor - class CropOnMarkers(ImagePreprocessor): def __init__(self, *args, **kwargs): @@ -45,7 +44,7 @@ def __init__(self, *args, **kwargs): # TODO: processing_width should come through proper channel marker = ImageUtils.resize_util( marker, - self.tuning_config.dimensions.processing_width + config.dimensions.processing_width / int(marker_ops["sheetToMarkerWidthRatio"]), ) marker = cv2.GaussianBlur(marker, (5, 5), 0) @@ -101,7 +100,7 @@ def getBestMatch(self, image_eroded_sub): logger.warning( "\tTemplate matching too low! Consider rechecking preProcessors applied before this." ) - if self.tuning_config.outputs.show_image_level >= 1: + if config.outputs.show_image_level >= 1: MainOperations.show("res", res, 1, 0) if best_scale is None: @@ -111,6 +110,7 @@ def getBestMatch(self, image_eroded_sub): return best_scale, all_max_t def apply_filter(self, image, args): + config = self.tuning_config image_eroded_sub = normalize_util( image if self.apply_erode_subtract @@ -133,7 +133,7 @@ def apply_filter(self, image, args): best_scale, all_max_t = self.getBestMatch(image_eroded_sub) if best_scale is None: # TODO: Plot and see performance of marker_rescale_range - if self.tuning_config.outputs.show_image_level >= 1: + if config.outputs.show_image_level >= 1: MainOperations.show("Quads", image_eroded_sub) return None @@ -167,7 +167,7 @@ def apply_filter(self, image, args): "\t all_max_t", all_max_t, ) - if self.tuning_config.outputs.show_image_level >= 1: + if config.outputs.show_image_level >= 1: MainOperations.show( "no_pts_" + args["current_file"].name, image_eroded_sub, 0 ) @@ -202,15 +202,15 @@ def apply_filter(self, image, args): # appendSaveImg(1,image_eroded_sub) # appendSaveImg(1,image_norm) - ImageUtils.append_save_img(2, image_eroded_sub) + InstanceImageUtils.append_save_img(2, image_eroded_sub) # Debugging image - # res = cv2.matchTemplate(image_eroded_sub,optimal_marker,cv2.TM_CCOEFF_NORMED) # res[ : , midw:midw+2] = 255 # res[ midh:midh+2, : ] = 255 # show("Markers Matching",res) if ( - self.tuning_config.outputs.show_image_level >= 2 - and self.tuning_config.outputs.show_image_level < 4 + config.outputs.show_image_level >= 2 + and config.outputs.show_image_level < 4 ): image_eroded_sub = ImageUtils.resize_util_h( image_eroded_sub, image.shape[0] @@ -220,7 +220,7 @@ def apply_filter(self, image, args): MainOperations.show( "Warped: " + args["current_file"].name, ImageUtils.resize_util( - h_stack, int(self.tuning_config.dimensions.display_width * 1.6) + h_stack, int(config.dimensions.display_width * 1.6) ), 0, 0, diff --git a/src/processors/CropPage.py b/src/processors/CropPage.py index f0eeb408..6e826e28 100644 --- a/src/processors/CropPage.py +++ b/src/processors/CropPage.py @@ -7,9 +7,8 @@ import numpy as np from src.logger import logger -from src.utils.imgutils import ImageUtils, four_point_transform - -from .interfaces.ImagePreprocessor import ImagePreprocessor +from src.processors.interfaces.ImagePreprocessor import ImagePreprocessor +from src.utils.image import four_point_transform, grab_contours MIN_PAGE_AREA = 80000 @@ -86,7 +85,7 @@ def find_page(self, image): # findContours returns outer boundaries in CW and inner boundaries in ACW # order. - cnts = ImageUtils.grab_contours( + cnts = grab_contours( cv2.findContours(edge, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) ) # hullify to resolve disordered curves due to noise diff --git a/src/processors/FeatureBasedAlignment.py b/src/processors/FeatureBasedAlignment.py index 1548d6a5..250d7786 100644 --- a/src/processors/FeatureBasedAlignment.py +++ b/src/processors/FeatureBasedAlignment.py @@ -6,9 +6,8 @@ import cv2 import numpy as np -from src.utils.imgutils import MainOperations - -from .interfaces.ImagePreprocessor import ImagePreprocessor +from src.processors.interfaces.ImagePreprocessor import ImagePreprocessor +from src.utils.image import MainOperations class FeatureBasedAlignment(ImagePreprocessor): diff --git a/src/processors/builtins.py b/src/processors/builtins.py index daaedf72..28345eec 100644 --- a/src/processors/builtins.py +++ b/src/processors/builtins.py @@ -1,7 +1,7 @@ import cv2 import numpy as np -from .interfaces.ImagePreprocessor import ImagePreprocessor +from src.processors.interfaces.ImagePreprocessor import ImagePreprocessor class Levels(ImagePreprocessor): diff --git a/src/schemas/config-schema.json b/src/schemas/config-schema.json index 08d2dd98..18525675 100644 --- a/src/schemas/config-schema.json +++ b/src/schemas/config-schema.json @@ -9,9 +9,7 @@ "display_height": { "type": "integer" }, "display_width": { "type": "integer" }, "processing_height": { "type": "integer" }, - "processing_width": { "type": "integer" }, - "window_width": { "type": "integer" }, - "window_height": { "type": "integer" } + "processing_width": { "type": "integer" } }, "threshold_params": { "GAMMA_LOW": { "type": "number", "minimum": 0, "maximum": 1 }, diff --git a/src/utils/imgutils.py b/src/utils/image.py similarity index 88% rename from src/utils/imgutils.py rename to src/utils/image.py index bc289240..6f8bd11b 100644 --- a/src/utils/imgutils.py +++ b/src/utils/image.py @@ -8,8 +8,9 @@ """ # TODO: refactor this file. -# Use all imports relative to root directory -# (https://chrisyeh96.github.io/2017/08/08/definitive-guide-python-imports.html) +# Image-processing utils +plt.rcParams["figure.figsize"] = (10.0, 8.0) + import os import sys from dataclasses import dataclass @@ -18,46 +19,90 @@ import matplotlib.pyplot as plt import numpy as np -# TODO: pass config in runtime later -import src.constants as constants from src.logger import logger +import src.constants as constants +from collections import defaultdict +from typing import Any +def grab_contours(cnts): + # source: imutils package + + # if the length the contours tuple returned by cv2.findContours + # is '2' then we are using either OpenCV v2.4, v4-beta, or + # v4-official + if len(cnts) == 2: + cnts = cnts[0] + + # if the length of the contours tuple is '3' then we are using + # either OpenCV v3, v4-pre, or v4-alpha + elif len(cnts) == 3: + cnts = cnts[1] + + # otherwise OpenCV has changed their cv2.findContours return + # signature yet again and I have no idea WTH is going on + else: + raise Exception( + ( + "Contours tuple must have length 2 or 3, " + "otherwise OpenCV changed their cv2.findContours return " + "signature yet again. Refer to OpenCV's documentation " + "in that case" + ) + ) + # return the actual contours array + return cnts class ImageUtils: + @staticmethod + def resize_util(img, u_width, u_height=None): + if u_height is None: + h, w = img.shape[:2] + u_height = int(h * u_width / w) + return cv2.resize(img, (int(u_width), int(u_height))) + + @staticmethod + def resize_util_h(img, u_height, u_width=None): + if u_width is None: + h, w = img.shape[:2] + u_width = int(w * u_height / h) + return cv2.resize(img, (int(u_width), int(u_height))) + +class InstanceImageUtils: """Class to hold indicators of images and save images.""" - save_image_level = self.tuning_config.outputs.save_image_level - save_img_list = {} + save_img_list:Any = defaultdict(list) + def __init__(self, tuning_config): + super().__init__() + self.tuning_config = tuning_config + self.save_image_level = tuning_config.outputs.save_image_level @staticmethod def reset_save_img(key): - ImageUtils.save_img_list[key] = [] + InstanceImageUtils.save_img_list[key] = [] # TODO: why is this static @staticmethod def append_save_img(key, img): - if ImageUtils.save_image_level >= int(key): - if key not in ImageUtils.save_img_list: - ImageUtils.save_img_list[key] = [] - ImageUtils.save_img_list[key].append(img.copy()) + if InstanceImageUtils.save_image_level >= int(key): + InstanceImageUtils.save_img_list[key].append(img.copy()) @staticmethod def save_img(path, final_marked): logger.info("Saving Image to " + path) cv2.imwrite(path, final_marked) - @staticmethod - def save_or_show_stacks(key, filename, save_dir=None, pause=1): + def save_or_show_stacks(self, key, filename, save_dir=None, pause=1): + config = self.tuning_config if ( - ImageUtils.save_image_level >= int(key) - and ImageUtils.save_img_list[key] != [] + InstanceImageUtils.save_image_level >= int(key) + and InstanceImageUtils.save_img_list[key] != [] ): name = os.path.splitext(filename)[0] result = np.hstack( tuple( [ ImageUtils.resize_util_h(img, config.dimensions.display_height) - for img in ImageUtils.save_img_list[key] + for img in InstanceImageUtils.save_img_list[key] ] ) ) @@ -77,72 +122,27 @@ def save_or_show_stacks(key, filename, save_dir=None, pause=1): else: MainOperations.show(name + "_" + str(key), result, pause, 0) - @staticmethod - def resize_util(img, u_width, u_height=None): - if u_height is None: - h, w = img.shape[:2] - u_height = int(h * u_width / w) - return cv2.resize(img, (int(u_width), int(u_height))) - - @staticmethod - def resize_util_h(img, u_height, u_width=None): - if u_width is None: - h, w = img.shape[:2] - u_width = int(w * u_height / h) - return cv2.resize(img, (int(u_width), int(u_height))) - - @staticmethod - def grab_contours(cnts): - # source: imutils package - - # if the length the contours tuple returned by cv2.findContours - # is '2' then we are using either OpenCV v2.4, v4-beta, or - # v4-official - if len(cnts) == 2: - cnts = cnts[0] - - # if the length of the contours tuple is '3' then we are using - # either OpenCV v3, v4-pre, or v4-alpha - elif len(cnts) == 3: - cnts = cnts[1] - - # otherwise OpenCV has changed their cv2.findContours return - # signature yet again and I have no idea WTH is going on - else: - raise Exception( - ( - "Contours tuple must have length 2 or 3, " - "otherwise OpenCV changed their cv2.findContours return " - "signature yet again. Refer to OpenCV's documentation " - "in that case" - ) - ) - - # return the actual contours array - return cnts - @dataclass class ImageMetrics: - resetpos = [0, 0] + window_width: int + window_height: int + reset_pos = [0, 0] # for positioning image windows window_x, window_y = 0, 0 clahe = cv2.createCLAHE(clipLimit=5.0, tileGridSize=(8, 8)) # TODO Fill these for stats + # Move qbox_vals here? # badThresholds = [] # veryBadPoints = [] -plt.rcParams["figure.figsize"] = (10.0, 8.0) - -# Image-processing utils - - def normalize_util(img, alpha=0, beta=255): return cv2.normalize(img, alpha, beta, norm_type=cv2.NORM_MINMAX) -def put_label(img, label, size): +def put_label(self, img, label, size): + config = self.tuning_config scale = img.shape[1] / config.dimensions.display_width bg_val = int(np.mean(img)) pos = (int(scale * 80), int(scale * 30)) @@ -227,9 +227,6 @@ def dist(p_1, p_2): return np.linalg.norm(np.array(p_1) - np.array(p_2)) -# These are used inside multiple extensions - - def order_points(pts): rect = np.zeros((4, 2), dtype="float32") @@ -538,75 +535,7 @@ def setup_dirs(paths): # logger.info("Present : " + _dir) -class MainOperations: - """Perform primary functions such as displaying images and reading responses""" - - image_metrics = ImageMetrics() - - def __init__(self, tuning_config=None): - self.tuning_config = tuning_config - - @staticmethod - def wait_q(): - esc_key = 27 - while cv2.waitKey(1) & 0xFF not in [ord("q"), esc_key]: - pass - # TODO: why this inside MainOperations? - MainOperations.image_metrics.window_x = 0 - MainOperations.image_metrics.window_y = 0 - cv2.destroyAllWindows() - - @staticmethod - def show(name, orig, pause=1, resize=False, resetpos=None): - if orig is None: - logger.info(name, " NoneType image to show!") - if pause: - cv2.destroyAllWindows() - return - # origDim = orig.shape[:2] - img = ( - ImageUtils.resize_util(orig, config.dimensions.display_width) - if resize - else orig - ) - cv2.imshow(name, img) - if resetpos: - MainOperations.image_metrics.window_x = resetpos[0] - MainOperations.image_metrics.window_y = resetpos[1] - cv2.moveWindow( - name, - MainOperations.image_metrics.window_x, - MainOperations.image_metrics.window_y, - ) - - h, w = img.shape[:2] - - # Set next window position - margin = 25 - w += margin - h += margin - if MainOperations.image_metrics.window_x + w > config.dimensions.window_width: - MainOperations.image_metrics.window_x = 0 - if ( - MainOperations.image_metrics.window_y + h - > config.dimensions.window_height - ): - MainOperations.image_metrics.window_y = 0 - else: - MainOperations.image_metrics.window_y += h - else: - MainOperations.image_metrics.window_x += w - - if pause: - logger.info( - "Showing '" - + name - + "'\n\tPress Q on image to continue Press Ctrl + C in terminal to exit" - ) - MainOperations.wait_q() - - @staticmethod - def read_response(template, image, name, save_dir=None, auto_align=False): + def read_response(self, template, image, name, save_dir=None, auto_align=False): config = self.tuning_config try: img = image.copy() @@ -623,19 +552,19 @@ def read_response(template, image, name, save_dir=None, auto_align=False): # str(origDim[0])+"x"+str(origDim[1]) + " "+name, size=1) morph = img.copy() - ImageUtils.append_save_img(3, morph) + InstanceImageUtils.append_save_img(3, morph) # TODO: evaluate if CLAHE is really req if auto_align: # Note: clahe is good for morphology, bad for thresholding morph = MainOperations.image_metrics.clahe.apply(morph) - ImageUtils.append_save_img(3, morph) + InstanceImageUtils.append_save_img(3, morph) # Remove shadows further, make columns/boxes darker (less gamma) morph = adjust_gamma(morph, config.threshold_params.GAMMA_LOW) # TODO: all numbers should come from either constants or config _, morph = cv2.threshold(morph, 220, 220, cv2.THRESH_TRUNC) morph = normalize_util(morph) - ImageUtils.append_save_img(3, morph) + InstanceImageUtils.append_save_img(3, morph) if config.outputs.show_image_level >= 4: MainOperations.show("morph1", morph, 0, 1) @@ -671,14 +600,14 @@ def read_response(template, image, name, save_dir=None, auto_align=False): # MainOperations.show("morph1",morph,0,1) # MainOperations.show("morphed_vertical",morph_v,0,1) - ImageUtils.append_save_img(3, morph_v) + InstanceImageUtils.append_save_img(3, morph_v) morph_thr = 60 # for Mobile images, 40 for scanned Images _, morph_v = cv2.threshold(morph_v, morph_thr, 255, cv2.THRESH_BINARY) # kernel best tuned to 5x5 now morph_v = cv2.erode(morph_v, np.ones((5, 5), np.uint8), iterations=2) - ImageUtils.append_save_img(3, morph_v) + InstanceImageUtils.append_save_img(3, morph_v) # h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (10, 2)) # morph_h = cv2.morphologyEx(morph, cv2.MORPH_OPEN, h_kernel, iterations=3) # ret, morph_h = cv2.threshold(morph_h,200,200,cv2.THRESH_TRUNC) @@ -689,7 +618,7 @@ def read_response(template, image, name, save_dir=None, auto_align=False): if config.outputs.show_image_level >= 3: MainOperations.show("morph_thr_eroded", morph_v, 0, 1) - ImageUtils.append_save_img(6, morph_v) + InstanceImageUtils.append_save_img(6, morph_v) # template relative alignment code for q_block in template.q_blocks: @@ -763,12 +692,12 @@ def read_response(template, image, name, save_dir=None, auto_align=False): img, template, shifted=True, draw_qvals=True ) # appendSaveImg(4,mean_vals) - ImageUtils.append_save_img(2, initial_align) - ImageUtils.append_save_img(2, final_align) + InstanceImageUtils.append_save_img(2, initial_align) + InstanceImageUtils.append_save_img(2, final_align) if auto_align: final_align = np.hstack((initial_align, final_align)) - ImageUtils.append_save_img(5, img) + InstanceImageUtils.append_save_img(5, img) # Get mean vals n other stats all_q_vals, all_q_strip_arrs, all_q_std_vals = [], [], [] @@ -1010,12 +939,12 @@ def read_response(template, image, name, save_dir=None, auto_align=False): if config.outputs.save_detections and save_dir is not None: if multi_roll: save_dir = save_dir + "_MULTI_/" - ImageUtils.save_img(save_dir + name, final_marked) + InstanceImageUtils.save_img(save_dir + name, final_marked) - ImageUtils.append_save_img(2, final_marked) + InstanceImageUtils.append_save_img(2, final_marked) for i in range(config.outputs.save_image_level): - ImageUtils.save_or_show_stacks(i + 1, name, save_dir) + InstanceImageUtils.save_or_show_stacks(i + 1, name, save_dir) return omr_response, final_marked, multi_marked, multi_roll diff --git a/src/utils/interaction.py b/src/utils/interaction.py new file mode 100644 index 00000000..54efff57 --- /dev/null +++ b/src/utils/interaction.py @@ -0,0 +1,69 @@ +import cv2 + +def wait_q(): + esc_key = 27 + while cv2.waitKey(1) & 0xFF not in [ord("q"), esc_key]: + pass + cv2.destroyAllWindows() + + + +class MainOperations: + """Perform primary functions such as displaying images and reading responses""" + + image_metrics = ImageMetrics() + + def __init__(self, tuning_config=None): + self.tuning_config = tuning_config + + def show(self, name, orig, pause=1, resize=False, reset_pos=None): + image_metrics = MainOperations.image_metrics + config = self.tuning_config + if orig is None: + logger.info(name, " NoneType image to show!") + if pause: + cv2.destroyAllWindows() + return + # origDim = orig.shape[:2] + img = ( + ImageUtils.resize_util(orig, config.display_width) + if resize + else orig + ) + cv2.imshow(name, img) + if reset_pos: + image_metrics.window_x = reset_pos[0] + image_metrics.window_y = reset_pos[1] + cv2.moveWindow( + name, + image_metrics.window_x, + image_metrics.window_y, + ) + + h, w = img.shape[:2] + + # Set next window position + margin = 25 + w += margin + h += margin + if image_metrics.window_x + w > image_metrics.window_width: + image_metrics.window_x = 0 + if ( + image_metrics.window_y + h + > image_metrics.window_height + ): + image_metrics.window_y = 0 + else: + image_metrics.window_y += h + else: + image_metrics.window_x += w + + if pause: + logger.info( + "Showing '" + + name + + "'\n\tPress Q on image to continue Press Ctrl + C in terminal to exit" + ) + wait_q() + MainOperations.image_metrics.window_x = 0 + MainOperations.image_metrics.window_y = 0 diff --git a/src/utils/validations.py b/src/utils/validations.py index db127a2a..80661dd7 100644 --- a/src/utils/validations.py +++ b/src/utils/validations.py @@ -38,7 +38,7 @@ def parse_validation_error(error): def validate_evaluation_json(json_data, evaluation_path): logger.info("Validating evaluation.json...") try: - validate(instance=json_data, schema=SCHEMA_JSONS[SCHEMA_NAMES["evaluation"]]) + validate(instance=json_data, schema=SCHEMA_JSONS[SCHEMA_NAMES.evaluation]) except jsonschema.exceptions.ValidationError as _err: # NOQA table = Table(show_lines=True) @@ -46,7 +46,7 @@ def validate_evaluation_json(json_data, evaluation_path): table.add_column("Error", style="magenta") errors = sorted( - SCHEMA_VALIDATORS[SCHEMA_NAMES["evaluation"]].iter_errors(json_data), + SCHEMA_VALIDATORS[SCHEMA_NAMES.evaluation].iter_errors(json_data), key=lambda e: e.path, ) for error in errors: @@ -70,7 +70,7 @@ def validate_evaluation_json(json_data, evaluation_path): def validate_template_json(json_data, template_path): logger.info("Validating template.json...") try: - validate(instance=json_data, schema=SCHEMA_JSONS[SCHEMA_NAMES["template"]]) + validate(instance=json_data, schema=SCHEMA_JSONS[SCHEMA_NAMES.template]) except jsonschema.exceptions.ValidationError as _err: # NOQA table = Table(show_lines=True) @@ -78,7 +78,7 @@ def validate_template_json(json_data, template_path): table.add_column("Error", style="magenta") errors = sorted( - SCHEMA_VALIDATORS[SCHEMA_NAMES["template"]].iter_errors(json_data), + SCHEMA_VALIDATORS[SCHEMA_NAMES.template].iter_errors(json_data), key=lambda e: e.path, ) for error in errors: @@ -109,13 +109,13 @@ def validate_template_json(json_data, template_path): def validate_config_json(json_data, config_path): logger.info("Validating config.json...") try: - validate(instance=json_data, schema=SCHEMA_JSONS[SCHEMA_NAMES["config"]]) + validate(instance=json_data, schema=SCHEMA_JSONS[SCHEMA_NAMES.config]) except jsonschema.exceptions.ValidationError as _err: # NOQA table = Table(show_lines=True) table.add_column("Key", style="cyan", no_wrap=True) table.add_column("Error", style="magenta") errors = sorted( - SCHEMA_VALIDATORS[SCHEMA_NAMES["config"]].iter_errors(json_data), + SCHEMA_VALIDATORS[SCHEMA_NAMES.config].iter_errors(json_data), key=lambda e: e.path, ) for error in errors: