Transforms
Модуль svetlanna.transforms предоставляет преобразования для подготовки входных данных к оптической обработке.
ToWavefront
Преобразует изображение (тензор) в волновой фронт.
from svetlanna.transforms import ToWavefront
# Фазовая модуляция (значения изображения -> фаза)
to_wf_phase = ToWavefront(modulation_type='phase')
# Амплитудная модуляция (значения изображения -> амплитуда)
to_wf_amp = ToWavefront(modulation_type='amp')
# Амплитудно-фазовая модуляция
to_wf_both = ToWavefront(modulation_type='amp&phase')Типы модуляции
phase
Фазовая модуляция — значения изображения кодируются в фазу:
transform = ToWavefront(modulation_type='phase')
wf = transform(image) # image ∈ [0, 1]
# wf.phase ∈ [0, 2π]
# |wf| = 1GaussModulation
Модулирует волновой фронт гауссовым профилем (имитация лазерного пучка).
from svetlanna.transforms import GaussModulation
from svetlanna.units import ureg
gauss_mod = GaussModulation(
sim_params=params,
fwhm_x=2.0*ureg.mm, # FWHM по X
fwhm_y=2.0*ureg.mm, # FWHM по Y
peak_x=0.0, # Центр по X
peak_y=0.0 # Центр по Y
)
# Применение к волновому фронту
wf_modulated = gauss_mod(wf)
# wf_modulated = wf * Gaussian(x, y)Параметры
| Параметр | Описание |
|---|---|
fwhm_x, fwhm_y | Полуширина на полувысоте |
peak_x, peak_y | Положение центра гауссиана |
Комбинирование transforms
Transforms можно комбинировать с помощью torchvision.transforms.Compose или nn.Sequential:
import torch.nn as nn
from svetlanna.transforms import ToWavefront, GaussModulation
class InputPipeline(nn.Module):
def __init__(self, params):
super().__init__()
self.to_wf = ToWavefront(modulation_type='phase')
self.gauss = GaussModulation(
sim_params=params,
fwhm_x=1.5*ureg.mm,
fwhm_y=1.5*ureg.mm
)
def forward(self, image):
wf = self.to_wf(image)
wf = self.gauss(wf)
return wf
pipeline = InputPipeline(params)
wf = pipeline(image)Нормализация входных данных
Перед применением ToWavefront убедитесь, что изображение нормализовано в диапазоне [0, 1].
import torch
def normalize_image(image):
"""Нормализация в [0, 1]."""
min_val = image.min()
max_val = image.max()
return (image - min_val) / (max_val - min_val + 1e-8)
# Применение
image_norm = normalize_image(raw_image)
wf = ToWavefront('phase')(image_norm)Пример: MNIST classification pipeline
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms as T
from svetlanna import SimulationParameters, LinearOpticalSetup
from svetlanna.elements import DiffractiveLayer, FreeSpace
from svetlanna.transforms import ToWavefront, GaussModulation
from svetlanna.detector import Detector, DetectorProcessorClf
from svetlanna.units import ureg
# Параметры оптической системы
params = SimulationParameters.from_ranges(
w_range=(-5*ureg.mm, 5*ureg.mm), w_points=28, # Размер MNIST
h_range=(-5*ureg.mm, 5*ureg.mm), h_points=28,
wavelength=632.8*ureg.nm
)
# Torchvision transforms -> Svetlanna transforms
class MNISTToOptical(torch.nn.Module):
def __init__(self, params):
super().__init__()
self.to_wf = ToWavefront(modulation_type='phase')
self.gauss = GaussModulation(
sim_params=params,
fwhm_x=3*ureg.mm,
fwhm_y=3*ureg.mm
)
def forward(self, image):
# image: (B, 1, 28, 28) -> (B, 28, 28)
image = image.squeeze(1)
wf = self.to_wf(image)
wf = self.gauss(wf)
return wf
# Оптическая сеть
class DONN(torch.nn.Module):
def __init__(self, params, num_classes):
super().__init__()
self.input_transform = MNISTToOptical(params)
self.optical = LinearOpticalSetup([
DiffractiveLayer(params, mask=torch.rand(28, 28) * 2 * torch.pi),
FreeSpace(params, distance=10*ureg.mm, method='AS'),
DiffractiveLayer(params, mask=torch.rand(28, 28) * 2 * torch.pi),
FreeSpace(params, distance=10*ureg.mm, method='AS'),
])
self.detector = Detector(params, func='intensity')
self.processor = DetectorProcessorClf(
num_classes=num_classes,
simulation_parameters=params,
segmentation_type='strips'
)
def forward(self, images):
wf = self.input_transform(images)
wf = self.optical(wf)
intensity = self.detector(wf)
return self.processor.batch_forward(intensity)
# Использование
model = DONN(params, num_classes=10)
# Загрузка данных
dataset = datasets.MNIST(
'./data',
train=True,
download=True,
transform=T.ToTensor()
)
loader = DataLoader(dataset, batch_size=32)
# Forward pass
images, labels = next(iter(loader))
logits = model(images)
print(f"Output shape: {logits.shape}") # (32, 10)См. также
- Detector — детектирование и классификация
- Wavefront — волновые фронты
- Туториал: D²NN — дифракционная нейросеть