engine¶
Simulation engine — orchestrates the daily time-stepping loop.
The engine owns no trainable parameters of its own; it wires together the
process sub-modules supplied by Lintul5Model.
SimulationEngine (Module)
¶
Daily time-stepping loop.
The engine expects a callable compute_rates(state, weather_day, doy)
returning a dict of rate tensors indexed by state-field name, plus a
callable update_state(state, rates, dt).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
compute_rates |
Callable[..., dict[str, torch.Tensor]] |
Callable returning per-day rate tensors for a given
|
required |
update_state |
Callable[..., ModelState] |
Callable applying an integration step
|
required |
dt |
float |
Integration step size in days. Defaults to |
1.0 |
Source code in torchcrop/engine.py
class SimulationEngine(nn.Module):
"""Daily time-stepping loop.
The engine expects a callable ``compute_rates(state, weather_day, doy)``
returning a dict of rate tensors indexed by state-field name, plus a
callable ``update_state(state, rates, dt)``.
Args:
compute_rates: Callable returning per-day rate tensors for a given
``(state, weather_day, doy, params...)`` tuple.
update_state: Callable applying an integration step
``(state, rates, dt) -> ModelState``.
dt: Integration step size in days. Defaults to ``1.0``.
"""
def __init__(
self,
compute_rates: Callable[..., dict[str, torch.Tensor]],
update_state: Callable[..., ModelState],
dt: float = 1.0,
) -> None:
super().__init__()
self._compute_rates = compute_rates
self._update_state = update_state
self.dt = dt
def step(
self,
state: ModelState,
weather_day: dict[str, torch.Tensor],
doy: torch.Tensor,
crop_params: CropParameters,
soil_params: SoilParameters,
site_params: SiteParameters,
) -> StepResult:
rates = self._compute_rates(
state=state,
weather_day=weather_day,
doy=doy,
crop_params=crop_params,
soil_params=soil_params,
site_params=site_params,
)
new_state = self._update_state(state, rates, self.dt)
return StepResult(state=new_state, rates=rates)
def run(
self,
state: ModelState,
weather: WeatherDriver,
start_doy: int,
crop_params: CropParameters,
soil_params: SoilParameters,
site_params: SiteParameters,
) -> tuple[list[ModelState], list[dict[str, torch.Tensor]]]:
"""Run the full trajectory.
Args:
state: Initial `ModelState` at day 0.
weather: `WeatherDriver` carrying the daily forcing.
start_doy: Day-of-year of the first simulated day.
crop_params: Species-specific crop parameters.
soil_params: Soil-specific parameters.
site_params: Site-level parameters (e.g. latitude).
Returns:
A ``(states, rates)`` tuple where ``states`` is a list of length
``T + 1`` of per-day `ModelState` snapshots (the first
entry is the initial state) and ``rates`` is a list of length
``T`` of per-day rate dicts.
"""
states: list[ModelState] = [state]
rates_all: list[dict[str, torch.Tensor]] = []
n_days = weather.n_days
for t in range(n_days):
weather_day = weather.day(t)
doy_t = torch.full_like(
state.dvs,
float(((start_doy - 1 + t) % 365) + 1),
)
result = self.step(
state=states[-1],
weather_day=weather_day,
doy=doy_t,
crop_params=crop_params,
soil_params=soil_params,
site_params=site_params,
)
states.append(result.state)
rates_all.append(result.rates)
return states, rates_all
run(self, state, weather, start_doy, crop_params, soil_params, site_params)
¶
Run the full trajectory.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state |
ModelState |
Initial |
required |
weather |
WeatherDriver |
|
required |
start_doy |
int |
Day-of-year of the first simulated day. |
required |
crop_params |
CropParameters |
Species-specific crop parameters. |
required |
soil_params |
SoilParameters |
Soil-specific parameters. |
required |
site_params |
SiteParameters |
Site-level parameters (e.g. latitude). |
required |
Returns:
| Type | Description |
|---|---|
tuple[list[ModelState], list[dict[str, torch.Tensor]]] |
A |
Source code in torchcrop/engine.py
def run(
self,
state: ModelState,
weather: WeatherDriver,
start_doy: int,
crop_params: CropParameters,
soil_params: SoilParameters,
site_params: SiteParameters,
) -> tuple[list[ModelState], list[dict[str, torch.Tensor]]]:
"""Run the full trajectory.
Args:
state: Initial `ModelState` at day 0.
weather: `WeatherDriver` carrying the daily forcing.
start_doy: Day-of-year of the first simulated day.
crop_params: Species-specific crop parameters.
soil_params: Soil-specific parameters.
site_params: Site-level parameters (e.g. latitude).
Returns:
A ``(states, rates)`` tuple where ``states`` is a list of length
``T + 1`` of per-day `ModelState` snapshots (the first
entry is the initial state) and ``rates`` is a list of length
``T`` of per-day rate dicts.
"""
states: list[ModelState] = [state]
rates_all: list[dict[str, torch.Tensor]] = []
n_days = weather.n_days
for t in range(n_days):
weather_day = weather.day(t)
doy_t = torch.full_like(
state.dvs,
float(((start_doy - 1 + t) % 365) + 1),
)
result = self.step(
state=states[-1],
weather_day=weather_day,
doy=doy_t,
crop_params=crop_params,
soil_params=soil_params,
site_params=site_params,
)
states.append(result.state)
rates_all.append(result.rates)
return states, rates_all
StepResult
dataclass
¶
Outputs of a single simulation step.
Attributes:
| Name | Type | Description |
|---|---|---|
state |
ModelState |
The updated |
rates |
dict |
Dict of rate tensors produced by the process modules for the
current day, keyed by state-field name (e.g. |
Source code in torchcrop/engine.py
@dataclass
class StepResult:
"""Outputs of a single simulation step.
Attributes:
state: The updated `ModelState` after applying the Euler step.
rates: Dict of rate tensors produced by the process modules for the
current day, keyed by state-field name (e.g. ``dvs_rate``).
"""
state: ModelState
rates: dict[str, torch.Tensor]
euler_update(state, rates, dt)
¶
Forward-Euler update of a ModelState.
Rate keys must match state field names with a _rate suffix (e.g.
dvs_rate updates dvs). Fields without a matching rate are left
unchanged.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state |
ModelState |
Current |
required |
rates |
dict[str, torch.Tensor] |
Dict of rate tensors keyed by |
required |
dt |
float |
Integration step size in days. |
required |
Returns:
| Type | Description |
|---|---|
ModelState |
A new |
Source code in torchcrop/engine.py
def euler_update(state: ModelState, rates: dict[str, torch.Tensor], dt: float) -> ModelState:
"""Forward-Euler update of a `ModelState`.
Rate keys must match state field names with a ``_rate`` suffix (e.g.
``dvs_rate`` updates ``dvs``). Fields without a matching rate are left
unchanged.
Args:
state: Current `ModelState`.
rates: Dict of rate tensors keyed by ``"<field>_rate"``.
dt: Integration step size in days.
Returns:
A new `ModelState` with all matched fields advanced by
``rates[field + "_rate"] * dt``. Physically non-negative fields are
clamped to ``>= 0`` and ``dvs`` is clamped to ``[0, 2]``.
"""
updates: dict[str, torch.Tensor] = {}
for f in fields(state):
rate_key = f"{f.name}_rate"
current = getattr(state, f.name)
if rate_key in rates and isinstance(current, torch.Tensor):
new_val = current + dt * rates[rate_key]
# Non-negative clamping where it makes physical sense
if f.name in {
"tsum",
"tsump",
"vern",
"wlv",
"wlvd",
"wst",
"wrt",
"wso",
"lai",
"rootd",
"wa",
"wa_lower",
"anlv",
"anst",
"anrt",
"anso",
"aplv",
"apst",
"aprt",
"apso",
"aklv",
"akst",
"akrt",
"akso",
"tran_cum",
"evap_cum",
}:
new_val = torch.clamp(new_val, min=0.0)
if f.name == "dvs":
new_val = torch.clamp(new_val, min=0.0, max=2.0)
if f.name == "dslr":
new_val = torch.clamp(new_val, min=1.0)
if f.name == "dsos":
new_val = torch.clamp(new_val, min=0.0, max=4.0)
updates[f.name] = new_val
return state.replace(**updates)