Source code for src.gan

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from typing import Tuple, List, Any

from .conv import ConfigurableCNN, ConvBlock, ConvTransposeBlock


def _validate_img_shape(img_shape: Tuple[int, int, int]) -> None:
    if not isinstance(img_shape, (tuple, list)) or len(img_shape) != 3:
        raise TypeError(f"img_shape must be a tuple/list of 3 ints (C, H, W), got: {img_shape!r}")
    if not all(isinstance(x, int) and x > 0 for x in img_shape):
        raise ValueError(f"img_shape elements must be positive integers, got: {img_shape!r}")


def _validate_conv_config(conv_layers_config: Any) -> None:
    if not isinstance(conv_layers_config, list) or len(conv_layers_config) == 0:
        raise TypeError("conv_layers_config must be a non-empty list of layer configuration dicts.")

[docs] class Discriminator(nn.Module): """Discriminator network for GANs using a configurable CNN architecture. This module implements a discriminator network that processes input images through a series of convolutional layers defined by a configuration list. Attributes ---------- cnn : ConfigurableCNN The configurable CNN used for feature extraction and classification. Example ------- >>> img_shape = (3, 64, 64) # Example image shape >>> conv_layers_config = [ ... {'out_channels': 64, 'kernel_size': 4, 'stride': 2, 'padding': 1, 'activation': 'leakyrelu'}, ... {'out_channels': 128, 'kernel_size': 4, 'stride': 2, 'padding': 1, 'activation': 'leakyrelu'}, ... {'out_channels': 256, 'kernel_size': 4, 'stride': 2, 'padding': 1, 'activation': 'leakyrelu'}, ... {'out_channels': 1, 'kernel_size': 4, 'stride': 1, 'padding': 0, 'activation': 'linear'}, ... ] >>> discriminator = Discriminator(img_shape, conv_layers_config) >>> x = torch.randn(4, 3, 64, 64) # Batch of 4 images >>> output = discriminator(x) >>> print(output.shape) torch.Size([4, 1, 1, 1]) Notes ----- - The input images must match the specified `img_shape`. - The final output shape depends on the convolutional layers configuration. """ def __init__(self, img_shape: Tuple[int, int, int], conv_layers_config: List[dict]): """Initialize the Discriminator network. Parameters ---------- img_shape : Tuple[int, int, int] The shape of the input images as (channels, height, width). conv_layers_config : List[dict] A list of dictionaries specifying the configuration of each convolutional layer. Raises ------ RuntimeError If the ConfigurableCNN construction fails. """ super(Discriminator, self).__init__() _validate_img_shape(img_shape) _validate_conv_config(conv_layers_config) self.img_shape = tuple(img_shape) try: self.cnn = ConfigurableCNN(layers_config=conv_layers_config) except Exception as e: raise RuntimeError(f"Failed to construct ConfigurableCNN for discriminator: {e}") from e
[docs] def forward(self, img: torch.Tensor) -> torch.Tensor: """Forward pass through the Discriminator. Parameters ---------- img : torch.Tensor Input tensor representing a batch of images with shape (batch_size, channels, height, width). Returns ------- torch.Tensor Output tensor after passing through the discriminator network. Raises ------ TypeError If the input is not a torch.Tensor. ValueError If the input tensor does not have 4 dimensions. ValueError If the input tensor's shape does not match the expected image shape. """ if not torch.is_tensor(img): raise TypeError(f"Expected img to be a torch.Tensor, got {type(img)}") if img.dim() != 4: raise ValueError(f"Expected img tensor to have 4 dims (B, C, H, W), got shape {tuple(img.shape)}") # Basic shape compatibility check (C, H, W) if tuple(img.shape[1:]) != self.img_shape: raise ValueError(f"Input images must have shape (B, {self.img_shape[0]}, {self.img_shape[1]}, {self.img_shape[2]}), " f"but got {tuple(img.shape)}") out = self.cnn(img) return out
[docs] class DenseNetEncoder(nn.Module): """Encoder based on the DenseNet121 architecture for feature extraction. This class encapsulates a DenseNet121 network (optionally pre-trained) and splits its layers into sequential blocks to facilitate access to feature maps at different spatial resolutions. Attributes ---------- features : nn.Sequential Original feature layers from DenseNet121. block1 : nn.Sequential First dense block and transition layer, reducing spatial resolution. block2 : nn.Sequential Second dense block and transition layer. block3 : nn.Sequential Third dense block and transition layer. """ def __init__(self, pretrained: bool = True): """Initializes the DenseNet121 encoder and extracts its feature blocks. Parameters ---------- pretrained : bool, optional If True, loads ImageNet pre-trained weights. Default is True. """ super(DenseNetEncoder, self).__init__() densenet = models.densenet121( weights=(models.DenseNet121_Weights.DEFAULT if pretrained else None) ) self.features = densenet.features self.block1 = nn.Sequential(self.features.denseblock1, self.features.transition1) self.block2 = nn.Sequential(self.features.denseblock2, self.features.transition2) self.block3 = nn.Sequential(self.features.denseblock3, self.features.transition3)
[docs] class SideBranch(nn.Module): """Lateral branch for multi-scale feature fusion. Performs downsampling through Average Pooling followed by a 1x1 projection convolution to adjust the channel dimension. This is used to inject high-resolution information into deeper stages of the network. Attributes ---------- proj : ConvBlock Convolutional block that performs the linear projection (without Batch Normalization). """ def __init__(self, in_channels: int, out_channels: int): """Initializes the lateral branch used for multi-scale feature fusion. Parameters ---------- in_channels : int Number of channels of the input feature map. out_channels : int Number of output channels after the 1x1 projection. """ super(SideBranch, self).__init__() self.proj = ConvBlock( in_channels, out_channels, kernel_size=1, padding=0, activation="linear", use_batch_norm=False, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the SideBranch network. The input tensor is first downsampled by a factor of 2 using average pooling and then passed through a projection layer. Parameters ---------- x : torch.Tensor Input feature map tensor of shape (B, C, H, W). Returns ------- torch.Tensor Projected feature map tensor after downsampling and 1x1 convolution. """ x_ds = F.avg_pool2d(x, 2) return self.proj(x_ds)
[docs] class DecoderBlock(nn.Module): """Decoder block implementing a dense structure with optional upsampling. This block enriches feature maps through a local dense connection, concatenates the input with newly generated features, reduces dimensionality, and optionally upsamples the spatial resolution. Attributes ---------- dense : nn.Sequential Dense sub-block (Conv 1x1 -> Conv 3x3). up_trans : ConvTransposeBlock Transition block that reduces channels via a 1x1 projection. """ def __init__(self, in_channels: int, grow_channels: int, out_channels: int, upsample: bool = True): """Initializes the decoder block composed of a dense sub-block, a transition projection block, and an optional upsampling step. Parameters ---------- in_channels : int Number of input channels. grow_channels : int Growth rate defining how many new channels the dense sub-block generates. out_channels : int Number of output channels after the projection block. upsample : bool, optional If True, upsamples the output feature map by a factor of 2 using nearest-neighbor interpolation. Default is True. """ super(DecoderBlock, self).__init__() self.upsample = upsample self.dense = self._make_dense(in_channels, grow_channels) self.up_trans = ConvTransposeBlock( in_channels=in_channels + grow_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, activation="relu", use_batch_norm=True ) def _make_dense(self, in_c: int, grow_c: int) -> nn.Sequential: """Builds the dense sub-block formed by a 1x1 bottleneck convolution followed by a 3x3 convolution. Parameters ---------- in_c : int Number of input channels. grow_c : int Growth rate controlling the number of channels in the 3x3 convolution. Returns ------- nn.Sequential Dense convolutional block producing `grow_c` new feature channels. """ return nn.Sequential( ConvBlock( in_c, grow_c * 4, kernel_size=1, padding=0, activation="relu", use_batch_norm=True, ), ConvBlock( grow_c * 4, grow_c, kernel_size=3, padding=1, activation="relu", use_batch_norm=True, ), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the DecoderBlock. This method processes the input tensor by first applying a dense layer, concatenating its output with the original input tensor, and then upsampling the result using a transposed convolution. An optional additional upsampling step via nearest-neighbor interpolation can be applied. The input feature map tensor, expected to have a shape of (B, C, H, W), where B is the batch size, C is the number of channels, and H, W are the height and width. The output upsampled feature map tensor. The spatial dimensions (H, W) are increased, and the number of channels is modified by the layers within the block. Parameters ---------- x : torch.Tensor Input feature map tensor of shape (B, C, H, W). Returns ------- torch.Tensor Output upsampled feature map tensor with increased spatial dimensions (H, W). """ dense_feat = self.dense(x) x_dense = torch.cat([x, dense_feat], dim=1) out = self.up_trans(x_dense) if self.upsample: out = F.interpolate(out, scale_factor=2, mode="nearest") return out
[docs] class FDGANGenerator(nn.Module): """Generator module for the FD-GAN (Fusion-Discriminator GAN) architecture. Implements a densely connected U-Net-like structure with lateral side branches for multi-scale feature fusion. Designed for image-to-image translation tasks, such as haze removal (dehazing). Attributes ---------- encoder : DenseNetEncoder Pre-trained backbone used for feature extraction. conv_in : ConvBlock Initial input convolution preserving spatial resolution. side_branch1, side_branch2 : SideBranch Lateral branches for processing and fusing low-level features. fusion_x1, fusion_bottleneck : ConvBlock Fusion blocks combining side-branch features with the main encoder stream. block4, block5, block6 : DecoderBlock Decoder stages for progressively recovering spatial resolution. final_head : nn.Sequential Final projection layers mapping features to the RGB image space. Example ------- >>> generator = FDGANGenerator(output_same_size=True) >>> x = torch.randn(4, 3, 256, 256) # Batch of 4 images >>> output = generator(x) >>> print(output.shape) torch.Size([4, 3, 256, 256]) References ---------- Dong, Y., Liu, Y., Zhang, H., Chen, S., & Qiao, Y. (2020). *FD-GAN: Generative adversarial networks with fusion-discriminator for single image dehazing*. AAAI Conference on Artificial Intelligence, 34(07), 10729-10736. """ def __init__(self, output_same_size=True): """Initializes the FD-GAN generator composed of a DenseNet-based encoder, lateral side branches, multi-scale fusion modules, decoder stages, and a final reconstruction head. Parameters ---------- output_same_size : bool, optional Reserved flag indicating whether the output should match the input spatial resolution. The current implementation always preserves size. Default is True. """ super(FDGANGenerator, self).__init__() self.output_same_size = output_same_size self.encoder = DenseNetEncoder() self.conv_in = ConvBlock(3, 64, kernel_size=3, padding=1, activation="relu", use_batch_norm=False) self.side_branch1 = SideBranch(in_channels=64, out_channels=32) self.side_branch2 = SideBranch(in_channels=256, out_channels=128) self.fusion_x1 = ConvBlock( 32 + 128, 128, kernel_size=3, padding=1, activation="linear", use_batch_norm=False, ) self.fusion_bottleneck = ConvBlock( 512 + 128, 512, kernel_size=3, padding=1, activation="linear", use_batch_norm=False, ) self.block4 = DecoderBlock(in_channels=512, grow_channels=256, out_channels=256, upsample=True) self.block5 = DecoderBlock(in_channels=512, grow_channels=128, out_channels=128, upsample=True) self.block6 = DecoderBlock(in_channels=256, grow_channels=64, out_channels=64, upsample=True) self.final_head = nn.Sequential( ConvBlock(64, 32, kernel_size=3, padding=1, activation="leakyrelu"), nn.Conv2d(32, 3, kernel_size=3, padding=1), nn.Tanh(), )
[docs] def forward(self, img: torch.Tensor) -> torch.Tensor: """Forward pass through the FD-GAN Generator. This method implements the U-Net like architecture of the FD-GAN generator, which includes an encoder, a decoder, and specialized side-branch and fusion modules to combine features from different scales. The process is as follows: 1. The input image is passed through an initial convolution. 2. The result is processed by three encoder blocks to extract features. 3. Side branches process features from early encoder stages (`x0`, `x2`). 4. Fusion modules combine these side-branch features with deeper features to create inputs for the decoder and a fused skip connection. 5. The decoder, consisting of three blocks, reconstructs the image, using skip connections from the encoder and the fused feature maps. 6. A final head layer produces the output image. Parameters ---------- img : torch.Tensor Input normalized image in the range [-1, 1]. Expected shape: (B, 3, H, W), where H and W are multiples of 32. Returns ------- torch.Tensor Generated image in the range [-1, 1] with the same spatial resolution as the input. """ original_size = img.shape[-2:] # (H, W) x0 = self.conv_in(img) x1 = self.encoder.block1(x0) x2 = self.encoder.block2(x1) x3 = self.encoder.block3(x2) f_x0_side = self.side_branch1(x0) x1_fused = self.fusion_x1(torch.cat([f_x0_side, x1], dim=1)) f_x2_side = self.side_branch2(x2) bottleneck_in = self.fusion_bottleneck(torch.cat([x3, f_x2_side], dim=1)) d4 = self.block4(bottleneck_in) d4_skip = torch.cat([d4, x2], dim=1) d5 = self.block5(d4_skip) d5_skip = torch.cat([d5, x1_fused], dim=1) d6 = self.block6(d5_skip) output = self.final_head(d6) return output