diff --git a/backtesting/_plotting.py b/backtesting/_plotting.py index 844318aa..139a709b 100644 --- a/backtesting/_plotting.py +++ b/backtesting/_plotting.py @@ -3,7 +3,7 @@ import sys import warnings from colorsys import hls_to_rgb, rgb_to_hls -from itertools import cycle, combinations +from itertools import cycle, combinations, repeat from functools import partial from typing import Callable, List, Union @@ -537,10 +537,20 @@ def __eq__(self, other): colors = value._opts['color'] colors = colors and cycle(_as_list(colors)) or ( cycle([next(ohlc_colors)]) if is_overlay else colorgen()) - legend_label = LegendStr(value.name) - for j, arr in enumerate(value, 1): + + tooltip_label = value.name if isinstance(value.name, str) else ", ".join(value.name) + + if isinstance(value.name, str) and len(value) > 1: + legend_labels = [ + LegendStr(f"{name}[{index}]") + for index, name in enumerate(repeat(value.name, len(value))) + ] + else: + legend_labels = [LegendStr(item) for item in _as_list(value.name)] + + for j, arr in enumerate(value): color = next(colors) - source_name = f'{legend_label}_{i}_{j}' + source_name = f'{legend_labels[j]}_{i}_{j}' if arr.dtype == bool: arr = arr.astype(int) source.add(arr, source_name) @@ -550,24 +560,24 @@ def __eq__(self, other): if is_scatter: fig.scatter( 'index', source_name, source=source, - legend_label=legend_label, color=color, + legend_label=legend_labels[j], color=color, line_color='black', fill_alpha=.8, marker='circle', radius=BAR_WIDTH / 2 * 1.5) else: fig.line( 'index', source_name, source=source, - legend_label=legend_label, line_color=color, + legend_label=legend_labels[j], line_color=color, line_width=1.3) else: if is_scatter: r = fig.scatter( 'index', source_name, source=source, - legend_label=LegendStr(legend_label), color=color, + legend_label=legend_labels[j], color=color, marker='circle', radius=BAR_WIDTH / 2 * .9) else: r = fig.line( 'index', source_name, source=source, - legend_label=LegendStr(legend_label), line_color=color, + legend_label=legend_labels[j], line_color=color, line_width=1.3) # Add dashed centerline just because mean = float(pd.Series(arr).mean()) @@ -578,9 +588,9 @@ def __eq__(self, other): line_color='#666666', line_dash='dashed', line_width=.5)) if is_overlay: - ohlc_tooltips.append((legend_label, NBSP.join(tooltips))) + ohlc_tooltips.append((tooltip_label, NBSP.join(tooltips))) else: - set_tooltips(fig, [(legend_label, NBSP.join(tooltips))], vline=True, renderers=[r]) + set_tooltips(fig, [(tooltip_label, NBSP.join(tooltips))], vline=True, renderers=[r]) # If the sole indicator line on this figure, # have the legend only contain text without the glyph if len(value) == 1: diff --git a/backtesting/backtesting.py b/backtesting/backtesting.py index 9c168703..3f8a1e0d 100644 --- a/backtesting/backtesting.py +++ b/backtesting/backtesting.py @@ -90,7 +90,9 @@ def I(self, # noqa: E743 same length as `backtesting.backtesting.Strategy.data`. In the plot legend, the indicator is labeled with - function name, unless `name` overrides it. + function name, unless `name` overrides it. If `func` returns + multiple arrays, `name` can be a sequence of strings, and + its size must agree with the number of arrays returned. If `plot` is `True`, the indicator is plotted on the resulting `backtesting.backtesting.Backtest.plot`. @@ -115,13 +117,21 @@ def I(self, # noqa: E743 def init(): self.sma = self.I(ta.SMA, self.data.Close, self.n_sma) """ + def _format_name(name: str) -> str: + return name.format(*map(_as_str, args), + **dict(zip(kwargs.keys(), map(_as_str, kwargs.values())))) + if name is None: params = ','.join(filter(None, map(_as_str, chain(args, kwargs.values())))) func_name = _as_str(func) name = (f'{func_name}({params})' if params else f'{func_name}') + elif isinstance(name, str): + name = _format_name(name) + elif try_(lambda: all(isinstance(item, str) for item in name), False): + name = [_format_name(item) for item in name] else: - name = name.format(*map(_as_str, args), - **dict(zip(kwargs.keys(), map(_as_str, kwargs.values())))) + raise TypeError(f'Unexpected `name=` type {type(name)}; expected `str` or ' + '`Sequence[str]`') try: value = func(*args, **kwargs) @@ -139,6 +149,11 @@ def init(): if is_arraylike and np.argmax(value.shape) == 0: value = value.T + if isinstance(name, list) and (np.atleast_2d(value).shape[0] != len(name)): + raise ValueError( + f'Length of `name=` ({len(name)}) must agree with the number ' + f'of arrays the indicator returns ({value.shape[0]}).') + if not is_arraylike or not 1 <= value.ndim <= 2 or value.shape[-1] != len(self._data.Close): raise ValueError( 'Indicators must return (optionally a tuple of) numpy.arrays of same ' diff --git a/backtesting/test/_test.py b/backtesting/test/_test.py index d8d87814..af711d09 100644 --- a/backtesting/test/_test.py +++ b/backtesting/test/_test.py @@ -755,6 +755,37 @@ def test_resample(self): # Give browser time to open before tempfile is removed time.sleep(1) + def test_indicator_name(self): + test_self = self + + class S(Strategy): + def init(self): + def _SMA(): + return SMA(self.data.Close, 5), SMA(self.data.Close, 10) + + test_self.assertRaises(TypeError, self.I, _SMA, name=42) + test_self.assertRaises(ValueError, self.I, _SMA, name=("SMA One", )) + test_self.assertRaises( + ValueError, self.I, _SMA, name=("SMA One", "SMA Two", "SMA Three")) + + for overlay in (True, False): + self.I(SMA, self.data.Close, 5, overlay=overlay) + self.I(SMA, self.data.Close, 5, name="My SMA", overlay=overlay) + self.I(SMA, self.data.Close, 5, name=("My SMA", ), overlay=overlay) + self.I(_SMA, overlay=overlay) + self.I(_SMA, name="My SMA", overlay=overlay) + self.I(_SMA, name=("SMA One", "SMA Two"), overlay=overlay) + + def next(self): + pass + + bt = Backtest(GOOG, S) + bt.run() + with _tempfile() as f: + bt.plot(filename=f, + plot_drawdown=False, plot_equity=False, plot_pl=False, plot_volume=False, + open_browser=False) + def test_indicator_color(self): class S(Strategy): def init(self):