Source code for eureka.S5_lightcurve_fitting.modelgrid

# !/usr/bin/python
# -*- coding: latin-1 -*-
"""
A module for creating and managing grids of model spectra
"""
from functools import partial
from glob import glob
import multiprocessing
import os
import pickle
from importlib.resources import files
import time
import warnings

from astropy.io import fits
from astropy.utils.exceptions import AstropyWarning
import astropy.table as at
import astropy.units as q
import h5py
import numpy as np
from scipy.interpolate import RegularGridInterpolator
from scipy.ndimage import zoom

from . import utils

warnings.simplefilter('ignore', category=AstropyWarning)
warnings.simplefilter('ignore', category=FutureWarning)


[docs] class ModelGrid(object): """ Creates a ModelGrid object which contains a multi-parameter grid of model spectra and its references Parameters ---------- path : str The path to the directory of FITS files used to create the ModelGrid refs : list, str The references for the data contained in the ModelGrid teff_rng : tuple The range of effective temperatures [K] logg_rng : tuple The range of surface gravities [dex] FeH_rng : tuple The range of metalicities [dex] wave_rng : array-like The wavelength range of the models [um] n_bins : int The number of bins for the ModelGrid wavelength array data : astropy.table.Table The table of parameters for the ModelGrid inv_file : str An inventory file to more quickly load the database """ def __init__(self, model_directory, bibcode='2013A & A...553A...6H', names={'Teff': 'PHXTEFF', 'logg': 'PHXLOGG', 'FeH': 'PHXM_H', 'mass': 'PHXMASS', 'Lbol': 'PHXLUM'}, resolution=None, wave_units=q.um, **kwargs): """ Initializes the model grid by creating a table with a column for each parameter and ingests the spectra Parameters ---------- model_directory : str The path to the directory of FITS files of spectra, which may include a filename with a wildcard caharacter bibcode : str, array-like; optional The bibcode or list of bibcodes for this data set. Defaults to '2013A & A...553A...6H'. names : dict; optional A dictionary to rename the table columns. The Phoenix model keywords are given as an example. Defaults to {'Teff': 'PHXTEFF', 'logg': 'PHXLOGG', 'FeH': 'PHXM_H', 'mass': 'PHXMASS', 'Lbol': 'PHXLUM'}. resolution : int; optional The desired wavelength resolution (lambda/d_lambda) of the grid spectra. Defaults to None. wave_units : astropy.units.quantity; optional The wavelength units. Defaults to astropy.units.um. **kwargs : dict Additional arguments to pass to self.customize(). """ # Make sure we can use glob if a directory # is given without a wildcard if '*' not in model_directory: model_directory = os.path.join(model_directory, '*') # Check for a precomputed pickle of this ModelGrid model_grid = None if model_directory.endswith(os.sep+'*'): # Location of model_grid pickle file = model_directory.replace('*', 'model_grid.p') if os.path.isfile(file): model_grid = pickle.load(open(file, 'rb')) # Instantiate the precomputed model grid if model_grid is not None: for k, v in model_grid.items(): setattr(self, k, v) self.flux_file = os.path.join(self.path, 'model_grid_flux.hdf5') self.flux = None self.wavelength = None self.mu = None del model_grid # Or compute it from scratch else: # Print update... if model_directory.endswith(os.sep+'*'): print("Indexing models...") # Create some attributes self.path = os.path.dirname(model_directory)+os.sep self.refs = None self.wave_rng = (0*q.um, 40*q.um) self.flux_file = os.path.join(self.path, 'model_grid_flux.hdf5') self.flux = None self.wavelength = None self.mu = None # Save the refs to a References() object if bibcode: if isinstance(bibcode, (list, tuple)): pass elif bibcode and isinstance(bibcode, str): bibcode = [bibcode] else: pass self.refs = bibcode # _check_for_ref_object() # Get list of spectral intensity files files = glob(model_directory) filenames = [] if not files: print('No files match', model_directory, '.') return # Parse the FITS headers vals, dtypes = [], [] for f in files: if f.endswith('.fits'): try: header = fits.getheader(f) keys = np.array(header.cards).T[0] dtypes = [type(i[1]) for i in header.cards] vals.append([header.get(k) for k in keys]) filenames.append(f.split(os.sep)[-1]) except: # FINDME: Need to only catch the expected exception. print(f, 'could not be read into the model grid.') # Fix data types, trim extraneous values, and make the table dtypes = [str if d == bool else d for d in dtypes] vals = [v[: len(dtypes)] for v in vals] table = at.Table(np.array(vals), names=keys, dtype=dtypes) # Add the filenames as a column table['filename'] = filenames # Rename any columns for new, old in names.items(): try: table.rename_column(old, new) except: # FINDME: Need to only catch the expected exception. print('No column named', old) # Remove columns where the values are all the same # and store value as attribute instead for n in table.colnames: val = table[n][0] exc = n not in ['Teff', 'logg', 'FeH'] if list(table[n]).count(val) == len(table[n]) and exc: setattr(self, n, val) table.remove_column(n) # Store the table in the data attribute self.data = table # Store the parameter ranges self.Teff_vals = np.asarray(np.unique(table['Teff'])) self.logg_vals = np.asarray(np.unique(table['logg'])) self.FeH_vals = np.asarray(np.unique(table['FeH'])) # Write an inventory file to this directory for future table loads if model_directory.endswith(os.sep+'*'): self.file = file try: pickle.dump(self.__dict__, open(self.file, 'wb')) except IOError: print('Could not write model grid to', self.file) # Print something print(len(self.data), 'models loaded from', self.path) # In case no filter is used self.n_bins = 1 # Set the wavelength_units self.wave_units = q.AA if wave_units: self.set_units(wave_units) else: self.const = 1 # Save the desired resolution self.resolution = resolution # Customize from the get-go if kwargs: self.customize(**kwargs)
[docs] def export(self, filepath, **kwargs): """Export the model with the given parameters to a FITS file at the given filepath. Parameters ---------- filepath : str The path to the target FITS file. **kwargs : dict Additional parameters to pass to self.get(). """ if not filepath.endswith('.fits'): raise IOError("Target file must have a .fits extension.") # Get the model model = self.get(**kwargs) # Get a dummy FITS file ffile = str(files('ExoCTK').joinpath( f'data{os.sep}core{os.sep}' + 'ModelGrid_tmp.fits')) hdu = fits.open(ffile) # Replace the data hdu[0].data = model['flux'] hdu[1].data = model['mu'] hdu[0].header['PHXTEFF'] = model['Teff'] hdu[0].header['PHXLOGG'] = model['logg'] hdu[0].header['PHXM_H'] = model['FeH'] # Update the wavelength wave = model['wave'] hdu[0].header['CRVAL1'] = min(wave) hdu[0].header['CDELT1'] = np.mean(np.diff(wave)) hdu[0].header['CUNIT1'] = 'Micron' # Write the file hdu.writeto(filepath)
[docs] def get(self, Teff, logg, FeH, resolution=None, interp=True): """Retrieve the wavelength, flux, and effective radius for the spectrum of the given parameters Parameters ---------- Teff : int The effective temperature (K) logg : float The logarithm of the surface gravity (dex) FeH : float The logarithm of the ratio of the metallicity and solar metallicity (dex) resolution : int; optional The desired wavelength resolution (lambda/d_lambda). Defaults to None. interp : bool; optional Interpolate the model if possible. Defaults to True. Returns ------- spec_dict : dict A dictionary of arrays of the wavelength, flux, and mu values and the effective radius for the given model """ # See if the model with the desired parameters is witin the grid in_grid = all([(Teff >= min(self.Teff_vals)) & (Teff <= max(self.Teff_vals)) & (logg >= min(self.logg_vals)) & (logg <= max(self.logg_vals)) & (FeH >= min(self.FeH_vals)) & (FeH <= max(self.FeH_vals))]) if in_grid: # See if the model with the desired parameters is a true grid point on_grid = (self.data[[(self.data['Teff'] == Teff) & (self.data['logg'] == logg) & (self.data['FeH'] == FeH)]] in self.data) # Grab the data if the point is on the grid if on_grid: # Get the row index and filepath row, = np.where((self.data['Teff'] == Teff) & (self.data['logg'] == logg) & (self.data['FeH'] == FeH))[0] filepath = self.path+str(self.data[row]['filename']) # Get the flux, mu, and abundance arrays raw_flux = fits.getdata(filepath, 0) mu = fits.getdata(filepath, 1) # abund = fits.getdata(filepath, 2) # Construct full wavelength scale and convert to microns if self.CRVAL1 == '-': # Try to get data from WAVELENGTH extension... dat = fits.getdata(filepath, ext=-1) raw_wave = np.array(dat).squeeze() else: # ...or try to generate it b = self.CDELT1*np.arange(len(raw_flux[0])) raw_wave = np.array(self.CRVAL1+b).squeeze() # Convert from A to desired units raw_wave *= self.const # Trim the wavelength and flux arrays idx, = np.where(np.logical_and(raw_wave*self.wave_units >= self.wave_rng[0], raw_wave*self.wave_units <= self.wave_rng[1])) flux = raw_flux[:, idx] wave = raw_wave[idx] # Bin the spectrum if necessary if resolution is not None or self.resolution is not None: # Calculate zoom z = utils.calc_zoom(resolution or self.resolution, wave) wave = zoom(wave, z) flux = zoom(flux, (1, z)) # Make a dictionary of parameters # This should really be a core.Spectrum() object! row_data = self.data[row].as_void() spec_dict = dict(zip(self.data.colnames, row_data)) spec_dict['wave'] = wave spec_dict['flux'] = flux spec_dict['mu'] = mu # If not on the grid, interpolate to it else: # Call grid_interp method if interp: spec_dict = self.grid_interp(Teff, logg, FeH) else: return return spec_dict else: print('Teff: ', Teff, ' logg: ', logg, ' FeH: ', FeH, ' model not in grid.') return
[docs] def grid_interp(self, Teff, logg, FeH, plot=False): """Interpolate the grid to the desired parameters Parameters ---------- Teff : int The effective temperature (K) logg : float The logarithm of the surface gravity (dex) FeH : float The logarithm of the ratio of the metallicity and solar metallicity (dex) plot : bool; optional Plot the interpolated spectrum along with the 8 neighboring grid spectra. Defaults to False. Returns ------- grid_point : dict A dictionary of arrays of the wavelength, flux, and mu values and the effective radius for the given model """ # Load the fluxes if self.flux is None: self.load_flux() # Get the flux array flux = self.flux.copy() # Get the interpolable parameters params, values = [], [] for p, v in zip([self.Teff_vals, self.logg_vals, self.FeH_vals], [Teff, logg, FeH]): if len(p) > 1: params.append(p) values.append(v) values = np.asarray(values) label = '{}/{}/{}'.format(Teff, logg, FeH) try: # Interpolate flux values at each wavelength # using a pool for multiple processes print('Interpolating grid point [{}]...'.format(label)) processes = 8 mu_index = range(flux.shape[-2]) start = time.time() pool = multiprocessing.Pool(processes) func = partial(utils.interp_flux, flux=flux, params=params, values=values) new_flux, generators = zip(*pool.map(func, mu_index)) pool.close() pool.join() # Clean up and time of execution new_flux = np.asarray(new_flux) generators = np.asarray(generators) print('Run time in seconds: ', time.time()-start) # Interpolate mu value interp_mu = RegularGridInterpolator(params, self.mu) mu = interp_mu(np.array(values)).squeeze() # Make a dictionary to return grid_point = {'Teff': Teff, 'logg': logg, 'FeH': FeH, 'mu': mu, 'flux': new_flux, 'wave': self.wavelength, 'generators': generators} return grid_point except IOError: print('Grid too sparse. Could not interpolate.') return
[docs] def load_flux(self, reset=False): """Retrieve the flux arrays for all models and load into the ModelGrid.array attribute with shape (Teff, logg, FeH, mu, wavelength) Parameters ---------- reset : bool; optional Delete the old file and clear the flux attribute. Defaults to False. """ if reset: # Delete the old file and clear the flux attribute if os.path.isfile(self.flux_file): os.remove(self.flux_file) self.flux = None if self.flux is None: print('Loading flux into table...') if os.path.isfile(self.flux_file): # Load the flux from the HDF5 file f = h5py.File(self.flux_file, "r") self.flux = f['flux'][:] self.mu = f['mu'][:] self.wavelength = f['wave'][:] f.close() else: # Get array dimensions T, G, M = self.Teff_vals, self.logg_vals, self.FeH_vals shp = [len(T), len(G), len(M)] n, N = 1, np.prod(shp) # Iterate through rows for nt, teff in enumerate(T): for ng, logg in enumerate(G): for nm, feh in enumerate(M): try: # Retrieve flux using the `get()` method d = self.get(teff, logg, feh, interp=False) if d: # Make sure arrays exist if self.flux is None: new_shp = shp+list(d['flux'].shape) self.flux = np.zeros(new_shp) if self.mu is None: new_shp = shp+list(d['mu'].shape) self.mu = np.zeros(new_shp) # Add data to respective arrays self.flux[nt, ng, nm] = d['flux'] self.mu[nt, ng, nm] = d['mu'].squeeze() # Get the wavelength array if self.wavelength is None: self.wavelength = d['wave'] # Garbage collection del d # Print update n += 1 msg = "{: .2f}% complete.".format(n*100./N) print(msg, end='\r') except IOError: # No model computed so reduce total N -= 1 # Load the flux into an HDF5 file f = h5py.File(self.flux_file, "w") f.create_dataset('flux', data=self.flux) f.create_dataset('mu', data=self.mu) f.create_dataset('wave', data=self.wavelength) f.close() # del dset print("100.00 percent complete!", end='\n') else: print('Data already loaded.')
[docs] def customize(self, Teff_rng=(2300, 8000), logg_rng=(0, 6), FeH_rng=(-2, 1), wave_rng=(0*q.um, 40*q.um), n_bins=''): """Trims the model grid by the given ranges in effective temperature, surface gravity, and metallicity. Also sets the wavelength range and number of bins for retrieved model spectra. Parameters ---------- Teff_rng : array-like; optional The lower and upper inclusive bounds for the effective temperature (K). Defaults to (2300, 8000). logg_rng : array-like; optional The lower and upper inclusive bounds for the logarithm of the surface gravity (dex). Defaults to (0, 6). FeH_rng : array-like; optional The lower and upper inclusive bounds for the logarithm of the ratio of the metallicity and solar metallicity (dex). Defaults to (-2, 1). wave_rng : array-like; optional The lower and upper inclusive bounds for the wavelength (microns). Defaults to (0*q.um, 40*q.um). n_bins : int; optional The number of bins for the wavelength axis. Defaults to ''. """ # Make a copy of the grid grid = self.data.copy() self.wave_rng = wave_rng self.n_bins = n_bins or self.n_bins # Filter grid by given parameters self.data = grid[[(grid['Teff'] >= Teff_rng[0]) & (grid['Teff'] <= Teff_rng[1]) & (grid['logg'] >= logg_rng[0]) & (grid['logg'] <= logg_rng[1]) & (grid['FeH'] >= FeH_rng[0]) & (grid['FeH'] <= FeH_rng[1])]] # Print a summary of the returned grid print('{}/{}'.format(len(self.data), len(grid)), 'spectra in parameter range', 'Teff: ', Teff_rng, ', logg: ', logg_rng, ', FeH: ', FeH_rng, ', wavelength: ', wave_rng) # Do nothing if he cut leaves the grid empty if len(self.data) == 0: self.data = grid print('The given param ranges would leave 0 models in the grid.') print('The model grid has not been updated. Please try again.') return # Update the wavelength and flux attributes if isinstance(self.wavelength, np.ndarray): w = self.wavelength W_idx, = np.where((w >= wave_rng[0]) & (w <= wave_rng[1])) T_idx, = np.where((self.Teff_vals >= Teff_rng[0]) & (self.Teff_vals <= Teff_rng[1])) G_idx, = np.where((self.logg_vals >= logg_rng[0]) & (self.logg_vals <= logg_rng[1])) M_idx, = np.where((self.FeH_vals >= FeH_rng[0]) & (self.FeH_vals <= FeH_rng[1])) # Trim arrays self.wavelength = w[W_idx] self.flux = self.flux[T_idx[0]: T_idx[-1]+1, G_idx[0]: G_idx[-1]+1, M_idx[0]: M_idx[-1]+1, :, W_idx[0]: W_idx[-1]+1] self.mu = self.mu[T_idx[0]: T_idx[-1]+1, G_idx[0]: G_idx[-1]+1, M_idx[0]: M_idx[-1]+1] # Update the parameter attributes self.Teff_vals = np.unique(self.data['Teff']) self.logg_vals = np.unique(self.data['logg']) self.FeH_vals = np.unique(self.data['FeH']) # Reload the flux array with the new grid parameters self.load_flux(reset=True) # Clear the grid copy from memory del grid
[docs] def info(self): """Print a table of info about the current ModelGrid""" # Get the info from the class tp = (int, bytes, bool, str, float, tuple, list, np.ndarray) info = [[k, str(v)] for k, v in vars(self).items() if isinstance(v, tp)] # Make the table table = at.Table(np.asarray(info).reshape(len(info), 2), names=['Attributes', 'Values']) # Sort and print table.sort('Attributes') table.pprint(max_width=-1, align=['>', '<'])
[docs] def reset(self): """Reset the current grid to the original state""" file = os.path.join(self.path+'model_grid_flux.hdf5') if os.path.isfile(file): os.remove(file) self.__init__(self.path)
[docs] def set_units(self, wave_units=q.um): """Set the wavelength and flux units Parameters ---------- wave_units : str, astropy.units.core.PrefixUnit/CompositeUnit; optional The wavelength units. Defaults to astropy.units.um. """ # Set wavelength units old_unit = self.wave_units self.wave_units = q.Unit(wave_units) # Update the wavelength self.const = (old_unit/self.wave_units).decompose()._scale