Восстановление фазы
Восстановление фазы — классическая задача оптики: определить фазу волнового фронта по измерениям интенсивности. В этом туториале мы реализуем алгоритм Gerchberg-Saxton.
Проблема фазы
При измерении света детекторы регистрируют только интенсивность , теряя информацию о фазе . Однако фаза критически важна для многих приложений:
- Формирование пучка
- Адаптивная оптика
- Голография
- Оптические пинцеты
Алгоритм Gerchberg-Saxton
Идея
Если известны распределения интенсивности в двух сопряжённых плоскостях (например, объектной и фокальной), можно итеративно восстановить фазу.
Шаги алгоритма
- Начать со случайной фазы
- Применить известную амплитуду в плоскости 1
- Выполнить преобразование (FFT или распространение)
- Применить известную амплитуду в плоскости 2
- Выполнить обратное преобразование
- Повторять до сходимости
Шаг 1: Настройка
import torch
import matplotlib.pyplot as plt
from svetlanna import SimulationParameters, Wavefront
from svetlanna.elements import ThinLens, FreeSpace
from svetlanna.units import ureg
# Параметры
params = SimulationParameters.from_ranges(
w_range=(-2*ureg.mm, 2*ureg.mm), w_points=256,
h_range=(-2*ureg.mm, 2*ureg.mm), h_points=256,
wavelength=632.8*ureg.nm
)
# Координаты
W, H = torch.meshgrid(params.axes.W, params.axes.H, indexing='xy')Шаг 2: Создание тестовых данных
Создадим “известные” распределения интенсивности:
# Целевая амплитуда в плоскости объекта — буква "A"
def create_letter_A(W, H, size=1e-3):
"""Создаёт маску в форме буквы A."""
mask = torch.zeros_like(W)
# Вертикальные линии
left = (W > -size*0.4) & (W < -size*0.3) & (H > -size*0.5) & (H < size*0.5)
right = (W > size*0.3) & (W < size*0.4) & (H > -size*0.5) & (H < size*0.5)
# Горизонтальная перекладина
bar = (W > -size*0.3) & (W < size*0.3) & (H > -size*0.1) & (H < size*0.1)
# Верхняя часть (треугольник)
top = (H > size*0.3) & (H < size*0.5) & (torch.abs(W) < (size*0.5 - H)*0.8)
mask = (left | right | bar | top).float()
return mask
source_amplitude = create_letter_A(W, H)
# Целевая амплитуда в фокальной плоскости — кольцо
r = torch.sqrt(W**2 + H**2)
inner_r = 0.3e-3
outer_r = 0.5e-3
target_amplitude = ((r > inner_r) & (r < outer_r)).float()
# Нормализация
source_amplitude = source_amplitude / source_amplitude.max()
target_amplitude = target_amplitude / target_amplitude.sum().sqrt()Шаг 3: Прямое и обратное преобразование
Для простоты используем FFT как модель распространения:
def forward_propagate(field):
"""Прямое преобразование (FFT)."""
return torch.fft.fftshift(torch.fft.fft2(torch.fft.ifftshift(field)))
def backward_propagate(field):
"""Обратное преобразование (IFFT)."""
return torch.fft.fftshift(torch.fft.ifft2(torch.fft.ifftshift(field)))Шаг 4: Алгоритм Gerchberg-Saxton
def gerchberg_saxton(source_amplitude, target_amplitude, n_iterations=100):
"""
Алгоритм Gerchberg-Saxton для восстановления фазы.
Parameters
----------
source_amplitude : Tensor
Амплитуда в исходной плоскости
target_amplitude : Tensor
Амплитуда в целевой плоскости
n_iterations : int
Количество итераций
Returns
-------
phase : Tensor
Восстановленная фаза
errors : list
История ошибок
"""
# Инициализация случайной фазой
phase = torch.rand_like(source_amplitude) * 2 * torch.pi
errors = []
for iteration in range(n_iterations):
# 1. Поле в исходной плоскости с текущей фазой
source_field = source_amplitude * torch.exp(1j * phase)
# 2. Прямое распространение
target_field = forward_propagate(source_field)
# 3. Вычисление ошибки
target_intensity = torch.abs(target_field)**2
error = torch.mean((torch.sqrt(target_intensity) - target_amplitude)**2)
errors.append(error.item())
# 4. Замена амплитуды, сохранение фазы
target_phase = torch.angle(target_field)
target_field = target_amplitude * torch.exp(1j * target_phase)
# 5. Обратное распространение
source_field = backward_propagate(target_field)
# 6. Извлечение фазы
phase = torch.angle(source_field)
if iteration % 20 == 0:
print(f"Итерация {iteration}: ошибка = {error:.6f}")
return phase, errors
# Запуск алгоритма
recovered_phase, errors = gerchberg_saxton(
source_amplitude,
target_amplitude,
n_iterations=100
)Шаг 5: Проверка результата
# Применяем восстановленную фазу
source_field = source_amplitude * torch.exp(1j * recovered_phase)
target_field = forward_propagate(source_field)
# Сравнение
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# Исходная плоскость
axes[0, 0].imshow(source_amplitude.cpu(), cmap='gray')
axes[0, 0].set_title('Исходная амплитуда')
axes[0, 1].imshow(recovered_phase.cpu(), cmap='twilight')
axes[0, 1].set_title('Восстановленная фаза')
axes[0, 2].plot(errors)
axes[0, 2].set_xlabel('Итерация')
axes[0, 2].set_ylabel('Ошибка')
axes[0, 2].set_title('Сходимость')
axes[0, 2].set_yscale('log')
# Целевая плоскость
axes[1, 0].imshow(target_amplitude.cpu()**2, cmap='hot')
axes[1, 0].set_title('Целевая интенсивность')
axes[1, 1].imshow(torch.abs(target_field).cpu()**2, cmap='hot')
axes[1, 1].set_title('Полученная интенсивность')
# Разница
diff = torch.abs(torch.abs(target_field)**2 - target_amplitude**2)
axes[1, 2].imshow(diff.cpu(), cmap='hot')
axes[1, 2].set_title('Разница')
plt.tight_layout()
plt.show()Шаг 6: Модифицированный алгоритм с ограничениями
Добавим ограничения на фазу (например, только положительные значения):
def gerchberg_saxton_constrained(source_amplitude, target_amplitude,
n_iterations=100, phase_range=(0, 2*torch.pi)):
"""GS с ограничением диапазона фазы."""
phase = torch.rand_like(source_amplitude) * (phase_range[1] - phase_range[0]) + phase_range[0]
errors = []
for iteration in range(n_iterations):
source_field = source_amplitude * torch.exp(1j * phase)
target_field = forward_propagate(source_field)
error = torch.mean((torch.abs(target_field) - target_amplitude)**2)
errors.append(error.item())
target_phase = torch.angle(target_field)
target_field = target_amplitude * torch.exp(1j * target_phase)
source_field = backward_propagate(target_field)
phase = torch.angle(source_field)
# Ограничение фазы
phase = torch.clamp(phase, phase_range[0], phase_range[1])
return phase, errorsШаг 7: Использование с SVETlANNa элементами
Реализуем GS с использованием оптических элементов:
def gerchberg_saxton_optical(params, source_amplitude, target_amplitude,
focal_length, n_iterations=100):
"""
GS с использованием оптической системы (линза + распространение).
"""
lens = ThinLens(params, focal_length=focal_length)
prop = FreeSpace(params, distance=focal_length, method="AS")
phase = torch.rand_like(source_amplitude) * 2 * torch.pi
errors = []
for iteration in range(n_iterations):
# Прямое распространение через линзу
source_field = source_amplitude * torch.exp(1j * phase)
wf = Wavefront(source_field.to(torch.complex64))
wf = lens(wf)
wf = prop(wf)
target_field = wf
# Ошибка
error = torch.mean((torch.sqrt(target_field.intensity) - target_amplitude)**2)
errors.append(error.item())
# Замена амплитуды
target_phase = target_field.phase
target_field = Wavefront(
(target_amplitude * torch.exp(1j * target_phase)).to(torch.complex64)
)
# Обратное распространение
prop_back = FreeSpace(params, distance=-focal_length, method="AS")
lens_inv = ThinLens(params, focal_length=-focal_length)
wf = prop_back(target_field)
wf = lens_inv(wf)
phase = wf.phase
return phase, errorsШаг 8: Градиентная оптимизация
Альтернативный подход — использовать autograd PyTorch:
def phase_retrieval_gradient(source_amplitude, target_amplitude,
n_iterations=500, lr=0.1):
"""Восстановление фазы градиентным спуском."""
# Оптимизируемая фаза
phase = torch.nn.Parameter(torch.rand_like(source_amplitude) * 2 * torch.pi)
optimizer = torch.optim.Adam([phase], lr=lr)
errors = []
for iteration in range(n_iterations):
optimizer.zero_grad()
# Forward
source_field = source_amplitude * torch.exp(1j * phase)
target_field = forward_propagate(source_field)
# Loss
loss = torch.mean((torch.abs(target_field) - target_amplitude)**2)
# Backward
loss.backward()
optimizer.step()
errors.append(loss.item())
if iteration % 100 == 0:
print(f"Итерация {iteration}: loss = {loss.item():.6f}")
return phase.detach(), errorsПолный код
import torch
import matplotlib.pyplot as plt
from svetlanna import SimulationParameters, Wavefront
from svetlanna.units import ureg
# Параметры
params = SimulationParameters.from_ranges(
w_range=(-2*ureg.mm, 2*ureg.mm), w_points=256,
h_range=(-2*ureg.mm, 2*ureg.mm), h_points=256,
wavelength=632.8*ureg.nm
)
W, H = torch.meshgrid(params.axes.W, params.axes.H, indexing='xy')
# Целевое распределение — кольцо
r = torch.sqrt(W**2 + H**2)
target = ((r > 0.3e-3) & (r < 0.5e-3)).float()
target = target / target.sum().sqrt()
# Исходная амплитуда — гауссов пучок
source = torch.exp(-(W**2 + H**2) / (0.5e-3)**2)
# GS алгоритм
phase = torch.rand(256, 256) * 2 * torch.pi
for i in range(100):
field = source * torch.exp(1j * phase)
spectrum = torch.fft.fftshift(torch.fft.fft2(torch.fft.ifftshift(field)))
spectrum = target * torch.exp(1j * torch.angle(spectrum))
field = torch.fft.fftshift(torch.fft.ifft2(torch.fft.ifftshift(spectrum)))
phase = torch.angle(field)
print("Восстановление завершено!")Выводы
- Алгоритм GS эффективно восстанавливает фазу при известных интенсивностях
- Сходимость достигается за 50-100 итераций
- Градиентный подход может быть точнее, но требует больше вычислений
- SVETlANNa позволяет использовать реалистичные модели распространения
Что дальше?
- Дифракционная нейросеть — обучение D²NN
- Оптимизация — градиентные методы