Детекторы
Модуль svetlanna.detector предоставляет классы для преобразования оптического поля в измеряемые величины и классификации.
Detector
Detector преобразует волновой фронт в интенсивность (изображение детектора).
from svetlanna.detector import Detector
detector = Detector(
simulation_parameters=params,
func='intensity' # Пока единственный режим
)
# Преобразование волнового фронта
detector_image = detector(wavefront)
# detector_image = |wavefront|² = wavefront.intensityDetector эквивалентен вычислению wf.intensity, но интегрируется в pipeline как nn.Module.
DetectorProcessorClf
DetectorProcessorClf обрабатывает изображение детектора для задач классификации, суммируя интенсивность в заданных областях.
from svetlanna.detector import DetectorProcessorClf
processor = DetectorProcessorClf(
num_classes=10,
simulation_parameters=params,
segmentation_type='strips', # Тип сегментации
segments_zone_size=None, # Размер зоны (опционально)
device='cuda'
)
# Обработка одного изображения
class_probabilities = processor(detector_image)
# shape: (1, num_classes)
# Batch processing
batch_probs = processor.batch_forward(batch_detector_images)
# shape: (batch_size, num_classes)
# Интеграл по зоне конкретного класса
integral = processor.batch_zone_integral(batch_images, ind_class=0)
# shape: (batch_size,)Типы сегментации
| Тип | Описание |
|---|---|
'strips' | Вертикальные полосы, симметрично расположенные |
Визуализация зон
# Получение маски зон
zones = processor.zones # Tensor с индексами классов
import matplotlib.pyplot as plt
plt.imshow(zones.cpu())
plt.title('Зоны детектора')
plt.colorbar(label='Класс')Полный pipeline классификации
import torch
from svetlanna import SimulationParameters, Wavefront, LinearOpticalSetup
from svetlanna.elements import ThinLens, FreeSpace, DiffractiveLayer
from svetlanna.detector import Detector, DetectorProcessorClf
from svetlanna.transforms import ToWavefront
from svetlanna.units import ureg
# Параметры
params = SimulationParameters.from_ranges(
w_range=(-5*ureg.mm, 5*ureg.mm), w_points=256,
h_range=(-5*ureg.mm, 5*ureg.mm), h_points=256,
wavelength=632.8*ureg.nm
)
# Оптическая система
setup = LinearOpticalSetup([
DiffractiveLayer(params, mask=torch.rand(256, 256) * 2 * torch.pi),
FreeSpace(params, distance=50*ureg.mm, method='AS'),
ThinLens(params, focal_length=100*ureg.mm),
FreeSpace(params, distance=100*ureg.mm, method='AS'),
])
# Детектор и процессор
detector = Detector(params, func='intensity')
processor = DetectorProcessorClf(
num_classes=10,
simulation_parameters=params,
segmentation_type='strips'
)
# Forward pass
def classify(image):
"""Классификация изображения."""
# Изображение -> волновой фронт
wf = ToWavefront(modulation_type='phase')(image)
# Оптическая обработка
wf_out = setup(wf)
# Детектирование
intensity = detector(wf_out)
# Классификация
probs = processor(intensity)
return probs
# Пример использования
image = torch.rand(1, 256, 256) # Нормализованное изображение
probs = classify(image)
predicted_class = probs.argmax(dim=1)
print(f"Предсказанный класс: {predicted_class.item()}")Обучение классификатора
import torch.optim as optim
import torch.nn.functional as F
# Все компоненты в одном модуле
class OpticalClassifier(torch.nn.Module):
def __init__(self, params, num_classes):
super().__init__()
self.setup = LinearOpticalSetup([
DiffractiveLayer(params, mask=torch.rand(256, 256) * 2 * torch.pi),
FreeSpace(params, distance=50*ureg.mm, method='AS'),
])
self.detector = Detector(params, func='intensity')
self.processor = DetectorProcessorClf(
num_classes=num_classes,
simulation_parameters=params,
segmentation_type='strips'
)
def forward(self, wf):
wf = self.setup(wf)
intensity = self.detector(wf)
return self.processor(intensity)
# Обучение
model = OpticalClassifier(params, num_classes=10)
optimizer = optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
# Forward
wf_input = Wavefront.plane_wave(params) * input_images
logits = model(wf_input)
# Loss
loss = F.cross_entropy(logits, labels)
# Backward
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}: loss = {loss.item():.4f}")См. также
- LinearOpticalSetup — оптические системы
- Transforms — преобразования входных данных
- Оптимизация — обучение систем