diff --git a/package/lib/fits.py b/package/lib/fits.py index 20eb2da..1276b3f 100755 --- a/package/lib/fits.py +++ b/package/lib/fits.py @@ -16,7 +16,7 @@ from astropy.io import fits from astropy.wcs import WCS from .convex_hull import clean_ROI -from .utils import wcs_CD_to_PC, wcs_PA +from .utils import wcs_CD_to_PC, wcs_PA, add_stokes_axis_to_header, remove_stokes_axis_from_header def get_obs_data(infiles, data_folder="", compute_flux=False): @@ -141,15 +141,17 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P, Only returned if return_hdul is True. """ # Create new WCS object given the modified images - new_wcs = WCS(header_stokes).deepcopy() + new_wcs = WCS(header_stokes).celestial.deepcopy() + header = remove_stokes_axis_from_header(header_stokes).copy() if data_mask.shape != (1, 1): vertex = clean_ROI(data_mask) shape = vertex[1::2] - vertex[0::2] - new_wcs.array_shape = (4, *shape) - new_wcs.wcs.crpix[1:] = np.array(new_wcs.wcs.crpix[1:]) - vertex[0::-2] + new_wcs.array_shape = shape + new_wcs.wcs.crpix = np.array(new_wcs.wcs.crpix) - vertex[0::-2] + for key, val in list(new_wcs.to_header().items()) + [("NAXIS", 2), ("NAXIS1", new_wcs.array_shape[1]), ("NAXIS2", new_wcs.array_shape[0])]: + header[key] = val - header = new_wcs.to_header() header["TELESCOP"] = (header_stokes["TELESCOP"] if "TELESCOP" in list(header_stokes.keys()) else "HST", "telescope used to acquire data") header["INSTRUME"] = (header_stokes["INSTRUME"] if "INSTRUME" in list(header_stokes.keys()) else "FOC", "identifier for instrument used to acuire data") header["PHOTPLAM"] = (header_stokes["PHOTPLAM"], "Pivot Wavelength") @@ -190,8 +192,9 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P, # Add I_stokes as PrimaryHDU header["datatype"] = ("STOKES", "type of data stored in the HDU") Stokes[np.broadcast_to((1 - data_mask).astype(bool), Stokes.shape)] = 0.0 - primary_hdu = fits.PrimaryHDU(data=Stokes, header=header) - primary_hdu.name = "Stokes" + hdu_head = add_stokes_axis_to_header(header, 2) + primary_hdu = fits.PrimaryHDU(data=Stokes, header=hdu_head) + primary_hdu.name = "STOKES" hdul.append(primary_hdu) # Add Stokes_cov, P, s_P, PA, s_PA to the HDUList @@ -206,13 +209,15 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P, [s_PA_P, "Pol_ang_stat_err"], [data_mask, "Data_mask"], ]: - hdu_header = header.copy() - hdu_header["datatype"] = name + hdu_head = header.copy() + hdu_head["datatype"] = name if name == "STOKES_COV": + hdu_head = add_stokes_axis_to_header(hdu_head, 2) + hdu_head = add_stokes_axis_to_header(hdu_head, 3) data[np.broadcast_to((1 - data_mask).astype(bool), data.shape)] = 0.0 else: data[(1 - data_mask).astype(bool)] = 0.0 - hdu = fits.ImageHDU(data=data, header=hdu_header) + hdu = fits.ImageHDU(data=data, header=hdu_head) hdu.name = name hdul.append(hdu) diff --git a/package/lib/reduction.py b/package/lib/reduction.py index 2239faa..749ff5f 100755 --- a/package/lib/reduction.py +++ b/package/lib/reduction.py @@ -46,7 +46,6 @@ import matplotlib.pyplot as plt import numpy as np from astropy import log from astropy.wcs import WCS -from astropy.wcs.utils import add_stokes_axis_to_wcs from matplotlib.colors import LogNorm from matplotlib.patches import Rectangle from scipy.ndimage import rotate as sc_rotate @@ -58,7 +57,7 @@ from .convex_hull import image_hull from .cross_correlation import phase_cross_correlation from .deconvolve import deconvolve_im, gaussian2d, gaussian_psf, zeropad from .plots import plot_obs -from .utils import princ_angle +from .utils import princ_angle, add_stokes_axis_to_header log.setLevel("ERROR") @@ -1415,11 +1414,7 @@ def compute_Stokes(data_array, error_array, data_mask, headers, FWHM=None, scale Stokes[np.isnan(Stokes)] = 0.0 Stokes[1:][np.broadcast_to(Stokes[0] == 0.0, Stokes[1:].shape)] = 0.0 Stokes_cov[np.isnan(Stokes_cov)] = fmax - wcs_Stokes = add_stokes_axis_to_wcs(WCS(header_stokes), 0) - wcs_Stokes.array_shape = (4, *Stokes.shape[1:])[::-1] - header_stokes["NAXIS1"], header_stokes["NAXIS2"], header_stokes["NAXIS3"] = wcs_Stokes.array_shape[::-1] - for key, val in list(wcs_Stokes.to_header().items()) + list(zip(["PC1_1", "PC1_2", "PC1_3", "PC2_1", "PC3_1", "CUNIT1"], [1, 0, 0, 0, 0, "STOKES"])): - header_stokes[key] = val + header_stokes = add_stokes_axis_to_header(header_stokes, 0) if integrate: # Compute integrated values for P, PA before any rotation diff --git a/package/lib/utils.py b/package/lib/utils.py index 943de37..452a66c 100755 --- a/package/lib/utils.py +++ b/package/lib/utils.py @@ -242,3 +242,82 @@ def wcs_PA(PC, cdelt): rot2 = np.pi / 2.0 - np.arctan2(abs(cdelt[0]) * PC[0, 1], cdelt[1] * PC[1, 1]) orient = 0.5 * (rot + rot2) * 180.0 / np.pi return orient + + +def add_stokes_axis_to_header(header, ind=0): + """ + Add a new Stokes axis to the WCS cards in the header. + ---------- + Inputs: + header : astropy.io.fits.header.Header + The header in which the WCS to work on is saved. + ind : int, optional + Index of the WCS to insert the new Stokes axis in front of. + To add at the end, do add_before_ind = wcs.wcs.naxis + The beginning is at position 0. + Default to 0. + ---------- + Returns: + new_head : astropy.io.fits.header.Header + A new Header instance with an additional Stokes axis + """ + from astropy.wcs import WCS + from astropy.wcs.utils import add_stokes_axis_to_wcs + + wcs = WCS(header).deepcopy() + wcs_Stokes = add_stokes_axis_to_wcs(wcs, ind).deepcopy() + wcs_Stokes.array_shape = (*wcs.array_shape[ind:], 4, *wcs.array_shape[:ind]) if ind < wcs.wcs.naxis else (4, *wcs.array_shape) + new_head = header.copy() + new_head["NAXIS"] = wcs_Stokes.wcs.naxis + for key in wcs.to_header().keys(): + if key not in wcs_Stokes.to_header().keys(): + del new_head[key] + for key, val in ( + list(wcs_Stokes.to_header().items()) + + [("NAXIS%d" % (i + 1), k) for i, k in enumerate(wcs_Stokes.array_shape[::-1])] + + [("CUNIT%d" % (ind + 1), "STOKES")] + ): + if key not in header.keys() and key[:-1] + str(wcs.wcs.naxis) in header.keys(): + new_head.insert(key[:-1] + str(wcs.wcs.naxis), (key, val), after=int(key[-1]) < wcs.wcs.naxis) + elif key not in header.keys() and key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis) in header.keys(): + new_head.insert(key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis), (key, val), after=int(key[-1]) < wcs.wcs.naxis) + else: + new_head[key] = val + return new_head + + +def remove_stokes_axis_from_header(header): + """ + Remove a Stokes axis to the WCS cards in the header. + ---------- + Inputs: + header : astropy.io.fits.header.Header + The header in which the WCS to work on is saved. + ---------- + Returns: + new_head : astropy.io.fits.header.Header + A new Header instance with only a celestial WCS. + """ + from astropy.wcs import WCS + + wcs = WCS(header).deepcopy() + new_wcs = WCS(header).celestial.deepcopy() + new_head = header.copy() + del new_head["NAXIS%d" % (new_wcs.wcs.naxis + 1)] + new_head["NAXIS"] = new_wcs.wcs.naxis + for i, k in enumerate(new_wcs.array_shape[::-1]): + new_head["NAXIS%d" % (i + 1)] = k + for key in list(WCS(header).to_header().keys()) + list( + np.unique([["PC%d_%d" % (i + 1, j + 1) for i in range(wcs.wcs.naxis)] for j in range(wcs.wcs.naxis)]) + ): + if key in new_head.keys() and key not in new_wcs.to_header().keys(): + del new_head[key] + for key, val in new_wcs.to_header().items(): + if key not in new_head.keys() and key[:-1] + str(wcs.wcs.naxis) in new_head.keys(): + new_head.insert(key[:-1] + str(wcs.wcs.naxis), (key, val), after=True) + elif key not in new_head.keys() and key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis) in new_head.keys(): + new_head.insert(key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis), (key, val), after=True) + else: + new_head[key] = val + + return new_head