Right number of WCS axis for each HDU in output

This commit is contained in:
2025-08-06 17:28:18 +02:00
parent f47c650dc5
commit e639695618
3 changed files with 96 additions and 17 deletions

View File

@@ -16,7 +16,7 @@ from astropy.io import fits
from astropy.wcs import WCS
from .convex_hull import clean_ROI
from .utils import wcs_CD_to_PC, wcs_PA
from .utils import wcs_CD_to_PC, wcs_PA, add_stokes_axis_to_header, remove_stokes_axis_from_header
def get_obs_data(infiles, data_folder="", compute_flux=False):
@@ -141,15 +141,17 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P,
Only returned if return_hdul is True.
"""
# Create new WCS object given the modified images
new_wcs = WCS(header_stokes).deepcopy()
new_wcs = WCS(header_stokes).celestial.deepcopy()
header = remove_stokes_axis_from_header(header_stokes).copy()
if data_mask.shape != (1, 1):
vertex = clean_ROI(data_mask)
shape = vertex[1::2] - vertex[0::2]
new_wcs.array_shape = (4, *shape)
new_wcs.wcs.crpix[1:] = np.array(new_wcs.wcs.crpix[1:]) - vertex[0::-2]
new_wcs.array_shape = shape
new_wcs.wcs.crpix = np.array(new_wcs.wcs.crpix) - vertex[0::-2]
for key, val in list(new_wcs.to_header().items()) + [("NAXIS", 2), ("NAXIS1", new_wcs.array_shape[1]), ("NAXIS2", new_wcs.array_shape[0])]:
header[key] = val
header = new_wcs.to_header()
header["TELESCOP"] = (header_stokes["TELESCOP"] if "TELESCOP" in list(header_stokes.keys()) else "HST", "telescope used to acquire data")
header["INSTRUME"] = (header_stokes["INSTRUME"] if "INSTRUME" in list(header_stokes.keys()) else "FOC", "identifier for instrument used to acuire data")
header["PHOTPLAM"] = (header_stokes["PHOTPLAM"], "Pivot Wavelength")
@@ -190,8 +192,9 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P,
# Add I_stokes as PrimaryHDU
header["datatype"] = ("STOKES", "type of data stored in the HDU")
Stokes[np.broadcast_to((1 - data_mask).astype(bool), Stokes.shape)] = 0.0
primary_hdu = fits.PrimaryHDU(data=Stokes, header=header)
primary_hdu.name = "Stokes"
hdu_head = add_stokes_axis_to_header(header, 2)
primary_hdu = fits.PrimaryHDU(data=Stokes, header=hdu_head)
primary_hdu.name = "STOKES"
hdul.append(primary_hdu)
# Add Stokes_cov, P, s_P, PA, s_PA to the HDUList
@@ -206,13 +209,15 @@ def save_Stokes(Stokes, Stokes_cov, P, debiased_P, s_P, s_P_P, PA, s_PA, s_PA_P,
[s_PA_P, "Pol_ang_stat_err"],
[data_mask, "Data_mask"],
]:
hdu_header = header.copy()
hdu_header["datatype"] = name
hdu_head = header.copy()
hdu_head["datatype"] = name
if name == "STOKES_COV":
hdu_head = add_stokes_axis_to_header(hdu_head, 2)
hdu_head = add_stokes_axis_to_header(hdu_head, 3)
data[np.broadcast_to((1 - data_mask).astype(bool), data.shape)] = 0.0
else:
data[(1 - data_mask).astype(bool)] = 0.0
hdu = fits.ImageHDU(data=data, header=hdu_header)
hdu = fits.ImageHDU(data=data, header=hdu_head)
hdu.name = name
hdul.append(hdu)