Right number of WCS axis for each HDU in output
This commit is contained in:
@@ -16,7 +16,7 @@ from astropy.io import fits
|
||||
from astropy.wcs import WCS
|
||||
|
||||
from .convex_hull import clean_ROI
|
||||
from .utils import wcs_CD_to_PC, wcs_PA
|
||||
from .utils import wcs_CD_to_PC, wcs_PA, add_stokes_axis_to_header, remove_stokes_axis_from_header
|
||||
|
||||
|
||||
def get_obs_data(infiles, data_folder="", compute_flux=False):
|
||||
@@ -141,15 +141,17 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P,
|
||||
Only returned if return_hdul is True.
|
||||
"""
|
||||
# Create new WCS object given the modified images
|
||||
new_wcs = WCS(header_stokes).deepcopy()
|
||||
new_wcs = WCS(header_stokes).celestial.deepcopy()
|
||||
header = remove_stokes_axis_from_header(header_stokes).copy()
|
||||
|
||||
if data_mask.shape != (1, 1):
|
||||
vertex = clean_ROI(data_mask)
|
||||
shape = vertex[1::2] - vertex[0::2]
|
||||
new_wcs.array_shape = (4, *shape)
|
||||
new_wcs.wcs.crpix[1:] = np.array(new_wcs.wcs.crpix[1:]) - vertex[0::-2]
|
||||
new_wcs.array_shape = shape
|
||||
new_wcs.wcs.crpix = np.array(new_wcs.wcs.crpix) - vertex[0::-2]
|
||||
for key, val in list(new_wcs.to_header().items()) + [("NAXIS", 2), ("NAXIS1", new_wcs.array_shape[1]), ("NAXIS2", new_wcs.array_shape[0])]:
|
||||
header[key] = val
|
||||
|
||||
header = new_wcs.to_header()
|
||||
header["TELESCOP"] = (header_stokes["TELESCOP"] if "TELESCOP" in list(header_stokes.keys()) else "HST", "telescope used to acquire data")
|
||||
header["INSTRUME"] = (header_stokes["INSTRUME"] if "INSTRUME" in list(header_stokes.keys()) else "FOC", "identifier for instrument used to acuire data")
|
||||
header["PHOTPLAM"] = (header_stokes["PHOTPLAM"], "Pivot Wavelength")
|
||||
@@ -190,8 +192,9 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P,
|
||||
# Add I_stokes as PrimaryHDU
|
||||
header["datatype"] = ("STOKES", "type of data stored in the HDU")
|
||||
Stokes[np.broadcast_to((1 - data_mask).astype(bool), Stokes.shape)] = 0.0
|
||||
primary_hdu = fits.PrimaryHDU(data=Stokes, header=header)
|
||||
primary_hdu.name = "Stokes"
|
||||
hdu_head = add_stokes_axis_to_header(header, 2)
|
||||
primary_hdu = fits.PrimaryHDU(data=Stokes, header=hdu_head)
|
||||
primary_hdu.name = "STOKES"
|
||||
hdul.append(primary_hdu)
|
||||
|
||||
# Add Stokes_cov, P, s_P, PA, s_PA to the HDUList
|
||||
@@ -206,13 +209,15 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P,
|
||||
[s_PA_P, "Pol_ang_stat_err"],
|
||||
[data_mask, "Data_mask"],
|
||||
]:
|
||||
hdu_header = header.copy()
|
||||
hdu_header["datatype"] = name
|
||||
hdu_head = header.copy()
|
||||
hdu_head["datatype"] = name
|
||||
if name == "STOKES_COV":
|
||||
hdu_head = add_stokes_axis_to_header(hdu_head, 2)
|
||||
hdu_head = add_stokes_axis_to_header(hdu_head, 3)
|
||||
data[np.broadcast_to((1 - data_mask).astype(bool), data.shape)] = 0.0
|
||||
else:
|
||||
data[(1 - data_mask).astype(bool)] = 0.0
|
||||
hdu = fits.ImageHDU(data=data, header=hdu_header)
|
||||
hdu = fits.ImageHDU(data=data, header=hdu_head)
|
||||
hdu.name = name
|
||||
hdul.append(hdu)
|
||||
|
||||
|
||||
@@ -46,7 +46,6 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from astropy import log
|
||||
from astropy.wcs import WCS
|
||||
from astropy.wcs.utils import add_stokes_axis_to_wcs
|
||||
from matplotlib.colors import LogNorm
|
||||
from matplotlib.patches import Rectangle
|
||||
from scipy.ndimage import rotate as sc_rotate
|
||||
@@ -58,7 +57,7 @@ from .convex_hull import image_hull
|
||||
from .cross_correlation import phase_cross_correlation
|
||||
from .deconvolve import deconvolve_im, gaussian2d, gaussian_psf, zeropad
|
||||
from .plots import plot_obs
|
||||
from .utils import princ_angle
|
||||
from .utils import princ_angle, add_stokes_axis_to_header
|
||||
|
||||
log.setLevel("ERROR")
|
||||
|
||||
@@ -1415,11 +1414,7 @@ def compute_Stokes(data_array, error_array, data_mask, headers, FWHM=None, scale
|
||||
Stokes[np.isnan(Stokes)] = 0.0
|
||||
Stokes[1:][np.broadcast_to(Stokes[0] == 0.0, Stokes[1:].shape)] = 0.0
|
||||
Stokes_cov[np.isnan(Stokes_cov)] = fmax
|
||||
wcs_Stokes = add_stokes_axis_to_wcs(WCS(header_stokes), 0)
|
||||
wcs_Stokes.array_shape = (4, *Stokes.shape[1:])[::-1]
|
||||
header_stokes["NAXIS1"], header_stokes["NAXIS2"], header_stokes["NAXIS3"] = wcs_Stokes.array_shape[::-1]
|
||||
for key, val in list(wcs_Stokes.to_header().items()) + list(zip(["PC1_1", "PC1_2", "PC1_3", "PC2_1", "PC3_1", "CUNIT1"], [1, 0, 0, 0, 0, "STOKES"])):
|
||||
header_stokes[key] = val
|
||||
header_stokes = add_stokes_axis_to_header(header_stokes, 0)
|
||||
|
||||
if integrate:
|
||||
# Compute integrated values for P, PA before any rotation
|
||||
|
||||
@@ -242,3 +242,82 @@ def wcs_PA(PC, cdelt):
|
||||
rot2 = np.pi / 2.0 - np.arctan2(abs(cdelt[0]) * PC[0, 1], cdelt[1] * PC[1, 1])
|
||||
orient = 0.5 * (rot + rot2) * 180.0 / np.pi
|
||||
return orient
|
||||
|
||||
|
||||
def add_stokes_axis_to_header(header, ind=0):
|
||||
"""
|
||||
Add a new Stokes axis to the WCS cards in the header.
|
||||
----------
|
||||
Inputs:
|
||||
header : astropy.io.fits.header.Header
|
||||
The header in which the WCS to work on is saved.
|
||||
ind : int, optional
|
||||
Index of the WCS to insert the new Stokes axis in front of.
|
||||
To add at the end, do add_before_ind = wcs.wcs.naxis
|
||||
The beginning is at position 0.
|
||||
Default to 0.
|
||||
----------
|
||||
Returns:
|
||||
new_head : astropy.io.fits.header.Header
|
||||
A new Header instance with an additional Stokes axis
|
||||
"""
|
||||
from astropy.wcs import WCS
|
||||
from astropy.wcs.utils import add_stokes_axis_to_wcs
|
||||
|
||||
wcs = WCS(header).deepcopy()
|
||||
wcs_Stokes = add_stokes_axis_to_wcs(wcs, ind).deepcopy()
|
||||
wcs_Stokes.array_shape = (*wcs.array_shape[ind:], 4, *wcs.array_shape[:ind]) if ind < wcs.wcs.naxis else (4, *wcs.array_shape)
|
||||
new_head = header.copy()
|
||||
new_head["NAXIS"] = wcs_Stokes.wcs.naxis
|
||||
for key in wcs.to_header().keys():
|
||||
if key not in wcs_Stokes.to_header().keys():
|
||||
del new_head[key]
|
||||
for key, val in (
|
||||
list(wcs_Stokes.to_header().items())
|
||||
+ [("NAXIS%d" % (i + 1), k) for i, k in enumerate(wcs_Stokes.array_shape[::-1])]
|
||||
+ [("CUNIT%d" % (ind + 1), "STOKES")]
|
||||
):
|
||||
if key not in header.keys() and key[:-1] + str(wcs.wcs.naxis) in header.keys():
|
||||
new_head.insert(key[:-1] + str(wcs.wcs.naxis), (key, val), after=int(key[-1]) < wcs.wcs.naxis)
|
||||
elif key not in header.keys() and key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis) in header.keys():
|
||||
new_head.insert(key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis), (key, val), after=int(key[-1]) < wcs.wcs.naxis)
|
||||
else:
|
||||
new_head[key] = val
|
||||
return new_head
|
||||
|
||||
|
||||
def remove_stokes_axis_from_header(header):
|
||||
"""
|
||||
Remove a Stokes axis to the WCS cards in the header.
|
||||
----------
|
||||
Inputs:
|
||||
header : astropy.io.fits.header.Header
|
||||
The header in which the WCS to work on is saved.
|
||||
----------
|
||||
Returns:
|
||||
new_head : astropy.io.fits.header.Header
|
||||
A new Header instance with only a celestial WCS.
|
||||
"""
|
||||
from astropy.wcs import WCS
|
||||
|
||||
wcs = WCS(header).deepcopy()
|
||||
new_wcs = WCS(header).celestial.deepcopy()
|
||||
new_head = header.copy()
|
||||
del new_head["NAXIS%d" % (new_wcs.wcs.naxis + 1)]
|
||||
new_head["NAXIS"] = new_wcs.wcs.naxis
|
||||
for i, k in enumerate(new_wcs.array_shape[::-1]):
|
||||
new_head["NAXIS%d" % (i + 1)] = k
|
||||
for key in list(WCS(header).to_header().keys()) + list(
|
||||
np.unique([["PC%d_%d" % (i + 1, j + 1) for i in range(wcs.wcs.naxis)] for j in range(wcs.wcs.naxis)])
|
||||
):
|
||||
if key in new_head.keys() and key not in new_wcs.to_header().keys():
|
||||
del new_head[key]
|
||||
for key, val in new_wcs.to_header().items():
|
||||
if key not in new_head.keys() and key[:-1] + str(wcs.wcs.naxis) in new_head.keys():
|
||||
new_head.insert(key[:-1] + str(wcs.wcs.naxis), (key, val), after=True)
|
||||
elif key not in new_head.keys() and key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis) in new_head.keys():
|
||||
new_head.insert(key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis), (key, val), after=True)
|
||||
else:
|
||||
new_head[key] = val
|
||||
|
||||
return new_head
|
||||
|
||||
Reference in New Issue
Block a user