Source code for src.score_net

import torch
import torch.nn as nn

[docs] class ScoreNet(nn.Module): """A simple convolutional neural network for score estimation. This module implements a small encoder-decoder architecture using convolutional and transposed convolutional layers with ELU activations to estimate the score function :math:`\\nabla_x \\log p(x)`. Attributes ---------- main : nn.Sequential The sequential model containing the convolutional and transposed convolutional layers. """ def __init__(self, channels: int = 3): """Initialize the ScoreNet model. Parameters ---------- channels : int, optional Number of input and output channels. For RGB images is typically 3. Default is 3. """ super(ScoreNet, self).__init__() self.main = nn.Sequential( # Down nn.Conv2d(channels, 64, 3, 1, 1), nn.ELU(), nn.Conv2d(64, 128, 3, 2, 1), nn.ELU(), # 128 x H/2 x W/2 nn.Conv2d(128, 256, 3, 2, 1), nn.ELU(), # 256 x H/4 x W/4 # Up nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.ELU(), # Back to H/2 x W/2 nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ELU(), # Back to H x W nn.Conv2d(64, channels, 3, 1, 1) )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Compute the forward pass of the ScoreNet. Parameters ---------- x : torch.Tensor Input image batch of shape (batch_size, channels, height, width). Returns ------- torch.Tensor Output tensor with the same shape as the input, representing the estimated score for each pixel. """ return self.main(x)