Skip to content

Commit

Permalink
add keys for the states inside Stats class
Browse files Browse the repository at this point in the history
  • Loading branch information
Apolline El-Baz committed Apr 8, 2024
1 parent 8764da9 commit fdcb106
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 33 deletions.
6 changes: 4 additions & 2 deletions smash/core/simulation/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import TYPE_CHECKING

from smash._constant import (
INTERNAL_FLUXES
INTERNAL_FLUXES,
STRUCTURE_RR_STATES,
)

import numpy as np
Expand Down Expand Up @@ -163,8 +164,9 @@ def _forward_run(
return_options["nmts"],
return_options["fkeys"],
)

wrap_returns.stats.fluxes_keys = INTERNAL_FLUXES[model.setup.hydrological_module]
wrap_returns.stats.rr_states_keys = STRUCTURE_RR_STATES[model.setup.structure]

# % Map cost_options dict to derived type
_map_dict_to_fortran_derived_type(cost_options, wrap_options.cost)
Expand Down
28 changes: 20 additions & 8 deletions smash/core/stats/stats_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,27 @@
def _get_idx(stats:StatsDT, name):
for i, key in enumerate(stats.fluxes_keys):
if key == name:
return i
draw = "fluxes"
return i, draw
for i, key in enumerate(stats.rr_states_keys):
if key == name:
draw = "states"
return i, draw

def get_fluxes(stats:StatsDT, name):
i = _get_idx(stats, name)
mean = stats.fluxes_values[:, :, 0, i]
var = stats.fluxes_values[:, :, 1, i]
minimum = stats.fluxes_values[:, :, 2, i]
maximum = stats.fluxes_values[:, :, 3, i]
median = stats.fluxes_values[:, :, 4, i]
def get(stats:StatsDT, name):
i, draw = _get_idx(stats, name)
if draw == "fluxes":
mean = stats.fluxes_values[:, :, 0, i]
var = stats.fluxes_values[:, :, 1, i]
minimum = stats.fluxes_values[:, :, 2, i]
maximum = stats.fluxes_values[:, :, 3, i]
median = stats.fluxes_values[:, :, 4, i]
if draw == "states":
mean = stats.rr_states_values[:, :, 0, i]
var = stats.rr_states_values[:, :, 1, i]
minimum = stats.rr_states_values[:, :, 2, i]
maximum = stats.rr_states_values[:, :, 3, i]
median = stats.rr_states_values[:, :, 4, i]
return mean, var, minimum, maximum, median

'''
Expand Down
6 changes: 5 additions & 1 deletion smash/fcore/derived_type/mwd_stats.f90
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ module mwd_stats

real(sp), dimension(:, :, :), allocatable :: internal_fluxes
character(lchar), dimension(:), allocatable :: fluxes_keys !$F90W char-array
real(sp), dimension(:, :, :, :), allocatable :: fluxes_values
real(sp), dimension(:, :, :, :), allocatable :: fluxes_values
character(lchar), dimension(:), allocatable :: rr_states_keys !$F90W char-array
real(sp), dimension(:, :, :, :), allocatable :: rr_states_values ! rr_states_keys in Rr_StatesDT class

end type StatsDT
Expand All @@ -53,6 +54,9 @@ subroutine StatsDT_initialise(this, setup, mesh)

allocate (this%fluxes_values(mesh%ng, setup%ntime_step, 5, setup%nfx))

allocate (this%rr_states_keys(setup%nrrs))
this%rr_states_keys = "..."

allocate (this%rr_states_values(mesh%ng, setup%ntime_step, 5, setup%nrrs))


Expand Down
18 changes: 0 additions & 18 deletions smash/fcore/forward/forward_db.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5693,24 +5693,6 @@ SUBROUTINE GET_SERR_SIGMA_PARAMETERS(serr_sigma_parameters, key, vle)
END DO
END SUBROUTINE GET_SERR_SIGMA_PARAMETERS

SUBROUTINE GET_STATS(stats, key, vle)
IMPLICIT NONE
! Should be unreachable
TYPE(STATSDT), INTENT(IN) :: stats
CHARACTER(len=*), INTENT(IN) :: key
REAL(sp), DIMENSION(:, :, :), INTENT(INOUT) :: vle
INTEGER :: i
INTRINSIC SIZE
INTRINSIC TRIM
! Linear search on fluxes_keys
DO i=1,SIZE(stats%fluxes_keys)
IF (TRIM(stats%fluxes_keys(i)) .EQ. key) THEN
vle = stats%fluxes_values(:, :, :, i)
RETURN
END IF
END DO
END SUBROUTINE GET_STATS

SUBROUTINE SET_RR_PARAMETERS(rr_parameters, key, vle)
IMPLICIT NONE
! Should be unreachable
Expand Down
5 changes: 1 addition & 4 deletions smash/fcore/forward/md_simulation.f90
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ subroutine compute_fluxes_stats(mesh, t, idx, returns)
returns%stats%fluxes_values(j, t, 2, idx) = sum((fx - m) * (fx - m), mask = mask) / npos_val
returns%stats%fluxes_values(j, t, 3, idx) = minval(fx, mask = mask)
returns%stats%fluxes_values(j, t, 4, idx) = maxval(fx, mask = mask)

!~ print *, returns%stats%fluxes_values(j, t, 1, idx)


if (.not. allocated(fx_flat)) allocate (fx_flat(npos_val))
fx_flat = pack(fx, mask .eqv. .True.)

Expand All @@ -174,7 +172,6 @@ subroutine compute_fluxes_stats(mesh, t, idx, returns)
end if

end do
!~ print *, returns%stats%fluxes_values
!$AD end-exclude
end subroutine

Expand Down

0 comments on commit fdcb106

Please sign in to comment.