weather¶
Weather driver tensor.
Shape [B, T, C_weather] with channels, in order:
1 2 3 4 5 6 7 8 | |
WeatherDriver
dataclass
¶
Thin wrapper around a [B, T, C] weather tensor.
Provides named channel access and input validation. The underlying tensor
is accessible as .data for direct use in process modules.
Attributes:
| Name | Type | Description |
|---|---|---|
data |
Tensor |
The underlying weather tensor of shape |
Exceptions:
| Type | Description |
|---|---|
ValueError |
If |
Source code in torchcrop/drivers/weather.py
@dataclass
class WeatherDriver:
"""Thin wrapper around a ``[B, T, C]`` weather tensor.
Provides named channel access and input validation. The underlying tensor
is accessible as ``.data`` for direct use in process modules.
Attributes:
data: The underlying weather tensor of shape ``[B, T, C]``.
Raises:
ValueError: If ``data`` is not 3-D or does not carry exactly
`N_WEATHER_CHANNELS` channels.
"""
data: torch.Tensor
def __post_init__(self) -> None:
if self.data.dim() != 3:
raise ValueError(
f"weather tensor must be 3-D [B, T, C]; got {tuple(self.data.shape)}"
)
if self.data.shape[-1] != N_WEATHER_CHANNELS:
raise ValueError(
f"weather tensor must have {N_WEATHER_CHANNELS} channels; "
f"got {self.data.shape[-1]}"
)
@property
def batch_size(self) -> int:
"""Leading batch dimension ``B``."""
return self.data.shape[0]
@property
def n_days(self) -> int:
"""Number of daily time steps ``T``."""
return self.data.shape[1]
def day(self, t: int) -> dict[str, torch.Tensor]:
"""Return a dict of named channels for day ``t``.
Args:
t: Time index in ``[0, T)``.
Returns:
Dict mapping channel name to a ``[B]`` tensor.
"""
slice_t = self.data[:, t, :]
return {name: slice_t[:, i] for i, name in enumerate(WEATHER_CHANNELS)}
def channel(self, name: str) -> torch.Tensor:
"""Return the full ``[B, T]`` trajectory of a named channel.
Args:
name: Channel name (see `WEATHER_CHANNELS`).
Returns:
A ``[B, T]`` tensor view of the requested channel.
Raises:
KeyError: If ``name`` is not a known weather channel.
"""
try:
idx = WEATHER_CHANNELS.index(name)
except ValueError as e:
raise KeyError(f"unknown weather channel: {name!r}") from e
return self.data[:, :, idx]
def to(self, dtype: torch.dtype | None = None, device: torch.device | str | None = None) -> "WeatherDriver":
"""Return a new `WeatherDriver` cast/moved to ``dtype``/``device``.
Args:
dtype: Target tensor dtype, or ``None`` to leave unchanged.
device: Target torch device, or ``None`` to leave unchanged.
Returns:
A new `WeatherDriver` wrapping the cast tensor.
"""
return WeatherDriver(self.data.to(dtype=dtype, device=device))
batch_size: int
property
readonly
¶
Leading batch dimension B.
n_days: int
property
readonly
¶
Number of daily time steps T.
channel(self, name)
¶
Return the full [B, T] trajectory of a named channel.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name |
str |
Channel name (see |
required |
Returns:
| Type | Description |
|---|---|
torch.Tensor |
A |
Exceptions:
| Type | Description |
|---|---|
KeyError |
If |
Source code in torchcrop/drivers/weather.py
def channel(self, name: str) -> torch.Tensor:
"""Return the full ``[B, T]`` trajectory of a named channel.
Args:
name: Channel name (see `WEATHER_CHANNELS`).
Returns:
A ``[B, T]`` tensor view of the requested channel.
Raises:
KeyError: If ``name`` is not a known weather channel.
"""
try:
idx = WEATHER_CHANNELS.index(name)
except ValueError as e:
raise KeyError(f"unknown weather channel: {name!r}") from e
return self.data[:, :, idx]
day(self, t)
¶
Return a dict of named channels for day t.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t |
int |
Time index in |
required |
Returns:
| Type | Description |
|---|---|
dict[str, torch.Tensor] |
Dict mapping channel name to a |
Source code in torchcrop/drivers/weather.py
def day(self, t: int) -> dict[str, torch.Tensor]:
"""Return a dict of named channels for day ``t``.
Args:
t: Time index in ``[0, T)``.
Returns:
Dict mapping channel name to a ``[B]`` tensor.
"""
slice_t = self.data[:, t, :]
return {name: slice_t[:, i] for i, name in enumerate(WEATHER_CHANNELS)}
to(self, dtype=None, device=None)
¶
Return a new WeatherDriver cast/moved to dtype/device.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dtype |
torch.dtype | None |
Target tensor dtype, or |
None |
device |
torch.device | str | None |
Target torch device, or |
None |
Returns:
| Type | Description |
|---|---|
'WeatherDriver' |
A new |
Source code in torchcrop/drivers/weather.py
def to(self, dtype: torch.dtype | None = None, device: torch.device | str | None = None) -> "WeatherDriver":
"""Return a new `WeatherDriver` cast/moved to ``dtype``/``device``.
Args:
dtype: Target tensor dtype, or ``None`` to leave unchanged.
device: Target torch device, or ``None`` to leave unchanged.
Returns:
A new `WeatherDriver` wrapping the cast tensor.
"""
return WeatherDriver(self.data.to(dtype=dtype, device=device))