Source code for atlannot.atlas.align

# Copyright 2021, Blue Brain Project, EPFL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions related to atlas alignment."""
import warnings

import numpy as np


[docs]def unfurl_regions(atlas, meta, progress_bar=None): """Separate regions by hierarchy level. Each slice of the brain atlas is expanded into multiple copies with each subsequent copy having the last hierarchy level of the previous copy removed. For example, if a given slice has the region hierarchy up to depth 2, i. e. it has regions at level 0 (= the background), level 1, and level 2, then the slice will be expanded into 3 slices: - The original slice - The slice with regions at level 2 removed - The slice with regions at levels 2 and 1 removed (leaving just the background) Parameters ---------- atlas : np.ndarray An annotation atlas volume with shape `(n_slices, height, width)`. meta : atlannot.region_meta.RegionMeta The region metadata. Holds the information about the region hierarchy in the atlas. progress_bar : callable, optional A progress bar function that maps an iterable onto itself and produces a progress bar as a side effect. Notable examples are `tqdm.tqdm` and `tqdm.notebook.tqdm`. Returns ------- unfurled_atlas : np.ndarray The unfurled atlas. It will have the shape `(n_levels, n_slices, height, width)` where `n_levels` is the maximal region hierarchy level across all slices. """ max_level = max(meta.level.values()) unfurled = [atlas.copy()] to_remove = range(max_level, 0, -1) if progress_bar is not None: to_remove = progress_bar(to_remove) for remove_level in to_remove: atlas = atlas.copy() # Map regions at `remove_level` to the IDs of their parents for region_id in meta.ids_at_level(remove_level): parent_id = meta.parent_id[region_id] atlas[atlas == region_id] = parent_id unfurled.append(atlas) unfurled_atlas = np.stack(unfurled) return unfurled_atlas
[docs]def get_misalignment(data_1, data_2, fg_only=False): """Compute misalignment between annotation data. Parameters ---------- data_1 : np.ndarray The first annotation data. Can have any shape. data_2 : np.ndarray The second annotation data. Shape should match that of `data_1`. fg_only : bool, optional If true then only the foreground is considered for the evaluation. Foreground pixels are complimentary to background. Background is where both data arrays are zero. Returns ------- misalignment : float The misalignment between the annotation data. Raises ------ ValueError If the shapes of the data don't match. """ if data_1.shape != data_2.shape: raise ValueError("Data have to be of the same shape") unequal = data_1 != data_2 if fg_only: mask = (data_1 != 0) | (data_2 != 0) unequal = unequal[mask] misalignment = np.sum(unequal) / (unequal.size or 1) return misalignment
[docs]def specific_label_iou(data_1, data_2, specific_label): """Compute intersection over union for a given label. Parameters ---------- data_1 : np.ndarray The first annotation data. Can have any shape. data_2 : np.ndarray The second annotation data. Shape should match that of `data_1`. specific_label : int Label for which it is wanted to compute the IOU. Returns ------- iou : float The IOU for the given label. Raises ------ ValueError If the shapes of the data don't match. """ if data_1.shape != data_2.shape: raise ValueError("Data have to be of the same shape") data_1 = data_1 == specific_label data_2 = data_2 == specific_label intersection = np.logical_and(data_1, data_2) union = np.logical_or(data_1, data_2) if union.sum() == 0: iou = np.nan warnings.warn( f"It seems the specific label " f"{specific_label} does not exist on the input images." ) else: iou = intersection.sum() / union.sum() return iou
[docs]def warp(atlas, df_per_slice): """Warp the atlas with displacement fields. Parameters ---------- atlas : iterable of np.ndarray An annotation atlas. Can be an `np.ndarray` of shape `(n_slices, ...)` or any iterable over atlas slices. df_per_slice : iterable of atldld.base.DisplacementField The displacement fields for each brain slice. Returns ------- warped_atlas : np.ndarray The warped atlas. """ warped_atlas = np.stack( [df.warp_annotation(img) for df, img in zip(df_per_slice, atlas)] ) return warped_atlas