Source code for eureka.S5_lightcurve_fitting.models.StepModel

import numpy as np

from .Model import Model
from ...lib.split_channels import split, get_trim


[docs] class StepModel(Model): """Model for step-functions in time""" def __init__(self, **kwargs): """Initialize the step-function model. Parameters ---------- **kwargs : dict Additional parameters to pass to eureka.S5_lightcurve_fitting.models.Model.__init__(). """ # Inherit from Model class super().__init__(**kwargs) self.name = 'step' # Define model type (physical, systematic, other) self.modeltype = 'systematic' @property def time(self): """A getter for the time.""" return self._time @time.setter def time(self, time_array): """A setter for the time.""" if time_array is None: self._time = None self.time_local = None return self._time = np.ma.masked_invalid(time_array) # Convert to local time if self.multwhite: self.time_local = np.ma.zeros(self._time.shape) for chan in self.fitted_channels: # Split the arrays that have lengths # of the original time axis trim1, trim2 = get_trim(self.nints, chan) piece = self._time[trim1:trim2] self.time_local[trim1:trim2] = piece - piece.data[0] else: self.time_local = self._time - self._time.data[0] def _index_set_for_chan(self, chan): """Discover usable step indices for a given channel. Enumerate integer indices ``N`` present in any ``step{N}*`` or ``steptime{N}*`` key. For each candidate ``N``, use ``_get_param_value(..., default=None, chan=chan)`` to apply the standard precedence (``wl > ch > base``) and keep those where *both* a step and a steptime resolve to defined values. Parameters ---------- chan : int Real channel id. Returns ------- list of int Sorted indices where both step and steptime are defined after resolution for this (chan, wl). """ if getattr(self, "parameters", None) is None: return [] keys = list(self.parameters.dict.keys()) # Collect every integer N appearing after 'step' or 'steptime' # and before any suffix (first non-digit). cand = set() for k in keys: if k.startswith('step'): rest = k[4:] elif k.startswith('steptime'): rest = k[8:] else: continue digits = [] for ch_ in rest: if ch_.isdigit(): digits.append(ch_) else: break if digits: cand.add(int(''.join(digits))) out = [] for n in sorted(cand): has_step = self._get_param_value(f'step{n}', default=None, chan=chan) is not None has_time = self._get_param_value(f'steptime{n}', default=None, chan=chan) is not None if has_step and has_time: out.append(n) return out def _read_steps_for_chan(self, chan): """Read and sort step pairs for a given channel. For each index ``N`` discovered by ``_index_set_for_chan``, read values via ``_get_param_value`` using the same key rules as in ``_match_and_index``. Pairs with zero amplitude are skipped. The result is sorted by step time. Parameters ---------- chan : int Real channel id. Returns ------- list of tuple A list of ``(t_step, step)`` pairs sorted by ``t_step``. """ idxs = self._index_set_for_chan(chan) pairs = [] for n in idxs: # Resolve values for this channel (None if missing) step = self._get_param_value(f'step{n}', default=None, chan=chan) tstep = self._get_param_value(f'steptime{n}', default=None, chan=chan) if step is None or tstep is None or step == 0.0: continue pairs.append((tstep, step)) # Ensure deterministic application order. pairs.sort(key=lambda x: x[0]) return pairs
[docs] def eval(self, channel=None, **kwargs): """Evaluate the step model. Parameters ---------- channel : int; optional If not None, only consider one channel. Defaults to None. **kwargs : dict Must pass in the time array here if not already set. Returns ------- lcfinal : np.ma.MaskedArray The model values at self.time. """ nchan, channels = self._channels(channel) # Get the time if self.time is None: self.time = kwargs.get('time') pieces = [] for chan in channels: t = self.time_local if self.multwhite: # Split the arrays that have lengths of the original time axis t = split([t], self.nints, chan)[0] lcpiece = np.ma.ones(t.shape) for tstep, step in self._read_steps_for_chan(chan): mask = t >= tstep lcpiece[mask] += step lcpiece = np.ma.masked_where(np.ma.getmaskarray(t), lcpiece) pieces.append(lcpiece) if len(pieces) == 1: return pieces[0] else: return np.ma.concatenate(pieces)