我们介绍了一个关于如何将数据从 Dask 这样的松耦合并行计算系统传递到 MPI 这样的紧耦合并行计算系统的实验。
我们提供了动机和一个完整易懂的例子。
免责声明:本文中的内容尚未完善,也未准备好用于生产环境。这是一个旨在引发讨论的实验。不提供长期支持。
我们经常收到以下问题
如何使用 Dask 预处理我的数据,然后将结果传递给传统的 MPI 应用程序?
你可能想这样做,因为你在维护用 MPI 编写的遗留代码,或者因为你的计算需要只有 MPI 才能提供的紧耦合并行性。
最简单的方法当然是将 Dask 结果写入磁盘,然后用 MPI 从磁盘重新加载。考虑到你的计算与数据加载的相对成本,这可能是一个不错的选择。
对于本文的其余部分,我们将假设情况并非如此。
我们有一个用 MPI4Py 编写的简单 MPI 库,其中每个 rank 只打印它获得的所有数据。原则上它也可以调用 C++ 代码,并执行任意的 MPI 操作。
# my_mpi_lib.py
from mpi4py import MPI
comm = MPI.COMM_WORLD
def print_data_and_rank(chunks: list)
""" 模拟 MPI 函数如何运行的假函数
- 它接收一个包含此机器上数据块的列表
- 它可以使用这些数据和 MPI 做任何它想做的事情
这里为了简单起见,我们只打印数据和 rank
- 也许它会返回一些东西
"""
rank = comm.Get_rank()
for chunk in chunks
print("on rank:", rank)
print(chunk)
return sum(chunk.sum() for chunk in chunks)
在我们的 Dask 程序中,我们将正常使用 Dask 加载数据,进行一些预处理,然后将所有数据传递给每个 MPI rank,该 rank 将调用上面的 print_data_and_rank 函数来初始化 MPI 计算。
# my_dask_script.py
# 使用 dask_mpi 项目在 MPI 作业中设置 Dask workers
# 请参阅 https://dask-mpi.readthedocs.io/en/latest/
from dask_mpi import initialize
initialize()
from dask.distributed import Client, wait, futures_of
client = Client()
# 使用 Dask Array “加载”数据(实际上这里只是创建随机数据)
import dask.array as da
x = da.random.random(100000000, chunks=(1000000,))
x = x.persist()
wait(x)
# 找出每个 worker 上的数据位置
# TODO:这可以在 Dask 方面改进以减少样板代码
from toolz import first
from collections import defaultdict
key_to_part_dict = {str(part.key): part for part in futures_of(x)}
who_has = client.who_has(x)
worker_map = defaultdict(list)
for key, workers in who_has.items()
worker_map[first(workers)].append(key_to_part_dict[key])
# 对每个 worker 上的数据列表调用一个启用 MPI 的函数
from my_mpi_lib import print_data_and_rank
futures = [client.submit(print_data_and_rank, list_of_parts, workers=worker)
for worker, list_of_parts in worker_map.items()]
wait(futures)
client.close()
然后我们可以使用普通的 mpirun 或 mpiexec 命令调用这个 Dask 和 MPI 程序混合体。
mpirun -np 5 python my_dask_script.py
所以 MPI 启动并运行了我们的脚本。dask-mpi 项目在 rank 0 上设置了一个 Dask scheduler,在 rank 1 上运行我们的客户端代码,然后在 ranks 2+ 上运行一批 workers。
我们的脚本接着创建了一个 Dask 数组,尽管推测在这里它会从某个源读取数据,在继续之前进行更复杂的 Dask 操作。
然后我们等待所有的 Dask 工作完成并进入安静状态。然后我们查询 scheduler 中的状态,找出所有数据所在的位置。就是这里的这段代码
# 找出每个 worker 上的数据位置
# TODO:这可以在 Dask 方面改进以减少样板代码
from toolz import first
from collections import defaultdict
key_to_part_dict = {str(part.key): part for part in futures_of(x)}
who_has = client.who_has(x)
worker_map = defaultdict(list)
for key, workers in who_has.items()
worker_map[first(workers)].append(key_to_part_dict[key])
诚然,这段代码很糟糕,对非 Dask 专家(甚至 Dask 专家自己)来说也不特别友好或显而易见,我不得不从执行相同技巧的 Dask XGBoost 项目中“借用”这段代码。
但之后我们只需使用 Dask 的 Futures 接口,对所有数据调用我们的 MPI 库的 initialize 函数 print_data_and_rank。该函数直接从本地内存获取数据(Dask workers 和 MPI ranks 在同一个进程中),并执行 MPI 应用程序想要的任何操作。
这可以在几个方面得到改进