src.gan

Classes

DecoderBlock(*args, **kwargs)

Decoder block implementing a dense structure with optional upsampling.

DenseNetEncoder(*args, **kwargs)

Encoder based on the DenseNet121 architecture for feature extraction.

Discriminator(*args, **kwargs)

Discriminator network for GANs using a configurable CNN architecture.

FDGANGenerator(*args, **kwargs)

Generator module for the FD-GAN (Fusion-Discriminator GAN) architecture.

SideBranch(*args, **kwargs)

Lateral branch for multi-scale feature fusion.

class src.gan.DecoderBlock(*args: Any, **kwargs: Any)[source]

Bases: Module

Decoder block implementing a dense structure with optional upsampling.

This block enriches feature maps through a local dense connection, concatenates the input with newly generated features, reduces dimensionality, and optionally upsamples the spatial resolution.

Variables:
  • dense (nn.Sequential) – Dense sub-block (Conv 1x1 -> Conv 3x3).

  • up_trans (ConvTransposeBlock) – Transition block that reduces channels via a 1x1 projection.

Initializes the decoder block composed of a dense sub-block, a transition projection block, and an optional upsampling step.

Parameters:
  • in_channels (int) – Number of input channels.

  • grow_channels (int) – Growth rate defining how many new channels the dense sub-block generates.

  • out_channels (int) – Number of output channels after the projection block.

  • upsample (bool, optional) – If True, upsamples the output feature map by a factor of 2 using nearest-neighbor interpolation. Default is True.

forward(x: torch.Tensor) torch.Tensor[source]

Forward pass through the DecoderBlock.

This method processes the input tensor by first applying a dense layer, concatenating its output with the original input tensor, and then upsampling the result using a transposed convolution. An optional additional upsampling step via nearest-neighbor interpolation can be applied. The input feature map tensor, expected to have a shape of (B, C, H, W), where B is the batch size, C is the number of channels, and H, W are the height and width. The output upsampled feature map tensor. The spatial dimensions (H, W) are increased, and the number of channels is modified by the layers within the block.

Parameters:

x (torch.Tensor) – Input feature map tensor of shape (B, C, H, W).

Returns:

Output upsampled feature map tensor with increased spatial dimensions (H, W).

Return type:

torch.Tensor

class src.gan.DenseNetEncoder(*args: Any, **kwargs: Any)[source]

Bases: Module

Encoder based on the DenseNet121 architecture for feature extraction.

This class encapsulates a DenseNet121 network (optionally pre-trained) and splits its layers into sequential blocks to facilitate access to feature maps at different spatial resolutions.

Variables:
  • features (nn.Sequential) – Original feature layers from DenseNet121.

  • block1 (nn.Sequential) – First dense block and transition layer, reducing spatial resolution.

  • block2 (nn.Sequential) – Second dense block and transition layer.

  • block3 (nn.Sequential) – Third dense block and transition layer.

Initializes the DenseNet121 encoder and extracts its feature blocks.

Parameters:

pretrained (bool, optional) – If True, loads ImageNet pre-trained weights. Default is True.

class src.gan.Discriminator(*args: Any, **kwargs: Any)[source]

Bases: Module

Discriminator network for GANs using a configurable CNN architecture.

This module implements a discriminator network that processes input images through a series of convolutional layers defined by a configuration list.

Variables:

cnn (ConfigurableCNN) – The configurable CNN used for feature extraction and classification.

Example

>>> img_shape = (3, 64, 64)  # Example image shape
>>> conv_layers_config = [
...     {'out_channels': 64, 'kernel_size': 4, 'stride': 2, 'padding': 1, 'activation': 'leakyrelu'},
...     {'out_channels': 128, 'kernel_size': 4, 'stride': 2, 'padding': 1, 'activation': 'leakyrelu'},
...     {'out_channels': 256, 'kernel_size': 4, 'stride': 2, 'padding': 1, 'activation': 'leakyrelu'},
...     {'out_channels': 1, 'kernel_size': 4, 'stride': 1, 'padding': 0, 'activation': 'linear'},
... ]
>>> discriminator = Discriminator(img_shape, conv_layers_config)
>>> x = torch.randn(4, 3, 64, 64)  # Batch of 4 images
>>> output = discriminator(x)
>>> print(output.shape)
torch.Size([4, 1, 1, 1])

Notes

  • The input images must match the specified img_shape.

  • The final output shape depends on the convolutional layers configuration.

Initialize the Discriminator network.

Parameters:
  • img_shape (Tuple[int, int, int]) – The shape of the input images as (channels, height, width).

  • conv_layers_config (List[dict]) – A list of dictionaries specifying the configuration of each convolutional layer.

Raises:

RuntimeError – If the ConfigurableCNN construction fails.

forward(img: torch.Tensor) torch.Tensor[source]

Forward pass through the Discriminator.

Parameters:

img (torch.Tensor) – Input tensor representing a batch of images with shape (batch_size, channels, height, width).

Returns:

Output tensor after passing through the discriminator network.

Return type:

torch.Tensor

Raises:
  • TypeError – If the input is not a torch.Tensor.

  • ValueError – If the input tensor does not have 4 dimensions.

  • ValueError – If the input tensor’s shape does not match the expected image shape.

class src.gan.FDGANGenerator(*args: Any, **kwargs: Any)[source]

Bases: Module

Generator module for the FD-GAN (Fusion-Discriminator GAN) architecture.

Implements a densely connected U-Net-like structure with lateral side branches for multi-scale feature fusion. Designed for image-to-image translation tasks, such as haze removal (dehazing).

Variables:
  • encoder (DenseNetEncoder) – Pre-trained backbone used for feature extraction.

  • conv_in (ConvBlock) – Initial input convolution preserving spatial resolution.

  • side_branch2 (side_branch1,) – Lateral branches for processing and fusing low-level features.

  • fusion_bottleneck (fusion_x1,) – Fusion blocks combining side-branch features with the main encoder stream.

  • block6 (block4, block5,) – Decoder stages for progressively recovering spatial resolution.

  • final_head (nn.Sequential) – Final projection layers mapping features to the RGB image space.

Example

>>> generator = FDGANGenerator(output_same_size=True)
>>> x = torch.randn(4, 3, 256, 256)  # Batch of 4 images
>>> output = generator(x)
>>> print(output.shape)
torch.Size([4, 3, 256, 256])

References

Dong, Y., Liu, Y., Zhang, H., Chen, S., & Qiao, Y. (2020).

FD-GAN: Generative adversarial networks with fusion-discriminator for single image dehazing. AAAI Conference on Artificial Intelligence, 34(07), 10729-10736.

Initializes the FD-GAN generator composed of a DenseNet-based encoder, lateral side branches, multi-scale fusion modules, decoder stages, and a final reconstruction head.

Parameters:

output_same_size (bool, optional) – Reserved flag indicating whether the output should match the input spatial resolution. The current implementation always preserves size. Default is True.

forward(img: torch.Tensor) torch.Tensor[source]

Forward pass through the FD-GAN Generator.

This method implements the U-Net like architecture of the FD-GAN generator, which includes an encoder, a decoder, and specialized side-branch and fusion modules to combine features from different scales.

The process is as follows:

  1. The input image is passed through an initial convolution.

  2. The result is processed by three encoder blocks to extract features.

  3. Side branches process features from early encoder stages (x0, x2).

  4. Fusion modules combine these side-branch features with deeper features

    to create inputs for the decoder and a fused skip connection.

  5. The decoder, consisting of three blocks, reconstructs the image, using

    skip connections from the encoder and the fused feature maps.

  6. A final head layer produces the output image.

Parameters:

img (torch.Tensor) – Input normalized image in the range [-1, 1]. Expected shape: (B, 3, H, W), where H and W are multiples of 32.

Returns:

Generated image in the range [-1, 1] with the same spatial resolution as the input.

Return type:

torch.Tensor

class src.gan.SideBranch(*args: Any, **kwargs: Any)[source]

Bases: Module

Lateral branch for multi-scale feature fusion.

Performs downsampling through Average Pooling followed by a 1x1 projection convolution to adjust the channel dimension. This is used to inject high-resolution information into deeper stages of the network.

Variables:

proj (ConvBlock) – Convolutional block that performs the linear projection (without Batch Normalization).

Initializes the lateral branch used for multi-scale feature fusion.

Parameters:
  • in_channels (int) – Number of channels of the input feature map.

  • out_channels (int) – Number of output channels after the 1x1 projection.

forward(x: torch.Tensor) torch.Tensor[source]

Forward pass through the SideBranch network.

The input tensor is first downsampled by a factor of 2 using average pooling and then passed through a projection layer.

Parameters:

x (torch.Tensor) – Input feature map tensor of shape (B, C, H, W).

Returns:

Projected feature map tensor after downsampling and 1x1 convolution.

Return type:

torch.Tensor