Right number of WCS axis for each HDU in output

This commit is contained in:
2025-08-06 17:28:18 +02:00
parent f47c650dc5
commit e639695618
3 changed files with 96 additions and 17 deletions

View File

@@ -16,7 +16,7 @@ from astropy.io import fits
from astropy.wcs import WCS from astropy.wcs import WCS
from .convex_hull import clean_ROI from .convex_hull import clean_ROI
from .utils import wcs_CD_to_PC, wcs_PA from .utils import wcs_CD_to_PC, wcs_PA, add_stokes_axis_to_header, remove_stokes_axis_from_header
def get_obs_data(infiles, data_folder="", compute_flux=False): def get_obs_data(infiles, data_folder="", compute_flux=False):
@@ -141,15 +141,17 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P,
Only returned if return_hdul is True. Only returned if return_hdul is True.
""" """
# Create new WCS object given the modified images # Create new WCS object given the modified images
new_wcs = WCS(header_stokes).deepcopy() new_wcs = WCS(header_stokes).celestial.deepcopy()
header = remove_stokes_axis_from_header(header_stokes).copy()
if data_mask.shape != (1, 1): if data_mask.shape != (1, 1):
vertex = clean_ROI(data_mask) vertex = clean_ROI(data_mask)
shape = vertex[1::2] - vertex[0::2] shape = vertex[1::2] - vertex[0::2]
new_wcs.array_shape = (4, *shape) new_wcs.array_shape = shape
new_wcs.wcs.crpix[1:] = np.array(new_wcs.wcs.crpix[1:]) - vertex[0::-2] new_wcs.wcs.crpix = np.array(new_wcs.wcs.crpix) - vertex[0::-2]
for key, val in list(new_wcs.to_header().items()) + [("NAXIS", 2), ("NAXIS1", new_wcs.array_shape[1]), ("NAXIS2", new_wcs.array_shape[0])]:
header[key] = val
header = new_wcs.to_header()
header["TELESCOP"] = (header_stokes["TELESCOP"] if "TELESCOP" in list(header_stokes.keys()) else "HST", "telescope used to acquire data") header["TELESCOP"] = (header_stokes["TELESCOP"] if "TELESCOP" in list(header_stokes.keys()) else "HST", "telescope used to acquire data")
header["INSTRUME"] = (header_stokes["INSTRUME"] if "INSTRUME" in list(header_stokes.keys()) else "FOC", "identifier for instrument used to acuire data") header["INSTRUME"] = (header_stokes["INSTRUME"] if "INSTRUME" in list(header_stokes.keys()) else "FOC", "identifier for instrument used to acuire data")
header["PHOTPLAM"] = (header_stokes["PHOTPLAM"], "Pivot Wavelength") header["PHOTPLAM"] = (header_stokes["PHOTPLAM"], "Pivot Wavelength")
@@ -190,8 +192,9 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P,
# Add I_stokes as PrimaryHDU # Add I_stokes as PrimaryHDU
header["datatype"] = ("STOKES", "type of data stored in the HDU") header["datatype"] = ("STOKES", "type of data stored in the HDU")
Stokes[np.broadcast_to((1 - data_mask).astype(bool), Stokes.shape)] = 0.0 Stokes[np.broadcast_to((1 - data_mask).astype(bool), Stokes.shape)] = 0.0
primary_hdu = fits.PrimaryHDU(data=Stokes, header=header) hdu_head = add_stokes_axis_to_header(header, 2)
primary_hdu.name = "Stokes" primary_hdu = fits.PrimaryHDU(data=Stokes, header=hdu_head)
primary_hdu.name = "STOKES"
hdul.append(primary_hdu) hdul.append(primary_hdu)
# Add Stokes_cov, P, s_P, PA, s_PA to the HDUList # Add Stokes_cov, P, s_P, PA, s_PA to the HDUList
@@ -206,13 +209,15 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P,
[s_PA_P, "Pol_ang_stat_err"], [s_PA_P, "Pol_ang_stat_err"],
[data_mask, "Data_mask"], [data_mask, "Data_mask"],
]: ]:
hdu_header = header.copy() hdu_head = header.copy()
hdu_header["datatype"] = name hdu_head["datatype"] = name
if name == "STOKES_COV": if name == "STOKES_COV":
hdu_head = add_stokes_axis_to_header(hdu_head, 2)
hdu_head = add_stokes_axis_to_header(hdu_head, 3)
data[np.broadcast_to((1 - data_mask).astype(bool), data.shape)] = 0.0 data[np.broadcast_to((1 - data_mask).astype(bool), data.shape)] = 0.0
else: else:
data[(1 - data_mask).astype(bool)] = 0.0 data[(1 - data_mask).astype(bool)] = 0.0
hdu = fits.ImageHDU(data=data, header=hdu_header) hdu = fits.ImageHDU(data=data, header=hdu_head)
hdu.name = name hdu.name = name
hdul.append(hdu) hdul.append(hdu)

View File

@@ -46,7 +46,6 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
from astropy import log from astropy import log
from astropy.wcs import WCS from astropy.wcs import WCS
from astropy.wcs.utils import add_stokes_axis_to_wcs
from matplotlib.colors import LogNorm from matplotlib.colors import LogNorm
from matplotlib.patches import Rectangle from matplotlib.patches import Rectangle
from scipy.ndimage import rotate as sc_rotate from scipy.ndimage import rotate as sc_rotate
@@ -58,7 +57,7 @@ from .convex_hull import image_hull
from .cross_correlation import phase_cross_correlation from .cross_correlation import phase_cross_correlation
from .deconvolve import deconvolve_im, gaussian2d, gaussian_psf, zeropad from .deconvolve import deconvolve_im, gaussian2d, gaussian_psf, zeropad
from .plots import plot_obs from .plots import plot_obs
from .utils import princ_angle from .utils import princ_angle, add_stokes_axis_to_header
log.setLevel("ERROR") log.setLevel("ERROR")
@@ -1415,11 +1414,7 @@ def compute_Stokes(data_array, error_array, data_mask, headers, FWHM=None, scale
Stokes[np.isnan(Stokes)] = 0.0 Stokes[np.isnan(Stokes)] = 0.0
Stokes[1:][np.broadcast_to(Stokes[0] == 0.0, Stokes[1:].shape)] = 0.0 Stokes[1:][np.broadcast_to(Stokes[0] == 0.0, Stokes[1:].shape)] = 0.0
Stokes_cov[np.isnan(Stokes_cov)] = fmax Stokes_cov[np.isnan(Stokes_cov)] = fmax
wcs_Stokes = add_stokes_axis_to_wcs(WCS(header_stokes), 0) header_stokes = add_stokes_axis_to_header(header_stokes, 0)
wcs_Stokes.array_shape = (4, *Stokes.shape[1:])[::-1]
header_stokes["NAXIS1"], header_stokes["NAXIS2"], header_stokes["NAXIS3"] = wcs_Stokes.array_shape[::-1]
for key, val in list(wcs_Stokes.to_header().items()) + list(zip(["PC1_1", "PC1_2", "PC1_3", "PC2_1", "PC3_1", "CUNIT1"], [1, 0, 0, 0, 0, "STOKES"])):
header_stokes[key] = val
if integrate: if integrate:
# Compute integrated values for P, PA before any rotation # Compute integrated values for P, PA before any rotation

View File

@@ -242,3 +242,82 @@ def wcs_PA(PC, cdelt):
rot2 = np.pi / 2.0 - np.arctan2(abs(cdelt[0]) * PC[0, 1], cdelt[1] * PC[1, 1]) rot2 = np.pi / 2.0 - np.arctan2(abs(cdelt[0]) * PC[0, 1], cdelt[1] * PC[1, 1])
orient = 0.5 * (rot + rot2) * 180.0 / np.pi orient = 0.5 * (rot + rot2) * 180.0 / np.pi
return orient return orient
def add_stokes_axis_to_header(header, ind=0):
"""
Add a new Stokes axis to the WCS cards in the header.
----------
Inputs:
header : astropy.io.fits.header.Header
The header in which the WCS to work on is saved.
ind : int, optional
Index of the WCS to insert the new Stokes axis in front of.
To add at the end, do add_before_ind = wcs.wcs.naxis
The beginning is at position 0.
Default to 0.
----------
Returns:
new_head : astropy.io.fits.header.Header
A new Header instance with an additional Stokes axis
"""
from astropy.wcs import WCS
from astropy.wcs.utils import add_stokes_axis_to_wcs
wcs = WCS(header).deepcopy()
wcs_Stokes = add_stokes_axis_to_wcs(wcs, ind).deepcopy()
wcs_Stokes.array_shape = (*wcs.array_shape[ind:], 4, *wcs.array_shape[:ind]) if ind < wcs.wcs.naxis else (4, *wcs.array_shape)
new_head = header.copy()
new_head["NAXIS"] = wcs_Stokes.wcs.naxis
for key in wcs.to_header().keys():
if key not in wcs_Stokes.to_header().keys():
del new_head[key]
for key, val in (
list(wcs_Stokes.to_header().items())
+ [("NAXIS%d" % (i + 1), k) for i, k in enumerate(wcs_Stokes.array_shape[::-1])]
+ [("CUNIT%d" % (ind + 1), "STOKES")]
):
if key not in header.keys() and key[:-1] + str(wcs.wcs.naxis) in header.keys():
new_head.insert(key[:-1] + str(wcs.wcs.naxis), (key, val), after=int(key[-1]) < wcs.wcs.naxis)
elif key not in header.keys() and key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis) in header.keys():
new_head.insert(key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis), (key, val), after=int(key[-1]) < wcs.wcs.naxis)
else:
new_head[key] = val
return new_head
def remove_stokes_axis_from_header(header):
"""
Remove a Stokes axis to the WCS cards in the header.
----------
Inputs:
header : astropy.io.fits.header.Header
The header in which the WCS to work on is saved.
----------
Returns:
new_head : astropy.io.fits.header.Header
A new Header instance with only a celestial WCS.
"""
from astropy.wcs import WCS
wcs = WCS(header).deepcopy()
new_wcs = WCS(header).celestial.deepcopy()
new_head = header.copy()
del new_head["NAXIS%d" % (new_wcs.wcs.naxis + 1)]
new_head["NAXIS"] = new_wcs.wcs.naxis
for i, k in enumerate(new_wcs.array_shape[::-1]):
new_head["NAXIS%d" % (i + 1)] = k
for key in list(WCS(header).to_header().keys()) + list(
np.unique([["PC%d_%d" % (i + 1, j + 1) for i in range(wcs.wcs.naxis)] for j in range(wcs.wcs.naxis)])
):
if key in new_head.keys() and key not in new_wcs.to_header().keys():
del new_head[key]
for key, val in new_wcs.to_header().items():
if key not in new_head.keys() and key[:-1] + str(wcs.wcs.naxis) in new_head.keys():
new_head.insert(key[:-1] + str(wcs.wcs.naxis), (key, val), after=True)
elif key not in new_head.keys() and key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis) in new_head.keys():
new_head.insert(key[:2] + str(wcs.wcs.naxis) + key[2:-1] + str(wcs.wcs.naxis), (key, val), after=True)
else:
new_head[key] = val
return new_head