Skip to content

Commit

Permalink
Always use native Python types for list outputs (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwtoews authored Jul 3, 2024
1 parent 3023543 commit 81ef345
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 81 deletions.
34 changes: 17 additions & 17 deletions swn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,12 +814,12 @@ def route_segnums(self, start, end, *, allow_indirect=False):
raise IndexError(f"invalid end segnum {end}")
if start == end:
return [start]
to_segnums = dict(self.to_segnums)
to_segnums_d = self.to_segnums.to_dict()

def go_downstream(segnum):
yield segnum
if segnum in to_segnums:
yield from go_downstream(to_segnums[segnum])
if segnum in to_segnums_d:
yield from go_downstream(to_segnums_d[segnum])

con1 = list(go_downstream(start))
try:
Expand Down Expand Up @@ -942,17 +942,17 @@ def go_upstream(segnum):

def go_downstream(segnum):
yield segnum
if segnum in to_segnums:
yield from go_downstream(to_segnums[segnum])
if segnum in to_segnums_d:
yield from go_downstream(to_segnums_d[segnum])

to_segnums = dict(self.to_segnums)
to_segnums_d = self.to_segnums.to_dict()
from_segnums = self.from_segnums
for barrier in check_and_return_list(barrier, "barrier"):
try:
del from_segnums[barrier]
except KeyError: # this is a tributary, remove value
from_segnums[to_segnums[barrier]].remove(barrier)
del to_segnums[barrier]
from_segnums[to_segnums_d[barrier]].remove(barrier)
del to_segnums_d[barrier]

segnums = []
for segnum in check_and_return_list(upstream, "upstream"):
Expand Down Expand Up @@ -1125,7 +1125,7 @@ def locate_geoms(
match_s.index.name = "gidx"
match = match_s.reset_index()
if min_stream_order is not None:
to_segnums = dict(self.to_segnums)
to_segnums_d = self.to_segnums.to_dict()

def find_downstream_in_min_stream_order(segnum):
while True:
Expand All @@ -1134,8 +1134,8 @@ def find_downstream_in_min_stream_order(segnum):
>= min_stream_order
):
return segnum
elif segnum in to_segnums:
segnum = to_segnums[segnum]
elif segnum in to_segnums_d:
segnum = to_segnums_d[segnum]
else: # nothing found with stream order criteria
return segnum

Expand Down Expand Up @@ -1366,15 +1366,15 @@ def aggregate(self, segnums, follow_up="upstream_length"):
)
self.logger.debug("aggregating at least %d segnums (junctions)", len(junctions))
from_segnums = self.from_segnums
to_segnums = dict(self.to_segnums)
to_segnums_d = self.to_segnums.to_dict()

# trace down from each segnum to the outlet - keep this step simple
traced_segnums = list()

def trace_down(segnum):
if segnum is not None and segnum not in traced_segnums:
traced_segnums.append(segnum)
trace_down(to_segnums.get(segnum))
trace_down(to_segnums_d.get(segnum))

for segnum in junctions:
trace_down(segnum)
Expand Down Expand Up @@ -1467,7 +1467,7 @@ def up_path_headwater_segnums(segnum):
# segnum, up_segnums, up_segnum)
yield from up_path_headwater_segnums(up_segnum)

junctions_goto = {s: to_segnums.get(s) for s in junctions}
junctions_goto = {s: to_segnums_d.get(s) for s in junctions}
agg_patch = pd.Series(dtype=object)
agg_path = pd.Series(dtype=object)
agg_unpath = pd.Series(dtype=object)
Expand Down Expand Up @@ -2019,7 +2019,7 @@ def adjust_elevation_profile(self, min_slope=1.0 / 1000):

geom_name = self.segments.geometry.name
from_segnums = self.from_segnums
to_segnums = dict(self.to_segnums)
to_segnums_d = self.to_segnums.to_dict()
modified_d = {} # key is segnum, value is drop amount (+ve is down)
self.messages = []

Expand Down Expand Up @@ -2066,8 +2066,8 @@ def adjust_elevation_profile(self, min_slope=1.0 / 1000):
# print('adj', z0 + drop0, dx * min_slope[segnum], drop)
z0 = z1
# Ensure last coordinate matches other segments that end here
if segnum in to_segnums:
beside_segnums = from_segnums[to_segnums[segnum]]
if segnum in to_segnums_d:
beside_segnums = from_segnums[to_segnums_d[segnum]]
if beside_segnums:
last_zs = [profile_d[n][-1][1] for n in beside_segnums]
last_zs_min = min(last_zs)
Expand Down
4 changes: 3 additions & 1 deletion swn/modflow/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,9 @@ def do_linemerge(ij, df, drop_reach_ids):
has_sjoin_nearest = False
for divn in diversions_in_model.itertuples():
# Use the last upstream reach as a template for a new reach
reach_d = dict(reaches.loc[reaches.segnum == divn.from_segnum].iloc[-1])
reach_d = (
reaches.loc[reaches.segnum == divn.from_segnum].iloc[-1].to_dict()
)
reach_d.update(
{
"segnum": swn.END_SEGNUM,
Expand Down
22 changes: 13 additions & 9 deletions swn/modflow/_swnmf6.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,12 +2020,14 @@ def route_reaches(self, start, end, *, allow_indirect=False):
if start == end:
return [start]
to_ridxname = f"to_{self.reach_index_name}"
to_ridxs = dict(self.reaches.loc[self.reaches[to_ridxname] != 0, to_ridxname])
to_ridxs_d = self.reaches.loc[
self.reaches[to_ridxname] != 0, to_ridxname
].to_dict()

def go_downstream(ridx):
yield ridx
if ridx in to_ridxs:
yield from go_downstream(to_ridxs[ridx])
if ridx in to_ridxs_d:
yield from go_downstream(to_ridxs_d[ridx])

con1 = list(go_downstream(start))
try:
Expand Down Expand Up @@ -2145,19 +2147,21 @@ def go_upstream(ridx):

def go_downstream(ridx):
yield ridx
if ridx in to_ridxs:
yield from go_downstream(to_ridxs[ridx])
if ridx in to_ridxs_d:
yield from go_downstream(to_ridxs_d[ridx])

to_ridx_name = f"to_{self.reach_index_name}"
to_ridxs = dict(self.reaches.loc[self.reaches[to_ridx_name] != 0, to_ridx_name])
to_ridxs_d = self.reaches.loc[
self.reaches[to_ridx_name] != 0, to_ridx_name
].to_dict()
from_ridxs = self.reaches[f"from_{self.reach_index_name}s"]
# Note that `.copy(deep=True)` does not work; use deepcopy
from_ridxs = from_ridxs[from_ridxs.apply(len) > 0].apply(deepcopy)
for barrier in check_and_return_list(barrier, "barrier"):
for ridx in from_ridxs.get(barrier, []):
del to_ridxs[ridx]
from_ridxs[to_ridxs[barrier]].remove(barrier)
del to_ridxs[barrier]
del to_ridxs_d[ridx]
from_ridxs[to_ridxs_d[barrier]].remove(barrier)
del to_ridxs_d[barrier]

ridxs = []
for ridx in check_and_return_list(upstream, "upstream"):
Expand Down
18 changes: 10 additions & 8 deletions swn/modflow/_swnmodflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,19 +1747,21 @@ def route_reaches(self, start, end, *, allow_indirect=False):
.iseg
)

to_reachids = {}
to_reachids_d = {}
for segnum, iseg in segnum_iseg.items():
sel = self.reaches.index[self.reaches.iseg == iseg]
to_reachids.update(dict(zip(sel[0:-1], sel[1:])))
sel_l = self.reaches.index[self.reaches.iseg == iseg].to_list()
to_reachids_d.update(dict(zip(sel_l[0:-1], sel_l[1:])))
next_segnum = self.segments.to_segnum[segnum]
next_reachids = self.reaches.index[self.reaches.segnum == next_segnum]
if len(next_reachids) > 0:
to_reachids[sel[-1]] = next_reachids[0]
next_reachids_l = self.reaches.index[
self.reaches.segnum == next_segnum
].to_list()
if len(next_reachids_l) > 0:
to_reachids_d[sel_l[-1]] = next_reachids_l[0]

def go_downstream(rid):
yield rid
if rid in to_reachids:
yield from go_downstream(to_reachids[rid])
if rid in to_reachids_d:
yield from go_downstream(to_reachids_d[rid])

con1 = list(go_downstream(start))
try:
Expand Down
29 changes: 17 additions & 12 deletions swn/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ def find_location_pairs(loc_df, n, *, all_pairs=False, exclude_branches=False):
if not segnum_is_in_index.all():
raise ValueError("loc_df has segnum values not found in surface water network")

to_segnums = dict(n.to_segnums)
to_segnums_d = n.to_segnums.to_dict()
if exclude_branches:
from_segnums = dict(n.from_segnums)
from_segnums_d = n.from_segnums.to_dict()
loc_df = loc_df[["segnum", "seg_ndist"]].assign(_="") # also does .copy()
loc_segnum_s = set(loc_df.segnum)
loc_df["sequence"] = n.segments.sequence[loc_df.segnum].values
Expand All @@ -418,9 +418,12 @@ def find_location_pairs(loc_df, n, *, all_pairs=False, exclude_branches=False):
# continue searching downstream
cur_segnum = us_segnum
while True:
if cur_segnum in to_segnums:
next_segnum = to_segnums[cur_segnum]
if exclude_branches and len(from_segnums.get(next_segnum, [])) > 1:
if cur_segnum in to_segnums_d:
next_segnum = to_segnums_d[cur_segnum]
if (
exclude_branches
and len(from_segnums_d.get(next_segnum, [])) > 1
):
break # stop searching due to branch
sel = loc_df["segnum"] == next_segnum
for ds_idx in sel[sel].index:
Expand All @@ -433,22 +436,24 @@ def find_location_pairs(loc_df, n, *, all_pairs=False, exclude_branches=False):
# First case that the downstream segnum is in the same segnum
next_loc = loc_df.iloc[next_iloc]
if next_loc.segnum == us_segnum:
ds_idx = next_loc.name
ds_idx = next_loc.name.item()
else:
# otherwise search downstream
cur_segnum = us_segnum
while True:
if cur_segnum in to_segnums:
next_segnum = to_segnums[cur_segnum]
if cur_segnum in to_segnums_d:
next_segnum = to_segnums_d[cur_segnum]
if (
exclude_branches
and len(from_segnums.get(next_segnum, [])) > 1
and len(from_segnums_d.get(next_segnum, [])) > 1
):
break # no pair due to branch
if next_segnum in loc_segnum_s:
ds_idx = loc_df.segnum[loc_df.segnum == next_segnum].index[
0
]
ds_idx = (
loc_df.segnum[loc_df.segnum == next_segnum]
.index[0]
.item()
)
break # found pair
else:
break # no pair due to no downstream location
Expand Down
Loading

0 comments on commit 81ef345

Please sign in to comment.