import os
import numpy as np
import pandas as pd
import magnificat.input_data as input_data
[docs]def load_sdss_dr7_catalogs(bandpasses):
drw_path = os.path.join(input_data.__path__[0],
'sdss_dr7_s82', 's82drw_{:s}.dat')
shen_path = os.path.join(input_data.__path__[0],
'sdss_dr7_s82', 'DB_QSO_S82.dat')
# Normalizing wavelength
wavelength_norm = 4000.0 # in angstroms, from MacLeod et al 2010
# Central wavelength (value) for each SDSS bandpass (key) used to
# estimate the rest-frame wavelength
wavelength_center = {'u': 3520.0,
'g': 4800.0,
'r': 6250.0,
'i': 7690.0,
'z': 9110.0} # in angstroms, from MacLeod et al 2010
# Cols of s82drw_*.dat
# http://faculty.washington.edu/ivezic/macleod/qso_dr7/Southern_format_drw.html
drw_columns = ['SDR5ID', 'ra', 'dec', 'redshift', 'M_i', 'log_mass_BH', 'chi2']
drw_columns += ['log_tau', 'log_sighat', 'log_tau_lowlim', 'log_tau_uplim']
drw_columns += ['log_sfhat_lowlim', 'log_sfhat_uplim']
drw_columns += ['edge', 'll_model', 'll_noise', 'll_inf', 'mu', 'N_obs']
# Dictionary of pandas dataframes, each dataframe representing the band
drw_dict = {}
for bp in bandpasses:
drw_bp = pd.read_csv(drw_path.format(bp),
index_col=False, sep=r'\s+',
skiprows=3, names=drw_columns)
# z correction
z_corr = 1.0/(1.0 + drw_bp['redshift'].values)
# normalized rest-frame wavelength
drw_bp['rf_wavelength'] = wavelength_center[bp]*z_corr/wavelength_norm
drw_bp['log_rf_wavelength'] = np.log10(drw_bp['rf_wavelength'].values)
# log of rest-frame tau in days
drw_bp['log_rf_tau'] = drw_bp['log_tau'].values - np.log10(1.0 + drw_bp['redshift'].values)
# log of SF_inf in mag (with the proper unit conversion, see Note (4) in schema)
drw_bp['log_sf_inf'] = drw_bp['log_sighat'].values + 0.5*drw_bp['log_rf_tau'] - 0.5*np.log10(365.0)
drw_bp['bandpass'] = bp
# can't use SDR5ID b/c collapses new QSOs as -1 (not unique)
drw_bp['id'] = np.arange(len(drw_bp))
drw_dict[bp] = drw_bp
# Concatenate across bands
drw_all = pd.concat(drw_dict.values(), axis=0)
# http://faculty.washington.edu/ivezic/macleod/qso_dr7/Southern_format_DB.html
shen = clean_shen_et_al_2008_catalog(shen_path, bandpasses)
merged = drw_all.merge(shen, on=['ra', 'dec'],
how='inner', suffixes=('', '_shen'))
np.testing.assert_array_almost_equal(merged['M_i'].values,
merged['M_i_shen'].values)
np.testing.assert_array_almost_equal(merged['redshift'].values,
merged['redshift_shen'].values)
merged.drop(['M_i_shen', 'redshift_shen'], axis=1, inplace=True)
return merged
[docs]def clean_shen_et_al_2008_catalog(in_path, bandpasses, out_path=None):
"""Clean the catalog in
http://faculty.washington.edu/ivezic/macleod/qso_dr7/Southern_format_DB.html
"""
from astropy.io import ascii
# Reading the raw data file
DATA = ascii.read(in_path)
# Extracting useful columns from the DATA file
dbID = DATA.field('col1')
SDSS_id = DATA.field('col4')
RA = DATA.field('col2')
DEC = DATA.field('col3')
M_i = DATA.field('col5')
redshift = DATA.field('col7')
BH_MASS = DATA.field('col8')
mags = dict()
mags['u'] = DATA.field('col10')
mags['g'] = DATA.field('col11')
mags['r'] = DATA.field('col12')
mags['i'] = DATA.field('col13')
mags['z'] = DATA.field('col14')
# Converting the ID to an integer
dbID = [int(i) for i in dbID]
# Generating columns for the cleaned Stripe 82 data
shen_df = pd.DataFrame(dbID, columns=['dbID'])
shen_df['SDSS_id'] = SDSS_id
shen_df['ra'] = RA
shen_df['dec'] = DEC
shen_df['M_i'] = M_i
shen_df['redshift'] = redshift
shen_df['BH_mass'] = BH_MASS
for bp in bandpasses:
shen_df[bp] = mags[bp]
# Remove AGN without BH mass measurements
shen_df = shen_df[shen_df['BH_mass'] > 0.1]
if out_path is None:
return shen_df
else:
shen_df.to_csv(out_path, index=False)
[docs]class S82Sampler:
"""Sampler of AGN parameters from S82. Parameters can be those used to
render the light curves or those to be inferred by the network.
"""
[docs] bp_to_int = dict(zip(list('ugriz'), range(5)))
def __init__(self,
agn_params, # ['BH_mass', 'redshift', 'M_i']
bp_params, # ['log_rf_tau', 'log_sf_inf']
bandpasses,
out_dir,
seed=123):
self.agn_params = agn_params
self.bp_params = bp_params
self.bandpasses = bandpasses
assert 'i' in self.bandpasses
self.out_dir = out_dir
os.makedirs(self.out_dir, exist_ok=True)
os.makedirs(os.path.join(self.out_dir, 'processed'), exist_ok=True)
# For standardizing params
self.mean_params = None
self.std_params = None
self.slice_params = None
self.cleaned_metadata_path = os.path.join(self.out_dir,
'processed',
'metadata.dat')
self.rng = np.random.default_rng(seed)
self._idx = None # init
[docs] def _apply_selection(self):
"""Apply the MacLeod et al 2010 selection
"""
# Apply selection
merged = self.metadata.copy()
before_cut = len(self.metadata)
# 1. Number of observations
# merged = merged.query('(log_tau != -10.0) & (log_sighat != -10.0)') # same result as below
merged = merged.query('N_obs >= 10')
after_1 = len(merged)
print("N_obs selection removed: ", before_cut - after_1)
print("and left: ", after_1)
# 2. Model likelihood
merged['del_noise'] = merged['ll_model'] - merged['ll_noise']
merged = merged.query('del_noise > 2.0')
after_2 = len(merged)
print("Model likelihood selection removed: ", after_1 - after_2)
print("and left: ", after_2)
# 3. Tau is not a lower limit
merged['del_inf'] = merged['ll_model'] - merged['ll_inf']
merged = merged.query('del_inf > 0.05')
after_3 = len(merged)
print("Tau lower limit selection removed: ", after_2 - after_3)
print("and left: ", after_3)
# 4. Edge
merged = merged.query('edge == 0')
after_edgecut = len(merged)
print("Edge selection removed: ", after_3 - after_edgecut)
print("and left: ", after_edgecut)
for bp in list('ugriz'):
print(bp, merged[merged['bandpass'] == bp].shape[0])
self.selected = merged
[docs] def _set_keep_agn(self, keep_agn_mode):
"""Set which AGN IDs to keep
"""
n_obs = np.empty(len(self.bandpasses))
for bp_i, bp in enumerate(self.bandpasses):
n_obs[bp_i] = self.selected[self.selected['bandpass'] == bp].shape[0]
if keep_agn_mode == 'max_obs':
max_obs_i = np.argmax(n_obs)
max_bp = self.bandpasses[max_obs_i]
keep_id = self.selected[self.selected['bandpass'] == max_bp]['dbID'].unique()
self.metadata = self.metadata[self.metadata['dbID'].isin(keep_id)].copy()
else:
raise NotImplementedError
[docs] def _collapse_bandpasses(self):
"""Collapse rows for bandpasses belonging to one AGN to a single row
"""
self.metadata = self.metadata.pivot(['id'] + self.agn_params,
'bandpass',
self.bp_params)
self.metadata.columns = ['_'.join(col).strip() for col in \
self.metadata.columns.values]
self.metadata = self.metadata.reset_index()
self.metadata['id'] = np.arange(self.metadata.shape[0])
@property
[docs] def idx(self):
"""Indices of AGNs to sample from"""
return self._idx
@idx.setter
def idx(self, idx):
"""Indices of AGNs to sample from"""
if max(idx) > len(self) - 1:
raise ValueError("Indices can't exceed size of dataset, "
f"which is {len(self)}.")
self._idx = idx
[docs] def sample(self):
i = self.rng.choice(self.idx, replace=True)
return self.metadata.iloc[i]
[docs] def __len__(self):
if hasattr(self, 'metadata'):
return self.metadata.shape[0]
else:
return 0
if __name__ == '__main__':
[docs] sampler = S82Sampler(agn_params=['BH_mass', 'redshift', 'M_i'],
bp_params=['log_rf_tau', 'log_sf_inf'],
bandpasses=list('ugriz'),
out_dir='s82_sampler_testing',
seed=123)
sampler.process_metadata()
sampler.idx = [0, 1]
params = sampler.sample()
print(params)
params = sampler.sample()
print(params)
assert params['redshift'].dtype == np.float64