reformat code using python-lsp-ruff

This commit is contained in:
2024-07-01 15:21:52 +02:00
parent 271ecbb631
commit 5a62fa4983
13 changed files with 1271 additions and 860 deletions

View File

@@ -1,6 +1,7 @@
"""
Library functions for phase cross-correlation computation.
"""
# Prefer FFTs via the new scipy.fft module when available (SciPy 1.4+)
# Otherwise fall back to numpy.fft.
# Like numpy 1.15+ scipy 1.3+ is also using pocketfft, but a newer
@@ -13,8 +14,7 @@ except ImportError:
import numpy as np
def _upsampled_dft(data, upsampled_region_size, upsample_factor=1,
axis_offsets=None):
def _upsampled_dft(data, upsampled_region_size, upsample_factor=1, axis_offsets=None):
"""
Upsampled DFT by matrix multiplication.
This code is intended to provide the same result as if the following
@@ -48,26 +48,27 @@ def _upsampled_dft(data, upsampled_region_size, upsample_factor=1,
"""
# if people pass in an integer, expand it to a list of equal-sized sections
if not hasattr(upsampled_region_size, "__iter__"):
upsampled_region_size = [upsampled_region_size, ] * data.ndim
upsampled_region_size = [
upsampled_region_size,
] * data.ndim
else:
if len(upsampled_region_size) != data.ndim:
raise ValueError("shape of upsampled region sizes must be equal "
"to input data's number of dimensions.")
raise ValueError("shape of upsampled region sizes must be equal " "to input data's number of dimensions.")
if axis_offsets is None:
axis_offsets = [0, ] * data.ndim
axis_offsets = [
0,
] * data.ndim
else:
if len(axis_offsets) != data.ndim:
raise ValueError("number of axis offsets must be equal to input "
"data's number of dimensions.")
raise ValueError("number of axis offsets must be equal to input " "data's number of dimensions.")
im2pi = 1j * 2 * np.pi
dim_properties = list(zip(data.shape, upsampled_region_size, axis_offsets))
for (n_items, ups_size, ax_offset) in dim_properties[::-1]:
kernel = ((np.arange(ups_size) - ax_offset)[:, None]
* fft.fftfreq(n_items, upsample_factor))
for n_items, ups_size, ax_offset in dim_properties[::-1]:
kernel = (np.arange(ups_size) - ax_offset)[:, None] * fft.fftfreq(n_items, upsample_factor)
kernel = np.exp(-im2pi * kernel)
# Equivalent to:
@@ -100,14 +101,11 @@ def _compute_error(cross_correlation_max, src_amp, target_amp):
target_amp : float
The normalized average image intensity of the target image
"""
error = 1.0 - cross_correlation_max * cross_correlation_max.conj() /\
(src_amp * target_amp)
error = 1.0 - cross_correlation_max * cross_correlation_max.conj() / (src_amp * target_amp)
return np.sqrt(np.abs(error))
def phase_cross_correlation(reference_image, moving_image, *,
upsample_factor=1, space="real",
return_error=True, overlap_ratio=0.3):
def phase_cross_correlation(reference_image, moving_image, *, upsample_factor=1, space="real", return_error=True, overlap_ratio=0.3):
"""
Efficient subpixel image translation registration by cross-correlation.
This code gives the same precision as the FFT upsampled cross-correlation
@@ -174,11 +172,11 @@ def phase_cross_correlation(reference_image, moving_image, *,
raise ValueError("images must be same shape")
# assume complex data is already in Fourier space
if space.lower() == 'fourier':
if space.lower() == "fourier":
src_freq = reference_image
target_freq = moving_image
# real data needs to be fft'd.
elif space.lower() == 'real':
elif space.lower() == "real":
src_freq = fft.fftn(reference_image)
target_freq = fft.fftn(moving_image)
else:
@@ -190,8 +188,7 @@ def phase_cross_correlation(reference_image, moving_image, *,
cross_correlation = fft.ifftn(image_product)
# Locate maximum
maxima = np.unravel_index(np.argmax(np.abs(cross_correlation)),
cross_correlation.shape)
maxima = np.unravel_index(np.argmax(np.abs(cross_correlation)), cross_correlation.shape)
midpoints = np.array([np.fix(axis_size / 2) for axis_size in shape])
shifts = np.stack(maxima).astype(np.float64)
@@ -213,14 +210,10 @@ def phase_cross_correlation(reference_image, moving_image, *,
dftshift = np.fix(upsampled_region_size / 2.0)
upsample_factor = np.array(upsample_factor, dtype=np.float64)
# Matrix multiply DFT around the current shift estimate
sample_region_offset = dftshift - shifts*upsample_factor
cross_correlation = _upsampled_dft(image_product.conj(),
upsampled_region_size,
upsample_factor,
sample_region_offset).conj()
sample_region_offset = dftshift - shifts * upsample_factor
cross_correlation = _upsampled_dft(image_product.conj(), upsampled_region_size, upsample_factor, sample_region_offset).conj()
# Locate maximum and map back to original pixel grid
maxima = np.unravel_index(np.argmax(np.abs(cross_correlation)),
cross_correlation.shape)
maxima = np.unravel_index(np.argmax(np.abs(cross_correlation)), cross_correlation.shape)
CCmax = cross_correlation[maxima]
maxima = np.stack(maxima).astype(np.float64) - dftshift
@@ -240,10 +233,8 @@ def phase_cross_correlation(reference_image, moving_image, *,
if return_error:
# Redirect user to masked_phase_cross_correlation if NaNs are observed
if np.isnan(CCmax) or np.isnan(src_amp) or np.isnan(target_amp):
raise ValueError(
"NaN values found, please remove NaNs from your input data")
raise ValueError("NaN values found, please remove NaNs from your input data")
return shifts, _compute_error(CCmax, src_amp, target_amp), \
_compute_phasediff(CCmax)
return shifts, _compute_error(CCmax, src_amp, target_amp), _compute_phasediff(CCmax)
else:
return shifts