learned_stress¶
Learned stress-factor module.
Replaces the empirical water-stress curve (TRANRF) with a small MLP that
maps soil-water descriptors to a stress factor in \([0, 1]\) via a
sigmoid output.
LearnedStressFactor (Module)
¶
MLP-based replacement for empirical water-stress curves.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim |
int |
Number of scalar inputs concatenated along the feature
axis — typically |
4 |
hidden_dim |
int |
Hidden-layer width. |
32 |
n_hidden |
int |
Number of hidden layers (each followed by |
2 |
Source code in torchcrop/nn/learned_stress.py
class LearnedStressFactor(nn.Module):
"""MLP-based replacement for empirical water-stress curves.
Args:
input_dim: Number of scalar inputs concatenated along the feature
axis — typically ``2 + len(extra)`` for the mechanistic
``tranrf``/``nstress`` pair plus optional context tensors.
hidden_dim: Hidden-layer width.
n_hidden: Number of hidden layers (each followed by ``Tanh``).
"""
def __init__(
self,
input_dim: int = 4,
hidden_dim: int = 32,
n_hidden: int = 2,
) -> None:
super().__init__()
layers: list[nn.Module] = []
d = input_dim
for _ in range(n_hidden):
layers += [nn.Linear(d, hidden_dim), nn.Tanh()]
d = hidden_dim
layers.append(nn.Linear(d, 1))
self.mlp = nn.Sequential(*layers)
def forward(self, tranrf: torch.Tensor, nstress: torch.Tensor, *extra: torch.Tensor) -> torch.Tensor:
"""Return a combined stress factor in ``[0, 1]``.
The mechanistic inputs (``tranrf``, ``nstress``) are stacked with any
extra context tensors along the last axis and passed through the MLP,
whose scalar output is squashed by a sigmoid.
Args:
tranrf: Water-stress factor in ``[0, 1]`` from
`WaterBalance`, shape ``[B]``.
nstress: Nutrient-stress factor in ``[0, 1]`` from
`NutrientDemand`, shape ``[B]``.
*extra: Optional additional context tensors, each of shape
``[B]``.
Returns:
Combined stress factor in ``[0, 1]``, shape ``[B]``.
"""
stacked = torch.stack([tranrf, nstress, *extra], dim=-1)
return torch.sigmoid(self.mlp(stacked)).squeeze(-1)
forward(self, tranrf, nstress, *extra)
¶
Return a combined stress factor in [0, 1].
The mechanistic inputs (tranrf, nstress) are stacked with any
extra context tensors along the last axis and passed through the MLP,
whose scalar output is squashed by a sigmoid.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tranrf |
torch.Tensor |
Water-stress factor in |
required |
nstress |
torch.Tensor |
Nutrient-stress factor in |
required |
*extra |
torch.Tensor |
Optional additional context tensors, each of shape
|
() |
Returns:
| Type | Description |
|---|---|
torch.Tensor |
Combined stress factor in |
Source code in torchcrop/nn/learned_stress.py
def forward(self, tranrf: torch.Tensor, nstress: torch.Tensor, *extra: torch.Tensor) -> torch.Tensor:
"""Return a combined stress factor in ``[0, 1]``.
The mechanistic inputs (``tranrf``, ``nstress``) are stacked with any
extra context tensors along the last axis and passed through the MLP,
whose scalar output is squashed by a sigmoid.
Args:
tranrf: Water-stress factor in ``[0, 1]`` from
`WaterBalance`, shape ``[B]``.
nstress: Nutrient-stress factor in ``[0, 1]`` from
`NutrientDemand`, shape ``[B]``.
*extra: Optional additional context tensors, each of shape
``[B]``.
Returns:
Combined stress factor in ``[0, 1]``, shape ``[B]``.
"""
stacked = torch.stack([tranrf, nstress, *extra], dim=-1)
return torch.sigmoid(self.mlp(stacked)).squeeze(-1)