Skip to content

vis

Plotting helpers for state trajectories.

matplotlib is an optional dependency — the functions here import it lazily and raise a helpful error if it is not installed.

plot_trajectories(traj, title='', ylabel='', labels=None)

Plot one or more trajectories.

Parameters:

Name Type Description Default
traj torch.Tensor

Trajectory tensor of shape [T] (single series) or [B, T] (batched series — one line per batch element).

required
title str

Figure title.

''
ylabel str

Y-axis label.

''
labels Iterable[str] | None

Optional iterable of legend labels, one per batch element. Ignored for 1-D input. Defaults to "batch i".

None

Returns:

Type Description

Tuple (fig, ax) of the created Matplotlib figure and axes.

Source code in torchcrop/utils/vis.py
def plot_trajectories(
    traj: torch.Tensor,
    title: str = "",
    ylabel: str = "",
    labels: Iterable[str] | None = None,
):
    """Plot one or more trajectories.

    Args:
        traj: Trajectory tensor of shape ``[T]`` (single series) or
            ``[B, T]`` (batched series — one line per batch element).
        title: Figure title.
        ylabel: Y-axis label.
        labels: Optional iterable of legend labels, one per batch element.
            Ignored for 1-D input. Defaults to ``"batch i"``.

    Returns:
        Tuple ``(fig, ax)`` of the created Matplotlib figure and axes.
    """
    plt = _require_matplotlib()
    fig, ax = plt.subplots(figsize=(8, 4))
    y = traj.detach().cpu().numpy()
    if y.ndim == 1:
        ax.plot(y, label=None)
    else:
        labels = labels if labels is not None else [f"batch {i}" for i in range(y.shape[0])]
        for i, row in enumerate(y):
            ax.plot(row, label=labels[i])
        ax.legend(fontsize=8)
    ax.set_xlabel("day")
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    fig.tight_layout()
    return fig, ax