Skip to content

Commit

Permalink
Fix torch error on misc.CovarianceMatrix (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisoutin authored Jan 24, 2024
1 parent 341b34d commit 2d0d831
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.7, 3.8, 3.9]
python-version: ['3.10', '3.11']
steps:
- name: Checkout most recent commit
uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion deepdow/layers/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def compute_covariance(m, shrinkage_strategy=None, shrinkage_coef=0.5):
"""
fact = 1.0 / (m.size(1) - 1)
m -= torch.mean(m, dim=1, keepdim=True) # !!!!!!!!!!! INPLACE
m = m - torch.mean(m, dim=1, keepdim=True)
mt = m.t()

s = fact * m.matmul(mt) # sample covariance matrix
Expand Down
9 changes: 9 additions & 0 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
Changelog
=========

Unreleased
----------
Fixes
*****
- Fix PyTorch in-place issue and matplotib plotting problem - #147

v0.2.2
------

v0.2.1
------
Added
Expand Down
4 changes: 2 additions & 2 deletions examples/layers/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@

axs[i, 0].plot(tform.numpy().squeeze(), linewidth=3, color='red')
axs[i, 1].plot(x_warped.numpy().squeeze(), linewidth=3, color='blue')
axs[i, 0].set_title(r'$\bf{}$ tform'.format(tform_name))
axs[i, 1].set_title(r'$\bf{}$ warped'.format(tform_name))
axs[i, 0].set_title('{} tform'.format(tform_name))
axs[i, 1].set_title('{} warped'.format(tform_name))

0 comments on commit 2d0d831

Please sign in to comment.