Дифракционная нейросеть
Дифракционные глубокие нейронные сети (D²NN, Diffractive Deep Neural Networks) — это оптические нейронные сети, где вычисления выполняются светом при прохождении через последовательность дифракционных слоёв.
Теория
Архитектура D²NN
D²NN состоит из последовательности дифракционных слоёв, каждый из которых модулирует фазу проходящего света:
Вход → Слой 1 → Распространение → Слой 2 → ... → Слой N → ДетекторКаждый пиксель слоя действует как “нейрон”, модулирующий фазу света.
Обучение
Фазовые маски слоёв оптимизируются методом обратного распространения ошибки, так же как веса в обычных нейронных сетях.
Шаг 1: Настройка
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from svetlanna import SimulationParameters, Wavefront
from svetlanna.elements import FreeSpace
from svetlanna.units import ureg
# Параметры симуляции
params = SimulationParameters.from_ranges(
w_range=(-1*ureg.mm, 1*ureg.mm), w_points=64, # Небольшое разрешение для скорости
h_range=(-1*ureg.mm, 1*ureg.mm), h_points=64,
wavelength=632.8*ureg.nm
)
device = "cuda" if torch.cuda.is_available() else "cpu"
params = params.to(device)Шаг 2: Архитектура D²NN
class DiffractiveLayer(nn.Module):
"""Один дифракционный слой."""
def __init__(self, size):
super().__init__()
# Оптимизируемая фаза
self.phase = nn.Parameter(torch.zeros(size, size))
def forward(self, wf):
# Модуляция фазы
return wf * torch.exp(1j * self.phase)
class D2NN(nn.Module):
"""Дифракционная глубокая нейронная сеть."""
def __init__(self, params, n_layers, layer_distance, n_classes):
super().__init__()
self.params = params
self.n_classes = n_classes
size = len(params.axes.W)
# Дифракционные слои
self.layers = nn.ModuleList([
DiffractiveLayer(size) for _ in range(n_layers)
])
# Распространение между слоями
self.propagate = FreeSpace(params, distance=layer_distance, method="AS")
# Области детектора для классификации
self.register_buffer('detector_masks', self._create_detector_masks(size))
def _create_detector_masks(self, size):
"""Создаёт маски детекторов для каждого класса."""
masks = torch.zeros(self.n_classes, size, size)
# Располагаем детекторы по кругу
center = size // 2
radius = size // 4
angles = torch.linspace(0, 2*torch.pi, self.n_classes + 1)[:-1]
detector_size = size // 10
for i, angle in enumerate(angles):
cx = int(center + radius * torch.cos(angle))
cy = int(center + radius * torch.sin(angle))
# Квадратный детектор
y_start = max(0, cy - detector_size // 2)
y_end = min(size, cy + detector_size // 2)
x_start = max(0, cx - detector_size // 2)
x_end = min(size, cx + detector_size // 2)
masks[i, y_start:y_end, x_start:x_end] = 1.0
return masks
def forward(self, wf):
# Прохождение через слои
for layer in self.layers:
wf = layer(wf)
wf = self.propagate(wf)
# Интенсивность на выходе
intensity = wf.intensity
# Суммируем интенсивность в областях детекторов
outputs = []
for mask in self.detector_masks:
energy = (intensity * mask).sum(dim=(-2, -1))
outputs.append(energy)
return torch.stack(outputs, dim=-1)Шаг 3: Подготовка данных
Создадим простой датасет — изображения цифр:
def create_digit_dataset(params, n_samples_per_class=100):
"""Создаёт простой датасет цифр 0-9."""
size = len(params.axes.W)
X = []
y = []
for digit in range(10):
for _ in range(n_samples_per_class):
# Создаём изображение цифры с небольшим шумом
img = torch.zeros(size, size)
# Простые паттерны для цифр
center = size // 2
r = size // 4
if digit == 0:
# Круг
Y, X_grid = torch.meshgrid(
torch.arange(size), torch.arange(size), indexing='ij'
)
dist = torch.sqrt((X_grid - center)**2 + (Y - center)**2)
img = ((dist > r*0.7) & (dist < r)).float()
elif digit == 1:
# Вертикальная линия
img[:, center-2:center+2] = 1.0
elif digit == 2:
# Горизонтальные линии
img[center-r:center-r+4, center-r:center+r] = 1.0
img[center:center+4, center-r:center+r] = 1.0
img[center+r-4:center+r, center-r:center+r] = 1.0
# ... (упрощённо для остальных цифр)
else:
# Случайный паттерн для остальных
img[center-r:center+r, center-r:center+r] = (
torch.rand(2*r, 2*r) > 0.5 - digit*0.05
).float()
# Добавляем шум
img = img + 0.1 * torch.randn_like(img)
img = torch.clamp(img, 0, 1)
X.append(img)
y.append(digit)
X = torch.stack(X)
y = torch.tensor(y)
return X, y
# Создаём датасет
X_train, y_train = create_digit_dataset(params, n_samples_per_class=50)
X_test, y_test = create_digit_dataset(params, n_samples_per_class=10)
# DataLoader
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)Шаг 4: Входной волновой фронт
def create_input_wavefront(images, params):
"""Преобразует изображения в волновые фронты."""
# Нормализуем и преобразуем в комплексную амплитуду
wf = images.to(torch.complex64)
return Wavefront(wf)Шаг 5: Обучение
# Модель
model = D2NN(
params=params,
n_layers=5,
layer_distance=10*ureg.mm,
n_classes=10
).to(device)
# Оптимизатор
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
criterion = nn.CrossEntropyLoss()
# Обучение
n_epochs = 50
train_losses = []
train_accs = []
for epoch in range(n_epochs):
model.train()
epoch_loss = 0
correct = 0
total = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
# Создаём волновые фронты
wf = create_input_wavefront(images, params)
# Forward
optimizer.zero_grad()
outputs = model(wf)
# Loss
loss = criterion(outputs, labels)
# Backward
loss.backward()
optimizer.step()
epoch_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
train_loss = epoch_loss / len(train_loader)
train_acc = 100. * correct / total
train_losses.append(train_loss)
train_accs.append(train_acc)
if epoch % 10 == 0:
print(f"Epoch {epoch}: loss={train_loss:.4f}, acc={train_acc:.2f}%")Шаг 6: Оценка модели
def evaluate(model, test_loader, params, device):
"""Оценка модели на тестовом наборе."""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
wf = create_input_wavefront(images, params)
outputs = model(wf)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return 100. * correct / total
test_acc = evaluate(model, test_loader, params, device)
print(f"Test accuracy: {test_acc:.2f}%")Шаг 7: Визуализация
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# Обучение
axes[0, 0].plot(train_losses)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training Loss')
axes[0, 1].plot(train_accs)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('Training Accuracy')
# Фазовые маски слоёв
for i, layer in enumerate(model.layers[:3]):
ax = axes[0, 2] if i == 0 else axes[1, i-1]
im = ax.imshow(layer.phase.detach().cpu(), cmap='twilight')
ax.set_title(f'Layer {i+1} Phase')
plt.colorbar(im, ax=ax)
# Детекторы
axes[1, 2].imshow(model.detector_masks.sum(0).cpu(), cmap='hot')
axes[1, 2].set_title('Detector Regions')
plt.tight_layout()
plt.show()Шаг 8: Визуализация работы сети
def visualize_propagation(model, image, params):
"""Визуализирует распространение света через D²NN."""
model.eval()
with torch.no_grad():
wf = create_input_wavefront(image.unsqueeze(0), params)
intensities = [wf.intensity[0].cpu()]
for layer in model.layers:
wf = layer(wf)
wf = model.propagate(wf)
intensities.append(wf.intensity[0].cpu())
fig, axes = plt.subplots(1, len(intensities), figsize=(3*len(intensities), 3))
for i, intensity in enumerate(intensities):
axes[i].imshow(intensity, cmap='hot')
axes[i].set_title(f'{"Input" if i == 0 else f"After Layer {i}"}')
axes[i].axis('off')
plt.tight_layout()
plt.show()
# Визуализация для одного примера
visualize_propagation(model, X_test[0].to(device), params)Полный код
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from svetlanna import SimulationParameters, Wavefront
from svetlanna.elements import FreeSpace
from svetlanna.units import ureg
# Параметры
params = SimulationParameters.from_ranges(
w_range=(-1*ureg.mm, 1*ureg.mm), w_points=64,
h_range=(-1*ureg.mm, 1*ureg.mm), h_points=64,
wavelength=632.8*ureg.nm
)
device = "cuda" if torch.cuda.is_available() else "cpu"
# D²NN
class D2NN(nn.Module):
def __init__(self, params, n_layers, n_classes):
super().__init__()
size = len(params.axes.W)
self.phases = nn.ParameterList([
nn.Parameter(torch.zeros(size, size)) for _ in range(n_layers)
])
self.prop = FreeSpace(params, distance=10*ureg.mm, method="AS")
def forward(self, wf):
for phase in self.phases:
wf = wf * torch.exp(1j * phase)
wf = self.prop(wf)
return wf.intensity
# Создание и обучение
model = D2NN(params, n_layers=5, n_classes=10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
# ... обучение ...
print("D²NN готова!")Выводы
- D²NN реализует нейронную сеть полностью оптическими средствами
- Обучение происходит через обратное распространение ошибки в симуляции
- Фазовые маски слоёв являются обучаемыми параметрами
- SVETlANNa позволяет создавать дифференцируемые модели оптических систем
Применения
- Классификация изображений со скоростью света
- Распознавание объектов в реальном времени
- Оптические вычисления с низким энергопотреблением
- Специализированные оптические процессоры
Ссылки
Что дальше?
- Оптимизация — методы обучения
- Оптические системы — сборка систем