Skip to content

smoothing

Smooth replacements for discontinuous operations.

These are optional drop-ins that preserve gradient flow through both branches of a discontinuity, useful for second-order-smooth optimization (L-BFGS, Hessian-based methods) and for differentiating through threshold events in the model.

smooth_step(x, x0=0.0, k=50.0)

Differentiable Heaviside step.

Returns \(\sigma(k (x - x_0))\) which tends to 0 for \(x \ll x_0\) and to 1 for \(x \gg x_0\).

Parameters:

Name Type Description Default
x torch.Tensor

Input tensor.

required
x0 float | torch.Tensor

Step location (scalar or broadcastable tensor).

0.0
k float

Sharpness of the transition; larger k yields a sharper step.

50.0

Returns:

Type Description
torch.Tensor

Tensor with the same shape as x, element-wise in (0, 1).

Source code in torchcrop/functions/smoothing.py
def smooth_step(
    x: torch.Tensor,
    x0: float | torch.Tensor = 0.0,
    k: float = 50.0,
) -> torch.Tensor:
    """Differentiable Heaviside step.

    Returns $\sigma(k (x - x_0))$ which tends to 0 for
    $x \ll x_0$ and to 1 for $x \gg x_0$.

    Args:
        x: Input tensor.
        x0: Step location (scalar or broadcastable tensor).
        k: Sharpness of the transition; larger ``k`` yields a sharper step.

    Returns:
        Tensor with the same shape as ``x``, element-wise in ``(0, 1)``.
    """
    return torch.sigmoid(k * (x - x0))

soft_clamp(x, lo, hi, k=50.0)

Smooth clamp between lo and hi using softplus.

The output asymptotes to lo below the lower bound and to hi above the upper bound, with a smooth transition of sharpness k.

Parameters:

Name Type Description Default
x torch.Tensor

Input tensor.

required
lo float | torch.Tensor

Lower bound (scalar or broadcastable tensor).

required
hi float | torch.Tensor

Upper bound (scalar or broadcastable tensor).

required
k float

Sharpness of the transition at the bounds.

50.0

Returns:

Type Description
torch.Tensor

Tensor with the same shape as x, smoothly clamped into [lo, hi].

Source code in torchcrop/functions/smoothing.py
def soft_clamp(
    x: torch.Tensor,
    lo: float | torch.Tensor,
    hi: float | torch.Tensor,
    k: float = 50.0,
) -> torch.Tensor:
    """Smooth clamp between ``lo`` and ``hi`` using softplus.

    The output asymptotes to ``lo`` below the lower bound and to ``hi`` above
    the upper bound, with a smooth transition of sharpness ``k``.

    Args:
        x: Input tensor.
        lo: Lower bound (scalar or broadcastable tensor).
        hi: Upper bound (scalar or broadcastable tensor).
        k: Sharpness of the transition at the bounds.

    Returns:
        Tensor with the same shape as ``x``, smoothly clamped into
        ``[lo, hi]``.
    """
    softplus = torch.nn.functional.softplus
    upper = hi - softplus(k * (hi - x)) / k
    return lo + softplus(k * (upper - lo)) / k

soft_if(condition, true_val, false_val, k=50.0)

Smooth replacement for torch.where(condition >= 0, true_val, false_val).

Mirrors the FST INSW function but keeps gradient flow through both branches via a sigmoid blend of sharpness k.

Parameters:

Name Type Description Default
condition torch.Tensor

Selector tensor; the sign (and magnitude) controls the sigmoid blend.

required
true_val torch.Tensor

Value used when condition >= 0.

required
false_val torch.Tensor

Value used when condition < 0.

required
k float

Sharpness of the sigmoid blend.

50.0

Returns:

Type Description
torch.Tensor

Sigmoid-blended combination sigmoid(k * condition) * true_val + (1 - sigmoid(k * condition)) * false_val.

Source code in torchcrop/functions/smoothing.py
def soft_if(
    condition: torch.Tensor,
    true_val: torch.Tensor,
    false_val: torch.Tensor,
    k: float = 50.0,
) -> torch.Tensor:
    """Smooth replacement for ``torch.where(condition >= 0, true_val, false_val)``.

    Mirrors the FST ``INSW`` function but keeps gradient flow through both
    branches via a sigmoid blend of sharpness ``k``.

    Args:
        condition: Selector tensor; the sign (and magnitude) controls the
            sigmoid blend.
        true_val: Value used when ``condition >= 0``.
        false_val: Value used when ``condition < 0``.
        k: Sharpness of the sigmoid blend.

    Returns:
        Sigmoid-blended combination
        ``sigmoid(k * condition) * true_val +
        (1 - sigmoid(k * condition)) * false_val``.
    """
    alpha = torch.sigmoid(k * condition)
    return alpha * true_val + (1.0 - alpha) * false_val

soft_max(a, b, k=50.0)

Differentiable maximum using the log-sum-exp trick.

Parameters:

Name Type Description Default
a torch.Tensor

First input tensor.

required
b torch.Tensor

Second input tensor (same shape as a).

required
k float

Sharpness; larger k approximates the hard maximum more closely.

50.0

Returns:

Type Description
torch.Tensor

Element-wise smooth maximum of a and b.

Source code in torchcrop/functions/smoothing.py
def soft_max(a: torch.Tensor, b: torch.Tensor, k: float = 50.0) -> torch.Tensor:
    """Differentiable maximum using the log-sum-exp trick.

    Args:
        a: First input tensor.
        b: Second input tensor (same shape as ``a``).
        k: Sharpness; larger ``k`` approximates the hard maximum more
            closely.

    Returns:
        Element-wise smooth maximum of ``a`` and ``b``.
    """
    stacked = torch.stack([a, b], dim=-1)
    return torch.logsumexp(k * stacked, dim=-1) / k

soft_min(a, b, k=50.0)

Differentiable minimum using the log-sum-exp trick.

$ ext{softmin}(a,b) = - rac{1}{k}\log(e^{-k a} + e^{-k b})$.

Parameters:

Name Type Description Default
a torch.Tensor

First input tensor.

required
b torch.Tensor

Second input tensor (same shape as a).

required
k float

Sharpness; larger k approximates the hard minimum more closely.

50.0

Returns:

Type Description
torch.Tensor

Element-wise smooth minimum of a and b.

Source code in torchcrop/functions/smoothing.py
def soft_min(a: torch.Tensor, b: torch.Tensor, k: float = 50.0) -> torch.Tensor:
    """Differentiable minimum using the log-sum-exp trick.

    $\text{softmin}(a,b) = -\frac{1}{k}\log(e^{-k a} + e^{-k b})$.

    Args:
        a: First input tensor.
        b: Second input tensor (same shape as ``a``).
        k: Sharpness; larger ``k`` approximates the hard minimum more
            closely.

    Returns:
        Element-wise smooth minimum of ``a`` and ``b``.
    """
    stacked = torch.stack([a, b], dim=-1)
    return -torch.logsumexp(-k * stacked, dim=-1) / k