提交新活动

感谢您!您的提交已收到!
糟糕!提交表单时出了点问题。

提交新闻报道

感谢您!您的提交已收到!
糟糕!提交表单时出了点问题。

订阅新闻通讯

感谢您!您的提交已收到!
糟糕!提交表单时出了点问题。
2021年3月29日

将 Dask 与 PyTorch 结合用于大规模图像分析

作者

总结

这篇文章探讨了如何将预训练的 PyTorch 模型与 Dask Array 并行应用。

我们将介绍一个简单示例,演示如何将预训练的 UNet 应用于一叠图像,为每个像素生成特征。

示例详解

让我们从一个示例开始,将预训练的 UNet 应用于光片显微镜数据堆栈。

在此示例中,我们将:

     
  1. 从 Zarr 加载图像数据到多分块的 Dask 数组中
  2.  
  3. 加载一个用于提取图像特征的预训练 PyTorch 模型
  4.  
  5. 构建一个函数,将模型应用于每个分块
  6.  
  7. 使用 dask.array.map_blocks 函数将该函数应用于整个 Dask 数组。
  8.  
  9. 将结果重新存回 Zarr 格式

步骤 1:加载图像数据

首先,我们将图像数据加载到 Dask 数组中。

这里使用的示例数据集是斑马鱼胚胎尾部区域的点阵光片显微镜数据。这在这篇 Science 论文(参见图 4)中有所描述,并经 Srigokul Upadhyayula 许可提供。

Liu 2018 年 “Observing the cell in its native state: Imaging subcellular dynamics in multicellular organisms” Science, Vol. 360, Issue 6386, eaaq1392 DOI: 10.1126/science.aaq1392 (链接)

这是我们在上一篇关于 Dask 和 ITK 的博文中分析过的相同数据。您应该注意到它与该工作流程的相似之处,尽管我们现在使用的是新的库并进行不同的分析。

cd '/Users/nicholassofroniew/Github/image-demos/data/LLSM'
# Load our data
import dask.array as da
imgs = da.from_zarr("AOLLSM_m4_560nm.zarr")
imgs
dask.array<from-zarr, shape=(20, 199, 768, 1024), dtype=float32, chunksize=(1, 1, 768, 1024)>

步骤 2:加载预训练的 PyTorch 模型

接下来,我们加载预训练的 UNet 模型。

这个 UNet 模型接收一张 2D 图像,并返回一个 2D x 16 数组,其中每个像素现在都关联一个长度为 16 的特征向量。

我们感谢 Mars Huang 在大量生物图像上训练了这个特定的 UNet 模型,以生成生物相关的特征向量,这是他关于交互式生物图像分割工作的一部分。这些特征随后可用于更多下游图像处理任务,例如图像分割。

# Load our pretrained UNet¶
import torch
from segmentify.model import UNet, layers

def load_unet(path):
    """Load a pretrained UNet model."""

    # load in saved model
    pth = torch.load(path)
    model_args = pth['model_args']
    model_state = pth['model_state']
    model = UNet(**model_args)
    model.load_state_dict(model_state)

    # remove last layer and activation
    model.segment = layers.Identity()
    model.activate = layers.Identity()
    model.eval()

    return model

model = load_unet("HPA_3.pth")

步骤 3:构建一个函数,将模型应用于每个分块

我们构建一个函数,将预训练的 UNet 模型应用于 Dask 数组的每个分块。

由于 Dask 数组仅由可轻松转换为 Torch 数组的 Numpy 数组组成,因此我们能够大规模地利用机器学习的力量。

# Apply UNet featurization
import numpy as np

def unet_featurize(image, model):
    """Featurize pixels in an image using pretrained UNet model.
    """
    import numpy as np
    import torch

    # Extract the 2D image data from the Dask array
    # Original Dask array dimensions were (time, z-slice, y, x)
    img = image[0, 0, ...]

    # Put the data into a shape PyTorch expects
    # Expected dimensions are (Batch x Channel x Width x Height)
    img = img[None, None, ...]

    # convert image to torch Tensor
    img = torch.Tensor(img).float()

    # pass image through model
    with torch.no_grad():
        features = model(img).numpy()

    # generate feature vectors (w,h,f)
    features = np.transpose(features, (0,2,3,1))[0]

    # Add back the leading length-one dimensions
    result = features[None, None, ...]

    return result

注意:非常细心的读者可能会注意到,提取 2D 图像数据然后将其放入 PyTorch 期望的形状的步骤似乎是多余的。对于我们的特定示例来说,它是多余的,但情况很容易就不是这样。

为了更详细地解释这一点,UNet 期望 4D 输入,维度为(批量 x 通道 x 宽度 x 高度)。原始 Dask 数组的维度是(时间、z 切片、y、x)。在我们的示例中,这些维度恰好以一种使得移除然后添加前导维度变得多余的方式匹配,但根据原始 Dask 数组的形状,情况可能并非如此。

步骤 4:将该函数应用于整个 Dask 数组

现在我们使用 dask.array.map_blocks 函数将该函数应用于 Dask 数组中的数据。

# Apply UNet featurization
out = da.map_blocks(unet_featurize, imgs, model, dtype=np.float32, chunks=(1, 1, imgs.shape[2], imgs.shape[3], 16), new_axis=-1)
out
dask.array<unet_featurize, shape=(20, 199, 768, 1024, 16), dtype=float32, chunksize=(1, 1, 768, 1024, 16)>

步骤 5:将结果存回 Zarr 格式

最后,我们将 UNet 模型特征提取的结果存储为 zarr 数组。

# Trigger computation and store
out.to_zarr("AOLLSM_featurized.zarr", overwrite=True)

现在我们已经保存了输出,这些特征可用于更多下游图像处理任务,例如图像分割。

总结

至此,我们展示了如何将预训练的 PyTorch 模型应用于包含图像数据的 Dask 数组。

由于 Dask 数组的分块是 Numpy 数组,它们可以轻松转换为 Torch 数组。通过这种方式,我们能够大规模地利用机器学习的力量。

这个工作流程与我们使用 ITK 和 dask.array.map_blocks 函数执行图像反卷积的示例非常相似。这表明您可以轻松调整相同类型的工作流程,使用 Dask 实现许多不同类型的分析。