Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid grad check #27

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8f590f5
Adding HybridDerivativeCheck class.
stephanmg May 5, 2023
85e4932
Reformatting with black.
stephanmg May 5, 2023
30cbd19
Reverting existing test case to prior state.
stephanmg May 5, 2023
21ba0f2
Adding rtol and atol.
stephanmg May 5, 2023
71449f4
Fix.
stephanmg May 5, 2023
b384821
Merged main into hybrid_grad_check.
stephanmg May 5, 2023
481401d
Merge branch 'ICB-DCM:main' into hybrid_grad_check
stephanmg May 9, 2023
45acfb9
Merged and adding test case.
stephanmg May 10, 2023
e75f852
Formatting.
stephanmg May 10, 2023
ee666db
Clarified comment on new hybrid derivative check test case.
stephanmg May 10, 2023
fd106c5
Merge remote-tracking branch 'upstream/main' into hybrid_grad_check
stephanmg May 10, 2023
20e8a1d
Update fiddy/derivative_check.py
stephanmg Jun 30, 2023
c6e0d9a
Update fiddy/derivative_check.py
stephanmg Jun 30, 2023
52827e6
Update fiddy/derivative_check.py
stephanmg Jun 30, 2023
8fad9b7
Addressing review comments.
stephanmg Oct 5, 2023
f2d8c23
Merged.
stephanmg Oct 5, 2023
6cacf43
Merged.
stephanmg Oct 5, 2023
9bc3d18
Fix RTD.
stephanmg Nov 6, 2023
beb2fb9
Update fiddy/derivative_check.py
dilpath Nov 7, 2023
9d9747a
use `get_expected_and_test_values` in `NumpyIsCloseDerivativeCheck`
dilpath Nov 7, 2023
51279b5
Update fiddy/derivative_check.py
stephanmg Nov 13, 2023
6b6c327
Addressing DP's review.
stephanmg Nov 13, 2023
e0789e3
Merge branch 'hybrid_grad_check' of github.com:stephanmg/fiddy into h…
stephanmg Nov 13, 2023
75e4cf9
Fix indent.
stephanmg Nov 13, 2023
0922139
Fix check?
stephanmg Nov 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 129 additions & 12 deletions fiddy/derivative_check.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import abc
from typing import Any, Callable, Dict, List, Union
from itertools import chain
from dataclasses import dataclass
from typing import Any, Dict, List

import numpy as np
import pandas as pd
import math

from .constants import Type
from .derivative import Derivative
Expand Down Expand Up @@ -54,17 +56,17 @@ class DerivativeCheck(abc.ABC):
"""Check whether a derivative is correct.

Args:
derivative:
The test derivative.
expectation:
The expected derivative.
point:
The point where the test derivative was computed.
output_indices:
The derivative can be a multi-dimensional object that has dimensions
associated with the multiple outputs of a function, and dimensions
associated with the derivative of these multiple outputs with respect
to multiple directions.
derivative:
The test derivative.
expectation:
The expected derivative.
point:
The point where the test derivative was computed.
output_indices:
The derivative can be a multi-dimensional object that has dimensions
associated with the multiple outputs of a function, and dimensions
associated with the derivative of these multiple outputs with respect
to multiple directions.
dilpath marked this conversation as resolved.
Show resolved Hide resolved
"""

method_id: str
Expand Down Expand Up @@ -137,3 +139,118 @@ def method(self, *args, **kwargs):
success=success,
)
return derivative_check_result


def get_expected_and_test_values(directional_derivatives):
expected_values = []
test_values = []
for direction_index, directional_derivative in enumerate(
directional_derivatives
):
test_value = directional_derivative.value
test_values.append(test_value)

expected_value = []
for output_index in np.ndindex(self.output_indices):
element = self.expectation[output_index][direction_index]
expected_value.append(element)
expected_value = np.array(expected_value).reshape(test_value.shape)
expected_values.append(expected_value)

return expected_values, test_values


class HybridDerivativeCheck(DerivativeCheck):
"""HybridDerivativeCheck.

The method checks, if gradients are in finite differences range [min, max],
using forward, backward and central finite differences for potential
multiple stepsizes eps. If true, gradients will be checked for each
parameter and assessed whether or not gradients are within acceptable
absolute tolerances.
"""

method_id = "hybrid"

def method(self, *args, **kwargs):
success = True
expected_values, test_values = get_expected_and_test_values(
self.derivative.directional_derivatives
)

results_all = []
directional_derivative_check_results = []
for step_size in range(0, len(expected_values)):
approxs_for_param = []
grads_for_param = []
Comment on lines +189 to +190
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be reset in the inner loop, instead of here in the outer loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be okay as-is.

results = []
for diff_index, directional_derivative in enumerate(
self.derivative.directional_derivatives
):
try:
for grad, approx in zip(
expected_values[diff_index - 1][step_size - 1],
test_values[diff_index - 1][step_size - 1],
):
approxs_for_param.append(approx)
grads_for_param.append(grad)
fd_range = np.percentile(approxs_for_param, [0, 100])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just the min and max of approxs_for_param?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

fd_mean = np.mean(approxs_for_param)
grad_mean = np.mean(grads_for_param)
if not (fd_range[0] <= grad_mean <= fd_range[1]):
if np.any(
[
abs(x - y) > kwargs["atol"]
for i, x in enumerate(approxs_for_param)
for j, y in enumerate(approxs_for_param)
if i != j
]
):
fd_range = abs(fd_range[1] - fd_range[0])
if (
abs(grad_mean - fd_mean)
/ abs(fd_range + np.finfo(float).eps)
) > kwargs["rtol"]:
results.append(False)
else:
results.append(False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The handling of results needs some work. For example, here both cases of the if-else have results.append(False).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

else:
results.append(
None
) # can't judge consistency / questionable grad approxs
else:
fd_range = abs(fd_range[1] - fd_range[0])
if not np.isfinite([fd_range, fd_mean]).all():
results.append(None)
else:
result = True
results.append(result)
Comment on lines +228 to +232
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to do result = ... everywhere (if going with this coding style), and finally results.append(result) at the end (outside of this else statement). I can do this, will wait until the other comments are resolved.

except (IndexError, TypeError) as err:
print(
f"Unexpected error encountered: {err} (This should never happen!)"
)
stephanmg marked this conversation as resolved.
Show resolved Hide resolved

directional_derivative_check_result = (
DirectionalDerivativeCheckResult(
direction_id=directional_derivative.id,
method_id=self.method_id,
test=test_value,
expectation=expected_value,
output={"return": results},
success=all(results),
)
)
directional_derivative_check_results.append(
directional_derivative_check_result
)
results_all.append(results)

success = all(chain(*results_all))
derivative_check_result = DerivativeCheckResult(
method_id=self.method_id,
directional_derivative_check_results=directional_derivative_check_results,
test=self.derivative.value,
expectation=self.expectation,
success=success,
)
return derivative_check_result
Loading