Skip to Content
DocsТуториалыДифракционная нейросеть

Дифракционная нейросеть

Дифракционные глубокие нейронные сети (D²NN, Diffractive Deep Neural Networks) — это оптические нейронные сети, где вычисления выполняются светом при прохождении через последовательность дифракционных слоёв.

Теория

Архитектура D²NN

D²NN состоит из последовательности дифракционных слоёв, каждый из которых модулирует фазу проходящего света:

Вход → Слой 1 → Распространение → Слой 2 → ... → Слой N → Детектор

Каждый пиксель слоя действует как “нейрон”, модулирующий фазу света.

Обучение

Фазовые маски слоёв оптимизируются методом обратного распространения ошибки, так же как веса в обычных нейронных сетях.

Шаг 1: Настройка

import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, TensorDataset import matplotlib.pyplot as plt from svetlanna import SimulationParameters, Wavefront from svetlanna.elements import FreeSpace from svetlanna.units import ureg # Параметры симуляции params = SimulationParameters.from_ranges( w_range=(-1*ureg.mm, 1*ureg.mm), w_points=64, # Небольшое разрешение для скорости h_range=(-1*ureg.mm, 1*ureg.mm), h_points=64, wavelength=632.8*ureg.nm ) device = "cuda" if torch.cuda.is_available() else "cpu" params = params.to(device)

Шаг 2: Архитектура D²NN

class DiffractiveLayer(nn.Module): """Один дифракционный слой.""" def __init__(self, size): super().__init__() # Оптимизируемая фаза self.phase = nn.Parameter(torch.zeros(size, size)) def forward(self, wf): # Модуляция фазы return wf * torch.exp(1j * self.phase) class D2NN(nn.Module): """Дифракционная глубокая нейронная сеть.""" def __init__(self, params, n_layers, layer_distance, n_classes): super().__init__() self.params = params self.n_classes = n_classes size = len(params.axes.W) # Дифракционные слои self.layers = nn.ModuleList([ DiffractiveLayer(size) for _ in range(n_layers) ]) # Распространение между слоями self.propagate = FreeSpace(params, distance=layer_distance, method="AS") # Области детектора для классификации self.register_buffer('detector_masks', self._create_detector_masks(size)) def _create_detector_masks(self, size): """Создаёт маски детекторов для каждого класса.""" masks = torch.zeros(self.n_classes, size, size) # Располагаем детекторы по кругу center = size // 2 radius = size // 4 angles = torch.linspace(0, 2*torch.pi, self.n_classes + 1)[:-1] detector_size = size // 10 for i, angle in enumerate(angles): cx = int(center + radius * torch.cos(angle)) cy = int(center + radius * torch.sin(angle)) # Квадратный детектор y_start = max(0, cy - detector_size // 2) y_end = min(size, cy + detector_size // 2) x_start = max(0, cx - detector_size // 2) x_end = min(size, cx + detector_size // 2) masks[i, y_start:y_end, x_start:x_end] = 1.0 return masks def forward(self, wf): # Прохождение через слои for layer in self.layers: wf = layer(wf) wf = self.propagate(wf) # Интенсивность на выходе intensity = wf.intensity # Суммируем интенсивность в областях детекторов outputs = [] for mask in self.detector_masks: energy = (intensity * mask).sum(dim=(-2, -1)) outputs.append(energy) return torch.stack(outputs, dim=-1)

Шаг 3: Подготовка данных

Создадим простой датасет — изображения цифр:

def create_digit_dataset(params, n_samples_per_class=100): """Создаёт простой датасет цифр 0-9.""" size = len(params.axes.W) X = [] y = [] for digit in range(10): for _ in range(n_samples_per_class): # Создаём изображение цифры с небольшим шумом img = torch.zeros(size, size) # Простые паттерны для цифр center = size // 2 r = size // 4 if digit == 0: # Круг Y, X_grid = torch.meshgrid( torch.arange(size), torch.arange(size), indexing='ij' ) dist = torch.sqrt((X_grid - center)**2 + (Y - center)**2) img = ((dist > r*0.7) & (dist < r)).float() elif digit == 1: # Вертикальная линия img[:, center-2:center+2] = 1.0 elif digit == 2: # Горизонтальные линии img[center-r:center-r+4, center-r:center+r] = 1.0 img[center:center+4, center-r:center+r] = 1.0 img[center+r-4:center+r, center-r:center+r] = 1.0 # ... (упрощённо для остальных цифр) else: # Случайный паттерн для остальных img[center-r:center+r, center-r:center+r] = ( torch.rand(2*r, 2*r) > 0.5 - digit*0.05 ).float() # Добавляем шум img = img + 0.1 * torch.randn_like(img) img = torch.clamp(img, 0, 1) X.append(img) y.append(digit) X = torch.stack(X) y = torch.tensor(y) return X, y # Создаём датасет X_train, y_train = create_digit_dataset(params, n_samples_per_class=50) X_test, y_test = create_digit_dataset(params, n_samples_per_class=10) # DataLoader train_dataset = TensorDataset(X_train, y_train) test_dataset = TensorDataset(X_test, y_test) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=32)

Шаг 4: Входной волновой фронт

def create_input_wavefront(images, params): """Преобразует изображения в волновые фронты.""" # Нормализуем и преобразуем в комплексную амплитуду wf = images.to(torch.complex64) return Wavefront(wf)

Шаг 5: Обучение

# Модель model = D2NN( params=params, n_layers=5, layer_distance=10*ureg.mm, n_classes=10 ).to(device) # Оптимизатор optimizer = torch.optim.Adam(model.parameters(), lr=0.1) criterion = nn.CrossEntropyLoss() # Обучение n_epochs = 50 train_losses = [] train_accs = [] for epoch in range(n_epochs): model.train() epoch_loss = 0 correct = 0 total = 0 for images, labels in train_loader: images = images.to(device) labels = labels.to(device) # Создаём волновые фронты wf = create_input_wavefront(images, params) # Forward optimizer.zero_grad() outputs = model(wf) # Loss loss = criterion(outputs, labels) # Backward loss.backward() optimizer.step() epoch_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() train_loss = epoch_loss / len(train_loader) train_acc = 100. * correct / total train_losses.append(train_loss) train_accs.append(train_acc) if epoch % 10 == 0: print(f"Epoch {epoch}: loss={train_loss:.4f}, acc={train_acc:.2f}%")

Шаг 6: Оценка модели

def evaluate(model, test_loader, params, device): """Оценка модели на тестовом наборе.""" model.eval() correct = 0 total = 0 with torch.no_grad(): for images, labels in test_loader: images = images.to(device) labels = labels.to(device) wf = create_input_wavefront(images, params) outputs = model(wf) _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() return 100. * correct / total test_acc = evaluate(model, test_loader, params, device) print(f"Test accuracy: {test_acc:.2f}%")

Шаг 7: Визуализация

fig, axes = plt.subplots(2, 3, figsize=(15, 10)) # Обучение axes[0, 0].plot(train_losses) axes[0, 0].set_xlabel('Epoch') axes[0, 0].set_ylabel('Loss') axes[0, 0].set_title('Training Loss') axes[0, 1].plot(train_accs) axes[0, 1].set_xlabel('Epoch') axes[0, 1].set_ylabel('Accuracy (%)') axes[0, 1].set_title('Training Accuracy') # Фазовые маски слоёв for i, layer in enumerate(model.layers[:3]): ax = axes[0, 2] if i == 0 else axes[1, i-1] im = ax.imshow(layer.phase.detach().cpu(), cmap='twilight') ax.set_title(f'Layer {i+1} Phase') plt.colorbar(im, ax=ax) # Детекторы axes[1, 2].imshow(model.detector_masks.sum(0).cpu(), cmap='hot') axes[1, 2].set_title('Detector Regions') plt.tight_layout() plt.show()

Шаг 8: Визуализация работы сети

def visualize_propagation(model, image, params): """Визуализирует распространение света через D²NN.""" model.eval() with torch.no_grad(): wf = create_input_wavefront(image.unsqueeze(0), params) intensities = [wf.intensity[0].cpu()] for layer in model.layers: wf = layer(wf) wf = model.propagate(wf) intensities.append(wf.intensity[0].cpu()) fig, axes = plt.subplots(1, len(intensities), figsize=(3*len(intensities), 3)) for i, intensity in enumerate(intensities): axes[i].imshow(intensity, cmap='hot') axes[i].set_title(f'{"Input" if i == 0 else f"After Layer {i}"}') axes[i].axis('off') plt.tight_layout() plt.show() # Визуализация для одного примера visualize_propagation(model, X_test[0].to(device), params)

Полный код

import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset from svetlanna import SimulationParameters, Wavefront from svetlanna.elements import FreeSpace from svetlanna.units import ureg # Параметры params = SimulationParameters.from_ranges( w_range=(-1*ureg.mm, 1*ureg.mm), w_points=64, h_range=(-1*ureg.mm, 1*ureg.mm), h_points=64, wavelength=632.8*ureg.nm ) device = "cuda" if torch.cuda.is_available() else "cpu" # D²NN class D2NN(nn.Module): def __init__(self, params, n_layers, n_classes): super().__init__() size = len(params.axes.W) self.phases = nn.ParameterList([ nn.Parameter(torch.zeros(size, size)) for _ in range(n_layers) ]) self.prop = FreeSpace(params, distance=10*ureg.mm, method="AS") def forward(self, wf): for phase in self.phases: wf = wf * torch.exp(1j * phase) wf = self.prop(wf) return wf.intensity # Создание и обучение model = D2NN(params, n_layers=5, n_classes=10).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # ... обучение ... print("D²NN готова!")

Выводы

  1. D²NN реализует нейронную сеть полностью оптическими средствами
  2. Обучение происходит через обратное распространение ошибки в симуляции
  3. Фазовые маски слоёв являются обучаемыми параметрами
  4. SVETlANNa позволяет создавать дифференцируемые модели оптических систем

Применения

  • Классификация изображений со скоростью света
  • Распознавание объектов в реальном времени
  • Оптические вычисления с низким энергопотреблением
  • Специализированные оптические процессоры

Ссылки

Что дальше?