Skip to content

water_balance

Two-zone soil water balance for Lintul5.

A faithful port of the SIMPLACE WATBALS routine: tracks water storage in a rooted zone and a lower zone, runs a percolation cascade to deep drainage, and produces the water-stress factor TRANRF that gates crop growth.

References

SIMPLACE WaterBalance.java (WATBALS routine) and LintulFunctions.SWEAF.

Design

Zones. The soil column is split into a rooted zone of depth rootd storing wa [mm] and a lower zone of depth rdm - rootd storing wa_lower [mm], where rdm = min(rdmso, rdmcr) is the soil-/crop-limited maximum rooting depth. Root-front advance rr moves water at content SMACTL from the lower zone into the rooted zone via the WDR/WDRA fluxes.

Percolation cascade. Net infiltration (after evaporation and transpiration) enters the rooted zone as PERC1; excess above field capacity descends to the lower zone as PERC2; excess below the lower zone's field capacity leaves the profile as deep drainage PERC3. Each step is rate-limited by KSUB and by the free pore space of the receiving layer.

Stress factors. Transpiration is reduced multiplicatively by a drought factor RDRY and (for non-rice crops) an oxygen factor RWET that ramps with DSOS — days of oxygen shortage, persistent across days and clipped to [0, 4].

Soil evaporation. Stroosnijder model: when infiltration ≥ 5 mm d⁻¹, evaporation is reset to the potential rate and DSLR to 1; otherwise DSLR increments and evaporation follows the sqrt(DSLR) - sqrt(DSLR-1) decay, capped by air-dry capacity.

Irrigation. Three modes selected by params.irri: 0 = none; 1 = automatic refill when SMACT falls below SMCR + 0.02 and rain < 10 mm; 2 = table look-up of irrtab scaled by scale_factor_irr.

Equations

Volumetric soil-moisture contents [m³ m⁻³]:

\[ \theta = \frac{W_a}{1000 \cdot D_\text{root}}, \qquad \theta_\ell = \frac{W_{a,\ell}}{1000 \cdot (D_\text{rdm} - D_\text{root})} \]

Easily-available fraction (SIMPLACE SWEAF) and the critical moisture content below which transpiration starts to decline:

\[ f_\text{eaw} = \mathrm{clip}\left(\frac{1}{A + B\,\text{ETC}_ \text{cm}} - (5-\text{DEPNR})\cdot 0.10,\ 0.10,\ 0.95\right), \qquad \theta_\text{crit} = (1 - f_\text{eaw})(\theta_\text{fc} - \theta_\text{wp}) + \theta_\text{wp}. \]

Drought and oxygen reduction factors:

\[ R_\text{dry} = \mathrm{clip}\left( \frac{\theta - \theta_\text{wp}}{\theta_\text{crit} - \theta_\text{wp}},\ 0,\ 1\right), \quad R_\text{wet,max} = \mathrm{clip}\left( \frac{\theta_\text{sat} - \theta}{\theta_\text{sat} - \theta_\text{air}},\ 0,\ 1\right), \quad R_\text{wet} = R_\text{wet,max} + \left(1 - \tfrac{\text{DSOS}}{4} \right)(1 - R_\text{wet,max}). \]

Root-front water transfer (total / above wilting point):

\[ \text{WDR} = 1000 \cdot r_r \cdot \theta_\ell, \qquad \text{WDRA} = 1000 \cdot r_r \cdot (\theta_\ell - \theta_\text{wp})_+. \]

Percolation cascade (subscript 0 = saturation-capacity headroom; unlabelled = field-capacity headroom):

\[ \begin{aligned} \text{PERC} &= (1 - \text{RUNFR})\cdot \text{RAIN} + \text{RIRR},\\ \text{PERC1P} &= \text{PERC} - E_a - T_a,\\ \text{PERC1} &= \min(\text{KSUB} + \text{CAP}_0,\ \text{PERC1P}),\\ \text{RUNOFF} &= \text{RUNFR}\cdot \text{RAIN} + (\text{PERC1P} - \text{PERC1})_+,\\ \text{PERC2} &= \mathbb{1}_{\text{CAP}\le \text{PERC1}}\cdot \min(\text{KSUB} + \text{CAP}_{\ell,0},\ (\text{PERC1} - \text{CAP})_+),\\ \text{PERC3} &= \mathbb{1}_{\text{CAP}_\ell\le \text{PERC2}}\cdot \min(\text{KSUB},\ (\text{PERC2} - \text{CAP}_\ell)_+). \end{aligned} \]

State rates (forward Euler, dt = 1 d):

\[ \dot W_a = \text{PERC1} - \text{PERC2} + \text{WDR}, \qquad \dot W_{a,\ell} = \text{PERC2} - \text{PERC3} - \text{WDR}. \]

WaterBalance (Module)

Two-zone port of SIMPLACE Lintul5 WATBALS.

Computes daily water fluxes (transpiration, evaporation, runoff, drainage), stress factors (TRANRF), and rate variables for the rooted / lower / Stroosnijder / oxygen states, in a fully batched and autograd-safe manner.

Source code in torchcrop/processes/water_balance.py
class WaterBalance(nn.Module):
    """Two-zone port of SIMPLACE Lintul5 ``WATBALS``.

    Computes daily water fluxes (transpiration, evaporation, runoff,
    drainage), stress factors (``TRANRF``), and rate variables for the
    rooted / lower / Stroosnijder / oxygen states, in a fully batched
    and autograd-safe manner.
    """

    def forward(
        self,
        state: ModelState,
        rain: torch.Tensor,
        pevap: torch.Tensor,
        ptran: torch.Tensor,
        params: SoilParameters,
        rdm: torch.Tensor,
        etc: torch.Tensor | None = None,
        rr: torch.Tensor | None = None,
        irrigation: torch.Tensor | None = None,
        doy: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor]:
        """Compute one day of water-balance rates and fluxes.

        Args:
            state: Current `ModelState`; uses ``wa``, ``wa_lower``,
                ``rootd``, ``dslr``, ``dsos``.
            rain: Daily precipitation [mm d⁻¹], shape ``[B]``.
            pevap: Potential soil evaporation [mm d⁻¹], shape ``[B]``.
            ptran: Potential transpiration [mm d⁻¹], shape ``[B]``.
            params: Soil parameter container.
            rdm: Soil-/crop-limited maximum rooting depth
                ``min(rdmso, rdmcr)`` [m], shape ``[B]``.
            etc: CO2-corrected reference canopy ET [mm d⁻¹], shape ``[B]``;
                falls back to ``ptran`` when ``None``.
            rr: Root-front velocity [m d⁻¹], shape ``[B]``; ``None`` → 0
                (no root-front water transfer).
            irrigation: Externally supplied irrigation [mm d⁻¹] that
                overrides ``params.irri`` mode.
            doy: Day-of-year tensor needed by the ``IRRI = 2`` table
                look-up; shape ``[B]``.

        Returns:
            Dict of ``[B]`` tensors.

            **Rate variables** (consumed by the engine):

            * ``wa_rate``       — ``perc1 - perc2 + wdr`` [mm d⁻¹].
            * ``wa_lower_rate`` — ``perc2 - perc3 - wdr`` [mm d⁻¹].
            * ``dslr_rate``     — ``dslr_new - dslr`` [d d⁻¹].
            * ``dsos_rate``     — ``dsos_new - dsos`` [d d⁻¹].

            **Fluxes / diagnostics**:

            * ``tran``   — actual transpiration [mm d⁻¹].
            * ``evap``   — actual soil evaporation [mm d⁻¹].
            * ``runoff`` — surface runoff (preliminary + rejected infiltration) [mm d⁻¹].
            * ``drain``  — deep drainage = ``perc3`` [mm d⁻¹].
            * ``perc1`` / ``perc2`` / ``perc3`` — cascade fluxes [mm d⁻¹].
            * ``wdr`` / ``wdra`` — root-front water transfer to the rooted zone (total / available) [mm d⁻¹].
            * ``rirr``   — effective irrigation [mm d⁻¹].
            * ``tranrf`` — water-stress factor in ``[0, 1]``.
            * ``smact`` / ``smactl`` — soil-moisture contents [m³ m⁻³].
            * ``smcr``   — critical soil-moisture content [m³ m⁻³].
            * ``rdry`` / ``rwet`` — drought / oxygen reduction factors in ``[0, 1]``.
            * ``wbal``   — rooted-zone mass-balance residual [mm] (should be ≈ 0).
        """
        factor = 1000.0  # root [m] · volumetric water-content → water [mm]
        rootd = torch.clamp(state.rootd, min=1e-4)
        rdm_eff = torch.clamp(rdm, min=rootd + 1e-4)
        rd_lower = torch.clamp(rdm_eff - rootd, min=1e-4)

        # ---------------------------------------------------------------- #
        # 1. Actual volumetric soil-moisture contents [m³ m⁻³]
        # ---------------------------------------------------------------- #
        smact = state.wa / notnul(factor * rootd)
        smactl = state.wa_lower / notnul(factor * rd_lower)

        # ---------------------------------------------------------------- #
        # 2. Critical moisture content and drought reduction factor
        # ---------------------------------------------------------------- #
        etc_eff = ptran if etc is None else etc
        sweaf = _sweaf(etc_eff, params.depnr)
        smcr = (1.0 - sweaf) * (params.wcfc - params.wcwp) + params.wcwp
        rdry = limit(0.0, 1.0, (smact - params.wcwp) / notnul(smcr - params.wcwp))

        # ---------------------------------------------------------------- #
        # 3. Oxygen-shortage factor with DSOS accumulator
        # ---------------------------------------------------------------- #
        smair = params.wcst - params.crairc
        rwetmx = limit(0.0, 1.0, (params.wcst - smact) / notnul(params.wcst - smair))
        dsos_new = torch.where(
            smact >= smair,
            torch.clamp(state.dsos + 1.0, max=4.0),
            torch.zeros_like(state.dsos),
        )
        rwet_nonrice = rwetmx + (1.0 - dsos_new / 4.0) * (1.0 - rwetmx)
        is_aquatic = (params.iairdu > 0.5).to(rwet_nonrice.dtype)
        rwet = is_aquatic + (1.0 - is_aquatic) * rwet_nonrice
        rwet = limit(0.0, 1.0, rwet)

        # ---------------------------------------------------------------- #
        # 4. Actual transpiration and water-stress factor
        # ---------------------------------------------------------------- #
        wwp = factor * params.wcwp * rootd
        wavt = torch.clamp(state.wa - wwp, min=0.0)
        tran = torch.clamp(torch.minimum(wavt, rdry * rwet * ptran), min=0.0)
        tranrf = tran / notnul(ptran)

        # ---------------------------------------------------------------- #
        # 5. Irrigation demand
        # ---------------------------------------------------------------- #
        rirr = _irrigation_demand(
            params=params,
            smact=smact,
            smcr=smcr,
            wavt=wavt,
            rootd=rootd,
            rain=rain,
            doy=doy,
            external=irrigation,
        )

        # ---------------------------------------------------------------- #
        # 6. Stroosnijder soil evaporation with DSLR accumulator
        # ---------------------------------------------------------------- #
        perc = (1.0 - params.runfr) * rain + rirr
        runofp = params.runfr * rain

        wet_day = (perc >= 5.0).to(state.dslr.dtype)
        dslr_new = wet_day * torch.ones_like(state.dslr) + (1.0 - wet_day) * (
            state.dslr + 1.0
        )
        # Evaporation on dry days — Stroosnijder (1987): sqrt(t) - sqrt(t-1)
        dslr_prev = torch.clamp(dslr_new - 1.0, min=0.0)
        decay = torch.sqrt(torch.clamp(dslr_new, min=1e-8)) - torch.sqrt(dslr_prev)
        evmaxt = pevap * limit(0.0, 1.0, decay * params.cfev)
        # Cap by air-dry topsoil water capacity (SIMPLACE: 100·(SMACT - SMDRY))
        evap_cap = 100.0 * torch.clamp(smact - params.wcad, min=0.0)
        evap_dry = torch.clamp(
            torch.minimum(pevap, torch.minimum(evmaxt + perc, evap_cap)),
            min=0.0,
        )
        evap_wet = pevap
        evap = wet_day * evap_wet + (1.0 - wet_day) * evap_dry
        # Final safety clamp by available water above air-dry content
        wad_mm = factor * params.wcad * rootd
        evap = torch.minimum(evap, torch.clamp(state.wa - wad_mm, min=0.0))

        # ---------------------------------------------------------------- #
        # 7. Root-front water transfer (WDR / WDRA)
        # ---------------------------------------------------------------- #
        rr_eff = rr if rr is not None else torch.zeros_like(rain)
        wdr = factor * rr_eff * smactl
        wdra = factor * rr_eff * torch.clamp(smactl - params.wcwp, min=0.0)

        # ---------------------------------------------------------------- #
        # 8. Percolation cascade PERC1 → PERC2 → PERC3
        # ---------------------------------------------------------------- #
        cap = torch.clamp(params.wcfc - smact, min=0.0) * factor * rootd
        cap0 = torch.clamp(params.wcst - smact, min=0.0) * factor * rootd
        capl = torch.clamp(params.wcfc - smactl, min=0.0) * factor * rd_lower
        capl0 = torch.clamp(params.wcst - smactl, min=0.0) * factor * rd_lower

        perc1p = perc - evap - tran
        perc1 = torch.minimum(params.ksub + cap0, perc1p)
        extra_runoff = torch.clamp(perc1p - perc1, min=0.0)
        runoff = runofp + extra_runoff

        perc2_candidate = torch.minimum(
            params.ksub + capl0, torch.clamp(perc1 - cap, min=0.0)
        )
        perc2 = torch.where(cap <= perc1, perc2_candidate, torch.zeros_like(perc1))

        perc3_candidate = torch.minimum(
            params.ksub + torch.zeros_like(perc2),
            torch.clamp(perc2 - capl, min=0.0),
        )
        perc3 = torch.where(capl <= perc2, perc3_candidate, torch.zeros_like(perc2))

        # ---------------------------------------------------------------- #
        # 9. Rate variables
        # ---------------------------------------------------------------- #
        wa_rate = perc1 - perc2 + wdr
        wa_lower_rate = perc2 - perc3 - wdr
        dslr_rate = dslr_new - state.dslr
        dsos_rate = dsos_new - state.dsos

        # ---------------------------------------------------------------- #
        # 10. Mass-balance residual for diagnostics
        # ---------------------------------------------------------------- #
        wbal = rain + rirr - runoff - evap - tran - perc3 - (wa_rate + wa_lower_rate)

        return {
            # rates
            "wa_rate": wa_rate,
            "wa_lower_rate": wa_lower_rate,
            "dslr_rate": dslr_rate,
            "dsos_rate": dsos_rate,
            # fluxes
            "tran": tran,
            "evap": evap,
            "runoff": runoff,
            "drain": perc3,
            "perc1": perc1,
            "perc2": perc2,
            "perc3": perc3,
            "wdr": wdr,
            "wdra": wdra,
            "rirr": rirr,
            # stress factors / diagnostics
            "tranrf": tranrf,
            "smact": smact,
            "smactl": smactl,
            "smcr": smcr,
            "rdry": rdry,
            "rwet": rwet,
            "wbal": wbal,
        }

forward(self, state, rain, pevap, ptran, params, rdm, etc=None, rr=None, irrigation=None, doy=None)

Compute one day of water-balance rates and fluxes.

Parameters:

Name Type Description Default
state ModelState

Current ModelState; uses wa, wa_lower, rootd, dslr, dsos.

required
rain torch.Tensor

Daily precipitation [mm d⁻¹], shape [B].

required
pevap torch.Tensor

Potential soil evaporation [mm d⁻¹], shape [B].

required
ptran torch.Tensor

Potential transpiration [mm d⁻¹], shape [B].

required
params SoilParameters

Soil parameter container.

required
rdm torch.Tensor

Soil-/crop-limited maximum rooting depth min(rdmso, rdmcr) [m], shape [B].

required
etc torch.Tensor | None

CO2-corrected reference canopy ET [mm d⁻¹], shape [B]; falls back to ptran when None.

None
rr torch.Tensor | None

Root-front velocity [m d⁻¹], shape [B]; None → 0 (no root-front water transfer).

None
irrigation torch.Tensor | None

Externally supplied irrigation [mm d⁻¹] that overrides params.irri mode.

None
doy torch.Tensor | None

Day-of-year tensor needed by the IRRI = 2 table look-up; shape [B].

None

Returns:

Type Description
Dict of ``[B]`` tensors. **Rate variables** (consumed by the engine)
  • wa_rateperc1 - perc2 + wdr [mm d⁻¹].
  • wa_lower_rateperc2 - perc3 - wdr [mm d⁻¹].
  • dslr_ratedslr_new - dslr [d d⁻¹].
  • dsos_ratedsos_new - dsos [d d⁻¹].

Fluxes / diagnostics:

  • tran — actual transpiration [mm d⁻¹].
  • evap — actual soil evaporation [mm d⁻¹].
  • runoff — surface runoff (preliminary + rejected infiltration) [mm d⁻¹].
  • drain — deep drainage = perc3 [mm d⁻¹].
  • perc1 / perc2 / perc3 — cascade fluxes [mm d⁻¹].
  • wdr / wdra — root-front water transfer to the rooted zone (total / available) [mm d⁻¹].
  • rirr — effective irrigation [mm d⁻¹].
  • tranrf — water-stress factor in [0, 1].
  • smact / smactl — soil-moisture contents [m³ m⁻³].
  • smcr — critical soil-moisture content [m³ m⁻³].
  • rdry / rwet — drought / oxygen reduction factors in [0, 1].
  • wbal — rooted-zone mass-balance residual [mm] (should be ≈ 0).
Source code in torchcrop/processes/water_balance.py
def forward(
    self,
    state: ModelState,
    rain: torch.Tensor,
    pevap: torch.Tensor,
    ptran: torch.Tensor,
    params: SoilParameters,
    rdm: torch.Tensor,
    etc: torch.Tensor | None = None,
    rr: torch.Tensor | None = None,
    irrigation: torch.Tensor | None = None,
    doy: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
    """Compute one day of water-balance rates and fluxes.

    Args:
        state: Current `ModelState`; uses ``wa``, ``wa_lower``,
            ``rootd``, ``dslr``, ``dsos``.
        rain: Daily precipitation [mm d⁻¹], shape ``[B]``.
        pevap: Potential soil evaporation [mm d⁻¹], shape ``[B]``.
        ptran: Potential transpiration [mm d⁻¹], shape ``[B]``.
        params: Soil parameter container.
        rdm: Soil-/crop-limited maximum rooting depth
            ``min(rdmso, rdmcr)`` [m], shape ``[B]``.
        etc: CO2-corrected reference canopy ET [mm d⁻¹], shape ``[B]``;
            falls back to ``ptran`` when ``None``.
        rr: Root-front velocity [m d⁻¹], shape ``[B]``; ``None`` → 0
            (no root-front water transfer).
        irrigation: Externally supplied irrigation [mm d⁻¹] that
            overrides ``params.irri`` mode.
        doy: Day-of-year tensor needed by the ``IRRI = 2`` table
            look-up; shape ``[B]``.

    Returns:
        Dict of ``[B]`` tensors.

        **Rate variables** (consumed by the engine):

        * ``wa_rate``       — ``perc1 - perc2 + wdr`` [mm d⁻¹].
        * ``wa_lower_rate`` — ``perc2 - perc3 - wdr`` [mm d⁻¹].
        * ``dslr_rate``     — ``dslr_new - dslr`` [d d⁻¹].
        * ``dsos_rate``     — ``dsos_new - dsos`` [d d⁻¹].

        **Fluxes / diagnostics**:

        * ``tran``   — actual transpiration [mm d⁻¹].
        * ``evap``   — actual soil evaporation [mm d⁻¹].
        * ``runoff`` — surface runoff (preliminary + rejected infiltration) [mm d⁻¹].
        * ``drain``  — deep drainage = ``perc3`` [mm d⁻¹].
        * ``perc1`` / ``perc2`` / ``perc3`` — cascade fluxes [mm d⁻¹].
        * ``wdr`` / ``wdra`` — root-front water transfer to the rooted zone (total / available) [mm d⁻¹].
        * ``rirr``   — effective irrigation [mm d⁻¹].
        * ``tranrf`` — water-stress factor in ``[0, 1]``.
        * ``smact`` / ``smactl`` — soil-moisture contents [m³ m⁻³].
        * ``smcr``   — critical soil-moisture content [m³ m⁻³].
        * ``rdry`` / ``rwet`` — drought / oxygen reduction factors in ``[0, 1]``.
        * ``wbal``   — rooted-zone mass-balance residual [mm] (should be ≈ 0).
    """
    factor = 1000.0  # root [m] · volumetric water-content → water [mm]
    rootd = torch.clamp(state.rootd, min=1e-4)
    rdm_eff = torch.clamp(rdm, min=rootd + 1e-4)
    rd_lower = torch.clamp(rdm_eff - rootd, min=1e-4)

    # ---------------------------------------------------------------- #
    # 1. Actual volumetric soil-moisture contents [m³ m⁻³]
    # ---------------------------------------------------------------- #
    smact = state.wa / notnul(factor * rootd)
    smactl = state.wa_lower / notnul(factor * rd_lower)

    # ---------------------------------------------------------------- #
    # 2. Critical moisture content and drought reduction factor
    # ---------------------------------------------------------------- #
    etc_eff = ptran if etc is None else etc
    sweaf = _sweaf(etc_eff, params.depnr)
    smcr = (1.0 - sweaf) * (params.wcfc - params.wcwp) + params.wcwp
    rdry = limit(0.0, 1.0, (smact - params.wcwp) / notnul(smcr - params.wcwp))

    # ---------------------------------------------------------------- #
    # 3. Oxygen-shortage factor with DSOS accumulator
    # ---------------------------------------------------------------- #
    smair = params.wcst - params.crairc
    rwetmx = limit(0.0, 1.0, (params.wcst - smact) / notnul(params.wcst - smair))
    dsos_new = torch.where(
        smact >= smair,
        torch.clamp(state.dsos + 1.0, max=4.0),
        torch.zeros_like(state.dsos),
    )
    rwet_nonrice = rwetmx + (1.0 - dsos_new / 4.0) * (1.0 - rwetmx)
    is_aquatic = (params.iairdu > 0.5).to(rwet_nonrice.dtype)
    rwet = is_aquatic + (1.0 - is_aquatic) * rwet_nonrice
    rwet = limit(0.0, 1.0, rwet)

    # ---------------------------------------------------------------- #
    # 4. Actual transpiration and water-stress factor
    # ---------------------------------------------------------------- #
    wwp = factor * params.wcwp * rootd
    wavt = torch.clamp(state.wa - wwp, min=0.0)
    tran = torch.clamp(torch.minimum(wavt, rdry * rwet * ptran), min=0.0)
    tranrf = tran / notnul(ptran)

    # ---------------------------------------------------------------- #
    # 5. Irrigation demand
    # ---------------------------------------------------------------- #
    rirr = _irrigation_demand(
        params=params,
        smact=smact,
        smcr=smcr,
        wavt=wavt,
        rootd=rootd,
        rain=rain,
        doy=doy,
        external=irrigation,
    )

    # ---------------------------------------------------------------- #
    # 6. Stroosnijder soil evaporation with DSLR accumulator
    # ---------------------------------------------------------------- #
    perc = (1.0 - params.runfr) * rain + rirr
    runofp = params.runfr * rain

    wet_day = (perc >= 5.0).to(state.dslr.dtype)
    dslr_new = wet_day * torch.ones_like(state.dslr) + (1.0 - wet_day) * (
        state.dslr + 1.0
    )
    # Evaporation on dry days — Stroosnijder (1987): sqrt(t) - sqrt(t-1)
    dslr_prev = torch.clamp(dslr_new - 1.0, min=0.0)
    decay = torch.sqrt(torch.clamp(dslr_new, min=1e-8)) - torch.sqrt(dslr_prev)
    evmaxt = pevap * limit(0.0, 1.0, decay * params.cfev)
    # Cap by air-dry topsoil water capacity (SIMPLACE: 100·(SMACT - SMDRY))
    evap_cap = 100.0 * torch.clamp(smact - params.wcad, min=0.0)
    evap_dry = torch.clamp(
        torch.minimum(pevap, torch.minimum(evmaxt + perc, evap_cap)),
        min=0.0,
    )
    evap_wet = pevap
    evap = wet_day * evap_wet + (1.0 - wet_day) * evap_dry
    # Final safety clamp by available water above air-dry content
    wad_mm = factor * params.wcad * rootd
    evap = torch.minimum(evap, torch.clamp(state.wa - wad_mm, min=0.0))

    # ---------------------------------------------------------------- #
    # 7. Root-front water transfer (WDR / WDRA)
    # ---------------------------------------------------------------- #
    rr_eff = rr if rr is not None else torch.zeros_like(rain)
    wdr = factor * rr_eff * smactl
    wdra = factor * rr_eff * torch.clamp(smactl - params.wcwp, min=0.0)

    # ---------------------------------------------------------------- #
    # 8. Percolation cascade PERC1 → PERC2 → PERC3
    # ---------------------------------------------------------------- #
    cap = torch.clamp(params.wcfc - smact, min=0.0) * factor * rootd
    cap0 = torch.clamp(params.wcst - smact, min=0.0) * factor * rootd
    capl = torch.clamp(params.wcfc - smactl, min=0.0) * factor * rd_lower
    capl0 = torch.clamp(params.wcst - smactl, min=0.0) * factor * rd_lower

    perc1p = perc - evap - tran
    perc1 = torch.minimum(params.ksub + cap0, perc1p)
    extra_runoff = torch.clamp(perc1p - perc1, min=0.0)
    runoff = runofp + extra_runoff

    perc2_candidate = torch.minimum(
        params.ksub + capl0, torch.clamp(perc1 - cap, min=0.0)
    )
    perc2 = torch.where(cap <= perc1, perc2_candidate, torch.zeros_like(perc1))

    perc3_candidate = torch.minimum(
        params.ksub + torch.zeros_like(perc2),
        torch.clamp(perc2 - capl, min=0.0),
    )
    perc3 = torch.where(capl <= perc2, perc3_candidate, torch.zeros_like(perc2))

    # ---------------------------------------------------------------- #
    # 9. Rate variables
    # ---------------------------------------------------------------- #
    wa_rate = perc1 - perc2 + wdr
    wa_lower_rate = perc2 - perc3 - wdr
    dslr_rate = dslr_new - state.dslr
    dsos_rate = dsos_new - state.dsos

    # ---------------------------------------------------------------- #
    # 10. Mass-balance residual for diagnostics
    # ---------------------------------------------------------------- #
    wbal = rain + rirr - runoff - evap - tran - perc3 - (wa_rate + wa_lower_rate)

    return {
        # rates
        "wa_rate": wa_rate,
        "wa_lower_rate": wa_lower_rate,
        "dslr_rate": dslr_rate,
        "dsos_rate": dsos_rate,
        # fluxes
        "tran": tran,
        "evap": evap,
        "runoff": runoff,
        "drain": perc3,
        "perc1": perc1,
        "perc2": perc2,
        "perc3": perc3,
        "wdr": wdr,
        "wdra": wdra,
        "rirr": rirr,
        # stress factors / diagnostics
        "tranrf": tranrf,
        "smact": smact,
        "smactl": smactl,
        "smcr": smcr,
        "rdry": rdry,
        "rwet": rwet,
        "wbal": wbal,
    }