Source code for ocetrac.SurfTrack

# ============================================================
# SurfTrack/__init__.py
# ============================================================
"""
SurfTrack — one layer (2-D + time) tracker.

Tracks geospatial features in 3-D data (time, lat, lon) using
morphological operations, connected-component labelling, and
area filtering.

Public API
----------
    from ocetrac.SurfTrack import SurfTracker
"""
from __future__ import annotations

import numpy as np
import xarray as xr

from .tracker import (
    apply_mask,
    filter_area,
    label_3d,
    label_temporal_neighbor,
    morphological_operations,
    wrap_labels,
)

from ..preprocessing.preprocessing import clean_binary

__all__ = ["SurfTracker"]


[docs] class SurfTracker: """ End-to-end surface marine heatwave tracker. Operates on 3-D data (time, ydim, xdim). Each step stores its output as an attribute so intermediate results can be inspected. Pipeline -------- clean → filter_area → track → postprocess Parameters ---------- da : DataArray (time, ydim, xdim) Anomaly or thresholded feature field. mask : DataArray (ydim, xdim) Binary ocean mask — 1 = valid ocean cell, 0 = land/ignored. radius : int Disk radius for morphological close→open. ``radius=1`` is a no-op. Larger values fill wider gaps but risk bridging separate events. Default 2. min_size_quartile : float Relative area percentile threshold (0–1). Combined with ``min_area_cells`` via ``max()``. Default 0.25. min_area_cells : int Absolute minimum object area in grid cells. Objects smaller than this are always removed regardless of the percentile. Default 100. timedim : str Name of the time dimension. Default ``"time"``. xdim : str Name of the x (longitude) dimension. Default ``"nlon"``. ydim : str Name of the y (latitude) dimension. Default ``"nlat"``. positive : bool ``True`` → track warm anomalies (>0). ``False`` → track cold anomalies (<0). Default ``True``. method : str Labelling method. ``"3d"`` (default) runs connected-component labelling on the full (time, lat, lon) volume. ``"temporal_neighbor"`` labels each 2-D frame independently and only links blobs between adjacent frames if they spatially overlap. A gap of even one timestep starts a new event. contain_thresh : float Only used when ``method="temporal_neighbor"``. Containment threshold for merging two blobs using max(|A∩B|/|A|, |A∩B|/|B|) >= contain_thresh. 0.0 means any spatial overlap is sufficient. Default 0.0. """
[docs] def __init__( self, da: xr.DataArray, mask: xr.DataArray, *, radius: int = 2, min_size_quartile: float = 0.25, min_area_cells: int = 100, timedim: str = "time", xdim: str = "nlon", ydim: str = "nlat", positive: bool = True, method: str = "3d", contain_thresh: float = 0.0, ) -> None: _valid_methods = {"3d", "temporal_neighbor"} if method not in _valid_methods: raise ValueError( f"method={method!r} not recognised. " f"Choose from {_valid_methods}." ) self.da = da self.mask = mask self.radius = radius self.min_size_quartile = min_size_quartile self.min_area_cells = min_area_cells self.timedim = timedim self.xdim = xdim self.ydim = ydim self.positive = positive self.method = method self.contain_thresh = contain_thresh # Intermediate state self.binary_clean: xr.DataArray | None = None self.binary_filtered: xr.DataArray | None = None self.area: xr.DataArray | None = None self.min_area: float | None = None self.N_initial: int | None = None self.labels_raw: np.ndarray | None = None self.result: xr.DataArray | None = None
# ── Step 1 ─────────────────────────────────────────────────────────────── def clean(self) -> "SurfTracker": """ Binarise the input field, apply morphological close→open per (lat, lon) slice, then apply the ocean mask. Outputs ------- self.binary_clean : DataArray — binary field after cleaning and masking """ if (self.mask == 0).all(): raise ValueError( "Found only zeros in `mask`. " "Mask should have 1 = valid ocean cells." ) print("Step 1 · morphological cleaning …") binary = morphological_operations( self.da, self.radius, self.xdim, self.ydim, self.positive ) self.binary_clean = apply_mask(binary, self.mask) frac = float(self.binary_clean.values.mean()) print(f" fraction flagged = {frac:.4f} " f"({'OK' if frac > 0 else 'WARNING: nothing flagged'})") return self # ── Step 2 ─────────────────────────────────────────────────────────────── def filter(self) -> "SurfTracker": """ Label 2-D slices, make IDs consecutive across time, wrap across the date line, then filter by area using: max(min_area_cells, percentile(areas, min_size_quartile)) Outputs ------- self.binary_filtered : DataArray — binary field with small objects removed self.area : DataArray — area of every detected object self.min_area : float — effective threshold applied self.N_initial : int — object count before filtering """ if self.binary_clean is None: self.clean() print("Step 2 · area filtering …") self.area, self.min_area, self.binary_filtered, self.N_initial = filter_area( self.binary_clean, xdim = self.xdim, ydim = self.ydim, min_size_quartile = self.min_size_quartile, min_area_cells = self.min_area_cells, ) return self # ── Step 3 ─────────────────────────────────────────────────────────────── def track(self) -> "SurfTracker": """ Label connected objects and wrap across the date line. Uses method='3d' (full volume CCL) or method='temporal_neighbor' (strict adjacent-timestep linking). See class parameters for details. Outputs ------- self.labels_raw : ndarray — integer labels before postprocessing """ if self.binary_filtered is None: self.filter() print(f"Step 3 · labelling (method={self.method!r}) …") if self.method == "temporal_neighbor": labels, num = label_temporal_neighbor( self.binary_filtered, contain_thresh=self.contain_thresh, ) else: labels, num = label_3d(self.binary_filtered, connectivity=3) grid_res = abs(self.da[self.xdim][1] - self.da[self.xdim][0]) if self.da[self.xdim][-1] - self.da[self.xdim][0] >= 360 - grid_res: self.labels_raw, N_final = wrap_labels(labels) else: self.labels_raw = labels N_final = int(np.max(labels)) print(f" initial objects : {self.N_initial}") print(f" final objects : {N_final}") return self # ── Step 4 ─────────────────────────────────────────────────────────────── def postprocess(self) -> "SurfTracker": """ Wrap the tracked array as an ``xr.DataArray`` with NaN background and attach tracking diagnostics as attributes. Outputs ------- self.result : DataArray — final event labels, NaN = background """ if self.labels_raw is None: raise RuntimeError("Call .track() before .postprocess()") print("Step 4 · wrapping result …") result = xr.DataArray( self.labels_raw, dims=self.da.dims, coords=self.da.coords ) result = result.where(result != 0, drop=False, other=np.nan) # Diagnostics N_final = self.n_events() sum_tot = int(np.sum(self.area.values)) pct_reject = (int(np.sum( self.area.where(self.area <= self.min_area, drop=True).values )) / sum_tot) pct_accept = 1.0 - pct_reject result = result.rename("labels") result.attrs["initial objects identified"] = self.N_initial result.attrs["final objects tracked"] = N_final result.attrs["radius"] = self.radius result.attrs["size quantile threshold"] = self.min_size_quartile result.attrs["min area cells"] = self.min_area_cells result.attrs["min area (effective)"] = self.min_area result.attrs["percent area reject"] = pct_reject result.attrs["percent area accept"] = pct_accept result.attrs["method"] = self.method result.attrs["contain_thresh"] = self.contain_thresh self.result = result return self # ── Full pipeline ───────────────────────────────────────────────────────── def run(self) -> xr.DataArray: """ Execute the full pipeline: clean → filter → track → postprocess Returns ------- result : DataArray (time, ydim, xdim) — event labels, NaN = background """ return ( self.clean() .filter() .track() .postprocess() .result ) # ── Diagnostics ─────────────────────────────────────────────────────────── def n_events(self) -> int: """Number of unique tracked events in the final result.""" arr = self.labels_raw if self.result is None else self.result.values if arr is None: raise RuntimeError("Run the tracker first.") flat = arr.ravel().astype(float) flat = flat[~np.isnan(flat) & (flat > 0)] return int(len(np.unique(flat))) def event_duration(self) -> dict[int, int]: """Return ``{event_id: n_timesteps_present}`` from the final result.""" if self.result is None: raise RuntimeError("Run the tracker first.") arr = self.result.values.astype(float) duration: dict[int, int] = {} for t in range(arr.shape[0]): plane = arr[t].ravel() for fid in np.unique(plane[~np.isnan(plane)]): if fid > 0: duration[int(fid)] = duration.get(int(fid), 0) + 1 return duration def summary(self) -> None: """Print event count, duration distribution, and parameters.""" if self.result is None: print("Not run yet — call .run() first.") return durations = self.event_duration() n = len(durations) print("=" * 55) print("SurfTracker — Result Summary") print("=" * 55) print(f" Input shape : {tuple(self.da.shape)}") print(f" Tracked events : {n}") if n > 0: durs = np.array(list(durations.values())) print(f" Duration min/median/max : " f"{durs.min()} / {int(np.median(durs))} / {durs.max()}") for thr in [1, 3, 6, 12]: print(f" >= {thr:2d} ts : {(durs >= thr).sum()}") print() print(" Parameters:") for k in ["radius", "min_area_cells", "min_size_quartile", "positive", "method", "contain_thresh"]: print(f" {k:<20} = {getattr(self, k)}") def __repr__(self) -> str: n = self.n_events() if self.result is not None else "(not run yet)" return ( f"SurfTracker(shape={tuple(self.da.shape)}, " f"radius={self.radius}, " f"method={self.method!r}, " f"events={n})" )