From 516a2c7ba18e962d990e5b576f390146e4d6628e Mon Sep 17 00:00:00 2001 From: Kevin Scott <151596+thekevinscott@users.noreply.github.com> Date: Tue, 21 Nov 2023 15:32:04 -0500 Subject: [PATCH] Refactor error handling to surface the error (#1270) * Refactor error handling to surface the error --- packages/shared/src/constants.test.ts | 63 +++++------------ packages/shared/src/constants.ts | 23 +----- packages/shared/src/types.ts | 6 -- .../src/browser/loadModel.browser.test.ts | 38 ++++------ .../src/browser/loadModel.browser.ts | 25 ++----- .../src/node/loadModel.node.test.ts | 37 +++++----- .../upscalerjs/src/node/loadModel.node.ts | 20 ++---- .../src/shared/errors-and-warnings.ts | 20 ++---- .../upscalerjs/src/shared/model-utils.test.ts | 30 +++----- packages/upscalerjs/src/shared/utils.test.ts | 70 ++++++++++++------- packages/upscalerjs/src/shared/utils.ts | 28 ++++++-- 11 files changed, 138 insertions(+), 222 deletions(-) diff --git a/packages/shared/src/constants.test.ts b/packages/shared/src/constants.test.ts index a176d9c7a..10d36a32e 100644 --- a/packages/shared/src/constants.test.ts +++ b/packages/shared/src/constants.test.ts @@ -1,16 +1,11 @@ import * as tf from '@tensorflow/tfjs-node'; import { vi } from 'vitest'; -import { - ModelDefinition, - MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, -} from './types'; -import { +import { makeIsNDimensionalTensor, - isFourDimensionalTensor, - isThreeDimensionalTensor, + isFourDimensionalTensor, + isThreeDimensionalTensor, isTensor, isString, - isValidModelDefinition, hasValidChannels, isValidRange, isNumber, @@ -54,7 +49,7 @@ describe('isFourDimensionalTensor', () => { expect(isFourDimensionalTensor(tf.tensor([[[1,],],]))).toEqual(false); }); - expect(isFourDimensionalTensor({} as tf.Tensor)).toEqual(false); + expect(isFourDimensionalTensor({} as tf.Tensor)).toEqual(false); }); describe('isThreeDimensionalTensor', () => { @@ -90,35 +85,13 @@ describe('isString', () => { }); }); -describe('isValidModelDefinition', () => { - it('throws error if given an undefined', () => { - expect(() => isValidModelDefinition(undefined)).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED); - }); - - it('throws error if given no path', () => { - expect(() => isValidModelDefinition({ path: undefined, scale: 2 } as unknown as ModelDefinition )).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.MISSING_PATH); - }); - - it('throws error if given invalid model type', () => { - expect(() => isValidModelDefinition({ path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition )).toThrow(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE); - }); - - it('returns true if given scale and path', () => { - expect(isValidModelDefinition({ - path: 'foo', - scale: 2, - modelType: 'layers', - })).toEqual(true); - }); -}); - describe('hasValidChannels', () => { it('returns true if a tensor has valid channels', () => { - expect(hasValidChannels(tf.ones([4,4,3]))).toEqual(true); + expect(hasValidChannels(tf.ones([4, 4, 3]))).toEqual(true); }); it('returns false if a tensor does not have valid channels', () => { - expect(hasValidChannels(tf.ones([4,4,4]))).toEqual(false); + expect(hasValidChannels(tf.ones([4, 4, 4]))).toEqual(false); }); }); @@ -154,15 +127,15 @@ describe('isValidRange', () => { }); it('returns false if it gets an array with three numbers', () => { - expect(isValidRange([1,2,3])).toEqual(false); + expect(isValidRange([1, 2, 3])).toEqual(false); }); it('returns false if it gets an array with a number and a string', () => { - expect(isValidRange([1,'foo'])).toEqual(false); + expect(isValidRange([1, 'foo'])).toEqual(false); }); it('returns true if it gets an array with two numbers', () => { - expect(isValidRange([1,2])).toEqual(true); + expect(isValidRange([1, 2])).toEqual(true); }); }); @@ -176,19 +149,19 @@ describe('isShape4D', () => { }); it('returns false if given an array of 3 numbers', () => { - expect(isShape4D([1,2,3])).toEqual(false); + expect(isShape4D([1, 2, 3])).toEqual(false); }); it('returns false if given an array of 5 numbers', () => { - expect(isShape4D([1,2,3,4,5])).toEqual(false); + expect(isShape4D([1, 2, 3, 4, 5])).toEqual(false); }); it('returns false if given an array of not all numbers', () => { - expect(isShape4D([1,null,3,'foo'])).toEqual(false); + expect(isShape4D([1, null, 3, 'foo'])).toEqual(false); }); it('returns true if given an array of all numbers', () => { - expect(isShape4D([1,2,3,4])).toEqual(true); + expect(isShape4D([1, 2, 3, 4])).toEqual(true); }); it('returns true if given an array containing nulls', () => { @@ -201,9 +174,9 @@ describe('isFixedShape4D', () => { [[null, null, null, 3], false], [[null, -1, -1, 3], false], [[null, 2, 2, 3], true], - ])('%s | %s',(args, expectation) => { - expect(isFixedShape4D(args)).toEqual(expectation); - }); + ])('%s | %s', (args, expectation) => { + expect(isFixedShape4D(args)).toEqual(expectation); + }); }); describe('isDynamicShape', () => { @@ -212,7 +185,7 @@ describe('isDynamicShape', () => { [[null, -1, -1, 3], true], [[null, 2, 2, 3], false], ])('%s | %s', (args, expectation) => { - expect(isDynamicShape4D(args)).toEqual(expectation); - }); + expect(isDynamicShape4D(args)).toEqual(expectation); + }); }); diff --git a/packages/shared/src/constants.ts b/packages/shared/src/constants.ts index c1e2f6ce4..661e144b7 100644 --- a/packages/shared/src/constants.ts +++ b/packages/shared/src/constants.ts @@ -1,6 +1,6 @@ import * as tf from '@tensorflow/tfjs-core'; import { Tensor, Tensor3D, Tensor4D, } from '@tensorflow/tfjs-core'; -import { DynamicShape4D, FixedShape4D, IsTensor, MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, ModelDefinition, ModelType, Shape4D } from './types'; +import { DynamicShape4D, FixedShape4D, IsTensor, ModelType, Shape4D } from './types'; export const isShape4D = (shape?: unknown): shape is Shape4D => { if (!Boolean(shape) || !Array.isArray(shape) || shape.length !== 4) { @@ -29,27 +29,6 @@ export const isTensor = (input: unknown): input is tf.Tensor => input instanceof export const isString = (el: unknown): el is string => typeof el === 'string'; export const isValidModelType = (modelType: unknown): modelType is ModelType => typeof modelType === 'string' && ['layers', 'graph',].includes(modelType); -export class ModelDefinitionValidationError extends Error { - type: MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE; - - constructor(type: MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE) { - super(type); - this.type = type; - } -} - -export const isValidModelDefinition = (modelDefinition?: ModelDefinition): modelDefinition is ModelDefinition => { - if (modelDefinition === undefined) { - throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED); - } - if (!isValidModelType(modelDefinition.modelType ?? 'layers')) { - throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE); - } - if (!modelDefinition.path && !modelDefinition._internals?.path) { - throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.MISSING_PATH); - } - return true; -}; export const hasValidChannels = (tensor: tf.Tensor): boolean => tensor.shape.slice(-1)[0] === 3; diff --git a/packages/shared/src/types.ts b/packages/shared/src/types.ts index d7d6b31e4..9db8fa645 100644 --- a/packages/shared/src/types.ts +++ b/packages/shared/src/types.ts @@ -101,9 +101,3 @@ export type ModelDefinitionFn = (tf: TF) => ModelDefinition; export type ModelDefinitionObjectOrFn = ModelDefinitionFn | ModelDefinition; export type IsTensor = (pixels: Tensor) => pixels is T; - -export enum MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE { - UNDEFINED = 'undefined', - INVALID_MODEL_TYPE = 'invalidModelType', - MISSING_PATH = 'missingPath', -} diff --git a/packages/upscalerjs/src/browser/loadModel.browser.test.ts b/packages/upscalerjs/src/browser/loadModel.browser.test.ts index 181d77b82..45ddaca1c 100644 --- a/packages/upscalerjs/src/browser/loadModel.browser.test.ts +++ b/packages/upscalerjs/src/browser/loadModel.browser.test.ts @@ -13,20 +13,17 @@ import { import * as tf from '@tensorflow/tfjs-node'; import { - getModelDefinitionError, ERROR_MODEL_DEFINITION_BUG, } from '../shared/errors-and-warnings'; import { ModelDefinition, - MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, } from '../../../shared/src/types'; import { - ModelDefinitionValidationError, - isValidModelDefinition, -} from '../../../shared/src/constants'; + checkModelDefinition, +} from '../shared/utils'; -import type * as sharedConstants from '../../../shared/src/constants'; +import type * as sharedUtils from '../shared/utils'; import type * as modelUtils from '../shared/model-utils'; import type * as errorsAndWarnings from '../shared/errors-and-warnings'; import type * as loadModelBrowser from './loadModel.browser'; @@ -47,18 +44,17 @@ vi.mock('../shared/model-utils', async () => { }); vi.mock('../shared/errors-and-warnings', async () => { - const { getModelDefinitionError, ...rest } = await vi.importActual('../shared/errors-and-warnings') as typeof errorsAndWarnings; + const { ...rest } = await vi.importActual('../shared/errors-and-warnings') as typeof errorsAndWarnings; return { ...rest, - getModelDefinitionError: vi.fn(getModelDefinitionError), } }); -vi.mock('../../../shared/src/constants', async () => { - const { isValidModelDefinition, ...rest } = await vi.importActual('../../../shared/src/constants') as typeof sharedConstants; +vi.mock('../shared/utils', async () => { + const { checkModelDefinition, ...rest } = await vi.importActual('../shared/utils') as typeof sharedUtils; return { ...rest, - isValidModelDefinition: vi.fn(isValidModelDefinition), + checkModelDefinition: vi.fn(checkModelDefinition), } }); @@ -202,22 +198,18 @@ describe('loadModel browser tests', () => { }); describe('loadModel', () => { - it('throws if not a valid model definition', async () => { - const e = new Error(ERROR_MODEL_DEFINITION_BUG); - vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => { - throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED); + it('throws if given a bad model definition', async () => { + vi.mocked(checkModelDefinition).mockImplementation(() => { + throw new Error(); }); - vi.mocked(vi).mocked(getModelDefinitionError).mockImplementation(() => e); - await expect(() => loadModel(tf, Promise.resolve({ - path: 'foo', - scale: 2, - modelType: 'layers', - }))).rejects.toThrowError(e); + await expect(loadModel(tf, Promise.resolve({}) as Promise)) + .rejects + .toThrow(); }); it('loads a valid layers model successfully', async () => { - vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => true); + vi.mocked(vi).mocked(checkModelDefinition).mockImplementation(() => true); const model = 'foo' as unknown as LayersModel; vi.mocked(loadTfModel).mockImplementation(async () => model); expect(loadTfModel).toHaveBeenCalledTimes(0); @@ -240,7 +232,7 @@ describe('loadModel browser tests', () => { }); it('loads a valid graph model successfully', async () => { - vi.mocked(vi).mocked(isValidModelDefinition).mockImplementation(() => true); + vi.mocked(vi).mocked(checkModelDefinition).mockImplementation(() => true); const model = 'foo' as unknown as GraphModel; const modelDefinition: ModelDefinition = { diff --git a/packages/upscalerjs/src/browser/loadModel.browser.ts b/packages/upscalerjs/src/browser/loadModel.browser.ts index 954cc4356..efc0f096a 100644 --- a/packages/upscalerjs/src/browser/loadModel.browser.ts +++ b/packages/upscalerjs/src/browser/loadModel.browser.ts @@ -7,17 +7,11 @@ import { } from '../shared/model-utils'; import { ERROR_MODEL_DEFINITION_BUG, - getModelDefinitionError, } from '../shared/errors-and-warnings'; import type { TF, } from '../../../shared/src/types'; -import { - isValidModelDefinition, -} from '../../../shared/src/constants'; -import { - errIsModelDefinitionValidationError, -} from '../shared/utils'; +import { checkModelDefinition, } from '../shared/utils.js'; type CDN = 'jsdelivr' | 'unpkg'; @@ -38,7 +32,7 @@ export const CDNS: CDN[] = [ export const getLoadModelErrorMessage = (errs: Errors, modelPath: string, internals: ModelConfigurationInternals): Error => new Error([ `Could not resolve URL ${modelPath} for package ${internals?.name}@${internals?.version}`, 'Errors include:', - ...errs.map(([cdn, err, ]) => `- ${cdn}: ${err.message}`), + ...errs.map(([cdn, err,]) => `- ${cdn}: ${err.message}`), ].join('\n')); export async function fetchModel(tf: TF, modelConfiguration: { @@ -50,7 +44,7 @@ export async function fetchModel = async (tf, _modelDefinition) => { const modelDefinition = await _modelDefinition; - - try { - isValidModelDefinition(modelDefinition); - } catch (err: unknown) { - if (errIsModelDefinitionValidationError(err)) { - throw getModelDefinitionError(err.type, modelDefinition); - } - throw new Error(ERROR_MODEL_DEFINITION_BUG); - } + + checkModelDefinition(modelDefinition); const parsedModelDefinition = parseModelDefinition(modelDefinition); diff --git a/packages/upscalerjs/src/node/loadModel.node.test.ts b/packages/upscalerjs/src/node/loadModel.node.test.ts index 1ff53b2f4..2d337c658 100644 --- a/packages/upscalerjs/src/node/loadModel.node.test.ts +++ b/packages/upscalerjs/src/node/loadModel.node.test.ts @@ -1,4 +1,4 @@ -import { +import { loadModel, getModelPath, getModuleFolder, @@ -9,7 +9,6 @@ import path from 'path'; import { resolver, } from './resolver'; import { ModelDefinition, - MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, } from "../../../shared/src/types"; import * as tf from '@tensorflow/tfjs-node'; import { @@ -19,11 +18,10 @@ import { loadTfModel, } from '../shared/model-utils'; import { - isValidModelDefinition, - ModelDefinitionValidationError, -} from '../../../shared/src/constants'; + checkModelDefinition, +} from '../shared/utils'; -import type * as sharedConstants from '../../../shared/src/constants'; +import type * as sharedUtils from '../shared/utils'; import type * as modelUtils from '../shared/model-utils'; import type * as errorsAndWarnings from '../shared/errors-and-warnings'; import type * as resolverModule from './resolver'; @@ -37,18 +35,17 @@ vi.mock('../shared/model-utils', async () => { }); vi.mock('../shared/errors-and-warnings', async () => { - const { getModelDefinitionError, ...rest } = await vi.importActual('../shared/errors-and-warnings') as typeof errorsAndWarnings; + const { ...rest } = await vi.importActual('../shared/errors-and-warnings') as typeof errorsAndWarnings; return { ...rest, - getModelDefinitionError: vi.fn(getModelDefinitionError), } }); -vi.mock('../../../shared/src/constants', async () => { - const { isValidModelDefinition, ...rest } = await vi.importActual('../../../shared/src/constants') as typeof sharedConstants; +vi.mock('../shared/utils', async () => { + const { checkModelDefinition, ...rest } = await vi.importActual('../shared/utils') as typeof sharedUtils; return { ...rest, - isValidModelDefinition: vi.fn(isValidModelDefinition), + checkModelDefinition: vi.fn(checkModelDefinition), } }); vi.mock('./resolver', async () => { @@ -91,8 +88,8 @@ describe('loadModel.node', () => { describe('getModelPath', () => { it('returns model path if provided a path', () => { vi.mocked(resolver).mockImplementation(getResolver(() => '')); - expect(getModelPath({ - path: 'foo', + expect(getModelPath({ + path: 'foo', _internals: { path: 'some-model', name: 'baz', @@ -100,7 +97,7 @@ describe('loadModel.node', () => { }, scale: 2, modelType: 'layers', - })).toEqual('foo'); + })).toEqual('foo'); }); it('returns model path if not provided a path', () => { @@ -118,21 +115,21 @@ describe('loadModel.node', () => { }); describe('loadModel', () => { - it('throws if given an undefined model definition', async () => { + it('throws if given a bad model definition', async () => { vi.mocked(resolver).mockImplementation(getResolver(() => './node_modules/baz')); const error = ERROR_MODEL_DEFINITION_BUG; - vi.mocked(isValidModelDefinition).mockImplementation(() => { - throw new ModelDefinitionValidationError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED); + vi.mocked(checkModelDefinition).mockImplementation(() => { + throw new Error(); }); await expect(loadModel(tf, Promise.resolve({}) as Promise)) .rejects - .toThrow(error); + .toThrow(); }); it('loads a valid layers model', async () => { vi.mocked(resolver).mockImplementation(getResolver(() => './node_modules/baz')); - vi.mocked(isValidModelDefinition).mockImplementation(() => true); + vi.mocked(checkModelDefinition).mockImplementation(() => true); vi.mocked(loadTfModel).mockImplementation(async () => 'layers model' as any); const path = 'foo'; @@ -148,7 +145,7 @@ describe('loadModel.node', () => { it('loads a valid graph model', async () => { vi.mocked(resolver).mockImplementation(getResolver(() => './node_modules/baz')); - vi.mocked(isValidModelDefinition).mockImplementation(() => true); + vi.mocked(checkModelDefinition).mockImplementation(() => true); vi.mocked(loadTfModel).mockImplementation(async () => 'graph model' as any); const path = 'foo'; diff --git a/packages/upscalerjs/src/node/loadModel.node.ts b/packages/upscalerjs/src/node/loadModel.node.ts index 1a6ed1db9..9f6d9903f 100644 --- a/packages/upscalerjs/src/node/loadModel.node.ts +++ b/packages/upscalerjs/src/node/loadModel.node.ts @@ -2,19 +2,13 @@ import path from 'path'; import { loadTfModel, parseModelDefinition, } from '../shared/model-utils'; import { resolver, } from './resolver'; import { ParsedModelDefinition, LoadModel, } from '../shared/types'; -import { - isValidModelDefinition, -} from '../../../shared/src/constants'; import type { TF, } from '../../../shared/src/types'; import { ERROR_MODEL_DEFINITION_BUG, - getModelDefinitionError, } from '../shared/errors-and-warnings'; -import { - errIsModelDefinitionValidationError, -} from '../shared/utils'; +import { checkModelDefinition, } from '../shared/utils.js'; export const getMissingMatchesError = (moduleEntryPoint: string): Error => new Error( `No matches could be found for module entry point ${moduleEntryPoint}` @@ -37,7 +31,7 @@ export const getModelPath = (modelConfiguration: ParsedModelDefinition): string const { _internals, } = modelConfiguration; if (!_internals) { // This should never happen. This should have been caught by isValidModelDefinition. - throw new Error(ERROR_MODEL_DEFINITION_BUG); + throw ERROR_MODEL_DEFINITION_BUG('Missing internals'); } const moduleFolder = getModuleFolder(_internals.name); return `file://${path.resolve(moduleFolder, _internals.path)}`; @@ -45,14 +39,8 @@ export const getModelPath = (modelConfiguration: ParsedModelDefinition): string export const loadModel: LoadModel = async (tf, _modelDefinition) => { const modelDefinition = await _modelDefinition; - try { - isValidModelDefinition(modelDefinition); - } catch(err: unknown) { - if (errIsModelDefinitionValidationError(err)) { - throw getModelDefinitionError(err.type, modelDefinition); - } - throw new Error(ERROR_MODEL_DEFINITION_BUG); - } + + checkModelDefinition(modelDefinition); const parsedModelDefinition = parseModelDefinition(modelDefinition); diff --git a/packages/upscalerjs/src/shared/errors-and-warnings.ts b/packages/upscalerjs/src/shared/errors-and-warnings.ts index 0cc768104..9432895f6 100644 --- a/packages/upscalerjs/src/shared/errors-and-warnings.ts +++ b/packages/upscalerjs/src/shared/errors-and-warnings.ts @@ -1,7 +1,7 @@ import { ModelDefinition, - MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, } from "../../../shared/src/types"; +// import { ModelDefinitionValidationError, } from "../constants"; const WARNING_DEPRECATED_MODEL_DEFINITION_URL = 'https://upscalerjs.com/documentation/troubleshooting#deprecated-model-definition-function'; @@ -31,7 +31,7 @@ export const WARNING_PROGRESS_WITHOUT_PATCH_SIZE = [ `For more information, see ${WARNING_PROGRESS_WITHOUT_PATCH_SIZE_URL}.`, ].join(' '); -const ERROR_INVALID_TENSOR_PREDICTED_URL = +const ERROR_INVALID_TENSOR_PREDICTED_URL = 'https://upscalerjs.com/documentation/troubleshooting#invalid-predicted-tensor'; export const ERROR_INVALID_TENSOR_PREDICTED = (shape: number[]) => [ @@ -40,7 +40,7 @@ export const ERROR_INVALID_TENSOR_PREDICTED = (shape: number[]) => [ `For more information, see ${ERROR_INVALID_TENSOR_PREDICTED_URL}.`, ].join(' '); -const ERROR_INVALID_MODEL_PREDICTION_URL = +const ERROR_INVALID_MODEL_PREDICTION_URL = 'https://upscalerjs.com/documentation/troubleshooting#invalid-model-prediction'; export const ERROR_INVALID_MODEL_PREDICTION = [ @@ -59,11 +59,12 @@ const ERROR_INVALID_MODEL_TYPE_URL = 'https://upscalerjs.com/documentation/troub const WARNING_INPUT_SIZE_AND_PATCH_SIZE_URL = 'https://upscalerjs.com/documentation/troubleshooting#input-size-and-patch-size'; const ERROR_WITH_MODEL_INPUT_SHAPE_URL = 'https://upscalerjs.com/documentation/troubleshooting#error-with-model-input-shape'; +export const ERROR_UNDEFINED_MODEL = new Error('An undefined model was provided to UpscalerJS'); export const ERROR_INVALID_MODEL_TYPE = (modelType: unknown) => ([ `You've provided an invalid model type: ${JSON.stringify(modelType)}. Accepted types are "layers" and "graph".`, `For more information, see ${ERROR_INVALID_MODEL_TYPE_URL}.`, ].join(' ')); -export const ERROR_MODEL_DEFINITION_BUG = 'There is a bug with the upscaler code. Please report this.'; +export const ERROR_MODEL_DEFINITION_BUG = (err?: string) => new Error(`There is a bug with the upscaler code. Please report this. ${err ? `Error: ${err}` : ''}`.trim()); export const WARNING_INPUT_SIZE_AND_PATCH_SIZE = [ 'You have provided a patchSize, but the model definition already includes an input size.', 'Your patchSize will be ignored.', @@ -112,14 +113,3 @@ export const GET_MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS = (modelConfigur `For more information, see ${MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS_URL}.`, `The model configuration provided was: ${JSON.stringify(modelConfiguration)}`, ].join(' '); - -export function getModelDefinitionError(error: MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, modelDefinition?: ModelDefinition): Error { - switch(error) { - case MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE: - return new Error(ERROR_INVALID_MODEL_TYPE(modelDefinition?.modelType)); - case MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.MISSING_PATH: - return new Error(GET_MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS(modelDefinition)); - default: - return new Error(ERROR_MODEL_DEFINITION_BUG); - } -} diff --git a/packages/upscalerjs/src/shared/model-utils.test.ts b/packages/upscalerjs/src/shared/model-utils.test.ts index c94dfa6bb..4f5ad250d 100644 --- a/packages/upscalerjs/src/shared/model-utils.test.ts +++ b/packages/upscalerjs/src/shared/model-utils.test.ts @@ -1,5 +1,5 @@ import { vi } from 'vitest'; -import { +import { parseModelDefinition, getModel, loadTfModel, @@ -7,29 +7,27 @@ import { getModelInputShape, getPatchSizeAsMultiple, } from './model-utils'; -import type * as utils from './utils'; +import type * as utils from './utils'; import { warn, } from './utils'; import * as isLayersModel from './isLayersModel'; -import { - MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, +import { ModelDefinition, ModelDefinitionFn, - } from '../../../shared/src/types'; +} from '../../../shared/src/types'; import type * as sharedConstants from '../../../shared/src/constants'; -import { +import { isShape4D, - } from '../../../shared/src/constants'; +} from '../../../shared/src/constants'; import { ModelPackage } from './types'; import { ERROR_INVALID_MODEL_TYPE, - ERROR_MODEL_DEFINITION_BUG, - ERROR_WITH_MODEL_INPUT_SHAPE, + ERROR_MODEL_DEFINITION_BUG, + ERROR_WITH_MODEL_INPUT_SHAPE, GET_INVALID_PATCH_SIZE, WARNING_INPUT_SIZE_AND_PATCH_SIZE, WARNING_UNDEFINED_PADDING, - getModelDefinitionError, MODEL_INPUT_SIZE_MUST_BE_SQUARE, GET_INVALID_PATCH_SIZE_AND_PADDING, GET_WARNING_PATCH_SIZE_INDIVISIBLE_BY_DIVISIBILITY_FACTOR, @@ -81,18 +79,6 @@ describe('model-utils', () => { vi.clearAllMocks(); }); - describe('getModelDefinitionError', () => { - it('returns an error if invalid model type is provided', () => { - const err = getModelDefinitionError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE, { path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition); - expect(err.message).toEqual(ERROR_INVALID_MODEL_TYPE('foo')); - }); - - it('returns a generic error otherwise', () => { - const err = getModelDefinitionError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED, { path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition); - expect(err.message).toEqual(ERROR_MODEL_DEFINITION_BUG); - }); - }) - describe('getModel', () => { describe('ModelDefinition', () => { it('returns model definition', async () => { diff --git a/packages/upscalerjs/src/shared/utils.test.ts b/packages/upscalerjs/src/shared/utils.test.ts index 67ac8d67f..bf2f9554c 100644 --- a/packages/upscalerjs/src/shared/utils.test.ts +++ b/packages/upscalerjs/src/shared/utils.test.ts @@ -1,22 +1,23 @@ import { Tensor3D } from '@tensorflow/tfjs-node'; import { vi } from 'vitest'; import * as tf from '@tensorflow/tfjs-node'; -import { +import { processAndDisposeOfTensor, - wrapGenerator, - isSingleArgProgress, - isMultiArgTensorProgress, - warn, + wrapGenerator, + isSingleArgProgress, + isMultiArgTensorProgress, + warn, isAborted, + checkModelDefinition, } from './utils'; import { ModelDefinition, - MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE, } from '../../../shared/src/types'; import { ERROR_INVALID_MODEL_TYPE, - ERROR_MODEL_DEFINITION_BUG, - getModelDefinitionError, + ERROR_MODEL_DEFINITION_BUG, + ERROR_UNDEFINED_MODEL, + GET_MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS, } from './errors-and-warnings'; describe('isAborted', () => { @@ -109,7 +110,7 @@ describe('wrapGenerator', () => { return 'baz'; } - const callback = vi.fn(async () => {}); + const callback = vi.fn(async () => { }); await wrapGenerator(foo(), callback); expect(callback).toHaveBeenCalledTimes(2); expect(callback).toHaveBeenCalledWith('foo'); @@ -141,57 +142,72 @@ describe('wrapGenerator', () => { describe('isSingleArgProgress', () => { it('returns true for function', () => { - expect(isSingleArgProgress(() => {})).toEqual(true); + expect(isSingleArgProgress(() => { })).toEqual(true); }); it('returns true for a single arg function', () => { - expect(isSingleArgProgress((_1: any) => {})).toEqual(true); + expect(isSingleArgProgress((_1: any) => { })).toEqual(true); }); it('returns false for a double arg function', () => { - expect(isSingleArgProgress((_1: any, _2: any) => {})).toEqual(false); + expect(isSingleArgProgress((_1: any, _2: any) => { })).toEqual(false); }); }); describe('isMultiArgProgress', () => { it('returns false for a single arg function', () => { - expect(isMultiArgTensorProgress((_1: any) => {}, undefined, undefined)).toEqual(false); + expect(isMultiArgTensorProgress((_1: any) => { }, undefined, undefined)).toEqual(false); }); it('returns false for a zero arg function', () => { - expect(isMultiArgTensorProgress(() => {}, undefined, undefined, )).toEqual(false); + expect(isMultiArgTensorProgress(() => { }, undefined, undefined,)).toEqual(false); }); it('returns false for a multi arg tensor string function', () => { - expect(isMultiArgTensorProgress((_1: any, _2: any) => {}, 'base64', 'base64')).toEqual(false); + expect(isMultiArgTensorProgress((_1: any, _2: any) => { }, 'base64', 'base64')).toEqual(false); }); it('returns false for a multi arg tensor string function with overloaded outputs', () => { - expect(isMultiArgTensorProgress((_1: any, _2: any) => {}, 'tensor', 'base64')).toEqual(false); + expect(isMultiArgTensorProgress((_1: any, _2: any) => { }, 'tensor', 'base64')).toEqual(false); }); it('returns true for a multi arg tensor function', () => { - expect(isMultiArgTensorProgress((_1: any, _2: any) => {}, 'tensor', 'tensor')).toEqual(true); + expect(isMultiArgTensorProgress((_1: any, _2: any) => { }, 'tensor', 'tensor')).toEqual(true); }); it('returns true for a multi arg tensor function with conflicting outputs', () => { - expect(isMultiArgTensorProgress((_1: any, _2: any) => {}, 'base64', 'tensor')).toEqual(true); + expect(isMultiArgTensorProgress((_1: any, _2: any) => { }, 'base64', 'tensor')).toEqual(true); }); it('returns true for a multi arg tensor function with conflicting outputs with an undefined progressOutput', () => { - expect(isMultiArgTensorProgress((_1: any, _2: any) => {}, 'tensor', undefined)).toEqual(true); + expect(isMultiArgTensorProgress((_1: any, _2: any) => { }, 'tensor', undefined)).toEqual(true); }); }); -describe('getModelDefinitionError', () => { - it('returns an error if invalid model type is provided', () => { - const err = getModelDefinitionError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.INVALID_MODEL_TYPE, { path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition); - expect(err.message).toEqual(ERROR_INVALID_MODEL_TYPE('foo')); +describe('checkModelDefinition', () => { + it('throws if an undefined model is provided', () => { + expect(() => checkModelDefinition(undefined)).toThrowError(ERROR_UNDEFINED_MODEL); + }); + + it('throws if an invalid model is provided', () => { + const modelDef = { + modelType: 'foo', + } as unknown as ModelDefinition; + expect(() => checkModelDefinition(modelDef)).toThrowError(ERROR_INVALID_MODEL_TYPE(modelDef)); }); - it('returns a generic error otherwise', () => { - const err = getModelDefinitionError(MODEL_DEFINITION_VALIDATION_CHECK_ERROR_TYPE.UNDEFINED, { path: 'foo', scale: 2, modelType: 'foo' } as unknown as ModelDefinition); - expect(err.message).toEqual(ERROR_MODEL_DEFINITION_BUG); + it('throws if a model is missing a path and _internals', () => { + const modelDef = { + modelType: 'layers', + } as unknown as ModelDefinition; + expect(() => checkModelDefinition(modelDef)).toThrowError(GET_MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS(modelDef)); + }); + + it('passes with a valid model', () => { + checkModelDefinition({ + modelType: 'layers', + path: '/foo', + }); }); }) @@ -202,7 +218,7 @@ describe('processAndDisposeOfTensor', () => { isDisposed = false; value?: number; - mockDispose: typeof vi.fn = vi.fn().mockImplementation(() => {}); + mockDispose: typeof vi.fn = vi.fn().mockImplementation(() => { }); constructor({ mockDispose, diff --git a/packages/upscalerjs/src/shared/utils.ts b/packages/upscalerjs/src/shared/utils.ts index 1818fcee2..3f0bc6da2 100644 --- a/packages/upscalerjs/src/shared/utils.ts +++ b/packages/upscalerjs/src/shared/utils.ts @@ -1,12 +1,16 @@ import type { Tensor, } from '@tensorflow/tfjs-core'; import type { Progress, SingleArgProgress, ResultFormat, MultiArgTensorProgress, } from './types'; -import type { - ProcessFn, - TF, +import { + type ModelDefinition, + type ProcessFn, + type TF, } from '../../../shared/src/types'; import { - ModelDefinitionValidationError, -} from '../../../shared/src/constants'; + ERROR_INVALID_MODEL_TYPE, + ERROR_UNDEFINED_MODEL, + GET_MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS, +} from './errors-and-warnings'; +import { isValidModelType, } from '../../../shared/src/constants'; export const warn = (msg: string | string[]): void => { console.warn(Array.isArray(msg) ? msg.join('\n') : msg);// skipcq: JS-0002 @@ -34,7 +38,7 @@ export const isAborted = (abortSignal?: AbortSignal): boolean => { type PostNext = ((value: T) => (void | Promise)); /* eslint-disable @typescript-eslint/no-explicit-any */ export async function wrapGenerator( - gen: Generator | AsyncGenerator, + gen: Generator | AsyncGenerator, postNext?: PostNext ): Promise { let result: undefined | IteratorResult; @@ -68,4 +72,14 @@ export function processAndDisposeOfTensor( return tensor; } -export const errIsModelDefinitionValidationError = (err: unknown): err is ModelDefinitionValidationError => err instanceof ModelDefinitionValidationError; +export const checkModelDefinition = (modelDefinition?: ModelDefinition): void => { + if (modelDefinition === undefined) { + throw ERROR_UNDEFINED_MODEL; + } + if (!isValidModelType(modelDefinition.modelType ?? 'layers')) { + throw ERROR_INVALID_MODEL_TYPE(modelDefinition); + } + if (!modelDefinition.path && !modelDefinition._internals?.path) { + throw GET_MODEL_CONFIGURATION_MISSING_PATH_AND_INTERNALS(modelDefinition); + } +};