small fixes and improvments

This commit is contained in:
2024-10-17 17:05:35 +02:00
parent bd7cad46a1
commit f6d62bff73
8 changed files with 119 additions and 174 deletions

View File

@@ -66,7 +66,7 @@ except ImportError:
from utils import PCconf, princ_angle, rot2D, sci_not
def plot_obs(data_array, headers, rectangle=None, savename=None, plots_folder="", **kwargs):
def plot_obs(data_array, headers, rectangle=None, shifts=None, savename=None, plots_folder="", **kwargs):
"""
Plots raw observation imagery with some information on the instrument and
filters.
@@ -77,16 +77,14 @@ def plot_obs(data_array, headers, rectangle=None, savename=None, plots_folder=""
single observation with multiple polarizers of an instrument
headers : header list
List of headers corresponding to the images in data_array
vmin : float, optional
Min pixel value that should be displayed.
Defaults to 0.
vmax : float, optional
Max pixel value that should be displayed.
Defaults to 6.
rectangle : numpy.ndarray, optional
Array of parameters for matplotlib.patches.Rectangle objects that will
be displayed on each output image. If None, no rectangle displayed.
Defaults to None.
shifts : numpy.ndarray, optional
Array of vector coordinates corresponding to images shifts with respect
to the others. If None, no shifts displayed.
Defaults to None.
savename : str, optional
Name of the figure the map should be saved to. If None, the map won't
be saved (only displayed).
@@ -100,6 +98,7 @@ def plot_obs(data_array, headers, rectangle=None, savename=None, plots_folder=""
nb_obs = np.max([np.sum([head["filtnam1"] == curr_pol for head in headers]) for curr_pol in ["POL0", "POL60", "POL120"]])
shape = np.array((3, nb_obs))
fig, ax = plt.subplots(shape[0], shape[1], figsize=(3 * shape[1], 3 * shape[0]), layout="constrained", sharex=True, sharey=True)
used = np.zeros(shape, dtype=bool)
r_pol = dict(pol0=0, pol60=1, pol120=2)
c_pol = dict(pol0=0, pol60=0, pol120=0)
for i, (data, head) in enumerate(zip(data_array, headers)):
@@ -112,15 +111,17 @@ def plot_obs(data_array, headers, rectangle=None, savename=None, plots_folder=""
c_pol[filt.lower()] += 1
if shape[1] != 1:
ax_curr = ax[r_ax][c_ax]
used[r_ax][c_ax] = True
else:
ax_curr = ax[r_ax]
ax_curr[r_ax] = True
# plots
if "vmin" in kwargs.keys() or "vmax" in kwargs.keys():
vmin, vmax = kwargs["vmin"], kwargs["vmax"]
del kwargs["vmin"], kwargs["vmax"]
else:
vmin, vmax = convert * data[data > 0.0].min() / 10.0, convert * data[data > 0.0].max()
for key, value in [["cmap", [["cmap", "gray"]]], ["norm", [["norm", LogNorm(vmin, vmax)]]]]:
for key, value in [["cmap", [["cmap", "inferno"]]], ["norm", [["norm", LogNorm(vmin, vmax)]]]]:
try:
_ = kwargs[key]
except KeyError:
@@ -129,17 +130,29 @@ def plot_obs(data_array, headers, rectangle=None, savename=None, plots_folder=""
# im = ax[r_ax][c_ax].imshow(convert*data, origin='lower', **kwargs)
data[data * convert < vmin * 10.0] = vmin * 10.0 / convert
im = ax_curr.imshow(convert * data, origin="lower", **kwargs)
if shifts is not None:
x, y = np.array(data.shape[::-1]) / 2.0 - shifts[i]
dx, dy = shifts[i]
ax_curr.arrow(x, y, dx, dy, length_includes_head=True, width=0.1, head_width=0.3, color="g")
ax_curr.plot([x, x], [0, data.shape[0] - 1], "--", lw=2, color="g", alpha=0.85)
ax_curr.plot([0, data.shape[1] - 1], [y, y], "--", lw=2, color="g", alpha=0.85)
if rectangle is not None:
x, y, width, height, angle, color = rectangle[i]
ax_curr.add_patch(Rectangle((x, y), width, height, angle=angle, edgecolor=color, fill=False))
# position of centroid
ax_curr.plot([data.shape[1] / 2, data.shape[1] / 2], [0, data.shape[0] - 1], "--", lw=1, color="grey", alpha=0.5)
ax_curr.plot([0, data.shape[1] - 1], [data.shape[1] / 2, data.shape[1] / 2], "--", lw=1, color="grey", alpha=0.5)
ax_curr.plot([data.shape[1] / 2, data.shape[1] / 2], [0, data.shape[0] - 1], "--", lw=2, color="b", alpha=0.85)
ax_curr.plot([0, data.shape[1] - 1], [data.shape[0] / 2, data.shape[0] / 2], "--", lw=2, color="b", alpha=0.85)
ax_curr.annotate(
instr + ":" + rootname, color="white", fontsize=5, xy=(0.01, 1.00), xycoords="axes fraction", verticalalignment="top", horizontalalignment="left"
instr + ":" + rootname, color="white", fontsize=10, xy=(0.01, 1.00), xycoords="axes fraction", verticalalignment="top", horizontalalignment="left"
)
ax_curr.annotate(filt, color="white", fontsize=10, xy=(0.01, 0.01), xycoords="axes fraction", verticalalignment="bottom", horizontalalignment="left")
ax_curr.annotate(exptime, color="white", fontsize=5, xy=(1.00, 0.01), xycoords="axes fraction", verticalalignment="bottom", horizontalalignment="right")
ax_curr.annotate(filt, color="white", fontsize=15, xy=(0.01, 0.01), xycoords="axes fraction", verticalalignment="bottom", horizontalalignment="left")
ax_curr.annotate(
exptime, color="white", fontsize=10, xy=(1.00, 0.01), xycoords="axes fraction", verticalalignment="bottom", horizontalalignment="right"
)
unused = np.logical_not(used)
ii, jj = np.indices(shape)
for i, j in zip(ii[unused], jj[unused]):
fig.delaxes(ax[i][j])
# fig.subplots_adjust(hspace=0.01, wspace=0.01, right=1.02)
fig.colorbar(im, ax=ax, location="right", shrink=0.75, aspect=50, pad=0.025, label=r"Flux [$ergs \cdot cm^{-2} \cdot s^{-1} \cdot \AA^{-1}$]")
@@ -349,7 +362,7 @@ def polarization_map(
fig = plt.figure(figsize=(7 * ratiox, 7 * ratioy), layout="constrained")
if ax is None:
ax = fig.add_subplot(111, projection=wcs)
ax.set(aspect="equal", fc="k", ylim=[-stkI.shape[0] * 0.10, stkI.shape[0] * 1.15])
ax.set(aspect="equal", fc="k", ylim=[-stkI.shape[0] * 0.01, stkI.shape[0] * 1.01])
# fig.subplots_adjust(hspace=0, wspace=0, left=0.102, right=1.02)
# ax.coords.grid(True, color='white', ls='dotted', alpha=0.5)
@@ -677,7 +690,7 @@ class align_maps(object):
except KeyError:
for key_i, val_i in value:
kwargs[key_i] = val_i
self.map_ax.imshow(self.map_data * self.map_convert, aspect="equal", **kwargs)
self.im = self.map_ax.imshow(self.map_data * self.map_convert, aspect="equal", **kwargs)
if kwargs["cmap"] in [
"inferno",
@@ -766,7 +779,7 @@ class align_maps(object):
except KeyError:
for key_i, val_i in value:
other_kwargs[key_i] = val_i
self.other_ax.imshow(self.other_data * self.other_convert, aspect="equal", **other_kwargs)
self.other_im = self.other_ax.imshow(self.other_data * self.other_convert, aspect="equal", **other_kwargs)
px_size2 = self.other_wcs.wcs.get_cdelt()[0] * 3600.0
px_sc2 = AnchoredSizeBar(
@@ -1561,6 +1574,7 @@ class overplot_pol(align_maps):
while not self.aligned:
self.align()
self.overplot(levels=levels, P_cut=P_cut, SNRi_cut=SNRi_cut, scale_vec=scale_vec, savename=savename, **kwargs)
plt.show(block=True)
def add_vector(self, position="center", pol_deg=1.0, pol_ang=0.0, **kwargs):
if isinstance(position, str) and position == "center":
@@ -2148,7 +2162,7 @@ class image_lasso_selector(object):
self.mask = np.zeros(self.img.shape[:2], dtype=bool)
self.mask[self.indices] = True
if hasattr(self, "cont"):
for coll in self.cont:
for coll in self.cont.collections:
coll.remove()
self.cont = self.ax.contour(self.mask.astype(float), levels=[0.5], colors="white", linewidths=1)
if not self.embedded:
@@ -2261,7 +2275,7 @@ class slit(object):
for p in self.pix:
self.mask[tuple(p)] = (np.abs(np.dot(rot2D(-self.angle), p - self.rect.get_center()[::-1])) < (self.height / 2.0, self.width / 2.0)).all()
if hasattr(self, "cont"):
for coll in self.cont:
for coll in self.cont.collections:
try:
coll.remove()
except AttributeError:
@@ -2364,7 +2378,7 @@ class aperture(object):
x0, y0 = self.circ.center
self.mask = np.sqrt((xx - x0) ** 2 + (yy - y0) ** 2) < self.radius
if hasattr(self, "cont"):
for coll in self.cont:
for coll in self.cont.collections:
try:
coll.remove()
except AttributeError:
@@ -2499,7 +2513,7 @@ class pol_map(object):
self.selected = False
self.region = deepcopy(self.select_instance.mask.astype(bool))
self.select_instance.displayed.remove()
for coll in self.select_instance.cont:
for coll in self.select_instance.cont.collections:
coll.remove()
self.select_instance.lasso.set_active(False)
self.set_data_mask(deepcopy(self.region))
@@ -2543,7 +2557,7 @@ class pol_map(object):
self.select_instance.update_mask()
self.region = deepcopy(self.select_instance.mask.astype(bool))
self.select_instance.displayed.remove()
for coll in self.select_instance.cont:
for coll in self.select_instance.cont.collections:
coll.remove()
self.select_instance.circ.set_visible(False)
self.set_data_mask(deepcopy(self.region))
@@ -2601,7 +2615,7 @@ class pol_map(object):
self.select_instance.update_mask()
self.region = deepcopy(self.select_instance.mask.astype(bool))
self.select_instance.displayed.remove()
for coll in self.select_instance.cont:
for coll in self.select_instance.cont.collections:
coll.remove()
self.select_instance.rect.set_visible(False)
self.set_data_mask(deepcopy(self.region))
@@ -3324,7 +3338,7 @@ class pol_map(object):
)
if hasattr(self, "cont"):
for coll in self.cont:
for coll in self.cont.collections:
try:
coll.remove()
except AttributeError: