Source code for src.losses

import torch
import torch.nn as nn
import torch.nn.functional as F
from kornia import losses
from torchvision import models


[docs] class PerceptualLoss(nn.Module): """Computes a perceptual similarity metric using early-layer features of VGG16. This loss extracts features from the `relu1_2` layer of a pretrained VGG16 network and applies an L1 distance between the feature activations of two images. Notes ----- - The VGG16 backbone is loaded with ImageNet weights and frozen. - Only the first four layers of `vgg.features` are used, corresponding to the `relu1_2` activation described in the original FD-GAN paper. - Input images are expected to be normalized to the appropriate domain before calling this loss (e.g., `[0, 1]` range if matching standard VGG preprocessing). References ---------- Dong, Y., Liu, Y., Zhang, H., Chen, S., & Qiao, Y. (2020). *FD-GAN: Generative adversarial networks with fusion-discriminator for single image dehazing*. AAAI Conference on Artificial Intelligence, 34(07), 10729-10736. Examples -------- >>> loss_fn = PerceptualLoss() >>> x = torch.rand(1, 3, 224, 224) >>> y = torch.rand(1, 3, 224, 224) >>> loss = loss_fn(x, y) >>> loss.backward() """ def __init__(self): super(PerceptualLoss, self).__init__() vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT) self.feature_extractor = nn.Sequential(*list(vgg.features)[:4]).eval() for param in self.feature_extractor.parameters(): param.requires_grad = False
[docs] def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes the perceptual L1 difference between the VGG16 features of two inputs. Parameters ---------- x : torch.Tensor Input tensor of shape ``(B, C, H, W)``. y : torch.Tensor Input tensor of shape ``(B, C, H, W)``. Returns ------- torch.Tensor Scalar tensor representing the perceptual loss value. """ return F.l1_loss(self.feature_extractor(x), self.feature_extractor(y))
[docs] class FDGANLoss(nn.Module): """Combined loss function used in FD-GAN for single-image dehazing. This class implements the generator loss described in the FD-GAN paper, which combines multiple components: - L1 loss - SSIM loss (via Kornia) - Perceptual VGG-based loss - Adversarial BCE loss The final generator loss is computed as:: L_total = 2 * L1 + 1 * L_ssim + 2 * L_percep + 0.1 * L_adv Notes ----- - Inputs for SSIM and perceptual loss must be in the range ``[0, 1]``. - Inputs are assumed to be in ``[-1, 1]`` and internally rescaled. - The perceptual loss uses frozen VGG16 early-layer features. References ---------- Dong, Y., Liu, Y., Zhang, H., Chen, S., & Qiao, Y. (2020). *FD-GAN: Generative adversarial networks with fusion-discriminator for single image dehazing*. AAAI Conference on Artificial Intelligence, 34(07), 10729-10736. Examples -------- >>> criterion = FDGANLoss() >>> fake = torch.rand(1, 3, 256, 256) * 2 - 1 >>> real = torch.rand(1, 3, 256, 256) * 2 - 1 >>> pred = torch.randn(1, 1) >>> loss, components = criterion(fake, real, pred) >>> loss.backward() """ def __init__(self): super(FDGANLoss, self).__init__() self.l1 = nn.L1Loss() self.ssim = losses.SSIMLoss(window_size=11, reduction="mean") self.perceptual = PerceptualLoss() self.bce = nn.BCEWithLogitsLoss() self.perceptual.to(torch.device("cuda"))
[docs] def forward(self, fake, real, disc_pred): """Compute the combined FD-GAN generator loss. Parameters ---------- fake : torch.Tensor Generated image tensor in the range ``[-1, 1]``. real : torch.Tensor Ground-truth image tensor in the range ``[-1, 1]``. disc_pred : torch.Tensor Discriminator predictions on the generated images. Returns ------- tuple A tuple ``(total_loss, components)`` where: - **total_loss** : torch.Tensor Scalar tensor with the final weighted sum. - **components** : dict Dictionary containing each individual loss term:: { "l1": L1 loss, "ssim": SSIM loss, "percep": perceptual loss, "adv": adversarial BCE loss } """ lambda_l1, lambda_ssim, lambda_percep, lambda_adv = 2.0, 1.0, 2.0, 0.1 loss_l1 = self.l1(fake, real) # Denormalize to [0, 1] for SSIM and Perceptual Loss fake_01 = (fake + 1) / 2 real_01 = (real + 1) / 2 loss_ssim = self.ssim(fake_01, real_01) loss_percep = self.perceptual(fake_01, real_01) # Adversarial BCE Loss loss_adv = self.bce(disc_pred, torch.ones_like(disc_pred)) total_loss = ( (lambda_l1 * loss_l1) + (lambda_ssim * loss_ssim) + (lambda_percep * loss_percep) + (lambda_adv * loss_adv) ) return total_loss, { "l1": loss_l1, "ssim": loss_ssim, "percep": loss_percep, "adv": loss_adv, }
[docs] def denoising_score_matching_loss(score_net, clean_samples, sigma=0.1): """Computes denoising score matching (DSM) loss. This loss perturbs samples with Gaussian noise and trains the network to predict the added noise. This approximates score matching without requiring explicit computation of the data score. Parameters ---------- score_net : callable Neural network that maps noisy samples to predicted noise. clean_samples : torch.Tensor Clean input samples of any shape. sigma : float, optional Standard deviation of the Gaussian perturbation. Default is ``0.1``. Returns ------- torch.Tensor Scalar loss value representing the MSE between predicted and actual noise. Notes ----- This formulation follows standard denoising score matching used in diffusion models and score-based generative modeling. Examples -------- >>> net = lambda x: torch.zeros_like(x) >>> x = torch.randn(8, 3, 32, 32) >>> loss = denoising_score_matching_loss(net, x) """ noise = torch.randn_like(clean_samples) perturbed_samples = clean_samples + noise * sigma predicted_noise = score_net(perturbed_samples) loss = F.mse_loss(predicted_noise, noise) return loss