Skip to Content
DocsРуководстваДетекторы

Детекторы

Модуль svetlanna.detector предоставляет классы для преобразования оптического поля в измеряемые величины и классификации.

Detector

Detector преобразует волновой фронт в интенсивность (изображение детектора).

from svetlanna.detector import Detector detector = Detector( simulation_parameters=params, func='intensity' # Пока единственный режим ) # Преобразование волнового фронта detector_image = detector(wavefront) # detector_image = |wavefront|² = wavefront.intensity

Detector эквивалентен вычислению wf.intensity, но интегрируется в pipeline как nn.Module.


DetectorProcessorClf

DetectorProcessorClf обрабатывает изображение детектора для задач классификации, суммируя интенсивность в заданных областях.

from svetlanna.detector import DetectorProcessorClf processor = DetectorProcessorClf( num_classes=10, simulation_parameters=params, segmentation_type='strips', # Тип сегментации segments_zone_size=None, # Размер зоны (опционально) device='cuda' ) # Обработка одного изображения class_probabilities = processor(detector_image) # shape: (1, num_classes) # Batch processing batch_probs = processor.batch_forward(batch_detector_images) # shape: (batch_size, num_classes) # Интеграл по зоне конкретного класса integral = processor.batch_zone_integral(batch_images, ind_class=0) # shape: (batch_size,)

Типы сегментации

ТипОписание
'strips'Вертикальные полосы, симметрично расположенные

Визуализация зон

# Получение маски зон zones = processor.zones # Tensor с индексами классов import matplotlib.pyplot as plt plt.imshow(zones.cpu()) plt.title('Зоны детектора') plt.colorbar(label='Класс')

Полный pipeline классификации

import torch from svetlanna import SimulationParameters, Wavefront, LinearOpticalSetup from svetlanna.elements import ThinLens, FreeSpace, DiffractiveLayer from svetlanna.detector import Detector, DetectorProcessorClf from svetlanna.transforms import ToWavefront from svetlanna.units import ureg # Параметры params = SimulationParameters.from_ranges( w_range=(-5*ureg.mm, 5*ureg.mm), w_points=256, h_range=(-5*ureg.mm, 5*ureg.mm), h_points=256, wavelength=632.8*ureg.nm ) # Оптическая система setup = LinearOpticalSetup([ DiffractiveLayer(params, mask=torch.rand(256, 256) * 2 * torch.pi), FreeSpace(params, distance=50*ureg.mm, method='AS'), ThinLens(params, focal_length=100*ureg.mm), FreeSpace(params, distance=100*ureg.mm, method='AS'), ]) # Детектор и процессор detector = Detector(params, func='intensity') processor = DetectorProcessorClf( num_classes=10, simulation_parameters=params, segmentation_type='strips' ) # Forward pass def classify(image): """Классификация изображения.""" # Изображение -> волновой фронт wf = ToWavefront(modulation_type='phase')(image) # Оптическая обработка wf_out = setup(wf) # Детектирование intensity = detector(wf_out) # Классификация probs = processor(intensity) return probs # Пример использования image = torch.rand(1, 256, 256) # Нормализованное изображение probs = classify(image) predicted_class = probs.argmax(dim=1) print(f"Предсказанный класс: {predicted_class.item()}")

Обучение классификатора

import torch.optim as optim import torch.nn.functional as F # Все компоненты в одном модуле class OpticalClassifier(torch.nn.Module): def __init__(self, params, num_classes): super().__init__() self.setup = LinearOpticalSetup([ DiffractiveLayer(params, mask=torch.rand(256, 256) * 2 * torch.pi), FreeSpace(params, distance=50*ureg.mm, method='AS'), ]) self.detector = Detector(params, func='intensity') self.processor = DetectorProcessorClf( num_classes=num_classes, simulation_parameters=params, segmentation_type='strips' ) def forward(self, wf): wf = self.setup(wf) intensity = self.detector(wf) return self.processor(intensity) # Обучение model = OpticalClassifier(params, num_classes=10) optimizer = optim.Adam(model.parameters(), lr=0.01) for epoch in range(100): optimizer.zero_grad() # Forward wf_input = Wavefront.plane_wave(params) * input_images logits = model(wf_input) # Loss loss = F.cross_entropy(logits, labels) # Backward loss.backward() optimizer.step() if epoch % 10 == 0: print(f"Epoch {epoch}: loss = {loss.item():.4f}")

См. также