-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hybrid grad check #27
base: main
Are you sure you want to change the base?
Changes from all commits
8f590f5
85e4932
30cbd19
21ba0f2
71449f4
b384821
481401d
45acfb9
e75f852
ee666db
fd106c5
20e8a1d
c6e0d9a
52827e6
8fad9b7
f2d8c23
6cacf43
9bc3d18
beb2fb9
9d9747a
51279b5
6b6c327
e0789e3
75e4cf9
0922139
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
import abc | ||
from typing import Any, Callable, Dict, List, Union | ||
from itertools import chain | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, List | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import math | ||
|
||
from .constants import Type | ||
from .derivative import Derivative | ||
|
@@ -95,17 +97,20 @@ class NumpyIsCloseDerivativeCheck(DerivativeCheck): | |
|
||
def method(self, *args, **kwargs): | ||
directional_derivative_check_results = [] | ||
for direction_index, directional_derivative in enumerate( | ||
expected_values, test_values = get_expected_and_test_values( | ||
self.derivative.directional_derivatives | ||
): | ||
test_value = directional_derivative.value | ||
|
||
expected_value = [] | ||
for output_index in np.ndindex(self.output_indices): | ||
element = self.expectation[output_index][direction_index] | ||
expected_value.append(element) | ||
expected_value = np.array(expected_value).reshape(test_value.shape) | ||
) | ||
|
||
for ( | ||
direction_index, | ||
directional_derivative, | ||
expected_value, | ||
test_value, | ||
) in enumerate(zip( | ||
self.derivative.directional_derivatives, | ||
expected_values, | ||
test_values, | ||
)): | ||
test_result = np.isclose( | ||
test_value, | ||
expected_value, | ||
|
@@ -137,3 +142,120 @@ def method(self, *args, **kwargs): | |
success=success, | ||
) | ||
return derivative_check_result | ||
|
||
|
||
def get_expected_and_test_values(directional_derivatives): | ||
expected_values = [] | ||
test_values = [] | ||
for direction_index, directional_derivative in enumerate( | ||
directional_derivatives | ||
): | ||
test_value = directional_derivative.value | ||
test_values.append(test_value) | ||
|
||
expected_value = [] | ||
for output_index in np.ndindex(self.output_indices): | ||
element = self.expectation[output_index][direction_index] | ||
expected_value.append(element) | ||
expected_value = np.array(expected_value).reshape(test_value.shape) | ||
expected_values.append(expected_value) | ||
|
||
return expected_values, test_values | ||
|
||
|
||
class HybridDerivativeCheck(DerivativeCheck): | ||
"""HybridDerivativeCheck. | ||
|
||
The method checks, if gradients are in finite differences range [min, max], | ||
using forward, backward and central finite differences for potential | ||
multiple stepsizes eps. If true, gradients will be checked for each | ||
parameter and assessed whether or not gradients are within acceptable | ||
absolute tolerances. | ||
.. math:: | ||
\\frac{|\\mu - \\kappa|}{\\lambda} < \\epsilon | ||
""" | ||
|
||
method_id = "hybrid" | ||
|
||
def method(self, *args, **kwargs): | ||
success = True | ||
expected_values, test_values = get_expected_and_test_values( | ||
self.derivative.directional_derivatives | ||
) | ||
|
||
results_all = [] | ||
directional_derivative_check_results = [] | ||
for step_size in range(0, len(expected_values)): | ||
approxs_for_param = [] | ||
grads_for_param = [] | ||
results = [] | ||
for diff_index, directional_derivative in enumerate( | ||
self.derivative.directional_derivatives | ||
): | ||
try: | ||
for grad, approx in zip( | ||
expected_values[diff_index - 1][step_size - 1], | ||
test_values[diff_index - 1][step_size - 1], | ||
): | ||
approxs_for_param.append(approx) | ||
grads_for_param.append(grad) | ||
fd_range = np.percentile(approxs_for_param, [0, 100]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this just the min and max of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
fd_mean = np.mean(approxs_for_param) | ||
grad_mean = np.mean(grads_for_param) | ||
if not (fd_range[0] <= grad_mean <= fd_range[1]): | ||
if np.any( | ||
[ | ||
abs(x - y) > kwargs["atol"] | ||
for i, x in enumerate(approxs_for_param) | ||
for j, y in enumerate(approxs_for_param) | ||
if i != j | ||
] | ||
): | ||
fd_range = abs(fd_range[1] - fd_range[0]) | ||
if ( | ||
abs(grad_mean - fd_mean) | ||
/ abs(fd_range + np.finfo(float).eps) | ||
) > kwargs["rtol"]: | ||
results.append(False) | ||
else: | ||
results.append(True) | ||
else: | ||
results.append( | ||
None | ||
) # can't judge consistency / questionable grad approxs | ||
else: | ||
fd_range = abs(fd_range[1] - fd_range[0]) | ||
if not np.isfinite([fd_range, fd_mean]).all(): | ||
results.append(None) | ||
else: | ||
result = True | ||
results.append(result) | ||
Comment on lines
+228
to
+232
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to do |
||
except (IndexError, TypeError) as err: | ||
raise ValueError( | ||
f"Unexpected error encountered (This should never happen!)" | ||
) from err | ||
|
||
directional_derivative_check_result = ( | ||
DirectionalDerivativeCheckResult( | ||
direction_id=directional_derivative.id, | ||
method_id=self.method_id, | ||
test=test_value, | ||
expectation=expected_value, | ||
output={"return": results}, | ||
success=all(results), | ||
) | ||
) | ||
directional_derivative_check_results.append( | ||
directional_derivative_check_result | ||
) | ||
results_all.append(results) | ||
|
||
success = all(chain(*results_all)) | ||
derivative_check_result = DerivativeCheckResult( | ||
method_id=self.method_id, | ||
directional_derivative_check_results=directional_derivative_check_results, | ||
test=self.derivative.value, | ||
expectation=self.expectation, | ||
success=success, | ||
) | ||
return derivative_check_result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be reset in the inner loop, instead of here in the outer loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be okay as-is.