Source code for src.mlp

from typing import Sequence, Union, Type
import torch

import torch.nn as nn

[docs] class MLP(nn.Module): """Multi-layer perceptron built from nn.Linear and activation layers. This module implements a feedforward neural network of the form: .. math:: f(x) = W_n \sigma( W_{n-1} \sigma(\dots \sigma(W_1 x + b_1)\dots ) + b_{n-1}) + b_n Attributes ---------- layers : nn.Sequential The sequential container holding all linear and activation layers. Example ------- >>> mlp = MLP(input_dim=10, hidden_dims=[20, 30], output_dim=5, activation=nn.ReLU) >>> x = torch.randn(4, 10) # Batch of 4 samples >>> output = mlp(x) >>> print(output.shape) torch.Size([4, 5]) Notes ----- - No activation is applied after the final (output) linear layer. - Activation arguments should be activation classes (callables that return nn.Module instances), not instantiated modules. """ def __init__( self, input_dim: int, hidden_dims: Union[int, Sequence[int]], output_dim: int, activation: Union[Type[nn.Module], Sequence[Type[nn.Module]]] = nn.ReLU, ) -> None: """Initialize the Multilayer Perceptron (MLP). Parameters ---------- input_dim : int Dimension of the input features. hidden_dims : int or Sequence[int] Size(s) of the hidden layer(s). output_dim : int Dimension of the output features. activation : Type[nn.Module] or Sequence[Type[nn.Module]], optional Activation function(s) to use. Defaults to nn.ReLU. Notes ----- The number of hidden layers equals ``len(hidden_dims)``. If ``activation`` is a single class, the same activation is applied to all hidden layers. """ super(MLP, self).__init__() if isinstance(hidden_dims, int): hidden_dims = [hidden_dims] dims = [input_dim] + list(hidden_dims) + [output_dim] if not isinstance(activation, (list, tuple)): activation = [activation] * (len(dims) - 2) assert len(activation) == len(dims) - 2, ( f"{len(dims) - 2} activation functions expected, " f"but {len(activation)} were received" ) layers = [] for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): layers.append(nn.Linear(in_dim, out_dim)) if i < len(dims) - 2: # No activation after the last layer layers.append(activation[i]()) self.layers = nn.Sequential(*layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the MLP. The MLP is applied independently to the last dimension, so additional batch dimensions are preserved. Parameters ---------- x : torch.Tensor Input tensor of shape (..., input_dim). Returns ------- torch.Tensor Output tensor of shape (..., output_dim). """ return self.layers(x)