diff --git a/doc/api.rst b/doc/api.rst index d6e685d8..33f16f1d 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -144,6 +144,7 @@ The following functions operate on arrays of peak parameters, which may be usefu get_band_peak_group_arr get_highest_peak threshold_peaks + sort_peaks Measures -------- diff --git a/examples/analyses/plot_dev_demo.py b/examples/analyses/plot_dev_demo.py index 8f633ada..dffde294 100644 --- a/examples/analyses/plot_dev_demo.py +++ b/examples/analyses/plot_dev_demo.py @@ -127,9 +127,9 @@ ################################################################################################### # Access the model fit parameters & related attributes from the model object -print('Aperiodic parameters: \n', fm.results.aperiodic_params_, '\n') -print('Peak parameters: \n', fm.results.peak_params_, '\n') -print('Number of fit peaks: \n', fm.results.n_peaks_) +print('Aperiodic parameters: \n', fm.results.params.aperiodic.params, '\n') +print('Peak parameters: \n', fm.results.params.periodic.params, '\n') +print('Number of fit peaks: \n', fm.results.n_peaks) ################################################################################################### @@ -148,21 +148,21 @@ ################################################################################################### -# Extract aperiodic and periodic parameter -aps = fm.get_params('aperiodic_params') -peaks = fm.get_params('peak_params') +# Extract aperiodic and periodic parameters +aps = fm.get_params('aperiodic') +peaks = fm.get_params('peak') ################################################################################################### -# Extract goodness of fit information -err = fm.get_params('metrics', 'error_mae') -r2s = fm.get_params('metrics', 'gof_rsquared') +# Extract specific parameters +exp = fm.get_params('aperiodic', 'exponent') +cfs = fm.get_params('peak', 'CF') ################################################################################################### -# Extract specific parameters -exp = fm.get_params('aperiodic_params', 'exponent') -cfs = fm.get_params('peak_params', 'CF') +# Extract goodness of fit information +err = fm.get_metrics('error') +r2s = fm.get_metrics('gof') ################################################################################################### @@ -323,19 +323,19 @@ ################################################################################################### # Extract aperiodic and full periodic parameters -aps = fg.get_params('aperiodic_params') -per = fg.get_params('peak_params') +aps = fg.get_params('aperiodic') +per = fg.get_params('peak') ################################################################################################### # Extract group fit information -err = fg.get_params('metrics', 'error_mae') -r2s = fg.get_params('metrics', 'gof_rsquared') +err = fg.get_metrics('error') +r2s = fg.get_metrics('gof') ################################################################################################### # Check the average number of fit peaks, per model -print('Average number of fit peaks: ', np.mean(fg.results.n_peaks_)) +print('Average number of fit peaks: ', np.mean(fg.results.n_peaks)) ################################################################################################### @@ -551,7 +551,7 @@ ################################################################################################### # Find the index of the worst model fit from the group -worst_fit_ind = np.argmax(fg.get_params('metrics', 'error_mae')) +worst_fit_ind = np.argmax(fg.get_metrics('error')) # Extract this model fit from the group fm = fg.get_model(worst_fit_ind, regenerate=True) @@ -669,7 +669,7 @@ ################################################################################################### # Drop poor model fits based on MAE -fg.results.drop(fg.get_params('metrics', 'error_mae') > 0.10) +fg.results.drop(fg.get_metrics('error', 'mae') > 0.10) ################################################################################################### # Conclusions diff --git a/examples/analyses/plot_mne_example.py b/examples/analyses/plot_mne_example.py index 7cc626ae..7546507b 100644 --- a/examples/analyses/plot_mne_example.py +++ b/examples/analyses/plot_mne_example.py @@ -260,7 +260,7 @@ def check_nans(data, nan_policy='zero'): ################################################################################################### # Extract aperiodic exponent values -exps = fg.get_params('aperiodic_params', 'exponent') +exps = fg.get_params('aperiodic', 'exponent') ################################################################################################### diff --git a/examples/manage/plot_failed_fits.py b/examples/manage/plot_failed_fits.py index bbf54aa3..02039d14 100644 --- a/examples/manage/plot_failed_fits.py +++ b/examples/manage/plot_failed_fits.py @@ -59,15 +59,15 @@ # # These attributes are: # -# - ``n_null_`` : the number of model results that are null -# - ``null_inds_`` : the indices of any null model results +# - ``n_null`` : the number of model results that are null +# - ``null_inds`` : the indices of any null model results # ################################################################################################### # Check for failed model fits -print('Number of Null models : \t', fg.results.n_null_) -print('Indices of Null models : \t', fg.results.null_inds_) +print('Number of Null models : \t', fg.results.n_null) +print('Indices of Null models : \t', fg.results.null_inds) ################################################################################################### # Inducing Model Fit Failures @@ -86,7 +86,7 @@ ################################################################################################### # Hack the object to induce model failures -fg._maxfev = 50 +fg.algorithm._cf_settings.maxfev = 50 ################################################################################################### @@ -102,8 +102,8 @@ ################################################################################################### # Check how many model fit failures we have failed model fits -print('Number of Null models : \t', fg.results.n_null_) -print('Indices of Null models : \t', fg.results.null_inds_) +print('Number of Null models : \t', fg.results.n_null) +print('Indices of Null models : \t', fg.results.null_inds) ################################################################################################### # Debug Mode diff --git a/examples/manage/plot_fit_models_3d.py b/examples/manage/plot_fit_models_3d.py index 850c3a74..8786fd25 100644 --- a/examples/manage/plot_fit_models_3d.py +++ b/examples/manage/plot_fit_models_3d.py @@ -205,7 +205,7 @@ # Compare the aperiodic exponent results across conditions for ind, fg in enumerate(fgs): print("Aperiodic exponent for condition {} is {:1.4f}".format( - ind, np.mean(fg.get_params('aperiodic_params', 'exponent')))) + ind, np.mean(fg.get_params('aperiodic', 'exponent')))) ################################################################################################### # Managing Model Objects diff --git a/examples/manage/plot_manipulating_models.py b/examples/manage/plot_manipulating_models.py index 216de87c..eb9214fb 100644 --- a/examples/manage/plot_manipulating_models.py +++ b/examples/manage/plot_manipulating_models.py @@ -168,7 +168,7 @@ ################################################################################################### # Drop all model fits above an error threshold -fg.results.drop(fg.get_params('metrics', 'error_mae') > 0.01) +fg.results.drop(fg.get_metrics('error') > 0.01) ################################################################################################### # Note on Dropped or Failed Fits @@ -186,8 +186,8 @@ ################################################################################################### # Check information on null models (dropped models) -print('Number of Null models : \t', fg.results.n_null_) -print('Indices of Null models : \t', fg.results.null_inds_) +print('Number of Null models : \t', fg.results.n_null) +print('Indices of Null models : \t', fg.results.null_inds) # Despite the dropped model, the total number of models in the object is the same # This means that the indices are still the same as before dropping models @@ -196,7 +196,7 @@ ################################################################################################### # Null models are defined as all NaN (not a number) -for ind in fg.results.null_inds_: +for ind in fg.results.null_inds: print(fg.results[ind]) ################################################################################################### diff --git a/examples/models/plot_aperiodic_params.py b/examples/models/plot_aperiodic_params.py index 79c6ceb3..56884294 100644 --- a/examples/models/plot_aperiodic_params.py +++ b/examples/models/plot_aperiodic_params.py @@ -40,7 +40,7 @@ ################################################################################################### # Check the aperiodic parameters -fm.results.aperiodic_params_ +fm.results.params.aperiodic.params ################################################################################################### @@ -140,7 +140,7 @@ ################################################################################################### # Check the measured aperiodic parameters -fm.results.aperiodic_params_ +fm.results.params.aperiodic.params ################################################################################################### # Knee Frequency @@ -158,7 +158,7 @@ ################################################################################################### # Compute the knee frequency from aperiodic parameters -knee_frequency = compute_knee_frequency(*fm.results.aperiodic_params_[1:]) +knee_frequency = compute_knee_frequency(*fm.results.params.aperiodic.params[1:]) print('Knee frequency: ', knee_frequency) ################################################################################################### diff --git a/examples/models/plot_data_components.py b/examples/models/plot_data_components.py index eef17f23..aa41f554 100644 --- a/examples/models/plot_data_components.py +++ b/examples/models/plot_data_components.py @@ -57,7 +57,7 @@ ################################################################################################### # Plot the power spectrum model from the object -plot_spectra(fm.data.freqs, fm.results.modeled_spectrum_, color='red') +plot_spectra(fm.data.freqs, fm.results.model.modeled_spectrum, color='red') ################################################################################################### # Isolated Components @@ -70,7 +70,7 @@ # To access these components, we can use the following `getter` methods: # # - :meth:`~specparam.SpectralModel.get_data`: allows for accessing data components -# - :meth:`~specparam.SpectralModel.results.get_component`: allows for accessing model components +# - :meth:`~specparam.SpectralModel.results.model.get_component`: allows for accessing model components # ################################################################################################### @@ -91,7 +91,7 @@ ################################################################################################### # Plot the peak removed spectrum, with the model aperiodic fit -plot_spectra(fm.data.freqs, [fm.get_data('aperiodic'), fm.results.get_component('aperiodic')], +plot_spectra(fm.data.freqs, [fm.get_data('aperiodic'), fm.results.model.get_component('aperiodic')], colors=['black', 'blue'], linestyle=['-', '--']) ################################################################################################### @@ -113,7 +113,7 @@ ################################################################################################### # Plot the flattened spectrum data with the model peak fit -plot_spectra(fm.data.freqs, [fm.get_data('peak'), fm.results.get_component('peak')], +plot_spectra(fm.data.freqs, [fm.get_data('peak'), fm.results.model.get_component('peak')], colors=['black', 'green'], linestyle=['-', '--']) ################################################################################################### @@ -128,7 +128,7 @@ # Plot the full model fit, as the combination of the aperiodic and peak model components plot_spectra(fm.data.freqs, - [fm.results.get_component('aperiodic') + fm.results.get_component('peak')], + [fm.results.model.get_component('aperiodic') + fm.results.model.get_component('peak')], color='red') ################################################################################################### @@ -157,7 +157,7 @@ # Plot the peak removed spectrum, with the model aperiodic fit plot_spectra(fm.data.freqs, [fm.get_data('aperiodic', 'linear'), - fm.results.get_component('aperiodic', 'linear')], + fm.results.model.get_component('aperiodic', 'linear')], colors=['black', 'blue'], linestyle=['-', '--']) ################################################################################################### @@ -171,7 +171,7 @@ # Plot the flattened spectrum data with the model peak fit plot_spectra(fm.data.freqs, - [fm.get_data('peak', 'linear'), fm.results.get_component('peak', 'linear')], + [fm.get_data('peak', 'linear'), fm.results.model.get_component('peak', 'linear')], colors=['black', 'green'], linestyle=['-', '--']) ################################################################################################### @@ -199,8 +199,8 @@ # Plot the linear model, showing the combination of peak + aperiodic matches the full model plot_spectra(fm.data.freqs, - [fm.results.get_component('full', 'linear'), - fm.results.get_component('aperiodic', 'linear') + fm.results.get_component('peak', 'linear')], + [fm.results.model.get_component('full', 'linear'), + fm.results.model.get_component('aperiodic', 'linear') + fm.results.model.get_component('peak', 'linear')], linestyle=['-', 'dashed'], colors=['black', 'red'], alpha=[0.3, 0.75]) ################################################################################################### diff --git a/examples/plots/plot_model_components.py b/examples/plots/plot_model_components.py index c6ea795b..2e95e292 100644 --- a/examples/plots/plot_model_components.py +++ b/examples/plots/plot_model_components.py @@ -181,8 +181,8 @@ ################################################################################################### # Extract the aperiodic parameters for each group -aps1 = fg1.get_params('aperiodic_params') -aps2 = fg2.get_params('aperiodic_params') +aps1 = fg1.get_params('aperiodic') +aps2 = fg2.get_params('aperiodic') ################################################################################################### # Plotting Aperiodic Parameters diff --git a/examples/sims/plot_transforms.py b/examples/sims/plot_transforms.py index f0fe9374..0b7f9375 100644 --- a/examples/sims/plot_transforms.py +++ b/examples/sims/plot_transforms.py @@ -78,9 +78,9 @@ # Check the measured exponent values print("Original exponent value:\t {:1.2f}".format(\ - fm1.get_params('aperiodic_params', 'exponent'))) + fm1.get_params('aperiodic', 'exponent'))) print("Rotated exponent value:\t{:1.2f}".format(\ - fm2.get_params('aperiodic_params', 'exponent'))) + fm2.get_params('aperiodic', 'exponent'))) ################################################################################################### # Rotation Related Offset Changes diff --git a/motivations/measurements/plot_BandByBand.py b/motivations/measurements/plot_BandByBand.py index f84689d0..2c7b4235 100644 --- a/motivations/measurements/plot_BandByBand.py +++ b/motivations/measurements/plot_BandByBand.py @@ -107,8 +107,8 @@ def compare_exp(fm1, fm2): """Compare exponent values.""" - exp1 = fm1.get_params('aperiodic_params', 'exponent') - exp2 = fm2.get_params('aperiodic_params', 'exponent') + exp1 = fm1.get_params('aperiodic', 'exponent') + exp2 = fm2.get_params('aperiodic', 'exponent') return exp1 - exp2 @@ -203,7 +203,7 @@ def compare_band_pw(fm1, fm2, band_def): # Plot the power spectra differences plot_spectra_shading(freqs, - [fm_bands_g1.results._spectrum_flat, fm_bands_g2.results._spectrum_flat], + [fm_bands_g1.get_data('peak'), fm_bands_g2.get_data('peak')], log_powers=False, linewidth=3, shades=bands.definitions, shade_colors=shade_cols, labels=labels) @@ -294,7 +294,7 @@ def compare_band_pw(fm1, fm2, band_def): # Plot the power spectra differences plot_spectra_shading(freqs, - [fm_pa_g1.results._spectrum_flat, fm_pa_g2.results._spectrum_flat], + [fm_pa_g1.get_data('peak'), fm_pa_g2.get_data('peak')], log_powers=False, linewidth=3, shades=bands.definitions, shade_colors=shade_cols, labels=labels) diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index 15c10587..9f46e23a 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -1,13 +1,15 @@ """Define object to manage algorithm implementations.""" +import numpy as np + from specparam.utils.checks import check_input_options -from specparam.algorithms.settings import SettingsDefinition +from specparam.algorithms.settings import SettingsDefinition, SettingsValues +from specparam.modutils.docs import docs_get_section, replace_docstring_sections ################################################################################################### ################################################################################################### -FORMATS = ['spectrum', 'spectra', 'spectrogram', 'spectrograms'] - +DATA_FORMATS = ['spectrum', 'spectra', 'spectrogram', 'spectrograms'] class Algorithm(): """Template object for defining a fit algorithm. @@ -18,25 +20,43 @@ class Algorithm(): Name of the fitting algorithm. description : str Description of the fitting algorithm. - settings : dict - Name and description of settings for the fitting algorithm. - format : {'spectrum', 'spectra', 'spectrogram', 'spectrograms'} - Set base format of data model can be applied to. + public_settings : SettingsDefinition or dict + Name and description of public settings for the fitting algorithm. + private_settings : SettingsDefinition or dict, optional + Name and description of private settings for the fitting algorithm. + data_format : {'spectrum', 'spectra', 'spectrogram', 'spectrograms'} + Set base data format the model can be applied to. + modes : Modes + Modes object with fit mode definitions. + data : Data + Data object with spectral data and metadata. + results : Results + Results object with model fit results and metrics. + debug : bool + Whether to run in debug state, raising an error if encountered during fitting. """ - def __init__(self, name, description, settings, format, - modes=None, data=None, results=None, debug=False): + def __init__(self, name, description, public_settings, private_settings=None, + data_format='spectrum', modes=None, data=None, results=None, debug=False): """Initialize Algorithm object.""" self.name = name self.description = description - if not isinstance(settings, SettingsDefinition): - settings = SettingsDefinition(settings) - self.settings = settings + if not isinstance(public_settings, SettingsDefinition): + public_settings = SettingsDefinition(public_settings) + self.public_settings = public_settings + self.settings = SettingsValues(self.public_settings.names) + + if private_settings is None: + private_settings = {} + if not isinstance(private_settings, SettingsDefinition): + private_settings = SettingsDefinition(private_settings) + self.private_settings = private_settings + self._settings = SettingsValues(self.private_settings.names) - check_input_options(format, FORMATS, 'format') - self.format = format + check_input_options(data_format, DATA_FORMATS, 'data_format') + self.data_format = data_format self.modes = None self.data = None @@ -60,13 +80,11 @@ def add_settings(self, settings): Parameters ---------- settings : ModelSettings - A data object containing the settings for a power spectrum model. + A data object containing model settings. """ for setting in settings._fields: - setattr(self, setting, getattr(settings, setting)) - - self._check_loaded_settings(settings._asdict()) + setattr(self.settings, setting, getattr(settings, setting)) def get_settings(self): @@ -78,8 +96,8 @@ def get_settings(self): Object containing the settings from the current object. """ - return self.settings.make_model_settings()(\ - **{key : getattr(self, key) for key in self.settings.names}) + return self.public_settings.make_model_settings()(\ + **{key : getattr(self.settings, key) for key in self.public_settings.names}) def get_debug(self): @@ -100,32 +118,6 @@ def set_debug(self, debug): self._debug = debug - def _check_loaded_settings(self, data): - """Check if settings added, and update the object as needed. - - Parameters - ---------- - data : dict - A dictionary of data that has been added to the object. - """ - - # If settings not loaded from file, clear from object, so that default - # settings, which are potentially wrong for loaded data, aren't kept - if not set(self.settings.names).issubset(set(data.keys())): - - # Reset all public settings to None - for setting in self.settings.names: - setattr(self, setting, None) - - # Reset internal settings so that they are consistent with what was loaded - # Note that this will set internal settings to None, if public settings unavailable - self._reset_internal_settings() - - - def _reset_internal_settings(self): - """"Can be overloaded if any resetting needed for internal settings.""" - - def _reset_subobjects(self, modes=None, data=None, results=None): """Reset links to sub-objects (mode / data / results). @@ -145,3 +137,83 @@ def _reset_subobjects(self, modes=None, data=None, results=None): self.data = data if results is not None: self.results = results + + +## AlgorithmCF + +CURVE_FIT_SETTINGS = SettingsDefinition({ + 'maxfev' : { + 'type' : 'int', + 'description' : 'The maximum number of calls to the curve fitting function.', + }, + 'tol' : { + 'type' : 'float', + 'description' : \ + 'The tolerance setting for curve fitting (see scipy.curve_fit: ftol / xtol / gtol).' + }, +}) + +@replace_docstring_sections([docs_get_section(Algorithm.__doc__, 'Parameters')]) +class AlgorithmCF(Algorithm): + """Template object for defining a fit algorithm that uses `curve_fit`. + + Parameters + ---------- + % copied in from Algorithm + """ + + def __init__(self, name, description, public_settings, private_settings=None, + data_format='spectrum', modes=None, data=None, results=None, debug=False): + """Initialize Algorithm object.""" + + Algorithm.__init__(self, name=name, description=description, + public_settings=public_settings, private_settings=private_settings, + data_format=data_format, modes=modes, data=data, results=results, + debug=debug) + + self._cf_settings_desc = CURVE_FIT_SETTINGS + self._cf_settings = SettingsValues(self._cf_settings_desc.names) + + + def _initialize_bounds(self, mode): + """Initialize a bounds definition. + + Parameters + ---------- + mode : {'aperiodic', 'periodic'} + Which mode to initialize for. + + Returns + ------- + bounds : tuple of array + Bounds values. + + Notes + ----- + Output follows the needed bounds definition for curve_fit, which is: + ([low_bound_param1, low_bound_param2], + [high_bound_param1, high_bound_param2]) + """ + + n_params = getattr(self.modes, mode).n_params + bounds = (np.array([-np.inf] * n_params), np.array([np.inf] * n_params)) + + return bounds + + def _initialize_guess(self, mode): + """Initialize a guess definition. + + Parameters + ---------- + mode : {'aperiodic', 'periodic'} + Which mode to initialize for. + + Returns + ------- + guess : 1d array + Guess values. + """ + + guess = np.zeros([getattr(self.modes, mode).n_params]) + + return guess diff --git a/specparam/algorithms/settings.py b/specparam/algorithms/settings.py index 1fa601c0..14743188 100644 --- a/specparam/algorithms/settings.py +++ b/specparam/algorithms/settings.py @@ -5,48 +5,132 @@ ################################################################################################### ################################################################################################### +class SettingsValues(): + """Defines a set of algorithm settings values. + + Parameters + ---------- + names : list of str + Names of the settings to hold values for. + + Attributes + ---------- + values : dict of {str : object} + Settings values. + """ + + __slots__ = ('values',) + + def __init__(self, names): + """Initialize settings values.""" + + self.values = {name : None for name in names} + + + def __getattr__(self, name): + """Allow for accessing settings values as attributes.""" + + try: + return self.values[name] + except KeyError: + raise AttributeError(name) + + + def __setattr__(self, name, value): + """Allow for setting settings values as attributes.""" + + if name == 'values': + super().__setattr__(name, value) + else: + getattr(self, name) + self.values[name] = value + + + def __getstate__(self): + """Define how to get object state - for pickling.""" + + return self.values + + + def __setstate__(self, state): + """Define how to set object state - for pickling.""" + + self.values = state + + + @property + def names(self): + """Property attribute for settings names.""" + + return list(self.values.keys()) + + + def clear(self): + """Clear all settings - resetting to None.""" + + for setting in self.names: + self.values[setting] = None + + class SettingsDefinition(): """Defines a set of algorithm settings. Parameters ---------- - settings : dict + definitions : dict Settings definition. Each key should be a str name of a setting. Each value should be a dictionary with keys 'type' and 'description', with str values. + + Attributes + ---------- + names : list of str + Names of the settings defined in the object. + descriptions : dict of {str : str} + Description of each setting. + types : dict of {str : str} + Type for each setting. + values : dict of {str : object} + Value of each setting. """ - def __init__(self, settings): + def __init__(self, definitions): """Initialize settings definition.""" - self._settings = settings + self._definitions = definitions + + def __len__(self): + """Define the length of the object as the number of settings.""" - def _get_settings_subdict(self, field): - """Helper function to select from settings dictionary.""" + return len(self._definitions) - return {label : self._settings[label][field] for label in self._settings.keys()} + + def _get_definitions_subdict(self, field): + """Helper function to select from definitions dictionary.""" + + return {label : self._definitions[label][field] for label in self._definitions.keys()} @property def names(self): """Make property alias for setting names.""" - return list(self._settings.keys()) + return list(self._definitions.keys()) @property def types(self): """Make property alias for setting types.""" - return self._get_settings_subdict('type') + return self._get_definitions_subdict('type') @property def descriptions(self): """Make property alias for setting descriptions.""" - return self._get_settings_subdict('description') + return self._get_definitions_subdict('description') def make_setting_str(self, name): @@ -91,6 +175,11 @@ def make_model_settings(self): class ModelSettings(namedtuple('ModelSettings', self.names)): __slots__ = () + + @property + def names(self): + return list(self._fields) + ModelSettings.__doc__ = self.make_docstring() return ModelSettings diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index 4d76918f..fe31c892 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -1,6 +1,7 @@ """Define original spectral fitting algorithm object.""" import warnings +from itertools import repeat import numpy as np from numpy.linalg import LinAlgError @@ -8,15 +9,17 @@ from specparam.modutils.errors import FitError from specparam.utils.select import groupby +from specparam.data.periodic import sort_peaks from specparam.reports.strings import gen_width_warning_str +from specparam.measures.estimates import estimate_fwhm from specparam.measures.params import compute_gauss_std -from specparam.algorithms.algorithm import Algorithm +from specparam.algorithms.algorithm import AlgorithmCF from specparam.algorithms.settings import SettingsDefinition ################################################################################################### ################################################################################################### -SPECTRAL_FIT_SETTINGS = SettingsDefinition({ +SPECTRAL_FIT_SETTINGS_DEF = SettingsDefinition({ 'peak_width_limits' : { 'type' : 'tuple of (float, float), optional, default: (0.5, 12.0)', 'description' : 'Limits on possible peak width, in Hz, as (lower_bound, upper_bound).', @@ -28,64 +31,75 @@ 'min_peak_height' : { 'type' : 'float, optional, default: 0', 'description' : \ - 'Absolute threshold for detecting peaks.\n ' \ + 'Absolute threshold for detecting peaks.' + '\n ' 'This threshold is defined in absolute units of the power spectrum (log power).', }, 'peak_threshold' : { 'type' : 'float, optional, default: 2.0', 'description' : \ - 'Relative threshold for detecting peaks.\n ' \ + 'Relative threshold for detecting peaks.' + '\n ' 'Threshold is defined in relative units of the power spectrum (standard deviation).', }, }) -class SpectralFitAlgorithm(Algorithm): +SPECTRAL_FIT_PRIVATE_SETTINGS_DEF = SettingsDefinition({ + 'ap_percentile_thresh' : { + 'type' : 'float', + 'description' : \ + 'Percentile threshold to select data from flat spectrum for an initial aperiodic fit.' + '\n ' + 'Points are selected at a low percentile value to restrict to non-peak points.', + }, + 'ap_guess' : { + 'type' : 'list of float', + 'description' : \ + 'Guess parameters for fitting the aperiodic component.' + '\n ' + 'The guess parameters should match the length and order of the aperiodic parameters.' + '\n ' + 'If \'offset\' is a parameter, default guess is the first value of the power spectrum.' + '\n ' + 'If \'exponent\' is a parameter, ' + 'default guess is the abs(log-log slope) of first & last points.' + }, + 'ap_bounds' : { + 'type' : 'tuple of tuple of float', + 'description' : \ + 'Bounds for aperiodic fitting, as ((param1_low_bound, ...) (param1_high_bound, ...)).' + '\n ' + 'By default, aperiodic fitting is unbound, but can be restricted here.', + }, + 'cf_bound' : { + 'type' : 'float', + 'description' : \ + 'Parameter bounds for center frequency when fitting peaks, as +/- std dev.', + }, + 'bw_std_edge' : { + 'type' : 'float', + 'description' : \ + 'Threshold for how far a peak has to be from edge to keep.' + '\n ' + 'This is defined in units of peak standard deviation.', + }, + 'gauss_overlap_thresh' : { + 'type' : 'float', + 'description' : \ + 'Degree of overlap between peak guesses for one to be dropped.' + '\n ' + 'This is defined in units of peak standard deviation.', + }, +}) + + +class SpectralFitAlgorithm(AlgorithmCF): """Base object defining model & algorithm for spectral parameterization. Parameters ---------- % public settings described in Spectral Fit Algorithm Settings - _ap_percentile_thresh : float - Percentile threshold, to select points from a flat spectrum for an initial aperiodic fit - Points are selected at a low percentile value to restrict to non-peak points. - _ap_guess : list of [float, float, float] - Guess parameters for fitting the aperiodic component, as [offset, knee, exponent]. - If offset guess is None, the first value of the power spectrum is used as offset guess - If exponent guess is None, the abs(log-log slope) of first & last points is used - _ap_bounds : tuple of tuple of float - Bounds for aperiodic fitting, as: ((offset_low_bound, knee_low_bound, exp_low_bound), - (offset_high_bound, knee_high_bound, exp_high_bound)) - By default, aperiodic fitting is unbound, but can be restricted here. - Even if fitting without knee, leave bounds for knee (they are dropped later). - _cf_bound : float - Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev. - _bw_std_edge : float - Threshold for how far a peak has to be from edge to keep. - This is defined in units of gaussian standard deviation. - _gauss_overlap_thresh : float - Degree of overlap between gaussian guesses for one to be dropped. - This is defined in units of gaussian standard deviation. - _maxfev : int - The maximum number of calls to the curve fitting function. - _tol : float - The tolerance setting for curve fitting (see scipy.curve_fit - ftol / xtol / gtol). - The default value reduce tolerance to speed fitting (as compared to curve_fit's default). - Set value to 1e-8 to match curve_fit default. - - Attributes - ---------- - _gauss_std_limits : list of [float, float] - Settings attribute: peak width limits, to use for gaussian standard deviation parameter. - This attribute is computed based on `peak_width_limits` and should not be updated directly. - _spectrum_flat : 1d array - Data attribute: flattened power spectrum, with the aperiodic component removed. - _spectrum_peak_rm : 1d array - Data attribute: power spectrum, with peaks removed. - _ap_fit : 1d array - Model attribute: values of the isolated aperiodic fit. - _peak_fit : 1d array - Model attribute: values of the isolated peak fit. """ # pylint: disable=attribute-defined-outside-init @@ -99,29 +113,29 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h super().__init__( name='spectral fit', description='Original parameterizing neural power spectra algorithm.', - settings=SPECTRAL_FIT_SETTINGS, format='spectrum', + public_settings=SPECTRAL_FIT_SETTINGS_DEF, + private_settings=SPECTRAL_FIT_PRIVATE_SETTINGS_DEF, modes=modes, data=data, results=results, debug=debug) ## Public settings - self.peak_width_limits = peak_width_limits - self.max_n_peaks = max_n_peaks - self.min_peak_height = min_peak_height - self.peak_threshold = peak_threshold + self.settings.peak_width_limits = peak_width_limits + self.settings.max_n_peaks = max_n_peaks + self.settings.min_peak_height = min_peak_height + self.settings.peak_threshold = peak_threshold ## Private settings: model parameters related settings - self._ap_percentile_thresh = ap_percentile_thresh - self._ap_guess = ap_guess - self._set_ap_bounds(ap_bounds) - self._cf_bound = cf_bound - self._bw_std_edge = bw_std_edge - self._gauss_overlap_thresh = gauss_overlap_thresh + self._settings.ap_percentile_thresh = ap_percentile_thresh + self._settings.ap_guess = ap_guess + self._settings.ap_bounds = self._get_ap_bounds(ap_bounds) + self._settings.cf_bound = cf_bound + self._settings.bw_std_edge = bw_std_edge + self._settings.gauss_overlap_thresh = gauss_overlap_thresh - ## Private setting: curve_fit related settings - self._maxfev = maxfev - self._tol = tol - - ## Set internal settings, based on inputs, and initialize data & results attributes - self._reset_internal_settings() + ## curve_fit settings + # Note - default reduces tolerance to speed fitting (as compared to curve_fit's default). + # Set value to 1e-8 to match curve_fit default. + self._cf_settings.maxfev = maxfev + self._cf_settings.tol = tol def _fit_prechecks(self, verbose=True): @@ -134,8 +148,9 @@ def _fit_prechecks(self, verbose=True): """ if verbose: - if 1.5 * self.data.freq_res >= self.peak_width_limits[0]: - print(gen_width_warning_str(self.data.freq_res, self.peak_width_limits[0])) + if 1.5 * self.data.freq_res >= self.settings.peak_width_limits[0]: + print(gen_width_warning_str(self.data.freq_res, + self.settings.peak_width_limits[0])) def _fit(self): @@ -144,56 +159,38 @@ def _fit(self): ## FIT PROCEDURES # Take an initial fit of the aperiodic component - temp_aperiodic_params_ = self._robust_ap_fit(self.data.freqs, self.data.power_spectrum) - temp_ap_fit = self.modes.aperiodic.func(self.data.freqs, *temp_aperiodic_params_) + temp_aperiodic_params = self._robust_ap_fit(self.data.freqs, self.data.power_spectrum) + temp_ap_fit = self.modes.aperiodic.func(self.data.freqs, *temp_aperiodic_params) - # Find peaks from the flattened power spectrum, and fit them with gaussians + # Find peaks from the flattened power spectrum, and fit them temp_spectrum_flat = self.data.power_spectrum - temp_ap_fit - self.results.gaussian_params_ = self._fit_peaks(temp_spectrum_flat) + self.results.params.periodic.add_params('fit', self._fit_peaks(temp_spectrum_flat)) # Calculate the peak fit # Note: if no peaks are found, this creates a flat (all zero) peak fit - self.results._peak_fit = self.modes.periodic.func(\ - self.data.freqs, *np.ndarray.flatten(self.results.gaussian_params_)) + self.results.model._peak_fit = self.modes.periodic.func(\ + self.data.freqs, *np.ndarray.flatten(self.results.params.periodic.get_params('fit'))) # Create peak-removed (but not flattened) power spectrum - self.results._spectrum_peak_rm = self.data.power_spectrum - self.results._peak_fit + self.results.model._spectrum_peak_rm = \ + self.data.power_spectrum - self.results.model._peak_fit # Run final aperiodic fit on peak-removed power spectrum - self.results.aperiodic_params_ = self._simple_ap_fit(\ - self.data.freqs, self.results._spectrum_peak_rm) - self.results._ap_fit = self.modes.aperiodic.func(\ - self.data.freqs, *self.results.aperiodic_params_) + self.results.params.aperiodic.add_params('fit', \ + self._simple_ap_fit(self.data.freqs, self.results.model._spectrum_peak_rm)) + self.results.model._ap_fit = self.modes.aperiodic.func(\ + self.data.freqs, *self.results.params.aperiodic.params) # Create remaining model components: flatspec & full power_spectrum model fit - self.results._spectrum_flat = self.data.power_spectrum - self.results._ap_fit - self.results.modeled_spectrum_ = self.results._peak_fit + self.results._ap_fit + self.results.model._spectrum_flat = self.data.power_spectrum - self.results.model._ap_fit + self.results.model.modeled_spectrum = \ + self.results.model._peak_fit + self.results.model._ap_fit ## PARAMETER UPDATES - # Convert gaussian definitions to peak parameters - self.results.peak_params_ = self._create_peak_params(self.results.gaussian_params_) - - - def _reset_internal_settings(self): - """Set, or reset, internal settings, based on what is provided in init. - - Notes - ----- - These settings are for internal use, based on what is provided to, or set in `__init__`. - They should not be altered by the user. - """ - - # Only update these settings if other relevant settings are available - if self.peak_width_limits: - - # Bandwidth limits are given in 2-sided peak bandwidth - # Convert to gaussian std parameter limits - self._gauss_std_limits = tuple(bwl / 2 for bwl in self.peak_width_limits) - - # Otherwise, assume settings are unknown (have been cleared) and set to None - else: - self._gauss_std_limits = None + # Convert fit peak parameters to updated values + self.results.params.periodic.add_params('converted', \ + self._create_peak_params(self.results.params.periodic.get_params('fit'))) def _get_ap_guess(self, freqs, power_spectrum): @@ -206,32 +203,35 @@ def _get_ap_guess(self, freqs, power_spectrum): ToDo - Could be updated to fill in missing guesses. """ - if not self._ap_guess: + if not self._settings.ap_guess: + + ap_guess = self._initialize_guess('aperiodic') - ap_guess = [] - for label in self.modes.aperiodic.params.labels: + for label, ind in self.modes.aperiodic.params.indices.items(): if label == 'offset': # Offset guess is the power value for lowest available frequency - ap_guess.append(power_spectrum[0]) + ap_guess[ind] = power_spectrum[0] elif 'exponent' in label: # Exponent guess is a quick calculation of the log-log slope - ap_guess.append(np.abs((power_spectrum[-1] - power_spectrum[0]) / - (np.log10(freqs[-1]) - np.log10(freqs[0])))) - elif 'knee' in label: - # Knee guess set to zero (no real guess) - ap_guess.append(0) - else: - # Any other (un-anticipated) parameter set to guess of 0 - ap_guess.append(0) - - ap_guess = np.array(ap_guess) + ap_guess[ind] = np.abs((power_spectrum[-1] - power_spectrum[0]) / + (np.log10(freqs[-1]) - np.log10(freqs[0]))) return ap_guess - def _set_ap_bounds(self, ap_bounds): + def _get_ap_bounds(self, ap_bounds): """Set the default bounds for the aperiodic fit. + Parameters + ---------- + bounds : tuple of tuple or None + Bounds definition. If None, creates default bounds. + + Returns + ------- + bounds : tuple of tuple + Bounds definition. + Notes ----- The bounds for aperiodic parameters are set in general, and currently do not update @@ -240,12 +240,11 @@ def _set_ap_bounds(self, ap_bounds): if ap_bounds: msg = 'Provided aperiodic bounds do not have right length for fit function.' - assert len(self._ap_bounds[0]) == len(self._ap_bounds[1]) == \ - self.modes.aperiodic.n_params, msg - self._ap_bounds = ap_bounds + assert len(ap_bounds[0]) == len(ap_bounds[1]) == self.modes.aperiodic.n_params, msg else: - self._ap_bounds = (tuple([-np.inf] * self.modes.aperiodic.n_params), - tuple([np.inf] * self.modes.aperiodic.n_params)) + ap_bounds = self._initialize_bounds('aperiodic') + + return ap_bounds def _simple_ap_fit(self, freqs, power_spectrum): @@ -275,9 +274,12 @@ def _simple_ap_fit(self, freqs, power_spectrum): with warnings.catch_warnings(): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(self.modes.aperiodic.func, freqs, power_spectrum, - p0=ap_guess, bounds=self._ap_bounds, - maxfev=self._maxfev, check_finite=False, - ftol=self._tol, xtol=self._tol, gtol=self._tol) + p0=ap_guess, bounds=self._settings.ap_bounds, + maxfev=self._cf_settings.maxfev, + check_finite=False, + ftol=self._cf_settings.tol, + xtol=self._cf_settings.tol, + gtol=self._cf_settings.tol) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding parameters in " "the simple aperiodic component fit.") @@ -318,7 +320,7 @@ def _robust_ap_fit(self, freqs, power_spectrum): flatspec[flatspec < 0] = 0 # Use percentile threshold, in terms of # of points, to extract and re-fit - perc_thresh = np.percentile(flatspec, self._ap_percentile_thresh) + perc_thresh = np.percentile(flatspec, self._settings.ap_percentile_thresh) perc_mask = flatspec <= perc_thresh freqs_ignore = freqs[perc_mask] spectrum_ignore = power_spectrum[perc_mask] @@ -330,9 +332,12 @@ def _robust_ap_fit(self, freqs, power_spectrum): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(self.modes.aperiodic.func, freqs_ignore, spectrum_ignore, - p0=popt, bounds=self._ap_bounds, - maxfev=self._maxfev, check_finite=False, - ftol=self._tol, xtol=self._tol, gtol=self._tol) + p0=popt, bounds=self._settings.ap_bounds, + maxfev=self._cf_settings.maxfev, + check_finite=False, + ftol=self._cf_settings.tol, + xtol=self._cf_settings.tol, + gtol=self._cf_settings.tol) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding " "parameters in the robust aperiodic fit.") @@ -355,9 +360,8 @@ def _fit_peaks(self, flatspec): Returns ------- - gaussian_params : 2d array - Parameters that define the gaussian fit(s). - Each row is a gaussian, as [mean, height, standard deviation]. + peak_params : 2d array + Parameters that define the peak fit(s). """ # Take a copy of the flattened spectrum to iterate across @@ -368,111 +372,114 @@ def _fit_peaks(self, flatspec): # Find peak: loop through, finding a candidate peak, & fit with a guess peak # Stopping procedures: limit on # of peaks, or relative or absolute height thresholds - while len(guess) < self.max_n_peaks: + while len(guess) < self.settings.max_n_peaks: # Find candidate peak - the maximum point of the flattened spectrum max_ind = np.argmax(flat_iter) max_height = flat_iter[max_ind] # Stop searching for peaks once height drops below height threshold - if max_height <= self.peak_threshold * np.std(flat_iter): + if max_height <= self.settings.peak_threshold * np.std(flat_iter): break - # Set the guess parameters for gaussian fitting, specifying the mean and height + # Set the guess parameters for peak fitting, specifying the mean and height guess_freq = self.data.freqs[max_ind] guess_height = max_height # Halt fitting process if candidate peak drops below minimum height - if not guess_height > self.min_peak_height: + if not guess_height > self.settings.min_peak_height: break - # Data-driven first guess at standard deviation - # Find half height index on each side of the center frequency - half_height = 0.5 * max_height - le_ind = next((val for val in range(max_ind - 1, 0, -1) - if flat_iter[val] <= half_height), None) - ri_ind = next((val for val in range(max_ind + 1, len(flat_iter), 1) - if flat_iter[val] <= half_height), None) - - # Guess bandwidth procedure: estimate the width of the peak - try: - # Get an estimated width from the shortest side of the peak - # We grab shortest to avoid estimating very large values from overlapping peaks - # Grab the shortest side, ignoring a side if the half max was not found - short_side = min([abs(ind - max_ind) \ - for ind in [le_ind, ri_ind] if ind is not None]) - - # Use the shortest side to estimate full-width, half max (converted to Hz) - # and use this to estimate that guess for gaussian standard deviation - fwhm = short_side * 2 * self.data.freq_res - guess_std = compute_gauss_std(fwhm) - - except ValueError: - # This procedure can fail (very rarely), if both left & right inds end up as None - # In this case, default the guess to the average of the peak width limits - guess_std = np.mean(self.peak_width_limits) + # Estimate FWHM, and use to convert to an estimated Gaussian std + # If estimation process fails, then default guess to average of limits + fwhm = estimate_fwhm(flat_iter, max_ind, self.data.freq_res) + guess_std = compute_gauss_std(fwhm) if not np.isnan(fwhm) else \ + np.mean(self.settings.peak_width_limits) # Check that guess value isn't outside preset limits - restrict if so + # This also converts the peak_width_limits from 2-sided BW to 1-sided std # Note: without this, curve_fitting fails if given guess > or < bounds - if guess_std < self._gauss_std_limits[0]: - guess_std = self._gauss_std_limits[0] - if guess_std > self._gauss_std_limits[1]: - guess_std = self._gauss_std_limits[1] - - # Collect guess parameters and subtract this guess gaussian from the data - current_guess_params = (guess_freq, guess_height, guess_std) - - ## TEMP - if self.modes.periodic.name == 'skewnorm': - guess_skew = 0 - current_guess_params = (guess_freq, guess_height, guess_std, guess_skew) - - guess = np.vstack((guess, current_guess_params)) - peak_gauss = self.modes.periodic.func(self.data.freqs, *current_guess_params) - flat_iter = flat_iter - peak_gauss + if guess_std < self.settings.peak_width_limits[0] / 2: + guess_std = self.settings.peak_width_limits[0] / 2 + if guess_std > self.settings.peak_width_limits[1] / 2: + guess_std = self.settings.peak_width_limits[0] / 2 + + # Collect guess parameters + cur_guess = [0] * self.modes.periodic.n_params + cur_guess[self.modes.periodic.params.indices['cf']] = guess_freq + cur_guess[self.modes.periodic.params.indices['pw']] = guess_height + cur_guess[self.modes.periodic.params.indices['bw']] = guess_std + + # Fit and subtract guess peak from the spectrum + guess = np.vstack((guess, cur_guess)) + peak_fit = self.modes.periodic.func(self.data.freqs, *cur_guess) + flat_iter = flat_iter - peak_fit # Check peaks based on edges, and on overlap, dropping any that violate requirements guess = self._drop_peak_cf(guess) guess = self._drop_peak_overlap(guess) - # If there are peak guesses, fit the peaks, and sort results + # If there are peak guesses, fit the peaks, and sort results by CF if len(guess) > 0: - gaussian_params = self._fit_peak_guess(flatspec, guess) - gaussian_params = gaussian_params[gaussian_params[:, 0].argsort()] + peak_params = self._fit_peak_guess(flatspec, guess) + peak_params = sort_peaks(peak_params, 'CF', 'inc') + else: - gaussian_params = np.empty([0, self.modes.periodic.n_params]) + peak_params = np.empty([0, self.modes.periodic.n_params]) - return gaussian_params + return peak_params - ## TO GENERALIZE FOR MODES def _get_pe_bounds(self, guess): - """Get the bound for the peak fit.""" - - # Set the bounds for CF, enforce positive height value, and set bandwidth limits - # Note that 'guess' is in terms of gaussian std, so +/- BW is 2 * the guess_gauss_std - # This set of list comprehensions is a way to end up with bounds in the form: - # ((cf_low_peak1, height_low_peak1, bw_low_peak1, *repeated for n_peaks*), - # (cf_high_peak1, height_high_peak1, bw_high_peak, *repeated for n_peaks*)) - # ^where each value sets the bound on the specified parameter - lo_bound = [[peak[0] - 2 * self._cf_bound * peak[2], 0, self._gauss_std_limits[0]] - for peak in guess] - hi_bound = [[peak[0] + 2 * self._cf_bound * peak[2], np.inf, self._gauss_std_limits[1]] - for peak in guess] - - # Check that CF bounds are within frequency range - # If they are not, update them to be restricted to frequency range - lo_bound = [bound if bound[0] > self.data.freq_range[0] else \ - [self.data.freq_range[0], *bound[1:]] for bound in lo_bound] - hi_bound = [bound if bound[0] < self.data.freq_range[1] else \ - [self.data.freq_range[1], *bound[1:]] for bound in hi_bound] - - # Unpacks the embedded lists into flat tuples - # This is what the fit function requires as input - gaus_param_bounds = (tuple(item for sublist in lo_bound for item in sublist), - tuple(item for sublist in hi_bound for item in sublist)) - - return gaus_param_bounds + """Get the bound for the peak fit. + + Parameters + ---------- + guess : list + Guess parameters from initial peak search. + + Returns + ------- + pe_bounds : tuple of array + Bounds for periodic fit. + """ + + n_pe_params = self.modes.periodic.n_params + bounds = repeat(self._initialize_bounds('periodic')) + bounds_lo = np.empty(len(guess) * n_pe_params) + bounds_hi = np.empty(len(guess) * n_pe_params) + + for p_ind, peak in enumerate(guess): + for label, ind in self.modes.periodic.params.indices.items(): + + pbounds_lo, pbounds_hi = next(bounds) + + if label == 'cf': + # Set boundaries on CF, weighted by the bandwidth + peak_bw = peak[self.modes.periodic.params.indices['bw']] + lcf = peak[ind] - 2 * self._settings.cf_bound * peak_bw + hcf = peak[ind] + 2 * self._settings.cf_bound * peak_bw + # Check that CF bounds are within frequency range - if not restrict to range + pbounds_lo[ind] = lcf if lcf > self.data.freq_range[0] \ + else self.data.freq_range[0] + pbounds_hi[ind] = hcf if hcf < self.data.freq_range[1] \ + else self.data.freq_range[1] + + if label == 'pw': + # Enforce positive values for height + pbounds_lo[ind] = 0 + + if label == 'bw': + # Set bandwidth limits, converting limits from Hz to guess params in std + pbounds_lo[ind] = self.settings.peak_width_limits[0] / 2 + pbounds_hi[ind] = self.settings.peak_width_limits[1] / 2 + + bounds_lo[p_ind*n_pe_params:(p_ind+1)*n_pe_params] = pbounds_lo + bounds_hi[p_ind*n_pe_params:(p_ind+1)*n_pe_params] = pbounds_hi + + pe_bounds = (bounds_lo, bounds_hi) + + return pe_bounds def _fit_peak_guess(self, flatspec, guess): @@ -498,8 +505,11 @@ def _fit_peak_guess(self, flatspec, guess): p0=np.ndarray.flatten(guess), bounds=self._get_pe_bounds(guess), jac=self.modes.periodic.jacobian, - maxfev=self._maxfev, check_finite=False, - ftol=self._tol, xtol=self._tol, gtol=self._tol) + maxfev=self._cf_settings.maxfev, + check_finite=False, + ftol=self._cf_settings.tol, + xtol=self._cf_settings.tol, + gtol=self._cf_settings.tol) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding " @@ -517,7 +527,6 @@ def _fit_peak_guess(self, flatspec, guess): return pe_params - ## TO GENERALIZE FOR MODES def _drop_peak_cf(self, guess): """Check whether to drop peaks based on center's proximity to the edge of the spectrum. @@ -532,8 +541,8 @@ def _drop_peak_cf(self, guess): Guess parameters for periodic peak fits. Shape: [n_peaks, n_params_per_peak]. """ - cf_params = guess[:, 0] - bw_params = guess[:, 2] * self._bw_std_edge + cf_params = guess[:, self.modes.periodic.params.indices['cf']] + bw_params = guess[:, self.modes.periodic.params.indices['bw']] * self._settings.bw_std_edge # Check if peaks within drop threshold from the edge of the frequency range keep_peak = \ @@ -547,7 +556,7 @@ def _drop_peak_cf(self, guess): def _drop_peak_overlap(self, guess): - """Checks whether to drop gaussians based on amount of overlap. + """Checks whether to drop peaks based on amount of overlap. Parameters ---------- @@ -564,14 +573,17 @@ def _drop_peak_overlap(self, guess): For any peaks with an overlap > threshold, the lowest height guess peak is dropped. """ + inds = self.modes.periodic.params.indices + # Sort the peak guesses by increasing frequency # This is so adjacent peaks can be compared from right to left - guess = sorted(guess, key=lambda x: float(x[0])) + guess = sorted(guess, key=lambda x: float(x[inds['cf']])) # Calculate standard deviation bounds for checking amount of overlap - # The bounds are the gaussian frequency +/- gaussian standard deviation - bounds = [[peak[0] - peak[2] * self._gauss_overlap_thresh, - peak[0] + peak[2] * self._gauss_overlap_thresh] for peak in guess] + # The bounds are the center frequency +/- width (standard deviation) + bounds = [[peak[inds['cf']] - peak[inds['bw']] * self._settings.gauss_overlap_thresh, + peak[inds['cf']] + peak[inds['bw']] * self._settings.gauss_overlap_thresh]\ + for peak in guess] # Loop through peak bounds, comparing current bound to that of next peak # If the left peak's upper bound extends pass the right peaks lower bound, @@ -583,7 +595,7 @@ def _drop_peak_overlap(self, guess): # Check if bound of current peak extends into next peak if b_0[1] > b_1[0]: - # If so, get the index of the gaussian with the lowest height (to drop) + # If so, get the index of the peak with the lowest height (to drop) drop_inds.append([ind, ind + 1][np.argmin([guess[ind][1], guess[ind + 1][1]])]) # Drop any peaks guesses that overlap too much, based on threshold @@ -593,52 +605,50 @@ def _drop_peak_overlap(self, guess): return guess - ## TO GENERALIZE FOR MODES - def _create_peak_params(self, gaus_params): - """Copies over the gaussian params to peak outputs, updating as appropriate. + def _create_peak_params(self, fit_peak_params): + """Copies over the fit peak parameters output parameters, updating as appropriate. Parameters ---------- - gaus_params : 2d array - Parameters that define the gaussian fit(s), as gaussian parameters. + fit_peak_params : 2d array + Parameters that define the peak parameters directly fit to the spectrum. Returns ------- peak_params : 2d array - Fitted parameter values for the peaks, with each row as [CF, PW, BW]. + Updated parameter values for the peaks. Notes ----- - The gaussian center is unchanged as the peak center frequency. + The center frequency estimate is unchanged as the peak center frequency. - The gaussian height is updated to reflect the height of the peak above - the aperiodic fit. This is returned instead of the gaussian height, as - the gaussian height is harder to interpret, due to peak overlaps. + The peak height is updated to reflect the height of the peak above + the aperiodic fit. This is returned instead of the fit peak height, as + the fit height is harder to interpret, due to peak overlaps. - The gaussian standard deviation is updated to be 'both-sided', to reflect the - 'bandwidth' of the peak, as opposed to the gaussian parameter, which is 1-sided. + The peak bandwidth is updated to be 'both-sided', to reflect the overal width + of the peak, as opposed to the fit parameter, which is 1-sided standard deviation. Performing this conversion requires that the model has been run, - with `freqs`, `modeled_spectrum_` and `_ap_fit` all required to be available. + with `freqs`, `modeled_spectrum` and `_ap_fit` all required to be available. """ - peak_params = np.empty((len(gaus_params), self.modes.periodic.n_params)) + inds = self.modes.periodic.params.indices + + peak_params = np.empty((len(fit_peak_params), self.modes.periodic.n_params)) - for ii, peak in enumerate(gaus_params): + for ii, peak in enumerate(fit_peak_params): + + cpeak = peak.copy() # Gets the index of the power_spectrum at the frequency closest to the CF of the peak - ind = np.argmin(np.abs(self.data.freqs - peak[0])) - - # Collect peak parameter data - if self.modes.periodic.name == 'gaussian': ## TEMP - peak_params[ii] = [peak[0], - self.results.modeled_spectrum_[ind] - self.results._ap_fit[ind], - peak[2] * 2] - - ## TEMP: - if self.modes.periodic.name == 'skewnorm': - peak_params[ii] = [peak[0], - self.results.modeled_spectrum_[ind] - self.results._ap_fit[ind], - peak[2] * 2, peak[3]] + cf_ind = np.argmin(np.abs(self.data.freqs - peak[inds['cf']])) + cpeak[inds['pw']] = \ + self.results.model.modeled_spectrum[cf_ind] - self.results.model._ap_fit[cf_ind] + + # Bandwidth is updated to be 'two-sided' (as opposed to one-sided std dev) + cpeak[inds['bw']] = peak[inds['bw']] * 2 + + peak_params[ii] = cpeak return peak_params diff --git a/specparam/bands/bands.py b/specparam/bands/bands.py index 1d3ba29a..6e04c5f6 100644 --- a/specparam/bands/bands.py +++ b/specparam/bands/bands.py @@ -60,7 +60,7 @@ def __getitem__(self, label): raise ValueError(message) from None - def __repr__(self): + def __str__(self): """Define the string representation as a printout of the band information.""" return '\n'.join(['{:8} : {:2} - {:2} Hz'.format(key, *val) \ diff --git a/specparam/data/conversions.py b/specparam/data/conversions.py index 98258cbf..a1b99d8e 100644 --- a/specparam/data/conversions.py +++ b/specparam/data/conversions.py @@ -2,7 +2,7 @@ import numpy as np -from specparam.bands.bands import Bands, check_bands +from specparam.bands.bands import check_bands from specparam.modutils.dependencies import safe_import, check_dependency from specparam.data.periodic import get_band_peak_arr from specparam.data.utils import flatten_results_dict @@ -31,16 +31,18 @@ def model_to_dict(fit_results, modes, bands): Model results organized into a dictionary. """ + # TODO / NOTE: current update assumes fit / converted + bands = check_bands(bands) fr_dict = {} # aperiodic parameters - for label, param in zip(modes.aperiodic.params.indices, fit_results.aperiodic_params): + for label, param in zip(modes.aperiodic.params.indices, fit_results.aperiodic_fit): fr_dict[label] = param # periodic parameters - peaks = fit_results.peak_params + peaks = fit_results.peak_converted if not bands.bands and bands.n_bands: # If bands if defined in terms of number of peaks diff --git a/specparam/data/data.py b/specparam/data/data.py index 849d4e7a..ed878b90 100644 --- a/specparam/data/data.py +++ b/specparam/data/data.py @@ -86,20 +86,20 @@ class SpectrumMetaData(namedtuple('SpectrumMetaData', ['freq_range', 'freq_res'] __slots__ = () -class FitResults(namedtuple('FitResults', ['aperiodic_params', 'peak_params', - 'gaussian_params', 'metrics'])): +class FitResults(namedtuple('FitResults', ['aperiodic_fit', 'aperiodic_converted', + 'peak_fit', 'peak_converted', 'metrics'])): """Model results from parameterizing a power spectrum. Parameters ---------- - aperiodic_params : 1d array - Parameters that define the aperiodic fit. As [Offset, (Knee), Exponent]. - The knee parameter is only included if aperiodic is fit with knee. - peak_params : 2d array - Fitted parameter values for the peaks. Each row is a peak, as [CF, PW, BW]. - gaussian_params : 2d array - Parameters that define the gaussian fit(s). - Each row is a gaussian, as [mean, height, standard deviation]. + aperiodic_fit : 1d array + Parameters that define the aperiodic fit. + aperiodic_fit : 1d array + Parameters for the aperiodic fit after any applied conversions. + peak_fit : 2d array + Parameters that define the peak(s) that make up the periodic fit. + peak_converted : 2d array + Parameters for the periodic fit after any applied conversions. metrics : dict Metrics results. diff --git a/specparam/data/periodic.py b/specparam/data/periodic.py index 4be5ae22..537d7c50 100644 --- a/specparam/data/periodic.py +++ b/specparam/data/periodic.py @@ -6,7 +6,7 @@ ################################################################################################### def get_band_peak(model, band, select_highest=True, threshold=None, - thresh_param='PW', attribute='peak_params'): + thresh_param='PW', attribute='converted'): """Extract peaks from a band of interest from a model object. Parameters @@ -23,8 +23,9 @@ def get_band_peak(model, band, select_highest=True, threshold=None, A minimum threshold value to apply. thresh_param : {'PW', 'BW'} Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. - attribute : {'peak_params', 'gaussian_params'} - Which attribute of peak data to extract data from. + attribute : {'fit', 'converted'} + Which version of the peak parameters to extract data from. + TODO Returns ------- @@ -42,11 +43,11 @@ def get_band_peak(model, band, select_highest=True, threshold=None, >>> betas = get_band_peak(model, [13, 30], select_highest=False) # doctest:+SKIP """ - return get_band_peak_arr(getattr(model.results, attribute + '_'), band, + return get_band_peak_arr(getattr(model.results.params.periodic, '_' + attribute), band, select_highest, threshold, thresh_param) -def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribute='peak_params'): +def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribute='converted'): """Extract peaks from a band of interest from a group model object. Parameters @@ -60,8 +61,9 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut A minimum threshold value to apply. thresh_param : {'PW', 'BW'} Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. - attribute : {'peak_params', 'gaussian_params'} - Which attribute of peak data to extract data from. + attribute : {'fit', 'converted'} + Which version of the peak parameters to extract data from. + TODO Returns ------- @@ -81,8 +83,8 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut you can do something like: >>> peaks = np.empty((0, 3)) - >>> for res in group: # doctest:+SKIP - ... peaks = np.vstack((peaks, get_band_peak(res.peak_params, band, select_highest=False))) + >>> for res in group.results: # doctest:+SKIP + ... peaks = np.vstack((peaks, get_band_peak_arr(res.peak_params, band, select_highest=False))) Examples -------- @@ -95,11 +97,11 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut >>> betas = get_band_peak_group(group, [13, 30], threshold=0.1) # doctest:+SKIP """ - return get_band_peak_group_arr(group.results.get_params(attribute), band, len(group.results), + return get_band_peak_group_arr(group.results.get_params('peak'), band, len(group.results), threshold, thresh_param) -def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribute='peak_params'): +def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribute='converted'): """Extract peaks from a band of interest from an event model object. Parameters @@ -116,8 +118,8 @@ def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribut A minimum threshold value to apply. thresh_param : {'PW', 'BW'} Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. - attribute : {'peak_params', 'gaussian_params'} - Which attribute of peak data to extract data from. + attribute : {'fit', 'converted'} + Which version of the peak parameters to extract data from. Returns ------- @@ -288,3 +290,36 @@ def threshold_peaks(peak_params, threshold, param='PW'): thresholded_peaks = peak_params[thresh_mask] return thresholded_peaks + + +def sort_peaks(peak_params, sort_param, direction='inc'): + """Sort peak parameters by specified parameter and direction. + + Parameters + ---------- + peak_params : 2d array + Peak parameters, with shape of [n_peaks, 3]. + sort_param : {'CF', 'PW', 'BW'} + Which parameter to sort the parameters by. + direction : {'inc', 'dec'} + Whether to sort as increasing (lowest -> highest) or decreasing (highest -> lowest). + + Returns + ------- + sorted_peaks : 2d array + Sorted peak parameters. + """ + + # Return nan array if empty input + if peak_params.size == 0: + return np.array([np.nan, np.nan, np.nan]) + + # NOTE - TEMP: interim hardcode for parameter index while updating for modes + param_ind = {'CF' : 0, 'PW' : 1, 'BW' : 2}[sort_param] + + peak_params = peak_params[peak_params[:, param_ind].argsort()] + + if direction == 'dec': + peak_params = np.flipud(peak_params) + + return peak_params diff --git a/specparam/data/utils.py b/specparam/data/utils.py index 30efa160..c5ded16c 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -5,25 +5,57 @@ ################################################################################################### ################################################################################################### -def _get_params_helper(modes, name, field): - """Helper function for get_*_params functions.""" +def _get_field_ind(modes, component, field): + """Helper function to get the index for a specified field. - # Allow for shortcut alias, without adding `_params` - if name in ['aperiodic', 'peak', 'gaussian']: - name = name + '_params' + Parameters + ---------- + modes : Modes + Modes description. + component : {'aperiodic', 'peak'} + Component label. + field : str + Field label. + """ # If field specified as string, get mapping back to integer if isinstance(field, str): - if 'aperiodic' in name: + if component == 'aperiodic': field = modes.aperiodic.params.indices[field.lower()] - if 'peak' in name or 'gaussian' in name: + if component == 'peak': field = modes.periodic.params.indices[field.lower()] - return name, field + return field + + +def _get_metric_labels(metrics, category, measure): + """Get a selected set of metric labels. + + Parameters + ---------- + metrics : list of str + List of metric labels. + category : str or list of str + Category of metric to extract, e.g. 'error' or 'gof'. + If 'all', gets all metric labels. + measure : str or list of str + Name of the specific measure(s) to extract. + """ + + if category == 'all': + labels = metrics + elif category and not measure: + labels = [label for label in metrics if category in label] + elif isinstance(measure, list): + labels = [category + '_' + label for label in measure] + else: + labels = [category + '_' + measure] + + return labels -def get_model_params(fit_results, modes, name, field=None): - """Return model fit parameters for specified feature(s). +def get_model_params(fit_results, modes, component, field=None, version=None): + """Return model fit parameters for specified feature(s) from FitResults object. Parameters ---------- @@ -31,11 +63,11 @@ def get_model_params(fit_results, modes, name, field=None): Results of a model fit. modes : Modes Model modes definition. - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'metrics'} - Name of the data field to extract. + component : {'aperiodic', 'peak'} + Name of the component to extract. field : str or int, optional Column name / index to extract from selected data, if requested. - For example, {'CF', 'PW', 'BW'} (periodic) or {'offset', 'knee', 'exponent'} (aperiodic). + See `SpectralModel.modes.check_params` for a description of parameter field names. Returns ------- @@ -43,11 +75,18 @@ def get_model_params(fit_results, modes, name, field=None): Requested data. """ + # TEMP: + if not version: + version = 'converted' if component == 'peak' else 'fit' + component = 'peak' if component == 'periodic' else component + # Use helper function to sort out name and column selection - name, ind = _get_params_helper(modes, name, field) + ind = None + ind = _get_field_ind(modes, component, field) + component = component + '_' + version # Extract the requested data attribute from object - out = getattr(fit_results, name) + out = getattr(fit_results, component) # Periodic values can be empty arrays and if so, replace with NaN array if isinstance(out, np.ndarray) and out.size == 0: @@ -56,19 +95,14 @@ def get_model_params(fit_results, modes, name, field=None): # Select out a specific column, if requested if ind is not None: - if name == 'metrics': - out = out[ind] - - else: - - # Extract column, & if result is a single value in an array, unpack from array - out = out[ind] if out.ndim == 1 else out[:, ind] - out = out[0] if isinstance(out, np.ndarray) and out.size == 1 else out + # Extract column, & if result is a single value in an array, unpack from array + out = out[ind] if out.ndim == 1 else out[:, ind] + out = out[0] if isinstance(out, np.ndarray) and out.size == 1 else out return out -def get_group_params(group_results, modes, name, field=None): +def get_group_params(group_results, modes, component, field=None, version=None): """Extract a specified set of parameters from a set of group results. Parameters @@ -77,11 +111,13 @@ def get_group_params(group_results, modes, name, field=None): List of FitResults objects, reflecting model results across a group of power spectra. modes : Modes Model modes definition. - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + component : {'aperiodic', 'peak'} Name of the data field to extract across the group. field : str or int, optional Column name / index to extract from selected data, if requested. - For example, {'CF', 'PW', 'BW'} (periodic) or {'offset', 'knee', 'exponent'} (aperiodic). + See `SpectralModel.modes.check_params` for a description of parameter field names. + version : {'fit', 'converted'}, optional + TODO Returns ------- @@ -89,16 +125,24 @@ def get_group_params(group_results, modes, name, field=None): Requested data. """ + # TEMP: + if not version: + version = 'converted' if component == 'peak' else 'fit' + component = 'peak' if component == 'periodic' else component + # Use helper function to sort out name and column selection - name, ind = _get_params_helper(modes, name, field) + ind = None + ind = _get_field_ind(modes, component, field) + component = component + '_' + version # Pull out the requested data field from the group data # As a special case, peak_params are pulled out in a way that appends # an extra column, indicating which model each peak comes from - if name in ('peak_params', 'gaussian_params'): + #if name in ('peak_params', 'gaussian_params'): + if 'peak' in component: # Collect peak data, appending the index of the model it comes from - out = np.vstack([np.insert(getattr(data, name), modes.periodic.n_params, index, axis=1) + out = np.vstack([np.insert(getattr(data, component), modes.periodic.n_params, index, axis=1) for index, data in enumerate(group_results)]) # This updates index to grab selected column, and the last column @@ -106,19 +150,42 @@ def get_group_params(group_results, modes, name, field=None): if ind is not None: ind = [ind, -1] else: - out = np.array([getattr(data, name) for data in group_results]) + out = np.array([getattr(data, component) for data in group_results]) # Select out a specific column, if requested if ind is not None: - - if name == 'metrics': - out = np.array([cdict[ind] for cdict in out]) - else: - out = out[:, ind] + out = out[:, ind] return out +def get_group_metrics(group_results, category, measure=None): + """Extract metrics from a set of group results. + + Parameters + ---------- + group_results : list of FitResults + List of FitResults objects, reflecting model results across a group of power spectra. + category : str or list of str + Category of metric to extract, e.g. 'error' or 'gof'. + If 'all', returns all metrics. + measure : str or list of str, optional + Name of the specific measure(s) to extract. + + Returns + ------- + group_metrics : array + Requested metric(s). + """ + + group_metrics = [] + for label in _get_metric_labels(list(group_results[0].metrics.keys()), category, measure): + group_metrics.append(np.array([getattr(fres, 'metrics')[label] for fres in group_results])) + group_metrics = np.squeeze(np.array(group_metrics)) + + return group_metrics + + def get_results_by_ind(results, ind): """Get a specified index from a dictionary of results. diff --git a/specparam/io/files.py b/specparam/io/files.py index dde0166b..425fc798 100644 --- a/specparam/io/files.py +++ b/specparam/io/files.py @@ -69,7 +69,8 @@ def load_json(file_name, file_path): # Get dictionary of available attributes, and convert specified lists back into arrays arrays_to_convert = ['freqs', 'power_spectrum', - 'aperiodic_params_', 'peak_params_', 'gaussian_params_'] + 'aperiodic_fit', 'aperiodic_converted', + 'peak_fit', 'peak_converted'] data = dict_lst_to_array(data, arrays_to_convert) return data diff --git a/specparam/io/models.py b/specparam/io/models.py index f0ccb75d..9fce825a 100644 --- a/specparam/io/models.py +++ b/specparam/io/models.py @@ -48,17 +48,12 @@ def save_model(model, file_name, file_path=None, append=False, If the save file is not understood. """ - # Convert object to dictionary & convert all arrays to lists, for JSON serializing - # This 'flattens' the object, getting all relevant attributes in the same dictionary - obj_dict = dict_array_to_lst(model.__dict__) - data_dict = dict_array_to_lst(model.data.__dict__) - results_dict = dict_array_to_lst(model.results.__dict__) - algo_dict = dict_array_to_lst(model.algorithm.__dict__) - obj_dict = {**obj_dict, **data_dict, **results_dict, **algo_dict} + # 'Flatten' the model object by extracting relevant attributes to a dictionary + obj_dict = {**model.data.__dict__, **model.algorithm.settings.values} # Convert modes object to their saveable string name - obj_dict['aperiodic_mode'] = obj_dict['modes'].aperiodic.name - obj_dict['periodic_mode'] = obj_dict['modes'].periodic.name + obj_dict['aperiodic_mode'] = model.modes.aperiodic.name + obj_dict['periodic_mode'] = model.modes.periodic.name mode_labels = ['aperiodic_mode', 'periodic_mode'] # Add bands information to saveable information @@ -66,8 +61,16 @@ def save_model(model, file_name, file_path=None, append=False, if not model.results.bands._n_bands else model.results.bands._n_bands bands_label = ['bands'] if model.results.bands else [] - # Convert metrics results to saveable information - obj_dict['metrics'] = obj_dict['metrics'].results + # Add parameter results to information to saveable information + res_dict = model.results.params.asdict() + obj_dict = {**obj_dict, **res_dict} + results_labels = list(res_dict.keys()) + + # Add metrics to information to saveable information + obj_dict['metrics'] = model.results.metrics.results + + # Convert all arrays to list for JSON serialization + obj_dict = dict_array_to_lst(obj_dict) # Check for saving out base information / check if base only if save_base is None: @@ -79,7 +82,7 @@ def save_model(model, file_name, file_path=None, append=False, keep = set(\ (mode_labels + bands_label if save_base else []) + \ (model.data._meta_fields if save_base or base_only else []) + \ - (model.results._fields + ['metrics'] if save_results else []) + \ + (results_labels + ['metrics'] if save_results else []) + \ (model.algorithm.settings.names if save_settings else []) + \ (model.data._fields if save_data else [])) diff --git a/specparam/measures/estimates.py b/specparam/measures/estimates.py new file mode 100644 index 00000000..617128a5 --- /dev/null +++ b/specparam/measures/estimates.py @@ -0,0 +1,52 @@ +"""Estimate properties from data.""" + +import numpy as np + +################################################################################################### +################################################################################################### + +def estimate_fwhm(flatspec, peak_ind, freq_res): + """Estimate the Full-Width Half Max (FWHM) given a peak index for a flattened power spectrum. + + Parameters + ---------- + flatspec : 1d array + Flattened power spectrum. + peak_ind : int + Index of the peak in the flattened spectrum to compute FWHM for. + freq_res : float + Frequency resolution. + + Returns + ------- + fwhm : float + Estimated full width half maximum of a peak. + This can be NaN if the FWHM could not be estimated. + + Notes + ----- + Though FWHM are in theory symmetric (for a Gaussian), this procedure estimates the FWHM from + the shortest side of the peak. This is to deal with potential cases of overlapping peaks that + can elongate one side in a way that would bias the FWHM estimate to be greater than desired. + """ + + # Find half height index on each side of the given peak index + half_height = 0.5 * flatspec[peak_ind] + le_ind = next((val for val in range(peak_ind - 1, 0, -1) \ + if flatspec[val] <= half_height), None) + ri_ind = next((val for val in range(peak_ind + 1, len(flatspec), 1) \ + if flatspec[val] <= half_height), None) + + try: + # Get estimated width from the shortest side, ignoring a side if the half max was not found + short_side = min([abs(ind - peak_ind) \ + for ind in [le_ind, ri_ind] if ind is not None]) + + # Use short side to estimate FWHM, also converting estimate to Hz + fwhm = short_side * 2 * freq_res + + except ValueError: + # This process can fail if both sides end up as none - in which case, return as nan + fwhm = np.nan + + return fwhm diff --git a/specparam/measures/metrics.py b/specparam/measures/metrics.py index 9d1f5106..b073ccaf 100644 --- a/specparam/measures/metrics.py +++ b/specparam/measures/metrics.py @@ -20,5 +20,5 @@ 'gof_rsquared' : Metric('gof', 'rsquared', compute_r_squared), 'gof_adjrsquared' : Metric('gof', 'adjrsquared', compute_adj_r_squared, \ {'n_params' : lambda data, results: \ - results.peak_params_.size + results.aperiodic_params_.size}) + results.params.periodic.params.size + results.params.aperiodic.params.size}) } diff --git a/specparam/measures/pointwise.py b/specparam/measures/pointwise.py index b5553636..f1106bea 100644 --- a/specparam/measures/pointwise.py +++ b/specparam/measures/pointwise.py @@ -43,7 +43,7 @@ def compute_pointwise_error(model, plot_errors=True, return_errors=False, **plt_ raise NoModelError("No model is available to use, can not proceed.") errors = compute_pointwise_error_arr(\ - model.results.modeled_spectrum_, model.data.power_spectrum) + model.results.model.modeled_spectrum, model.data.power_spectrum) if plot_errors: plot_spectral_error(model.data.freqs, errors, **plt_kwargs) @@ -89,8 +89,8 @@ def compute_pointwise_error_group(group, plot_errors=True, return_errors=False, for ind, (res, data) in enumerate(zip(group.results, group.data.power_spectra)): - model = gen_model(group.data.freqs, group.modes.aperiodic, res.aperiodic_params, - group.modes.periodic, res.gaussian_params) + model = gen_model(group.data.freqs, group.modes.aperiodic, res.aperiodic_fit, + group.modes.periodic, res.peak_fit) errors[ind, :] = np.abs(model - data) mean = np.mean(errors, 0) diff --git a/specparam/models/base.py b/specparam/models/base.py index fe630d11..07497f15 100644 --- a/specparam/models/base.py +++ b/specparam/models/base.py @@ -3,6 +3,7 @@ from copy import deepcopy from specparam.utils.array import unlog +from specparam.utils.checks import check_array_dim from specparam.modes.modes import Modes from specparam.modutils.errors import NoDataError from specparam.reports.strings import gen_modes_str, gen_settings_str, gen_issue_str @@ -11,7 +12,24 @@ ################################################################################################### class BaseModel(): - """Define BaseModel object.""" + """Define BaseModel object. + + Parameters + ---------- + aperiodic_mode : Mode or str + Mode for aperiodic component, or string specifying which mode to use. + periodic_mode : Mode or str + Mode for periodic component, or string specifying which mode to use. + verbose : bool + Whether to print out updates from the object. + + Attributes + ---------- + modes : Modes + Fit modes definitions. + verbose : bool + Verbosity status. + """ def __init__(self, aperiodic_mode, periodic_mode, verbose): """Initialize object.""" @@ -84,11 +102,11 @@ def get_data(self, component='full', space='log'): output = self.data.power_spectrum if space == 'log' \ else unlog(self.data.power_spectrum) elif component == 'aperiodic': - output = self.results._spectrum_peak_rm if space == 'log' else \ - unlog(self.data.power_spectrum) / unlog(self.results._peak_fit) + output = self.results.model._spectrum_peak_rm if space == 'log' else \ + unlog(self.data.power_spectrum) / unlog(self.results.model._peak_fit) elif component == 'peak': - output = self.results._spectrum_flat if space == 'log' else \ - unlog(self.data.power_spectrum) - unlog(self.results._ap_fit) + output = self.results.model._spectrum_flat if space == 'log' else \ + unlog(self.data.power_spectrum) - unlog(self.results.model._ap_fit) else: raise ValueError('Input for component invalid.') @@ -142,25 +160,31 @@ def _add_from_dict(self, data): Parameters ---------- data : dict - Dictionary of data to add to self. + Dictionary of data to add to current object. """ - # Catch and add convert custom objects + # Catch and add custom objects if 'aperiodic_mode' in data.keys() and 'periodic_mode' in data.keys(): self.add_modes(aperiodic_mode=data.pop('aperiodic_mode'), periodic_mode=data.pop('periodic_mode')) - if 'bands' in data.keys(): + if 'bands' in data.keys(): self.results.add_bands(data.pop('bands')) if 'metrics' in data.keys(): tmetrics = data.pop('metrics') self.results.add_metrics(list(tmetrics.keys())) self.results.metrics.add_results(tmetrics) + # TODO + for label, params in {ke : va for ke, va in data.items() if '_fit' in ke or '_converted' in ke}.items(): + if 'peak' in label: + params = check_array_dim(params) + label1, label2 = label.split('_') + component = 'periodic' if label1 == 'peak' else label1 + getattr(self.results.params, component).add_params(label2, params) + #setattr(self.results.params, label.split('_')[0], params) # Add additional attributes directly to object for key in data.keys(): - if getattr(self, key, False) is not False: - setattr(self, key, data[key]) + if getattr(self.algorithm.settings, key, False) is not False: + setattr(self.algorithm.settings, key, data[key]) elif getattr(self.data, key, False) is not False: setattr(self.data, key, data[key]) - elif getattr(self.results, key, False) is not False: - setattr(self.results, key, data[key]) diff --git a/specparam/models/event.py b/specparam/models/event.py index aa19a5d6..319b74cb 100644 --- a/specparam/models/event.py +++ b/specparam/models/event.py @@ -27,10 +27,8 @@ class SpectralTimeEventModel(SpectralTimeModel): """Model a set of event as a combination of aperiodic and periodic components. - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. + WARNING: frequency and power values inputs must be in linear space. Passing in logged + frequencies and/or power spectra is not detected, and will silently produce incorrect results. Parameters ---------- @@ -43,9 +41,12 @@ class SpectralTimeEventModel(SpectralTimeModel): Notes ----- % copied in from SpectralModel object - - The event object inherits from the time model, which in turn inherits from the - group object, etc. As such it also has data attributes defined on the underlying - objects (see notes and attribute lists in inherited objects for details). + - The event object inherits from the time model, overwriting the `data` and + `results` objects with versions for fitting models across events. + Event related, temporally organized results are collected into the + `results.event_time_results` attribute, which may include sub-selecting peaks + per band (depending on settings). Note that the `results.event_group_results` attribute + is also available, which maintains the full model results. """ def __init__(self, *args, **kwargs): diff --git a/specparam/models/group.py b/specparam/models/group.py index 40e62172..2effea82 100644 --- a/specparam/models/group.py +++ b/specparam/models/group.py @@ -31,10 +31,8 @@ class SpectralGroupModel(SpectralModel): """Model a group of power spectra as a combination of aperiodic and periodic components. - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. + WARNING: frequency and power values inputs must be in linear space. Passing in logged + frequencies and/or power spectra is not detected, and will silently produce incorrect results. Parameters ---------- @@ -47,14 +45,10 @@ class SpectralGroupModel(SpectralModel): Notes ----- % copied in from SpectralModel object - - The group object inherits from the model object. As such it also has data - attributes (`power_spectrum` & `modeled_spectrum_`), and parameter attributes - (`aperiodic_params_`, `peak_params_`, `gaussian_params_`, `r_squared_`, `error_`) - which are defined in the context of individual model fits. These attributes are - used during the fitting process, but in the group context do not store results - post-fitting. Rather, all model fit results are collected and stored into the - `group_results` attribute. To access individual parameters of the fit, use - the `get_params` method. + - The group object inherits from the model object, and in doing so overwrites the + `data` and `results` objects with versions for fitting groups of power spectra. + All model fit results are collected and stored in the `results.group_results` attribute. + To access individual parameters of the fit, use the `get_params` method. """ def __init__(self, *args, **kwargs): @@ -225,15 +219,15 @@ def load(self, file_name, file_path=None): if 'power_spectrum' in data.keys(): power_spectra.append(data.pop('power_spectrum')) + data_keys = set(data.keys()) self._add_from_dict(data) - # If settings are loaded, check and update based on the first line - if ind == 0: - self.algorithm._check_loaded_settings(data) + # For hearder line, check if settings are loaded and clear defaults if not + if ind == 0 and not set(self.algorithm.settings.names).issubset(data_keys): + self.algorithm.settings.clear() # If results part of current data added, check and update object results - if set(self.results._fields).issubset(set(data.keys())): - self.results._check_loaded_results(data) + if 'aperiodic_fit' in data_keys: self.results.group_results.append(self.results._get_results()) # Reconstruct frequency vector, if information is available to do so @@ -249,9 +243,15 @@ def load(self, file_name, file_path=None): @copy_doc_func_to_method(Results2D.get_params) - def get_params(self, name, field=None): + def get_params(self, component, field=None): + + return self.results.get_params(component, field) + + + @copy_doc_func_to_method(Results2D.get_metrics) + def get_metrics(self, category, measure=None): - return self.results.get_params(name, field) + return self.results.get_metrics(category, measure) def get_model(self, ind=None, regenerate=True): diff --git a/specparam/models/model.py b/specparam/models/model.py index 85813839..7b2545f2 100644 --- a/specparam/models/model.py +++ b/specparam/models/model.py @@ -10,7 +10,7 @@ from specparam.models.base import BaseModel from specparam.objs.data import Data from specparam.objs.results import Results -from specparam.algorithms.spectral_fit import SpectralFitAlgorithm, SPECTRAL_FIT_SETTINGS +from specparam.algorithms.spectral_fit import SpectralFitAlgorithm, SPECTRAL_FIT_SETTINGS_DEF from specparam.reports.save import save_model_report from specparam.reports.strings import gen_model_results_str from specparam.modutils.errors import NoDataError, FitError @@ -24,14 +24,12 @@ ################################################################################################### ################################################################################################### -@replace_docstring_sections([SPECTRAL_FIT_SETTINGS.make_docstring()]) +@replace_docstring_sections([SPECTRAL_FIT_SETTINGS_DEF.make_docstring()]) class SpectralModel(BaseModel): """Model a power spectrum as a combination of aperiodic and periodic components. - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. + WARNING: frequency and power values inputs must be in linear space. Passing in logged + frequencies and/or power spectra is not detected, and will silently produce incorrect results. Parameters ---------- @@ -64,12 +62,6 @@ class SpectralModel(BaseModel): For example, raw FFT inputs are not appropriate. Where possible and appropriate, use longer time segments for power spectrum calculation to get smoother power spectra, as this will give better model fits. - - Commonly used abbreviations used in this module include: - CF: center frequency, PW: power, BW: Bandwidth, AP: aperiodic - - The gaussian params are those that define the gaussian of the fit, where as the peak - params are a modified version, in which the CF of the peak is the mean of the gaussian, - the PW of the peak is the height of the gaussian over and above the aperiodic component, - and the BW of the peak, is 2*std of the gaussian (as 'two sided' bandwidth). """ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, @@ -160,7 +152,7 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, prechecks=True): # If not set to fail on NaN or Inf data at add time, check data here # This serves as a catch all for curve_fits which will fail given NaN or Inf # Because FitError's are by default caught, this allows fitting to continue - if not self.data._check_data: + if not self.data.checks['data']: if np.any(np.isinf(self.data.power_spectrum)) or \ np.any(np.isnan(self.data.power_spectrum)): raise FitError("Model fitting was skipped because there are NaN or Inf " @@ -276,21 +268,29 @@ def load(self, file_name, file_path=None, regenerate=True): # Add loaded data to object and check loaded data self._add_from_dict(data) - self.algorithm._check_loaded_settings(data) - self.results._check_loaded_results(data) + + # If settings are not loaded, clear defaults to not have potentially incorrect values + if not set(self.algorithm.settings.names).issubset(set(data.keys())): + self.algorithm.settings.clear() # Regenerate model components, based on what is available if regenerate: if self.data.freq_res: self.data._regenerate_freqs() - if np.all(self.data.freqs) and np.all(self.results.aperiodic_params_): + if np.all(self.data.freqs) and np.all(self.results.params.aperiodic): self.results._regenerate_model(self.data.freqs) @copy_doc_func_to_method(Results.get_params) - def get_params(self, name, field=None): + def get_params(self, component, field=None): + + return self.results.get_params(component, field) + + + @copy_doc_func_to_method(Results.get_metrics) + def get_metrics(self, category, measure=None): - return self.results.get_params(name, field) + return self.results.get_metrics(category, measure) @copy_doc_func_to_method(save_model_report) diff --git a/specparam/models/time.py b/specparam/models/time.py index fd626407..5540cc0e 100644 --- a/specparam/models/time.py +++ b/specparam/models/time.py @@ -22,10 +22,8 @@ class SpectralTimeModel(SpectralGroupModel): """Model a spectrogram as a combination of aperiodic and periodic components. - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. + WARNING: frequency and power values inputs must be in linear space. Passing in logged + frequencies and/or power spectra is not detected, and will silently produce incorrect results. Parameters ---------- @@ -38,14 +36,12 @@ class SpectralTimeModel(SpectralGroupModel): Notes ----- % copied in from SpectralModel object - - The time object inherits from the group model, which in turn inherits from the - model object. As such it also has data attributes defined on the model object, - as well as additional attributes that are added to the group object (see notes - and attribute list in SpectralGroupModel). - - Notably, while this object organizes the results into the `time_results` - attribute, which may include sub-selecting peaks per band (depending on settings) - the `group_results` attribute is also available, which maintains the full - model results. + - The time object inherits from the group model, overwriting the `data` and + `results` objects with versions for fitting models across time. Temporally + organized results are collected into the `results.time_results` attribute, + which may include sub-selecting peaks per band (depending on settings). + Note that the `results.group_results` attribute is also available, which maintains + the full model results. """ def __init__(self, *args, **kwargs): diff --git a/specparam/models/utils.py b/specparam/models/utils.py index b3b39731..6f0171c6 100644 --- a/specparam/models/utils.py +++ b/specparam/models/utils.py @@ -39,7 +39,7 @@ def initialize_model_from_source(source, target): """ model = MODELS[target](**source.modes.get_modes()._asdict(), - **source.algorithm.get_settings()._asdict(), + **source.algorithm.settings.values, metrics=source.results.metrics.labels, bands=source.results.bands, verbose=source.verbose) @@ -72,14 +72,19 @@ def compare_model_objs(model_objs, aspect): outputs.append(compare_model_objs(model_objs, caspect)) return np.all(outputs) - check_input_options(aspect, ['settings', 'meta_data', 'metrics'], 'aspect') + aspects = ['modes', 'settings', 'meta_data', 'bands', 'metrics'] + check_input_options(aspect, aspects, 'aspect') # Check specified aspect of the objects are the same across instances for m_obj_1, m_obj_2 in zip(model_objs[:-1], model_objs[1:]): + if aspect == 'modes': + consistent = m_obj_1.modes.get_modes() == m_obj_2.modes.get_modes() if aspect == 'settings': consistent = m_obj_1.algorithm.get_settings() == m_obj_2.algorithm.get_settings() if aspect == 'meta_data': consistent = m_obj_1.data.get_meta_data() == m_obj_2.data.get_meta_data() + if aspect == 'bands': + consistent = m_obj_1.results.bands == m_obj_2.results.bands if aspect == 'metrics': consistent = m_obj_1.results.metrics.labels == m_obj_2.results.metrics.labels @@ -121,32 +126,33 @@ def average_group(group, bands, avg_method='mean', regenerate=True): raise ValueError("Requested average method not understood.") # Aperiodic parameters: extract & average - ap_params = avg_funcs[avg_method](group.results.get_params('aperiodic_params'), 0) + ap_params = avg_funcs[avg_method](group.results.get_params('aperiodic'), 0) # Periodic parameters: extract & average - peak_params = [] - gauss_params = [] + peak_fit_params = [] + peak_conv_params = [] for band_def in bands.definitions: - peaks = get_band_peak_group(group, band_def, attribute='peak_params') - gauss = get_band_peak_group(group, band_def, attribute='gaussian_params') + peaks_fit = get_band_peak_group(group, band_def, attribute='fit') + peaks_conv = get_band_peak_group(group, band_def, attribute='converted') # Check if there are any extracted peaks - if not, don't add - # Note that we only check peaks, but gauss should be the same - if not np.all(np.isnan(peaks)): - peak_params.append(avg_funcs[avg_method](peaks, 0)) - gauss_params.append(avg_funcs[avg_method](gauss, 0)) + # Note that we only check fit peaks, but converted should be the same + if not np.all(np.isnan(peaks_fit)): + peak_fit_params.append(avg_funcs[avg_method](peaks_fit, 0)) + peak_conv_params.append(avg_funcs[avg_method](peaks_conv, 0)) # Collect together result parameters results_params = { - 'aperiodic_params' : ap_params, - 'peak_params' : np.array(peak_params), - 'gaussian_params' : np.array(gauss_params), + 'aperiodic_fit' : ap_params, + 'aperiodic_converted' : np.array([np.nan] * len(ap_params)), + 'peak_fit' : np.array(peak_fit_params), + 'peak_converted' : np.array(peak_conv_params), } # Goodness of fit measures: extract & average - results_metrics = {label : avg_funcs[avg_method](group.results.get_params('metrics', label)) \ + results_metrics = {label : avg_funcs[avg_method](group.results.get_metrics(label)) \ for label in group.results.metrics.labels} # Create the new model object, with settings, data info, and then add average results @@ -188,7 +194,7 @@ def average_reconstructions(group, avg_method='mean'): models = np.zeros(shape=group.data.power_spectra.shape) for ind in range(len(group.results)): - models[ind, :] = group.get_model(ind, regenerate=True).results.modeled_spectrum_ + models[ind, :] = group.get_model(ind, regenerate=True).results.model.modeled_spectrum avg_model = avg_funcs[avg_method](models, 0) @@ -262,8 +268,8 @@ def combine_model_objs(model_objs): # Set the status for freqs & data checking # Check states gets set as True if any of the inputs have it on, False otherwise group.data.set_checks(\ - check_freqs=any(getattr(m_obj.data, '_check_freqs') for m_obj in model_objs), - check_data=any(getattr(m_obj.data, '_check_data') for m_obj in model_objs)) + check_freqs=any(m_obj.data.checks['freqs'] for m_obj in model_objs), + check_data=any(m_obj.data.checks['data'] for m_obj in model_objs)) # Add data information information group.data.add_meta_data(model_objs[0].data.get_meta_data()) diff --git a/specparam/modes/definitions.py b/specparam/modes/definitions.py index 7698e3b3..a230d46a 100644 --- a/specparam/modes/definitions.py +++ b/specparam/modes/definitions.py @@ -24,6 +24,7 @@ func=expo_nk_function, jacobian=None, params=params_fixed, + ndim=1, freq_space='linear', powers_space='log10', ) @@ -43,6 +44,7 @@ func=expo_function, jacobian=None, params=params_knee, + ndim=1, freq_space='linear', powers_space='log10', ) @@ -63,6 +65,7 @@ func=double_expo_function, jacobian=None, params=params_double_exp, + ndim=1, freq_space='linear', powers_space='log10', ) @@ -92,6 +95,7 @@ func=gaussian_function, jacobian=jacobian_gauss, params=params_gauss, + ndim=2, freq_space='linear', powers_space='log10', ) @@ -112,6 +116,7 @@ func=skewnorm_function, jacobian=None, params=params_skew, + ndim=2, freq_space='linear', powers_space='log10', ) @@ -131,6 +136,7 @@ func=cauchy_function, jacobian=None, params=params_cauchy, + ndim=2, freq_space='linear', powers_space='log10', ) diff --git a/specparam/modes/mode.py b/specparam/modes/mode.py index d010110a..596e6804 100644 --- a/specparam/modes/mode.py +++ b/specparam/modes/mode.py @@ -7,7 +7,7 @@ ################################################################################################### # Set valid options for Mode parameters -VALID_COMPONENTS = ['periodic', 'aperiodic'] +VALID_COMPONENTS = ['aperiodic', 'periodic'] VALID_SPACINGS = ['linear', 'log10'] @@ -18,7 +18,7 @@ class Mode(): ---------- name : str Name of the mode. - component : {'periodic', 'aperiodic'}, + component : {'aperiodic', 'periodic'}, Which component the mode relates to. description : str Description of the mode. @@ -28,6 +28,9 @@ class Mode(): Function for computing Jacobian matrix corresponding to `func`. params : dict or ParamDefinition Parameter definition. + ndim : {1, 2} + Dimensionality of the parameters. + This reflects whether they require a 1d or 2d array to store. freq_space : {'linear', 'log10'} Required spacing of the frequency values for this mode. powers_space : {'linear', 'log10'} @@ -35,7 +38,7 @@ class Mode(): """ def __init__(self, name, component, description, func, jacobian, - params, freq_space, powers_space): + params, ndim, freq_space, powers_space): """Initialize a mode.""" self.name = name @@ -49,6 +52,8 @@ def __init__(self, name, component, description, func, jacobian, params = ParamDefinition(params) self.params = params + self.ndim = ndim + self.spacing = { 'frequency' : check_input_options(freq_space, VALID_SPACINGS, 'freq_space'), 'powers' : check_input_options(powers_space, VALID_SPACINGS, 'powers_space'), @@ -78,3 +83,12 @@ def n_params(self): """Define property attribute to access the number of parameters.""" return self.params.n_params + + + def check_params(self): + """Check the description of the parameters for the current mode.""" + + print('Parameters for the {} component in {} mode:'.format(\ + self.component, self.name)) + for pkey, desc in self.params.descriptions.items(): + print('\t{:15s} {:s}'.format(pkey, desc)) diff --git a/specparam/modes/modes.py b/specparam/modes/modes.py index f2752ed8..2677d586 100644 --- a/specparam/modes/modes.py +++ b/specparam/modes/modes.py @@ -1,7 +1,7 @@ """Modes object.""" from specparam.data import ModelModes -from specparam.modes.mode import Mode +from specparam.modes.mode import Mode, VALID_COMPONENTS from specparam.modes.definitions import AP_MODES, PE_MODES ################################################################################################### @@ -21,10 +21,23 @@ class Modes(): def __init__(self, aperiodic, periodic): """Initialize modes.""" + # Set list of component names + self.components = VALID_COMPONENTS + + # Add mode definitions for each component self.aperiodic = check_mode_definition(aperiodic, AP_MODES) self.periodic = check_mode_definition(periodic, PE_MODES) + def check_params(self): + """Check the description of the parameters for each mode.""" + + if self.aperiodic: + self.aperiodic.check_params() + if self.periodic: + self.periodic.check_params() + + def get_modes(self): """Get the modes definition. @@ -34,7 +47,8 @@ def get_modes(self): Modes definition. """ - return ModelModes(aperiodic_mode=self.aperiodic.name, periodic_mode=self.periodic.name) + return ModelModes(aperiodic_mode=self.aperiodic.name if self.aperiodic else None, + periodic_mode=self.periodic.name if self.periodic else None) def check_mode_definition(mode, options): @@ -42,11 +56,16 @@ def check_mode_definition(mode, options): Parameters ---------- - mode : str or Mode + mode : str or None or Mode Fit mode. If str, should be a label corresponding to an entry in `options`. options : dict Available modes. + Returns + ------- + mode : Mode or None + Mode object, if defined, or None if not defined. + Raises ------ ValueError @@ -56,9 +75,10 @@ def check_mode_definition(mode, options): if isinstance(mode, str): assert mode in list(options.keys()), 'Specific Mode not found.' mode = options[mode] - elif isinstance(mode, Mode): - mode = mode - else: + + if mode is None: + mode = None + elif not isinstance(mode, Mode): raise ValueError('Mode input not understood.') return mode diff --git a/specparam/modutils/docs.py b/specparam/modutils/docs.py index 7b2153c6..af97b739 100644 --- a/specparam/modutils/docs.py +++ b/specparam/modutils/docs.py @@ -22,12 +22,11 @@ def get_docs_indices(docstring, sections=DOCSTRING_SECTIONS): Docstring to check indices for. sections : list of str, optional List of sections to check and get indices for. - If not provided, uses the default set of Returns ------- inds : dict - Dictionary in which each key is a section label, and each value is the corresponding index. + Dictionary where each key is a section label, and each value is the corresponding index. """ inds = {label : None for label in sections} diff --git a/specparam/objs/components.py b/specparam/objs/components.py new file mode 100644 index 00000000..25b6ee58 --- /dev/null +++ b/specparam/objs/components.py @@ -0,0 +1,90 @@ +"""Define model components object.""" + +from specparam.utils.array import unlog +from specparam.modutils.errors import NoModelError + +################################################################################################### +################################################################################################### + +class ModelComponents(): + """Object for managing model components. + + Attributes + ---------- + modeled_spectrum : 1d array + Modeled spectrum. + _spectrum_flat : 1d array + Data attribute: flattened power spectrum, with the aperiodic component removed. + _spectrum_peak_rm : 1d array + Data attribute: power spectrum, with peaks removed. + _ap_fit : 1d array + Model attribute: values of the isolated aperiodic fit. + _peak_fit : 1d array + Model attribute: values of the isolated peak fit. + """ + + def __init__(self): + """Initialize ModelComponents object.""" + + self.reset() + + + def reset(self): + """Reset model components attributes.""" + + # Full model + self.modeled_spectrum = None + + # Model components + self._ap_fit = None + self._peak_fit = None + + # Data components + self._spectrum_flat = None + self._spectrum_peak_rm = None + + + def get_component(self, component='full', space='log'): + """Get a model component. + + Parameters + ---------- + component : {'full', 'aperiodic', 'peak'} + Which model component to return. + 'full' - full model + 'aperiodic' - isolated aperiodic model component + 'peak' - isolated peak model component + space : {'log', 'linear'} + Which space to return the model component in. + 'log' - returns in log10 space. + 'linear' - returns in linear space. + + Returns + ------- + output : 1d array + Specified model component, in specified spacing. + + Notes + ----- + The 'space' parameter doesn't just define the spacing of the model component + values, but rather defines the space of the additive model such that + `model = aperiodic_component + peak_component`. + With space set as 'log', this combination holds in log space. + With space set as 'linear', this combination holds in linear space. + """ + + if self.modeled_spectrum is None: + raise NoModelError("No model fit results are available, can not proceed.") + assert space in ['linear', 'log'], "Input for 'space' invalid." + + if component == 'full': + output = self.modeled_spectrum if space == 'log' else unlog(self.modeled_spectrum) + elif component == 'aperiodic': + output = self._ap_fit if space == 'log' else unlog(self._ap_fit) + elif component == 'peak': + output = self._peak_fit if space == 'log' else \ + unlog(self.modeled_spectrum) - unlog(self._ap_fit) + else: + raise ValueError('Input for component invalid.') + + return output diff --git a/specparam/objs/data.py b/specparam/objs/data.py index e7d593dc..cd3546b4 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -1,5 +1,6 @@ -"""Define base data objects.""" +"""Define data objects.""" +from warnings import warn from functools import wraps import numpy as np @@ -9,6 +10,7 @@ from specparam.utils.spectral import trim_spectrum from specparam.utils.checks import check_input_options from specparam.modutils.errors import DataError, InconsistentDataError +from specparam.modutils.docs import docs_get_section, replace_docstring_sections from specparam.plts.settings import PLT_COLORS from specparam.plts.spectra import plot_spectra, plot_spectrogram from specparam.plts.utils import check_plot_kwargs @@ -28,24 +30,28 @@ class Data(): Parameters ---------- check_freqs : bool - Whether to check the frequency values. - If True, checks the frequency values, and raises an error for uneven spacing. + Whether to check the frequency values. If so, raises an error for uneven spacing. check_data : bool - Whether to check the power spectrum values. - If True, checks the power values and raises an error for any NaN / Inf values. + Whether to check the spectral data. If so, raises an error for any NaN / Inf values. format : {'power'} The representation format of the data. Attributes ---------- + checks : dict + Specifiers for which aspects of the data to run checks on. freqs : 1d array - Frequency values for the power spectrum. - power_spectrum : 1d array - Power values, stored internally in log10 scale. + Frequency values for the spectral data. freq_range : list of [float, float] - Frequency range of the power spectrum, as [lowest_freq, highest_freq]. + Frequency range of the spectral data, as [lowest_freq, highest_freq]. freq_res : float - Frequency resolution of the power spectrum. + Frequency resolution of the spectral data. + power_spectrum : 1d array + Power values. + + Notes + ----- + All power values are stored internally in log10 scale. """ def __init__(self, check_freqs=True, check_data=True, format='power'): @@ -55,9 +61,10 @@ def __init__(self, check_freqs=True, check_data=True, format='power'): self._fields = DATA_FIELDS self._meta_fields = META_DATA_FIELDS - # Define data check run statuses - self._check_freqs = check_freqs - self._check_data = check_data + self.checks = { + 'freqs' : check_freqs, + 'data' : check_data, + } check_input_options(format, FORMATS, 'format') self.format = format @@ -120,7 +127,7 @@ def get_checks(self): Object containing the check statuses from the current object. """ - return ModelChecks(**{key : getattr(self, '_' + key) for key in ModelChecks._fields}) + return ModelChecks(**{'check_' + key : value for key, value in self.checks.items()}) def get_meta_data(self): @@ -156,9 +163,9 @@ def set_checks(self, check_freqs=None, check_data=None): """ if check_freqs is not None: - self._check_freqs = check_freqs + self.checks['freqs'] = check_freqs if check_data is not None: - self._check_data = check_data + self.checks['data'] = check_data def _reset_data(self, clear_freqs=False, clear_spectrum=False): @@ -256,10 +263,10 @@ def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): # Check if freqs start at 0 and move up one value if so # Aperiodic fit gets an inf if freq of 0 is included, which leads to an error if freqs[0] == 0.0: + msg = "specparam fit warning - skipping frequency == 0, " \ + "as this causes a problem with fitting." + warn(msg, category=RuntimeWarning) freqs, powers = trim_spectrum(freqs, powers, [freqs[1], freqs.max()]) - if self.verbose: - print("\nFITTING WARNING: Skipping frequency == 0, " - "as this causes a problem with fitting.") # Calculate frequency resolution, and actual frequency range of the data freq_range = [freqs.min(), freqs.max()] @@ -270,13 +277,13 @@ def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): ## Data checks - run checks on inputs based on check statuses - if self._check_freqs: + if self.checks['freqs']: # Check if the frequency data is unevenly spaced, and raise an error if so freq_diffs = np.diff(freqs) if not np.all(np.isclose(freq_diffs, freq_res)): raise DataError("The input frequency values are not evenly spaced. " "The model expects equidistant frequency values in linear space.") - if self._check_data: + if self.checks['data']: # Check if there are any infs / nans, and raise an error if so if np.any(np.isinf(powers)) or np.any(np.isnan(powers)): error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. " @@ -288,20 +295,24 @@ def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): return freqs, powers, freq_range, freq_res +@replace_docstring_sections([docs_get_section(Data.__doc__, 'Parameters'), + docs_get_section(Data.__doc__, 'Attributes')]) class Data2D(Data): """Base object for managing data for spectral parameterization - for 2D data. + Parameters + ---------- + % copied in from Data + Attributes ---------- - freqs : 1d array - Frequency values for the power spectra. + % copied in from Data power_spectra : 2d array Power values for the group of power spectra, as [n_power_spectra, n_freqs]. - Power values are stored internally in log10 scale. - freq_range : list of [float, float] - Frequency range of the power spectra, as [lowest_freq, highest_freq]. - freq_res : float - Frequency resolution of the power spectra. + + Notes + ----- + All power values are stored internally in log10 scale. """ def __init__(self): @@ -385,20 +396,24 @@ def decorated(*args, **kwargs): return decorated +@replace_docstring_sections([docs_get_section(Data.__doc__, 'Parameters'), + docs_get_section(Data2D.__doc__, 'Attributes')]) class Data2DT(Data2D): """Base object for managing data for spectral parameterization - for 2D transposed data. + Parameters + ---------- + % copied in from Data + Attributes ---------- - freqs : 1d array - Frequency values for the spectrogram. + % copied in from Data2D spectrogram : 2d array Power values for the spectrogram, as [n_freqs, n_time_windows]. - Power values are stored internally in log10 scale. - freq_range : list of [float, float] - Frequency range of the spectrogram, as [lowest_freq, highest_freq]. - freq_res : float - Frequency resolution of the spectrogram. + + Notes + ----- + All power values are stored internally in log10 scale. """ def __init__(self): @@ -451,20 +466,24 @@ def plot(self, **plt_kwargs): plot_spectrogram(self.freqs, self.spectrogram, **plot_kwargs) +@replace_docstring_sections([docs_get_section(Data.__doc__, 'Parameters'), + docs_get_section(Data2DT.__doc__, 'Attributes')]) class Data3D(Data2DT): """Base object for managing data for spectral parameterization - for 3D data. + Parameters + ---------- + % copied in from Data + Attributes ---------- - freqs : 1d array - Frequency values for the power spectra. + % copied in from Data2DT spectrograms : 3d array Power values for the spectrograms, organized as [n_events, n_freqs, n_time_windows]. - Power values are stored internally in log10 scale. - freq_range : list of [float, float] - Frequency range of the power spectra, as [lowest_freq, highest_freq]. - freq_res : float - Frequency resolution of the power spectra. + + Notes + ----- + All power values are stored internally in log10 scale. """ def __init__(self): diff --git a/specparam/objs/metrics.py b/specparam/objs/metrics.py index 9170f041..1915481d 100644 --- a/specparam/objs/metrics.py +++ b/specparam/objs/metrics.py @@ -12,8 +12,8 @@ class Metric(): Parameters ---------- - type : str - The type of measure, e.g. 'error' or 'gof'. + category : str + The category of measure, e.g. 'error' or 'gof'. measure : str The specific measure, e.g. 'r_squared'. func : callable @@ -25,10 +25,10 @@ class Metric(): and returns the desired parameter / computed value. """ - def __init__(self, type, measure, func, kwargs=None): + def __init__(self, category, measure, func, kwargs=None): """Initialize metric.""" - self.type = type + self.category = category self.measure = measure self.func = func self.result = np.nan @@ -45,17 +45,17 @@ def __repr__(self): def label(self): """Define label property.""" - return self.type + '_' + self.measure + return self.category + '_' + self.measure @property def flabel(self): """Define formatted label property.""" - if self.type == 'error': - flabel = '{} ({})'.format(self.type.capitalize(), self.measure.upper()) - if self.type == 'gof': - flabel = '{} ({})'.format(self.type.upper(), self.measure) + if self.category == 'error': + flabel = '{} ({})'.format(self.category.capitalize(), self.measure.upper()) + if self.category == 'gof': + flabel = '{} ({})'.format(self.category.upper(), self.measure) return flabel @@ -75,7 +75,13 @@ def compute_metric(self, data, results): for key, lfunc in self.kwargs.items(): kwargs[key] = lfunc(data, results) - self.result = self.func(data.power_spectrum, results.modeled_spectrum_, **kwargs) + self.result = self.func(data.power_spectrum, results.model.modeled_spectrum, **kwargs) + + + def reset(self): + """Reset metric result.""" + + self.result = np.nan class Metrics(): @@ -146,6 +152,39 @@ def add_metrics(self, metrics): self.add_metric(metric) + def get_metrics(self, category, measure=None): + """Get requested metric(s) from the object. + + Parameters + ---------- + category : str + Category of metric to extract, e.g. 'error' or 'gof'. + If 'all', returns all available metrics. + measure : str, optional + Name of the specific measure(s) to return. + + Returns + ------- + metrics : dict + Dictionary of requested metrics. + """ + + if category == 'all': + out = self.results + + else: + + out = {ke : va for ke, va in self.results.items() if category in ke} + + if measure is not None: + out = {ke : va for ke, va in out.items() if measure in ke} + + out = np.array(list(out.values())) + out = out[0] if out.size == 1 else out + + return out + + def compute_metrics(self, data, results): """Compute all currently defined metrics. @@ -162,10 +201,10 @@ def compute_metrics(self, data, results): @property - def types(self): - """Define alias for metric type of all currently defined metrics.""" + def categories(self): + """Define alias for metric categories of all currently defined metrics.""" - return [metric.type for metric in self.metrics] + return [metric.category for metric in self.metrics] @property @@ -208,3 +247,10 @@ def add_results(self, results): for key, value in results.items(): self[key].result = value + + + def reset(self): + """Reset all metric results.""" + + for metric in self.metrics: + metric.reset() diff --git a/specparam/objs/params.py b/specparam/objs/params.py new file mode 100644 index 00000000..811a1e73 --- /dev/null +++ b/specparam/objs/params.py @@ -0,0 +1,236 @@ +"""Define model parameters object.""" + +import numpy as np + +from specparam.modes.mode import Mode + +################################################################################################### +################################################################################################### + +class ModelParameters(): + """Object to manage model fit parameters. + + Parameters + ---------- + modes : Modes + Fit modes definition. + If provided, used to initialize parameter arrays to correct sizes. + + Attributes + ---------- + aperiodic : ComponentParameters + Parameters for the aperiodic component of the model fit. + periodic : ComponentParameters + Parameters for the periodic component of the model fit. + """ + + def __init__(self, modes=None): + """Initialize ModelParameters object.""" + + self.aperiodic = ComponentParameters(modes.aperiodic if modes else 'aperiodic') + self.periodic = ComponentParameters(modes.periodic if modes else 'periodic') + + + def reset(self): + """Reset component parameter definitions.""" + + self.aperiodic.reset() + self.periodic.reset() + + + def asdict(self): + """"Export model parameters to a dictionary. + + Returns + ------- + dict + Exported dictionary of the model parameters. + """ + + apdict = self.aperiodic.asdict() + pedict = self.periodic.asdict() + + return {**apdict, **pedict} + + +class ComponentParameters(): + """Object to manage parameters for a particular model component. + + Parameters + ---------- + component : str or Mode + Component that the parameters reflect. + If Mode, includes a definition of the component fit mode. + If str, should be a label to use for the component. + """ + + def __init__(self, component): + """Initialize ComponentParameters object.""" + + self._fit = np.nan + self._converted = np.nan + + self.n_params = None + self.ndim = None + self.indices = {} + + if isinstance(component, Mode): + self.component = component.component + self.n_params = component.n_params + self.ndim = component.ndim + self.add_indices(component.params.indices) + self.reset() + + else: + self.component = component + + + def _has_param(self, version): + """Helper function to check whether the object has parameter values. + + Parameters + ---------- + version : {'fit', 'converted'} + Which version of the parameters to check for. + + Returns + ------- + bool + Whether the object has the specified type of parameter values. + + Notes + ----- + Return of False can indicate either that the specified params attribute is uninitialized + (singular nan value), or that it is initialized but has no values (an array of nan). + """ + + return True if not np.all(np.isnan(getattr(self, version))) else False + + + @property + def has_fit(self): + """Property attribute for checking if object has fit parameters.""" + + return self._has_param('_fit') + + + @property + def has_converted(self): + """Property attribute for checking if object has converted parameters.""" + + return self._has_param('_converted') + + + @property + def has_params(self): + """"Property attribute for checking if any params are avaialble.""" + + return self.has_fit + + + @property + def params(self): + """Property attribute to return parameters. + + Notes + ----- + If available, this return converted parameters. If not, this returns fit parameters. + """ + + return self.get_params('converted' if self.has_converted else 'fit') + + + def reset(self): + """Reset parameter stores.""" + + if self.n_params: + self._fit = np.array([np.nan] * self.n_params, ndmin=self.ndim) + self._converted = np.array([np.nan] * self.n_params, ndmin=self.ndim) + else: + self._fit = np.nan + self._converted = np.nan + + + def add_indices(self, indices): + """Add parameter indices definition to the object.""" + + self.indices = indices + + + def add_params(self, version, params): + """Add parameter values to the object. + + Parameters + ---------- + version : {'fit', 'converted'} + Which version of the parameters to return. + params : array + Parameter values to add to the object. + """ + + if version == 'fit': + self._fit = params + if version == 'converted': + self._converted = params + + + def convert_params(self, converter): + """Convert fit parameters to converted versions and store in the object. + + Parameters + ---------- + converter : func + Callable that takes in fit parameters and returns converted version. + """ + + self.add_params('converted', converter(self.get_params('fit'))) + + + def get_params(self, version, field=None): + """Get parameter values from the object. + + Parameters + ---------- + version : {'fit', 'converted'} + Which version of the parameters to return. + field : str, optional + Which field from the parameters to return. + + Returns + ------- + params : array + Extracted parameter values. + """ + + if version is None: + output = self.params + if version == 'fit': + output = self._fit + if version == 'converted': + output = self._converted + + if field is not None: + ind = self.indices[field.lower()] if isinstance(field, str) else field + output = output[ind] if output.ndim == 1 else output[:, ind] + + return output + + + def asdict(self): + """Get the parameter values in a dictionary. + + Returns + ------- + dict + Parameter values from object in a dictionary. + """ + + # TEMP + label = 'peak' if self.component == 'periodic' else self.component + + outdict = { + label + '_fit' : self._fit, + label + '_converted' : self._converted, + } + + return outdict diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 66ea6194..bd422fe7 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -1,4 +1,4 @@ -"""Define base results objects.""" +"""Define results objects.""" from copy import deepcopy from itertools import repeat @@ -6,15 +6,18 @@ import numpy as np from specparam.bands.bands import check_bands +from specparam.modes.modes import Modes from specparam.objs.metrics import Metrics +from specparam.objs.params import ModelParameters +from specparam.objs.components import ModelComponents from specparam.measures.metrics import METRICS -from specparam.utils.array import unlog from specparam.utils.checks import check_inds, check_array_dim from specparam.modutils.errors import NoModelError -from specparam.modutils.docs import docs_get_section, replace_docstring_sections +from specparam.modutils.docs import (copy_doc_func_to_method, docs_get_section, + replace_docstring_sections) from specparam.data.data import FitResults from specparam.data.conversions import group_to_dict, event_group_to_dict -from specparam.data.utils import (get_model_params, get_group_params, +from specparam.data.utils import (get_model_params, get_group_params, get_group_metrics, get_results_by_ind, get_results_by_row) from specparam.sim.gen import gen_model @@ -22,7 +25,6 @@ ################################################################################################### # Define set of results fields & default metrics to use -RESULTS_FIELDS = ['aperiodic_params_', 'gaussian_params_', 'peak_params_'] DEFAULT_METRICS = ['error_mae', 'gof_rsquared'] @@ -35,22 +37,37 @@ class Results(): Modes object with fit mode definitions. metrics : Metrics Metrics object with metric definitions. - bands : bands + bands : Bands Bands object with band definitions. + + Attributes + ---------- + modes : Modes + Modes object with fit mode definitions. + bands : Bands + Bands object with band definitions. + model : ModelComponents + Manages the model fit and components. + params : ModelParameters + Manages the model fit parameters. + metrics : Metrics + Metrics object with metric definitions. """ # pylint: disable=attribute-defined-outside-init, arguments-differ def __init__(self, modes=None, metrics=None, bands=None): """Initialize Results object.""" - self.modes = modes + self.modes = modes if modes else Modes(None, None) self.add_bands(bands) self.add_metrics(metrics) + self.model = ModelComponents() + self.params = ModelParameters(modes=modes) + # Initialize results attributes self._reset_results(True) - self._fields = RESULTS_FIELDS @property @@ -59,33 +76,30 @@ def has_model(self): Notes ----- - This check uses the aperiodic params, which are: - - - nan if no model has been fit - - necessarily defined, as floats, if model has been fit + This checks the aperiodic params, which are necessarily defined if a model has been fit. """ - return not np.all(np.isnan(self.aperiodic_params_)) + return self.params.aperiodic.has_params @property - def n_peaks_(self): + def n_peaks(self): """How many peaks were fit in the model.""" n_peaks = None if self.has_model: - n_peaks = self.peak_params_.shape[0] + n_peaks = self.params.periodic.params.shape[0] return n_peaks @property - def n_params_(self): + def n_params(self): """The total number of parameters fit in the model.""" n_params = None if self.has_model: - n_peak_params = self.modes.periodic.n_params * self.n_peaks_ + n_peak_params = self.modes.periodic.n_params * self.n_peaks n_params = n_peak_params + self.modes.aperiodic.n_params return n_params @@ -133,13 +147,15 @@ def add_results(self, results): A data object containing the results from fitting a power spectrum model. """ - # Add parameter fields and then select and add metrics results - for pfield in self._fields: - setattr(self, pfield, getattr(results, pfield.strip('_'))) + # TODO: use check_array_dim for peak arrays? Or is / should this be done in `add_params` - self.metrics.add_results(results.metrics) + for component in self.modes.components: + for version in ['fit', 'converted']: + attr_comp = 'peak' if component == 'periodic' else component + getattr(self.params, component).add_params(\ + version, getattr(results, attr_comp + '_' + version)) - self._check_loaded_results(results._asdict()) + self.metrics.add_results(results.metrics) def get_results(self): @@ -151,69 +167,21 @@ def get_results(self): Object containing the model fit results from the current object. """ - results = FitResults( - **{key.strip('_') : getattr(self, key) for key in self._fields}, - metrics=self.metrics.results) - - return results - + return FitResults(**self.params.asdict(), metrics=self.metrics.results) - def get_component(self, component='full', space='log'): - """Get a model component. - Parameters - ---------- - component : {'full', 'aperiodic', 'peak'} - Which model component to return. - 'full' - full model - 'aperiodic' - isolated aperiodic model component - 'peak' - isolated peak model component - space : {'log', 'linear'} - Which space to return the model component in. - 'log' - returns in log10 space. - 'linear' - returns in linear space. - - Returns - ------- - output : 1d array - Specified model component, in specified spacing. - - Notes - ----- - The 'space' parameter doesn't just define the spacing of the model component - values, but rather defines the space of the additive model such that - `model = aperiodic_component + peak_component`. - With space set as 'log', this combination holds in log space. - With space set as 'linear', this combination holds in linear space. - """ - - if not self.has_model: - raise NoModelError("No model fit results are available, can not proceed.") - assert space in ['linear', 'log'], "Input for 'space' invalid." - - if component == 'full': - output = self.modeled_spectrum_ if space == 'log' else unlog(self.modeled_spectrum_) - elif component == 'aperiodic': - output = self._ap_fit if space == 'log' else unlog(self._ap_fit) - elif component == 'peak': - output = self._peak_fit if space == 'log' else \ - unlog(self.modeled_spectrum_) - unlog(self._ap_fit) - else: - raise ValueError('Input for component invalid.') - - return output - - - def get_params(self, name, field=None): + def get_params(self, component, field=None, version=None): """Return model fit parameters for specified feature(s). Parameters ---------- - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} - Name of the data field to extract. - field : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + component : {'aperiodic', 'periodic'} + Name of the component to extract parameters for. + field : str or int, optional Column name / index to extract from selected data, if requested. - Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + If str, should align with a parameter label for the component fit mode. + version : {'fit', 'converted'} + Which version of the parameters to extract. Returns ------- @@ -227,29 +195,21 @@ def get_params(self, name, field=None): Notes ----- - If there are no fit peak (no peak parameters), this method will return NaN. + If there are no fit peaks (no periodic parameters), this method will return NaN. """ - if not self.has_model: - raise NoModelError("No model fit results are available to extract, can not proceed.") + component = 'periodic' if component == 'peak' else component - return get_model_params(self.get_results(), self.modes, name, field) + if not self.has_model: + raise NoModelError("No model fit results are available, can not proceed.") + return getattr(self.params, component).get_params(version, field) - def _check_loaded_results(self, data): - """Check if results have been added and check data. - Parameters - ---------- - data : dict - A dictionary of data that has been added to the object. - """ + @copy_doc_func_to_method(Metrics.get_metrics) + def get_metrics(self, category, measure=None): - # If results loaded, check dimensions of peak parameters - # This fixes an issue where they end up the wrong shape if they are empty (no peaks) - if set(self._fields).issubset(set(data.keys())): - self.peak_params_ = check_array_dim(self.peak_params_) - self.gaussian_params_ = check_array_dim(self.gaussian_params_) + return self.metrics.get_metrics(category, measure) def _reset_results(self, clear_results=False): @@ -262,29 +222,9 @@ def _reset_results(self, clear_results=False): """ if clear_results: - - # Aperiodic parameters - if self.modes: - self.aperiodic_params_ = np.array([np.nan] * self.modes.aperiodic.n_params) - else: - self.aperiodic_params_ = np.nan - - # Periodic parameters - if self.modes: - self.gaussian_params_ = np.empty([0, self.modes.periodic.n_params]) - self.peak_params_ = np.empty([0, self.modes.periodic.n_params]) - else: - self.gaussian_params_ = np.nan - self.peak_params_ = np.nan - - # Data components - self._spectrum_flat = None - self._spectrum_peak_rm = None - - # Modeled spectrum components - self.modeled_spectrum_ = None - self._ap_fit = None - self._peak_fit = None + self.params.reset() + self.model.reset() + self.metrics.reset() def _regenerate_model(self, freqs): @@ -296,19 +236,26 @@ def _regenerate_model(self, freqs): Frequency values for the power_spectrum, in linear scale. """ - self.modeled_spectrum_, self._peak_fit, self._ap_fit = gen_model(freqs, \ - self.modes.aperiodic, self.aperiodic_params_, - self.modes.periodic, self.gaussian_params_, - return_components=True) + self.model.modeled_spectrum, self.model._peak_fit, self.model._ap_fit = \ + gen_model(freqs, self.modes.aperiodic, self.params.aperiodic.get_params('fit'), + self.modes.periodic, self.params.periodic.get_params('fit'), + return_components=True) -@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters')]) +@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters'), + docs_get_section(Results.__doc__, 'Attributes')]) class Results2D(Results): """Object for managing results - 2D version. Parameters ---------- % copied in from Results + + Attributes + ---------- + % copied in from Results + group_results : list of FitResults + Results of the model fit for each power spectrum. """ def __init__(self, modes=None, metrics=None, bands=None): @@ -364,35 +311,35 @@ def has_model(self): @property - def n_peaks_(self): + def n_peaks(self): """How many peaks were fit for each model.""" n_peaks = None if self.has_model: - n_peaks = np.array([res.peak_params.shape[0] for res in self]) + n_peaks = np.array([res.peak_fit.shape[0] for res in self]) return n_peaks @property - def n_null_(self): + def n_null(self): """How many model fits are null.""" n_null = None if self.has_model: - n_null = sum([1 for res in self.group_results if np.isnan(res.aperiodic_params[0])]) + n_null = sum([1 for res in self.group_results if np.isnan(res.aperiodic_fit[0])]) return n_null @property - def null_inds_(self): + def null_inds(self): """The indices for model fits that are null.""" null_inds = None if self.has_model: null_inds = [ind for ind, res in enumerate(self.group_results) \ - if np.isnan(res.aperiodic_params[0])] + if np.isnan(res.aperiodic_fit[0])] return null_inds @@ -433,16 +380,16 @@ def drop(self, inds): self.group_results[ind] = null_results - def get_params(self, name, field=None): + def get_params(self, component, field=None): """Return model fit parameters for specified feature(s). Parameters ---------- - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} - Name of the data field to extract across the group. - field : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + component : {'aperiodic', 'periodic'} + Name of the component to extract parameters for. + field : str or int, optional Column name / index to extract from selected data, if requested. - Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + If str, should align with a parameter label for the component fit mode. Returns ------- @@ -458,23 +405,36 @@ def get_params(self, name, field=None): Notes ----- - When extracting peak information ('peak_params' or 'gaussian_params'), an additional - column is appended to the returned array, indicating the index that the peak came from. + When extracting peak parameters, an additional column is appended to the + returned array, indicating the index that the peak came from. """ if not self.has_model: raise NoModelError("No model fit results are available, can not proceed.") - return get_group_params(self.group_results, self.modes, name, field) + return get_group_params(self.group_results, self.modes, component, field) + + @copy_doc_func_to_method(Metrics.get_metrics) + def get_metrics(self, category, measure=None): -@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters')]) + return get_group_metrics(self.group_results, category, measure) + + +@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters'), + docs_get_section(Results2D.__doc__, 'Attributes')]) class Results2DT(Results2D): """Object for managing results - 2D transpose version. Parameters ---------- % copied in from Results + + Attributes + ---------- + % copied in from Results2D + time_results : dict + Results of the model fit across each time window. """ def __init__(self, modes=None, metrics=None, bands=None): @@ -527,13 +487,23 @@ def convert_results(self): self.time_results = group_to_dict(self.group_results, self.modes, self.bands) -@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters')]) +@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters'), + docs_get_section(Results2DT.__doc__, 'Attributes')]) class Results3D(Results2DT): """Object for managing results - 3D version. Parameters ---------- % copied in from Results + + Attributes + ---------- + % copied in from Results2DT + event_group_results : list of list of FitResults + Full model results collected across all events and models. + event_time_results : dict + Results of the model fit across each time window, collected across events. + Each value in the dictionary stores a model fit parameter, as [n_events, n_time_windows]. """ def __init__(self, modes=None, metrics=None, bands=None): @@ -571,12 +541,12 @@ def has_model(self): @property - def n_peaks_(self): + def n_peaks(self): """How many peaks were fit for each model, for each event.""" n_peaks = None if self.has_model: - n_peaks = np.array([[res.peak_params.shape[0] for res in gres] \ + n_peaks = np.array([[res.peak_fit.shape[0] for res in gres] \ for gres in self.event_group_results]) return n_peaks @@ -637,16 +607,16 @@ def get_results(self): return self.event_time_results - def get_params(self, name, field=None): + def get_params(self, component, field=None): """Return model fit parameters for specified feature(s). Parameters ---------- - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} - Name of the data field to extract across the group. - field : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional + component : {'aperiodic', 'periodic'} + Name of the component to extract parameters for. + field : str or int, optional Column name / index to extract from selected data, if requested. - Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}. + If str, should align with a parameter label for the component fit mode. Returns ------- @@ -662,11 +632,12 @@ def get_params(self, name, field=None): Notes ----- - When extracting peak information ('peak_params' or 'gaussian_params'), an additional - column is appended to the returned array, indicating the index that the peak came from. + When extracting peak parameters, an additional column is appended to the + returned array, indicating the index that the peak came from. """ - return [get_group_params(gres, self.modes, name, field) for gres in self.event_group_results] + return [get_group_params(gres, self.modes, component, field) \ + for gres in self.event_group_results] def convert_results(self): diff --git a/specparam/plts/annotate.py b/specparam/plts/annotate.py index 5a8df135..0320fa66 100644 --- a/specparam/plts/annotate.py +++ b/specparam/plts/annotate.py @@ -3,7 +3,7 @@ import numpy as np from specparam.utils.select import nearest_ind -from specparam.data.periodic import get_band_peak +from specparam.data.periodic import get_band_peak, sort_peaks from specparam.measures.params import compute_knee_frequency, compute_fwhm from specparam.modutils.errors import NoModelError from specparam.modutils.dependencies import safe_import, check_dependency @@ -18,6 +18,56 @@ ################################################################################################### ################################################################################################### +def _recompute_flatspec(model, remove_peaks=0): + """Helper function to recompute the initial flattened spectrum from model fitting. + + Parameters + ---------- + model : SpectralModel + Model object, with model fit, data and settings available. + remove_peaks : int, optional, default: 0 + Number of peak iterations to remove from the flattened spectrum. + + Returns + ------- + flatspec : 1d array + Flattened spectrum. + """ + + flatspec = model.data.power_spectrum - \ + model.modes.aperiodic.func(model.data.freqs, \ + *model.algorithm._robust_ap_fit(model.data.freqs, model.data.power_spectrum)) + + for peak_ind in range(remove_peaks): + flatspec = _remove_flatspec_peak(model, flatspec, peak_ind) + + return flatspec + + +def _remove_flatspec_peak(model, flatspec, peak_ind): + """Helper function to remove peaks from flattened spectrum. + + Parameters + ---------- + model : SpectralModel + Model object, with model fit, data and settings available. + flatspec : 1d array + Flattened spectrum. + peak_ind : int + Index of the peak to remove from the flattened spectrum. + + Returns + ------- + flatspec : 1d array + Flattened spectrum, with peak(s) removed. + """ + + peak_fit_params = sort_peaks(model.results.params.periodic.get_params('fit'), 'PW', 'dec') + flatspec = flatspec - model.modes.periodic.func(model.data.freqs, *peak_fit_params[peak_ind, :]) + + return flatspec + + @savefig @check_dependency(plt, 'matplotlib') def plot_annotated_peak_search(model): @@ -29,50 +79,70 @@ def plot_annotated_peak_search(model): Model object, with model fit, data and settings available. """ - # Recalculate the initial aperiodic fit and flattened spectrum that - # is the same as the one that is used in the peak fitting procedure - flatspec = model.data.power_spectrum - \ - model.modes.aperiodic.func(model.data.freqs, \ - *model.algorithm._robust_ap_fit(model.data.freqs, model.data.power_spectrum),) - # Calculate ylims of the plot that are scaled to the range of the data - ylims = [min(flatspec) - 0.1 * np.abs(min(flatspec)), max(flatspec) + 0.1 * max(flatspec)] - - # Sort parameters by peak height - gaussian_params = model.results.gaussian_params_[\ - model.results.gaussian_params_[:, 1].argsort()][::-1] + flatspec = _recompute_flatspec(model) + ylim = [min(flatspec) - 0.1 * np.abs(min(flatspec)), + max(flatspec) + 0.1 * max(flatspec)] # Loop through the iterative search for each peak - for ind in range(model.results.n_peaks_ + 1): + for peak_ind in range(model.results.n_peaks + 1): + plot_individual_peak_search(model, peak_ind, flatspec, ylim=ylim) + if peak_ind != model.results.n_peaks: + flatspec = _remove_flatspec_peak(model, flatspec, peak_ind) - # This forces the creation of a new plotting axes per iteration - ax = check_ax(None, PLT_FIGSIZES['spectral']) - plot_spectra(model.data.freqs, flatspec, linewidth=2.5, - label='Flattened Spectrum', color=PLT_COLORS['data'], ax=ax) - plot_spectra(model.data.freqs, - [model.algorithm.peak_threshold * np.std(flatspec)] * len(model.data.freqs), - label='Relative Threshold', color='orange', linewidth=2.5, - linestyle='dashed', ax=ax) - plot_spectra(model.data.freqs, [model.algorithm.min_peak_height]*len(model.data.freqs), - label='Absolute Threshold', color='red', linewidth=2.5, - linestyle='dashed', ax=ax) - - maxi = np.argmax(flatspec) - ax.plot(model.data.freqs[maxi], flatspec[maxi], '.', - color=PLT_COLORS['periodic'], alpha=0.75, markersize=30) +@savefig +@check_dependency(plt, 'matplotlib') +def plot_individual_peak_search(model, iteration, flatspec=None, ax=None, **plt_kwargs): + """Plot the process of detecting and fitting an individual peak. - ax.set_ylim(ylims) - ax.set_title('Iteration #' + str(ind+1), fontsize=16) + Parameters + ---------- + model : SpectralModel + Model object, with model fit, data and settings available. + iteration : int + Which peak iteration to plot. + flatspec : array, optional + xx + plt_kwargs + Keyword arguments for managing the plot. + """ - if ind < model.results.n_peaks_: + if not model.results.has_model: + raise NoModelError("No model is available to plot, can not proceed.") - gauss = model.modes.periodic.func(model.data.freqs, *gaussian_params[ind, :]) - plot_spectra(model.data.freqs, gauss, ax=ax, label='Gaussian Fit', - color=PLT_COLORS['periodic'], linestyle=':', linewidth=3.0) + if flatspec is None: + flatspec = _recompute_flatspec(model, iteration) - flatspec = flatspec - gauss + ax = check_ax(ax, PLT_FIGSIZES['spectral']) + plot_spectra(model.data.freqs, flatspec, linewidth=2.5, + label='Flattened Spectrum', color=PLT_COLORS['data'], ax=ax) + plot_spectra(model.data.freqs, + [model.algorithm.settings.peak_threshold * np.std(flatspec)] \ + * len(model.data.freqs), + label='Relative Threshold', color='orange', linewidth=2.5, + linestyle='dashed', ax=ax) + plot_spectra(model.data.freqs, + [model.algorithm.settings.min_peak_height] * len(model.data.freqs), + label='Absolute Threshold', color='red', linewidth=2.5, + linestyle='dashed', ax=ax) + + maxi = np.argmax(flatspec) + ax.plot(model.data.freqs[maxi], flatspec[maxi], '.', + color=PLT_COLORS['periodic'], alpha=0.75, markersize=30) + + ax.set_ylim(plt_kwargs.get('ylim', None)) + ax.set_title(plt_kwargs.get('title', 'Iteration #' + str(iteration+1)), fontsize=16) + + if iteration < model.results.n_peaks: + + peak_fit_params = sort_peaks(model.results.params.periodic.get_params('fit'), 'PW', 'dec') + cpeak = model.modes.periodic.func(model.data.freqs, *peak_fit_params[iteration, :]) + plot_spectra(model.data.freqs, cpeak, ax=ax, label='Gaussian Fit', + color=PLT_COLORS['periodic'], linestyle=':', linewidth=3.0) + + if plt_kwargs.get('restyle', True) is not False: style_spectrum_plot(ax, False, True) @@ -136,10 +206,10 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, # See: https://github.com/matplotlib/matplotlib/issues/12820. Fixed in 3.2.1. bug_buff = 0.000001 - if annotate_peaks and model.results.n_peaks_: + if annotate_peaks and model.results.n_peaks: - # Extract largest peak, to annotate, grabbing gaussian params - gauss = get_band_peak(model, model.data.freq_range, attribute='gaussian_params') + # Extract largest peak, to annotate, grabbing peak fit params + gauss = get_band_peak(model, model.data.freq_range, attribute='fit') peak_ctr, peak_hgt, peak_wid = gauss bw_freqs = [peak_ctr - 0.5 * compute_fwhm(peak_wid), @@ -183,7 +253,7 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, # Annotate Aperiodic Offset # Add a line to indicate offset, without adjusting plot limits below it ax.set_autoscaley_on(False) - ax.plot([freqs[0], freqs[0]], [ax.get_ylim()[0], model.results.modeled_spectrum_[0]], + ax.plot([freqs[0], freqs[0]], [ax.get_ylim()[0], model.results.model.modeled_spectrum[0]], color=PLT_COLORS['aperiodic'], linewidth=lw2, alpha=0.5) ax.annotate('Offset', xy=(freqs[0]+bug_buff, model.data.power_spectrum[0]-y_buff1), diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 97e2df39..3b99df68 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -86,5 +86,5 @@ def plot_event_model(event, **plot_kwargs): title='Fit Quality' if ind == 0 else None, drop_xticks=ind < len(event.results.metrics), add_xlabel=ind == len(event.results.metrics), - color=PARAM_COLORS[event.results.metrics.types[ind]], + color=PARAM_COLORS[event.results.metrics.categories[ind]], xlim=xlim, ax=next(axes)) diff --git a/specparam/plts/group.py b/specparam/plts/group.py index a2489b5d..a3e00bd0 100644 --- a/specparam/plts/group.py +++ b/specparam/plts/group.py @@ -77,11 +77,11 @@ def plot_group_aperiodic(group, ax=None, **plot_kwargs): """ if group.modes.aperiodic.name == 'knee': - plot_scatter_2(group.results.get_params('aperiodic_params', 'exponent'), 'Exponent', - group.results.get_params('aperiodic_params', 'knee'), 'Knee', + plot_scatter_2(group.results.get_params('aperiodic', 'exponent'), 'Exponent', + group.results.get_params('aperiodic', 'knee'), 'Knee', 'Aperiodic Fit', ax=ax) else: - plot_scatter_1(group.results.get_params('aperiodic_params', 'exponent'), 'Exponent', + plot_scatter_1(group.results.get_params('aperiodic', 'exponent'), 'Exponent', 'Aperiodic Fit', ax=ax) @@ -103,11 +103,13 @@ def plot_group_goodness(group, ax=None, **plot_kwargs): # Get indices of metrics to plot err_ind = find_first_ind(group.results.metrics.labels, 'error') + err_label = group.results.metrics.labels[err_ind] gof_ind = find_first_ind(group.results.metrics.labels, 'gof') + gof_label = group.results.metrics.labels[gof_ind] - plot_scatter_2(group.results.get_params('metrics', group.results.metrics.labels[err_ind]), + plot_scatter_2(group.results.get_metrics(err_label), group.results.metrics.flabels[err_ind], - group.results.get_params('metrics', group.results.metrics.labels[gof_ind]), + group.results.get_metrics(gof_label), group.results.metrics.flabels[gof_ind], 'Fit Quality', ax=ax) @@ -128,5 +130,5 @@ def plot_group_peak_frequencies(group, ax=None, **plot_kwargs): Additional plot related keyword arguments, with styling options managed by ``style_plot``. """ - plot_hist(group.results.get_params('peak_params', 0)[:, 0], 'Center Frequency', + plot_hist(group.results.get_params('peak', 0)[:, 0], 'Center Frequency', 'Peaks - Center Frequencies', x_lims=group.data.freq_range, ax=ax) diff --git a/specparam/plts/model.py b/specparam/plts/model.py index 6a2ca633..1f547035 100644 --- a/specparam/plts/model.py +++ b/specparam/plts/model.py @@ -8,7 +8,6 @@ import numpy as np from specparam.modutils.dependencies import safe_import, check_dependency -from specparam.sim.gen import gen_periodic from specparam.utils.select import nearest_ind from specparam.utils.spectral import trim_spectrum from specparam.measures.params import compute_fwhm @@ -87,7 +86,7 @@ def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_sp model_defaults = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, 'label' : 'Full Model Fit' if add_legend else None} model_kwargs = check_plot_kwargs(model_kwargs, model_defaults) - plot_spectra(model.data.freqs, model.results.modeled_spectrum_, + plot_spectra(model.data.freqs, model.results.model.modeled_spectrum, log_freqs, log_powers, ax=ax, **model_kwargs) # Plot the aperiodic component of the model fit @@ -96,7 +95,7 @@ def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_sp 'alpha' : 0.5, 'linestyle' : 'dashed', 'label' : 'Aperiodic Fit' if add_legend else None} aperiodic_kwargs = check_plot_kwargs(aperiodic_kwargs, aperiodic_defaults) - plot_spectra(model.data.freqs, model.results._ap_fit, + plot_spectra(model.data.freqs, model.results.model._ap_fit, log_freqs, log_powers, ax=ax, **aperiodic_kwargs) # Plot the periodic components of the model fit @@ -169,13 +168,12 @@ def _add_peaks_shade(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.results.gaussian_params_: + for peak in model.results.params.periodic.get_params('fit'): peak_freqs = np.log10(model.data.freqs) if plt_log else model.data.freqs - #peak_line = model.results._ap_fit + gen_periodic(model.data.freqs, peak) - peak_line = model.results._ap_fit + model.modes.periodic.func(model.data.freqs, *peak) + peak_line = model.results.model._ap_fit + model.modes.periodic.func(model.data.freqs, *peak) - ax.fill_between(peak_freqs, peak_line, model.results._ap_fit, **plot_kwargs) + ax.fill_between(peak_freqs, peak_line, model.results.model._ap_fit, **plot_kwargs) def _add_peaks_dot(model, plt_log, ax, **plot_kwargs): @@ -196,9 +194,9 @@ def _add_peaks_dot(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.results.peak_params_: + for peak in model.results.params.periodic.get_params('fit'): - ap_point = np.interp(peak[0], model.data.freqs, model.results._ap_fit) + ap_point = np.interp(peak[0], model.data.freqs, model.results.model._ap_fit) freq_point = np.log10(peak[0]) if plt_log else peak[0] # Add the line from the aperiodic fit up the tip of the peak @@ -226,14 +224,13 @@ def _add_peaks_outline(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.5} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.results.gaussian_params_: + for peak in model.results.params.periodic.get_params('fit'): # Define the frequency range around each peak to plot - peak bandwidth +/- 3 peak_range = [peak[0] - peak[2]*3, peak[0] + peak[2]*3] # Generate a peak reconstruction for each peak, and trim to desired range - #peak_line = model.results._ap_fit + gen_periodic(model.data.freqs, peak) - peak_line = model.results._ap_fit + model.modes.periodic.func(model.data.freqs, *peak) + peak_line = model.results.model._ap_fit + model.modes.periodic.func(model.data.freqs, *peak) peak_freqs, peak_line = trim_spectrum(model.data.freqs, peak_line, peak_range) # Plot the peak outline @@ -261,7 +258,7 @@ def _add_peaks_line(model, plt_log, ax, **plot_kwargs): ylims = ax.get_ylim() - for peak in model.results.peak_params_: + for peak in model.results.params.periodic.get_params('fit'): freq_point = np.log10(peak[0]) if plt_log else peak[0] ax.plot([freq_point, freq_point], ylims, '-', **plot_kwargs) @@ -291,7 +288,7 @@ def _add_peaks_width(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.results.gaussian_params_: + for peak in model.results.params.periodic.get_params('fit'): peak_top = model.data.power_spectrum[nearest_ind(model.data.freqs, peak[0])] bw_freqs = [peak[0] - 0.5 * compute_fwhm(peak[2]), diff --git a/specparam/plts/time.py b/specparam/plts/time.py index eece420f..6fa435d0 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -74,6 +74,6 @@ def plot_time_model(time, **plot_kwargs): time.results.time_results[time.results.metrics.labels[gof_ind]]], labels=[time.results.metrics.flabels[err_ind], time.results.metrics.flabels[gof_ind]], - colors=[PARAM_COLORS[time.results.metrics.types[err_ind]], - PARAM_COLORS[time.results.metrics.types[gof_ind]]], + colors=[PARAM_COLORS[time.results.metrics.categories[err_ind]], + PARAM_COLORS[time.results.metrics.categories[gof_ind]]], xlim=xlim, title='Fit Quality', ax=next(axes)) diff --git a/specparam/plts/utils.py b/specparam/plts/utils.py index 83d2f018..29078d09 100644 --- a/specparam/plts/utils.py +++ b/specparam/plts/utils.py @@ -93,7 +93,7 @@ def add_shades(ax, shades, colors='r', shade_alpha=0.2, shades = [shades] colors = repeat(colors) if not isinstance(colors, list) else colors - shade_alphas = repeat(shade_alpha) if not isinstance(shade_alpha, list) else alpha + shade_alphas = repeat(shade_alpha) if not isinstance(shade_alpha, list) else shade_alpha for shade, color, alpha in zip(shades, colors, shade_alphas): diff --git a/specparam/reports/strings.py b/specparam/reports/strings.py index cea9261b..4af46ff1 100644 --- a/specparam/reports/strings.py +++ b/specparam/reports/strings.py @@ -211,9 +211,9 @@ def gen_settings_str(model, description=False, concise=False): # Loop through algorithm settings, and add information for name in model.algorithm.settings.names: - str_lst.append(name + ' : ' + str(getattr(model.algorithm, name))) + str_lst.append(name + ' : ' + str(getattr(model.algorithm.settings, name))) if description: - str_lst.append(model.algorithm.settings.descriptions[name].split('\n ')[0]) + str_lst.append(model.algorithm.public_settings.descriptions[name].split('\n ')[0]) # Add footer to string str_lst.extend([ @@ -337,10 +337,10 @@ def gen_methods_text_str(model=None): methods_str = template.format(MODULE_VERSION, model.modes.aperiodic.name if model else 'XX', model.modes.periodic.name if model else 'XX', - model.algorithm.peak_width_limits if model else 'XX', - model.algorithm.max_n_peaks if model else 'XX', - model.algorithm.min_peak_height if model else 'XX', - model.algorithm.peak_threshold if model else 'XX', + model.algorithm.settings.peak_width_limits if model else 'XX', + model.algorithm.settings.max_n_peaks if model else 'XX', + model.algorithm.settings.min_peak_height if model else 'XX', + model.algorithm.settings.peak_threshold if model else 'XX', *freq_range) return methods_str @@ -388,13 +388,13 @@ def gen_model_results_str(model, concise=False): 'Aperiodic Parameters (\'{}\' mode)'.format(model.modes.aperiodic.name), '(' + ', '.join(model.modes.aperiodic.params.labels) + ')', ', '.join(['{:2.4f}'] * \ - len(model.results.aperiodic_params_)).format(*model.results.aperiodic_params_), + len(model.results.params.aperiodic.params)).format(*model.results.params.aperiodic.params), '', # Peak parameters 'Peak Parameters (\'{}\' mode) {} peaks found'.format(\ - model.modes.periodic.name, model.results.n_peaks_), - *[peak_str.format(*op) for op in model.results.peak_params_], + model.modes.periodic.name, model.results.n_peaks), + *[peak_str.format(*op) for op in model.results.params.periodic.params], '', # Metrics @@ -449,21 +449,21 @@ def gen_group_results_str(group, concise=False): 'Aperiodic Parameters (\'{}\' mode)'.format(group.modes.aperiodic.name), *[el for el in [\ '{:8s} - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}'.format(label, \ - *compute_arr_desc(group.results.get_params('aperiodic_params', label))) \ + *compute_arr_desc(group.results.get_params('aperiodic', label))) \ for label in group.modes.aperiodic.params.labels]], '', # Peak Parameters 'Peak Parameters (\'{}\' mode) {} total peaks found'.format(\ - group.modes.periodic.name, sum(group.results.n_peaks_)), + group.modes.periodic.name, sum(group.results.n_peaks)), '', # Metrics 'Model fit quality metrics:', *['{:>18s} - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'.format(\ - '{:s} ({:s})'.format(*key.split('_')), - *compute_arr_desc(group.results.get_params('metrics', key))) \ - for key in group.results.metrics.results], + '{:s} ({:s})'.format(*label.split('_')), + *compute_arr_desc(group.results.get_metrics(label))) \ + for label in group.results.metrics.labels], '', # Footer @@ -651,7 +651,7 @@ def _report_str_n_null(model): output = \ [el for el in ['{} power spectra failed to fit'.format(\ - model.results.n_null_)] if model.results.n_null_] + model.results.n_null)] if model.results.n_null] return output diff --git a/specparam/sim/params.py b/specparam/sim/params.py index 04223559..d5720d90 100644 --- a/specparam/sim/params.py +++ b/specparam/sim/params.py @@ -7,7 +7,6 @@ from specparam.data import SimParams from specparam.modes.modes import check_mode_definition from specparam.modes.definitions import AP_MODES -from specparam.utils.select import groupby from specparam.utils.checks import check_flat from specparam.modutils.errors import InconsistentDataError @@ -33,7 +32,6 @@ def collect_sim_params(aperiodic_params, periodic_params, nlv): """ return SimParams(deepcopy(aperiodic_params), - #sorted(groupby(check_flat(periodic_params), 3)), deepcopy(periodic_params), nlv) diff --git a/specparam/tests/algorithms/test_algorithm.py b/specparam/tests/algorithms/test_algorithm.py index d16f0668..f46f4cb3 100644 --- a/specparam/tests/algorithms/test_algorithm.py +++ b/specparam/tests/algorithms/test_algorithm.py @@ -1,5 +1,6 @@ """Tests for specparam.algorthms.algorithm.""" +from specparam.modes.modes import Modes from specparam.algorithms.settings import SettingsDefinition from specparam.algorithms.algorithm import * @@ -16,12 +17,14 @@ def test_algorithm(): 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, }) - algo = Algorithm(name=tname, description=tdescription, settings=tsettings, format='spectrum') + algo = Algorithm(name=tname, description=tdescription, public_settings=tsettings) assert algo assert algo.name == tname assert algo.description == tdescription - assert isinstance(algo.settings, SettingsDefinition) - assert algo.settings == tsettings + assert isinstance(algo.public_settings, SettingsDefinition) + assert algo.public_settings == tsettings + for setting in algo.public_settings.names: + assert getattr(algo.settings, setting) is None def test_algorithm_settings(): @@ -32,14 +35,46 @@ def test_algorithm_settings(): 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, }) - talgo = Algorithm(name=tname, description=tdescription, settings=tsettings, format='spectrum') + talgo = Algorithm(name=tname, description=tdescription, public_settings=tsettings) - model_settings = talgo.settings.make_model_settings() + model_settings = talgo.public_settings.make_model_settings() settings = model_settings(a=1, b=2) talgo.add_settings(settings) for setting in settings._fields: - assert getattr(talgo, setting) == getattr(settings, setting) + assert getattr(talgo.settings, setting) == getattr(settings, setting) settings_out = talgo.get_settings() assert isinstance(settings, model_settings) assert settings_out == settings + +def test_algorithm_cf(): + + tname = 'test_algo' + tdescription = 'Test algorithm description' + tsettings = SettingsDefinition({ + 'a' : {'type' : 'a type desc', 'description' : 'a desc'}, + 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, + }) + + algo = AlgorithmCF(name=tname, description=tdescription, public_settings=tsettings) + + assert isinstance(algo._cf_settings_desc, SettingsDefinition) + assert algo._cf_settings + for setting in algo._cf_settings.names: + assert getattr(algo._cf_settings, setting) is None + +def test_algorithm_cf_initialize(): + + algo = AlgorithmCF(name='test_algo', description='desc', + public_settings={'a' : {'type' : 'a type desc', 'description' : 'a desc'}}, + modes=Modes('fixed', 'gaussian')) + + ap_bounds = algo._initialize_bounds('aperiodic') + assert len(ap_bounds[0]) == algo.modes.aperiodic.n_params + pe_bounds = algo._initialize_bounds('periodic') + assert len(pe_bounds[0]) == algo.modes.periodic.n_params + + ap_guess = algo._initialize_guess('aperiodic') + assert len(ap_guess) == algo.modes.aperiodic.n_params + pe_guess = algo._initialize_guess('periodic') + assert len(pe_guess) == algo.modes.periodic.n_params diff --git a/specparam/tests/algorithms/test_settings.py b/specparam/tests/algorithms/test_settings.py index 30501a86..4e896fbb 100644 --- a/specparam/tests/algorithms/test_settings.py +++ b/specparam/tests/algorithms/test_settings.py @@ -5,18 +5,36 @@ ################################################################################################### ################################################################################################### +def test_settings_values(): + + tsettings_names = ['a', 'b'] + settings_vals = SettingsValues(tsettings_names) + assert isinstance(settings_vals.values, dict) + assert settings_vals.names == tsettings_names + assert settings_vals.a is None + assert settings_vals.b is None + settings_vals.a = 1 + settings_vals.b = 2 + assert settings_vals.a == 1 + assert settings_vals.b == 2 + + settings_vals.clear() + assert settings_vals.a is None + assert settings_vals.b is None + def test_settings_definition(): - tsettings = { + tdefinitions = { 'a' : {'type' : 'a type desc', 'description' : 'a desc'}, 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, } - settings = SettingsDefinition(tsettings) - assert settings._settings == tsettings - assert settings.names == list(tsettings.keys()) - assert settings.types - assert settings.descriptions - for label in tsettings.keys(): - assert settings.make_setting_str(label) - assert settings.make_docstring() + settings_def = SettingsDefinition(tdefinitions) + assert settings_def._definitions == tdefinitions + assert len(settings_def) == len(tdefinitions) + assert settings_def.names == list(tdefinitions.keys()) + assert settings_def.types + assert settings_def.descriptions + for label in tdefinitions.keys(): + assert settings_def.make_setting_str(label) + assert settings_def.make_docstring() diff --git a/specparam/tests/data/test_conversions.py b/specparam/tests/data/test_conversions.py index b53361bb..28889aa4 100644 --- a/specparam/tests/data/test_conversions.py +++ b/specparam/tests/data/test_conversions.py @@ -18,13 +18,13 @@ def test_model_to_dict(tresults, tmodes, tbands): out = model_to_dict(tresults, tmodes, Bands(n_bands=1)) assert isinstance(out, dict) assert 'cf_0' in out - assert out['cf_0'] == tresults.peak_params[0, 0] + assert out['cf_0'] == tresults.peak_converted[0, 0] assert 'cf_1' not in out out = model_to_dict(tresults, tmodes, Bands(n_bands=2)) assert 'cf_0' in out assert 'cf_1' in out - assert out['cf_1'] == tresults.peak_params[1, 0] + assert out['cf_1'] == tresults.peak_converted[1, 0] out = model_to_dict(tresults, tmodes, Bands(n_bands=3)) assert 'cf_2' in out diff --git a/specparam/tests/data/test_data.py b/specparam/tests/data/test_data.py index e9c583f5..22ed53e7 100644 --- a/specparam/tests/data/test_data.py +++ b/specparam/tests/data/test_data.py @@ -1,5 +1,7 @@ """Tests for the specparam.data.data.""" +import numpy as np + from specparam.data.data import * ################################################################################################### @@ -31,7 +33,8 @@ def test_model_checks(): def test_fit_results(): - results = FitResults([1, 1], [10, 0.5, 1], [10, 0.5, 0.5], {'a' : 0.95, 'b' : 0.05}) + results = FitResults(\ + [1, 1], [np.nan, np.nan], [10, 0.5, 1], [10, 0.5, 0.5], {'a' : 0.95, 'b' : 0.05}) assert results for field in FitResults._fields: diff --git a/specparam/tests/data/test_periodic.py b/specparam/tests/data/test_periodic.py index 41c20d9b..3f1313f9 100644 --- a/specparam/tests/data/test_periodic.py +++ b/specparam/tests/data/test_periodic.py @@ -75,6 +75,19 @@ def test_threshold_peaks(): data = np.array([[10, 1, 1.8, 0], [13, 1, 2, 2], [14, 2, 4, 2]]) assert np.array_equal(threshold_peaks(data, 1.5), np.array([[14, 2, 4, 2]])) +def test_sort_peaks(): + + # With original order of {A B C}, these should get sorted differently for each param + # CF: A, C, B; PW: B, A, C; BW: C, B, A + tpeaks = np.array([[5, 2, 8], [15, 1, 7], [10, 3, 6]]) + + assert np.array_equal(sort_peaks(tpeaks, 'CF', 'inc')[:, 0], np.array([5, 10, 15])) + assert np.array_equal(sort_peaks(tpeaks, 'CF', 'dec')[:, 0], np.array([15, 10, 5])) + assert np.array_equal(sort_peaks(tpeaks, 'PW', 'inc')[:, 1], np.array([1, 2, 3])) + assert np.array_equal(sort_peaks(tpeaks, 'PW', 'dec')[:, 1], np.array([3, 2, 1])) + assert np.array_equal(sort_peaks(tpeaks, 'BW', 'inc')[:, 2], np.array([6, 7, 8])) + assert np.array_equal(sort_peaks(tpeaks, 'BW', 'dec')[:, 2], np.array([8, 7, 6])) + def test_empty_inputs(): data = np.empty(shape=[0, 3]) @@ -82,6 +95,7 @@ def test_empty_inputs(): assert np.all(get_band_peak_arr(data, [8, 12])) assert np.all(get_highest_peak(data)) assert np.all(threshold_peaks(data, 1)) + assert np.all(sort_peaks(data, 1)) data = np.empty(shape=[0, 4]) diff --git a/specparam/tests/data/test_utils.py b/specparam/tests/data/test_utils.py index 838a9c55..10c25149 100644 --- a/specparam/tests/data/test_utils.py +++ b/specparam/tests/data/test_utils.py @@ -9,40 +9,40 @@ def test_get_model_params(tresults, tmodes): - for dname in ['aperiodic_params', 'aperiodic', 'peak_params', 'peak', - 'gaussian_params', 'gaussian', 'metrics']: - assert np.any(get_model_params(tresults, tmodes, dname)) + for component in tmodes.components: - if dname == 'aperiodic_params' or dname == 'aperiodic': - for dtype in ['offset', 'exponent']: - assert np.any(get_model_params(tresults, tmodes, dname, dtype)) + assert np.any(get_model_params(tresults, tmodes, component)) - if dname == 'peak_params' or dname == 'peak': - for dtype in ['CF', 'PW', 'BW']: - assert np.any(get_model_params(tresults, tmodes, dname, dtype)) + for param in getattr(tmodes, component).params.labels: - if dname == 'metrics': - for dtype in ['error_mae', 'gof_rsquared']: - assert np.any(get_model_params(tresults, tmodes, dname, dtype)) + assert np.any(get_model_params(tresults, tmodes, component, param)) def test_get_group_params(tresults, tmodes): gresults = [tresults, tresults] - for dname in ['aperiodic_params', 'peak_params', 'gaussian_params', 'metrics']: - assert np.any(get_group_params(gresults, tmodes, dname)) + for component in tmodes.components: - if dname == 'aperiodic_params': - for dtype in ['offset', 'exponent']: - assert np.any(get_group_params(gresults, tmodes, dname, dtype)) + assert np.any(get_group_params(gresults, tmodes, component)) - if dname == 'peak_params': - for dtype in ['CF', 'PW', 'BW']: - assert np.any(get_group_params(gresults, tmodes, dname, dtype)) + for param in getattr(tmodes, component).params.labels: - if dname == 'metrics': - for dtype in ['error_mae', 'gof_rsquared']: - assert np.any(get_group_params(gresults, tmodes, dname, dtype)) + assert np.any(get_group_params(gresults, tmodes, component, param)) + +def test_get_group_metrics(tresults): + + gresults = [tresults, tresults] + measures = {'error' : 'mae', 'gof' : 'rsquared'} + + for metric in measures.keys(): + + out1 = get_group_metrics(gresults, metric) + assert np.all(out1) + assert len(out1) == len(gresults) + + out2 = get_group_metrics(gresults, metric, measures[metric]) + assert np.all(out2) + assert len(out2) == len(gresults) def test_get_results_by_ind(): diff --git a/specparam/tests/io/test_models.py b/specparam/tests/io/test_models.py index a22c7594..f3047be3 100644 --- a/specparam/tests/io/test_models.py +++ b/specparam/tests/io/test_models.py @@ -151,117 +151,119 @@ def test_load_file_contents(tfm): """Check that loaded model files contain the contents they should.""" # Loads file saved from `test_save_model_str` - file_name = 'test_model_all' - - loaded_data = load_json(file_name, TEST_DATA_PATH) + loaded_data = load_json('test_model_all', TEST_DATA_PATH) for mode in tfm.modes.get_modes()._fields: assert mode in loaded_data.keys() + assert 'bands' in loaded_data.keys() + for setting in tfm.algorithm.settings.names: assert setting in loaded_data.keys() - for result in tfm.results._fields: - assert result in loaded_data.keys() + + for rescomp in ['aperiodic', 'peak']: + for version in ['fit', 'converted']: + assert rescomp + '_' + version in loaded_data.keys() + assert 'metrics' in loaded_data.keys() + for datum in tfm.data._fields: assert datum in loaded_data.keys() def test_load_model(tfm): # Loads file saved from `test_save_model_str` - file_name = 'test_model_all' - ntfm = load_model(file_name, TEST_DATA_PATH) + ntfm = load_model('test_model_all', TEST_DATA_PATH) + assert isinstance(ntfm, SpectralModel) + compare_model_objs([tfm, ntfm], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) - # Check that all elements get loaded - assert tfm.modes.get_modes() == ntfm.modes.get_modes() - assert tfm.results.bands == ntfm.results.bands - for meta_dat in tfm.data._meta_fields: - assert getattr(ntfm.data, meta_dat) is not None - for setting in ntfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is not None - for result in tfm.results._fields: - assert not np.all(np.isnan(getattr(ntfm.results, result))) - assert tfm.results.metrics.results == ntfm.results.metrics.results for data in tfm.data._fields: - assert getattr(ntfm.data, data) is not None + assert np.array_equal(getattr(tfm.data, data), getattr(ntfm.data, data)) + + for component in tfm.modes.components: + assert not np.all(np.isnan(getattr(ntfm.results.params, component).get_params('fit'))) + assert tfm.results.metrics.results == ntfm.results.metrics.results # Check directory matches (loading didn't add any unexpected attributes) cfm = SpectralModel() assert dir(cfm) == dir(ntfm) + assert dir(cfm.algorithm) == dir(ntfm.algorithm) assert dir(cfm.data) == dir(ntfm.data) assert dir(cfm.results) == dir(ntfm.results) + assert dir(cfm.results.params) == dir(ntfm.results.params) def test_load_model2(tfm2): # Loads file saved from `test_save_model_str2` - file_name = 'test_model_all2' - ntfm2 = load_model(file_name, TEST_DATA_PATH) - assert tfm2.modes.get_modes() == ntfm2.modes.get_modes() - compare_model_objs([tfm2, ntfm2], ['settings', 'meta_data', 'metrics']) + ntfm2 = load_model('test_model_all2', TEST_DATA_PATH) + compare_model_objs([tfm2, ntfm2], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) def test_load_group(tfg): # Loads file saved from `test_save_group` - file_name = 'test_group_all' - ntfg = load_group(file_name, TEST_DATA_PATH) + ntfg = load_group('test_group_all', TEST_DATA_PATH) assert isinstance(ntfg, SpectralGroupModel) - - # Check that all elements get loaded - assert tfg.modes.get_modes() == ntfg.modes.get_modes() - assert tfg.results.bands == ntfg.results.bands - for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is not None + compare_model_objs([tfg, ntfg], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + for data in tfg.data._fields: + assert np.array_equal(getattr(tfg.data, data), getattr(ntfg.data, data)) assert len(ntfg.results.group_results) > 0 for metric in tfg.results.metrics.labels: - assert tfg.results.metrics.results[metric] is not None - assert ntfg.data.power_spectra is not None - for meta_dat in tfg.data._meta_fields: - assert getattr(ntfg.data, meta_dat) is not None + assert np.array_equal(tfg.results.get_metrics(metric), ntfg.results.get_metrics(metric)) # Check directory matches (loading didn't add any unexpected attributes) cfg = SpectralGroupModel() assert dir(cfg) == dir(ntfg) + assert dir(cfg.algorithm) == dir(ntfg.algorithm) assert dir(cfg.data) == dir(ntfg.data) assert dir(cfg.results) == dir(ntfg.results) + assert dir(cfg.results.params) == dir(ntfg.results.params) def test_load_group2(tfg2): # Loads file saved from `test_save_group_str2` - file_name = 'test_group_all2' - ntfg2 = load_group(file_name, TEST_DATA_PATH) - assert tfg2.modes.get_modes() == ntfg2.modes.get_modes() - compare_model_objs([tfg2, ntfg2], ['settings', 'meta_data', 'metrics']) + ntfg2 = load_group('test_group_all2', TEST_DATA_PATH) + compare_model_objs([tfg2, ntfg2], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) -def test_load_time(): +def test_load_time(tft): # Loads file saved from `test_save_time` - file_name = 'test_time_all' - - # Load without bands definition - tft = load_time(file_name, TEST_DATA_PATH) + ntft = load_time('test_time_all', TEST_DATA_PATH) assert isinstance(tft, SpectralTimeModel) - assert tft.results.time_results + compare_model_objs([tft, ntft], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + for data in tft.data._fields: + assert np.array_equal(getattr(tft.data, data), getattr(ntft.data, data)) + assert tft.results.time_results.keys() == ntft.results.time_results.keys() + for key in tft.results.time_results: + assert np.array_equal(\ + tft.results.time_results[key], ntft.results.time_results[key], equal_nan=True) # Check directory matches (loading didn't add any unexpected attributes) cft = SpectralTimeModel() - assert dir(cft) == dir(tft) - assert dir(cft.data) == dir(tft.data) - assert dir(cft.results) == dir(tft.results) + assert dir(cft) == dir(ntft) + assert dir(cft.algorithm) == dir(ntft.algorithm) + assert dir(cft.data) == dir(ntft.data) + assert dir(cft.results) == dir(ntft.results) + assert dir(cft.results.params) == dir(ntft.results.params) -def test_load_event(): +def test_load_event(tfe): # Loads file saved from `test_save_event` - file_name = 'test_event_all' - - # Load without bands definition - tfe = load_event(file_name, TEST_DATA_PATH) + ntfe = load_event('test_event_all', TEST_DATA_PATH) assert isinstance(tfe, SpectralTimeEventModel) + compare_model_objs([tfe, ntfe], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + for data in tfe.data._fields: + assert np.array_equal(getattr(tfe.data, data), getattr(ntfe.data, data)) assert len(tfe.results) > 1 - assert tfe.results.event_time_results + assert tfe.results.time_results.keys() == ntfe.results.time_results.keys() + for key in tfe.results.time_results: + assert np.array_equal(\ + tfe.results.time_results[key], ntfe.results.time_results[key], equal_nan=True) # Check directory matches (loading didn't add any unexpected attributes) cfe = SpectralTimeEventModel() - assert dir(cfe) == dir(tfe) - assert dir(cfe.data) == dir(tfe.data) - assert dir(cfe.results) == dir(tfe.results) + assert dir(cfe) == dir(ntfe) + assert dir(cfe.algorithm) == dir(ntfe.algorithm) + assert dir(cfe.data) == dir(ntfe.data) + assert dir(cfe.results) == dir(ntfe.results) + assert dir(cfe.results.params) == dir(ntfe.results.params) diff --git a/specparam/tests/measures/test_error.py b/specparam/tests/measures/test_error.py index 7b3c0650..34bed15a 100644 --- a/specparam/tests/measures/test_error.py +++ b/specparam/tests/measures/test_error.py @@ -7,26 +7,29 @@ def test_compute_mean_abs_error(tfm): - error = compute_mean_abs_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_mean_abs_error(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum) assert isinstance(error, float) def test_compute_mean_squared_error(tfm): - error = compute_mean_squared_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_mean_squared_error(tfm.data.power_spectrum, + tfm.results.model.modeled_spectrum) assert isinstance(error, float) def test_compute_root_mean_squared_error(tfm): - error = compute_root_mean_squared_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_root_mean_squared_error(tfm.data.power_spectrum, + tfm.results.model.modeled_spectrum) assert isinstance(error, float) def test_compute_median_abs_error(tfm): - error = compute_median_abs_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_median_abs_error(tfm.data.power_spectrum, + tfm.results.model.modeled_spectrum) assert isinstance(error, float) def test_compute_error(tfm): for metric in ['mae', 'mse', 'rmse', 'medae']: - error = compute_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_error(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum) assert isinstance(error, float) diff --git a/specparam/tests/measures/test_estimates.py b/specparam/tests/measures/test_estimates.py new file mode 100644 index 00000000..62e3e884 --- /dev/null +++ b/specparam/tests/measures/test_estimates.py @@ -0,0 +1,21 @@ +"""Test functions for specparam.measures.estimates.""" + +from specparam.sim.gen import gen_freqs, gen_noise +from specparam.modes.funcs import gaussian_function +from specparam.measures.params import compute_fwhm + +from specparam.measures.estimates import * + +################################################################################################### +################################################################################################### + +def test_estimate_fwhm(): + + fres = 0.1 + freqs = gen_freqs([1, 40], fres) + gauss_params = [10, 1, 2] + peak = gaussian_function(freqs, *gauss_params) + gen_noise(freqs, 0.01) + + out = estimate_fwhm(peak, np.argmax(peak), fres) + assert isinstance(out, float) + assert np.isclose(out, compute_fwhm(gauss_params[2]), atol=0.5) diff --git a/specparam/tests/measures/test_gof.py b/specparam/tests/measures/test_gof.py index a1990749..6b54a4ee 100644 --- a/specparam/tests/measures/test_gof.py +++ b/specparam/tests/measures/test_gof.py @@ -7,16 +7,17 @@ def test_compute_r_squared(tfm): - r_squared = compute_r_squared(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + r_squared = compute_r_squared(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum) assert isinstance(r_squared, float) def test_compute_adj_r_squared(tfm): - r_squared = compute_adj_r_squared(tfm.data.power_spectrum, tfm.results.modeled_spectrum_, 5) + r_squared = compute_adj_r_squared(tfm.data.power_spectrum, + tfm.results.model.modeled_spectrum, 5) assert isinstance(r_squared, float) def test_compute_gof(tfm): for metric in ['r_squared', 'adj_r_squared']: - gof = compute_gof(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + gof = compute_gof(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum) assert isinstance(gof, float) diff --git a/specparam/tests/models/test_event.py b/specparam/tests/models/test_event.py index 149827be..e5583b69 100644 --- a/specparam/tests/models/test_event.py +++ b/specparam/tests/models/test_event.py @@ -9,6 +9,7 @@ import numpy as np from specparam.models import SpectralGroupModel, SpectralTimeModel +from specparam.models.utils import compare_model_objs from specparam.sim import sim_spectrogram from specparam.modutils.dependencies import safe_import @@ -41,8 +42,8 @@ def test_event_iter(tfe): def test_event_n_properties(tfe): - assert np.all(tfe.results.n_peaks_) - assert np.all(tfe.results.n_params_) + assert np.all(tfe.results.n_peaks) + assert np.all(tfe.results.n_params) def test_event_fit(): @@ -95,26 +96,27 @@ def test_event_report(skip_if_no_mpl): assert tfe -def test_event_load(): - - file_name_res = 'test_event_res' - file_name_set = 'test_event_set' - file_name_dat = 'test_event_dat' +def test_event_load(tfe): # Test loading results - tfe = SpectralTimeEventModel(verbose=False) - tfe.load(file_name_res, TEST_DATA_PATH) - assert tfe.results.event_time_results + ntfe = SpectralTimeEventModel(verbose=False) + ntfe.load('test_event_res', TEST_DATA_PATH) + assert ntfe.results.event_time_results # Test loading settings - tfe = SpectralTimeEventModel(verbose=False) - tfe.load(file_name_set, TEST_DATA_PATH) - assert tfe.algorithm.get_settings() + ntfe = SpectralTimeEventModel(verbose=False) + ntfe.load('test_event_set', TEST_DATA_PATH) + assert ntfe.algorithm.get_settings() # Test loading data - tfe = SpectralTimeEventModel(verbose=False) - tfe.load(file_name_dat, TEST_DATA_PATH) - assert np.all(tfe.data.spectrograms) + ntfe = SpectralTimeEventModel(verbose=False) + ntfe.load('test_event_dat', TEST_DATA_PATH) + assert np.all(ntfe.data.spectrograms) + + # Test loading all elements + ntfe = SpectralTimeEventModel(verbose=False) + ntfe.load('test_event_all', TEST_DATA_PATH) + assert compare_model_objs([tfe, ntfe], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) def test_event_get_model(tfe): @@ -122,8 +124,7 @@ def test_event_get_model(tfe): tfm_null = tfe.get_model() assert tfm_null # Check that settings are copied over properly, but data and results are empty - for setting in tfe.algorithm.settings.names: - assert getattr(tfe.algorithm, setting) == getattr(tfm_null.algorithm, setting) + assert tfe.algorithm.settings.values == tfm_null.algorithm.settings.values assert not tfm_null.data.has_data assert not tfm_null.results.has_model @@ -138,12 +139,14 @@ def test_event_get_model(tfe): assert tfm1 assert tfm1.data.has_data assert tfm1.results.has_model - assert np.all(tfm1.results.modeled_spectrum_) + assert np.all(tfm1.results.model.modeled_spectrum) def test_event_get_params(tfe): - for dname in ['aperiodic', 'peak']: - assert np.any(tfe.results.get_params(dname)) + for component in tfe.modes.components: + assert np.any(tfe.get_params(component)) + for pname in getattr(tfe.modes, component).params.labels: + assert np.any(tfe.get_params(component, pname)) def test_event_get_group(tfe): diff --git a/specparam/tests/models/test_group.py b/specparam/tests/models/test_group.py index 3d918fe3..b706a3d5 100644 --- a/specparam/tests/models/test_group.py +++ b/specparam/tests/models/test_group.py @@ -12,6 +12,7 @@ from numpy.testing import assert_equal from specparam.measures.metrics import METRICS +from specparam.models.utils import compare_model_objs from specparam.modutils.dependencies import safe_import from specparam.sim import sim_group_power_spectra @@ -64,20 +65,20 @@ def test_has_model(tfg): def test_n_properties(tfg): """Test the n_peaks & n_params property attributes.""" - assert np.all(tfg.results.n_peaks_) - assert np.all(tfg.results.n_params_) + assert np.all(tfg.results.n_peaks) + assert np.all(tfg.results.n_params) def test_n_null(tfg): - """Test the n_null_ property attribute.""" + """Test the n_null property attribute.""" # Since there should have been no failed fits, this should return 0 - assert tfg.results.n_null_ == 0 + assert tfg.results.n_null == 0 def test_null_inds(tfg): - """Test the null_inds_ property attribute.""" + """Test the null_inds property attribute.""" # Since there should be no failed fits, this should return an empty list - assert tfg.results.null_inds_ == [] + assert tfg.results.null_inds == [] def test_fit_nk(): """Test group fit, no knee.""" @@ -91,7 +92,7 @@ def test_fit_nk(): assert out assert len(out) == n_spectra - assert np.all(out[1].aperiodic_params) + assert np.all(out[1].aperiodic_fit) def test_fit_nk_noise(): """Test group fit, no knee, on noisy data, to make sure nothing breaks.""" @@ -147,35 +148,40 @@ def test_fg_fail(): """ # Create some noisy spectra that will be hard to fit + n_spectra = 10 fs, ps = sim_group_power_spectra(\ - 10, [3, 6], {'fixed' : [1, 1]}, {'gaussian' : [10, 1, 1]}, nlvs=10) + n_spectra, [3, 6], {'fixed' : [1, 1]}, {'gaussian' : [10, 1, 1]}, nlvs=10) # Use a fg with the max iterations set so low that it will fail to converge ntfg = SpectralGroupModel() - ntfg.algorithm._maxfev = 5 + ntfg.algorithm._cf_settings.maxfev = 5 # Fit models, where some will fail, to see if it completes cleanly ntfg.fit(fs, ps) - # Check that results are all + # Check that results are all properly organized + assert len(ntfg.results) == n_spectra for res in ntfg.results.get_results(): assert res - # Test that get_params works with failed model fits - outs1 = ntfg.results.get_params('aperiodic_params') - outs2 = ntfg.results.get_params('aperiodic_params', 'exponent') - outs3 = ntfg.results.get_params('peak_params') - outs4 = ntfg.results.get_params('peak_params', 0) - outs5 = ntfg.results.get_params('gaussian_params', 2) - - # Test shortcut labels - outs6 = ntfg.results.get_params('aperiodic') - outs6 = ntfg.results.get_params('peak', 'CF') - # Test the property attributes related to null model fits # This checks that they do the right thing when there are null fits (failed fits) - assert ntfg.results.n_null_ > 0 - assert ntfg.results.null_inds_ + assert ntfg.results.n_null > 0 + assert ntfg.results.null_inds + + # Test that get_params works with failed model fits + outs1 = ntfg.results.get_params('aperiodic') + outs2 = ntfg.results.get_params('aperiodic', 'exponent') + outs3 = ntfg.results.get_params('peak') + outs4 = ntfg.results.get_params('peak', 0) + outs5 = ntfg.results.get_params('peak', 'CF') + # TODO + #outs6 = ntfg.results.get_params('peak', 2, version='fit') + + # Check that null ind values are nan + for null_ind in ntfg.results.null_inds: + assert np.isnan(ntfg.results.get_params('aperiodic', 'exponent')[null_ind]) + assert np.isnan(ntfg.results.get_metrics('error', 'mae')[null_ind]) def test_drop(): """Test function to drop results from group object.""" @@ -207,8 +213,8 @@ def test_drop(): assert np.all(np.isnan(list(dropped_fres.metrics.values()))) # Test that a group object that has had inds dropped still works with `get_params` - cfs = tfg.results.get_params('peak_params', 1) - exps = tfg.results.get_params('aperiodic_params', 'exponent') + cfs = tfg.get_params('peak', 1) + exps = tfg.get_params('aperiodic', 'exponent') assert np.all(np.isnan(exps[drop_inds])) assert np.all(np.invert(np.isnan(np.delete(exps, drop_inds)))) @@ -224,7 +230,7 @@ def test_fit_par(): assert out assert len(out) == n_spectra - assert np.all(out[1].aperiodic_params) + assert np.all(out[1].aperiodic_fit) def test_print(tfg): """Check print method (alias).""" @@ -247,21 +253,10 @@ def test_get_results(tfg): def test_get_params(tfg): """Check get_params method.""" - for dname in ['aperiodic_params', 'aperiodic', 'peak_params', 'peak', - 'gaussian_params', 'gaussian', 'metrics']: - assert np.any(tfg.get_params(dname)) - - if dname == 'aperiodic_params' or dname == 'aperiodic': - for dtype in ['offset', 'exponent']: - assert np.any(tfg.get_params(dname, dtype)) - - if dname == 'peak_params' or dname == 'peak': - for dtype in ['CF', 'PW', 'BW']: - assert np.any(tfg.get_params(dname, dtype)) - - if dname == 'metrics': - for dtype in ['error_mae', 'gof_rsquared']: - assert np.any(tfg.get_params(dname, dtype)) + for component in tfg.modes.components: + assert np.any(tfg.get_params(component)) + for pname in getattr(tfg.modes, component).params.labels: + assert np.any(tfg.get_params(component, pname)) @plot_test def test_plot(tfg, skip_if_no_mpl): @@ -273,49 +268,39 @@ def test_load(tfg): """Test load into group object. Note: loads files from test_save_group in specparam/tests/io/test_models.py.""" - file_name_res = 'test_group_res' - file_name_set = 'test_group_set' - file_name_dat = 'test_group_dat' - # Test loading just results ntfg = SpectralGroupModel(verbose=False) - ntfg.load(file_name_res, TEST_DATA_PATH) + ntfg.load('test_group_res', TEST_DATA_PATH) assert len(ntfg.results.group_results) > 0 # Test that settings and data are None for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is None + assert getattr(ntfg.algorithm.settings, setting) is None assert ntfg.data.power_spectra is None # Test loading just settings ntfg = SpectralGroupModel(verbose=False) - ntfg.load(file_name_set, TEST_DATA_PATH) - for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is not None + ntfg.load('test_group_set', TEST_DATA_PATH) + assert tfg.algorithm.settings.values == ntfg.algorithm.settings.values # Test that results and data are None - for result in tfg.results._fields: - assert np.all(np.isnan(getattr(ntfg.results, result))) + for component in tfg.modes.components: + assert not getattr(ntfg.results.params, component).has_params assert ntfg.data.power_spectra is None # Test loading just data ntfg = SpectralGroupModel(verbose=False) - ntfg.load(file_name_dat, TEST_DATA_PATH) + ntfg.load('test_group_dat', TEST_DATA_PATH) assert ntfg.data.has_data # Test that settings and results are None for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is None - for result in tfg.results._fields: - assert np.all(np.isnan(getattr(ntfg.results, result))) + assert getattr(ntfg.algorithm.settings, setting) is None + for component in tfg.modes.components: + assert not getattr(ntfg.results.params, component).has_params # Test loading all elements ntfg = SpectralGroupModel(verbose=False) - file_name_all = 'test_group_all' - ntfg.load(file_name_all, TEST_DATA_PATH) + ntfg.load('test_group_all', TEST_DATA_PATH) + assert compare_model_objs([tfg, ntfg], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) assert len(ntfg.results.group_results) > 0 - for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is not None - assert ntfg.data.has_data - for meta_dat in tfg.data._meta_fields: - assert getattr(ntfg.data, meta_dat) is not None def test_report(skip_if_no_mpl): """Check that running the top level model method runs.""" @@ -335,8 +320,7 @@ def test_get_model(tfg): tfm_null = tfg.get_model() assert tfm_null # Check that settings are copied over properly, but data and results are empty - for setting in tfg.algorithm.settings.names: - assert getattr(tfg.algorithm, setting) == getattr(tfm_null.algorithm, setting) + assert tfg.algorithm.settings.values == tfm_null.algorithm.settings.values assert not tfm_null.data.has_data assert not tfm_null.results.has_model @@ -344,15 +328,15 @@ def test_get_model(tfg): tfm0 = tfg.get_model(0, False) assert tfm0 # Check that settings are copied over properly - for setting in tfg.algorithm.settings.names: - assert getattr(tfg.algorithm, setting) == getattr(tfm0.algorithm, setting) + assert tfg.algorithm.settings.values == tfm0.algorithm.settings.values # Check with regenerating tfm1 = tfg.get_model(1, True) assert tfm1 - # Check that regenerated model is created - for result in tfg.results._fields: - assert np.all(getattr(tfm1.results, result)) + # Check that parameters are copied and that regenerated model is created + for component in tfg.modes.components: + assert getattr(tfm1.results.params, component).has_params + assert np.all(tfm1.results.model.modeled_spectrum) # Test when object has no data (clear a copy of tfg) new_tfg = tfg.copy() @@ -383,9 +367,8 @@ def test_get_group(tfg): assert isinstance(nfg2, SpectralGroupModel) # Check that settings are copied over properly - for setting in tfg.algorithm.settings.names: - assert getattr(tfg.algorithm, setting) == getattr(nfg1.algorithm, setting) - assert getattr(tfg.algorithm, setting) == getattr(nfg2.algorithm, setting) + assert tfg.algorithm.settings.values == nfg1.algorithm.settings.values + assert tfg.algorithm.settings.values == nfg2.algorithm.settings.values # Check that data info is copied over properly for meta_dat in tfg.data._meta_fields: diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index 23210a6e..7b8ee88d 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -14,6 +14,7 @@ from specparam.measures.metrics import METRICS from specparam.sim import gen_freqs, sim_power_spectrum from specparam.modes.definitions import AP_MODES, PE_MODES +from specparam.models.utils import compare_model_objs from specparam.modutils.dependencies import safe_import from specparam.modutils.errors import DataError, NoDataError, InconsistentDataError @@ -56,8 +57,8 @@ def test_has_model(tfm): def test_n_properties(tfm): - assert tfm.results.n_peaks_ - assert tfm.results.n_params_ + assert tfm.results.n_peaks + assert tfm.results.n_params def test_fit_nk(): """Test fit, no knee.""" @@ -71,11 +72,11 @@ def test_fit_nk(): tfm.fit(xs, ys) # Check model results - aperiodic parameters - assert np.allclose(ap_params, tfm.results.aperiodic_params_, [0.5, 0.1]) + assert np.allclose(ap_params, tfm.results.params.aperiodic.params, [0.5, 0.1]) # Check model results - gaussian parameters for ii, gauss in enumerate(groupby(gauss_params, 3)): - assert np.allclose(gauss, tfm.results.gaussian_params_[ii], [2.0, 0.5, 1.0]) + assert np.allclose(gauss, tfm.results.params.periodic.get_params('fit')[ii], [2.0, 0.5, 1.0]) def test_fit_nk_noise(): """Test fit on noisy data, to make sure nothing breaks.""" @@ -102,11 +103,11 @@ def test_fit_knee(): tfm.fit(xs, ys) # Check model results - aperiodic parameters - assert np.allclose(ap_params, tfm.results.aperiodic_params_, [1, 2, 0.2]) + assert np.allclose(ap_params, tfm.results.params.aperiodic.params, [1, 2, 0.2]) # Check model results - gaussian parameters for ii, gauss in enumerate(groupby(gauss_params, 3)): - assert np.allclose(gauss, tfm.results.gaussian_params_[ii], [2.0, 0.5, 1.0]) + assert np.allclose(gauss, tfm.results.params.periodic.get_params('fit')[ii], [2.0, 0.5, 1.0]) def test_fit_default_metrics(): """Test goodness of fit & error metrics, post model fitting.""" @@ -115,7 +116,7 @@ def test_fit_default_metrics(): # Hack fake data with known properties: total error magnitude 2 tfm.data.power_spectrum = np.array([1, 2, 3, 4, 5]) - tfm.results.modeled_spectrum_ = np.array([1, 2, 5, 4, 5]) + tfm.results.model.modeled_spectrum = np.array([1, 2, 5, 4, 5]) # Check default goodness of fit and error measures tfm.results.metrics.compute_metrics(tfm.data, tfm.results) @@ -194,50 +195,45 @@ def test_load(tfm): # Test loading just results ntfm = SpectralModel(verbose=False) - file_name_res = 'test_model_res' - ntfm.load(file_name_res, TEST_DATA_PATH) + ntfm.load('test_model_res', TEST_DATA_PATH) + # Check that result attributes get filled - for result in tfm.results._fields: - assert not np.all(np.isnan(getattr(ntfm.results, result))) + for component in tfm.modes.components: + assert getattr(ntfm.results.params, component).has_params + # Test that settings and data are None for setting in tfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is None + assert getattr(ntfm.algorithm.settings, setting) is None assert ntfm.data.power_spectrum is None # Test loading just settings ntfm = SpectralModel(verbose=False) - file_name_set = 'test_model_set' - ntfm.load(file_name_set, TEST_DATA_PATH) - for setting in tfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is not None + ntfm.load('test_model_set', TEST_DATA_PATH) + assert tfm.algorithm.settings.values == ntfm.algorithm.settings.values # Test that results and data are None - for result in tfm.results._fields: - assert np.all(np.isnan(getattr(ntfm.results, result))) + for component in tfm.modes.components: + assert not getattr(ntfm.results.params, component).has_params assert ntfm.data.power_spectrum is None # Test loading just data ntfm = SpectralModel(verbose=False) - file_name_dat = 'test_model_dat' - ntfm.load(file_name_dat, TEST_DATA_PATH) - assert ntfm.data.power_spectrum is not None + ntfm.load('test_model_dat', TEST_DATA_PATH) + assert ntfm.data.has_data + assert np.array_equal(tfm.data.power_spectrum, ntfm.data.power_spectrum) # Test that settings and results are None for setting in tfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is None - for result in tfm.results._fields: - assert np.all(np.isnan(getattr(ntfm.results, result))) + assert getattr(ntfm.algorithm.settings, setting) is None + for component in tfm.modes.components: + assert not getattr(ntfm.results.params, component).has_params # Test loading all elements ntfm = SpectralModel(verbose=False) - file_name_all = 'test_model_all' - ntfm.load(file_name_all, TEST_DATA_PATH) - for result in tfm.results._fields: - assert not np.all(np.isnan(getattr(ntfm.results, result))) - for setting in tfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is not None + ntfm.load('test_model_all', TEST_DATA_PATH) + assert compare_model_objs([tfm, ntfm], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) for data in tfm.data._fields: - assert getattr(ntfm.data, data) is not None - for meta_dat in tfm.data._meta_fields: - assert getattr(ntfm.data, meta_dat) is not None + assert np.array_equal(getattr(tfm.data, data), getattr(ntfm.data, data)) + for component in tfm.modes.components: + assert getattr(ntfm.results.params, component).has_params def test_add_data(tresults): """Tests method to add data to model objects.""" @@ -270,21 +266,10 @@ def test_add_data(tresults): def test_get_params(tfm): """Test the get_params method.""" - for dname in ['aperiodic_params', 'aperiodic', 'peak_params', 'peak', - 'gaussian_params', 'gaussian', 'metrics']: - assert np.any(tfm.get_params(dname)) - - if dname == 'aperiodic_params' or dname == 'aperiodic': - for dtype in ['offset', 'exponent']: - assert np.any(tfm.get_params(dname, dtype)) - - if dname == 'peak_params' or dname == 'peak': - for dtype in ['CF', 'PW', 'BW']: - assert np.any(tfm.get_params(dname, dtype)) - - if dname == 'metrics': - for dtype in ['error_mae', 'gof_rsquared']: - assert np.any(tfm.get_params(dname, dtype)) + for component in tfm.modes.components: + assert np.any(tfm.get_params(component)) + for pname in getattr(tfm.modes, component).params.labels: + assert np.any(tfm.get_params(component, pname)) def test_get_data(tfm): @@ -296,7 +281,7 @@ def test_get_component(tfm): for comp in ['full', 'aperiodic', 'peak']: for space in ['log', 'linear']: - assert isinstance(tfm.results.get_component(comp, space), np.ndarray) + assert isinstance(tfm.results.model.get_component(comp, space), np.ndarray) def test_prints(tfm): """Test methods that print (alias and pass through methods). @@ -320,19 +305,14 @@ def test_resets(): # Note: uses it's own tfm, to not clear the global one tfm = get_tfm() - tfm._reset_data_results(True, True, True) - tfm.algorithm._reset_internal_settings() - for field in tfm.data._fields: assert getattr(tfm.data, field) is None - model_components = ['modeled_spectrum_', '_spectrum_flat', - '_spectrum_peak_rm', '_ap_fit', '_peak_fit'] - for field in model_components: - assert getattr(tfm.results, field) is None - for field in tfm.results._fields: - assert np.all(np.isnan(getattr(tfm.results, field))) - assert tfm.data.freqs is None and tfm.results.modeled_spectrum_ is None + for key, value in tfm.results.model.__dict__.items(): + assert value is None + for component in tfm.modes.components: + assert not getattr(tfm.results.params, component).has_params + assert tfm.data.freqs is None and tfm.results.model.modeled_spectrum is None def test_report(skip_if_no_mpl): """Check that running the top level model method runs.""" @@ -347,13 +327,13 @@ def test_fit_failure(): ## Induce a runtime error, and check it runs through tfm = SpectralModel(verbose=False) - tfm.algorithm._maxfev = 2 + tfm.algorithm._cf_settings.maxfev = 2 tfm.fit(*sim_power_spectrum(*default_spectrum_params())) # Check after failing out of fit, all results are reset - for result in tfm.results._fields: - assert np.all(np.isnan(getattr(tfm.results, result))) + for component in tfm.modes.components: + assert not getattr(tfm.results.params, component).has_params ## Monkey patch to check errors in general # This mimics the main fit-failure, without requiring bad data / waiting for it to fail. @@ -366,14 +346,14 @@ def raise_runtime_error(*args, **kwargs): tfm.fit(*sim_power_spectrum(*default_spectrum_params())) # Check after failing out of fit, all results are reset - for result in tfm.results._fields: - assert np.all(np.isnan(getattr(tfm.results, result))) + for component in tfm.modes.components: + assert not getattr(tfm.results.params, component).has_params def test_debug(): """Test model object in debug state, including with fit failures.""" tfm = SpectralModel(verbose=False) - tfm.algorithm._maxfev = 2 + tfm.algorithm._cf_settings.maxfev = 2 tfm.algorithm.set_debug(True) assert tfm.algorithm._debug is True @@ -406,8 +386,8 @@ def test_set_checks(): # Reset checks to true tfm.data.set_checks(True, True) - assert tfm.data._check_freqs is True - assert tfm.data._check_data is True + assert tfm.data.checks['freqs'] is True + assert tfm.data.checks['data'] is True def test_to_df(tfm, tbands, skip_if_no_pandas): diff --git a/specparam/tests/models/test_time.py b/specparam/tests/models/test_time.py index 8a5e7cf3..14d10bd7 100644 --- a/specparam/tests/models/test_time.py +++ b/specparam/tests/models/test_time.py @@ -9,6 +9,7 @@ import numpy as np from specparam.sim import sim_spectrogram +from specparam.models.utils import compare_model_objs from specparam.modutils.dependencies import safe_import pd = safe_import('pandas') @@ -40,8 +41,8 @@ def test_time_iter(tft): def test_time_n_properties(tft): - assert np.all(tft.results.n_peaks_) - assert np.all(tft.results.n_params_) + assert np.all(tft.results.n_peaks) + assert np.all(tft.results.n_params) def test_time_fit(): @@ -78,26 +79,27 @@ def test_time_report(skip_if_no_mpl): assert tft -def test_time_load(): - - file_name_res = 'test_time_res' - file_name_set = 'test_time_set' - file_name_dat = 'test_time_dat' +def test_time_load(tft): # Test loading results - tft = SpectralTimeModel(verbose=False) - tft.load(file_name_res, TEST_DATA_PATH) - assert tft.results.time_results + ntft = SpectralTimeModel(verbose=False) + ntft.load('test_time_res', TEST_DATA_PATH) + assert ntft.results.time_results # Test loading settings - tft = SpectralTimeModel(verbose=False) - tft.load(file_name_set, TEST_DATA_PATH) - assert tft.algorithm.get_settings() + ntft = SpectralTimeModel(verbose=False) + ntft.load('test_time_set', TEST_DATA_PATH) + assert ntft.algorithm.get_settings() # Test loading data - tft = SpectralTimeModel(verbose=False) - tft.load(file_name_dat, TEST_DATA_PATH) - assert np.all(tft.data.power_spectra) + ntft = SpectralTimeModel(verbose=False) + ntft.load('test_time_dat', TEST_DATA_PATH) + assert np.all(ntft.data.spectrogram) + + # Test loading all elements + ntft = SpectralTimeModel(verbose=False) + ntft.load('test_time_all', TEST_DATA_PATH) + assert compare_model_objs([tft, ntft], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) def test_time_drop(): diff --git a/specparam/tests/models/test_utils.py b/specparam/tests/models/test_utils.py index b34e0399..8e2d9903 100644 --- a/specparam/tests/models/test_utils.py +++ b/specparam/tests/models/test_utils.py @@ -33,17 +33,25 @@ def test_compare_model_objs(tfm, tfg): f_obj2 = f_obj.copy() - assert compare_model_objs([f_obj, f_obj2], ['settings', 'meta_data', 'metrics']) + assert compare_model_objs([f_obj, f_obj2], + ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + + assert compare_model_objs([f_obj, f_obj2], 'modes') + f_obj2.add_modes('knee', 'cauchy') + assert not compare_model_objs([f_obj, f_obj2], 'modes') assert compare_model_objs([f_obj, f_obj2], 'settings') - f_obj2.algorithm.peak_width_limits = [2, 4] - f_obj2.algorithm._reset_internal_settings() + f_obj2.algorithm.settings.peak_width_limits = [2, 4] assert not compare_model_objs([f_obj, f_obj2], 'settings') assert compare_model_objs([f_obj, f_obj2], 'meta_data') f_obj2.data.freq_range = [5, 25] assert not compare_model_objs([f_obj, f_obj2], 'meta_data') + assert compare_model_objs([f_obj, f_obj2], 'bands') + f_obj2.results.add_bands({'new' : [1, 4]}) + assert not compare_model_objs([f_obj, f_obj2], 'bands') + assert compare_model_objs([f_obj, f_obj2], 'metrics') f_obj2.results.metrics.add_metric(METRICS['error_rmse']) assert not compare_model_objs([f_obj, f_obj2], 'metrics') @@ -128,9 +136,7 @@ def test_combine_errors(tfm, tfg): # Incompatible settings for f_obj in [tfm, tfg]: f_obj2 = f_obj.copy() - f_obj2.algorithm.peak_width_limits = [2, 4] - f_obj2.algorithm._reset_internal_settings() - + f_obj2.algorithm.settings.peak_width_limits = [2, 4] with raises(IncompatibleSettingsError): combine_model_objs([f_obj, f_obj2]) diff --git a/specparam/tests/modes/test_mode.py b/specparam/tests/modes/test_mode.py index 96f641e9..09cf7950 100644 --- a/specparam/tests/modes/test_mode.py +++ b/specparam/tests/modes/test_mode.py @@ -20,10 +20,11 @@ def tfit(xs, *params): })) tmode = Mode(name='tmode', component='periodic', description='test_desc', - func=tfit, jacobian=None, params=params, + func=tfit, jacobian=None, params=params, ndim=1, freq_space='linear', powers_space='linear') assert tmode assert tmode.n_params == params.n_params + tmode.check_params() def test_mode_params_dict(): @@ -36,7 +37,7 @@ def tfit2(xs, *params): } tmode = Mode(name='tmode', component='aperiodic', description='test_desc2', - func=tfit2, jacobian=None, params=params, + func=tfit2, jacobian=None, params=params, ndim=2, freq_space='linear', powers_space='linear') assert tmode assert isinstance(tmode.params, ParamDefinition) diff --git a/specparam/tests/modes/test_modes.py b/specparam/tests/modes/test_modes.py index b09a9467..6849614b 100644 --- a/specparam/tests/modes/test_modes.py +++ b/specparam/tests/modes/test_modes.py @@ -14,6 +14,7 @@ def test_modes(): assert modes assert isinstance(modes.aperiodic, Mode) assert isinstance(modes.periodic, Mode) + modes.check_params() def test_modes_get_modes(): diff --git a/specparam/tests/objs/test_components.py b/specparam/tests/objs/test_components.py new file mode 100644 index 00000000..b9fa3d32 --- /dev/null +++ b/specparam/tests/objs/test_components.py @@ -0,0 +1,13 @@ +"""Tests for specparam.objs.components.""" + +from specparam.objs.components import * + +################################################################################################### +################################################################################################### + +## ModelComponents object + +def test_model_components(): + + mc = ModelComponents() + assert mc diff --git a/specparam/tests/objs/test_data.py b/specparam/tests/objs/test_data.py index 87aeab46..9b8593c0 100644 --- a/specparam/tests/objs/test_data.py +++ b/specparam/tests/objs/test_data.py @@ -1,4 +1,4 @@ -"""Tests for specparam.objs.data, including the data object and it's methods.""" +"""Tests for specparam.objs.data.""" from specparam.data import SpectrumMetaData, ModelChecks @@ -44,15 +44,14 @@ def test_data_get_set_checks(tdata): tdata.set_checks(False, False) tchecks1 = tdata.get_checks() assert isinstance(tchecks1, ModelChecks) - assert tdata._check_freqs == tchecks1.check_freqs == False - assert tdata._check_data == tchecks1.check_data == False + assert tdata.checks['freqs'] == tchecks1.check_freqs == False + assert tdata.checks['data'] == tchecks1.check_data == False tdata.set_checks(True, True) tchecks2 = tdata.get_checks() assert isinstance(tchecks2, ModelChecks) - assert tdata._check_freqs == tchecks2.check_freqs == True - assert tdata._check_data == tchecks2.check_data == True - + assert tdata.checks['freqs'] == tchecks2.check_freqs == True + assert tdata.checks['data'] == tchecks2.check_data == True @plot_test def test_data_plot(tdata, skip_if_no_mpl): diff --git a/specparam/tests/objs/test_metrics.py b/specparam/tests/objs/test_metrics.py index 0885453a..793b838c 100644 --- a/specparam/tests/objs/test_metrics.py +++ b/specparam/tests/objs/test_metrics.py @@ -23,7 +23,7 @@ def test_metric_kwargs(tfm): metric = Metric('gof', 'ar2', compute_adj_r_squared, {'n_params' : lambda data, results: \ - results.peak_params_.size + results.aperiodic_params_.size}) + results.params.periodic.params.size + results.params.aperiodic.params.size}) assert isinstance(metric, Metric) assert isinstance(metric.label, str) @@ -53,10 +53,16 @@ def test_metrics_obj(tfm): with raises(ValueError): metrics['bad-label'] + # Check getting metrics out + out1 = metrics.get_metrics('error') + assert out1 == metrics.results['error_mae'] + out2 = metrics.get_metrics('gof', 'rsquared') + assert out2 == metrics.results['gof_rsquared'] + def test_metrics_dict(tfm): - er_met_def = {'type' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} - gof_met_def = {'type' : 'gof', 'measure' : 'rsquared', 'func' : compute_r_squared} + er_met_def = {'category' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} + gof_met_def = {'category' : 'gof', 'measure' : 'rsquared', 'func' : compute_r_squared} metrics = Metrics([er_met_def, gof_met_def]) assert isinstance(metrics, Metrics) @@ -73,11 +79,11 @@ def test_metrics_dict(tfm): def test_metrics_kwargs(tfm): - er_met_def = {'type' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} - ar2_met_def = {'type' : 'gof', 'measure' : 'arsquared', + er_met_def = {'category' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} + ar2_met_def = {'category' : 'gof', 'measure' : 'arsquared', 'func' : compute_adj_r_squared, 'kwargs' : {'n_params' : lambda data, results: \ - results.peak_params_.size + results.aperiodic_params_.size}} + results.params.periodic.params.size + results.params.aperiodic.params.size}} metrics = Metrics([er_met_def, ar2_met_def]) assert isinstance(metrics, Metrics) diff --git a/specparam/tests/objs/test_params.py b/specparam/tests/objs/test_params.py new file mode 100644 index 00000000..8d3ad1bd --- /dev/null +++ b/specparam/tests/objs/test_params.py @@ -0,0 +1,59 @@ +"""Tests for specparam.objs.params.""" + +import numpy as np + +from specparam.objs.params import * + +################################################################################################### +################################################################################################### + +## ComponentParameters object + +def test_component_parameters_str(): + + # Test basic string definition + acp = ComponentParameters('aperiodic') + assert acp + + # Check adding values + fparams = np.array([1, 2]) + acp.add_params('fit', fparams) + assert acp.has_fit + assert np.array_equal(acp.params, fparams) + + cparams = np.array([3, 4]) + acp.add_params('converted', cparams) + assert acp.has_converted + assert np.array_equal(acp.params, cparams) + + # Check dictionary export + pdict = acp.asdict() + assert isinstance(pdict, dict) + assert np.array_equal(pdict['aperiodic_fit'], fparams) + assert np.array_equal(pdict['aperiodic_converted'], cparams) + +def test_component_parameters_modes(tmodes): + + ## Check aperiodic mode component definition + ap_params = ComponentParameters(tmodes.aperiodic) + assert ap_params._fit.ndim == tmodes.aperiodic.ndim + assert ap_params._fit.size == ap_params._converted.size == tmodes.aperiodic.params.n_params + assert ap_params.indices == tmodes.aperiodic.params.indices + assert ap_params.ndim == tmodes.aperiodic.ndim + + ## Check periodic mode component definition + pe_params = ComponentParameters(tmodes.periodic) + assert pe_params._fit.ndim == tmodes.periodic.ndim + assert pe_params._fit.size == pe_params._converted.size == tmodes.periodic.params.n_params + assert pe_params.indices == tmodes.periodic.params.indices + assert pe_params.ndim == tmodes.periodic.ndim + +## ModelParameters object + +def test_model_parameters(): + + mp = ModelParameters() + assert mp + + assert isinstance(mp.aperiodic, ComponentParameters) + assert isinstance(mp.periodic, ComponentParameters) diff --git a/specparam/tests/objs/test_results.py b/specparam/tests/objs/test_results.py index d2317590..55c8514e 100644 --- a/specparam/tests/objs/test_results.py +++ b/specparam/tests/objs/test_results.py @@ -1,4 +1,4 @@ -"""Tests for specparam.objs.results, including the data object and it's methods.""" +"""Tests for specparam.objs.results.""" from specparam.objs.results import * @@ -12,14 +12,16 @@ def test_results(): tres = Results() assert isinstance(tres, Results) -def test_results_results(tresults): +def test_results_results(tresults, tmodes): tres = Results() tres.add_results(tresults) assert tres.has_model - for result in tres._fields: - assert np.array_equal(getattr(tres, result), getattr(tresults, result.strip('_'))) + for component in tmodes.components: + attr_comp = 'peak' if component == 'periodic' else component + assert np.array_equal(getattr(tres.params, component).get_params('fit'), + getattr(tresults, attr_comp + '_fit')) results_out = tres.get_results() assert results_out == tresults diff --git a/specparam/tests/plts/test_annotate.py b/specparam/tests/plts/test_annotate.py index 21b7b6f1..e36883f2 100644 --- a/specparam/tests/plts/test_annotate.py +++ b/specparam/tests/plts/test_annotate.py @@ -14,6 +14,14 @@ def test_plot_annotated_peak_search(tfm, skip_if_no_mpl): plot_annotated_peak_search(tfm, file_path=TEST_PLOTS_PATH, file_name='test_plot_annotated_peak_search.png') +@plot_test +def test_plot_individual_peak_search(tfm, skip_if_no_mpl): + + plot_individual_peak_search(tfm, 0, file_path=TEST_PLOTS_PATH, + file_name='test_plot_individual_peak_search-0.png') + plot_individual_peak_search(tfm, 1, file_path=TEST_PLOTS_PATH, + file_name='test_plot_individual_peak_search-1.png') + @plot_test def test_plot_annotated_model(tfm, skip_if_no_mpl): diff --git a/specparam/tests/tdata.py b/specparam/tests/tdata.py index 6791d2d0..4b71b806 100644 --- a/specparam/tests/tdata.py +++ b/specparam/tests/tdata.py @@ -53,7 +53,8 @@ def get_tdata2d(): def get_tfm(): """Get a model object, with a fit power spectrum, for testing.""" - tfm = SpectralModel(bands=Bands({'alpha' : (7, 14)}), verbose=False) + tfm = SpectralModel(bands=Bands({'alpha' : (7, 14)}), + min_peak_height=0.05, peak_width_limits=[1, 8]) tfm.fit(*sim_power_spectrum(*default_spectrum_params())) return tfm @@ -62,6 +63,7 @@ def get_tfm2(): """Get a model object, with a fit power spectrum, for testing - custom metrics & modes.""" tfm2 = SpectralModel(bands=Bands({'alpha' : (7, 14), 'beta' : [15, 30]}), + min_peak_height=0.05, peak_width_limits=[1, 8], metrics=['error_mse', 'gof_adjrsquared'], aperiodic_mode='knee', periodic_mode='gaussian') tfm2.fit(*sim_power_spectrum(*default_spectrum_params())) @@ -72,7 +74,8 @@ def get_tfg(): """Get a group object, with some fit power spectra, for testing.""" n_spectra = 3 - tfg = SpectralGroupModel(bands=Bands({'alpha' : (7, 14)}), verbose=False) + tfg = SpectralGroupModel(bands=Bands({'alpha' : (7, 14)}), + min_peak_height=0.05, peak_width_limits=[1, 8]) tfg.fit(*sim_group_power_spectra(n_spectra, *default_group_params())) return tfg @@ -82,6 +85,7 @@ def get_tfg2(): n_spectra = 3 tfg2 = SpectralGroupModel(bands=Bands({'alpha' : (7, 14), 'beta' : [15, 30]}), + min_peak_height=0.05, peak_width_limits=[1, 8], metrics=['error_mse', 'gof_adjrsquared'], aperiodic_mode='knee', periodic_mode='gaussian') tfg2.fit(*sim_group_power_spectra(n_spectra, *default_group_params())) @@ -94,7 +98,8 @@ def get_tft(): n_spectra = 3 xs, ys = sim_spectrogram(n_spectra, *default_group_params()) - tft = SpectralTimeModel(bands=Bands({'alpha' : (7, 14)}), verbose=False) + tft = SpectralTimeModel(bands=Bands({'alpha' : (7, 14)}), \ + min_peak_height=0.05, peak_width_limits=[1, 8],) tft.fit(xs, ys) return tft @@ -106,7 +111,8 @@ def get_tfe(): xs, ys = sim_spectrogram(n_spectra, *default_group_params()) ys = [ys, ys] - tfe = SpectralTimeEventModel(bands=Bands({'alpha' : (7, 14)}), verbose=False) + tfe = SpectralTimeEventModel(bands=Bands({'alpha' : (7, 14)}), + min_peak_height=0.05, peak_width_limits=[1, 8],) tfe.fit(xs, ys) return tfe @@ -124,9 +130,10 @@ def get_tmodes(): def get_tresults(): """Get a FitResults object, for testing.""" - return FitResults(aperiodic_params=np.array([1.0, 1.00]), - peak_params=np.array([[10.0, 1.25, 2.0], [20.0, 1.0, 3.0]]), - gaussian_params=np.array([[10.0, 1.25, 1.0], [20.0, 1.0, 1.5]]), + return FitResults(aperiodic_fit=np.array([1.0, 1.00]), + aperiodic_converted=np.array([np.nan, np.nan]), + peak_fit=np.array([[10.0, 1.25, 1.0], [20.0, 1.0, 1.5]]), + peak_converted=np.array([[10.0, 1.25, 2.0], [20.0, 1.0, 3.0]]), metrics={'error_mae' : 0.01, 'gof_rsquared' : 0.97}) def get_tdocstring(): diff --git a/tutorials/plot_02-PSDModel.py b/tutorials/plot_02-PSDModel.py index 23a3d2de..b00c1206 100644 --- a/tutorials/plot_02-PSDModel.py +++ b/tutorials/plot_02-PSDModel.py @@ -100,36 +100,19 @@ # Once the power spectrum model has been calculated, the model fit parameters are stored # as object attributes that can be accessed after fitting. # -# Following scikit-learn conventions, attributes that are fit as a result of -# the model have a trailing underscore, for example: -# -# - ``aperiodic_params_`` -# - ``peak_params_`` -# - ``error_`` -# - ``r2_`` -# - ``n_peaks_`` -# - -################################################################################################### -# # Access model fit parameters from specparam object, after fitting: # ################################################################################################### # Aperiodic parameters -print('Aperiodic parameters: \n', fm.results.aperiodic_params_, '\n') +print('Aperiodic parameters: \n', fm.results.params.aperiodic.params, '\n') # Peak parameters -print('Peak parameters: \n', fm.results.peak_params_, '\n') - -# Goodness of fit measures -print('Goodness of fit:') -print(' Error - ', fm.results.metrics.results['error_mae']) -print(' R^2 - ', fm.results.metrics.results['gof_rsquared'], '\n') +print('Peak parameters: \n', fm.results.params.periodic.params, '\n') # Check how many peaks were fit -print('Number of fit peaks: \n', fm.results.n_peaks_) +print('Number of fit peaks: \n', fm.results.n_peaks) ################################################################################################### # Selecting Parameters @@ -141,18 +124,9 @@ ################################################################################################### -# Extract a model parameter with `get_params` -err = fm.get_params('metrics', 'error_mae') - # Extract parameters, indicating sub-selections of parameters -exp = fm.get_params('aperiodic_params', 'exponent') -cfs = fm.get_params('peak_params', 'CF') - -# Print out a custom parameter report -template = ("With an error level of {error:1.2f}, an exponent " - "of {exponent:1.2f} and peaks of {cfs:s} Hz were fit.") -print(template.format(error=err, exponent=exp, - cfs=' & '.join(map(str, [round(cf, 2) for cf in cfs])))) +exp = fm.get_params('aperiodic', 'exponent') +cfs = fm.get_params('periodic', 'CF') ################################################################################################### # @@ -164,6 +138,45 @@ # in general Python (ex: `help(fm.get_params)`). # +################################################################################################### +# Model Metrics +# ~~~~~~~~~~~~~ +# +# In addition to model fit parameters, the fitting procedure computes and stores various +# metrics that can be used to evaluate model fit quality. +# + +################################################################################################### + +# Goodness of fit metrics +print('Goodness of fit:') +print(' Error - ', fm.results.metrics.results['error_mae']) +print(' R^2 - ', fm.results.metrics.results['gof_rsquared'], '\n') + +################################################################################################### +# +# You can also access metrics with the :meth:`~specparam.SpectralModel.results.get_metrics` method. +# + +################################################################################################### + +# Extract a model metric with `get_metrics` +err = fm.get_metrics('error') + +################################################################################################### +# +# Extracting model fit parameters and model metrics can also be combined to evaluate +# model properties, for example using the following template: +# + +################################################################################################### + +# Print out a custom parameter report +template = ("With an error level of {error:1.2f}, an exponent " + "of {exponent:1.2f} and peaks of {cfs:s} Hz were fit.") +print(template.format(error=err, exponent=exp, + cfs=' & '.join(map(str, [round(cf, 2) for cf in cfs])))) + ################################################################################################### # Notes on Interpreting Peak Parameters # ------------------------------------- @@ -203,7 +216,8 @@ # Compare the 'peak_params_' to the underlying gaussian parameters print(' Peak Parameters \t Gaussian Parameters') -for peak, gauss in zip(fm.results.peak_params_, fm.results.gaussian_params_): +for peak, gauss in zip(fm.results.params.periodic.get_params('converted'), + fm.results.params.periodic.get_params('fit')): print('{:5.2f} {:5.2f} {:5.2f} \t {:5.2f} {:5.2f} {:5.2f}'.format(*peak, *gauss)) #################################################################################################### @@ -227,7 +241,7 @@ fres = fm.results.get_results() # You can also unpack all fit parameters when using `get_results` -ap_params, peak_params, metrics, gauss_params = fm.results.get_results() +ap_fit, ap_conv, peak_fit, peak_conv, metrics = fm.results.get_results() ################################################################################################### @@ -235,7 +249,7 @@ print(fres, '\n') # from specparamResults, you can access the different results -print('Aperiodic Parameters: \n', fres.aperiodic_params) +print('Aperiodic Parameters: \n', fres.aperiodic_fit) # Check the R^2 and error of the model fit print('R-squared: \n {:5.4f}'.format(fres.metrics['gof_rsquared'])) diff --git a/tutorials/plot_03-Algorithm.py b/tutorials/plot_03-Algorithm.py index 11fe4b60..2517b479 100644 --- a/tutorials/plot_03-Algorithm.py +++ b/tutorials/plot_03-Algorithm.py @@ -174,7 +174,7 @@ ################################################################################################### # Plot the peak fit: created by re-fitting all of the candidate peaks together -plot_spectra(fm.data.freqs, fm.results._peak_fit, plt_log, +plot_spectra(fm.data.freqs, fm.results.model.get_component('peak'), plt_log, color='green', label='Final Periodic Fit') ################################################################################################### @@ -191,7 +191,7 @@ ################################################################################################### # Plot the peak removed power spectrum, created by removing peak fit from original spectrum -plot_spectra(fm.data.freqs, fm.results._spectrum_peak_rm, plt_log, +plot_spectra(fm.data.freqs, fm.get_data('aperiodic'), plt_log, label='Peak Removed Spectrum', color='black') ################################################################################################### @@ -209,10 +209,10 @@ # Plot the final aperiodic fit, calculated on the peak removed power spectrum _, ax = plt.subplots(figsize=(12, 10)) -plot_spectra(fm.data.freqs, fm.results._spectrum_peak_rm, plt_log, +plot_spectra(fm.data.freqs, fm.get_data('aperiodic'), plt_log, label='Peak Removed Spectrum', color='black', ax=ax) -plot_spectra(fm.data.freqs, fm.results._ap_fit, plt_log, label='Final Aperiodic Fit', - color='blue', alpha=0.5, linestyle='dashed', ax=ax) +plot_spectra(fm.data.freqs, fm.results.model.get_component('aperiodic'), plt_log, + label='Final Aperiodic Fit', color='blue', alpha=0.5, linestyle='dashed', ax=ax) ################################################################################################### # Step 7: Combine the Full Model Fit @@ -229,7 +229,7 @@ ################################################################################################### # Plot full model, created by combining the peak and aperiodic fits -plot_spectra(fm.data.freqs, fm.results.modeled_spectrum_, plt_log, +plot_spectra(fm.data.freqs, fm.results.model.modeled_spectrum, plt_log, label='Full Model', color='red') ################################################################################################### diff --git a/tutorials/plot_04-ModelObject.py b/tutorials/plot_04-ModelObject.py index 3118e7a4..70460a18 100644 --- a/tutorials/plot_04-ModelObject.py +++ b/tutorials/plot_04-ModelObject.py @@ -231,8 +231,8 @@ ################################################################################################### # Print out model fit results parameters -print('aperiodic params: \t', fm.results.aperiodic_params_) -print('peak params: \t', fm.results.peak_params_) +print('aperiodic params: \t', fm.results.params.aperiodic.params) +print('peak params: \t', fm.results.params.periodic.params) # Print out metrics model fit results parameters print('fit error: \t', fm.results.metrics.results['error_mae']) diff --git a/tutorials/plot_06-GroupFits.py b/tutorials/plot_06-GroupFits.py index ecd3384e..79fa7132 100644 --- a/tutorials/plot_06-GroupFits.py +++ b/tutorials/plot_06-GroupFits.py @@ -141,16 +141,12 @@ ################################################################################################### # Extract aperiodic parameters -aps = fg.get_params('aperiodic_params') -exps = fg.get_params('aperiodic_params', 'exponent') +aps = fg.get_params('aperiodic') +exps = fg.get_params('aperiodic', 'exponent') # Extract peak parameters -peaks = fg.get_params('peak_params') -cfs = fg.get_params('peak_params', 'CF') - -# Extract goodness-of-fit metrics -errors = fg.get_params('metrics', 'error_mae') -r2s = fg.get_params('metrics', 'gof_rsquared') +peaks = fg.get_params('peak') +cfs = fg.get_params('peak', 'CF') ################################################################################################### @@ -159,7 +155,19 @@ ################################################################################################### # -# More information about the parameters you can extract is also documented in the +# Similarly, goodness of fit metrics can be accessed with the +# the :func:`~specparam.SpectralGroupModel.get_metrics` method. +# + +################################################################################################### + +# Extract goodness-of-fit metrics +errors = fg.get_metrics('error') +r2s = fg.get_metrics('gof') + +################################################################################################### +# +# More information about the parameters and metrics you can extract is also documented in the # FitResults object. # diff --git a/tutorials/plot_08-TroubleShooting.py b/tutorials/plot_08-TroubleShooting.py index 7dc54634..072db700 100644 --- a/tutorials/plot_08-TroubleShooting.py +++ b/tutorials/plot_08-TroubleShooting.py @@ -177,7 +177,7 @@ # Compare ground truth simulated parameters to model fit results print('Ground Truth \t\t Model Parameters') -for sy, fi in zip(np.array(gauss_params), fm.results.gaussian_params_): +for sy, fi in zip(np.array(gauss_params), fm.results.params.periodic.get_params('fit')): print('{:5.2f} {:5.2f} {:5.2f} \t {:5.2f} {:5.2f} {:5.2f}'.format(*sy, *fi)) ################################################################################################### @@ -236,7 +236,7 @@ # Check reconstructed parameters compared to the simulated parameters print('Ground Truth \t\t Model Parameters') -for sy, fi in zip(np.array(gauss_params), fm.results.gaussian_params_): +for sy, fi in zip(np.array(gauss_params), fm.results.params.periodic.get_params('fit')): print('{:5.2f} {:5.2f} {:5.2f} \t {:5.2f} {:5.2f} {:5.2f}'.format(*sy, *fi)) ################################################################################################### @@ -308,7 +308,7 @@ ################################################################################################### # Find the index of the worst model fit from the group -worst_fit_ind = np.argmax(fg.get_params('metrics', 'error_mae')) +worst_fit_ind = np.argmax(fg.get_metrics('error')) # Extract this model fit from the group fm = fg.get_model(worst_fit_ind, regenerate=True) @@ -360,7 +360,7 @@ ################################################################################################### # Check the average number of fit peaks, per model -print('Average number of fit peaks: ', np.mean(fg.results.n_peaks_)) +print('Average number of fit peaks: ', np.mean(fg.results.n_peaks)) ################################################################################################### # Reporting Bad Fits diff --git a/tutorials/plot_09-FurtherAnalysis.py b/tutorials/plot_09-FurtherAnalysis.py index 9958fd52..179dfa82 100644 --- a/tutorials/plot_09-FurtherAnalysis.py +++ b/tutorials/plot_09-FurtherAnalysis.py @@ -234,7 +234,7 @@ ################################################################################################### # Extract aperiodic exponent parameters from group results -exps = fg.get_params('aperiodic_params', 'exponent') +exps = fg.get_params('aperiodic', 'exponent') # Check out the aperiodic exponent results print(exps)