WISCIpaper1 / analysis / dust_features / silicon_feature / plot_fits.py
plot_fits.py
Raw
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from matplotlib.ticker import FormatStrFormatter
import astropy.units as u
import astropy.constants as const
from astropy.modeling import models, fitting
from numpy import linspace,exp
from models_mcmc_extension import EmceeFitter
from spectres import spectres


from astropy.modeling import models, fitting
from astropy import units as u
from astropy.convolution import convolve_models

from astropy.modeling.models import Polynomial1D
from astropy.modeling.physical_models import Drude1D
from scipy.interpolate import UnivariateSpline

from astropy.table import Table

from astropy.modeling.models import custom_model

from scipy import special
from scipy.integrate import quad
import math


# Define model
@custom_model
def skewed_gaussian(x, amplitude=0.2, mean=9.8, stddev=2., gamma=0.01):
    """
    One dimensional Gaussian model.

    Parameters
    ----------
    amplitude : float or `~astropy.units.Quantity`.
        Amplitude (peak value) of the Gaussian - for a normalized profile
        (integrating to 1), set amplitude = 1 / (stddev * np.sqrt(2 * np.pi))
    mean : float or `~astropy.units.Quantity`.
        Mean of the Gaussian.
    stddev : float or `~astropy.units.Quantity`.
        Standard deviation of the Gaussian with FWHM = 2 * stddev * np.sqrt(2 * np.log(2)).

    """

    return amplitude * 2/stddev/np.sqrt(2 * np.pi) * np.exp(-0.5 * ((x-mean)/stddev)**2)*0.5 * (1+special.erf(gamma * (x-mean)/stddev/np.sqrt(2)))


def optical_depthfit(continuum_wav, continuum_flux, sel):
    linfitter = fitting.LinearLSQFitter()
    #fit = fitting.LevMarLSQFitter()
    fit = fitting.SimplexLSQFitter()
    #poly_cont = linfitter(models.Polynomial1D(3), continuum_wav[sel], continuum_flux[sel]*continuum_wav[sel]**2.)
    model_power = models.PowerLaw1D(amplitude = 0.5, x_0=1., alpha=2.)
    power_cont = fit(model_power, continuum_wav[sel], continuum_flux[sel], maxiter=10000)

    print(power_cont)

    tau_data = np.log(power_cont(continuum_wav)/continuum_flux)

    plt.plot(continuum_wav, continuum_flux*continuum_wav**2.)
    plt.plot(continuum_wav,power_cont(continuum_wav)*continuum_wav**2.)
    plt.show()

    plt.plot(continuum_wav, continuum_flux)
    plt.plot(continuum_wav,power_cont(continuum_wav))
    plt.show()

    tuple_return2 = [tau_data, power_cont(continuum_wav)]


    return(tuple_return2)



def fit_carbonyl_gauss_skewed(settings_array, continuum_wav, continuum_flux, sel, sel_lines):

    fit_cont = fitting.SimplexLSQFitter()
    fit = fitting.LevMarLSQFitter()
    #linfitter = fitting.LinearLSQFitter()

    #poly_cont = linfitter(models.Polynomial1D(3), continuum_wav[sel], continuum_flux[sel])
    power_cont = fit_cont(models.PowerLaw1D(amplitude = 0.5, x_0=1., alpha=2.), continuum_wav[sel], continuum_flux[sel], maxiter=10000)

    #tau_data = np.log(poly_cont(continuum_wav)/continuum_flux)
    tau_data = np.log(power_cont(continuum_wav)/continuum_flux)
    stddev = np.std(tau_data[sel])
    weights = 1.0 / (stddev)

    #fwhm = 2 * settings_array[1] * np.sqrt(2 * np.log(2))
    #fwhm_lower = 2 * settings_array[5] * np.sqrt(2 * np.log(2))
    #print('fwhm_lower',fwhm_lower)
    #fwhm_upper = 2 * settings_array[6] * np.sqrt(2 * np.log(2))

    gl_init_skewgaus=skewed_gaussian(amplitude=settings_array[0], stddev=settings_array[1], mean=settings_array[2],gamma=settings_array[3])

    print()
    #gl_init_gaus=models.Gaussian1D(amplitude=settings_array[0], stddev=settings_array[3], mean=settings_array[6],bounds={'amplitude':(settings_array[1],settings_array[2]), 'stddev':(settings_array[4],settings_array[5]), 'mean':(settings_array[6]-settings_array[7],settings_array[6]+settings_array[8])})


    # plotting the initial models
    init_model = gl_init_skewgaus

    plt.plot(continuum_wav, tau_data)
    #plt.errorbar(continuum_wav,tau_data, yerr=continuum_error, linestyle="-",marker='')
    plt.plot(continuum_wav, init_model(continuum_wav))
    plt.ylabel('optical depth mod', fontsize=14)
    plt.xlabel(r'             Wavelength [$\mu$m]',fontsize=18)
    plt.title('initial model')


    gl_fit = fit(gl_init_skewgaus, continuum_wav[sel_lines], tau_data[sel_lines], weights=weights, maxiter=10000)


    return gl_fit

def fit_carbonyl_errors_gauss_skewed(name, settings_array, continuum_wav, continuum_flux, sel, sel_lines):
    # calculates errors
    # and fitting with Emcee
    fit_cont = fitting.SimplexLSQFitter()
    fit = fitting.LevMarLSQFitter()
    #linfitter = fitting.LinearLSQFitter()

    #poly_cont = linfitter(models.Polynomial1D(3), continuum_wav[sel], continuum_flux[sel])
    power_cont = fit_cont(models.PowerLaw1D(amplitude = 0.5, x_0=1., alpha=2.), continuum_wav[sel], continuum_flux[sel], maxiter=10000)

    #tau_data = np.log(poly_cont(continuum_wav)/continuum_flux)
    tau_data = np.log(power_cont(continuum_wav)/continuum_flux)
    stddevs = np.std(tau_data[sel])
    weights = 1.0 / (stddevs)

    gl_init_skewgaus=skewed_gaussian(amplitude=settings_array[0], stddev=settings_array[1], mean=settings_array[2],gamma=settings_array[3])

    gl_fit = fit(gl_init_skewgaus, continuum_wav[sel_lines], tau_data[sel_lines], weights=weights, maxiter=100000)

    fit2 = EmceeFitter(nsteps=5000, burnfrac=0.1) #, save_samples=emcee_samples_file


    fit_mcmc_result = fit2(gl_fit, continuum_wav[sel_lines], tau_data[sel_lines], weights=weights)

    fit2.plot_emcee_results(fit_mcmc_result, filebase="/home/zeegers/wisci_first_shot/carbonyl_github/carbonyl_24_3/carbonyl/silicon_feature/plots_fshot/emcee_res")
    plt.show()
    print(fit_mcmc_result.parameters)
    print(fit_mcmc_result.uncs)

    chains = fit2.fit_info['sampler'].get_chain(flat=True,discard=np.int32(0.1*5000))
    #log_probs = fit2.fit_info['sampler'].get_log_prob(flat=True,discard=np.int32(0.1*1000))  # Also extract the log_probs

    tuple_return = [fit_mcmc_result, chains,stddevs]
    print('tuple',tuple_return)

    return tuple_return
    #return fitparams, self.fit_info

def plot_fits(filename,continuum_wav,tau_data,continuum_error):
    # plots the fits and writes it to a file

    fig,axs=plt.subplots(2,1)

    axs[0].plot(continuum_wav, tau_data)
    #axs[0].errorbar(continuum_wav,tau_data, yerr=continuum_error, linestyle="-",marker='')
    axs[0].plot(continuum_wav, tau_model)
    axs[0].set_ylabel('optical depth', fontsize=14)
    axs[0].set_xlabel(r'             Wavelength [$\mu$m]',fontsize=18)

    axs[1].plot(continuum_wav, tau_data-tau_model)
    axs[1].set_ylabel('residuals', fontsize=14)
    axs[1].set_xlabel(r'Wavelength ($\mu$m)', fontsize=14)

    plt.subplots_adjust(hspace=0,wspace=0.17)
    plt.show()

def print_results(filename, source_name, Av, wav_mode, wav_mode_error, optical_depths, optical_depths_error, fwhm, fwhm_error, surface_area, surface_area_error):

    # in the table we need: name of the source, mean wavelength, FWHM + error, integrated area in per cm^-1
    # create both a table with a text file and with a tex format for the paper

    # create empty tables
    # table txt

    table_path = '/home/zeegers/wisci_first_shot/carbonyl_github/carbonyl_24_3/carbonyl/silicon_feature/sil_res_fshot/'

    table_txt = Table(
        names=(
        "Source",
        "AV",
        "Mode [micron]",
        "Mode error [micron]",
        "FWHM[micron]",
        "FWHM error [micron]",
        "Optical depth [micron]",
        "Optical depth error [micron]",
        "Integrated Area [micron]",
        "Integrated Area error [micron]"
        ),
        dtype=(
            "str",
            "float64",
            "float64",
            "float64",
            "float64",
            "float64",
            "float64",
            "float64",
            "float64",
            "float64"
            ),
        )
    # table latex
    table_latex = Table(
        names=(
            "Source",
            r"$A_{V}$",
            r"Mode (\micron)",
            r"Optical depth (\micron)",
            "FWHM",
            "Integrated Area"
            ),
            dtype=("str", "str", "str", "str", "str", "str"),
        )

    # add the fitting results to the tables

    for i in range(0,len(source_name)):

        table_txt.add_row(
            (
                source_name[i],
                Av[i],
                wav_mode[i],
                wav_mode_error[i],
                optical_depths[i],
                optical_depths_error[i],
                fwhm[i],
                fwhm_error[i],
                surface_area[i],
                surface_area_error[i]
            )
        )
        table_latex.add_row(
            (
                source_name[i],
                f'${Av[i]:.3f}$',
                f'${wav_mode[i]:.3f}\pm{wav_mode_error[i]:.3f}$',
                f'${optical_depths[i]:.3f}\pm{optical_depths_error[i]:.4f}$',
                f'${fwhm[i]:.3f}\pm{fwhm_error[i]:.3f}$',
                f'${surface_area[i]:.3f}\pm {surface_area_error[i]:.3f}$'
            )
        )

    tabname = "fit_fshots_micron"
    # write the tables to files
    table_txt.write(
        table_path + f"{tabname}.txt",
        format="ascii.commented_header",
        overwrite=True,
    )

    table_latex.write(
        table_path + f"{tabname}.tex",
        format="aastex",
        col_align="lccc",
        latexdict={
            "caption": r"Fitting results. \label{tab:fit_results}",
        },
        overwrite=True,
    )

# to do list:
# Get 10 Lac on the list of sources and fit it
# put names of sources on the plots
#

# 10lac_nircam_mrs_merged.fits

# names sources
names_miri   = ['10lac', '2MASSJ085747','2MASSJ150958']

names_miri_2mass = ['10lac','2MASSJ08574757-4609145', '2MASSJ15095841-5958463']

# main program
direc = '/home/zeegers/wisci_first_shot/'
datafile = '_nircam_mrs_merged.fits'

# creating a table for the results

output_directory = '/home/zeegers/wisci_first_shot/carbonyl_github/carbonyl_24_3/carbonyl/silicon_feature'
filename_results = 'output_res_carbonyl.txt'

# output directory for plots
output_directory_plots = '/home/zeegers/wisci_first_shot/carbonyl_github/carbonyl_24_3/carbonyl/silicon_feature/plots_fshot/rebin/'
output_directory_plots_emcee = '/home/zeegers/wisci_first_shot/carbonyl_github/carbonyl_24_3/carbonyl/silicon_feature/plots_fshot/rebin/emcee_res'



Av_array = np.array([0.21, 5.1, 4.7]) # old estimates
Av_array_new = np.array([0.21, 4.98, 4.35])

Rv_array_new = ([3.1, 3.32, 3.13])
logg = ([36000, 24380, 15330])
temperature =([4.03, 3.25, 3.4])


# trying to fit with skewed gaussian:
# amplitude, ampl range, stdev, stdev range, wavelength, wav range + and -
settings_array_gauss = np.array([[0.2, 0.1, 1.0,0.8, 0.3, 3.0, 9.8, 0.2, 0.2],
                  [0.15, 0.05, 1.0, 0.8, 0.3, 3.0, 9.8, 0.05, 0.05],
                  [0.3, 0.05, 1.0, 0.8, 0.3, 3.0, 9.8, 0.05, 0.05]
                  ])

# output arrays
amplitude_array = np.zeros((3))
amplitude_array_error = np.zeros((3))
amplitude_array_unc_plus = np.zeros((3))
amplitude_array_unc_minus = np.zeros((3))

mode_array = np.zeros((3))
mode_array_error = np.zeros((3))
mode_array_unc_plus = np.zeros((3))
mode_array_unc_minus = np.zeros((3))

mean_array = np.zeros((3))
mean_array_error = np.zeros((3))
mean_array_unc_plus = np.zeros((3))
mean_array_unc_minus = np.zeros((3))

stddev_array = np.zeros((3))
stddev_array_error = np.zeros((3))
stddev_array_unc_plus = np.zeros((3))
stddev_array_unc_minus = np.zeros((3))

fwhm_array = np.zeros((3))
fwhm_array_error = np.zeros((3))
fwhm_array_unc_plus = np.zeros((3))
fwhm_array_unc_minus = np.zeros((3))

surface_area_cm = np.zeros((3))
surface_area_error = np.zeros((3))
surface_area_unc_plus = np.zeros((3))
surface_area_unc_minus = np.zeros((3))

optical_depths = np.zeros((3))
optical_depths_error = np.zeros((3))
optical_depths_unc_plus = np.zeros((3))
optical_depths_unc_minus = np.zeros((3))

# try to fit with Drude profile

settings_array_drude = np.array([[0.2, 0.1,9.8,0.001,0.001,0.01,0.3],
                  [0.2, 2.0,9.8,0.001,0.001,0.01,0.3],
                  [0.2, 2.0,9.8,0.001,0.001,0.01,0.3]
                  ])


# scale (central amplitude), x_o (central wavelength), gamma_o(full-width-half-maximum of profile), asym(asymmetry where a value of 0 results in a standard Drude profile)

settings_array_drude_modified = np.array([[0.2, 0.1,9.8,0.001],
                  [0.2, 9.8,2.0,0.01],
                  [0.2, 9.8,2.0,0.01]
                  ])

#amplitude, stdev, wavelength, gamma
#settings_array_skewed_gauss = np.array([[0.02, 0.1,9.8,0.0001],
                  #[0.2, 0.1,9.8,0.001],
                  #[0.2, 0.1,9.8,0.001]
                  #])

settings_array_skewed_gauss = np.array([[0.001, 1.0,9.8,0.01],
                  [0.1, 0.1,9.8,0.01],
                  [0.1, 0.1,9.8,0.01]
                  ])

#filename_plots = names_miri+''

for i in range(0,len(names_miri)):

    data_miri_merged = fits.getdata(direc + 'data/jwst/delivery_v6/' + names_miri[i]+datafile)


    wavelength_merged = data_miri_merged['WAVELENGTH']
    flux_merged = data_miri_merged['FLUX']
    uncs_merged = data_miri_merged['UNC']

    plt.plot(wavelength_merged,flux_merged)
    plt.show()

    good = np.where(np.isfinite(flux_merged)&(flux_merged > 0.))
    flux_merged_new = flux_merged[good]
    wavelength_merged_new = wavelength_merged[good]
    uncs_merged_new = uncs_merged[good]

    continuum_sel=((wavelength_merged_new >= 6.59) & (wavelength_merged_new <= 13.1))
    new_waves= wavelength_merged_new[continuum_sel]
    new_fluxs = flux_merged_new[continuum_sel]
    continuum_error = uncs_merged_new[continuum_sel]

    # Let's rebin the spectrum here
    rebin = [6.6,13.0,1000]
    continuum_wav  = np.arange(rebin[0],rebin[1],
                                (rebin[1]-0.5*(rebin[1]-rebin[0]))/rebin[2])
    continuum_flux = spectres(continuum_wav,new_waves, new_fluxs)

    # We need to get the signal to noise from the ETC calculation, probably the best we've got so far?
    # Line list for CPD: 66371.2

    feature_sil = ((continuum_wav > 8.0) & (continuum_wav <= 12.3))
    stellar_line1 = ((continuum_wav > 7.45) & (continuum_wav <= 7.466))
    stellar_line2 = ((continuum_wav > 7.49) & (continuum_wav <=7.51))
    stellar_line3 = ((continuum_wav > 11.25) & (continuum_wav <= 11.35))
    stellar_line4 = ((continuum_wav > 12.30) & (continuum_wav <= 12.40))
    stellar_line5 = ((continuum_wav > 8.7) & (continuum_wav <= 8.8))
    stellar_line6 = ((continuum_wav > 9.6) & (continuum_wav <= 9.8))
    stellar_line7 = ((continuum_wav > 13.0) & (continuum_wav <= 13.2))

    feature_sil = ((continuum_wav > 8.0) & (continuum_wav <= 12.3))
    stellar_line1 = ((continuum_wav > 6.92) & (continuum_wav <= 6.953))
    stellar_line2 = ((continuum_wav > 7.45) & (continuum_wav <= 7.466))
    stellar_line3 = ((continuum_wav > 7.424) & (continuum_wav <=7.53))
    stellar_line4 = ((continuum_wav > 7.76) & (continuum_wav <=7.785))
    stellar_line5 = ((continuum_wav > 8.124) & (continuum_wav <=8.179))
    stellar_line6 = ((continuum_wav > 11.25) & (continuum_wav <= 11.35))
    stellar_line7 = ((continuum_wav > 12.30) & (continuum_wav <= 12.40))
    stellar_line8 = ((continuum_wav > 8.7) & (continuum_wav <= 8.8))
    stellar_line9 = ((continuum_wav > 9.6) & (continuum_wav <= 9.8))
    stellar_line10 = ((continuum_wav > 12.565) & (continuum_wav <= 12.631))
    stellar_line11 = ((continuum_wav > 13.0) & (continuum_wav <= 13.2))


    #sel =~ (stellar_line1|stellar_line2|feature_sil|stellar_line3|stellar_line4) # can be expanded if more features will be added
    #sel =~ (stellar_line1|stellar_line2|feature_sil|stellar_line3|stellar_line4) # can be expanded if more features will be added
    sel =~ (stellar_line2|stellar_line3|feature_sil|stellar_line6|stellar_line7) # can be expanded if more features will be added

    sel_lines =~ (stellar_line2|stellar_line3|stellar_line6|stellar_line7) # can be expanded if more features will be added
    sel_lines_extreme =~ (stellar_line1|stellar_line2|stellar_line3|stellar_line4|stellar_line5|stellar_line6|stellar_line7|stellar_line8|stellar_line9|stellar_line10|stellar_line11)
    #sel_lines = sel
    sel_extreme =~ (stellar_line1|stellar_line2|stellar_line3|stellar_line4|stellar_line5|feature_sil|stellar_line6|stellar_line7|stellar_line8|stellar_line9|stellar_line10|stellar_line11)

    if i==7:
       optical_depth_return = optical_depthfit(continuum_wav, continuum_flux, sel_extreme)
    else: optical_depth_return = optical_depthfit(continuum_wav, continuum_flux, sel)

    tau_data = optical_depth_return[0]
    power_cont_array = optical_depth_return[1]

    if i==7:
       modified_drude_model = fit_carbonyl_gauss_skewed(settings_array_skewed_gauss[i], continuum_wav, continuum_flux, sel_extreme, sel_lines_extreme)
    else: modified_drude_model = fit_carbonyl_gauss_skewed(settings_array_skewed_gauss[i], continuum_wav, continuum_flux, sel, sel_lines)
    #modified_drude_model = modified_drude(scale=settings_array_drude_modified[i,0], x_o=settings_array_drude_modified[i,1], gamma_o=settings_array_drude_modified[i,2], asym=settings_array_drude_modified[i,3])
        #modified_drude_model = fit_carbonyl_gauss(settings_array_gauss[i], continuum_wav, continuum_flux, sel, sel_lines)

    #gl_fit_drude = fit_carbonyl_drude(settings_array_drude[i], continuum_wav, continuum_flux, sel, sel_lines)

    #tau_model = gl_fit_drude(continuum_wav)
    #tau_model = gl_fit_gauss(continuum_wav)
    tau_model = modified_drude_model(continuum_wav)


    #print(gl_fit_drude)
    print(modified_drude_model)

    fig,axs=plt.subplots(2,1)

    axs[0].plot(continuum_wav, tau_data)
    axs[0].plot(continuum_wav, tau_model)
    axs[0].set_ylabel('optical depth', fontsize=14)
    axs[0].set_xlabel(r'             Wavelength [$\mu$m]',fontsize=18)

    axs[1].plot(continuum_wav, tau_data-tau_model)
    axs[1].set_ylabel('residuals', fontsize=14)
    axs[1].set_xlabel(r'Wavelength ($\mu$m)', fontsize=14)

    plt.subplots_adjust(hspace=0,wspace=0.17)
    plt.show()

    if i==7:
       fit_result = fit_carbonyl_errors_gauss_skewed(names_miri[i], settings_array_skewed_gauss[i], continuum_wav, continuum_flux, sel_extreme, sel_lines_extreme)
    else: fit_result = fit_carbonyl_errors_gauss_skewed(names_miri[i], settings_array_skewed_gauss[i], continuum_wav, continuum_flux, sel, sel_lines)

    print(fit_result)

    fit_mcmc_result = fit_result[0]

    tau_model2 = fit_mcmc_result(continuum_wav)

    stddev = fit_result[2]

    fig,axs=plt.subplots(3,1, height_ratios=[2, 3, 1], figsize=(3.6, 5.5))

    axs[0].plot(continuum_wav, continuum_flux*continuum_wav**2., color = "black")
    axs[0].plot(continuum_wav[sel_extreme], continuum_flux[sel_extreme]*continuum_wav[sel_extreme]**2., marker='o',linestyle=" ", markersize = 1)
    axs[0].plot(continuum_wav,power_cont_array*continuum_wav**2., color="green")
    axs[0].set_ylabel(r'$\lambda^2\cdot$ F($\lambda$) [$\mu\mathrm{m}^2$ Jy]', fontsize=10)
    axs[0].yaxis.set_tick_params(labelsize=9)
    axs[0].set_xlim([6.6, 13.0])
    axs[0].set_xticks([])

    axs[1].errorbar(continuum_wav, tau_data, yerr=stddev, marker='',linestyle="-", color="black", ecolor="grey")
    axs[1].plot(continuum_wav, tau_model2, color="red", lw=2, zorder=10, label = r'10 $\mu$m silicate')
    axs[1].axhline(y=0, color='k', ls=":", c="k")
    axs[1].set_ylabel('optical depth', fontsize=10)
    axs[1].set_ylim([-0.06, 0.26])
    axs[1].yaxis.set_tick_params(labelsize=9)
    axs[1].set_xlim([6.6, 13.0])
    axs[1].set_xticks([])

    axs[2].plot(continuum_wav, (tau_data-tau_model2)/stddev, color="grey")
    axs[2].axhline(y=0, color='k', ls=":", c="k")
    axs[2].set_ylim([-6.4, 6.4])
    axs[2].set_xlim([6.6, 13.0])
    axs[2].xaxis.set_tick_params(labelsize=9)
    axs[2].yaxis.set_tick_params(labelsize=9)
    axs[2].set_ylabel('residuals', fontsize=9)
    axs[2].set_xlabel(r'Wavelength [$\mu$m]',fontsize=9)

    plt.subplots_adjust(hspace=0,left=0.22, right=0.96, top=0.98, bottom = 0.13)
    plt.show()

    fig.savefig(output_directory_plots+names_miri[i]+"plot_res_carbonyl.pdf",dpi=300)

    plt.close()

    fig2,ax=plt.subplots()

    ax.plot(continuum_wav, tau_data)
    #ax.errorbar(continuum_wav, tau_data, yerr=continuum_error, marker='',linestyle="-", color="black")
    ax.plot(continuum_wav, tau_model2, color="crimson", lw=2, zorder=10)
    ax.axhline(y=0, color='k', ls=":", c="k")
    ax.set_ylabel('optical depth', fontsize=14)
    ax.set_xlabel(r'Wavelength ($\mu$m)', fontsize=14)
    plt.show()
    fig2.savefig(output_directory_plots+names_miri[i]+"plot_noresiduals_silicates.pdf",dpi=200, format="pdf")


    plt.close()