import os
from typing import Callable, Union
import kornia.filters
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
from PIL import Image
from torch.utils.data import Dataset
ACTIVATION_MAP = {
"relu": nn.ReLU,
"leakyrelu": nn.LeakyReLU,
"tanh": nn.Tanh,
"sigmoid": nn.Sigmoid,
"gelu": nn.GELU,
"selu": nn.SELU,
"elu": nn.ELU,
"identity": nn.Identity,
}
POOLING_MAP = {"max": nn.MaxPool2d, "avg": nn.AvgPool2d}
[docs]
def get_lf_hf(images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute the low- and high-frequency components of a batch of images.
This function extracts two complementary frequency representations from an
input batch of images. The low-frequency (LF) component is obtained using a
Gaussian blur, which removes fine details while preserving coarse structure.
The high-frequency (HF) component is computed by applying a Laplacian filter
to the grayscale version of the input and then replicating the result across
the RGB channels. Because the Laplacian may produce high-magnitude values,
the HF output is normalized to the range ``[-1, 1]`` using a hyperbolic
tangent to ensure numerical stability during training.
Parameters
----------
images : torch.Tensor
Input batch of images with shape ``(B, C, H, W)`` and RGB channels.
Returns
-------
tuple[torch.Tensor, torch.Tensor]
A pair ``(lf, hf)`` where:
- ``lf`` is the low-frequency component obtained via Gaussian blur.
- ``hf`` is the high-frequency component derived from the Laplacian and
normalized to ``[-1, 1]``.
Notes
-----
- The HF component is expanded to three channels so that both LF and HF can
be concatenated with the original RGB image if needed (e.g., for
discriminators in GAN architectures).
- Normalization with ``tanh`` prevents unstable gradients that may arise
from the Laplacian operator.
"""
lf = TF.gaussian_blur(images, kernel_size=15, sigma=3.0)
gray = TF.rgb_to_grayscale(images)
laplacian = kornia.filters.laplacian(gray, kernel_size=3)
hf = laplacian.repeat(1, 3, 1, 1)
hf = torch.tanh(hf)
return lf, hf
[docs]
def weights_init_normal(m: nn.Module) -> None:
"""Initialize model weights following a normal distribution scheme.
This function is intended to be passed to ``nn.Module.apply`` to initialize
convolutional and batch-normalization layers commonly used in GANs.
Convolutional layers are initialized with a normal distribution centered at
0, while batch-normalization layers are initialized with weights centered at
1 and biases at 0.
Parameters
----------
m : nn.Module
A PyTorch module. If ``m`` is an instance of ``Conv2d`` or
``ConvTranspose2d``, its weights are initialized using
``N(0, 0.02)``. If ``m`` is a ``BatchNorm2d`` or ``BatchNorm1d`` layer,
its weights are initialized using ``N(1, 0.02)`` and its biases are set
to zero.
Notes
-----
- This initialization scheme follows standard practices in GAN training
(e.g., DCGAN), improving stability during early optimization.
- Use as: ``model.apply(weights_init_normal)``.
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
if m.bias is not None:
torch.nn.init.constant_(m.bias.data, 0.0)
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
[docs]
class DehazingDataset(Dataset):
"""A paired-image dataset for single-image dehazing tasks.
This dataset assumes a directory structure:
``root_dir/clear`` — clean (ground-truth) images
``root_dir/hazy`` — corresponding hazy/degraded images
Both subdirectories must contain images with matching filenames.
The dataset returns pairs ``(hazy, clear)`` suitable for supervised
training of dehazing models.
Parameters
----------
root_dir : str or Path-like
Path to the dataset root directory containing ``clear/`` and ``hazy/``.
transform : callable, optional
A function or transform applied to both images. It must accept and
return a PIL image or a tensor. When provided, the same transform is
applied to both hazy and clear images.
Attributes
----------
root_dir : str
Path to the dataset root.
clear_dir : str
Directory containing the clean images.
hazy_dir : str
Directory containing the hazy images.
images : list of str
Filenames present in ``clear_dir`` (matching files are expected in
``hazy_dir``).
transform : callable or None
Transform applied to both images in a sample.
"""
def __init__(self, root_dir: str, transform: Union[Callable, None] = None):
self.root_dir = root_dir
self.transform = transform
self.clear_dir = os.path.join(root_dir, "clear")
self.hazy_dir = os.path.join(root_dir, "hazy")
self.images = [
f for f in os.listdir(self.clear_dir) if f.endswith((".jpg", ".png"))
]
def __len__(self) -> int:
"""Return the number of image pairs in the dataset."""
return len(self.images)
def __getitem__(
self, idx: int
) -> tuple[Union[torch.Tensor, Image.Image], Union[torch.Tensor, Image.Image]]:
"""Load and return a hazy/clear image pair.
Parameters
----------
idx : int
Index of the image pair to load.
Returns
-------
tuple[torch.Tensor or PIL.Image, torch.Tensor or PIL.Image]
A tuple ``(hazy_img, clear_img)``. If a transform is provided,
both elements are tensors. Otherwise, they are returned as PIL
images converted to RGB.
"""
img_name = self.images[idx]
clear_path = os.path.join(self.clear_dir, img_name)
hazy_path = os.path.join(self.hazy_dir, img_name)
clear_img = Image.open(clear_path).convert("RGB")
hazy_img = Image.open(hazy_path).convert("RGB")
if self.transform:
clear_img = self.transform(clear_img)
hazy_img = self.transform(hazy_img)
return hazy_img, clear_img