From 673646eedaca4caf28a0295684679d78ad5cc602 Mon Sep 17 00:00:00 2001 From: arnav-singhal <43693924+arnav-singhal@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:20:44 -0400 Subject: [PATCH] Visualization: more plotting options and handling of missing .shx file (#85) * fix handling of .shx and add command-line option to plot deaths * fix whitespace --- utilities/plotMovie/generate_frames.py | 56 +++++++++++++++----------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/utilities/plotMovie/generate_frames.py b/utilities/plotMovie/generate_frames.py index 07fcf3f..6c349c6 100644 --- a/utilities/plotMovie/generate_frames.py +++ b/utilities/plotMovie/generate_frames.py @@ -6,14 +6,18 @@ ffmpeg -framerate 1 -r 30 -i frames/frame%05d.png -pix_fmt yuv420p movie.mp4 Can adjust inputs directly at bottom of file or as command-line input: -python generate_frames.py [sim results dir] [.shx dir] [output dir (optional)] +python generate_frames.py [sim results dir] [.shp dir] [output dir] [plotted element] The function being plotted can be altered in the get_raw() functions defined in the get_raw_data() and get_raw_data_hdf5() functions, such as plotting cumulative infections, proportions of infections, or raw counts. -Endpoints of color range can be set in or passed into generate_plot, -or in main function, as necessary. +The plotted element can be any of the data fields in the county-level data: +infected, never_infected, susceptible, immune, +or deaths to plot number of deaths. + +Depending on choice, adjusting color scale may be advisable, which can be +done by changing vmin and vmax values on line 198 in the call to generate_plot() Since US census tracts include territories around the globe, if we use this data, some sort of cropping is necessary (i.e. 48-state mainland) @@ -87,26 +91,27 @@ def get_raw(county): ds.close() return raw_df -def get_raw_data_hdf5(name: str): +def get_raw_data_hdf5(name: str, plot_option: str = "infected"): f = h5py.File(name, 'r') - found = 0 - i = 0 - while found < 2: - if f.attrs['component_' + str(i)] == b'FIPS': - fips_idx = i - found += 1 - if f.attrs['component_' + str(i)] == b'infected': - inf_idx = i - found += 1 - i += 1 - - fips = f['level_0']['data:datatype=' + str(fips_idx)][()] - infs = f['level_0']['data:datatype=' + str(inf_idx)][()] + comm_indices = {} + for i in range(f.attrs['num_components'][0]): + comm_indices[f.attrs['component_' + str(i)]] = str(i) + + fips = f['level_0']['data:datatype=' + comm_indices[b'FIPS']][()] + + if plot_option == "deaths": + plts = (f['level_0']['data:datatype=' + comm_indices[b'total']][()] + - f['level_0']['data:datatype=' + comm_indices[b'infected']][()] + - f['level_0']['data:datatype=' + comm_indices[b'never_infected']][()] + - f['level_0']['data:datatype=' + comm_indices[b'immune']][()] + - f['level_0']['data:datatype=' + comm_indices[b'susceptible']][()]) + else: + plts = f['level_0']['data:datatype=' + comm_indices[bytes(plot_option, "utf-8")]][()] unique_fips = np.unique(fips).astype(int) def get_raw(county): mask = fips == county - return np.log(1 + infs[mask].sum()) + return np.log(1 + plts[mask].sum()) raw_df = pd.DataFrame() raw_df["FIPS"] = unique_fips @@ -124,10 +129,12 @@ def get_gdf(prefix: str): # can be done by running code inside a # with fiona.Env(SHAPE_RESTORE_SHX = "YES"): # block (remember to import fiona explicitly) - gdf = gpd.read_file(prefix + ".shp", driver="esri") - - # with fiona.Env(SHAPE_RESTORE_SHX = "YES"): - # gdf = gpd.read_file(prefix + ".shp", driver="esri") + if os.path.isfile(prefix + ".shx"): + gdf = gpd.read_file(prefix + ".shp", driver="esri") + else: + import fiona + with fiona.Env(SHAPE_RESTORE_SHX = "YES"): + gdf = gpd.read_file(prefix + ".shp", driver="esri") cols = list(gdf.columns) @@ -179,9 +186,12 @@ def generate_plot(per_df, gdf, vmin = None, vmax = None, crop_usa = False): crop_usa = "_us_" in prefix output_dir = sys.argv[3] if argc > 3 else "./frames_usa/" + + plot_option = sys.argv[4] if argc > 4 else "infected" + for i in range(len(data_names)): # vmin and vmax are endpoints for color range; 16 > log(population of LA) is a safe upper bound # for per-capita, endpoints should be set to much less - fig = generate_plot(get_raw_data_hdf5(data_names[i]), gdf, vmin=0, vmax=16, crop_usa = crop_usa) + fig = generate_plot(get_raw_data_hdf5(data_names[i], plot_option), gdf, vmin=0, vmax=16, crop_usa = crop_usa) fig.savefig(output_dir + "frame{:05d}".format(i)) plt.close(fig)