本工作由 Anaconda Inc 提供支持
太长不看: Pickle 不慢,它是一个协议。协议对于生态系统至关重要。
最近的一个 Dask 问题表明,将 Dask 与 PyTorch 结合使用很慢,因为在 Dask worker 之间发送 PyTorch 模型需要很长时间(Dask GitHub 问题)。
事实证明,这是因为使用 pickle 序列化 PyTorch 模型非常慢(基于 GPU 的模型为 1 MB/s,基于 CPU 的模型为 50 MB/s)。没有任何架构原因需要如此慢。硬件管道的每个部分都比这快得多。
我们本可以通过对 PyTorch 模型进行特殊处理来解决 Dask 中的这个问题(Dask 有自己的可选序列化系统以提高性能),但作为良好的生态系统公民,我们决定在上游提出性能问题(PyTorch Github 问题)。这导致 PyTorch 进行了五行代码的修复,将 1-50 MB/s 的序列化带宽变成了 1 GB/s 的带宽,这对于许多用例来说已经足够快了(PyTorch 的 PR)。
def __reduce__(self)
- return type(self), (self.tolist(),)
+ b = io.BytesIO()
+ torch.save(self, b)
+ return (_load_from_bytes, (b.getvalue(),))
+def _load_from_bytes(b)
+ return torch.load(io.BytesIO(b))
感谢 PyTorch 维护者,这个问题很容易就解决了。PyTorch 张量和模型现在可以在 Dask 或 任何其他 Python 库中高效地序列化,这些库可能希望在 PySpark、IPython parallel、Ray 或其他任何分布式系统中使用它们,而无需添加特殊处理代码或做任何特殊的事情。我们没有解决一个 Dask 问题,我们解决了一个生态系统问题。
然而,在解决这个问题之前,我们进行了一些讨论。这条评论让我印象深刻
这条评论包含了两种非常普遍的观点,我发现它们有些事倍功半
我在这里有点吹毛求疵地说了 PyTorch 维护者(抱歉!),但我发现这些观点非常普遍,所以想在这里阐述一下。
Pickle 不 慢。Pickle 是一个协议。我们 实现了 pickle。如果它慢,那是 我们的 错,而不是 Pickle 的错。
需要澄清的是,有很多理由不使用 Pickle。
因此,你不应该使用 Pickle 存储数据或创建公共服务,但对于诸如在线路上传输数据之类的任务,如果你是在受信任且统一的环境中严格地从 Python 进程到 Python 进程进行传输,那么它是一个很好的默认选择。
它之所以很棒,是因为它可以做到像内存复制一样快,并且生态系统中的其他库无需特殊处理你的代码即可使用它。
这是我们为 PyTorch 所做的更改。
def __reduce__(self)
- return type(self), (self.tolist(),)
+ b = io.BytesIO()
+ torch.save(self, b)
+ return (_load_from_bytes, (b.getvalue(),))
+def _load_from_bytes(b)
+ return torch.load(io.BytesIO(b))
慢的部分不是 Pickle,而是 __reduce__ 中的 .tolist() 调用,它将 PyTorch 张量转换为 Python 整数和浮点数的列表。我怀疑“Pickle 就是慢”的普遍看法阻止了其他人调查这里的糟糕性能。我很惊讶地得知,像 PyTorch 这样活跃且维护良好的项目竟然还没有解决这个问题。
提醒一下,你可以通过在类中提供 __reduce__ 方法来实现 pickle 协议。__reduce__ 函数返回一个加载函数和足够的参数来重构你的对象。这里我们使用了 torch 现有的 save/load 函数来创建一个我们可以传递的字节字符串。
专门选项可能很棒。它们可以提供带有许多选项的良好 API,如果存在专门的通信硬件(如 RDMA 或 NVLink),它们可以根据硬件进行调整,等等。但人们需要先了解它们,而了解它们可能有两个难点。
如今,我们使用着大量且快速变化的库。用户很难精通所有这些库。我们越来越依赖新的库通过遵循标准 API、提供有用的错误消息来引导良好行为等方式使我们更容易使用。
需要交互的其他库肯定不会阅读文档,即使阅读了,也不合理让每个库都为其他每个库喜欢的将对象转换为字节的方法进行特殊处理。库的生态系统在很大程度上依赖于协议的存在以及围绕一致高效地实现协议的强烈共识。
支持专门选项确实有很好的理由。有时你需要超过 1GB/s 的带宽。虽然这通常很少见(很少有管道处理速度超过 1GB/s/节点),但在 PyTorch 的特定情况下,当它们在具有多个进程的单台机器上进行并行训练时,情况确实如此。Soumith(PyTorch 维护者)写道:
当通过多进程发送张量时,我们的自定义序列化程序实际上通过共享内存来快捷处理它们,即将底层 Storage 移动到共享内存,并在另一个进程中恢复张量以指向共享内存。我们这样做是出于以下原因: