Right number of WCS axis for each HDU in output

This commit is contained in:
2025-08-06 17:28:18 +02:00
parent f47c650dc5
commit e639695618
3 changed files with 96 additions and 17 deletions

View File

@@ -16,7 +16,7 @@ from astropy.io import fits
from astropy.wcs import WCS
from .convex_hull import clean_ROI
from .utils import wcs_CD_to_PC, wcs_PA
from .utils import wcs_CD_to_PC, wcs_PA, add_stokes_axis_to_header, remove_stokes_axis_from_header
def get_obs_data(infiles, data_folder="", compute_flux=False):
@@ -141,15 +141,17 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P,
Only returned if return_hdul is True.
"""
# Create new WCS object given the modified images
new_wcs = WCS(header_stokes).deepcopy()
new_wcs = WCS(header_stokes).celestial.deepcopy()
header = remove_stokes_axis_from_header(header_stokes).copy()
if data_mask.shape != (1, 1):
vertex = clean_ROI(data_mask)
shape = vertex[1::2] - vertex[0::2]
new_wcs.array_shape = (4, *shape)
new_wcs.wcs.crpix[1:] = np.array(new_wcs.wcs.crpix[1:]) - vertex[0::-2]
new_wcs.array_shape = shape
new_wcs.wcs.crpix = np.array(new_wcs.wcs.crpix) - vertex[0::-2]
for key, val in list(new_wcs.to_header().items()) + [("NAXIS", 2), ("NAXIS1", new_wcs.array_shape[1]), ("NAXIS2", new_wcs.array_shape[0])]:
header[key] = val
header = new_wcs.to_header()
header["TELESCOP"] = (header_stokes["TELESCOP"] if "TELESCOP" in list(header_stokes.keys()) else "HST", "telescope used to acquire data")
header["INSTRUME"] = (header_stokes["INSTRUME"] if "INSTRUME" in list(header_stokes.keys()) else "FOC", "identifier for instrument used to acuire data")
header["PHOTPLAM"] = (header_stokes["PHOTPLAM"], "Pivot Wavelength")
@@ -190,8 +192,9 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P,
# Add I_stokes as PrimaryHDU
header["datatype"] = ("STOKES", "type of data stored in the HDU")
Stokes[np.broadcast_to((1 - data_mask).astype(bool), Stokes.shape)] = 0.0
primary_hdu = fits.PrimaryHDU(data=Stokes, header=header)
primary_hdu.name = "Stokes"
hdu_head = add_stokes_axis_to_header(header, 2)
primary_hdu = fits.PrimaryHDU(data=Stokes, header=hdu_head)
primary_hdu.name = "STOKES"
hdul.append(primary_hdu)
# Add Stokes_cov, P, s_P, PA, s_PA to the HDUList
@@ -206,13 +209,15 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P,
[s_PA_P, "Pol_ang_stat_err"],
[data_mask, "Data_mask"],
]:
hdu_header = header.copy()
hdu_header["datatype"] = name
hdu_head = header.copy()
hdu_head["datatype"] = name
if name == "STOKES_COV":
hdu_head = add_stokes_axis_to_header(hdu_head, 2)
hdu_head = add_stokes_axis_to_header(hdu_head, 3)
data[np.broadcast_to((1 - data_mask).astype(bool), data.shape)] = 0.0
else:
data[(1 - data_mask).astype(bool)] = 0.0
hdu = fits.ImageHDU(data=data, header=hdu_header)
hdu = fits.ImageHDU(data=data, header=hdu_head)
hdu.name = name
hdul.append(hdu)

View File

@@ -46,7 +46,6 @@ import matplotlib.pyplot as plt
import numpy as np
from astropy import log
from astropy.wcs import WCS
from astropy.wcs.utils import add_stokes_axis_to_wcs
from matplotlib.colors import LogNorm
from matplotlib.patches import Rectangle
from scipy.ndimage import rotate as sc_rotate
@@ -58,7 +57,7 @@ from .convex_hull import image_hull
from .cross_correlation import phase_cross_correlation
from .deconvolve import deconvolve_im, gaussian2d, gaussian_psf, zeropad
from .plots import plot_obs
from .utils import princ_angle
from .utils import princ_angle, add_stokes_axis_to_header
log.setLevel("ERROR")
@@ -1415,11 +1414,7 @@ def compute_Stokes(data_array, error_array, data_mask, headers, FWHM=None, scale
Stokes[np.isnan(Stokes)] = 0.0
Stokes[1:][np.broadcast_to(Stokes[0] == 0.0, Stokes[1:].shape)] = 0.0
Stokes_cov[np.isnan(Stokes_cov)] = fmax
wcs_Stokes = add_stokes_axis_to_wcs(WCS(header_stokes), 0)
wcs_Stokes.array_shape = (4, *Stokes.shape[1:])[::-1]
header_stokes["NAXIS1"], header_stokes["NAXIS2"], header_stokes["NAXIS3"] = wcs_Stokes.array_shape[::-1]
for key, val in list(wcs_Stokes.to_header().items()) + list(zip(["PC1_1", "PC1_2", "PC1_3", "PC2_1", "PC3_1", "CUNIT1"], [1, 0, 0, 0, 0, "STOKES"])):
header_stokes[key] = val
header_stokes = add_stokes_axis_to_header(header_stokes, 0)
if integrate:
# Compute integrated values for P, PA before any rotation

View File

@@ -242,3 +242,82 @@ def wcs_PA(PC, cdelt):
rot2 = np.pi / 2.0 - np.arctan2(abs(cdelt[0]) * PC[0, 1], cdelt[1] * PC[1, 1])
orient = 0.5 * (rot + rot2) * 180.0 / np.pi
return orient
def add_stokes_axis_to_header(header, ind=0):
"""
Add a new Stokes axis to the WCS cards in the header.
----------
Inputs:
header : astropy.io.fits.header.Header
The header in which the WCS to work on is saved.
ind : int, optional
Index of the WCS to insert the new Stokes axis in front of.
To add at the end, do add_before_ind = wcs.wcs.naxis
The beginning is at position 0.
Default to 0.
----------
Returns:
new_head : astropy.io.fits.header.Header
A new Header instance with an additional Stokes axis
"""
from astropy.wcs import WCS
from astropy.wcs.utils import add_stokes_axis_to_wcs
wcs = WCS(header).deepcopy()
wcs_Stokes = add_stokes_axis_to_wcs(wcs, ind).deepcopy()
wcs_Stokes.array_shape = (*wcs.array_shape[ind:], 4, *wcs.array_shape[:ind]) if ind < wcs.wcs.naxis else (4, *wcs.array_shape)
new_head = header.copy()
new_head["NAXIS"] = wcs_Stokes.wcs.naxis
for key in wcs.to_header().keys():
if key not in wcs_Stokes.to_header().keys():
del new_head[key]
for key, val in (
list(wcs_Stokes.to_header().items())
+ [("NAXIS%d" % (i + 1), k) for i, k in enumerate(wcs_Stokes.array_shape[::-1])]
+ [("CUNIT%d" % (ind + 1), "STOKES")]
):
if key not in header.keys() and key[:-1] + str(wcs.wcs.naxis) in header.keys():
new_head.insert(key[:-1] + str(wcs.wcs.naxis), (key, val), after=int(key[-1]) < wcs.wcs.naxis)
elif key not in header.keys() and key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis) in header.keys():
new_head.insert(key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis), (key, val), after=int(key[-1]) < wcs.wcs.naxis)
else:
new_head[key] = val
return new_head
def remove_stokes_axis_from_header(header):
"""
Remove a Stokes axis to the WCS cards in the header.
----------
Inputs:
header : astropy.io.fits.header.Header
The header in which the WCS to work on is saved.
----------
Returns:
new_head : astropy.io.fits.header.Header
A new Header instance with only a celestial WCS.
"""
from astropy.wcs import WCS
wcs = WCS(header).deepcopy()
new_wcs = WCS(header).celestial.deepcopy()
new_head = header.copy()
del new_head["NAXIS%d" % (new_wcs.wcs.naxis + 1)]
new_head["NAXIS"] = new_wcs.wcs.naxis
for i, k in enumerate(new_wcs.array_shape[::-1]):
new_head["NAXIS%d" % (i + 1)] = k
for key in list(WCS(header).to_header().keys()) + list(
np.unique([["PC%d_%d" % (i + 1, j + 1) for i in range(wcs.wcs.naxis)] for j in range(wcs.wcs.naxis)])
):
if key in new_head.keys() and key not in new_wcs.to_header().keys():
del new_head[key]
for key, val in new_wcs.to_header().items():
if key not in new_head.keys() and key[:-1] + str(wcs.wcs.naxis) in new_head.keys():
new_head.insert(key[:-1] + str(wcs.wcs.naxis), (key, val), after=True)
elif key not in new_head.keys() and key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis) in new_head.keys():
new_head.insert(key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis), (key, val), after=True)
else:
new_head[key] = val
return new_head