Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 106 additions & 19 deletions climanet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

import numpy as np
from .utils import add_month_day_dims, calc_stats
from .geo_embedding_utils import (
calculate_sh_geo_pos_embeddings,
compute_patch_geo_pos_embedding,
)
from .geo_embedding_utils import compute_patch_scale_features
import xarray as xr
import torch
from torch.utils.data import Dataset
Expand All @@ -20,13 +25,19 @@ def __init__(
spatial_dims: Tuple[str, str] = ("lat", "lon"),
patch_size: Tuple[int, int] = (16, 16), # (lat, lon)
stride: Tuple[int, int] = None,
sh_pos_table: str = None, # Optional; str formatted path to precomputed table of sh
sh_embed_dim: int = 96, # sh_embed_dim should <= (sh_order_L + 1)**2
sh_order_L: int = 10,
):
self.spatial_dims = spatial_dims
self.patch_size = patch_size
self.daily_da = daily_da
self.monthly_da = monthly_da
self.stride = stride if stride is not None else patch_size

self.sh_embed_dim = sh_embed_dim
self.sh_order_L = sh_order_L

# Check that the input data has the expected dimensions
Comment thread
SarahAlidoost marked this conversation as resolved.
if time_dim not in daily_da.dims or time_dim not in monthly_da.dims:
raise ValueError(f"Time dimension '{time_dim}' not found in input data")
Expand All @@ -52,7 +63,7 @@ def __init__(
self.daily_np = daily_mt.to_numpy().copy() # (M, T=31, H, W) float
self.monthly_np = monthly_m.to_numpy().copy() # (M, H, W) float
self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool
self.daily_timef_np = daily_timef.to_numpy().copy() # (M,T=31, 4)
self.daily_timef_np = daily_timef.to_numpy().copy() # (M,T=31, 4)

# Store coordinate arrays
self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy()
Expand Down Expand Up @@ -84,6 +95,28 @@ def __init__(
H, W = self.daily_np.shape[2], self.daily_np.shape[3]
self.patch_indices = self._compute_patch_indices(H, W)

# Precompute geoposition and scale embeddings for patches
self.sh_geo_pos = None
self.geo_pos_table = self._get_geo_pos(sh_pos_table)
self.patch_geo_embeddings, self.patch_scale_features = (
self._compute_geoscalepatch_embeddings()
)

def _get_geo_pos(self, sh_pos_table: str):
"""Calculate or retrieve spherical harmonics based geo position embeddings."""
if sh_pos_table is None:
self.sh_geo_pos = calculate_sh_geo_pos_embeddings(
self.lat_coords, self.lon_coords, self.sh_order_L, self.sh_embed_dim
)
else:
# load then set embed dim and sh order L from here
raise (RuntimeError("load method not implemented"))
# TODO implement load functionality. loaded tensor should
# be placed in self.sh_geo_pos. return sh_pos_table to
# preserve provenance in dataset. IMPORTANT check
# compatability of L and sh_dim between requested
# and loaded. Raise error if not consistent

def _compute_patch_indices(self, H: int, W: int) -> list:
"""Generate patch start indices with coverage warning (overlap support)."""
ph, pw = self.patch_size
Expand Down Expand Up @@ -122,11 +155,43 @@ def _compute_patch_indices(self, H: int, W: int) -> list:

overlap_h = ph - sh if sh < ph else 0
overlap_w = pw - sw if sw < pw else 0
print(f"Patch grid: {len(i_starts)} x {len(j_starts)} = {len(i_starts) * len(j_starts)} patches")
print(
f"Patch grid: {len(i_starts)} x {len(j_starts)} = {len(i_starts) * len(j_starts)} patches"
)
print(f"Overlap: {overlap_h} pixels (height), {overlap_w} pixels (width)")

return [(i, j) for i in i_starts for j in j_starts]

def _compute_geoscalepatch_embeddings(self):
patch_geo_embeddings = []
patch_scale_features = []

for i, j in self.patch_indices:
ph, pw = self.patch_size
geo_pos_tensor = self.sh_geo_pos[
i : i + ph,
j : j + pw,
]
lat_patch = self.lat_coords[i : i + ph]
lon_patch = self.lon_coords[j : j + pw]

geo_emb = compute_patch_geo_pos_embedding(
geo_pos_tensor,
lat_patch,
)
scale_feat = compute_patch_scale_features(
lat_patch,
lon_patch,
)

patch_geo_embeddings.append(geo_emb)
patch_scale_features.append(scale_feat)

patch_geo_embeddings = torch.stack(patch_geo_embeddings)
patch_scale_features = torch.stack(patch_scale_features)

return patch_geo_embeddings, patch_scale_features

def __len__(self):
return len(self.patch_indices)

Expand All @@ -140,49 +205,71 @@ def __getitem__(self, idx):
ph, pw = self.patch_size

# Extract spatial patch via numpy slicing — faster than xarray indexing
daily_patch = self.daily_np[:, :, i : i + ph, j : j + pw] # (M, T, H, W)
monthly_patch = self.monthly_np[:, i : i + ph, j : j + pw] # (M, H, W)
daily_patch = self.daily_np[
:, :, i : i + ph, j : j + pw
] # (M, T, H, W) -> (M,T,pH, pW)
monthly_patch = self.monthly_np[
:, i : i + ph, j : j + pw
] # (M, H, W) -> (M, pH, pW)
daily_nan_mask = self.daily_nan_mask[
:, :, i : i + ph, j : j + pw
] # (M, T, H, W)
] # (M, T, H, W) -> (M, T, pH, pW)

if self.land_mask_np is not None:
land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W)
land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W) -> (pH,pW)
land_tensor = torch.from_numpy(land_patch.copy()).bool()
else:
land_tensor = torch.zeros(ph, pw, dtype=torch.bool)

# geo_pos_tensor = self.sh_geo_pos[i: i + ph, j: j + pw] # (H,W, sh_emb_dim) -> (pH, pW, sh_embed_dim)

# Convert to tensors (from_numpy is zero-copy on contiguous arrays)
# (1, M, T, H, W)
# (1, M, T, pH, pW)
daily_tensor = torch.from_numpy(daily_patch).float().unsqueeze(0)
# (M, H, W)
# (M, pH, pW)
monthly_tensor = torch.from_numpy(monthly_patch).float()
# (1, M, T, H, W)
# (1, M, T, pH, pW)
daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0)
# ( M, T, 2)
daily_timef_tensor = torch.from_numpy(self.daily_timef_np).float()

# daily_mask: NaN locations that are NOT land
# Reshape land_tensor for broadcasting: (H, W) → (1, 1, 1, H, W)
# Reshape land_tensor for broadcasting: (pH, pW) → (1, 1, 1, pH, pW)
daily_mask_tensor = daily_nan_mask & (
~land_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(0)
)

# Extract lat/lon coordinates for this patch
lat_patch = self.lat_coords[i : i + ph]
lon_patch = self.lon_coords[j : j + pw]
lat_patch = self.lat_coords[i : i + ph] # (H,) -> (pH,)
lon_patch = self.lon_coords[j : j + pw] # (W,) -> (pW,)

# get patch geo pos embedding
geo_pos_embedding_tensor = self.patch_geo_embeddings[idx] # (sh_dim,)

# get scale feature for patch
scale_feature_tensor = self.patch_scale_features[idx] # (10,)

# create tensors to pass sh embedding dimension, harmonic order, and scale feature dim
sh_embed_dim = torch.tensor(self.sh_embed_dim)
harmonic_order = torch.tensor(self.sh_order_L)
scale_f_dim = torch.tensor(len(scale_feature_tensor))

# Convert to tensors
return {
"daily_patch": daily_tensor, # (C=1, M, T=31, H, W)
"monthly_patch": monthly_tensor, # (M, H, W)
"daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, H, W)
"land_mask_patch": land_tensor, # (H,W) True=Land
"daily_timef_patch": daily_timef_tensor, #(M, T=31, 2)
"daily_patch": daily_tensor, # (C=1, M, T=31, pH, pW)
"monthly_patch": monthly_tensor, # (M, pH, pW)
"daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, pH, pW)
"land_mask_patch": land_tensor, # (pH,pW) True=Land
"daily_timef_patch": daily_timef_tensor, # (M, T=31, 2)
"padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded
"scale_feature_patch": scale_feature_tensor, # (10,)
"geo_pos_embedding_patch": geo_pos_embedding_tensor, # (sh_embed_dim,)
"sh_embed_dim": sh_embed_dim,
"harmonic_order": harmonic_order,
"scale_f_dim": scale_f_dim,
"coords": (i, j),
"lat_patch": lat_patch, # (H,)
"lon_patch": lon_patch, # (W,)
"lat_patch": lat_patch, # (pH,)
"lon_patch": lon_patch, # (pW,)
}

def compute_stats(self, indices: list = None) -> Tuple[np.ndarray, np.ndarray]:
Expand Down
Loading