DeepTrack#
A 4D extension of Ocetrac for volumetric tracking#
1. Imports#
[17]:
import sys
import os
# Go up from notebooks/ to repo root
repo_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, repo_root)
from ocetrac.preprocessing.preprocessing import clean_binary, compute_anomalies, threshold_features
from ocetrac.preprocessing.utils import compute_dask_quantile, get_xarray_memory_usage
from ocetrac.DeepTrack import DeepTracker
from ocetrac.DeepTrack import grid, tracker
from ocetrac.DeepTrack import _wrap_longitude
from ocetrac.preprocessing.cesm2_lens_utils import get_ds_var
[2]:
import gc
import warnings
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
warnings.filterwarnings("ignore", message=".*decode the variable.*")
warnings.filterwarnings("ignore", message=".*default value for data_vars.*")
2. Data loading#
This section loads CESM2 Large Ensemble (CESM2-LENS) ocean temperature data for a single ensemble member. The CESM2-LE provides 100 ensemble members spanning 1850-2100.
Component:
ocn(ocean model component POP2)Temporal resolution: Monthly means
[3]:
%%time
# Select the ensemble member index
ens_memb_index = 0
# Define variable and component
# 'TEMP' = potential temperature, 'ocn' = ocean component
# of CESM
var, comp = 'TEMP', 'ocn'
# Construct path to CESM2 Large Ensemble data on NCAR's
# GLADE filesystem
directory = f'/glade/campaign/cgd/cesm/CESM2-LE/{comp}/proc/tseries/month_1/{var}/'
# Load historical and future projection datasets
ds_hist, ds_fut = get_ds_var(
directory,
var,
comp,
ens_memb_index)
# Define latitude slice indices
# These correspond to specific grid points
# in the ocean model's native grid
nlat_low, nlat_high = 26, 328
ds = ds_hist.TEMP.isel(
z_t = slice(0, 20), # Select top 20 vertical levels (surface to ~200m)
nlat = slice(nlat_low, nlat_high), # Select latitude range using indices
).sel(time=slice('1979-01', '2015-01')) # Select time range
ds = ds.compute() # Load data from disk into memory
CPU times: user 27.2 s, sys: 2.28 s, total: 29.4 s
Wall time: 1min 42s
/glade/u/home/cassiacai/.conda/envs/test_ocetrac/lib/python3.9/site-packages/numpy/lib/nanfunctions.py:1563: RuntimeWarning: All-NaN slice encountered
return function_base._ureduce(a,
[4]:
# Clean up future dataset to free memory
# ds_fut not needed for current analysis (only using historical period)
del ds_fut
# Force garbage collection to immediately release memory
gc.collect()
print(f"Loaded: {ds.shape}")
Loaded: (433, 20, 302, 320)
[5]:
## -------------- Figure of temperature at different z_t
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
im1 = ds[0, 0, :, :].plot.contourf(
ax=ax1,
levels=36,
vmin=-35,
vmax=35,
cmap='RdBu_r',
add_colorbar=False
)
ax1.set_title('TEMP at z_t=0 (t=0)', fontsize=12)
ax1.set_xlabel('Longitude')
ax1.set_ylabel('Latitude')
im2 = ds[0, 14, :, :].plot.contourf(
ax=ax2,
levels=36,
vmin=-35,
vmax=35,
cmap='RdBu_r',
add_colorbar=False
)
ax2.set_title('TEMP at z_t=14 (t=0)', fontsize=12)
ax2.set_xlabel('Longitude')
ax2.set_ylabel('Latitude')
cbar = plt.colorbar(im1, ax=[ax1, ax2], orientation='vertical', pad=0.05)
cbar.set_label('Temperature anomaly (°C)', fontsize=10)
plt.show()
3. Anomaly computation#
This preprocessing step is separate from the Ocetrac tracking algorithm. It prepares the temperature field by the trend and seasonality.
This example notebook uses preprocessing.compute_anomalies, which fits a 6-coefficient harmonic model per grid cell and returns the residual. preprocessing.threshold_features masks values below the 90th percentile.
[6]:
%%time
anom = compute_anomalies(
ds)
features, threshold_map = threshold_features(
anom,
q=0.9) # 90th quantile
print(f"features shape: {features.shape} ({features.nbytes/1e9:.2f} GB)")
[########################################] | 100% Completed | 57.95 s
[########################################] | 100% Completed | 4.69 sms
features shape: (433, 20, 302, 320) (6.70 GB)
CPU times: user 1min 8s, sys: 2.63 s, total: 1min 11s
Wall time: 1min 19s
[7]:
# Subset first 40 timesteps (months) for tutorial
features = features.isel(time=slice(40))
[8]:
## -------------- Figure of anom temperature at different z_t
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
im1 = anom[0, 0, :, :].plot.contourf(
ax=ax1,
levels=21,
vmin=-5,
vmax=5,
cmap='RdBu_r',
add_colorbar=False
)
ax1.set_title('Anom Temp at t=0, z_t=0', fontsize=12)
ax1.set_xlabel('Longitude')
ax1.set_ylabel('Latitude')
im2 = anom[0, 14, :, :].plot.contourf(
ax=ax2,
levels=21,
vmin=-5,
vmax=5,
cmap='RdBu_r',
add_colorbar=False
)
ax2.set_title('Anom Temp at t=0, z_t=14', fontsize=12)
ax2.set_xlabel('Longitude')
ax2.set_ylabel('Latitude')
cbar = plt.colorbar(im1, ax=[ax1, ax2], orientation='vertical', pad=0.05)
cbar.set_label('Temperature anomaly (°C)', fontsize=10)
plt.show()
[9]:
## -------------- Figure of threshold map at different z_t
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
im1 = threshold_map[0, :, :].plot.contourf(
ax=ax1,
levels=16,
vmin=0,
vmax=3,
cmap='Reds',
add_colorbar=False
)
ax1.set_title('Threshold Map at z_t=0', fontsize=12)
ax1.set_xlabel('Longitude')
ax1.set_ylabel('Latitude')
im2 = threshold_map[14, :, :].plot.contourf(
ax=ax2,
levels=16,
vmin=0,
vmax=3,
cmap='Reds',
add_colorbar=False
)
ax2.set_title('Threshold Map at z_t=14', fontsize=12)
ax2.set_xlabel('Longitude')
ax2.set_ylabel('Latitude')
cbar = plt.colorbar(im1, ax=[ax1, ax2], orientation='vertical', pad=0.05)
cbar.set_label('Threshold Temp (°C)', fontsize=10)
plt.show()
[10]:
## -------------- Figure of 90% threshold features at different z_t
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
im1 = features[0, 0, :, :].plot.contourf(
ax=ax1,
levels=16,
vmin=0,
vmax=3,
cmap='Reds',
add_colorbar=False
)
ax1.set_title('Features at t=0, z_t=0', fontsize=12)
ax1.set_xlabel('Longitude')
ax1.set_ylabel('Latitude')
im2 = features[0, 14, :, :].plot.contourf(
ax=ax2,
levels=16,
vmin=0,
vmax=3,
cmap='Reds',
add_colorbar=False
)
ax2.set_title('Features at t=0, z_t=14', fontsize=12)
ax2.set_xlabel('Longitude')
ax2.set_ylabel('Latitude')
cbar = plt.colorbar(im1, ax=[ax1, ax2], orientation='vertical', pad=0.05)
cbar.set_label('Features (°C)', fontsize=10)
plt.show()
4. Cell volume#
grid.build_cell_volume converts TAREA (cm²) × dz (cm) → m³ per voxel. Using POP_grid.
[11]:
%%time
grid_info = ds_hist.isel(
z_t=slice(0, 20),
nlat=slice(nlat_low, nlat_high)).sel(
time=slice('1979-01', '2015-01'))
TAREA = grid_info.TAREA.isel(time=0)
lat = grid_info['TLAT'].values
cell_volume = grid.build_cell_volume(TAREA,
ds_hist.z_t,
n_z=20).compute()
cell_vol_np = cell_volume.values
print("cell_volume shape:", cell_volume.shape)
print("lat shape :", lat.shape)
cell_volume shape: (20, 302, 320)
lat shape : (302, 320)
CPU times: user 187 ms, sys: 28.1 ms, total: 215 ms
Wall time: 267 ms
DeepTrack components#
Step 1 — Morphological cleaning#
preprocessing.clean_binary runs close→open per (t, z) slice. Closing fills holes and bridges small gaps; opening removes tiny residual blobs. radius controls the disk size.
[12]:
%%time
RADIUS = 3
binary_clean = clean_binary(
features,
radius=RADIUS,
positive=True).compute()
CPU times: user 4.2 s, sys: 4.92 ms, total: 4.2 s
Wall time: 4.38 s
[13]:
# Figure - Compare radii
t_check, z_check = 0, 0
fig, axes = plt.subplots(1, 3, figsize=(10, 4))
for ax, r in zip(axes, [1, 2, 3]):
c = clean_binary(features, radius=r, positive=True).compute()
ax.pcolormesh(c.values[t_check, z_check].astype(int), cmap='Blues', vmin=0, vmax=1)
ax.set_title(f'radius={r}'); ax.set_xlabel('nlon'); ax.set_ylabel('nlat')
plt.suptitle(f't={t_check} z={z_check} effect of radius', y=1.02)
plt.tight_layout(); plt.show(); plt.close()
[14]:
# Figure - Compare radii
t_check, z_check = 0, 14
fig, axes = plt.subplots(1, 3, figsize=(10, 4))
for ax, r in zip(axes, [1, 2, 3]):
c = clean_binary(features, radius=r, positive=True).compute()
ax.pcolormesh(c.values[t_check, z_check].astype(int), cmap='Blues', vmin=0, vmax=1)
ax.set_title(f'radius={r}'); ax.set_xlabel('nlon'); ax.set_ylabel('nlat')
plt.suptitle(f't={t_check} z={z_check} effect of radius', y=1.02)
plt.tight_layout(); plt.show(); plt.close()
Step 2 — 2-D connected-component labelling#
tracker.label_2d_stack runs scipy.ndimage.label on each (t, z) slice. Background = 0.
[15]:
%%time
labeled_2d = tracker.label_2d_stack(binary_clean)
CPU times: user 314 ms, sys: 65 µs, total: 314 ms
Wall time: 357 ms
[16]:
n_surf = [len(np.unique(labeled_2d.values[t, 0])) - 1 for t in range(labeled_2d.shape[0])]
print(f"2-D blobs at surface — mean: {np.mean(n_surf):.1f} min: {min(n_surf)} max: {max(n_surf)}")
labeled_2d_nan = xr.where(labeled_2d == 0, np.nan, labeled_2d)
fig, axes = plt.subplots(1, 3, figsize=(10, 4))
for i, ax in enumerate(axes):
ax.pcolormesh(labeled_2d_nan.values[i, 0], cmap='tab20')
ax.set_title(f't={i} 2-D labels (surface)'); ax.set_xlabel('nlon'); ax.set_ylabel('nlat')
plt.tight_layout(); plt.show(); plt.close()
fig, axes = plt.subplots(1, 3, figsize=(10, 4))
for i, ax in enumerate(axes):
ax.pcolormesh(labeled_2d_nan.values[i, 14], cmap='tab20')
ax.set_title(f't={i} 2-D labels (z_t=14)'); ax.set_xlabel('nlon'); ax.set_ylabel('nlat')
plt.tight_layout(); plt.show(); plt.close()
2-D blobs at surface — mean: 15.1 min: 8 max: 22
[18]:
labeled_2d_np = _wrap_longitude(labeled_2d.values)
labeled_2d = xr.DataArray(labeled_2d_np, dims=labeled_2d.dims, coords=labeled_2d.coords)
Step 3 — 2-D area filter#
tracker.filter_area_2d_global_depth drops small blobs using threshold = max(min_area_cells, percentile(all_areas_at_z, q)) computed across ALL timesteps per depth level.
[19]:
%%time
# absolute floor: blobs with fewer than this many grid
# cells are removed, regardless of percentile
MIN_AREA_CELLS = 200
# also drop blobs below the 25th percentile of the area
# distribution
MIN_QUANTILE = 0.25
filtered_np = tracker.filter_area_2d_global_depth(
labeled_2d.values,
min_quantile = MIN_QUANTILE,
min_area_cells = MIN_AREA_CELLS,
)
CPU times: user 1.01 s, sys: 4.09 ms, total: 1.02 s
Wall time: 1.06 s
[20]:
n_before = sum(len(np.unique(labeled_2d.values[t,z]))-1
for t in range(labeled_2d.shape[0]) for z in range(labeled_2d.shape[1]))
n_after = sum(len(np.unique(filtered_np[t,z][filtered_np[t,z]>0]))
for t in range(filtered_np.shape[0]) for z in range(filtered_np.shape[1]))
print(f"Total 2-D blobs before: {n_before}")
print(f"After area filter : {n_after} (removed {n_before - n_after})")
filtered_np_nan = xr.where(filtered_np == 0, np.nan, filtered_np)
fig, axes = plt.subplots(1, 3, figsize=(10, 4))
for i, ax in enumerate(axes):
ax.pcolormesh(filtered_np_nan[i, 0], cmap='tab20')
ax.set_title(f't={i} after area filter (surface)'); ax.set_xlabel('nlon'); ax.set_ylabel('nlat')
plt.tight_layout(); plt.show(); plt.close()
for i, ax in enumerate(axes):
ax.pcolormesh(filtered_np_nan[i, 14], cmap='tab20')
ax.set_title(f't={i} after area filter (z_t=14)'); ax.set_xlabel('nlon'); ax.set_ylabel('nlat')
plt.tight_layout(); plt.show(); plt.close()
Total 2-D blobs before: 15690
After area filter : 7317 (removed 8373)
<Figure size 640x480 with 0 Axes>
Step 4 — 3-D depth connectivity#
tracker.build_3d_objects runs scipy.ndimage.label in 3D on each timestep’s (z, nlat, nlon) volume using an anisotropic structuring element: full 8-connectivity in (lat, lon), face-only in z. Labels reset to 1..N at each timestep independently.
[21]:
%%time
CONNECT_Z = True
# Relabel per (t,z) so IDs are compact before 3-D labelling
relabeled_np = np.zeros_like(
filtered_np,
dtype=int)
for t in range(filtered_np.shape[0]):
for z in range(filtered_np.shape[1]):
relabeled_np[t, z] = tracker.relabel_2d(filtered_np[t, z])
struct = grid.make_anisotropic_struct(
connect_xy=True,
connect_z=CONNECT_Z)
labeled_3d = tracker.build_3d_objects(
relabeled_np,
struct)
labeled_3d = _wrap_longitude(labeled_3d) # merge date-line objects
CPU times: user 955 ms, sys: 160 ms, total: 1.11 s
Wall time: 1.16 s
[22]:
n_3d_per_t = [len(np.unique(labeled_3d[t][labeled_3d[t] > 0])) for t in range(labeled_3d.shape[0])]
print(f"connect_z={CONNECT_Z}")
print(f"3-D objects per timestep — mean: {np.mean(n_3d_per_t):.1f} "
f"min: {min(n_3d_per_t)} max: {max(n_3d_per_t)}")
print(f"Total (sum across t) : {sum(n_3d_per_t)}")
labeled_3d_nan = xr.where(labeled_3d == 0, np.nan, labeled_3d)
for t_check in range(3):
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].pcolormesh(labeled_3d_nan[t_check, 0], vmin=0, vmax=25, cmap='prism')
axes[0].axhline(y=150, color='red', linewidth=2.5, linestyle='--', alpha=0.8, label='Zonal section (nlat=150)')
axes[0].axvline(x=150, color='blue', linewidth=2.5, linestyle='--', alpha=0.8, label='Meridional section (nlon=150)')
axes[0].set_title(f't={t_check} surface'); axes[0].set_xlabel('nlon'); axes[0].set_ylabel('nlat')
axes[0].grid(True, color='k', linewidth=0.5, alpha=0.5)
axes[1].pcolormesh(labeled_3d_nan[t_check, :, 150, :], vmax=31, cmap='prism')
axes[1].invert_yaxis(); axes[1].set_title('Zonal cross-section nlat=150')
axes[1].set_xlabel('nlon'); axes[1].set_ylabel('depth level')
axes[1].grid(True, color='k', linewidth=0.5, alpha=0.5)
axes[2].pcolormesh(labeled_3d_nan[t_check, :, :, 150], vmax=31, cmap='prism')
axes[2].invert_yaxis(); axes[2].set_title('Meridional cross-section nlon=150')
axes[2].set_xlabel('nlat'); axes[2].set_ylabel('depth level')
axes[2].grid(True, color='k', linewidth=0.5, alpha=0.5)
plt.tight_layout(); plt.show(); plt.close()
connect_z=True
3-D objects per timestep — mean: 20.7 min: 13 max: 30
Total (sum across t) : 827
Step 5 — Global volume filter#
tracker.filter_preserve_labels_global drops the bottom frac_filter fraction of all 3-D objects by voxel count summed across ALL timesteps.
[26]:
%%time
FRAC_FILTER = 0.25
filtered_labels = tracker.filter_preserve_labels_global(
labeled_3d,
frac=FRAC_FILTER)
CPU times: user 25.6 s, sys: 282 ms, total: 25.9 s
Wall time: 28.6 s
[27]:
n_pre_per_t = [len(np.unique(filtered_labels[t][filtered_labels[t] > 0])) for t in range(filtered_labels.shape[0])]
dropped = sum(n_3d_per_t) - sum(n_pre_per_t)
print(f"After prefilter — mean per t: {np.mean(n_pre_per_t):.1f} total: {sum(n_pre_per_t)}")
print(f"Dropped {dropped} objects ({dropped / max(sum(n_3d_per_t),1) * 100:.1f} %)")
After prefilter — mean per t: 19.8 total: 793
Dropped 34 objects (4.1 %)
[28]:
filtered_labels_nan = xr.where(filtered_labels == 0, np.nan, filtered_labels)
[29]:
for t_check in range(3):
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].pcolormesh(filtered_labels_nan[t_check, 0], vmin=0, vmax=31, cmap='prism')
axes[0].axhline(y=150, color='red', linewidth=2.5, linestyle='--', alpha=0.8, label='Zonal section (nlat=150)')
axes[0].axvline(x=150, color='blue', linewidth=2.5, linestyle='--', alpha=0.8, label='Meridional section (nlon=150)')
axes[0].set_title(f't={t_check} surface'); axes[0].set_xlabel('nlon'); axes[0].set_ylabel('nlat')
axes[0].grid(True, color='k', linewidth=0.5, alpha=0.5)
axes[1].pcolormesh(filtered_labels_nan[t_check, :, 150, :], vmin=0, vmax=31, cmap='prism')
axes[1].invert_yaxis(); axes[1].set_title('Zonal cross-section nlat=150')
axes[1].set_xlabel('nlon'); axes[1].set_ylabel('depth level')
axes[1].grid(True, color='k', linewidth=0.5, alpha=0.5)
axes[2].pcolormesh(filtered_labels_nan[t_check, :, :, 150], vmin=0, vmax=31, cmap='prism')
axes[2].invert_yaxis(); axes[2].set_title('Meridional cross-section nlon=150')
axes[2].set_xlabel('nlat'); axes[2].set_ylabel('depth level')
axes[2].grid(True, color='k', linewidth=0.5, alpha=0.5)
plt.tight_layout(); plt.show(); plt.close()
Step 6 — Containment tracking with lineage preservation#
tracker.track_objects_with_splitting links objects across time using:
score = max(|A∩B|/|A|, |A∩B|/|B|)
Dividing by the smaller object means a fragment fully contained in a parent scores 1.0, correctly linking split children. Each current object independently picks the parent with the smallest original ID above contain_thresh. Unmatched objects become new events.
[32]:
%%time
CONTAIN_THRESH = 0.3
ALPHA = 0.5
depth_connected = filtered_labels
tracked, origin_map = tracker.track_objects_with_splitting(
depth_connected,
volume_weights = cell_vol_np,
contain_thresh = CONTAIN_THRESH,
alpha = ALPHA,
plot_results = False,
)
CPU times: user 16.9 s, sys: 75.3 ms, total: 17 s
Wall time: 19.2 s
[33]:
tracked_nan = xr.where(tracked == 0, np.nan, tracked)
[45]:
for t_check in range(3):
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
# ===== SURFACE PLOT =====
surface_data = tracked_nan[t_check, 0]
im0 = axes[0].pcolormesh(surface_data, vmin=0, vmax=256, cmap='prism')
axes[0].set_title(f't={t_check} surface')
axes[0].set_xlabel('nlon')
axes[0].set_ylabel('nlat')
axes[0].grid(True, color='k', linewidth=0.5, alpha=0.5)
# Add labels for each object on surface
unique_ids = np.unique(surface_data)
unique_ids = unique_ids[unique_ids > 0]
for obj_id in unique_ids:
mask = (surface_data == obj_id)
if mask.any():
# Get centroid of the object
y_center, x_center = np.mean(np.where(mask), axis=1)
# Get the actual object label
obj_label = int(obj_id)
axes[0].text(x_center, y_center, str(obj_label),
color='white', fontsize=12, fontweight='bold',
ha='center', va='center',
bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
# ===== ZONAL CROSS-SECTION =====
zonal_data = tracked_nan[t_check, :, 150, :]
im1 = axes[1].pcolormesh(zonal_data, vmin=0, vmax=256, cmap='prism')
axes[1].invert_yaxis()
axes[1].set_title('Zonal cross-section nlat=150')
axes[1].set_xlabel('nlon')
axes[1].set_ylabel('depth level')
axes[1].grid(True, color='k', linewidth=0.5, alpha=0.5)
# Add labels for each object in zonal section
unique_ids = np.unique(zonal_data)
unique_ids = unique_ids[unique_ids > 0]
for obj_id in unique_ids:
mask = (zonal_data == obj_id)
if mask.any():
# Get centroid (depth, lon)
z_center, x_center = np.mean(np.where(mask), axis=1)
obj_label = int(obj_id)
axes[1].text(x_center, z_center, str(obj_label),
color='white', fontsize=12, fontweight='bold',
ha='center', va='center',
bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
# ===== MERIDIONAL CROSS-SECTION =====
merid_data = tracked_nan[t_check, :, :, 150]
im2 = axes[2].pcolormesh(merid_data, vmin=0, vmax=256, cmap='prism')
axes[2].invert_yaxis()
axes[2].set_title('Meridional cross-section nlon=150')
axes[2].set_xlabel('nlat')
axes[2].set_ylabel('depth level')
axes[2].grid(True, color='k', linewidth=0.5, alpha=0.5)
# Add labels for each object in meridional section
unique_ids = np.unique(merid_data)
unique_ids = unique_ids[unique_ids > 0]
for obj_id in unique_ids:
mask = (merid_data == obj_id)
if mask.any():
# Get centroid (depth, lat)
z_center, y_center = np.mean(np.where(mask), axis=1)
obj_label = int(obj_id)
axes[2].text(y_center, z_center, str(obj_label),
color='white', fontsize=12, fontweight='bold',
ha='center', va='center',
bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
plt.tight_layout()
plt.show()
plt.close()
Dropout summary#
[37]:
n_2d = sum(len(np.unique(labeled_2d.values[t, z])) - 1
for t in range(labeled_2d.shape[0])
for z in range(labeled_2d.shape[1]))
n_3d = sum(n_3d_per_t)
n_before_wrap = n_3d
labeled_3d = _wrap_longitude(labeled_3d)
n_wrap = sum(len(np.unique(labeled_3d[t][labeled_3d[t] > 0]))
for t in range(labeled_3d.shape[0]))
n_pre = sum(n_pre_per_t)
n_final = len(np.unique(tracked_nan.data)) - 1
print(f"2-D blobs (all t, all z) : {n_2d:>6}")
print(f"After 3-D labelling : {n_3d:>6} (area filter removed small blobs)")
print(f"After date-line wrap : {n_wrap:>6} (merged {n_before_wrap - n_wrap})")
print(f"After global filter : {n_pre:>6} (dropped {n_wrap - n_pre})")
print(f"After containment tracking : {n_final}")
2-D blobs (all t, all z) : 15690
After 3-D labelling : 827 (area filter removed small blobs)
After date-line wrap : 827 (merged 0)
After global filter : 793 (dropped 34)
After containment tracking : 243
OR RUN EVERYTHING VIA DeepTracker#
The DeepTracker class wraps the entire pipeline in one call.
[36]:
%%time
tracker = DeepTracker(
features,
radius = 3,
min_area_cells = 200,
min_quantile = 0.25,
contain_thresh = 0.3,
alpha = 0.5,
frac_filter = 0.25,
connect_z = True,
positive = True,
n_z = 20,
wrap_lon = True, # <-- add this
)
result_full = tracker.run(cell_volume=cell_vol_np)
tracker.summary()
Step 1 · morphological cleaning …
fraction flagged warm = 0.0899 (OK)
Step 2 · 2-D connected-component labelling …
Step 2b · 2-D longitude wrap
merged 0 date-line objects (35 → 35)
surface blobs — mean=14.1 min=7 max=20
Step 3a · area filtering …
2-D blobs: 15,690 → 7,317 (removed 8,373)
Step 3b · 3-D depth connectivity …
Step 3c · 3-D longitude wrap
merged 23 date-line objects (850 → 827)
3-D objects/timestep — mean=20.7 min=13 max=30 total=827
Step 4 · global volume filter …
3-D objects: 827 → 793 (removed 34)
Step 5 · containment tracking …
unique event IDs assigned: 243
Step 6 · wrapping result …
final events: 243
=======================================================
DeepTracker — Result Summary
=======================================================
Input shape : (40, 20, 302, 320)
Tracked events : 243
Duration min/median/max : 1 / 1 / 40
>= 1 ts : 243
>= 3 ts : 46
>= 6 ts : 14
>= 12 ts : 6
Parameters:
radius = 3
min_area_cells = 200
min_quantile = 0.25
frac_filter = 0.25
contain_thresh = 0.3
alpha = 0.5
connect_z = True
wrap_lon = True
=======================================================
CPU times: user 1min 5s, sys: 773 ms, total: 1min 6s
Wall time: 1min 13s
[38]:
result_xr = xr.DataArray(
result_full,
dims=features.dims,
coords=features.coords,
name='events_tracked')
# result_xr.to_netcdf('deeptrack_events.nc')
[39]:
for t_check in range(3):
fig, axes = plt.subplots(1, 3, figsize=(15, 3))
# ===== SURFACE PLOT =====
surface_data = result_xr[t_check, 0]
im0 = axes[0].pcolormesh(surface_data, vmin=0, vmax=256, cmap='prism')
axes[0].set_title(f't={t_check} surface')
axes[0].set_xlabel('nlon')
axes[0].set_ylabel('nlat')
axes[0].grid(True, color='k', linewidth=0.5, alpha=0.5)
# Add labels for each object on surface
unique_ids = np.unique(surface_data)
unique_ids = unique_ids[unique_ids > 0]
for obj_id in unique_ids:
mask = (surface_data == obj_id)
if mask.any():
# Get centroid of the object
y_center, x_center = np.mean(np.where(mask), axis=1)
# Get the actual object label
obj_label = int(obj_id)
axes[0].text(x_center, y_center, str(obj_label),
color='white', fontsize=8, fontweight='bold',
ha='center', va='center',
bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
# ===== ZONAL CROSS-SECTION =====
zonal_data = result_xr[t_check, :, 150, :]
im1 = axes[1].pcolormesh(zonal_data, vmin=0, vmax=256, cmap='prism')
axes[1].invert_yaxis()
axes[1].set_title('Zonal cross-section nlat=150')
axes[1].set_xlabel('nlon')
axes[1].set_ylabel('depth level')
axes[1].grid(True, color='k', linewidth=0.5, alpha=0.5)
# Add labels for each object in zonal section
unique_ids = np.unique(zonal_data)
unique_ids = unique_ids[unique_ids > 0]
for obj_id in unique_ids:
mask = (zonal_data == obj_id)
if mask.any():
# Get centroid (depth, lon)
z_center, x_center = np.mean(np.where(mask), axis=1)
obj_label = int(obj_id)
axes[1].text(x_center, z_center, str(obj_label),
color='white', fontsize=8, fontweight='bold',
ha='center', va='center',
bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
# ===== MERIDIONAL CROSS-SECTION =====
merid_data = result_xr[t_check, :, :, 150]
im2 = axes[2].pcolormesh(merid_data, vmin=0, vmax=256, cmap='prism')
axes[2].invert_yaxis()
axes[2].set_title('Meridional cross-section nlon=150')
axes[2].set_xlabel('nlat')
axes[2].set_ylabel('depth level')
axes[2].grid(True, color='k', linewidth=0.5, alpha=0.5)
# Add labels for each object in meridional section
unique_ids = np.unique(merid_data)
unique_ids = unique_ids[unique_ids > 0]
for obj_id in unique_ids:
mask = (merid_data == obj_id)
if mask.any():
# Get centroid (depth, lat)
z_center, y_center = np.mean(np.where(mask), axis=1)
obj_label = int(obj_id)
axes[2].text(y_center, z_center, str(obj_label),
color='white', fontsize=8, fontweight='bold',
ha='center', va='center',
bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
plt.colorbar(im0, ax=axes[0], label='Object ID')
plt.colorbar(im1, ax=axes[1], label='Object ID')
plt.colorbar(im2, ax=axes[2], label='Object ID')
plt.tight_layout()
plt.show()
plt.close()
Plotting 3D plots of events#
[40]:
from ocetrac.plotting.DeepTrack_plot_utils import plot_3d_labeled_feature
[41]:
for time_idx in range(3):
print(f"\n{'='*50}")
print(f"Plotting feature 5 at time {time_idx}")
print(f"{'='*50}")
fig, ax = plot_3d_labeled_feature(
result_xr,
feature_id=5,
time_idx=time_idx,
threshold=0.5,
sigma=0.5, # Slight smoothing for cleaner surface
alpha=0.7,
cmap='plasma',
elev=25,
azim=-60,
figsize=(12, 8)
)
if fig is not None:
plt.show()
plt.close(fig)
==================================================
Plotting feature 5 at time 0
==================================================
Applied Gaussian smoothing (sigma=0.5)
Extracted surface: 3167 vertices, 6110 faces
==================================================
Plotting feature 5 at time 1
==================================================
Applied Gaussian smoothing (sigma=0.5)
Extracted surface: 2997 vertices, 5794 faces
==================================================
Plotting feature 5 at time 2
==================================================
Applied Gaussian smoothing (sigma=0.5)
Extracted surface: 2934 vertices, 5654 faces
[44]:
for time_idx in range(5,8):
print(f"\n{'='*50}")
print(f"Plotting feature 1 at time {time_idx}") # also 4 is interested
print(f"{'='*50}")
fig, ax = plot_3d_labeled_feature(
result_xr,
feature_id=1,
time_idx=time_idx,
threshold=0.5,
sigma=0.2, # Slight smoothing for cleaner surface
alpha=0.7,
cmap='plasma',
elev=25,
azim=-60,
figsize=(12, 8)
)
if fig is not None:
plt.show()
plt.close(fig)
==================================================
Plotting feature 1 at time 5
==================================================
Applied Gaussian smoothing (sigma=0.2)
Extracted surface: 11367 vertices, 22190 faces
==================================================
Plotting feature 1 at time 6
==================================================
Applied Gaussian smoothing (sigma=0.2)
Extracted surface: 13798 vertices, 27170 faces
==================================================
Plotting feature 1 at time 7
==================================================
Applied Gaussian smoothing (sigma=0.2)
Extracted surface: 13880 vertices, 27280 faces
[ ]: