Skip to Content

Transforms

Модуль svetlanna.transforms предоставляет преобразования для подготовки входных данных к оптической обработке.

ToWavefront

Преобразует изображение (тензор) в волновой фронт.

from svetlanna.transforms import ToWavefront # Фазовая модуляция (значения изображения -> фаза) to_wf_phase = ToWavefront(modulation_type='phase') # Амплитудная модуляция (значения изображения -> амплитуда) to_wf_amp = ToWavefront(modulation_type='amp') # Амплитудно-фазовая модуляция to_wf_both = ToWavefront(modulation_type='amp&phase')

Типы модуляции

Фазовая модуляция — значения изображения кодируются в фазу:

E(x,y)=exp(i2πI(x,y))E(x, y) = \exp(i \cdot 2\pi \cdot I(x, y))

transform = ToWavefront(modulation_type='phase') wf = transform(image) # image ∈ [0, 1] # wf.phase ∈ [0, 2π] # |wf| = 1

GaussModulation

Модулирует волновой фронт гауссовым профилем (имитация лазерного пучка).

from svetlanna.transforms import GaussModulation from svetlanna.units import ureg gauss_mod = GaussModulation( sim_params=params, fwhm_x=2.0*ureg.mm, # FWHM по X fwhm_y=2.0*ureg.mm, # FWHM по Y peak_x=0.0, # Центр по X peak_y=0.0 # Центр по Y ) # Применение к волновому фронту wf_modulated = gauss_mod(wf) # wf_modulated = wf * Gaussian(x, y)

Параметры

ПараметрОписание
fwhm_x, fwhm_yПолуширина на полувысоте
peak_x, peak_yПоложение центра гауссиана

Комбинирование transforms

Transforms можно комбинировать с помощью torchvision.transforms.Compose или nn.Sequential:

import torch.nn as nn from svetlanna.transforms import ToWavefront, GaussModulation class InputPipeline(nn.Module): def __init__(self, params): super().__init__() self.to_wf = ToWavefront(modulation_type='phase') self.gauss = GaussModulation( sim_params=params, fwhm_x=1.5*ureg.mm, fwhm_y=1.5*ureg.mm ) def forward(self, image): wf = self.to_wf(image) wf = self.gauss(wf) return wf pipeline = InputPipeline(params) wf = pipeline(image)

Нормализация входных данных

Перед применением ToWavefront убедитесь, что изображение нормализовано в диапазоне [0, 1].

import torch def normalize_image(image): """Нормализация в [0, 1].""" min_val = image.min() max_val = image.max() return (image - min_val) / (max_val - min_val + 1e-8) # Применение image_norm = normalize_image(raw_image) wf = ToWavefront('phase')(image_norm)

Пример: MNIST classification pipeline

import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms as T from svetlanna import SimulationParameters, LinearOpticalSetup from svetlanna.elements import DiffractiveLayer, FreeSpace from svetlanna.transforms import ToWavefront, GaussModulation from svetlanna.detector import Detector, DetectorProcessorClf from svetlanna.units import ureg # Параметры оптической системы params = SimulationParameters.from_ranges( w_range=(-5*ureg.mm, 5*ureg.mm), w_points=28, # Размер MNIST h_range=(-5*ureg.mm, 5*ureg.mm), h_points=28, wavelength=632.8*ureg.nm ) # Torchvision transforms -> Svetlanna transforms class MNISTToOptical(torch.nn.Module): def __init__(self, params): super().__init__() self.to_wf = ToWavefront(modulation_type='phase') self.gauss = GaussModulation( sim_params=params, fwhm_x=3*ureg.mm, fwhm_y=3*ureg.mm ) def forward(self, image): # image: (B, 1, 28, 28) -> (B, 28, 28) image = image.squeeze(1) wf = self.to_wf(image) wf = self.gauss(wf) return wf # Оптическая сеть class DONN(torch.nn.Module): def __init__(self, params, num_classes): super().__init__() self.input_transform = MNISTToOptical(params) self.optical = LinearOpticalSetup([ DiffractiveLayer(params, mask=torch.rand(28, 28) * 2 * torch.pi), FreeSpace(params, distance=10*ureg.mm, method='AS'), DiffractiveLayer(params, mask=torch.rand(28, 28) * 2 * torch.pi), FreeSpace(params, distance=10*ureg.mm, method='AS'), ]) self.detector = Detector(params, func='intensity') self.processor = DetectorProcessorClf( num_classes=num_classes, simulation_parameters=params, segmentation_type='strips' ) def forward(self, images): wf = self.input_transform(images) wf = self.optical(wf) intensity = self.detector(wf) return self.processor.batch_forward(intensity) # Использование model = DONN(params, num_classes=10) # Загрузка данных dataset = datasets.MNIST( './data', train=True, download=True, transform=T.ToTensor() ) loader = DataLoader(dataset, batch_size=32) # Forward pass images, labels = next(iter(loader)) logits = model(images) print(f"Output shape: {logits.shape}") # (32, 10)

См. также