Skip to Content
DocsРуководстваОптимизация и обучение

Оптимизация и обучение

SVETlANNa полностью дифференцируема благодаря интеграции с PyTorch. Это позволяет использовать градиентные методы для оптимизации оптических систем.

Основы автоматического дифференцирования

Оптимизируемые параметры

Используйте nn.Parameter для параметров, которые нужно оптимизировать:

import torch import torch.nn as nn class OptimizableLens(nn.Module): def __init__(self, params, initial_focal_length): super().__init__() # Оптимизируемое фокусное расстояние self.focal_length = nn.Parameter( torch.tensor(initial_focal_length) ) self.params = params def forward(self, wf): k = 2 * torch.pi / self.params.axes.wavelength W, H = torch.meshgrid( self.params.axes.W, self.params.axes.H, indexing='xy' ) phase = -k / (2 * self.focal_length) * (W**2 + H**2) return wf * torch.exp(1j * phase)

Вычисление градиентов

# Forward pass wf_out = system(wf_in) loss = loss_function(wf_out) # Backward pass loss.backward() # Градиенты доступны for name, param in system.named_parameters(): print(f"{name}: grad = {param.grad}")

Функции потерь

Максимизация интенсивности в точке

def peak_intensity_loss(wf, target_x, target_y, params): """Максимизация интенсивности в заданной точке.""" # Найти индексы ближайшей точки ix = torch.argmin(torch.abs(params.axes.W - target_x)) iy = torch.argmin(torch.abs(params.axes.H - target_y)) # Отрицательная интенсивность (минимизируем) return -wf.intensity[iy, ix]

Соответствие целевому распределению

def target_distribution_loss(wf, target_intensity): """MSE между интенсивностью и целевым распределением.""" return nn.functional.mse_loss( wf.intensity / wf.intensity.max(), target_intensity / target_intensity.max() )

Минимизация FWHM

def fwhm_loss(wf, params): """Минимизация размера пятна.""" fwhm_x, fwhm_y = wf.fwhm(params) return fwhm_x + fwhm_y

Эффективность в область

def efficiency_loss(wf, target_mask): """Максимизация энергии в целевой области.""" total_energy = wf.intensity.sum() target_energy = (wf.intensity * target_mask).sum() efficiency = target_energy / total_energy return -efficiency # Минимизируем отрицательную эффективность

Оптимизаторы PyTorch

Adam (рекомендуется)

optimizer = torch.optim.Adam(system.parameters(), lr=1e-3)

SGD с momentum

optimizer = torch.optim.SGD( system.parameters(), lr=1e-2, momentum=0.9 )

LBFGS (для точной оптимизации)

optimizer = torch.optim.LBFGS( system.parameters(), lr=1.0, max_iter=20 ) def closure(): optimizer.zero_grad() wf_out = system(wf_in) loss = loss_function(wf_out) loss.backward() return loss optimizer.step(closure)

Цикл обучения

Базовый цикл

import torch from svetlanna import SimulationParameters, Wavefront from svetlanna.units import ureg # Параметры params = SimulationParameters.from_ranges( w_range=(-1*ureg.mm, 1*ureg.mm), w_points=256, h_range=(-1*ureg.mm, 1*ureg.mm), h_points=256, wavelength=632.8*ureg.nm ) # Входной волновой фронт wf_in = Wavefront.gaussian_beam(params, waist_radius=0.3*ureg.mm) # Система с оптимизируемыми параметрами system = OptimizableSystem(params) # Оптимизатор optimizer = torch.optim.Adam(system.parameters(), lr=1e-3) # Обучение losses = [] for epoch in range(1000): optimizer.zero_grad() wf_out = system(wf_in) loss = loss_function(wf_out) loss.backward() optimizer.step() losses.append(loss.item()) if epoch % 100 == 0: print(f"Epoch {epoch}: loss = {loss.item():.6f}")

С валидацией

for epoch in range(1000): # Training system.train() optimizer.zero_grad() wf_out = system(wf_in) train_loss = loss_function(wf_out) train_loss.backward() optimizer.step() # Validation (без градиентов) system.eval() with torch.no_grad(): wf_val = system(wf_validation) val_loss = loss_function(wf_val) print(f"Epoch {epoch}: train={train_loss:.6f}, val={val_loss:.6f}")

Примеры оптимизации

Оптимизация фазовой маски SLM

class SLMOptimization(nn.Module): def __init__(self, params, focal_length): super().__init__() # Оптимизируемая фаза (от 0 до 2π) self.phase = nn.Parameter(torch.zeros(256, 256)) self.lens = ThinLens(params, focal_length=focal_length) self.prop = FreeSpace(params, distance=focal_length, method="AS") def forward(self, wf): # Ограничиваем фазу диапазоном [0, 2π] phase = torch.sigmoid(self.phase) * 2 * torch.pi wf = wf * torch.exp(1j * phase) wf = self.lens(wf) wf = self.prop(wf) return wf # Целевое распределение (например, кольцо) r = torch.sqrt(W**2 + H**2) target = ((r > 10e-6) & (r < 20e-6)).float() system = SLMOptimization(params, focal_length=100*ureg.mm) optimizer = torch.optim.Adam(system.parameters(), lr=0.1) for epoch in range(500): optimizer.zero_grad() wf_out = system(wf_in) loss = nn.functional.mse_loss(wf_out.intensity, target) loss.backward() optimizer.step()

Восстановление фазы (Gerchberg-Saxton)

def gerchberg_saxton(params, target_amplitude, n_iterations=100): """Итеративный алгоритм восстановления фазы.""" # Начальная фаза — случайная phase = torch.rand(256, 256) * 2 * torch.pi for _ in range(n_iterations): # Прямое преобразование field = target_amplitude * torch.exp(1j * phase) spectrum = torch.fft.fft2(field) # Ограничение в частотной области spectrum_phase = torch.angle(spectrum) spectrum = torch.abs(target_spectrum) * torch.exp(1j * spectrum_phase) # Обратное преобразование field = torch.fft.ifft2(spectrum) # Ограничение в пространственной области phase = torch.angle(field) return phase

GPU ускорение

Перенос на GPU

device = "cuda" if torch.cuda.is_available() else "cpu" params = params.to(device) wf_in = wf_in.to(device) system = system.to(device) target = target.to(device)

Смешанная точность (AMP)

from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() for epoch in range(1000): optimizer.zero_grad() with autocast(): wf_out = system(wf_in) loss = loss_function(wf_out) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()

Регуляризация

L2 регуляризация фазы

def regularized_loss(wf_out, target, phase, lambda_reg=0.01): main_loss = nn.functional.mse_loss(wf_out.intensity, target) reg_loss = lambda_reg * (phase**2).mean() return main_loss + reg_loss

Total Variation (гладкость фазы)

def total_variation(phase): """Total variation для гладкости фазы.""" diff_x = torch.abs(phase[:, 1:] - phase[:, :-1]) diff_y = torch.abs(phase[1:, :] - phase[:-1, :]) return diff_x.mean() + diff_y.mean()

Ограничение диапазона фазы

# Мягкое ограничение через sigmoid phase_constrained = torch.sigmoid(self.phase) * 2 * torch.pi # Или через clamp phase_constrained = torch.clamp(self.phase, 0, 2 * torch.pi)

Визуализация обучения

import matplotlib.pyplot as plt fig, axes = plt.subplots(1, 3, figsize=(15, 4)) # Loss axes[0].plot(losses) axes[0].set_xlabel('Epoch') axes[0].set_ylabel('Loss') axes[0].set_title('Training Loss') axes[0].set_yscale('log') # Результат axes[1].imshow(wf_out.intensity.detach().cpu(), cmap='hot') axes[1].set_title('Output Intensity') # Оптимизированная фаза axes[2].imshow(system.phase.detach().cpu(), cmap='twilight') axes[2].set_title('Optimized Phase') plt.tight_layout() plt.show()

См. также