Skip to content

residual

Neural residual correction networks.

A NeuralResidual adds a small MLP-based correction to the output of a mechanistic process, scaled by tanh to keep corrections bounded.

NeuralResidual (Module)

Bounded additive residual from an MLP.

The residual output is produced by

\[ f_ heta(\mathbf{x}) = ext{scale} \cdot anh( ext{MLP}(\mathbf{x})) \]

which keeps the correction bounded to [-scale, +scale] so that the learned term cannot overwhelm the mechanistic prediction it augments.

Parameters:

Name Type Description Default
input_dim int

Size of the context feature vector fed to the MLP.

required
output_dim int

Dimensionality of the residual correction (default 1).

1
hidden_dim int

Hidden-layer width.

32
n_hidden int

Number of hidden layers (each followed by Tanh).

2
scale float

Magnitude cap on the residual output.

1.0
Source code in torchcrop/nn/residual.py
class NeuralResidual(nn.Module):
    """Bounded additive residual from an MLP.

    The residual output is produced by

    $$
    f_\theta(\mathbf{x}) = \text{scale} \cdot \tanh(\text{MLP}(\mathbf{x}))
    $$

    which keeps the correction bounded to ``[-scale, +scale]`` so that the
    learned term cannot overwhelm the mechanistic prediction it augments.

    Args:
        input_dim: Size of the context feature vector fed to the MLP.
        output_dim: Dimensionality of the residual correction (default 1).
        hidden_dim: Hidden-layer width.
        n_hidden: Number of hidden layers (each followed by ``Tanh``).
        scale: Magnitude cap on the residual output.
    """

    def __init__(
        self,
        input_dim: int,
        output_dim: int = 1,
        hidden_dim: int = 32,
        n_hidden: int = 2,
        scale: float = 1.0,
    ) -> None:
        super().__init__()
        layers: list[nn.Module] = []
        d = input_dim
        for _ in range(n_hidden):
            layers += [nn.Linear(d, hidden_dim), nn.Tanh()]
            d = hidden_dim
        layers.append(nn.Linear(d, output_dim))
        self.mlp = nn.Sequential(*layers)
        self.scale = scale

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Compute the bounded residual correction.

        Args:
            x: Context feature tensor of shape ``[..., input_dim]``.

        Returns:
            Residual correction of shape ``[..., output_dim]`` with values
            in ``[-scale, +scale]``.
        """
        return self.scale * torch.tanh(self.mlp(x))

forward(self, x)

Compute the bounded residual correction.

Parameters:

Name Type Description Default
x torch.Tensor

Context feature tensor of shape [..., input_dim].

required

Returns:

Type Description
torch.Tensor

Residual correction of shape [..., output_dim] with values in [-scale, +scale].

Source code in torchcrop/nn/residual.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Compute the bounded residual correction.

    Args:
        x: Context feature tensor of shape ``[..., input_dim]``.

    Returns:
        Residual correction of shape ``[..., output_dim]`` with values
        in ``[-scale, +scale]``.
    """
    return self.scale * torch.tanh(self.mlp(x))