Оптимизация и обучение
SVETlANNa полностью дифференцируема благодаря интеграции с PyTorch. Это позволяет использовать градиентные методы для оптимизации оптических систем.
Основы автоматического дифференцирования
Оптимизируемые параметры
Используйте nn.Parameter для параметров, которые нужно оптимизировать:
import torch
import torch.nn as nn
class OptimizableLens(nn.Module):
def __init__(self, params, initial_focal_length):
super().__init__()
# Оптимизируемое фокусное расстояние
self.focal_length = nn.Parameter(
torch.tensor(initial_focal_length)
)
self.params = params
def forward(self, wf):
k = 2 * torch.pi / self.params.axes.wavelength
W, H = torch.meshgrid(
self.params.axes.W,
self.params.axes.H,
indexing='xy'
)
phase = -k / (2 * self.focal_length) * (W**2 + H**2)
return wf * torch.exp(1j * phase)Вычисление градиентов
# Forward pass
wf_out = system(wf_in)
loss = loss_function(wf_out)
# Backward pass
loss.backward()
# Градиенты доступны
for name, param in system.named_parameters():
print(f"{name}: grad = {param.grad}")Функции потерь
Максимизация интенсивности в точке
def peak_intensity_loss(wf, target_x, target_y, params):
"""Максимизация интенсивности в заданной точке."""
# Найти индексы ближайшей точки
ix = torch.argmin(torch.abs(params.axes.W - target_x))
iy = torch.argmin(torch.abs(params.axes.H - target_y))
# Отрицательная интенсивность (минимизируем)
return -wf.intensity[iy, ix]Соответствие целевому распределению
def target_distribution_loss(wf, target_intensity):
"""MSE между интенсивностью и целевым распределением."""
return nn.functional.mse_loss(
wf.intensity / wf.intensity.max(),
target_intensity / target_intensity.max()
)Минимизация FWHM
def fwhm_loss(wf, params):
"""Минимизация размера пятна."""
fwhm_x, fwhm_y = wf.fwhm(params)
return fwhm_x + fwhm_yЭффективность в область
def efficiency_loss(wf, target_mask):
"""Максимизация энергии в целевой области."""
total_energy = wf.intensity.sum()
target_energy = (wf.intensity * target_mask).sum()
efficiency = target_energy / total_energy
return -efficiency # Минимизируем отрицательную эффективностьОптимизаторы PyTorch
Adam (рекомендуется)
optimizer = torch.optim.Adam(system.parameters(), lr=1e-3)SGD с momentum
optimizer = torch.optim.SGD(
system.parameters(),
lr=1e-2,
momentum=0.9
)LBFGS (для точной оптимизации)
optimizer = torch.optim.LBFGS(
system.parameters(),
lr=1.0,
max_iter=20
)
def closure():
optimizer.zero_grad()
wf_out = system(wf_in)
loss = loss_function(wf_out)
loss.backward()
return loss
optimizer.step(closure)Цикл обучения
Базовый цикл
import torch
from svetlanna import SimulationParameters, Wavefront
from svetlanna.units import ureg
# Параметры
params = SimulationParameters.from_ranges(
w_range=(-1*ureg.mm, 1*ureg.mm), w_points=256,
h_range=(-1*ureg.mm, 1*ureg.mm), h_points=256,
wavelength=632.8*ureg.nm
)
# Входной волновой фронт
wf_in = Wavefront.gaussian_beam(params, waist_radius=0.3*ureg.mm)
# Система с оптимизируемыми параметрами
system = OptimizableSystem(params)
# Оптимизатор
optimizer = torch.optim.Adam(system.parameters(), lr=1e-3)
# Обучение
losses = []
for epoch in range(1000):
optimizer.zero_grad()
wf_out = system(wf_in)
loss = loss_function(wf_out)
loss.backward()
optimizer.step()
losses.append(loss.item())
if epoch % 100 == 0:
print(f"Epoch {epoch}: loss = {loss.item():.6f}")С валидацией
for epoch in range(1000):
# Training
system.train()
optimizer.zero_grad()
wf_out = system(wf_in)
train_loss = loss_function(wf_out)
train_loss.backward()
optimizer.step()
# Validation (без градиентов)
system.eval()
with torch.no_grad():
wf_val = system(wf_validation)
val_loss = loss_function(wf_val)
print(f"Epoch {epoch}: train={train_loss:.6f}, val={val_loss:.6f}")Примеры оптимизации
Оптимизация фазовой маски SLM
class SLMOptimization(nn.Module):
def __init__(self, params, focal_length):
super().__init__()
# Оптимизируемая фаза (от 0 до 2π)
self.phase = nn.Parameter(torch.zeros(256, 256))
self.lens = ThinLens(params, focal_length=focal_length)
self.prop = FreeSpace(params, distance=focal_length, method="AS")
def forward(self, wf):
# Ограничиваем фазу диапазоном [0, 2π]
phase = torch.sigmoid(self.phase) * 2 * torch.pi
wf = wf * torch.exp(1j * phase)
wf = self.lens(wf)
wf = self.prop(wf)
return wf
# Целевое распределение (например, кольцо)
r = torch.sqrt(W**2 + H**2)
target = ((r > 10e-6) & (r < 20e-6)).float()
system = SLMOptimization(params, focal_length=100*ureg.mm)
optimizer = torch.optim.Adam(system.parameters(), lr=0.1)
for epoch in range(500):
optimizer.zero_grad()
wf_out = system(wf_in)
loss = nn.functional.mse_loss(wf_out.intensity, target)
loss.backward()
optimizer.step()Восстановление фазы (Gerchberg-Saxton)
def gerchberg_saxton(params, target_amplitude, n_iterations=100):
"""Итеративный алгоритм восстановления фазы."""
# Начальная фаза — случайная
phase = torch.rand(256, 256) * 2 * torch.pi
for _ in range(n_iterations):
# Прямое преобразование
field = target_amplitude * torch.exp(1j * phase)
spectrum = torch.fft.fft2(field)
# Ограничение в частотной области
spectrum_phase = torch.angle(spectrum)
spectrum = torch.abs(target_spectrum) * torch.exp(1j * spectrum_phase)
# Обратное преобразование
field = torch.fft.ifft2(spectrum)
# Ограничение в пространственной области
phase = torch.angle(field)
return phaseGPU ускорение
Перенос на GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
params = params.to(device)
wf_in = wf_in.to(device)
system = system.to(device)
target = target.to(device)Смешанная точность (AMP)
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for epoch in range(1000):
optimizer.zero_grad()
with autocast():
wf_out = system(wf_in)
loss = loss_function(wf_out)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()Регуляризация
L2 регуляризация фазы
def regularized_loss(wf_out, target, phase, lambda_reg=0.01):
main_loss = nn.functional.mse_loss(wf_out.intensity, target)
reg_loss = lambda_reg * (phase**2).mean()
return main_loss + reg_lossTotal Variation (гладкость фазы)
def total_variation(phase):
"""Total variation для гладкости фазы."""
diff_x = torch.abs(phase[:, 1:] - phase[:, :-1])
diff_y = torch.abs(phase[1:, :] - phase[:-1, :])
return diff_x.mean() + diff_y.mean()Ограничение диапазона фазы
# Мягкое ограничение через sigmoid
phase_constrained = torch.sigmoid(self.phase) * 2 * torch.pi
# Или через clamp
phase_constrained = torch.clamp(self.phase, 0, 2 * torch.pi)Визуализация обучения
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# Loss
axes[0].plot(losses)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].set_yscale('log')
# Результат
axes[1].imshow(wf_out.intensity.detach().cpu(), cmap='hot')
axes[1].set_title('Output Intensity')
# Оптимизированная фаза
axes[2].imshow(system.phase.detach().cpu(), cmap='twilight')
axes[2].set_title('Optimized Phase')
plt.tight_layout()
plt.show()См. также
- Оптические системы — сборка систем
- Туториал: Восстановление фазы — GS алгоритм
- Туториал: Дифракционная нейросеть — D²NN