Source code for phot_class.utils

#!/usr/bin/env python3.7
# -*- coding: UTF-8 -*-

"""A collection of general utilities used across the parent package."""

import signal
from copy import deepcopy

import numpy as np
import sncosmo
from astropy.table import Table


[docs]class timeout: """A timeout context manager""" def __init__(self, seconds=1, error_message='Timeout'): """A timeout context manager Args: seconds (int): The number of seconds until timeout error_message (str): The TimeOutError message on timeout """ self.seconds = seconds self.error_message = error_message def handle_timeout(self, signum, frame): raise TimeoutError(self.error_message) def __enter__(self): signal.signal(signal.SIGALRM, self.handle_timeout) signal.alarm(self.seconds) def __exit__(self, type_, value, traceback): signal.alarm(0)
[docs]def parse_config_dict(obj_id, config_dict): """Return the priors and kwargs for a given object from a config file Args: obj_id (str): The object id in the dictionary config_dict (dict): A dictionary with data from a config file Returns: - A dictionary with object priors for the hsiao_x1 model - A dictionary of fitting kwargs for the hsiao_x1 model - A dictionary with object priors for the sn91bg model - A dictionary of fitting kwargs for the sn91bg model """ out_data = [] for model in ('hsiao_x1', 'sn91bg'): for dtype in ('priors', 'kwargs'): object_data = config_dict[model].get(obj_id, {}).get(dtype, {}) out_data.append(object_data) return tuple(out_data)
[docs]def calc_model_chisq(data, result, model): """Calculate the chi-squared for a given data table and model Chi-squareds are calculated using parameter values from ``model``. Degrees of freedom are calculated using the number of varied parameters specified is the ``result`` object. Args: data (Table): An sncosmo input table model (Model): An sncosmo Model result (Result): sncosmo fitting result Returns: The un-normalized chi-squared The number of data points used in the calculation """ data = deepcopy(data) # Drop any data that is not withing the model's range min_band_wave = [sncosmo.get_bandpass(b).minwave() for b in data['band']] max_band_wave = [sncosmo.get_bandpass(b).maxwave() for b in data['band']] data = data[ (data['time'] >= model.mintime()) & (data['time'] <= model.maxtime()) & (min_band_wave >= model.minwave()) & (max_band_wave <= model.maxwave()) ] if len(data) == 0: raise ValueError('No data within model range') return sncosmo.chisq(data, model), len(data) - len(result.vparam_names)
[docs]def split_bands(bands, lambda_eff, redshift=0): """Split band-passes into collections of blue and red bands Blue bands have an rest frame effective wavelength < 5500 Ang. Red bands have a rest frame effective wavelength >= 5500 Ang. Args: bands (array[str]): Name of band-passes lambda_eff (array[float]): Effective wavelength of band-passes redshift (float): The redshift of the rest frame Returns: An array of blue filter names An array of red filter names """ # Blueshift wavelengths to rest frame lambda_eff = np.array(lambda_eff) / (1 + redshift) is_blue = np.array(lambda_eff) < 5500 band_array = np.array(bands) return band_array[is_blue], band_array[~is_blue]
[docs]def split_data(data_table, band_names, lambda_eff, z, cutoff=700): """Split a data table into blue and red data (by rest frame) Wavelengths are expected to be in angstroms. Split data by keeping filters that are red-ward or blue-ward of 5500 Ang. If the closest rest frame filter for an observation is more than ``cutoff`` angstroms away, drop the observation. Args: data_table (Table): An SNCosmo input table with column 'band' band_names (iter): List of all bands available in the survey lambda_eff (iter): The effective wavelength of each band in band_names z (float): The redshift of the observed target cutoff (float): The cutoff distance for dropping an observation Returns: A SNCosmo input table with only blue bands A SNCosmo input table with only red bands """ # Check an effective wavelength was specified for each band in the # data table. This avoids a cryptic error message later on. observed_bands = np.unique(data_table['band']) band_has_lambda_eff = np.isin(observed_bands, band_names) if not band_has_lambda_eff.all(): missing_bands = observed_bands[~band_has_lambda_eff] raise ValueError(f'Missing effective wavelength for: {missing_bands}') # Type cast to allow numpy indexing band_names = np.array(band_names) lambda_eff = np.array(lambda_eff) @np.vectorize def lambda_for_band(band): return lambda_eff[band_names == band] # Calculate rest frame effective wavelengths for each observation observed_lambda = lambda_for_band(data_table['band']) rest_frame_lambda = observed_lambda / (1 + z) # Get the name of the observer frame band with the smallest distance # to each rest frame lambda delta_lambda = np.array([ np.abs(rest_frame_lambda - l_eff) for l_eff in lambda_eff]) min_indx = np.argmin(delta_lambda, axis=0) rest_frame_filters = np.array(band_names)[min_indx] # Keep only the specified filters that are within 700 Angstroms of the # rest frame effective wavelength within_dif_range = delta_lambda[ min_indx, np.arange(delta_lambda.shape[1])] < cutoff # Split into blue and red band passes out_list = [] for bands in split_bands(band_names, lambda_eff): is_in_bands = np.isin(rest_frame_filters, bands) indices = np.logical_and(is_in_bands, within_dif_range) out_list.append(data_table[indices]) return out_list
[docs]def classification_filter_factory(classifications, ftype='exclude'): """Returns function to determine whether data should be skipped/kept in an iterator based on its classification The function returned by this factory has signature ``returned_function(table: astropy.Table) -> boolean``. The boolean indicates whether the data should kept (i.e. not skipped). The class of each object is determined from ``table.meta['classification']``. If there is no '`classification'` key in the meta data, the return is True. Args: classifications (list[str]): A list of classifications to allow ftype (str): 'exclude' or 'include' the given classes Returns: A filter function for sndata """ def filter_func(table): if 'classification' not in table.meta: return True if ftype == 'exclude': return table.meta['classification'] not in classifications elif ftype == 'include': return table.meta['classification'] in classifications else: raise ValueError(f'Unknown filter type: {ftype}') return filter_func