param_net¶
Parameter net — predict crop parameters from site/weather embeddings.
ParameterNet (Module)
¶
MLP that maps a context embedding to a vector of crop parameters.
This is useful when you want to condition parameters on site/cultivar metadata (e.g., a learned embedding for each cultivar identifier).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embedding_dim |
int |
Size of the input context embedding. |
required |
n_params |
int |
Number of crop parameters to predict. |
required |
hidden_dim |
int |
Hidden-layer width. |
64 |
n_hidden |
int |
Number of hidden layers (each followed by |
2 |
Source code in torchcrop/nn/param_net.py
class ParameterNet(nn.Module):
"""MLP that maps a context embedding to a vector of crop parameters.
This is useful when you want to condition parameters on site/cultivar
metadata (e.g., a learned embedding for each cultivar identifier).
Args:
embedding_dim: Size of the input context embedding.
n_params: Number of crop parameters to predict.
hidden_dim: Hidden-layer width.
n_hidden: Number of hidden layers (each followed by ``Tanh``).
"""
def __init__(
self,
embedding_dim: int,
n_params: int,
hidden_dim: int = 64,
n_hidden: int = 2,
) -> None:
super().__init__()
layers: list[nn.Module] = []
d = embedding_dim
for _ in range(n_hidden):
layers += [nn.Linear(d, hidden_dim), nn.Tanh()]
d = hidden_dim
layers.append(nn.Linear(d, n_params))
self.mlp = nn.Sequential(*layers)
def forward(self, embedding: torch.Tensor) -> torch.Tensor:
"""Predict crop parameters from a context embedding.
Args:
embedding: Context embedding tensor of shape
``[..., embedding_dim]``.
Returns:
Predicted parameter vector of shape ``[..., n_params]``.
"""
return self.mlp(embedding)
forward(self, embedding)
¶
Predict crop parameters from a context embedding.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embedding |
torch.Tensor |
Context embedding tensor of shape
|
required |
Returns:
| Type | Description |
|---|---|
torch.Tensor |
Predicted parameter vector of shape |
Source code in torchcrop/nn/param_net.py
def forward(self, embedding: torch.Tensor) -> torch.Tensor:
"""Predict crop parameters from a context embedding.
Args:
embedding: Context embedding tensor of shape
``[..., embedding_dim]``.
Returns:
Predicted parameter vector of shape ``[..., n_params]``.
"""
return self.mlp(embedding)