Skip to content

[ENH] Enable full plot range for fm.plot() #246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions fooof/core/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def get_description():
'settings' : ['peak_width_limits', 'max_n_peaks',
'min_peak_height', 'peak_threshold',
'aperiodic_mode'],
'data' : ['power_spectrum', 'freq_range', 'freq_res'],
'data' : ['power_spectrum', 'freq_range', 'freq_res', 'freqs_full', 'power_spectrum_full'],
'meta_data' : ['freq_range', 'freq_res'],
'arrays' : ['freqs', 'power_spectrum', 'aperiodic_params_',
'peak_params_', 'gaussian_params_'],
'peak_params_', 'gaussian_params_',
'freqs_full', 'power_spectrum_full'],
'model_components' : ['fooofed_spectrum_', '_spectrum_flat',
'_spectrum_peak_rm', '_ap_fit', '_peak_fit'],
'descriptors' : ['has_data', 'has_model', 'n_peaks_']
Expand Down
4 changes: 2 additions & 2 deletions fooof/core/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
###################################################################################################

@check_dependency(plt, 'matplotlib')
def save_report_fm(fm, file_name, file_path=None, plt_log=False):
def save_report_fm(fm, file_name, file_path=None, plt_log=False, plot_range=None):
"""Generate and save out a PDF report for a power spectrum model fit.

Parameters
Expand Down Expand Up @@ -51,7 +51,7 @@ def save_report_fm(fm, file_name, file_path=None, plt_log=False):

# Second - data plot
ax1 = plt.subplot(grid[1])
fm.plot(plt_log=plt_log, ax=ax1)
fm.plot(plt_log=plt_log, plot_range=plot_range, ax=ax1)

# Third - FOOOF settings
ax2 = plt.subplot(grid[2])
Expand Down
36 changes: 28 additions & 8 deletions fooof/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ class FOOOF():
Whether data is loaded to the object.
has_model : bool
Whether model results are available in the object.
freqs_full : 1d array
Frequency values for the full power spectrum (entire frequency range).
power_spectrum_full : 1d array
Power values for the full power spectrum (entire frequency range).
Stored internally in log10 scale.

Notes
-----
Expand Down Expand Up @@ -270,9 +275,11 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_res
self.freqs = None
self.freq_range = None
self.freq_res = None
self.freqs_full = None

if clear_spectrum:
self.power_spectrum = None
self.power_spectrum_full = None

if clear_results:

Expand Down Expand Up @@ -320,7 +327,8 @@ def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True):
clear_spectrum=self.has_data,
clear_results=self.has_model and clear_results)

self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \
self.freqs, self.power_spectrum, self.freq_range, self.freq_res, \
self.freqs_full, self.power_spectrum_full = \
self._prepare_data(freqs, power_spectrum, freq_range, 1)


Expand Down Expand Up @@ -372,7 +380,8 @@ def add_results(self, fooof_result):
self._check_loaded_results(fooof_result._asdict())


def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False):
def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False,
plot_range=None):
"""Run model fit, and display a report, which includes a plot, and printed results.

Parameters
Expand All @@ -393,7 +402,7 @@ def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False
"""

self.fit(freqs, power_spectrum, freq_range)
self.plot(plt_log=plt_log)
self.plot(plt_log=plt_log, plot_range=plot_range)
self.print_results(concise=False)


Expand Down Expand Up @@ -633,18 +642,20 @@ def get_results(self):
def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False,
add_legend=True, save_fig=False, file_name=None, file_path=None,
ax=None, data_kwargs=None, model_kwargs=None,
aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs):
aperiodic_kwargs=None, peak_kwargs=None, plot_range=None,
**plot_kwargs):

plot_fm(self, plot_peaks=plot_peaks, plot_aperiodic=plot_aperiodic, plt_log=plt_log,
add_legend=add_legend, save_fig=save_fig, file_name=file_name,
file_path=file_path, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs,
aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **plot_kwargs)
aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, plot_range=plot_range,
**plot_kwargs)


@copy_doc_func_to_method(save_report_fm)
def save_report(self, file_name, file_path=None, plt_log=False):
def save_report(self, file_name, file_path=None, plt_log=False, plot_range=None):

save_report_fm(self, file_name, file_path, plt_log)
save_report_fm(self, file_name, file_path, plt_log, plot_range)


@copy_doc_func_to_method(save_fm)
Expand Down Expand Up @@ -1163,6 +1174,10 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
Minimum and maximum values of the frequency vector.
freq_res : float
Frequency resolution of the power spectrum.
freqs_full : 1d array
Frequency values for the full power_spectrum, in linear space.
power_spectrum_full : 1d or 2d array
Full power spectrum values, in log10 scale.

Raises
------
Expand Down Expand Up @@ -1197,6 +1212,10 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
if power_spectrum.dtype != 'float64':
power_spectrum = power_spectrum.astype('float64')

# Add full data to object for full frequency range plotting
freqs_full = freqs
power_spectrum_full = power_spectrum

# Check frequency range, trim the power_spectrum range if requested
if freq_range:
freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, freq_range)
Expand All @@ -1215,6 +1234,7 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):

# Log power values
power_spectrum = np.log10(power_spectrum)
power_spectrum_full = np.log10(power_spectrum_full)

if self._check_data:
# Check if there are any infs / nans, and raise an error if so
Expand All @@ -1224,7 +1244,7 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
"One reason this can happen is if inputs are already logged. "
"Inputs data should be in linear spacing, not log.")

return freqs, power_spectrum, freq_range, freq_res
return freqs, power_spectrum, freq_range, freq_res, freqs_full, power_spectrum_full


def _add_from_dict(self, data):
Expand Down
8 changes: 7 additions & 1 deletion fooof/objs/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class FOOOFGroup(FOOOF):
The number of models that failed to fit.
failed_fit_inds_ : list of int
The indices of any models that failed to fit.
freqs_full : 1d array
Frequency values for the full power spectrum (entire frequency range).
power_spectra_full : 1d array
Power values for the full power spectrum (entire frequency range).
Stored internally in log10 scale.

Notes
-----
Expand Down Expand Up @@ -221,7 +226,8 @@ def add_data(self, freqs, power_spectra, freq_range=None):
self._reset_data_results(True, True, True, True)
self._reset_group_results()

self.freqs, self.power_spectra, self.freq_range, self.freq_res = \
self.freqs, self.power_spectra, self.freq_range, self.freq_res, \
self.freqs_full, self.power_spectra_full = \
self._prepare_data(freqs, power_spectra, freq_range, 2)


Expand Down
14 changes: 12 additions & 2 deletions fooof/plts/fm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
@check_dependency(plt, 'matplotlib')
def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=True,
save_fig=False, file_name=None, file_path=None, ax=None, data_kwargs=None,
model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs):
model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, plot_range=None,
**plot_kwargs):
"""Plot the power spectrum and model fit results from a FOOOF object.

Parameters
Expand All @@ -53,6 +54,8 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=
Figure axes upon which to plot.
data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs : None or dict, optional
Keyword arguments to pass into the plot call for each plot element.
plot_range : tuple of (float, float), optional, default: None
Frequency range to plot. If None, plots the fitting range.
**plot_kwargs
Keyword arguments to pass into the ``style_plot``.

Expand All @@ -73,7 +76,14 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=
data_defaults = {'color' : PLT_COLORS['data'], 'linewidth' : 2.0,
'label' : 'Original Spectrum' if add_legend else None}
data_kwargs = check_plot_kwargs(data_kwargs, data_defaults)
plot_spectra(fm.freqs, fm.power_spectrum, log_freqs, log_powers, ax=ax, **data_kwargs)
if plot_range is None or fm.power_spectrum_full is None:
freqs_plot = fm.freqs
powers_plot = fm.power_spectrum
else:
if plot_range[0] > fm.freq_range[0] or plot_range[1] < fm.freq_range[1]:
raise ValueError(f"Plot range must be larger than the fitting range {fm.freq_range}.")
freqs_plot, powers_plot = trim_spectrum(fm.freqs_full, fm.power_spectrum_full, plot_range)
plot_spectra(freqs_plot, powers_plot, log_freqs, log_powers, ax=ax, **data_kwargs)

# Add the full model fit, and components (if requested)
if fm.has_model:
Expand Down