Skip to content

Commit

Permalink
Merge pull request mila-iqia#28 from jbornschein/plot
Browse files Browse the repository at this point in the history
Various minor fixes for blocks-plot
  • Loading branch information
rizar committed Oct 19, 2015
2 parents a6dd519 + 02cd0f5 commit d1daab2
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
5 changes: 3 additions & 2 deletions bin/blocks-plot
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def plot_dataframe(dataframe):
print("Plotting {} channels:".format(len(dataframe.columns)))
for cname, series in iteritems(dataframe):
print(" {}".format(cname))
pylab.plot(t, series.interpolate(), label=cname)
pylab.plot(t, series.interpolate(method='nearest'), label=cname)
pylab.legend()
pylab.show(block=True)


def main(args):
import readline
import blocks.scripts.plot as plot
import blocks.extras.scripts.plot as plot

from six import iteritems
from six.moves import input
Expand Down Expand Up @@ -121,6 +121,7 @@ def main(args):
break

column_specs = column_spec.split(',')
column_specs = [s.strip() for s in column_specs]
matched = plot.match_column_specs(experiments, column_specs)

if len(matched.columns) == 0:
Expand Down
8 changes: 5 additions & 3 deletions blocks/extras/scripts/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from blocks.serialization import load

try:
from pandas import DataFrame
import pandas
PANDAS_AVAILABLE = True
except ImportError:
PANDAS_AVAILABLE = False
Expand Down Expand Up @@ -95,7 +95,7 @@ def match_column_specs(experiments, column_specs):
" install it with pip.")
# We iterate over all column and match each spec to the
# channels of all experiments.
df = DataFrame()
df = pandas.DataFrame()
for spec in column_specs:
if ":" in spec:
exp_spec, column_spec = spec.split(":")
Expand All @@ -111,6 +111,8 @@ def match_column_specs(experiments, column_specs):
continue

column_name = "{}:{}".format(i, column)
df[column_name] = exp[column]

exp = exp.rename(columns={column: column_name})
df = pandas.concat((df, exp[column_name]), axis=1)

return df
25 changes: 23 additions & 2 deletions tests/scripts/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from collections import OrderedDict
from tests import silence_printing, skip_if_not_available
from numpy import nan, isfinite

from blocks.log import TrainingLog
from blocks.main_loop import MainLoop
Expand All @@ -22,8 +23,8 @@ def some_experiments():
experiments['exp0']['col0'] = (0, 1, 2)
experiments['exp0']['col1'] = (3, 4, 5)
experiments['exp1'] = DataFrame()
experiments['exp1']['col0'] = (6, 7, 8)
experiments['exp1']['col1'] = (9, 9, 9)
experiments['exp1']['col0'] = (6, 7, 8, 9)
experiments['exp1']['col1'] = (9, 9, 9, 9)
return experiments


Expand Down Expand Up @@ -66,3 +67,23 @@ def test_match_column_specs():

assert isinstance(df, DataFrame)
assert list(df.columns) == ['0:col0', '0:col1', '1:col1']
assert list(df.index) == [0, 1, 2, 3]


def test_interpolate():
skip_if_not_available(modules=['pandas'])
""" Ensure tha DataFrame.interpolate(method='nearest') has the
desired properties.
It is used by blocks-plot and should:
* interpolate missing/NaN datapoints between valid ones
* not replace any NaN before/after the first/last finite datapoint
"""
y = [nan, nan, 2., 3., nan, 5, nan, nan]
df = DataFrame(y)
df_ = df.interpolate(method='nearest')[0]

assert all(isfinite(df_[2:6]))
assert all(~isfinite(df_[0:2]))
assert all(~isfinite(df_[6:8]))

0 comments on commit d1daab2

Please sign in to comment.