Skip to content

Commit

Permalink
refactor fit spectrum plot
Browse files Browse the repository at this point in the history
  • Loading branch information
MyPyDavid committed Feb 18, 2024
1 parent b836bd9 commit a589d63
Showing 1 changed file with 198 additions and 129 deletions.
327 changes: 198 additions & 129 deletions src/raman_fitting/exports/plotting_fit_results.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,196 @@
# ruff: noqa
from typing import Dict

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.ticker import AutoMinorLocator, FormatStrFormatter, MultipleLocator
from matplotlib.axes import Axes

from matplotlib.text import Text
from matplotlib.ticker import AutoMinorLocator

from raman_fitting.imports.samples.models import SampleMetaData
from raman_fitting.models.fit_models import SpectrumFitModel

matplotlib.rcParams.update({"font.size": 14})

from raman_fitting.config.settings import ExportPathSettings
from raman_fitting.models.splitter import WindowNames
from raman_fitting.delegating.models import AggregatedSampleSpectrumFitResult

from loguru import logger
# TODO fix big spectrum plot


matplotlib.rcParams.update({"font.size": 14})
# TODO fix big spectrum plot, reduce complexity if-statements


def fit_spectrum_plot(
aggregated_spectra: Dict[WindowNames, AggregatedSampleSpectrumFitResult],
export_paths: ExportPathSettings | None = None,
plot_Annotation=True,
plot_Residuals=True,
plot_annotation=True,
plot_residuals=True,
): # pragma: no cover
# %%
first_order = aggregated_spectra[WindowNames.first_order]
second_order = aggregated_spectra[WindowNames.second_order]

sources = first_order.aggregated_spectrum.sources
sample = sources[0].file_info.sample
# first_model_name = "4peaks"
second_model_name = "2nd_4peaks"
# first_model = [first_model_name]
second_model = second_order.fit_model_results.get(second_model_name)
# breakpoint()
for first_model_name, first_model in first_order.fit_model_results.items():
# for second_model_name, second_model in second_order.fit_model_results.items():
prepare_combined_spectrum_fit_result_plot(
first_model,
second_model,
sample,
export_paths,
plot_annotation=plot_annotation,
plot_residuals=plot_residuals,
)


def prepare_combined_spectrum_fit_result_plot(
first_model: SpectrumFitModel,
second_model: SpectrumFitModel,
sample: SampleMetaData,
export_paths: ExportPathSettings,
plot_annotation=True,
plot_residuals=True,
):
plt.figure(figsize=(28, 24))
gs = gridspec.GridSpec(4, 1, height_ratios=[4, 1, 4, 1])
ax = plt.subplot(gs[0])
ax_res = plt.subplot(gs[1])
ax.set_title(f"{sample.id}")
# breakpoint()

first_model_name = first_model.model.name

fit_plot_first(ax, ax_res, first_model, plot_residuals=plot_residuals)

ax2nd = plt.subplot(gs[2])
ax2nd_res = plt.subplot(gs[3])

if second_model is not None:
fit_plot_second(ax2nd, ax2nd_res, second_model, plot_residuals=plot_residuals)

_bbox_artists = tuple()
if plot_annotation:
annotate_report_first = prepare_annotate_fit_report_first(
ax, first_model.fit_result
)
_bbox_artists = (annotate_report_first,)
if second_model is not None:
annotate_report_second = prepare_annotate_fit_report_second(
ax2nd, second_model.fit_result
)
if annotate_report_second is not None:
_bbox_artists = (annotate_report_first, annotate_report_second)

first_model = first_order.fit_model_results["4peaks"]
# set axes labels and legend
set_axes_labels_and_legend(ax)

plot_special_si_components(ax, first_model, first_model_name)

if export_paths is not None:
savepath = export_paths.plots.joinpath(f"Model_{first_model_name}").with_suffix(
".png"
)
plt.savefig(
savepath,
dpi=100,
bbox_extra_artists=_bbox_artists,
bbox_inches="tight",
)
logger.debug(f"Plot saved to {savepath}")
plt.close()


def fit_plot_first(
ax, ax_res, first_model: SpectrumFitModel, plot_residuals: bool = True
) -> matplotlib.text.Text | None:
first_result = first_model.fit_result
first_components = first_model.fit_result.components
first_eval_comps = first_model.fit_result.eval_components()
first_model_name = first_model.model.name
first_pars = first_model.fit_result.best_values

if second_order.fit_model_results:
second_model = second_order.fit_model_results["2nd_4peaks"]
second_result = second_model.fit_result
second_components = second_model.fit_result.components
second_eval_comps = second_model.fit_result.eval_components()
second_model_name = second_model.model.name
second_pars = second_model.fit_result.best_values
else:
second_order = None
second_components = []
second_result = None

plt.figure(figsize=(28, 24))
gs = gridspec.GridSpec(4, 1, height_ratios=[4, 1, 4, 1])
ax = plt.subplot(gs[0])
ax.grid(True, "both")
axRes = plt.subplot(gs[1])
axRes.grid(True, "both")
if second_order:
ax2nd = plt.subplot(gs[2])
ax2nd.grid(True)
ax2ndRes = plt.subplot(gs[3])
ax2ndRes.grid(True)
ax_res.grid(True, "both")
ax.get_yaxis().set_tick_params(direction="in")
ax.get_xaxis().set_tick_params(direction="in")
ax.set_title(f"{sample.id}")

ax.xaxis.set_minor_locator(AutoMinorLocator(2))
ax.yaxis.set_minor_locator(AutoMinorLocator(2))
ax.tick_params(which="both", direction="in")
ax.set_facecolor("oldlace")
axRes.set_facecolor("oldlace")
if second_order:
ax_res.set_facecolor("oldlace")
ax.plot(
first_model.spectrum.ramanshift,
first_result.best_fit,
label=first_model_name,
lw=3,
c="r",
)
ax.plot(
first_model.spectrum.ramanshift,
first_result.data,
label="Data",
lw=3,
c="grey",
alpha=0.8,
)

if plot_residuals:
ax_res.plot(
first_model.spectrum.ramanshift,
first_result.residual,
label="Residual",
lw=3,
c="k",
alpha=0.8,
)

for _component in first_components: # automatic color cycle 'cyan' ...
peak_name = _component.prefix.rstrip("_")
ax.plot(
first_model.spectrum.ramanshift,
first_eval_comps[_component.prefix],
ls="--",
lw=4,
label=peak_name,
)
center_col = _component.prefix + "center"
ax.annotate(
f"{peak_name}:\n {first_result.best_values[center_col]:.0f}",
xy=(
first_result.best_values[center_col] * 0.97,
0.7 * first_result.params[_component.prefix + "height"].value,
),
xycoords="data",
)


def fit_plot_second(
ax2nd, ax2nd_res, second_model: SpectrumFitModel, plot_residuals: bool = True
) -> None:
if second_model:
second_result = second_model.fit_result
second_components = second_model.fit_result.components
second_eval_comps = second_model.fit_result.eval_components()
second_model_name = second_model.model.name
else:
second_components = []
second_result = None
if second_model:
ax2nd.grid(True)
ax2nd_res.grid(True)
ax2nd.xaxis.set_minor_locator(AutoMinorLocator(2))
ax2nd.yaxis.set_minor_locator(AutoMinorLocator(2))
ax2nd.tick_params(which="both", direction="in")
ax2nd.set_facecolor("oldlace")
ax2ndRes.set_facecolor("oldlace")
ax2nd_res.set_facecolor("oldlace")
ax2nd.plot(
second_model.spectrum.ramanshift,
second_result.best_fit,
Expand All @@ -89,8 +206,8 @@ def fit_spectrum_plot(
c="grey",
alpha=0.5,
)
if plot_Residuals:
ax2ndRes.plot(
if plot_residuals:
ax2nd_res.plot(
second_model.spectrum.ramanshift,
second_result.residual,
label="Residual",
Expand Down Expand Up @@ -118,52 +235,48 @@ def fit_spectrum_plot(
xycoords="data",
)
ax2nd.set_ylim(-0.02, second_result.data.max() * 1.5)
ax.plot(
first_model.spectrum.ramanshift,
first_result.best_fit,
label=first_model_name,
lw=3,
c="r",
)
ax.plot(
first_model.spectrum.ramanshift,
first_result.data,
label="Data",
lw=3,
c="grey",
alpha=0.8,

set_axes_labels_and_legend(ax2nd)


def prepare_annotate_fit_report_second(ax2nd, second_result) -> Text:
props = dict(boxstyle="round", facecolor="wheat", alpha=0.5)
annotate_report_second = ax2nd.text(
1.01,
0.7,
second_result.fit_report(),
transform=ax2nd.transAxes,
fontsize=11,
verticalalignment="top",
bbox=props,
)

if plot_Residuals:
axRes.plot(
first_model.spectrum.ramanshift,
first_result.residual,
label="Residual",
lw=3,
c="k",
alpha=0.8,
)
return annotate_report_second

for _component in first_components: # automatic color cycle 'cyan' ...
peak_name = _component.prefix.rstrip("_")
ax.plot(
first_model.spectrum.ramanshift,
first_eval_comps[_component.prefix],
ls="--",
lw=4,
label=peak_name,
)
center_col = _component.prefix + "center"
ax.annotate(
f"{peak_name}:\n {first_result.best_values[center_col]:.0f}",
xy=(
first_result.best_values[center_col] * 0.97,
0.7 * first_result.params[_component.prefix + "height"].value,
),
xycoords="data",
)

for si_comp in [i for i in first_components if i.prefix.startswith("Si")]:
def prepare_annotate_fit_report_first(ax, first_result):
fit_report = first_result.fit_report()
if len(fit_report) > -1: # TODO remove
fit_report = fit_report.replace("prefix='D3_'", "prefix='D3_' \n")
props = dict(boxstyle="round", facecolor="wheat", alpha=0.5)

annotate_report_first = ax.text(
1.01,
1,
fit_report,
transform=ax.transAxes,
fontsize=11,
verticalalignment="top",
bbox=props,
)
return annotate_report_first


def plot_special_si_components(ax, first_model, model_name: str):
first_result = first_model.fit_result
si_components = filter(lambda x: x.prefix.startswith("Si"), first_result.components)
first_eval_comps = first_model.fit_result.eval_components()
for si_comp in si_components:
si_result = si_comp
# TODO should be si_fit_results
ax.plot(
Expand All @@ -182,54 +295,10 @@ def fit_spectrum_plot(
),
xycoords="data",
)
if plot_Annotation:
fit_report = first_result.fit_report()
if len(fit_report) > -1: # TODO remove
fit_report = fit_report.replace("prefix='D3_'", "prefix='D3_' \n")
props = dict(boxstyle="round", facecolor="wheat", alpha=0.5)

Report1 = ax.text(
1.01,
1,
fit_report,
transform=ax.transAxes,
fontsize=11,
verticalalignment="top",
bbox=props,
)
_bbox_artists = (Report1,)
if second_result:
Report2 = ax2nd.text(
1.01,
0.7,
second_result.fit_report(),
transform=ax2nd.transAxes,
fontsize=11,
verticalalignment="top",
bbox=props,
)
_bbox_artists = (Report1, Report2)

(
ax.legend(loc=1),
ax.set_xlabel("Raman shift (cm$^{-1}$)"),
ax.set_ylabel("normalized I / a.u."),
)
if second_order:
(
ax2nd.legend(loc=1),
ax2nd.set_xlabel("Raman shift (cm$^{-1}$)"),
ax2nd.set_ylabel("normalized I / a.u."),
)
if export_paths is not None:
savepath = export_paths.plots.joinpath(f"Model_{first_model_name}").with_suffix(
".png"
)
plt.savefig(
savepath,
dpi=100,
bbox_extra_artists=_bbox_artists,
bbox_inches="tight",
)
logger.debug(f"Plot saved to {savepath}")
plt.close()

def set_axes_labels_and_legend(ax: Axes):
# set axes labels and legend
ax.legend(loc=1)
ax.set_xlabel("Raman shift (cm$^{-1}$)")
ax.set_ylabel("normalized I / a.u.")

0 comments on commit a589d63

Please sign in to comment.