Skip to content

model

Top-level Lintul5 model.

Wires all process sub-modules together and provides both a high-level and a low-level API.

Lintul5Model (Module)

Differentiable reimplementation of the Lintul5 crop growth model.

Parameters:

Name Type Description Default
crop_params CropParameters | None

Crop parameter container (see torchcrop.parameters).

None
soil_params SoilParameters | None

Soil parameter container.

None
site_params SiteParameters | None

Site parameter container (e.g. latitude, altitude).

None
smooth bool

If True, use smooth (sigmoid-blend) replacements for stage-based branching.

False
stress_module nn.Module | None

Optional replacement for the default StressFactors combiner.

None
residual_modules dict[str, nn.Module] | None

Optional neural residual corrections keyed by process name ("photosynthesis" adds to gtotal; "partitioning" adds to the four allocation fractions; "leaf_dynamics" adds to lai_rate).

None
Source code in torchcrop/model.py
class Lintul5Model(nn.Module):
    """Differentiable reimplementation of the Lintul5 crop growth model.

    Args:
        crop_params: Crop parameter container (see `torchcrop.parameters`).
        soil_params: Soil parameter container.
        site_params: Site parameter container (e.g. latitude, altitude).
        smooth: If ``True``, use smooth (sigmoid-blend) replacements for
            stage-based branching.
        stress_module: Optional replacement for the default
            `StressFactors` combiner.
        residual_modules: Optional neural residual corrections keyed by
            process name (``"photosynthesis"`` adds to ``gtotal``;
            ``"partitioning"`` adds to the four allocation fractions;
            ``"leaf_dynamics"`` adds to ``lai_rate``).
    """

    def __init__(
        self,
        crop_params: CropParameters | None = None,
        soil_params: SoilParameters | None = None,
        site_params: SiteParameters | None = None,
        smooth: bool = False,
        stress_module: nn.Module | None = None,
        residual_modules: dict[str, nn.Module] | None = None,
    ) -> None:
        super().__init__()
        self.crop_params = crop_params or CropParameters()
        self.soil_params = soil_params or SoilParameters()
        self.site_params = site_params or SiteParameters()
        self.smooth = smooth

        self.astro = Astro()
        self.phenology = Phenology(smooth=smooth)
        self.irradiation = Irradiation()
        self.evapotranspiration = PotentialEvapoTranspiration()
        self.water_balance = WaterBalance()
        self.photosynthesis = Photosynthesis()
        self.partitioning = Partitioning()
        self.leaf_dynamics = LeafDynamics()
        self.root_dynamics = RootDynamics()
        self.nutrient_demand = NutrientDemand()
        self.stress = stress_module or StressFactors()

        self.residual_modules = nn.ModuleDict(residual_modules or {})

        self.engine = SimulationEngine(
            compute_rates=self._compute_rates_dispatch,
            update_state=euler_update,
            dt=1.0,
        )

    # ------------------------------------------------------------------ #
    # High-level API
    # ------------------------------------------------------------------ #

    def initialize(
        self,
        batch_size: int,
        dtype: torch.dtype = torch.float32,
        device: torch.device | str = "cpu",
    ) -> ModelState:
        """Build an initial state for a batch, using ``dvsi`` from crop params.

        Args:
            batch_size: Number of parallel simulation instances ``B``.
            dtype: Tensor dtype.
            device: Torch device (e.g. ``"cpu"``, ``"cuda"``).

        Returns:
            A fresh `ModelState` with initial DVS, root depth, soil
            water at field capacity, and a seeded leaf mass so that LAI
            growth has a substrate post-emergence.
        """
        dvsi = float(self.crop_params.dvsi.detach().cpu().item())
        rootdi = float(self.crop_params.rootdi.detach().cpu().item())
        # Initialise at field capacity × initial rooting depth (mm)
        wfc = float(self.soil_params.wcfc.detach().cpu().item())
        wai = 1000.0 * wfc * rootdi
        # Lower-zone initial water — SIMPLACE WTOTL = factor·(RDM − RDI)·SMLOWI
        rdmso = float(self.soil_params.rdmso.detach().cpu().item())
        rdmcr = float(self.crop_params.rdmcr.detach().cpu().item())
        rdm_val = min(rdmso, rdmcr)
        wci_lower = float(self.soil_params.wci_lower.detach().cpu().item())
        wa_lower_i = 1000.0 * max(rdm_val - rootdi, 1e-4) * wci_lower
        state = ModelState.initial(
            batch_size=batch_size,
            dtype=dtype,
            device=device,
            dvsi=dvsi,
            wai=wai,
            rootdi=rootdi,
            wa_lower_i=wa_lower_i,
            dslri=3.0,
            dsosi=0.0,
        )
        # Seed leaf mass so that LAI growth has a substrate post-emergence
        laii = float(self.crop_params.laii.detach().cpu().item())
        sla = float(self.crop_params.sla.detach().cpu().item())
        wlv0 = torch.full_like(state.wlv, laii / max(sla, 1e-6))
        lai0 = torch.full_like(state.lai, laii)
        return state.replace(wlv=wlv0, lai=lai0)

    def forward(
        self,
        weather: WeatherDriver | torch.Tensor,
        start_doy: int = 1,
        initial_state: ModelState | None = None,
    ) -> ModelOutput:
        """Run a full simulation and return trajectories plus final yield.

        Args:
            weather: `WeatherDriver` or a raw ``[B, T, C]`` tensor of
                daily weather forcing.
            start_doy: Day-of-year of the first simulated day.
            initial_state: Optional pre-built `ModelState`. When
                omitted, `initialize` is called automatically.

        Returns:
            A `ModelOutput` containing the full state/rate trajectories
            and summary variables (``lai``, ``dvs``, ``biomass``, ``yield_``).
        """
        if isinstance(weather, torch.Tensor):
            weather = WeatherDriver(weather)
        batch_size = weather.batch_size
        if initial_state is None:
            state = self.initialize(
                batch_size=batch_size,
                dtype=weather.data.dtype,
                device=weather.data.device,
            )
        else:
            state = initial_state

        states, rates = self.engine.run(
            state=state,
            weather=weather,
            start_doy=start_doy,
            crop_params=self.crop_params,
            soil_params=self.soil_params,
            site_params=self.site_params,
        )

        lai = torch.stack([s.lai for s in states], dim=1)  # [B, T+1]
        dvs = torch.stack([s.dvs for s in states], dim=1)
        biomass = torch.stack([s.wlv + s.wst + s.wso for s in states], dim=1)
        yield_ = states[-1].wso

        return ModelOutput(
            states=states,
            rates=rates,
            yield_=yield_,
            lai=lai,
            dvs=dvs,
            biomass=biomass,
        )

    # ------------------------------------------------------------------ #
    # Low-level API — single-step rate + state update
    # ------------------------------------------------------------------ #

    def compute_rates(
        self,
        state: ModelState,
        weather_day: dict[str, torch.Tensor],
        doy: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        """Compute the rate vector for a single day (low-level API).

        Args:
            state: Current `ModelState`.
            weather_day: Dict of named weather channels for the current day
                (see `WEATHER_CHANNELS`), each of shape ``[B]``.
            doy: Day-of-year tensor of shape ``[B]``.

        Returns:
            Dict of rate tensors keyed by ``"<field>_rate"`` plus diagnostics
            (``tranrf``, ``nstress``, ``gtotal``).
        """
        return self._compute_rates_dispatch(
            state=state,
            weather_day=weather_day,
            doy=doy,
            crop_params=self.crop_params,
            soil_params=self.soil_params,
            site_params=self.site_params,
        )

    def update_state(
        self,
        state: ModelState,
        rates: dict[str, torch.Tensor],
        dt: float = 1.0,
    ) -> ModelState:
        """Apply a forward-Euler step to advance the state by ``dt`` days.

        Args:
            state: Current `ModelState`.
            rates: Dict of rate tensors produced by `compute_rates`.
            dt: Integration step in days.

        Returns:
            A new `ModelState` advanced by one step.
        """
        return euler_update(state, rates, dt)

    # ------------------------------------------------------------------ #
    # Internal: one-day rate computation in the SIMPLACE execution order
    # ------------------------------------------------------------------ #

    def _compute_rates_dispatch(
        self,
        state: ModelState,
        weather_day: dict[str, torch.Tensor],
        doy: torch.Tensor,
        crop_params: CropParameters,
        soil_params: SoilParameters,
        site_params: SiteParameters,
    ) -> dict[str, torch.Tensor]:
        # Extract weather variables (SIMPLACE order)
        davtmp = weather_day["davtmp"]
        tmin = weather_day["tmin"]
        tmax = weather_day["tmax"]
        dtr = weather_day["irrad"]
        rain = weather_day["rain"]
        vap = weather_day["vp"]  # [kPa] from weather
        wind = weather_day["wind"]

        # 1. Astro — solar declination, daylength
        lat_b = (
            site_params.latitude.expand_as(doy)
            if site_params.latitude.dim() > 0
            else site_params.latitude
        )
        astro = self.astro(doy=doy, latitude=lat_b)
        dayl = astro["daylength"]
        sinld = astro["sinld"]
        cosld = astro["cosld"]
        ddlp = astro["ddlp"]

        # 2. Irradiation — daily total irradiation and PAR interception
        irrad_out = self.irradiation(
            state=state,
            doy=doy.float(),
            dayl=dayl,
            sinld=sinld,
            cosld=cosld,
            dtr=dtr,
            params=crop_params,
        )
        avrad = irrad_out["avrad"]
        atmtr = irrad_out["atmtr"]
        frac_int = irrad_out["frac_intercepted"]

        # 3. Phenology
        pheno = self.phenology(state, davtmp, ddlp, crop_params)

        # 4. Evapotranspiration — PENMAN formula
        et = self.evapotranspiration(
            tmin=tmin,
            tmax=tmax,
            wind=wind,
            vap=vap,
            avrad=avrad,
            atmtr=atmtr,
            frac_int=frac_int,
        )

        # 5. Water balance (two-zone, with SIMPLACE percolation cascade)
        rdm = torch.minimum(
            soil_params.rdmso + torch.zeros_like(state.rootd),
            crop_params.rdmcr + torch.zeros_like(state.rootd),
        )
        water = self.water_balance(
            state=state,
            rain=rain,
            pevap=et["pevap"],
            ptran=et["ptran"],
            params=soil_params,
            rdm=rdm,
            etc=et["etc"],
            doy=doy,
        )
        tranrf = water["tranrf"]

        # 6+7. Nutrient preliminary step — we first estimate partitioning
        #      using a "no nutrient stress" GTOTAL to compute demand, then
        #      finalise with the resulting nstress.
        photo_pre = self.photosynthesis(
            parint=irrad_out["parint"],
            davtmp=davtmp,
            tranrf=tranrf,
            nstress=torch.ones_like(davtmp),
            params=crop_params,
        )
        part_pre = self.partitioning(
            state=state, gtotal=photo_pre["gtotal"], params=crop_params
        )
        nut = self.nutrient_demand(
            state=state,
            g_lv=part_pre["g_lv"],
            g_st=part_pre["g_st"],
            g_rt=part_pre["g_root"],
            g_so=part_pre["g_so"],
            crop_params=crop_params,
            soil_params=soil_params,
        )
        nstress = nut["nstress"]

        # 8. Photosynthesis (final) with nutrient + water stress
        photo = self.photosynthesis(
            parint=irrad_out["parint"],
            davtmp=davtmp,
            tranrf=tranrf,
            nstress=self.stress(tranrf, nstress) / torch.clamp(tranrf, min=1e-6),
            params=crop_params,
        )
        gtotal = photo["gtotal"]

        # Residual correction on gtotal
        if "photosynthesis" in self.residual_modules:
            ctx = torch.stack(
                [state.lai, state.dvs, davtmp, dtr, tranrf, nstress, state.wa, doy],
                dim=-1,
            )
            gtotal = gtotal + self.residual_modules["photosynthesis"](ctx).squeeze(-1)
            gtotal = torch.clamp(gtotal, min=0.0)

        # 9. Partitioning
        part = self.partitioning(state=state, gtotal=gtotal, params=crop_params)

        # 10. Leaf dynamics
        leaf = self.leaf_dynamics(
            state=state,
            g_lv=part["g_lv"],
            dtsu=pheno["dtsu"],
            tranrf=tranrf,
            nstress=nstress,
            params=crop_params,
        )

        # 11. Root dynamics
        root = self.root_dynamics(
            state=state,
            g_root=part["g_root"],
            tranrf=tranrf,
            params=crop_params,
        )

        # Gate all growth/senescence rates post-maturity
        active = (state.dvs < 2.0).to(davtmp.dtype)
        gate = lambda x: x * active  # noqa: E731

        rates: dict[str, torch.Tensor] = {
            "dvs_rate": pheno["dvs_rate"],
            "tsum_rate": pheno["tsum_rate"],
            "tsump_rate": pheno["tsump_rate"],
            "vern_rate": pheno["vern_rate"],
            "wlv_rate": gate(leaf["wlv_rate"]),
            "wlvd_rate": gate(leaf["wlvd_rate"]),
            "wst_rate": gate(part["g_st"]),
            "wrt_rate": gate(root["wrt_rate"]),
            "wso_rate": gate(part["g_so"]),
            "lai_rate": gate(leaf["lai_rate"]),
            "rootd_rate": root["rootd_rate"],
            "wa_rate": water["wa_rate"],
            "wa_lower_rate": water["wa_lower_rate"],
            "dslr_rate": water["dslr_rate"],
            "dsos_rate": water["dsos_rate"],
            "anlv_rate": gate(nut["n_lv_rate"]),
            "anst_rate": gate(nut["n_st_rate"]),
            "anrt_rate": gate(nut["n_rt_rate"]),
            "anso_rate": gate(nut["n_so_rate"]),
            "aplv_rate": gate(nut["p_lv_rate"]),
            "apst_rate": gate(nut["p_st_rate"]),
            "aprt_rate": gate(nut["p_rt_rate"]),
            "apso_rate": gate(nut["p_so_rate"]),
            "aklv_rate": gate(nut["k_lv_rate"]),
            "akst_rate": gate(nut["k_st_rate"]),
            "akrt_rate": gate(nut["k_rt_rate"]),
            "akso_rate": gate(nut["k_so_rate"]),
            "tran_cum_rate": water["tran"],
            "evap_cum_rate": water["evap"],
            # Diagnostics (not integrated)
            "tranrf": tranrf,
            "nstress": nstress,
            "gtotal": gtotal,
        }
        return rates

    # ------------------------------------------------------------------ #
    # Convenience: flatten all learnable parameters across dataclasses
    # ------------------------------------------------------------------ #

    def learnable_parameter_groups(self) -> dict[str, Any]:
        """Return a dict of named `nn.Parameter` tensors.

        Walks the ``crop``/``soil``/``site`` parameter containers and
        collects every field that is an `nn.Parameter`.

        Returns:
            Dict keyed by ``"<container>.<field>"`` mapping to the
            corresponding `nn.Parameter`.
        """
        out: dict[str, Any] = {}
        for name, params in (
            ("crop", self.crop_params),
            ("soil", self.soil_params),
            ("site", self.site_params),
        ):
            for f in fields(params):
                v = getattr(params, f.name)
                if isinstance(v, nn.Parameter):
                    out[f"{name}.{f.name}"] = v
        return out

compute_rates(self, state, weather_day, doy)

Compute the rate vector for a single day (low-level API).

Parameters:

Name Type Description Default
state ModelState

Current ModelState.

required
weather_day dict[str, torch.Tensor]

Dict of named weather channels for the current day (see WEATHER_CHANNELS), each of shape [B].

required
doy torch.Tensor

Day-of-year tensor of shape [B].

required

Returns:

Type Description
dict[str, torch.Tensor]

Dict of rate tensors keyed by "<field>_rate" plus diagnostics (tranrf, nstress, gtotal).

Source code in torchcrop/model.py
def compute_rates(
    self,
    state: ModelState,
    weather_day: dict[str, torch.Tensor],
    doy: torch.Tensor,
) -> dict[str, torch.Tensor]:
    """Compute the rate vector for a single day (low-level API).

    Args:
        state: Current `ModelState`.
        weather_day: Dict of named weather channels for the current day
            (see `WEATHER_CHANNELS`), each of shape ``[B]``.
        doy: Day-of-year tensor of shape ``[B]``.

    Returns:
        Dict of rate tensors keyed by ``"<field>_rate"`` plus diagnostics
        (``tranrf``, ``nstress``, ``gtotal``).
    """
    return self._compute_rates_dispatch(
        state=state,
        weather_day=weather_day,
        doy=doy,
        crop_params=self.crop_params,
        soil_params=self.soil_params,
        site_params=self.site_params,
    )

forward(self, weather, start_doy=1, initial_state=None)

Run a full simulation and return trajectories plus final yield.

Parameters:

Name Type Description Default
weather WeatherDriver | torch.Tensor

WeatherDriver or a raw [B, T, C] tensor of daily weather forcing.

required
start_doy int

Day-of-year of the first simulated day.

1
initial_state ModelState | None

Optional pre-built ModelState. When omitted, initialize is called automatically.

None

Returns:

Type Description
ModelOutput

A ModelOutput containing the full state/rate trajectories and summary variables (lai, dvs, biomass, yield_).

Source code in torchcrop/model.py
def forward(
    self,
    weather: WeatherDriver | torch.Tensor,
    start_doy: int = 1,
    initial_state: ModelState | None = None,
) -> ModelOutput:
    """Run a full simulation and return trajectories plus final yield.

    Args:
        weather: `WeatherDriver` or a raw ``[B, T, C]`` tensor of
            daily weather forcing.
        start_doy: Day-of-year of the first simulated day.
        initial_state: Optional pre-built `ModelState`. When
            omitted, `initialize` is called automatically.

    Returns:
        A `ModelOutput` containing the full state/rate trajectories
        and summary variables (``lai``, ``dvs``, ``biomass``, ``yield_``).
    """
    if isinstance(weather, torch.Tensor):
        weather = WeatherDriver(weather)
    batch_size = weather.batch_size
    if initial_state is None:
        state = self.initialize(
            batch_size=batch_size,
            dtype=weather.data.dtype,
            device=weather.data.device,
        )
    else:
        state = initial_state

    states, rates = self.engine.run(
        state=state,
        weather=weather,
        start_doy=start_doy,
        crop_params=self.crop_params,
        soil_params=self.soil_params,
        site_params=self.site_params,
    )

    lai = torch.stack([s.lai for s in states], dim=1)  # [B, T+1]
    dvs = torch.stack([s.dvs for s in states], dim=1)
    biomass = torch.stack([s.wlv + s.wst + s.wso for s in states], dim=1)
    yield_ = states[-1].wso

    return ModelOutput(
        states=states,
        rates=rates,
        yield_=yield_,
        lai=lai,
        dvs=dvs,
        biomass=biomass,
    )

initialize(self, batch_size, dtype=torch.float32, device='cpu')

Build an initial state for a batch, using dvsi from crop params.

Parameters:

Name Type Description Default
batch_size int

Number of parallel simulation instances B.

required
dtype torch.dtype

Tensor dtype.

torch.float32
device torch.device | str

Torch device (e.g. "cpu", "cuda").

'cpu'

Returns:

Type Description
ModelState

A fresh ModelState with initial DVS, root depth, soil water at field capacity, and a seeded leaf mass so that LAI growth has a substrate post-emergence.

Source code in torchcrop/model.py
def initialize(
    self,
    batch_size: int,
    dtype: torch.dtype = torch.float32,
    device: torch.device | str = "cpu",
) -> ModelState:
    """Build an initial state for a batch, using ``dvsi`` from crop params.

    Args:
        batch_size: Number of parallel simulation instances ``B``.
        dtype: Tensor dtype.
        device: Torch device (e.g. ``"cpu"``, ``"cuda"``).

    Returns:
        A fresh `ModelState` with initial DVS, root depth, soil
        water at field capacity, and a seeded leaf mass so that LAI
        growth has a substrate post-emergence.
    """
    dvsi = float(self.crop_params.dvsi.detach().cpu().item())
    rootdi = float(self.crop_params.rootdi.detach().cpu().item())
    # Initialise at field capacity × initial rooting depth (mm)
    wfc = float(self.soil_params.wcfc.detach().cpu().item())
    wai = 1000.0 * wfc * rootdi
    # Lower-zone initial water — SIMPLACE WTOTL = factor·(RDM − RDI)·SMLOWI
    rdmso = float(self.soil_params.rdmso.detach().cpu().item())
    rdmcr = float(self.crop_params.rdmcr.detach().cpu().item())
    rdm_val = min(rdmso, rdmcr)
    wci_lower = float(self.soil_params.wci_lower.detach().cpu().item())
    wa_lower_i = 1000.0 * max(rdm_val - rootdi, 1e-4) * wci_lower
    state = ModelState.initial(
        batch_size=batch_size,
        dtype=dtype,
        device=device,
        dvsi=dvsi,
        wai=wai,
        rootdi=rootdi,
        wa_lower_i=wa_lower_i,
        dslri=3.0,
        dsosi=0.0,
    )
    # Seed leaf mass so that LAI growth has a substrate post-emergence
    laii = float(self.crop_params.laii.detach().cpu().item())
    sla = float(self.crop_params.sla.detach().cpu().item())
    wlv0 = torch.full_like(state.wlv, laii / max(sla, 1e-6))
    lai0 = torch.full_like(state.lai, laii)
    return state.replace(wlv=wlv0, lai=lai0)

learnable_parameter_groups(self)

Return a dict of named nn.Parameter tensors.

Walks the crop/soil/site parameter containers and collects every field that is an nn.Parameter.

Returns:

Type Description
dict[str, Any]

Dict keyed by "<container>.<field>" mapping to the corresponding nn.Parameter.

Source code in torchcrop/model.py
def learnable_parameter_groups(self) -> dict[str, Any]:
    """Return a dict of named `nn.Parameter` tensors.

    Walks the ``crop``/``soil``/``site`` parameter containers and
    collects every field that is an `nn.Parameter`.

    Returns:
        Dict keyed by ``"<container>.<field>"`` mapping to the
        corresponding `nn.Parameter`.
    """
    out: dict[str, Any] = {}
    for name, params in (
        ("crop", self.crop_params),
        ("soil", self.soil_params),
        ("site", self.site_params),
    ):
        for f in fields(params):
            v = getattr(params, f.name)
            if isinstance(v, nn.Parameter):
                out[f"{name}.{f.name}"] = v
    return out

update_state(self, state, rates, dt=1.0)

Apply a forward-Euler step to advance the state by dt days.

Parameters:

Name Type Description Default
state ModelState

Current ModelState.

required
rates dict[str, torch.Tensor]

Dict of rate tensors produced by compute_rates.

required
dt float

Integration step in days.

1.0

Returns:

Type Description
ModelState

A new ModelState advanced by one step.

Source code in torchcrop/model.py
def update_state(
    self,
    state: ModelState,
    rates: dict[str, torch.Tensor],
    dt: float = 1.0,
) -> ModelState:
    """Apply a forward-Euler step to advance the state by ``dt`` days.

    Args:
        state: Current `ModelState`.
        rates: Dict of rate tensors produced by `compute_rates`.
        dt: Integration step in days.

    Returns:
        A new `ModelState` advanced by one step.
    """
    return euler_update(state, rates, dt)

ModelOutput dataclass

Container for a full simulation run.

Attributes:

Name Type Description
states list

Per-day state snapshots (length T + 1; the first entry is the initial condition).

rates list

Per-day rate dicts (length T).

yield_ Tensor

Final storage-organ dry weight WSO at the last step [g m-2].

lai Tensor

LAI trajectory of shape [B, T + 1].

dvs Tensor

DVS trajectory of shape [B, T + 1].

biomass Tensor

Above-ground biomass trajectory of shape [B, T + 1].

Source code in torchcrop/model.py
@dataclass
class ModelOutput:
    """Container for a full simulation run.

    Attributes:
        states: Per-day state snapshots (length ``T + 1``; the first entry is
            the initial condition).
        rates: Per-day rate dicts (length ``T``).
        yield_: Final storage-organ dry weight ``WSO`` at the last step
            [g m-2].
        lai: LAI trajectory of shape ``[B, T + 1]``.
        dvs: DVS trajectory of shape ``[B, T + 1]``.
        biomass: Above-ground biomass trajectory of shape ``[B, T + 1]``.
    """

    states: list[ModelState]
    rates: list[dict[str, torch.Tensor]]
    yield_: torch.Tensor
    lai: torch.Tensor
    dvs: torch.Tensor
    biomass: torch.Tensor