Skip to content

Commit

Permalink
Go back to from syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
paddyroddy committed Nov 8, 2024
1 parent 0a83b2e commit 663f80a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 32 deletions.
14 changes: 4 additions & 10 deletions glass/lensing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
if typing.TYPE_CHECKING:
import collections.abc

import cosmology.api
from cosmology.api import StandardCosmology

from glass.shells import RadialWindow

Expand Down Expand Up @@ -271,9 +271,7 @@ class MultiPlaneConvergence:

def __init__(
self,
cosmo: cosmology.api.StandardCosmology[
npt.NDArray[np.float64], npt.NDArray[np.float64]
],
cosmo: StandardCosmology[npt.NDArray[np.float64], npt.NDArray[np.float64]],
) -> None:
"""Create a new instance to iteratively compute the convergence."""
self.cosmo = cosmo
Expand Down Expand Up @@ -382,9 +380,7 @@ def wlens(self) -> float:

def multi_plane_matrix(
shells: collections.abc.Sequence[RadialWindow],
cosmo: cosmology.api.StandardCosmology[
npt.NDArray[np.float64], npt.NDArray[np.float64]
],
cosmo: StandardCosmology[npt.NDArray[np.float64], npt.NDArray[np.float64]],
) -> npt.NDArray[np.float64]:
"""Compute the matrix of lensing contributions from each shell."""
mpc = MultiPlaneConvergence(cosmo)
Expand All @@ -398,9 +394,7 @@ def multi_plane_matrix(
def multi_plane_weights(
weights: npt.NDArray[np.float64],
shells: collections.abc.Sequence[RadialWindow],
cosmo: cosmology.api.StandardCosmology[
npt.NDArray[np.float64], npt.NDArray[np.float64]
],
cosmo: StandardCosmology[npt.NDArray[np.float64], npt.NDArray[np.float64]],
) -> npt.NDArray[np.float64]:
"""
Compute effective weights for multi-plane convergence.
Expand Down
16 changes: 4 additions & 12 deletions glass/shells.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,29 +63,23 @@

def distance_weight(
z: npt.NDArray[np.float64],
cosmo: cosmology.api.StandardCosmology[
npt.NDArray[np.float64], npt.NDArray[np.float64]
],
cosmo: StandardCosmology[npt.NDArray[np.float64], npt.NDArray[np.float64]],
) -> npt.NDArray[np.float64]:
"""Uniform weight in comoving distance."""
return 1 / cosmo.H_over_H0(z) # type: ignore[no-any-return]


def volume_weight(
z: npt.NDArray[np.float64],
cosmo: cosmology.api.StandardCosmology[
npt.NDArray[np.float64], npt.NDArray[np.float64]
],
cosmo: StandardCosmology[npt.NDArray[np.float64], npt.NDArray[np.float64]],
) -> npt.NDArray[np.float64]:
"""Uniform weight in comoving volume."""
return cosmo.xm(z) ** 2 / cosmo.H_over_H0(z) # type: ignore[no-any-return]


def density_weight(
z: npt.NDArray[np.float64],
cosmo: cosmology.api.StandardCosmology[
npt.NDArray[np.float64], npt.NDArray[np.float64]
],
cosmo: StandardCosmology[npt.NDArray[np.float64], npt.NDArray[np.float64]],
) -> npt.NDArray[np.float64]:
"""Uniform weight in matter density."""
return ( # type: ignore[no-any-return]
Expand Down Expand Up @@ -661,9 +655,7 @@ def redshift_grid(


def distance_grid(
cosmo: cosmology.api.StandardCosmology[
npt.NDArray[np.float64], npt.NDArray[np.float64]
],
cosmo: StandardCosmology[npt.NDArray[np.float64], npt.NDArray[np.float64]],
zmin: float,
zmax: float,
*,
Expand Down
14 changes: 4 additions & 10 deletions tests/test_lensing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from glass.shells import RadialWindow

if typing.TYPE_CHECKING:
import cosmology.api
from cosmology.api import StandardCosmology


@pytest.fixture
Expand All @@ -31,9 +31,7 @@ def shells() -> list[RadialWindow]:


@pytest.fixture
def cosmo() -> (
cosmology.api.StandardCosmology[npt.NDArray[np.float64], npt.NDArray[np.float64]]
):
def cosmo() -> StandardCosmology[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
class MockCosmology:
@property
def Omega_m0(self) -> float: # noqa: N802
Expand Down Expand Up @@ -99,9 +97,7 @@ def test_deflect_many(rng: np.random.Generator) -> None:

def test_multi_plane_matrix(
shells: list[RadialWindow],
cosmo: cosmology.api.StandardCosmology[
npt.NDArray[np.float64], npt.NDArray[np.float64]
],
cosmo: StandardCosmology[npt.NDArray[np.float64], npt.NDArray[np.float64]],
rng: np.random.Generator,
) -> None:
mat = multi_plane_matrix(shells, cosmo)
Expand All @@ -123,9 +119,7 @@ def test_multi_plane_matrix(

def test_multi_plane_weights(
shells: list[RadialWindow],
cosmo: cosmology.api.StandardCosmology[
npt.NDArray[np.float64], npt.NDArray[np.float64]
],
cosmo: StandardCosmology[npt.NDArray[np.float64], npt.NDArray[np.float64]],
rng: np.random.Generator,
) -> None:
w_in = np.eye(len(shells))
Expand Down

0 comments on commit 663f80a

Please sign in to comment.