Skip to content

Commit

Permalink
Visualization: more plotting options and handling of missing .shx file (
Browse files Browse the repository at this point in the history
#85)

* fix handling of .shx and add command-line option to plot deaths

* fix whitespace
  • Loading branch information
arnav-singhal authored Oct 3, 2024
1 parent 7bf437e commit 673646e
Showing 1 changed file with 33 additions and 23 deletions.
56 changes: 33 additions & 23 deletions utilities/plotMovie/generate_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
ffmpeg -framerate 1 -r 30 -i frames/frame%05d.png -pix_fmt yuv420p movie.mp4
Can adjust inputs directly at bottom of file or as command-line input:
python generate_frames.py [sim results dir] [.shx dir] [output dir (optional)]
python generate_frames.py [sim results dir] [.shp dir] [output dir] [plotted element]
The function being plotted can be altered in the get_raw() functions defined
in the get_raw_data() and get_raw_data_hdf5() functions, such as plotting
cumulative infections, proportions of infections, or raw counts.
Endpoints of color range can be set in or passed into generate_plot,
or in main function, as necessary.
The plotted element can be any of the data fields in the county-level data:
infected, never_infected, susceptible, immune,
or deaths to plot number of deaths.
Depending on choice, adjusting color scale may be advisable, which can be
done by changing vmin and vmax values on line 198 in the call to generate_plot()
Since US census tracts include territories around the globe, if
we use this data, some sort of cropping is necessary (i.e. 48-state mainland)
Expand Down Expand Up @@ -87,26 +91,27 @@ def get_raw(county):
ds.close()
return raw_df

def get_raw_data_hdf5(name: str):
def get_raw_data_hdf5(name: str, plot_option: str = "infected"):
f = h5py.File(name, 'r')
found = 0
i = 0
while found < 2:
if f.attrs['component_' + str(i)] == b'FIPS':
fips_idx = i
found += 1
if f.attrs['component_' + str(i)] == b'infected':
inf_idx = i
found += 1
i += 1

fips = f['level_0']['data:datatype=' + str(fips_idx)][()]
infs = f['level_0']['data:datatype=' + str(inf_idx)][()]
comm_indices = {}
for i in range(f.attrs['num_components'][0]):
comm_indices[f.attrs['component_' + str(i)]] = str(i)

fips = f['level_0']['data:datatype=' + comm_indices[b'FIPS']][()]

if plot_option == "deaths":
plts = (f['level_0']['data:datatype=' + comm_indices[b'total']][()]
- f['level_0']['data:datatype=' + comm_indices[b'infected']][()]
- f['level_0']['data:datatype=' + comm_indices[b'never_infected']][()]
- f['level_0']['data:datatype=' + comm_indices[b'immune']][()]
- f['level_0']['data:datatype=' + comm_indices[b'susceptible']][()])
else:
plts = f['level_0']['data:datatype=' + comm_indices[bytes(plot_option, "utf-8")]][()]
unique_fips = np.unique(fips).astype(int)

def get_raw(county):
mask = fips == county
return np.log(1 + infs[mask].sum())
return np.log(1 + plts[mask].sum())

raw_df = pd.DataFrame()
raw_df["FIPS"] = unique_fips
Expand All @@ -124,10 +129,12 @@ def get_gdf(prefix: str):
# can be done by running code inside a
# with fiona.Env(SHAPE_RESTORE_SHX = "YES"):
# block (remember to import fiona explicitly)
gdf = gpd.read_file(prefix + ".shp", driver="esri")

# with fiona.Env(SHAPE_RESTORE_SHX = "YES"):
# gdf = gpd.read_file(prefix + ".shp", driver="esri")
if os.path.isfile(prefix + ".shx"):
gdf = gpd.read_file(prefix + ".shp", driver="esri")
else:
import fiona
with fiona.Env(SHAPE_RESTORE_SHX = "YES"):
gdf = gpd.read_file(prefix + ".shp", driver="esri")

cols = list(gdf.columns)

Expand Down Expand Up @@ -179,9 +186,12 @@ def generate_plot(per_df, gdf, vmin = None, vmax = None, crop_usa = False):
crop_usa = "_us_" in prefix

output_dir = sys.argv[3] if argc > 3 else "./frames_usa/"

plot_option = sys.argv[4] if argc > 4 else "infected"

for i in range(len(data_names)):
# vmin and vmax are endpoints for color range; 16 > log(population of LA) is a safe upper bound
# for per-capita, endpoints should be set to much less
fig = generate_plot(get_raw_data_hdf5(data_names[i]), gdf, vmin=0, vmax=16, crop_usa = crop_usa)
fig = generate_plot(get_raw_data_hdf5(data_names[i], plot_option), gdf, vmin=0, vmax=16, crop_usa = crop_usa)
fig.savefig(output_dir + "frame{:05d}".format(i))
plt.close(fig)

0 comments on commit 673646e

Please sign in to comment.