diff --git a/src/lib/plots.py b/src/lib/plots.py index 70cbc9f..ec78625 100755 --- a/src/lib/plots.py +++ b/src/lib/plots.py @@ -12,9 +12,9 @@ prototypes : from copy import deepcopy import numpy as np import matplotlib.pyplot as plt -from matplotlib.patches import Rectangle +from matplotlib.patches import Rectangle, Circle from matplotlib.path import Path -from matplotlib.widgets import RectangleSelector, Button, Slider, TextBox, LassoSelector +from matplotlib.widgets import RectangleSelector, LassoSelector, Button, Slider, TextBox from matplotlib.colors import LogNorm import matplotlib.font_manager as fm from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar, AnchoredDirectionArrows @@ -971,7 +971,7 @@ class image_lasso_selector: plt.ion() lineprops = {'color': 'grey', 'linewidth': 1, 'alpha': 0.8} - self.lasso = LassoSelector(self.ax, self.onselect,props=lineprops, useblit=False) + self.lasso = LassoSelector(self.ax, self.onselect, props=lineprops, useblit=False) self.lasso.set_visible(True) pix_x = np.arange(self.img.shape[0]) @@ -1011,6 +1011,110 @@ class image_lasso_selector: else: self.on_close() + +class aperture: + def __init__(self, img, cdelt=np.array([1.,1.]), radius=1., fig=None, ax=None): + """ + img must have shape (X, Y) + """ + self.selected = False + self.img = img + self.vmin, self.vmax = 0., np.max(self.img[self.img>0.]) + plt.ioff() # see https://github.com/matplotlib/matplotlib/issues/17013 + if fig is None: + self.fig = plt.figure(figsize=(15,15)) + else: + self.fig = fig + if ax is None: + self.ax = self.fig.gca() + self.mask_alpha = 1. + self.embedded = False + else: + self.ax = ax + self.mask_alpha = 0.1 + self.embedded = True + + self.displayed = self.ax.imshow(self.img, vmin=self.vmin, vmax=self.vmax, aspect='equal', cmap='inferno',alpha=self.mask_alpha) + plt.ion() + + xx, yy = np.indices(self.img.shape) + self.pix = np.vstack( (xx.flatten(), yy.flatten()) ).T + + self.x0, self.y0 = np.array(self.img.shape)/2. + if np.abs(cdelt).max() != 1.: + self.cdelt = cdelt + self.radius = radius/np.abs(self.cdelt).max()/3600. + + self.circ = Circle((self.x0, self.y0), self.radius, alpha=0.8, ec='grey',fc='none') + self.ax.add_patch(self.circ) + + self.fig.canvas.mpl_connect('button_press_event', self.on_press) + self.fig.canvas.mpl_connect('button_release_event', self.on_release) + self.fig.canvas.mpl_connect('motion_notify_event', self.on_move) + self.fig.canvas.mpl_connect('close_event', self.on_close) + self.x0, self.y0 = self.circ.center + self.pressevent = None + plt.show() + + def on_close(self, event=None) -> None: + if not hasattr(self, 'mask'): + self.mask = np.zeros(self.img.shape[:2],dtype=bool) + self.selected = True + + def on_press(self, event): + if event.inaxes != self.ax: + return + + if not self.circ.contains(event)[0]: + return + + self.pressevent = event + + def on_release(self, event): + self.pressevent = None + self.x0, self.y0 = self.circ.center + self.update_mask() + + def on_move(self, event): + if self.pressevent is None or event.inaxes != self.pressevent.inaxes: + return + + dx = event.xdata - self.pressevent.xdata + dy = event.ydata - self.pressevent.ydata + self.circ.center = self.x0 + dx, self.y0 + dy + self.fig.canvas.draw_idle() + + def update_radius(self, radius): + self.radius = radius/np.abs(self.cdelt).max()/3600 + self.circ.set_radius(self.radius) + self.fig.canvas.draw_idle() + + def update_mask(self): + if hasattr(self, 'displayed'): + try: + self.displayed.remove() + except: + return + self.displayed = self.ax.imshow(self.img, vmin=self.vmin, vmax=self.vmax, aspect='equal', cmap='inferno',alpha=self.mask_alpha) + array = self.displayed.get_array().data + + yy, xx = np.indices(self.img.shape[:2]) + x0, y0 = self.circ.center + self.mask = np.sqrt((xx-x0)**2+(yy-y0)**2) < self.radius + if hasattr(self, 'cont'): + for coll in self.cont.collections: + try: + coll.remove() + except: + return + self.cont = self.ax.contour(self.mask.astype(float),levels=[0.5], colors='white', linewidths=1) + if not self.embedded: + self.displayed.set_data(array) + self.fig.canvas.draw_idle() + else: + self.on_close() + + class pol_map(object): """ Class to interactively study polarization maps. @@ -1076,6 +1180,62 @@ class pol_map(object): s_P_cut.on_changed(update_snrp) b_snr_reset.on_clicked(reset_snr) + #Set axe for Aperture selection + ax_aper = self.fig.add_axes([0.55, 0.040, 0.05, 0.02]) + ax_aper_reset = self.fig.add_axes([0.605, 0.040, 0.05, 0.02]) + ax_aper_radius = self.fig.add_axes([0.55, 0.020, 0.10, 0.01]) + self.selected = False + b_aper = Button(ax_aper,"Aperture") + b_aper_reset = Button(ax_aper_reset,"Reset") + s_aper_radius = Slider(ax_aper_radius, r"$R_{aper}$", 0.5, 3.5, valstep=0.1, valinit=1) + + def select_aperture(event): + if self.data is None: + self.data = self.Stokes[0].data + if self.selected: + self.selected = False + self.region = deepcopy(self.select_instance.mask.astype(bool)) + self.select_instance.displayed.remove() + for coll in self.select_instance.cont.collections[:]: + coll.remove() + self.select_instance.circ.set_visible(False) + self.set_data_mask(deepcopy(self.region)) + self.pol_int() + else: + self.selected = True + self.region = None + self.select_instance = aperture(self.data, fig=self.fig, ax=self.ax, cdelt=self.wcs.wcs.cdelt, radius=s_aper_radius.val) + self.select_instance.circ.set_visible(True) + #k = 0 + #while not self.select_instance.selected and k<60: + # self.fig.canvas.start_event_loop(timeout=1) + # k+=1 + #select_aperture(event) + + self.fig.canvas.draw_idle() + + def update_aperture(val): + if hasattr(self, 'select_instance'): + if hasattr(self.select_instance, 'radius'): + self.select_instance.update_radius(val) + else: + self.selected = True + self.select_instance = aperture(self.data, fig=self.fig, ax=self.ax, cdelt=self.wcs.wcs.cdelt, radius=val) + else: + self.selected = True + self.select_instance = aperture(self.data, fig=self.fig, ax=self.ax, cdelt=self.wcs.wcs.cdelt, radius=val) + self.fig.canvas.draw_idle() + + + def reset_aperture(event): + self.region = None + self.pol_int() + self.fig.canvas.draw_idle() + + b_aper.on_clicked(select_aperture) + b_aper_reset.on_clicked(reset_aperture) + s_aper_radius.on_changed(update_aperture) + #Set axe for ROI selection ax_select = self.fig.add_axes([0.55, 0.070, 0.05, 0.02]) ax_roi_reset = self.fig.add_axes([0.605, 0.070, 0.05, 0.02])