diff --git a/relu.test.ts b/relu.test.ts new file mode 100644 index 0000000..cd8e859 --- /dev/null +++ b/relu.test.ts @@ -0,0 +1,35 @@ +import { relu } from './relu'; // Import your relu function from your implementation file + +describe('relu', () => { + it('should apply ReLU activation correctly for positive input', () => { + // Define input and expected output for positive input + const input = 3; // Positive input + const expectedOutput = 3; // Expected output (same as input) + + // Convert input to a Field (assuming you have a way to represent numbers as Fields) + const inputField = ...; // Convert input to a Field + + // Call the relu function + const result = relu(inputField); + + // Assert that the result matches the expected output + expect(result.toNumber()).toEqual(expectedOutput); + }); + + it('should apply ReLU activation correctly for negative input', () => { + // Define input and expected output for negative input + const input = -2; // Negative input + const expectedOutput = 0; // Expected output (ReLU turns negative input to zero) + + // Convert input to a Field (assuming you have a way to represent numbers as Fields) + const inputField = ...; // Convert input to a Field + + // Call the relu function + const result = relu(inputField); + + // Assert that the result matches the expected output + expect(result.toNumber()).toEqual(expectedOutput); + }); + + // Add more test cases as needed +}); diff --git a/relu.ts b/relu.ts new file mode 100644 index 0000000..5e315dc --- /dev/null +++ b/relu.ts @@ -0,0 +1,4 @@ +function relu(input) { + // If input is greater than or equal to zero, return input; otherwise, return zero. + return input.greaterThanOrEqual(0).ifElse(input, new Field(0)); +}