Source code for phenocoder.sampling
from __future__ import annotations
import anndata as ad
import numpy as np
import pandas as pd
[docs]
class SpatialSubunitSampler:
def __init__(
self,
adata: ad.AnnData,
dim_subunit: tuple[int],
min_obs: int,
spatial_key: str,
verbose: bool = False,
):
self.adata = adata
self.dim_subunit = dim_subunit
self.min_obs = min_obs
self.spatial_key = spatial_key
self.verbose = verbose
self.subunits = None
[docs]
def partition(self) -> None:
"""
Partition the observations into a uniform grid of cubic spatial subunits.
Divides the bounding box of the spatial coordinates (``adata.obsm[spatial_key]``)
into cubes of edge length ``dim_subunit`` and assigns each observation to the cube
it falls in. The result is stored on ``self.subunits`` as a dict keyed by the
integer grid index of each cube, where each value holds the member observation
indices, their spatial coordinates, a bounding box and an integer subunit id.
Returns:
None
"""
min_bounds = self.adata.obsm[self.spatial_key].min(axis=0)
max_bounds = self.adata.obsm[self.spatial_key].max(axis=0)
# Calculate number of cubes in each dimension
extent = max_bounds - min_bounds
n_subunits = np.ceil(extent / self.dim_subunit).astype(int)
if self.verbose:
print(f'Sample extent: {extent}')
print(f'Grid dimensions: {n_subunits} subunits')
print(f'Total potential subunits: {np.prod(n_subunits)}')
# get subunit index for each cell position
obs_indices = np.floor(
(self.adata.obsm[self.spatial_key] - min_bounds) / self.dim_subunit
).astype(int)
# obs_indices = np.clip(obs_indices, 0, n_subunits - 1)
# assign cells to subunits
self.subunits = {}
for i, spatial_pos in enumerate(self.adata.obsm[self.spatial_key]):
subunit_key = tuple(map(int, obs_indices[i]))
if subunit_key not in self.subunits:
# Calculate bounding box for this subunit
subunit_min = min_bounds + np.array(subunit_key) * self.dim_subunit
subunit_max = (
min_bounds + (np.array(subunit_key) + 1) * self.dim_subunit
)
self.subunits[subunit_key] = {
'obs_indices': [], # List of observation indices
'obs_spatial': [], # List of spatial coordinates
'bb_box': {
'min': subunit_min, # Shape: (3,)
'max': subunit_max, # Shape: (3,)
},
}
# Append to lists instead of overwriting
self.subunits[subunit_key]['obs_indices'].append(self.adata.obs.index[i])
self.subunits[subunit_key]['obs_spatial'].append(spatial_pos)
# Convert lists to arrays and add subunit ids
for i, subunit_key in enumerate(self.subunits):
self.subunits[subunit_key]['id'] = i
self.subunits[subunit_key]['obs_indices'] = np.array(
self.subunits[subunit_key]['obs_indices']
)
self.subunits[subunit_key]['obs_spatial'] = np.array(
self.subunits[subunit_key]['obs_spatial']
)
[docs]
def filter(self) -> None:
"""
Filter subunits based on minimum number of observations.
Drops any subunit with fewer than ``self.min_obs`` observations (the
threshold set at construction).
Returns:
None
"""
self.subunits = {
subunit_key: subunit_data
for subunit_key, subunit_data in self.subunits.items()
if len(subunit_data['obs_indices']) >= self.min_obs
}
[docs]
def sample(self, max_obs: int | None) -> None:
"""
Sample observations within each subunit based on max_obs threshold.
Randomly subsamples observations in subunits that exceed the max_obs threshold.
Subunits with fewer observations than max_obs are left unchanged.
Args:
max_obs (int): Maximum number of observations per subunit. Subunits exceeding this
threshold will be randomly subsampled to this size.
Returns:
None
"""
self.max_obs = max_obs
if self.max_obs is None:
if self.verbose:
print('No max_obs specified, skipping subsampling')
return
n_subsampled = 0
total_before = sum(len(data['obs_indices']) for data in self.subunits.values())
for _, subunit_data in self.subunits.items():
n_obs = len(subunit_data['obs_indices'])
if n_obs > self.max_obs:
# Subsample this subunit
keep_indices = self._random_sample(n_obs, self.max_obs)
subunit_data['obs_indices'] = subunit_data['obs_indices'][keep_indices]
subunit_data['obs_spatial'] = subunit_data['obs_spatial'][keep_indices]
n_subsampled += 1
total_after = sum(len(data['obs_indices']) for data in self.subunits.values())
if self.verbose:
print(
f'Subsampled {n_subsampled} subunits with >{self.max_obs} observations'
)
print(f'Total observations: {total_before} → {total_after}')
def _random_sample(self, n_points, n_samples):
"""Random subsampling"""
return np.random.choice(n_points, n_samples, replace=False)
[docs]
def to_df(self) -> pd.DataFrame:
"""
Build a per-observation table mapping each object to its spatial subunit.
Flattens ``self.subunits`` into one row per observation, with the subunit
assignment as a column and the observation index as the (string) index.
Returns:
pd.DataFrame: One row per observation, with a ``subunit_id`` column and
the observation index (``obs_index``) as the DataFrame index.
"""
# Pre-calculate total number of rows
total_obs = sum(len(data['obs_indices']) for data in self.subunits.values())
# Pre-allocate arrays
subunit_ids = []
obs_indices = np.zeros(total_obs, dtype=int)
idx = 0
for subunit_id, subunit_data in self.subunits.items():
n_obs = len(subunit_data['obs_indices'])
subunit_ids.extend([str(subunit_id)] * n_obs)
obs_indices[idx : idx + n_obs] = subunit_data['obs_indices']
idx += n_obs
# Create DataFrame
df = pd.DataFrame(
{
'subunit_id': subunit_ids,
'obs_index': obs_indices,
}
)
df['obs_index'] = df['obs_index'].astype(str)
df = df.set_index('obs_index')
return df