vis¶
Plotting helpers for state trajectories.
matplotlib is an optional dependency — the functions here import it
lazily and raise a helpful error if it is not installed.
plot_trajectories(traj, title='', ylabel='', labels=None)
¶
Plot one or more trajectories.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
traj |
torch.Tensor |
Trajectory tensor of shape |
required |
title |
str |
Figure title. |
'' |
ylabel |
str |
Y-axis label. |
'' |
labels |
Iterable[str] | None |
Optional iterable of legend labels, one per batch element.
Ignored for 1-D input. Defaults to |
None |
Returns:
| Type | Description |
|---|---|
Tuple |
Source code in torchcrop/utils/vis.py
def plot_trajectories(
traj: torch.Tensor,
title: str = "",
ylabel: str = "",
labels: Iterable[str] | None = None,
):
"""Plot one or more trajectories.
Args:
traj: Trajectory tensor of shape ``[T]`` (single series) or
``[B, T]`` (batched series — one line per batch element).
title: Figure title.
ylabel: Y-axis label.
labels: Optional iterable of legend labels, one per batch element.
Ignored for 1-D input. Defaults to ``"batch i"``.
Returns:
Tuple ``(fig, ax)`` of the created Matplotlib figure and axes.
"""
plt = _require_matplotlib()
fig, ax = plt.subplots(figsize=(8, 4))
y = traj.detach().cpu().numpy()
if y.ndim == 1:
ax.plot(y, label=None)
else:
labels = labels if labels is not None else [f"batch {i}" for i in range(y.shape[0])]
for i, row in enumerate(y):
ax.plot(row, label=labels[i])
ax.legend(fontsize=8)
ax.set_xlabel("day")
ax.set_ylabel(ylabel)
ax.set_title(title)
fig.tight_layout()
return fig, ax