Source code for tuskitoo.SpectralExtraction.spectra_extraction_results

from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
from .utils import gaussian,moffat
from copy import deepcopy
import numpy as np
import pandas as pd
import itertools
from scipy.interpolate import interp1d
import astropy.io
from astropy.stats import sigma_clip

[docs] class spectral_extraction_results_handler: def __init__(self,spectral_extraction_results,conditions={"rsquared":0.7},header=None,names=None,band="NoK",name="noname",nsigmas=10,wavelength=None,relevant_keywords_header=None): """ Initialize the spectral_extraction_results_handler class. Parameters: ---------- spectral_extraction_results : dict Dictionary containing results from the spectral extraction process. conditions : dict, optional, default={"rsquared": 0.7} Conditions for cleaning the pandas DataFrame. The default condition is to have rsquared >= 0.7. header : FITS header or None, optional Header information, typically from a FITS file. nsigmas : Number of sigmas of the sigma cliping to make nan outlayers """ self.name = name self.spectral_extraction_results = spectral_extraction_results self.conditions = conditions self.image,self.full_fit,self.normalization_array,self.source_number,self.distribution,self.mask,self.original_image = list(self.spectral_extraction_results.values()) ##### self.distribution_function = gaussian if self.distribution=="gaussian" else moffat self.parameter_number = 3 if self.distribution=="gaussian" else 4 self.columns_distribtuion = ["center","height","sigma"] if self.distribution=="gaussian" else ["center","height","alpha","sigma"] ##### #its not clear called pandas results self.pandas_results,self.image_model = self.array_to_pandas() self.band = band self.pandas_results["band"] = self.band self.cleaned_panda = spectral_extraction_results_handler.clean_pandas(deepcopy(self.pandas_results),conditions=self.conditions) self.residuals = self.image - self.image_model #################### self.spectras1d_raw = {i:spectral_extraction_results_handler.interpolate_1d(self.pandas_results[i].values,nsigmas=100) for i in self.cleaned_panda.columns if "flux" in i} self.spectras1d = {i:spectral_extraction_results_handler.interpolate_1d(self.cleaned_panda[i].values,nsigmas=nsigmas) for i in self.cleaned_panda.columns if "flux" in i} self.cleaned_panda[[i for i in self.spectras1d.keys()]] = np.array(list(self.spectras1d.values())).T if header: self.header=header if isinstance(self.header,astropy.io.fits.header.Header): self.relevant_keywords_header = {i:self.header[i] for i in ["ORIGIN","INSTRUME","OBJECT","NAXIS1","CRVAL1","CD1_1","CUNIT1"] if i in list(self.header.keys()) } to_angs = 1 if "CUNIT1" in self.relevant_keywords_header.keys(): if self.relevant_keywords_header["CUNIT1"]=="nm": to_angs=10 self.wavelength = np.array([(self.relevant_keywords_header["CRVAL1"]+i*self.relevant_keywords_header["CD1_1"])*to_angs for i in self.cleaned_panda["n_pixel"].values]) self.cleaned_panda["wavelength"] = self.wavelength #is possible add the relevant keywords to this directly in some cases else: print("your results will not have wavelength try to define a header") self.names = np.arange(1,self.source_number+1).astype("str") if names: self.set_names(names)
[docs] def set_names(self,names): if len(names)==len(self.names): self.renames = {col_name:col_name.replace(self.names[n],names[n]) for n in range(len(self.names)) for col_name in self.cleaned_panda.columns if self.names[n] in col_name } self.cleaned_panda = self.cleaned_panda.rename(columns=self.renames) self.spectras1d = {key.replace(self.names[n],names[n]):value for n,(key,value) in enumerate(self.spectras1d.items())} self.spectras1d_raw = {key.replace(self.names[n],names[n]):value for n,(key,value) in enumerate(self.spectras1d_raw.items())} #self.spectras1d = #sprint({n:key for n,(key,value) in enumerate(zip(self.spectras1d.items()))}) self.names = names else: print("check well your defined names and your number of objects in the extraction")
[docs] def unset_names(self): names = np.arange(1,self.source_number+1).astype("str") self.renames = {col_name:col_name.replace(self.names[n],names[n]) for n in range(len(self.names)) for col_name in self.cleaned_panda.columns if self.names[n] in col_name } self.cleaned_panda = self.cleaned_panda.rename(columns=self.renames) self.spectras1d = {key.replace(self.names[n],names[n]):value for n,(key,value) in enumerate(self.spectras1d.items())} self.spectras1d_raw = {key.replace(self.names[n],names[n]):value for n,(key,value) in enumerate(self.spectras1d_raw.items())} self.names = names
[docs] def array_to_pandas(self): """ Convert array results to pandas DataFrame. Returns: ------- tuple A tuple containing the pandas DataFrame of results and the 2D model image. """ separation_as_parameter = False #superM = pd.DataFrame(self.full_fit[:,2*self.parameter_number*self.source_number:3*(self.parameter_number*self.source_number)]).dropna().values[1] #data[~np.isnan(data).any(axis=1)] df = pd.DataFrame(self.full_fit[:,2*self.parameter_number*self.source_number:3*(self.parameter_number*self.source_number)]) df_cleaned = df.replace("nan", np.nan) # Drop all rows that contain only NaN values df_non_nan = df_cleaned.dropna(how='all') # Drop all columns that contain only NaN values #df_non_nan = df_non_nan.dropna(axis=1, how='all') # Reset index if needed columns_model = df_non_nan.values[0] #columns_model = pd.DataFrame(self.full_fit[:,2*self.parameter_number*self.source_number:3*(self.parameter_number*self.source_number)]).dropna().values[1] #[i for i in np.unique(deepcopy(self.full_fit[:,2*self.parameter_number*self.source_number:3*(self.parameter_number*self.source_number)]),axis=0) if "nan" not in i][0] columns_flux =[f"flux_{n}" for n in range(1,self.source_number+1)] columns_stats = ["chisqr","redchi","aic","bic","rsquared","n_pixel","x_num"] columns_model_init =["value_norm_"+i if "height" in i else "value_"+i for i in columns_model] # if separation exist should be here columns_std_init = ["std_norm_"+i if "height" in i else "std_"+i for i in columns_model] if any([bool("separation" in i) for i in columns_model]): separation_as_parameter = True columns_model_final = ["value_"+i.replace("separation","center") if "separation" in i else "value_"+i for i in columns_model] columns_std_final = ["std_"+i.replace("separation","center") if "separation" in i else "std_"+i for i in columns_model] else: columns_model_final = ["value_"+i for i in columns_model] columns_std_final = ["std_"+i for i in columns_model] panda_columns = columns_flux+columns_model_final+columns_std_final+columns_stats+columns_model_init#+columns_std_init model_parameters = deepcopy(self.full_fit[:,:self.parameter_number*self.source_number]).astype(float) std = deepcopy(self.full_fit[:,self.parameter_number*self.source_number:2*(self.parameter_number*self.source_number)]).astype(float) stats = deepcopy(self.full_fit[:,3*(self.parameter_number*self.source_number):]).astype(float) pre_values = deepcopy(model_parameters) # ############################################ model_parameters[:,[i for i in range(1,self.parameter_number*self.source_number,self.parameter_number)]] = model_parameters[:,[i for i in range(1,self.parameter_number*self.source_number,self.parameter_number)]] * self.normalization_array if separation_as_parameter: model_parameters[:,[i for i in range(self.parameter_number,self.parameter_number*self.source_number,self.parameter_number)]] = (model_parameters[:,[i for i in range(self.parameter_number,self.parameter_number*self.source_number,self.parameter_number)]].T + model_parameters[:,0]).T multiple_dist = np.array([self.distribution_function(np.arange(stats[0][-1])[:, np.newaxis],*i.T) for i in model_parameters.reshape(len(model_parameters),self.source_number,self.parameter_number)]) fluxes = multiple_dist.sum(axis=1) image_2d_model = multiple_dist.T.sum(axis=0) sumary_results = pd.DataFrame(np.hstack((fluxes,model_parameters,std,stats,pre_values)),columns=panda_columns)#,columns=panda_columns) sumary_results["distribution"] = [self.distribution] * len(sumary_results) sumary_results["source_number"] = [self.source_number] * len(sumary_results) return sumary_results.loc[:,~sumary_results.columns.duplicated()].copy(),image_2d_model#values,std,fluxes,stats,pre_v
[docs] @staticmethod def interpolate_1d(flux,nsigmas=10): """ Interpolate 1D flux data. Parameters: ---------- flux : array-like 1D array of flux values to be interpolated. Returns: ------- array-like Interpolated 1D flux data. Notes: ------ This require more analize given the posibility of what happend when we are working with a cuted 2d image,add table with parameters """ clip_,lower,upper = sigma_clip(flux,sigma=nsigmas, cenfunc='median', return_bounds=True) flux[flux>upper] = np.nan if np.isnan(flux[0]): flux[0] = np.nanmedian(flux) if np.isnan(flux[-1]): flux[-1] = np.nanmedian(flux) x = np.arange(len(flux)) mask_nan = np.isnan(flux) flux_1_no_nan = flux[~mask_nan] x_non_nan = x[~mask_nan] function_to_interpolate = interp1d(x_non_nan, flux_1_no_nan, kind='linear') return function_to_interpolate(x)
[docs] @staticmethod def clean_pandas(pandas_no_clean,conditions={"min":{"rsquared":0.7}}): """ Clean pandas DataFrame based on conditions. Parameters: ---------- pandas_no_clean : DataFrame The pandas DataFrame to be cleaned. conditions : dict, optional, default={"min": {"rsquared": 0.7}} Conditions for cleaning the DataFrame. The default condition is to have rsquared >= 0.7. Returns: ------- DataFrame Cleaned pandas DataFrame. Notes: ------ Here will be a good idea add the posibility of decide over what "pandas" plot the column so is more clear where could be the problem """ for super_key,super_values in conditions.items(): if super_key=="min": for key,values in super_values.items(): indices = pandas_no_clean.index[pandas_no_clean[key] < values] flux_columns = [col for col in pandas_no_clean.columns if "flux" in col] pandas_no_clean.loc[indices, flux_columns] = np.nan elif super_key=="max": for key,values in super_values.items(): indices = pandas_no_clean.index[pandas_no_clean[key] > values] flux_columns = [col for col in pandas_no_clean.columns if "flux" in col] pandas_no_clean.loc[indices, flux_columns] = np.nan return pandas_no_clean
[docs] def plot_2d_image_residuals(self,save=False): model_result = {"original_image":self.image/self.image.max(axis=0),"model_image":self.image_model/self.image_model.max(axis=0),"residuals original-model":self.residuals/np.max(np.abs(self.residuals),axis=0)} fig, axes = plt.subplots(1,3, figsize=(50, 10)) for ax, (key, spectra2d) in zip(axes, model_result.items()): vmin,vmax,label=0,1,"normalize" ax.set_title(key, fontsize=30) if key=="residuals original-model": vmin,vmax,label=-1,1,"(image-model)/max" im = ax.imshow(spectra2d,aspect="auto",vmin=vmin,vmax=vmax, cmap='coolwarm') cbar = plt.colorbar(im, ax=ax, shrink=1) ax.set_xlabel("Pixel",fontsize=20) cbar.ax.tick_params(labelsize=20) cbar.set_label(label, fontsize=20) ax.tick_params(axis='both', which='major', labelsize=20) #fig.colorbar(im, ax=ax, shrink=1,label=label, fontsize=14) if save: plt.savefig(f"images/{self.name}_{self.band}_2Dspectra.jpg", bbox_inches='tight') plt.show()
[docs] def plot_1d(self,n_pixel,save=None): parameters=self.cleaned_panda[self.cleaned_panda['n_pixel'].isin([n_pixel])][[f"value_{c}_{n}" for n in self.names for c in self.columns_distribtuion]] pixel_1d =self.image.T[n_pixel] x = np.linspace(0,len(pixel_1d),100) plt.plot(self.image.T[n_pixel],label="raw data") separated_sources = np.array([self.distribution_function(x,*i) for i in parameters.values[0].reshape(self.source_number,self.parameter_number)]) plt.plot(x,np.sum(separated_sources,axis=0),color="k",label="added models") [plt.plot(x ,i, linestyle="--", linewidth=1.5,label=f"source {self.names[n]}") for n,i in enumerate(separated_sources)] plt.plot(np.arange(len(pixel_1d)),pixel_1d-np.sum(np.array([self.distribution_function(np.arange(len(pixel_1d)),*i) for i in parameters.values[0].reshape(self.source_number,self.parameter_number)]),axis=0),label="residuals",alpha=0.5) plt.title(f"pixel {n_pixel}") plt.legend() plt.show()
[docs] def plot_column(self,column_name="",**kwargs): if column_name not in list(self.cleaned_panda.columns): print(f"{column_name} is not a avalaible column try \n {list(self.cleaned_panda.columns)}") return fig, ax = plt.subplots(figsize=(20, 6)) column = self.cleaned_panda[column_name].values print(f"mean value for {column_name} if {np.nanmedian(column)}") mdian = np.nanmedian(column) ax.plot(column) ax.axhline(mdian,zorder=10,c="k", linewidth=1.5) ax.set_title(f"column {column_name.replace('value','')}: {mdian:.3f}") if "xlim" in kwargs.keys(): ax.set_xlim(*kwargs["xlim"]) ax.set_ylim([mdian*0.2, mdian*1.7]) if "ylim" in kwargs.keys(): ax.set_ylim(*kwargs["ylim"]) ax.tick_params(which="both", bottom=True, top=True, left=True, right=True, length=10, width=2, labelsize=35) # Increase tick length and width ax.xaxis.label.set_size(40) # Set x-axis label font size ax.yaxis.label.set_size(40) # Set y-axis label font size plt.show() return mdian
[docs] def plot_spectra(self,obj=None,xlim=None,ylim=None,save=False,add_lines=False,xlabel=None): #it will be interesting can change between clear and not clear in this routine to check the diferences if isinstance(obj,str): obj = [obj] wavelength = None if xlabel=="pixel": wavelength = np.arange(len(self.cleaned_panda)) else: try: wavelength = self.wavelength xlabel="Observe wavelength" except: print("not wavelength in the class") wavelength = np.arange(len(self.cleaned_panda)) xlabel="pixel" plt.figure(figsize=(20,10)) if not obj: obj = self.spectras1d.keys() [plt.plot(wavelength,self.spectras1d[key],label=key,linewidth=0.5) for key in obj] plt.xlabel(xlabel, fontsize=20) plt.ylabel('Flux', fontsize=20) #ax1.set_title(title, fontsize=20) plt.xlim(np.min(wavelength),np.max(wavelength)) if xlim: plt.xlim(*xlim) if ylim: plt.ylim(*ylim) if add_lines: if xlabel=="pixel": print("not zs informed") else: #maybe pre render a kind of plots import os #tableau_colors = list(mcolors.TABLEAU_COLORS.values()) plt.text(0.05, 0.95, r"$z_{source}=$"+f"{self.zs}", transform=plt.gca().transAxes, fontsize=30, verticalalignment='top', horizontalalignment='left') module_dir = os.path.dirname(os.path.abspath(__file__)) xmin,xmax=plt.gca().get_xlim() _,ymax = plt.gca().get_ylim() line_name,wv = np.loadtxt(os.path.join(module_dir,"tabuled_values/linelist.txt"),dtype="str").T for key,value in zip(line_name,wv): value = float(value)*(1+self.zs) if xmin<value<xmax: #remove lines in masked zone if "Fe" in key or "H1" in key or "H9" in key or "H8" in key: continue#print(key,value) plt.axvline(float(value),c="k",ls="--",alpha=0.2) plt.text(float(value),ymax,key, rotation=90, verticalalignment='bottom', fontsize=20) plt.tick_params(axis='both', which='major', labelsize=20) plt.legend(loc='upper right', prop={'size': 24}, frameon=False, ncol=2) if save: plt.savefig(f"images/{self.name}_{self.band}_spectra.jpg")