使用 Ray 和 Apache Arrow 實現快速 Python 序列化


已發布 2017 年 10 月 15 日
作者 Philipp Moritz, Robert Nishihara

這篇文章最初發布在 Ray 部落格Philipp MoritzRobert Nishihara 是加州大學柏克萊分校的研究生。

這篇文章詳細闡述了 RayApache Arrow 之間的整合。它主要解決的問題是 資料序列化

根據 維基百科序列化

... 將資料結構或物件狀態轉換為可以儲存 ... 或傳輸 ... 並在稍後(可能在不同的電腦環境中)重建的格式的過程。

為什麼需要任何轉換? 嗯,當您建立一個 Python 物件時,它可能會有指向其他 Python 物件的指標,而這些物件都分配在記憶體的不同區域中,所有這些都必須在另一部機器上的另一個程序解包時才有意義。

序列化和反序列化是平行和分散式運算中的瓶頸,尤其是在具有大型物件和大量資料的機器學習應用程式中。

設計目標

由於 Ray 針對機器學習和 AI 應用程式進行了最佳化,我們非常關注序列化和資料處理,並具有以下設計目標

  1. 對於大型數值資料應該非常有效率(這包括 NumPy 陣列和 Pandas DataFrame,以及遞迴包含 NumPy 陣列和 Pandas DataFrame 的物件)。
  2. 對於一般 Python 類型,其速度應與 Pickle 相當。
  3. 它應該與共享記憶體相容,允許在多個程序之間使用相同的資料而無需複製。
  4. 反序列化應該非常快速(在可能的情況下,它不應該需要讀取整個序列化的物件)。
  5. 它應該是語言獨立的(最終我們希望能夠讓 Python 工作程序使用由 Java 或其他語言的工作程序建立的物件,反之亦然)。

我們的方法和替代方案

Python 中常用的序列化方法是 pickle 模組。Pickle 非常通用,特別是如果您使用像 cloudpickle 這樣的變體。但是,它不滿足要求 1、3、4 或 5。像 json 這樣的替代方案滿足了 5,但不滿足 1-4。

我們的方法: 為了滿足要求 1-5,我們選擇使用 Apache Arrow 格式作為我們的底層資料表示。與 Apache Arrow 團隊合作,我們建構了 函式庫,用於將一般 Python 物件對應到 Arrow 格式以及從 Arrow 格式對應回來。這種方法的一些特性

  • 資料佈局是語言獨立的(要求 5)。
  • 可以以常數時間計算序列化資料 Blob 的偏移量,而無需讀取整個物件(要求 1 和 4)。
  • Arrow 支援零複製讀取,因此物件可以自然地儲存在共享記憶體中,並由多個程序使用(要求 1 和 3)。
  • 對於任何我們無法很好處理的東西,我們可以自然地退回到 pickle(要求 2)。

Arrow 的替代方案: 我們本來可以基於 Protocol Buffers 構建,但 protocol buffers 並非真正為數值資料而設計,而且這種方法無法滿足 1、3 或 4。基於 Flatbuffers 構建實際上是可以實現的,但它需要實作 Arrow 已經擁有的許多功能,而且我們更喜歡針對大數據最佳化的欄狀資料佈局。

加速

在這裡,我們展示了一些相對於 Python 的 pickle 模組的效能改進。這些實驗是使用 pickle.HIGHEST_PROTOCOL 完成的。用於產生這些圖表的程式碼包含在文章末尾。

使用 NumPy 陣列: 在機器學習和 AI 應用程式中,資料(例如,影像、神經網路權重、文字文件)通常表示為包含 NumPy 陣列的資料結構。當使用 NumPy 陣列時,加速效果令人印象深刻。

Ray 的反序列化條形圖幾乎看不到並非錯誤。這是支援零複製讀取的結果(節省主要來自於缺乏記憶體移動)。

請注意,最大的優勢在於反序列化。此處的加速是多個數量級,並且隨著 NumPy 陣列變得更大而變得更好(感謝設計目標 1、3 和 4)。使反序列化快速非常重要,原因有二。首先,物件可能被序列化一次,然後被反序列化多次(例如,廣播到所有工作程序的物件)。其次,常見的模式是許多物件被平行序列化,然後在單個工作程序上一次一個地聚合和反序列化,這使得反序列化成為瓶頸。

不使用 NumPy 陣列: 當使用常規 Python 物件時,我們無法利用共享記憶體,結果與 pickle 相當。

這些只是一些有趣的 Python 物件的範例。最重要的情況是 NumPy 陣列巢狀在其他物件中的情況。請注意,我們的序列化函式庫適用於非常通用的 Python 類型,包括自訂 Python 類別和深度巢狀物件。

API

序列化函式庫可以直接透過 pyarrow 使用,如下所示。更多文件請參閱此處

x = [(1, 2), 'hello', 3, 4, np.array([5.0, 6.0])]
serialized_x = pyarrow.serialize(x).to_buffer()
deserialized_x = pyarrow.deserialize(serialized_x)

它可以直接透過 Ray API 使用,如下所示。

x = [(1, 2), 'hello', 3, 4, np.array([5.0, 6.0])]
x_id = ray.put(x)
deserialized_x = ray.get(x_id)

資料表示

我們使用 Apache Arrow 作為底層的語言獨立資料佈局。物件儲存在兩個部分:schemadata blob。在高階層次上,data blob 大致是物件中遞迴包含的所有資料值的扁平串聯,而 schema 定義了 data blob 的類型和巢狀結構。

技術細節: Python 序列(例如,字典、列表、元組、集合)被編碼為其他類型的 Arrow UnionArrays(例如,布林值、整數、字串、位元組、浮點數、雙精度浮點數、date64s、張量(即 NumPy 陣列)、列表、元組、字典和集合)。巢狀序列使用 Arrow ListArrays 進行編碼。所有張量都被收集並附加到序列化物件的末尾,並且 UnionArray 包含對這些張量的引用。

為了給出具體的範例,請考慮以下物件。

[(1, 2), 'hello', 3, 4, np.array([5.0, 6.0])]

它將在 Arrow 中以以下結構表示。

UnionArray(type_ids=[tuple, string, int, int, ndarray],
           tuples=ListArray(offsets=[0, 2],
                            UnionArray(type_ids=[int, int],
                                       ints=[1, 2])),
           strings=['hello'],
           ints=[3, 4],
           ndarrays=[<offset of numpy array>])

Arrow 使用 Flatbuffers 來編碼序列化 schema。僅使用 schema,我們就可以計算資料 Blob 中每個值的偏移量,而無需掃描整個資料 Blob(與 Pickle 不同,這是實現快速反序列化的原因)。這表示我們可以避免在反序列化期間複製或以其他方式轉換大型陣列和其他值。張量被附加在 UnionArray 的末尾,並且可以使用共享記憶體有效地共享和存取。

請注意,實際物件在記憶體中的佈局如下所示。

Python 物件在堆積中的佈局。每個框都分配在不同的記憶體區域中,框之間的箭頭表示指標。


Arrow 序列化表示將如下所示。

Arrow 序列化物件的記憶體佈局。


參與其中

我們歡迎貢獻,尤其是在以下領域。

  • 使用 Arrow 的 C++ 和 Java 實作來為 C++ 和 Java 實作此版本的序列化函式庫。
  • 實作對更多 Python 類型的支援和更好的測試覆蓋率。

重現以上圖表

作為參考,可以使用以下程式碼重現這些圖表。基準測試 ray.putray.get 而不是 pyarrow.serializepyarrow.deserialize 會得到類似的圖表。這些圖表是在這個 commit 生成的。

import pickle
import pyarrow
import matplotlib.pyplot as plt
import numpy as np
import timeit


def benchmark_object(obj, number=10):
    # Time serialization and deserialization for pickle.
    pickle_serialize = timeit.timeit(
        lambda: pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL),
        number=number)
    serialized_obj = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
    pickle_deserialize = timeit.timeit(lambda: pickle.loads(serialized_obj),
                                       number=number)

    # Time serialization and deserialization for Ray.
    ray_serialize = timeit.timeit(
        lambda: pyarrow.serialize(obj).to_buffer(), number=number)
    serialized_obj = pyarrow.serialize(obj).to_buffer()
    ray_deserialize = timeit.timeit(
        lambda: pyarrow.deserialize(serialized_obj), number=number)

    return [[pickle_serialize, pickle_deserialize],
            [ray_serialize, ray_deserialize]]


def plot(pickle_times, ray_times, title, i):
    fig, ax = plt.subplots()
    fig.set_size_inches(3.8, 2.7)

    bar_width = 0.35
    index = np.arange(2)
    opacity = 0.6

    plt.bar(index, pickle_times, bar_width,
            alpha=opacity, color='r', label='Pickle')

    plt.bar(index + bar_width, ray_times, bar_width,
            alpha=opacity, color='c', label='Ray')

    plt.title(title, fontweight='bold')
    plt.ylabel('Time (seconds)', fontsize=10)
    labels = ['serialization', 'deserialization']
    plt.xticks(index + bar_width / 2, labels, fontsize=10)
    plt.legend(fontsize=10, bbox_to_anchor=(1, 1))
    plt.tight_layout()
    plt.yticks(fontsize=10)
    plt.savefig('plot-' + str(i) + '.png', format='png')


test_objects = [
    [np.random.randn(50000) for i in range(100)],
    {'weight-' + str(i): np.random.randn(50000) for i in range(100)},
    {i: set(['string1' + str(i), 'string2' + str(i)]) for i in range(100000)},
    [str(i) for i in range(200000)]
]

titles = [
    'List of large numpy arrays',
    'Dictionary of large numpy arrays',
    'Large dictionary of small sets',
    'Large list of strings'
]

for i in range(len(test_objects)):
    plot(*benchmark_object(test_objects[i]), titles[i], i)