Skip to content

Commit

Permalink
Merge pull request #299 from informatics-lab/earth-networks-datashader
Browse files Browse the repository at this point in the history
Earth networks optimisation
  • Loading branch information
andrewgryan authored Mar 11, 2020
2 parents 79dc025 + dce4b66 commit 149a1e2
Show file tree
Hide file tree
Showing 14 changed files with 165 additions and 50 deletions.
2 changes: 1 addition & 1 deletion forest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
.. automodule:: forest.presets
"""
__version__ = '0.12.6'
__version__ = '0.12.7'

from .config import *
from . import (
Expand Down
7 changes: 1 addition & 6 deletions forest/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,8 @@
pass
from forest import (
gridded_forecast,
saf,
satellite,
rdt,
earth_networks,
geo,
disk,
nearcast)
disk)
import bokeh.models
from collections import OrderedDict
from functools import partial
Expand Down
16 changes: 13 additions & 3 deletions forest/db/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,24 @@ def _pressure(self, store, action):
yield set_value(key, value)

def _pattern(self, store, action):
value = action["payload"]["value"]
variables = self.navigator.variables(pattern=value)
initial_times = self.navigator.initial_times(pattern=value)
pattern = action["payload"]["value"]
variables = self.navigator.variables(pattern=pattern)
initial_times = self.navigator.initial_times(pattern=pattern)
initial_times = list(reversed(initial_times))
yield action
yield set_value("variables", variables)
yield set_value("initial_times", initial_times)

# Set valid_times if pattern, variable and initial_time present
kwargs = {
"pattern": pattern,
"variable": store.state.get("variable"),
"initial_time": store.state.get("initial_time"),
}
if all(kwargs[k] is not None for k in ["variable", "initial_time"]):
valid_times = self.navigator.valid_times(**kwargs)
yield set_value("valid_times", valid_times)

def _variable(self, store, action):
for attr in ["pattern", "initial_time"]:
if attr not in store.state:
Expand Down
33 changes: 33 additions & 0 deletions forest/drivers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,34 @@
from importlib import import_module
from forest.exceptions import DriverNotFound
from functools import wraps


_CACHE = {}


def _cache(f):
# Ensure per-server dataset instances
def wrapped(driver_name, settings=None):
uid = _uid(driver_name, settings)
if uid not in _CACHE:
_CACHE[uid] = f(driver_name, settings)
return _CACHE[uid]
return wrapped


def _uid(driver_name, settings):
if settings is None:
return (driver_name,)
return (driver_name,) + tuple(settings[k] for k in sorted(settings.keys()))


@_cache
def get_dataset(driver_name, settings=None):
"""Find Dataset related to file type"""
if settings is None:
settings = {}
try:
module = import_module(f"forest.drivers.{driver_name}")
except ModuleNotFoundError:
raise DriverNotFound(driver_name)
return module.Dataset(**settings)
27 changes: 25 additions & 2 deletions forest/earth_networks.py → forest/drivers/earth_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,50 @@
import numpy as np


class Dataset:
"""High-level class to relate navigators, loaders and views"""
def __init__(self, pattern=None):
self.pattern = pattern
if pattern is not None:
self._paths = glob.glob(pattern)
else:
self._paths = []

def navigator(self):
"""Construct navigator"""
return Navigator(self._paths)

def map_view(self):
"""Construct view"""
return View(Loader(self._paths))


class View(object):
def __init__(self, loader):
self.loader = loader
palette = bokeh.palettes.all_palettes['Spectral'][11][::-1]
self.color_mapper = bokeh.models.LinearColorMapper(low=-1000, high=0, palette=palette)
self.source = bokeh.models.ColumnDataSource({
self.empty_image = {
"x": [],
"y": [],
"date": [],
"longitude": [],
"latitude": [],
"flash_type": [],
"time_since_flash": []
})
}
self.source = bokeh.models.ColumnDataSource(self.empty_image)

@old_state
@unique
def render(self, state):
if state.valid_time is None:
return

valid_time = _to_datetime(state.valid_time)
frame = self.loader.load_date(valid_time)
if len(frame) == 0:
return self.empty_image
x, y = geo.web_mercator(
frame.longitude,
frame.latitude)
Expand Down
8 changes: 8 additions & 0 deletions forest/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
class DriverNotFound(Exception):
pass


class UnknownFileType(Exception):
pass


class FileNotFound(Exception):
pass

Expand Down
6 changes: 2 additions & 4 deletions forest/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import os
from forest.export import export
from forest import (
exceptions,
data,
db,
earth_networks,
gridded_forecast,
unified_model,
rdt,
Expand Down Expand Up @@ -134,8 +134,6 @@ def file_loader(file_type, pattern, label=None, locator=None):
return rdt.Loader(pattern)
elif file_type == 'gpm':
return data.GPM(pattern)
elif file_type == 'earthnetworks':
return earth_networks.Loader.pattern(pattern)
elif file_type == 'eida50':
return satellite.EIDA50(pattern)
elif file_type == 'griddedforecast':
Expand All @@ -151,4 +149,4 @@ def file_loader(file_type, pattern, label=None, locator=None):
elif file_type == 'nearcast':
return nearcast.NearCast(pattern)
else:
raise Exception("unrecognised file_type: {}".format(file_type))
raise exceptions.UnknownFileType("unrecognised file_type: {}".format(file_type))
56 changes: 34 additions & 22 deletions forest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import glob
from forest import _profile as profile
from forest import (
drivers,
exceptions,
satellite,
screen,
tools,
series,
data,
load,
view,
earth_networks,
rdt,
nearcast,
geo,
Expand Down Expand Up @@ -95,32 +96,43 @@ def main(argv=None):
database = None
if group.locator == "database":
database = db.get_database(group.database_path)
loader = load.Loader.group_args(
group, args, database=database)
try:
loader = load.Loader.group_args(
group, args, database=database)
except exceptions.UnknownFileType:
# TODO: Deprecate load.Loader.group_args()
continue
data.add_loader(group.label, loader)

renderers = {}
viewers = {}
for name, loader in data.LOADERS.items():
if isinstance(loader, rdt.Loader):
viewer = rdt.View(loader)
elif isinstance(loader, earth_networks.Loader):
viewer = earth_networks.View(loader)
elif isinstance(loader, data.GPM):
viewer = view.GPMView(loader, color_mapper)
elif isinstance(loader, satellite.EIDA50):
viewer = view.EIDA50(loader, color_mapper)
elif isinstance(loader, nearcast.NearCast):
viewer = view.NearCast(loader, color_mapper)
viewer.set_hover_properties(nearcast.NEARCAST_TOOLTIPS)
elif isinstance(loader, intake_loader.IntakeLoader):
viewer = view.UMView(loader, color_mapper)
viewer.set_hover_properties(intake_loader.INTAKE_TOOLTIPS,
intake_loader.INTAKE_FORMATTERS)
for group in config.file_groups:
if group.label in data.LOADERS:
loader = data.LOADERS[group.label]
if isinstance(loader, rdt.Loader):
viewer = rdt.View(loader)
elif isinstance(loader, data.GPM):
viewer = view.GPMView(loader, color_mapper)
elif isinstance(loader, satellite.EIDA50):
viewer = view.EIDA50(loader, color_mapper)
elif isinstance(loader, nearcast.NearCast):
viewer = view.NearCast(loader, color_mapper)
viewer.set_hover_properties(nearcast.NEARCAST_TOOLTIPS)
elif isinstance(loader, intake_loader.IntakeLoader):
viewer = view.UMView(loader, color_mapper)
viewer.set_hover_properties(intake_loader.INTAKE_TOOLTIPS,
intake_loader.INTAKE_FORMATTERS)
else:
viewer = view.UMView(loader, color_mapper)
else:
viewer = view.UMView(loader, color_mapper)
viewers[name] = viewer
renderers[name] = [
# Use dataset interface
settings = {
"pattern": group.pattern
}
dataset = drivers.get_dataset(group.file_type, settings)
viewer = dataset.map_view()
viewers[group.label] = viewer
renderers[group.label] = [
viewer.add_figure(f)
for f in figures]

Expand Down
17 changes: 13 additions & 4 deletions forest/navigate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
ValidTimesNotFound,
PressuresNotFound)
from forest import (
earth_networks,
exceptions,
drivers,
db,
gridded_forecast,
unified_model,
Expand All @@ -17,7 +18,8 @@
saf,
nearcast)

from forest.drivers import ghrsstl4
from forest.drivers import (
ghrsstl4)


class Navigator:
Expand Down Expand Up @@ -80,6 +82,15 @@ def __init__(self, paths, coordinates=None):

@classmethod
def from_file_type(cls, paths, file_type, pattern=None):
try:
settings = {
"pattern": pattern}
dataset = drivers.get_dataset(file_type, settings)
return dataset.navigator()
except exceptions.DriverNotFound:
# TODO: Migrate all file types to forest.drivers
pass

if file_type.lower() == "rdt":
coordinates = rdt.Coordinates()
elif file_type.lower() == "eida50":
Expand All @@ -95,8 +106,6 @@ def from_file_type(cls, paths, file_type, pattern=None):
coordinates = unified_model.Coordinates()
elif file_type.lower() == "saf":
coordinates = saf.Coordinates()
elif file_type.lower() == "earth_networks":
return earth_networks.Navigator(paths)
elif file_type.lower() == "nearcast":
return nearcast.Navigator(pattern)
else:
Expand Down
6 changes: 4 additions & 2 deletions forest/old_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ def wrapper(*args):

if len(args) == 2:
self, value = args
key = (id(self), value) # Distinguish wrapped methods
else:
value, = args
key = value

if (not called) or (value != previous):
if (not called) or (key != previous):
called = True
previous = value
previous = key
if len(args) == 2:
result = f(self, value)
else:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
bokeh
bokeh=1.4.0 # Port to 2.0.0 in future
datashader
iris
intake
Expand Down
9 changes: 9 additions & 0 deletions test/test_drivers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from forest import drivers


def test_singleton_dataset():
driver_name = "earth_networks"
datasets = (
drivers.get_dataset(driver_name),
drivers.get_dataset(driver_name))
assert id(datasets[0]) == id(datasets[1])
18 changes: 17 additions & 1 deletion test/test_earth_networks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime as dt
import numpy as np
import glob
from forest import earth_networks
import forest.drivers
from forest.drivers import earth_networks


LINES = [
Expand All @@ -28,3 +29,18 @@ def test_earth_networks(tmpdir):
assert result["flash_type"] == "IC"
assert abs(result["latitude"] - 2.75144) < atol
assert abs(result["longitude"] - 31.92064) < atol


def test_dataset():
dataset = forest.drivers.get_dataset("earth_networks")
assert isinstance(dataset, forest.drivers.earth_networks.Dataset)


def test_dataset_navigator():
settings = {
"pattern": "*.txt"
}
dataset = forest.drivers.get_dataset("earth_networks", settings)
navigator = dataset.navigator()
assert isinstance(navigator,
forest.drivers.earth_networks.Navigator)
8 changes: 4 additions & 4 deletions test/test_load.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import yaml
import forest
from forest import main
from forest import main, rdt


def test_earth_networks_loader_given_pattern():
loader = forest.Loader.from_pattern("Label", "EarthNetworks*.txt", "earth_networks")
assert isinstance(loader, forest.earth_networks.Loader)
def test_rdt_loader_given_pattern():
loader = forest.Loader.from_pattern("Label", "RDT*.json", "rdt")
assert isinstance(loader, rdt.Loader)


def test_build_loader_given_files():
Expand Down

0 comments on commit 149a1e2

Please sign in to comment.