Source code for eureka.S5_lightcurve_fitting.models.Model

import numpy as np
import matplotlib.pyplot as plt
import copy
import os

from ..utils import COLORS
from ...lib.readEPF import Parameters
from ...lib.split_channels import split
from ...lib import plots
from ...lib.util import resolve_param_key


[docs] class Model: def __init__(self, **kwargs): """Create a model instance. Parameters ---------- **kwargs : dict Parameters to set in the Model object. Any parameter named log will not be loaded into the Model object as Logedit objects cannot be pickled which is required for multiprocessing. """ self.default_name = 'New Model' # Set up default model attributes self.components = kwargs.get('components', []) self.name = kwargs.get('name', self.default_name) self.nchannel = kwargs.get('nchannel', 1) self.nchannel_fitted = kwargs.get('nchannel_fitted', 1) self.fitted_channels = kwargs.get('fitted_channels', [0, ]) self.wl_groups = kwargs.get('wl_groups', None) self.multwhite = kwargs.get('multwhite', False) self.nints = kwargs.get('nints') self.fitter = kwargs.get('fitter', None) self.time = kwargs.get('time', None) self.time_units = kwargs.get('time_units', 'BMJD_TDB') self.flux = kwargs.get('flux', None) self.freenames = kwargs.get('freenames', None) self._parameters = kwargs.get('parameters', Parameters()) self.longparamlist = kwargs.get('longparamlist', None) self.paramtitles = kwargs.get('paramtitles', None) self.modeltype = kwargs.get('modeltype', None) self.fmt = kwargs.get('fmt', None) # Store the arguments as attributes for arg, val in kwargs.items(): if arg != 'log': setattr(self, arg, val) if self.wl_groups is None: self.wl_groups = [0, ]*self.nchannel_fitted # --- normalize/validate metadata (cast to numpy arrays) --- try: fc = np.asarray(self.fitted_channels, dtype=int).reshape(-1) wg = np.asarray(self.wl_groups, dtype=int).reshape(-1) except Exception as e: raise ValueError(f"Could not parse fitted_channels/wl_groups: {e}") if fc.size != int(self.nchannel_fitted): raise ValueError( f"fitted_channels must have length nchannel_fitted " f"(got {fc.size}, expected {self.nchannel_fitted})." ) if wg.size != fc.size: raise ValueError( f"wl_groups must be the same length as fitted_channels " f"(got {wg.size} vs {fc.size})." ) # Store normalized arrays self.fitted_channels = fc self.wl_groups = wg # Build quick lookup table from real channel id -> position self._chan_to_pos = {int(ch): i for i, ch in enumerate(fc.tolist())} def _channels(self, channel=None): """Return the list of channel IDs to evaluate. Parameters ---------- channel : int; optional If not None, only consider one of the channels. Defaults to None. Returns ------- nchan : int The number of channels to evaluate. channels : ndarray The array of channel IDs to evaluate. """ if channel is None: channels = self.fitted_channels else: channels = np.array([channel, ]) nchan = len(channels) return nchan, channels def _wl_for_chan(self, chan): """Return wavelength-group id for a real channel id. Parameters ---------- chan : int Real channel id present in ``self.fitted_channels``. Returns ------- int Wavelength-group id for this channel. """ try: pos = self._chan_to_pos[chan] except (KeyError, ValueError, TypeError): raise ValueError( f"Channel {chan!r} not found in fitted_channels " f"{self.fitted_channels!r}" ) return self.wl_groups[pos] def _get_param_value(self, base, default=0.0, *, chan=0, wl=None, pid=0): """Resolve a parameter key (wl > ch > base) and return its value. Parameters ---------- base : str Base parameter name (e.g., 'c0', 'r0', 'A'). default : float or None; optional Fallback value if the parameter is missing. If ``None``, return ``None`` when unresolved or uncastable. Default 0.0. chan : int; optional Channel id to resolve against. Default 0. wl : int or None; optional Wavelength-group id. If None, it will be inferred from ``chan``. Pass ``wl=0`` to target the base (unsuffixed) key without consulting the channel-to-wavelength map. pid : int; optional Planet id for astrophysical parameters (0 for none). Default 0. Returns ------- float or None Scalar numeric value, or ``None`` if ``default is None`` and the key is missing or cannot be cast to float. """ params = getattr(self, "parameters", None) if params is None: # No parameters object at all return None if default is None else float(default) if wl is None: # Infer wl from chan; raise with guidance if chan isn't present. try: wl = self._wl_for_chan(chan) except ValueError as e: raise ValueError( f"_get_param_value({base}): cannot infer wl for " f"chan={chan}; pass wl explicitly (e.g., wl=0 for base)." ) from e key = resolve_param_key(base, params, pid=pid, channel=chan, wl=wl) val = getattr(self.parameters, key, default) # If unresolved and caller asked for None, propagate None. if val is default and default is None: return None # Parameters objects have a .value attribute if hasattr(val, "value"): val = val.value # Attempt robust float cast (handles numpy scalars/arrays) try: arr = np.asanyarray(val) if arr.shape == () or arr.size == 1: return float(arr.reshape(-1)[0]) return float(arr.flat[0]) except (TypeError, ValueError): return None if default is None else float(default) def __mul__(self, other): """Multiply model components to make a combined model. Parameters ---------- other : eureka.S5_lightcurve_fitting.models.Model The model to multiply. Returns ------- eureka.S5_lightcurve_fitting.models.CompositeModel The combined model. """ # Make sure it is the right type attrs = ['flux', 'time'] if not all([hasattr(other, attr) for attr in attrs]): raise TypeError('Only another Model instance may be multiplied.') # Combine the model parameters too parameters = self.parameters + other.parameters if self.paramtitles is None: if isinstance(other.paramtitles, list): paramtitles = other.paramtitles[:] else: paramtitles = other.paramtitles elif other.paramtitles is None: if isinstance(self.paramtitles, list): paramtitles = self.paramtitles[:] else: paramtitles = self.paramtitles else: # both are present: concatenate paramtitles = self.paramtitles + other.paramtitles return CompositeModel([copy.copy(self), other], parameters=parameters, paramtitles=paramtitles) @property def flux(self): """A getter for the flux.""" return self._flux @flux.setter def flux(self, flux_array): """A setter for the flux Parameters ---------- flux_array : sequence The flux array """ # Check the type if not isinstance(flux_array, (np.ndarray, tuple, list, type(None))): raise TypeError("flux axis must be a tuple, list, or numpy array.") # Set the array self._flux = np.ma.masked_array(flux_array) @property def time(self): """A getter for the time""" return self._time @time.setter def time(self, time_array): """A setter for the time""" # Check the type if not isinstance(time_array, (np.ndarray, tuple, list, type(None))): raise TypeError("Time axis must be a tuple, list, or numpy array.") # Set the array self._time = np.ma.masked_array(time_array) # Set the array for the components for component in self.components: component.time = time_array @property def parameters(self): """A getter for the parameters.""" return self._parameters @parameters.setter def parameters(self, params): """A setter for the parameters.""" # Process if it is a parameters file if isinstance(params, str) and os.path.isfile(params): params = Parameters(params) # Or a Parameters instance if (params is not None) and (type(params).__name__ != Parameters.__name__): raise TypeError("'params' argument must be a JSON file, " "ascii file, or parameters.Parameters instance.") # Set the parameters attribute self._parameters = params # Set the attribute for the components for component in self.components: component.parameters = params @property def freenames(self): """A getter for the freenames.""" return self._freenames @freenames.setter def freenames(self, freenames): """A setter for the freenames.""" # Set the freenames attribute self._freenames = freenames # Update the components' freenames for component in self.components: component.freenames = freenames @property def nints(self): """A getter for the nints.""" return self._nints @nints.setter def nints(self, nints_array): """A setter for the nints.""" self._nints = nints_array # Update the components' nints for component in self.components: component.nints = nints_array
[docs] def interp(self, new_time, nints, channel=None, **kwargs): """Evaluate the model over a different time array. Parameters ---------- new_time : sequence The time array. nints : list The number of integrations for each channel, for the new time array. channel : int; optional If not None, only consider one of the channels. Defaults to None. **kwargs : dict Additional parameters to pass to self.eval(). """ # Save the current values old_time = copy.deepcopy(self.time) old_nints = copy.deepcopy(self.nints) # Evaluate the model on the new time array self.time = new_time self.nints = nints interp_flux = self.eval(channel=channel, **kwargs) # Reset the old values self.time = old_time self.nints = old_nints return interp_flux
[docs] def update(self, newparams, **kwargs): """Update the model with new parameter values. Parameters ---------- newparams : ndarray New parameter values. **kwargs : dict Unused by the base eureka.S5_lightcurve_fitting.models.Model class. """ for val, arg in zip(newparams, self.freenames): # For now, the dict and Parameter are separate self.parameters.dict[arg][0] = val getattr(self.parameters, arg).value = val for component in self.components: component.update(newparams, **kwargs)
[docs] @plots.apply_style def plot(self, components=False, ax=None, draw=False, color='blue', zorder=np.inf, share=False, chan=0, **kwargs): """Plot the model. Parameters ---------- components : bool; optional Plot all model components. ax : Matplotlib Axes; optional The figure axes to plot on. draw : bool; optional Whether or not to display the plot. Defaults to False. color : str; optional The color to use for the plot. Defaults to 'blue'. zorder : numeric; optional The zorder for the plot. Defaults to np.inf. share : bool; optional Whether or not this model is a shared model. Defaults to False. chan : int; optional The real channel id to render. `LightCurve.plot()` passes the correct value; this function now always respects it. **kwargs : dict Additional parameters to pass to plot and self.eval(). """ # Make the figure if ax is None: fig = plt.figure(5103, figsize=(8, 6)) ax = fig.gca() # Validate channel choice (helps catch accidental chan=0 when fitting # a nonzero channel) try: fc = np.asarray(self.fitted_channels).reshape(-1) if fc.size and (chan not in fc): raise ValueError( f"Model.plot: chan={chan} not in fitted_channels {fc!r}") except Exception: # If fitted_channels is not set/array-like, skip this guard. pass # Plot the model label = self.fitter if self.name != self.default_name: label += ': '+self.name model = self.eval(channel=chan, incl_GP=True, **kwargs) time = self.time if self.multwhite: # Split the arrays that have lengths of the original time axis time = split([time, ], self.nints, chan)[0] ax.plot(time, model, '.', ls='', ms=1, label=label, color=color, zorder=zorder) if components and self.components is not None: for component in self.components: component.plot(components=components, ax=ax, draw=False, color=next(COLORS), zorder=zorder, share=share, chan=chan, **kwargs) # Format axes ax.set_xlabel(str(self.time_units)) ax.set_ylabel('Flux') if draw: fig.show() else: return
[docs] class CompositeModel(Model): """A class to create composite models.""" def __init__(self, components, **kwargs): """Initialize the composite model. Parameters ---------- components : sequence The list of model components. **kwargs : dict Additional parameters to pass to eureka.S5_lightcurve_fitting.models.Model.__init__(). """ # Inherit from Model class kwargs['name'] = kwargs.get('name', 'composite model') super().__init__(components=components, **kwargs) self.GP = False for component in self.components: if component.modeltype == 'GP': self.GP = True @property def freenames(self): """A getter for the freenames.""" return self._freenames @freenames.setter def freenames(self, freenames): """A setter for the freenames.""" # Update the components' freenames for component in self.components: component.freenames = freenames # Set the freenames attribute self._freenames = freenames
[docs] def eval(self, channel=None, incl_GP=False, **kwargs): """Evaluate the model components. Parameters ---------- channel : int; optional If not None, only consider one of the channels. Defaults to None. incl_GP : bool; optional Whether or not to include the GP's predictions in the evaluated model predictions. **kwargs : dict Must pass in the time array here if not already set. Returns ------- flux : ndarray The evaluated model predictions at the times self.time. """ # Get the time if self.time is None: # This also updates all components self.time = kwargs.get('time') if channel is None: nchan = self.nchannel_fitted else: nchan = 1 if self.multwhite and channel is None: # Evaluating all channels of a multwhite fit flux_length = len(self.time) elif self.multwhite: # Evaluating a single channel of a multwhite fit flux_length = self.nints[channel] else: # Evaluating a non-multwhite fit (individual or shared) flux_length = len(self.time)*nchan flux = np.ma.ones(flux_length) # Evaluate flux of each component for component in self.components: if component.modeltype != 'GP': flux *= component.eval(channel=channel, **kwargs) if incl_GP: flux += self.GPeval(flux, channel=channel, **kwargs) return flux
[docs] def syseval(self, channel=None, incl_GP=False, **kwargs): """Evaluate the systematic model components only. Parameters ---------- channel : int; optional If not None, only consider one of the channels. Defaults to None. incl_GP : bool; optional Whether or not to include the GP's predictions in the evaluated model predictions. **kwargs : dict Must pass in the time array here if not already set. Returns ------- flux : ndarray The evaluated systematics model predictions at the times self.time. """ # Get the time if self.time is None: self.time = kwargs.get('time') if channel is None: nchan = self.nchannel_fitted else: nchan = 1 if self.multwhite and channel is None: # Evaluating all channels of a multwhite fit flux_length = len(self.time) elif self.multwhite: # Evaluating a single channel of a multwhite fit flux_length = self.nints[channel] else: # Evaluating a non-multwhite fit (individual or shared) flux_length = len(self.time)*nchan flux = np.ma.ones(flux_length) # Evaluate flux at each component for component in self.components: if component.modeltype == 'systematic': if component.time is None: component.time = self.time flux *= component.eval(channel=channel, **kwargs) if incl_GP: flux += self.GPeval(flux, channel=channel, **kwargs) return flux
[docs] def GPeval(self, fit, channel=None, **kwargs): """Evaluate the GP model components only. Parameters ---------- fit : ndarray The model predictions (excluding the GP). channel : int; optional If not None, only consider one of the channels. Defaults to None. **kwargs : dict Must pass in the time array here if not already set. Returns ------- flux : ndarray The evaluated GP model predictions at the times self.time. """ # Get the time if self.time is None: self.time = kwargs.get('time') if channel is None: nchan = self.nchannel_fitted else: nchan = 1 if self.multwhite and channel is None: # Evaluating all channels of a multwhite fit flux_length = len(self.time) elif self.multwhite: # Evaluating a single channel of a multwhite fit flux_length = self.nints[channel] else: # Evaluating a non-multwhite fit (individual or shared) flux_length = len(self.time)*nchan flux = np.ma.zeros(flux_length) # Evaluate flux for component in self.components: if component.modeltype == 'GP': flux = component.eval(fit, channel=channel, **kwargs) return flux
[docs] def physeval(self, interp=False, channel=None, **kwargs): """Evaluate the physical model components only. Parameters ---------- interp : bool; optional Whether to uniformly sample in time or just use the self.time time points. Defaults to False. channel : int; optional If not None, only consider one of the channels. Defaults to None. **kwargs : dict Must pass in the time array here if not already set. Returns ------- flux : ndarray The evaluated physical model predictions at the times self.time if interp==False, else at evenly spaced times between self.time[0] and self.time[-1] with spacing self.time[1]-self.time[0]. new_time : ndarray The time values at which flux has been computed. nints_interp : list The number of time points per lightcurve for each lightcurve (after interpolation if interp is True). """ # Get the time if self.time is None: self.time = kwargs.get('time') if channel is None: nchan = self.nchannel_fitted channels = self.fitted_channels else: nchan = 1 channels = [channel] if interp: if self.multwhite: new_time = [] nints_interp = [] for chan in channels: # Split the arrays that have lengths of # the original time axis time = split([self.time, ], self.nints, chan)[0] # Remove masked points at the start or end to avoid # extrapolating out to those points time = time[~np.ma.getmaskarray(time)] # Get time step on full time array to ensure good steps dt = np.min(np.diff(time)) # Interpolate as needed steps = int(np.round((time[-1]-time[0])/dt+1)) nints_interp.append(steps) new_time.extend(np.linspace(time[0], time[-1], steps, endpoint=True)) new_time = np.array(new_time) else: time = self.time # Remove masked points at the start or end to avoid # extrapolating out to those points time = time[~np.ma.getmaskarray(time)] # Get time step on full time array to ensure good steps dt = np.min(np.diff(time)) # Interpolate as needed dt = time[1]-time[0] steps = int(np.round((time[-1]-time[0])/dt+1)) nints_interp = np.ones(nchan)*steps new_time = np.linspace(time[0], time[-1], steps, endpoint=True) else: new_time = self.time if self.multwhite and channel is not None: # Split the arrays that have lengths of the original time axis new_time = split([new_time, ], self.nints, channel)[0] nints_interp = self.nints # Setup the flux array if self.multwhite: flux = np.ma.ones(len(new_time)) else: flux = np.ma.ones(len(new_time)*nchan) # Evaluate flux at each component for component in self.components: if component.modeltype == 'physical': if interp: flux *= component.interp(new_time, nints_interp, channel=channel, **kwargs) else: flux *= component.eval(channel=channel, **kwargs) return flux, new_time, nints_interp