StackFast / python_idl_funcs.py
python_idl_funcs.py
Raw
import matplotlib.patches as patches
from scipy.spatial import ConvexHull
from astropy.io import fits,ascii
import astropy
import gc
import numpy as np
import os
import time
import copy
import gzip
import sys
import pickle
import shutil
import statistics as stats
import glob
from multiprocessing import Pool
from contextlib import closing
from matplotlib import path
from scipy.interpolate import interp1d
from astropy import units as u
import astropy.coordinates as ac
#from shapely.geometry import Point, Polygon
np.set_printoptions(threshold=sys.maxsize)
from astropy.table import Table, vstack, hstack,Column, join


def match2_quick(arr1,arr2):
	'''
	copy the two arrays
	so they can be edited and we don't need to do an elementwise match:
	'''
	first=copy.deepcopy(arr1)
	second=copy.deepcopy(arr2)
	sub_a,sub_b=np.full(len(first),np.nan),np.full(len(second),np.nan)
	matches,ind1,ind2=np.intersect1d(first,second,return_indices=True)
	while len(matches)>0:
		sub_a[ind1]=ind2
		first[ind1]=-1
		matches,ind1,ind2=np.intersect1d(first,second,return_indices=True)
	matches,ind1,ind2=np.intersect1d(arr1,second,return_indices=True)
	while len(matches)>0:
		sub_b[ind2]=ind1
		second[ind2]=-1
		matches,ind1,ind2=np.intersect1d(arr1,second,return_indices=True)
	return sub_a,sub_b

def match2(arr1,arr2,print_missing=False):
	'''
	uses match2_quick for quicker processing
	for two input arrays
	iterates over each element for first array
	finds the indices in second array and stores
	the subscripts of the second array in a list
	called sub_a - indexes of b with length of a
	sub_b - indexes of a with length of b
	'''
	sub_a=[]
	sub_b=[]
	missing_in_b=[]
	missing_in_a=[]
	for ele1 in arr1:
		ind_now=np.where(arr2==ele1)[0]
		if len(ind_now)==0:
			sub_a.append(np.nan)
			missing_in_b.append(ele1)
		else:
			sub_a.append(ind_now[0])
	for ele1 in arr2:
		ind_now=np.where(arr1==ele1)[0]
		if len(ind_now)==0:
			sub_b.append(np.nan)
			missing_in_a.append(ele1)
		else:
			sub_b.append(ind_now[0])	
	if print_missing:
		print("Missing in B:", missing_in_b)
		print("B:", arr2[:100])
		print("Missing in A:", missing_in_a)
		print("A:",arr1[:100])
	return np.array(sub_a),np.array(sub_b)	

def avoid_nans(main_arra,indices,ret_indices=False):
	right_bool=np.logical_and(indices>=0,~np.isnan(indices))
	right_arr=np.where(right_bool)[0]
	right_inds=indices[right_arr]
	#print("max",max(right_inds))
	fin_arr=np.zeros(len(indices))
	#print("shape of input arr",main_arra.shape,right_inds.shape,indices.shape)
	if ret_indices:
		return right_inds
	else:
		fin_arr[right_arr]=main_arra.flatten()[right_inds]
		return fin_arr


def hextract(img, img_header, xlims, ylims):
	'''
	xlim = [first x pixel, last x pixel]
	ylim = ^ for y
	'''
	xsize,ysize=img.shape
	'''
	if xlims[0]<0 or xlims[1]>xsize:
		print("X outside of np.arange")
	if ylims[0]<0 or ylims[1]>ysize:
		print("Y outside of np.arange")
	'''
	xlims,ylims=np.array(xlims).astype(int),np.array(ylims).astype(int)
	extracted_image=img.T[xlims[0]:xlims[1],ylims[0]:ylims[1]]
	naxis1,naxis2=xlims[1]-xlims[0]+1,ylims[1]-ylims[0]+1
	img_header_new=copy.copy(img_header)
	img_header_new["NAXIS1"],img_header_new["NAXIS2"]=naxis1,naxis2
	hist = 'Original image size was '+ str(xsize) + ' by ' + str(ysize)+\
			' Extracted Image: ['+str(xlims[0])+':'+str(xlims[1])+\
			','+ str(ylims[0])+':'+str(ylims[1])+']'
	img_header_new["hist"]=hist
	#assuming no alt,
	img_header_new["CRPIX1"]=int(img_header_new["CRPIX1"]-xlims[0])
	img_header_new["CRPIX2"]=int(img_header_new["CRPIX2"]-ylims[0])
	return extracted_image.T,img_header_new


def wcs_rotate(longit,latit,theta0,crval,reverse=False,longpole=180.0, latpole=0.0):
	'''
	takes in latitude and longitude and outputs theta and phi
	AZP w/ pv2_1 = 0 for TAN
	'''
	phi_p = longpole/RADEG
	sp,cp = np.sin(phi_p), np.cos(phi_p)
	#always assumes theta0 = 90
	alpha_p, delta_p=crval[0]/RADEG, crval[1]/RADEG
	#print("alpha_p, delta_p",alpha_p, delta_p)
	sa,ca,sd,cd = np.sin(alpha_p),np.cos(alpha_p),np.sin(delta_p),np.cos(delta_p)
	#print("sa,ca,sd,cd",sa,ca,sd,cd)
	r = np.array([[-sa*sp - ca*cp*sd,  ca*sp - sa*cp*sd, cp*cd ],
				  [ sa*cp - ca*sp*sd, -ca*cp - sa*sp*sd, sp*cd ],
				  [ ca*cd		   ,  sa*cd		   , sd	]])
	if reverse:
		latitude=latit
		longitude=longit
		g=np.logical_or(~np.isinf(latit),~np.isinf(longit))
		phi1=longit[g]/RADEG
		theta1=latit[g]/RADEG
		#print("phi1,theta1",phi1,theta1)
		#r=r.T
		l = np.cos(theta1)*np.cos(phi1)
		m = np.cos(theta1)*np.sin(phi1)
		n = np.sin(theta1)
		b0 = r[0,0]*l + r[1,0]*m + r[2,0]*n
		b1 = r[0,1]*l + r[1,1]*m + r[2,1]*n
		b2 = (r[0,2]*l + r[1,2]*m + r[2,2]*n)
		b2 = np.clip(b2, -1, 1)
		#print("phi etc", phi1, theta1, l,m,n, b0, b1, b2)
		latitude[g] = np.arcsin(b2)*RADEG
		longitude[g] = np.arctan2( b1, b0)*RADEG
		#print("latitude,longitude",latitude,longitude)
		return longitude,latitude
	else:
		phi = longit
		phi1 = longit/RADEG
		theta1 = latit/RADEG
		r=r.T
		l = np.cos(theta1)*np.cos(phi1)
		m = np.cos(theta1)*np.sin(phi1)
		n = np.sin(theta1)
		b0 = r[0,0]*l + r[1,0]*m + r[2,0]*n
		b1 = r[0,1]*l + r[1,1]*m + r[2,1]*n
		b2 = (r[0,2]*l + r[1,2]*m + r[2,2]*n)
		#print("phi etc",phi, phi1, theta1, l,m,n, b0, b1, b2)
		b2 = np.clip(b2, -1, 1)
		theta = np.arcsin(b2)*RADEG
		phi = np.arctan2(b1, b0)*RADEG
		#print("theta,phi",theta,phi)
		return phi, theta

def wcs_spher_to_xy(a,d,ctype,crval,pv2=0.0,
					  longpole=180.0, latpole=0.0,
					  north_offset=1.e-7,south_offset=1.e-7):
	'''
	just for RA_TAN, DEC_TAN case
	default value of latpole = 0.0 and longpole=180.
	'''
	ctype1,ctype2=ctype[0],ctype[1]
	#projection_type=ctype1[5:8]
	if isinstance(a,int) or isinstance(a,float):
		a=np.array([a])
		d=np.array([d])
	n_long=len(a)
	n_lat=len(d)
	lng,lat=np.array(a),np.array(d)
	inappropriate_ind=np.where(lng>=180.)
	#print("inappropriate_ind",inappropriate_ind,lng,lat)
	lng[inappropriate_ind[0]]=lng[inappropriate_ind[0]]-360.0
	#deal with input points which are too close to the north pole
	bad_ind=np.where(abs(lat-90.0) < north_offset*RADEG)
	#print("bad ind",bad_ind)
	if len(bad_ind[0])>0:
		lat[bad_ind[0]] = 90.0 - north_offset*RADEG
	#deal with input points which are too close to the south pole
	bad_ind=np.where(abs(lat+90.0) < south_offset*RADEG)
	#print("bad ind",bad_ind)
	if len(bad_ind[0])>0:
		lat[bad_ind[0]] = south_offset*RADEG - 90.0
	#as TAN is zenithal
	theta0 = 90.
	#print("before wcs rotate",lng,lat)
	phi, theta = wcs_rotate(lng,lat,theta0,crval)
	phi,theta = phi/RADEG,theta/RADEG
	#print("phi, theta",phi,theta)
	x,y=np.full(len(theta),-np.inf),np.full(len(theta),-np.inf)
	right_flip=np.where(theta>0)[0]
	if len(right_flip)>0:
		r_theta=1.*RADEG/np.tan(theta[right_flip])
		x[right_flip] = r_theta*np.sin(phi[right_flip]) #- crxy[0]
		y[right_flip] = -r_theta*np.cos(phi[right_flip]) #- crxy[1]
	#print("x,y",x,y)
	return x,y

def wcs_xy_to_sph(x,y,ctype,crval,pv2=0.0,
				  longpole=180.0, latpole=0.0,
				  north_offset=1.e-7,south_offset=1.e-7):
	'''
	assuming map_type=TAN transformation
	'''
	ctype1,ctype2=ctype[0],ctype[1]
	n_x,n_y  = len(x),len(y)
	theta=np.full(n_x,PI2)
	xx,yy=x.astype(float),y.astype(float)
	#print("xx,yy",xx,yy)
	r=np.sqrt(xx**2+yy**2)
	g=np.where(r > 0)[0]
	if len(g)>0:
		theta[g]=np.arctan(RADEG/r[g])
	phi=np.arctan2(xx,-yy)
	phi=phi*RADEG
	theta=theta*RADEG
	#print("phi,theta",phi,theta)
	#zenithal
	theta0=90.
	longitude,latitude = wcs_rotate(phi,theta,theta0,crval,True)
	temp=np.where(longitude < 0.0)[0]
	if len(temp)>0:
		longitude[temp]+=360.
	temp=np.where(longitude>=359.99000000000001)[0]
	if len(temp)>0:
		longitude[temp]-=360.
	#print("latitude,longitude",latitude,longitude)
	return longitude,latitude

def ad2xy(header_info,a,d,pos_limits=False,
		  minims=[0,0],maxims=[0,0]):
	t2=time.time()
	try:
		crpixX,crpixY,crvalX,crvalY,\
		cdeltX,cdeltY,ctypeX,ctypeY=header_info['TCRPX11'],header_info['TCRPX12'],\
								header_info['TCRVL11'],header_info['TCRVL12'],\
								header_info['TCDLT11'],header_info['TCDLT12'],\
								header_info['TCTYP11'],header_info['TCTYP12']
	except KeyError:
		crpixX,crpixY,crvalX,crvalY,\
		cdeltX,cdeltY,ctypeX,ctypeY=header_info['CRPIX1'],header_info['CRPIX2'],\
								header_info['CRVAL1'],header_info['CRVAL2'],\
								header_info['CDELT1'],header_info['CDELT2'],\
								header_info['CTYPE1'],header_info['CTYPE2']
	if ctypeX=="":
		ctypeX,ctypeY='RA---TAN','DEC--TAN'
		print('No CTYPE specified - assuming TANgent projection')
	if isinstance(a,int) or isinstance(a,np.float32):
		a=np.array([a])
		d=np.array([d])
	#cdelt,cd=np.array([1.,1.]),np.array([ [1.,0.], [0.,1.] ])
	cdelt,cd=np.array([cdeltX,cdeltY]),np.array([ [1.,0.], [0.,1.] ])
	#print("ctypeX",ctypeX)
	try:
		spherical = ctypeX[4]=='-'
	except IndexError:
		spherical = False
	ctype,crpix = np.array([ctypeX, ctypeY]),np.array([crpixX,crpixY])
	coord=[ctypeX[:4],ctypeY[:4]]
	reverse=(((coord[0]=='DEC-') and (coord[1]=='RA--')) or\
			 ((coord[0]=='GLAT') and (coord[1]=='GLON')) or\
			 ((coord[0]=='ELAT') and (coord[1]=='ELON')))
	if reverse:
		crval=np.array([crvalY,crvalX])
	else:
		crval=np.array([crvalX,crvalY])
	if spherical:
		#replacing wcssph2xy
		xsi,eta=wcs_spher_to_xy(a,d,ctype,crval)
		#print("xsi,eta",xsi,eta)
	else:
		xsi, eta = a - crval[0], d-crval[1]
	if reverse:
		xsi, eta = eta, xsi
	crpix_working = crpix - 1.
	if cdelt[0]!=1.0:
		cd[0,0],cd[0,1]  = cd[0,0]*cdelt[0], cd[0,1]*cdelt[0]
		cd[1,1],cd[1,0]  = cd[1,1]*cdelt[1], cd[1,0]*cdelt[1]
	#for the cdinverse case here (identity), cdinverse and cd are the same
	#print("cd",cd)
	cdinv=cd.flatten()
	for n in np.arange(len(cdinv)):
		if cdinv[n]!=0:
			cdinv[n]=1./cdinv[n]
	cdinv=cdinv.reshape(cd.shape)
	#set the zeros back to 0:
	#cdinv[np.where(abs(cdinv)==np.inf)[0]]=0
	#print(cdinv)
	xdif = ( cdinv[0,0]*xsi + cdinv[0,1]*eta  )
	ydif = ( cdinv[1,0]*xsi + cdinv[1,1]*eta  )
	#print("xdif, ydif", xdif, ydif)
	x = xdif + crpix_working[0] 
	y = ydif + crpix_working[1] 
	#print('x',x,'y',y)
	#print("time taken to do ad2xy:", time.time()-t2)
	#get rid of the things which are too far away:
	'''Sources which are large (>300degs) distances from the OBSID
	pointing results in NaNs as they are no projectable. We
	actually don't need these anyway for each individual
	OBSID. So set the NaNs to just large -ve x,y coords'''
	if pos_limits:
		relevant=np.logical_and(np.logical_and(x>=minims[0],x<=maxims[0]),
								np.logical_and(y>=minims[1],y<=maxims[1]))
		return x[relevant],y[relevant], relevant
	else:
		return x,y
	#print("before edit:")
	#print(x)
	'''
	ibad=np.logical_or(x<=-9e5,x>=9e5)
	x[ibad]=1e6
	ibad=np.logical_or(y<=-9e5,y>=9e5)
	y[ibad]=1e6
	'''
	#final_x,final_y=np.full(len(x),-1),np.full(len(x),-1)
	#final_x[relevant],final_y[relevant]=x[relevant],y[relevant]
	#print("ra,dec of sources in this event:",a[relevant],d[relevant])
	#return x,y
	#return x,y#final_x, final_y

def xy2ad(header_info,x,y):
	t3=time.time()
	try:
		crpixX,crpixY,crvalX,crvalY,\
		cdeltX,cdeltY,ctypeX,ctypeY=header_info['TCRPX11'],header_info['TCRPX12'],\
								header_info['TCRVL11'],header_info['TCRVL12'],\
								header_info['TCDLT11'],header_info['TCDLT12'],\
								header_info['TCTYP11'],header_info['TCTYP12']
		min_x,max_x=crpixX-header_info["TLMAX5"]/2.,crpixX+header_info["TLMAX5"]/2.
		min_y,max_y=crpixY-header_info["TLMAX6"]/2.,crpixY+header_info["TLMAX6"]/2.
	except KeyError:
		try:
			crpixX,crpixY,crvalX,crvalY,\
			cdeltX,cdeltY,ctypeX,ctypeY=header_info['CRPIX1'],header_info['CRPIX2'],\
									header_info['CRVAL1'],header_info['CRVAL2'],\
									header_info['CDELT1'],header_info['CDELT2'],\
									header_info['CTYPE1'],header_info['CTYPE2']
			min_x,max_x,=0,2*crpixX
			min_y,max_y=0,2*crpixY
		except KeyError:
			print(header_info)
			return -99.,-99.
	if ctypeX=="":
		ctypeX,ctypeY='RA---TAN','DEC--TAN'
		print('No CTYPE specified - assuming TANgent projection')
	#print("x,y",x,min_x,max_x,y,min_y,max_y)
	if isinstance(x,int) or isinstance(x,np.float32):
		x=np.array([x])
		y=np.array([y])
	relevant=np.logical_and(np.logical_and(x>=min_x,x<=max_x),np.logical_and(y>=min_y,y<=max_y))
	final_a,final_d=np.full(len(x),-1),np.full(len(x),-1)
	cdelt,cd=np.array([cdeltX,cdeltY]),np.array([ [1.,0.], [0.,1.] ])
	if cdelt[0]!=1.0:
		cd[0,0],cd[0,1]  = cd[0,0]*cdelt[0], cd[0,1]*cdelt[0]
		cd[1,1],cd[1,0]  = cd[1,1]*cdelt[1], cd[1,0]*cdelt[1]
	try:
		spherical = ctypeX[4]=='-'
	except IndexError:
		spherical = False
	ctype,crpix = np.array([ctypeX, ctypeY]),np.array([crpixX,crpixY])
	#xdif = x[relevant] - (crpix[0]-1)
	#ydif = y[relevant] - (crpix[1]-1)
	xdif = x - (crpix[0]-1)
	ydif = y - (crpix[1]-1)
	xsi = cd[0,0]*xdif + cd[0,1]*ydif   #Can't use matrix notation, in
	eta = cd[1,0]*xdif + cd[1,1]*ydif   #case X and Y are vectors
	coord=[ctypeX[:4],ctypeY[:4]]
	reverse=(((coord[0]=='DEC-') and (coord[1]=='RA--')) or\
			 ((coord[0]=='GLAT') and (coord[1]=='GLON')) or\
			 ((coord[0]=='ELAT') and (coord[1]=='ELON')))
	#print("xdif,ydif",xdif,ydif,xsi,eta,coord,reverse)
	#print("crpic",crpix,cd,spherical)
	if reverse:
		crval=np.array([crvalY,crvalX])
		xsi, eta = eta, xsi
	else:
		crval=np.array([crvalX,crvalY])
	if spherical:
		a, d=wcs_xy_to_sph(xsi, eta, ctype, crval)
	else:
		a, d = crval[0] + xsi, crval[1] + eta
	final_a=a
	final_d=d
	#final_a[relevant]=a
	#final_d[relevant]=d
	#print("time taken to do xy2ad", time.time()-t3)
	return final_a,final_d

def hist_nd(input_arr,hist_min,hist_max,boxsize):#,distinguisher=1000):#distinguisher=10000
	t1=time.time()
	'''
	Note that this is a 1-dimensional reverse index
	vector (see HISTOGRAM).  E.g., to find the indices of points
	which fell in a histogram bin [i,j,k], look up:
	ind=[i+nx*(j+ny*k)]
	ri[ri[ind]:ri[ind+1]-1]
	For 3D:
	IDL> h_rev=hist_nd(v,.25,MIN=0,MAX=1,REVERSE_INDICES=ri)
	IDL> ind=[1+nx*(0+ny*0)]
	IDL> ind
	1
	IDL> v[0,ri[ri[ind]:ri[ind+1]-1]]
	0.33962137
	0.32754040
	0.35023004
	0.46381110
	IDL> v[1,ri[ri[ind]:ri[ind+1]-1]]
	0.22486423
	0.074997164
	0.025470998
	0.10085465
	IDL> v[2,ri[ri[ind]:ri[ind+1]-1]]
	0.23146111
	0.13568136
	0.23750301
	0.13226520
	As you can see, 0-th element falls between 0.25-0.5, 1st and 2nc element within 0-0.25
	which is what you would expect from i,j,k = 1,0,0 
	'''	
	if len(input_arr.shape)==1:
		bins_now=np.arange(hist_min,hist_max+boxsize,boxsize)
		#print(bins_now)
		hist_orig,edges=np.histogram(input_arr,bins=len(bins_now)-1,
									 range=(bins_now[0],bins_now[-1]))
		#address of the ind of the bin where these objects should fall
		#has the same dimension as input arr
		hist_fake=np.insert(hist_orig,0,0)#hist_orig[0]
		edited_hist=np.cumsum(hist_fake)#np.cumsum(np.insert(hist_orig,0,1))-1#np.cumsum(np.insert(hist_orig,0,1))
		#edited_hist=np.insert(edited_hist,0,hist_orig[0])
		address=np.digitize(input_arr,bins_now)-1
		sorted_address=np.argsort(address)
		rev_ind=np.concatenate((edited_hist+len(edited_hist),sorted_address))
		return hist_orig,rev_ind,edges
	else:	
		n_dim=len(input_arr)
		all_ind=[]
		options_bin=[]
		mesh=[]
		options_two=[1.]
		hist_range=[]
		for mini,maxi,boxi,arr_ele in zip(hist_min,hist_max,boxsize,input_arr):
			bins_now=np.arange(mini,maxi+boxi,boxi)
			nx=len(bins_now)
			hist_range.append((bins_now[0],bins_now[-1]))
			all_ind.append(np.digitize(arr_ele,bins_now)-1)
			options_bin.append(nx)
			options_two.append(nx)
		all_ind=np.array(all_ind)
		options_bin=np.array(options_bin)
		options_two=np.array(options_two[:-1])
		'''
		i have changed from address=i+nx*(j+ny*k),
		which is not a unique identifier, for example if nx=2 and ny=6,
		we will run into issues as i=1, j=1 results is uniq_ind=1+2*1 = 3
		as i=0 and j=3
		to uniquely identify each bin using a base 
		we should use a distinguisher (a variable name), the smallest value of which
		should be equal to the length of the largest dimension
		This could lead to memory issues if we move to ridiculously large images,
		but we will cross that bridge if we must
		'''
		distinguisher=np.max(options_bin)
		address=np.matmul(distinguisher**np.arange(n_dim),all_ind)
		sorted_address=np.argsort(address)
		hist_orig,edges=np.histogramdd(input_arr.T,bins=options_bin-1,
									   range=hist_range)
		#we need this blah_new step because IDL creates a new bin at the edge of the last bin
		hist=np.cumsum(hist_orig.T.flatten())
		hist=np.concatenate((np.array([0]),hist))
		hist=len(hist)+hist
		rev_ind=np.concatenate((hist,sorted_address)).astype(int)
		return hist_orig, rev_ind, edges


def match2_improved(arr1,arr2):
	'''
	not used for anything
	'''
	sub_a=[]
	sub_b=[]
	for ele1 in arr1:
		ind_now=np.where(arr2==ele1)[0]
		if len(ind_now)!=0:
			sub_a.append(ind_now[0])
	for ele1 in arr2:
		ind_now=np.where(arr1==ele1)[0]
		if len(ind_now)!=0:
			sub_b.append(ind_now[0])		
	return np.array(sub_a),np.array(sub_b)	


def match2_only_subb(arr1,arr2):
	'''
	not used for anything
	'''
	sub_b=[]
	for ele1 in arr2:
		ind_now=np.where(arr1==ele1)[0]
		if len(ind_now)==0:
			sub_b.append(np.nan)
		else:
			sub_b.append(ind_now[0])		
	return np.array(sub_b)