Skip to Content
DocsТуториалыВосстановление фазы

Восстановление фазы

Восстановление фазы — классическая задача оптики: определить фазу волнового фронта по измерениям интенсивности. В этом туториале мы реализуем алгоритм Gerchberg-Saxton.

Проблема фазы

При измерении света детекторы регистрируют только интенсивность I=E2I = |E|^2, теряя информацию о фазе ϕ=arg(E)\phi = \arg(E). Однако фаза критически важна для многих приложений:

  • Формирование пучка
  • Адаптивная оптика
  • Голография
  • Оптические пинцеты

Алгоритм Gerchberg-Saxton

Идея

Если известны распределения интенсивности в двух сопряжённых плоскостях (например, объектной и фокальной), можно итеративно восстановить фазу.

Шаги алгоритма

  1. Начать со случайной фазы
  2. Применить известную амплитуду в плоскости 1
  3. Выполнить преобразование (FFT или распространение)
  4. Применить известную амплитуду в плоскости 2
  5. Выполнить обратное преобразование
  6. Повторять до сходимости

Шаг 1: Настройка

import torch import matplotlib.pyplot as plt from svetlanna import SimulationParameters, Wavefront from svetlanna.elements import ThinLens, FreeSpace from svetlanna.units import ureg # Параметры params = SimulationParameters.from_ranges( w_range=(-2*ureg.mm, 2*ureg.mm), w_points=256, h_range=(-2*ureg.mm, 2*ureg.mm), h_points=256, wavelength=632.8*ureg.nm ) # Координаты W, H = torch.meshgrid(params.axes.W, params.axes.H, indexing='xy')

Шаг 2: Создание тестовых данных

Создадим “известные” распределения интенсивности:

# Целевая амплитуда в плоскости объекта — буква "A" def create_letter_A(W, H, size=1e-3): """Создаёт маску в форме буквы A.""" mask = torch.zeros_like(W) # Вертикальные линии left = (W > -size*0.4) & (W < -size*0.3) & (H > -size*0.5) & (H < size*0.5) right = (W > size*0.3) & (W < size*0.4) & (H > -size*0.5) & (H < size*0.5) # Горизонтальная перекладина bar = (W > -size*0.3) & (W < size*0.3) & (H > -size*0.1) & (H < size*0.1) # Верхняя часть (треугольник) top = (H > size*0.3) & (H < size*0.5) & (torch.abs(W) < (size*0.5 - H)*0.8) mask = (left | right | bar | top).float() return mask source_amplitude = create_letter_A(W, H) # Целевая амплитуда в фокальной плоскости — кольцо r = torch.sqrt(W**2 + H**2) inner_r = 0.3e-3 outer_r = 0.5e-3 target_amplitude = ((r > inner_r) & (r < outer_r)).float() # Нормализация source_amplitude = source_amplitude / source_amplitude.max() target_amplitude = target_amplitude / target_amplitude.sum().sqrt()

Шаг 3: Прямое и обратное преобразование

Для простоты используем FFT как модель распространения:

def forward_propagate(field): """Прямое преобразование (FFT).""" return torch.fft.fftshift(torch.fft.fft2(torch.fft.ifftshift(field))) def backward_propagate(field): """Обратное преобразование (IFFT).""" return torch.fft.fftshift(torch.fft.ifft2(torch.fft.ifftshift(field)))

Шаг 4: Алгоритм Gerchberg-Saxton

def gerchberg_saxton(source_amplitude, target_amplitude, n_iterations=100): """ Алгоритм Gerchberg-Saxton для восстановления фазы. Parameters ---------- source_amplitude : Tensor Амплитуда в исходной плоскости target_amplitude : Tensor Амплитуда в целевой плоскости n_iterations : int Количество итераций Returns ------- phase : Tensor Восстановленная фаза errors : list История ошибок """ # Инициализация случайной фазой phase = torch.rand_like(source_amplitude) * 2 * torch.pi errors = [] for iteration in range(n_iterations): # 1. Поле в исходной плоскости с текущей фазой source_field = source_amplitude * torch.exp(1j * phase) # 2. Прямое распространение target_field = forward_propagate(source_field) # 3. Вычисление ошибки target_intensity = torch.abs(target_field)**2 error = torch.mean((torch.sqrt(target_intensity) - target_amplitude)**2) errors.append(error.item()) # 4. Замена амплитуды, сохранение фазы target_phase = torch.angle(target_field) target_field = target_amplitude * torch.exp(1j * target_phase) # 5. Обратное распространение source_field = backward_propagate(target_field) # 6. Извлечение фазы phase = torch.angle(source_field) if iteration % 20 == 0: print(f"Итерация {iteration}: ошибка = {error:.6f}") return phase, errors # Запуск алгоритма recovered_phase, errors = gerchberg_saxton( source_amplitude, target_amplitude, n_iterations=100 )

Шаг 5: Проверка результата

# Применяем восстановленную фазу source_field = source_amplitude * torch.exp(1j * recovered_phase) target_field = forward_propagate(source_field) # Сравнение fig, axes = plt.subplots(2, 3, figsize=(15, 10)) # Исходная плоскость axes[0, 0].imshow(source_amplitude.cpu(), cmap='gray') axes[0, 0].set_title('Исходная амплитуда') axes[0, 1].imshow(recovered_phase.cpu(), cmap='twilight') axes[0, 1].set_title('Восстановленная фаза') axes[0, 2].plot(errors) axes[0, 2].set_xlabel('Итерация') axes[0, 2].set_ylabel('Ошибка') axes[0, 2].set_title('Сходимость') axes[0, 2].set_yscale('log') # Целевая плоскость axes[1, 0].imshow(target_amplitude.cpu()**2, cmap='hot') axes[1, 0].set_title('Целевая интенсивность') axes[1, 1].imshow(torch.abs(target_field).cpu()**2, cmap='hot') axes[1, 1].set_title('Полученная интенсивность') # Разница diff = torch.abs(torch.abs(target_field)**2 - target_amplitude**2) axes[1, 2].imshow(diff.cpu(), cmap='hot') axes[1, 2].set_title('Разница') plt.tight_layout() plt.show()

Шаг 6: Модифицированный алгоритм с ограничениями

Добавим ограничения на фазу (например, только положительные значения):

def gerchberg_saxton_constrained(source_amplitude, target_amplitude, n_iterations=100, phase_range=(0, 2*torch.pi)): """GS с ограничением диапазона фазы.""" phase = torch.rand_like(source_amplitude) * (phase_range[1] - phase_range[0]) + phase_range[0] errors = [] for iteration in range(n_iterations): source_field = source_amplitude * torch.exp(1j * phase) target_field = forward_propagate(source_field) error = torch.mean((torch.abs(target_field) - target_amplitude)**2) errors.append(error.item()) target_phase = torch.angle(target_field) target_field = target_amplitude * torch.exp(1j * target_phase) source_field = backward_propagate(target_field) phase = torch.angle(source_field) # Ограничение фазы phase = torch.clamp(phase, phase_range[0], phase_range[1]) return phase, errors

Шаг 7: Использование с SVETlANNa элементами

Реализуем GS с использованием оптических элементов:

def gerchberg_saxton_optical(params, source_amplitude, target_amplitude, focal_length, n_iterations=100): """ GS с использованием оптической системы (линза + распространение). """ lens = ThinLens(params, focal_length=focal_length) prop = FreeSpace(params, distance=focal_length, method="AS") phase = torch.rand_like(source_amplitude) * 2 * torch.pi errors = [] for iteration in range(n_iterations): # Прямое распространение через линзу source_field = source_amplitude * torch.exp(1j * phase) wf = Wavefront(source_field.to(torch.complex64)) wf = lens(wf) wf = prop(wf) target_field = wf # Ошибка error = torch.mean((torch.sqrt(target_field.intensity) - target_amplitude)**2) errors.append(error.item()) # Замена амплитуды target_phase = target_field.phase target_field = Wavefront( (target_amplitude * torch.exp(1j * target_phase)).to(torch.complex64) ) # Обратное распространение prop_back = FreeSpace(params, distance=-focal_length, method="AS") lens_inv = ThinLens(params, focal_length=-focal_length) wf = prop_back(target_field) wf = lens_inv(wf) phase = wf.phase return phase, errors

Шаг 8: Градиентная оптимизация

Альтернативный подход — использовать autograd PyTorch:

def phase_retrieval_gradient(source_amplitude, target_amplitude, n_iterations=500, lr=0.1): """Восстановление фазы градиентным спуском.""" # Оптимизируемая фаза phase = torch.nn.Parameter(torch.rand_like(source_amplitude) * 2 * torch.pi) optimizer = torch.optim.Adam([phase], lr=lr) errors = [] for iteration in range(n_iterations): optimizer.zero_grad() # Forward source_field = source_amplitude * torch.exp(1j * phase) target_field = forward_propagate(source_field) # Loss loss = torch.mean((torch.abs(target_field) - target_amplitude)**2) # Backward loss.backward() optimizer.step() errors.append(loss.item()) if iteration % 100 == 0: print(f"Итерация {iteration}: loss = {loss.item():.6f}") return phase.detach(), errors

Полный код

import torch import matplotlib.pyplot as plt from svetlanna import SimulationParameters, Wavefront from svetlanna.units import ureg # Параметры params = SimulationParameters.from_ranges( w_range=(-2*ureg.mm, 2*ureg.mm), w_points=256, h_range=(-2*ureg.mm, 2*ureg.mm), h_points=256, wavelength=632.8*ureg.nm ) W, H = torch.meshgrid(params.axes.W, params.axes.H, indexing='xy') # Целевое распределение — кольцо r = torch.sqrt(W**2 + H**2) target = ((r > 0.3e-3) & (r < 0.5e-3)).float() target = target / target.sum().sqrt() # Исходная амплитуда — гауссов пучок source = torch.exp(-(W**2 + H**2) / (0.5e-3)**2) # GS алгоритм phase = torch.rand(256, 256) * 2 * torch.pi for i in range(100): field = source * torch.exp(1j * phase) spectrum = torch.fft.fftshift(torch.fft.fft2(torch.fft.ifftshift(field))) spectrum = target * torch.exp(1j * torch.angle(spectrum)) field = torch.fft.fftshift(torch.fft.ifft2(torch.fft.ifftshift(spectrum))) phase = torch.angle(field) print("Восстановление завершено!")

Выводы

  1. Алгоритм GS эффективно восстанавливает фазу при известных интенсивностях
  2. Сходимость достигается за 50-100 итераций
  3. Градиентный подход может быть точнее, но требует больше вычислений
  4. SVETlANNa позволяет использовать реалистичные модели распространения

Что дальше?