[go: nahoru, domu]

Skip to content

Commit

Permalink
fix visium
Browse files Browse the repository at this point in the history
  • Loading branch information
giovp committed Jan 15, 2023
1 parent 727ae8e commit b2d6c60
Showing 1 changed file with 83 additions and 42 deletions.
125 changes: 83 additions & 42 deletions src/spatialdata_io/readers/cosmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import re
from collections.abc import Mapping
from copy import deepcopy
from pathlib import Path
from types import MappingProxyType
from typing import Any
Expand All @@ -13,19 +12,26 @@
from anndata import AnnData
from dask_image.imread import imread
from scipy.sparse import csr_matrix
from skimage.transform import estimate_transform

# from skimage.transform import estimate_transform
from spatialdata import SpatialData
from spatialdata._core.core_utils import xy_cs
from spatialdata._core.models import Image2DModel, Labels2DModel, ShapesModel
from spatialdata._core.transformations import Affine
from spatialdata._core.coordinate_system import Axis # , CoordinateSystem

# from spatialdata._core.core_utils import xy_cs
from spatialdata._core.models import Image2DModel, Labels2DModel, TableModel

# from spatialdata._core.transformations import Affine
from spatialdata._logging import logger
from spatialdata._types import ArrayLike

from spatialdata_io._constants._constants import CosmxKeys
from spatialdata_io._docs import inject_docs

__all__ = ["cosmx"]

x_axis = Axis(name="x", type="space", unit="discrete")
y_axis = Axis(name="y", type="space", unit="discrete")
c_axis = Axis(name="c", type="channel", unit="index")


@inject_docs(cx=CosmxKeys)
def cosmx(
Expand Down Expand Up @@ -104,51 +110,93 @@ def cosmx(
)
adata.var_names = counts.columns

fovs_counts = set(adata.obs.fov.astype(str).unique())
table = TableModel.parse(
adata,
region=adata.obs.fov.astype(str).tolist(),
region_key=CosmxKeys.REGION_KEY,
instance_key=CosmxKeys.INSTANCE_KEY,
)

fovs_counts = set(table.obs.fov.astype(str).unique())

# TODO(giovp): uncomment once transform is ready
# input_cs = CoordinateSystem("cxy", axes=[c_axis, y_axis, x_axis])
# input_cs_labels = CoordinateSystem("cxy", axes=[y_axis, x_axis])
# output_cs = CoordinateSystem("global", axes=[c_axis, y_axis, x_axis])
# output_cs_labels = CoordinateSystem("global", axes=[y_axis, x_axis])

# affine_transforms_images = {}
# affine_transforms_labels = {}

# for fov in fovs_counts:
# idx = table.obs.fov.astype(str) == fov
# loc = table[idx, :].obs[[CosmxKeys.X_LOCAL, CosmxKeys.Y_LOCAL]].values
# glob = table[idx, :].obs[[CosmxKeys.X_GLOBAL, CosmxKeys.Y_GLOBAL]].values
# out = estimate_transform(ttype="affine", src=loc, dst=glob)
# affine_transforms_images[fov] = Affine(
# out.params, input_coordinate_system=input_cs, output_coordinate_system=output_cs
# )
# affine_transforms_labels[fov] = Affine(
# out.params, input_coordinate_system=input_cs_labels, output_coordinate_system=output_cs_labels
# )

table.obsm["global"] = table.obs[[CosmxKeys.X_GLOBAL, CosmxKeys.Y_GLOBAL]].to_numpy()
table.obsm["spatial"] = table.obs[[CosmxKeys.X_LOCAL, CosmxKeys.Y_LOCAL]].to_numpy()
table.obs.drop(columns=[CosmxKeys.X_LOCAL, CosmxKeys.Y_LOCAL, CosmxKeys.X_GLOBAL, CosmxKeys.Y_GLOBAL], inplace=True)

# prepare to read images and labels
file_extensions = (".jpg", ".png", ".jpeg", ".tif", ".tiff")
pat = re.compile(r".*_F(\d+)")

# read images
images = {}
# check if fovs are correct for images and labels
fovs_images = []
for fname in os.listdir(path / CosmxKeys.IMAGES_DIR):
if fname.endswith(file_extensions):
fov = str(int(pat.findall(fname)[0]))
images[fov] = Image2DModel.parse(
imread(path / CosmxKeys.IMAGES_DIR / fname, **imread_kwargs).squeeze(), name=fov, **image_models_kwargs
)
fovs_images.append(str(int(pat.findall(fname)[0])))

# read labels
labels = {}
fovs_labels = []
for fname in os.listdir(path / CosmxKeys.LABELS_DIR):
if fname.endswith(file_extensions):
fov = str(int(pat.findall(fname)[0]))
labels[fov] = Labels2DModel.parse(
imread(path / CosmxKeys.LABELS_DIR / fname, **imread_kwargs).squeeze(), name=fov, **image_models_kwargs
)
fovs_labels.append(str(int(pat.findall(fname)[0])))

fovs_images = set(images.keys()).intersection(set(labels.keys()))
fovs_diff = fovs_images.difference(fovs_counts)
fovs_images_and_labels = set(fovs_images).intersection(set(fovs_labels))
fovs_diff = fovs_images_and_labels.difference(set(fovs_counts))
if len(fovs_diff):
logger.warning(
raise logger.warning(
f"Found images and labels for {len(fovs_images)} FOVs, but only {len(fovs_counts)} FOVs in the counts file.\n"
+ f"The following FOVs are missing: {fovs_diff} \n"
+ "`SpatialData` returns intersection of FOVs for counts and images/labels.",
+ "... will use only fovs in Table."
)

circles = {}
for fov in fovs_images:
idx = adata.obs.fov.astype(str) == fov
loc = adata[idx, :].obs[[CosmxKeys.X_LOCAL, CosmxKeys.Y_LOCAL]].values
glob = adata[idx, :].obs[[CosmxKeys.X_GLOBAL, CosmxKeys.Y_GLOBAL]].values
loc_to_glob_transform = _estimate_transform(loc, glob)
circ = ShapesModel.parse(loc, shape_type="circle", shape_size=shape_size)
implicit_transform = circ.uns["transform"]
circ.uns["transform"] = [implicit_transform, loc_to_glob_transform]
circles[fov] = circ
# read images
images = {}
for fname in os.listdir(path / CosmxKeys.IMAGES_DIR):
if fname.endswith(file_extensions):
fov = str(int(pat.findall(fname)[0]))
if fov in fovs_counts:
images[fov] = Image2DModel.parse(
imread(path / CosmxKeys.IMAGES_DIR / fname, **imread_kwargs).squeeze(),
name=fov,
# transform=affine_transforms_images[fov],
**image_models_kwargs,
)
else:
logger.warning(f"FOV {fov} not found in counts file. Skipping image {fname}.")

adata.obs.drop(columns=[CosmxKeys.X_LOCAL, CosmxKeys.Y_LOCAL, CosmxKeys.X_GLOBAL, CosmxKeys.Y_GLOBAL], inplace=True)
# read labels
labels = {}
for fname in os.listdir(path / CosmxKeys.LABELS_DIR):
if fname.endswith(file_extensions):
fov = str(int(pat.findall(fname)[0]))
if fov in fovs_counts:
labels[fov] = Labels2DModel.parse(
imread(path / CosmxKeys.LABELS_DIR / fname, **imread_kwargs).squeeze(),
name=fov,
# transform=affine_transforms_labels[fov],
**image_models_kwargs,
)
else:
logger.warning(f"FOV {fov} not found in counts file. Skipping labels {fname}.")

# TODO: what to do with fov file?
# if fov_file is not None:
Expand All @@ -160,11 +208,4 @@ def cosmx(
# logg.warning(f"FOV `{str(fov)}` does not exist, skipping it.")
# continue

return SpatialData(images=images, labels=labels, shapes=circles, table=adata)


def _estimate_transform(src: ArrayLike, tgt: ArrayLike) -> Affine:
out = estimate_transform(ttype="affine", src=src, dst=tgt)
out_cs = deepcopy(xy_cs)
out_cs.name = "xy_global"
return Affine(out.params, input_coordinate_system=xy_cs, output_coordinate_system=out_cs)
return SpatialData(images=images, labels=labels, table=table)

0 comments on commit b2d6c60

Please sign in to comment.