ディープ ラーニング モデル推論のワークフロー
Azure Databricks では、ディープ ラーニング アプリケーションのモデル推論に次のワークフローをお勧めしています。 TensorFlow と PyTorch を使用するノートブックの例については、「ディープ ラーニング モデル推論の例」を参照してください。
データを Spark DataFrames に読み込みます。 Azure Databricks では、データ型に応じて次のようなデータの読み込み方法をお勧めしています。
- 画像ファイル (JPG、PNG): イメージ パスを Spark DataFrame に読み込みます。 画像の読み込みと入力データの前処理は、pandas UDF で行われます。
files_df = spark.createDataFrame(map(lambda path: (path,), file_paths), ["path"])
- TFRecords: spark-tensorflow-connector を使用してデータを読み込みます。
df = spark.read.format("tfrecords").load(image_path)
- Parquet、CSV、JSON、JDBC、その他のメタデータなどのデータ ソース: Spark データ ソースを使用してデータを読み込みます。
pandas UDF を使用してモデル推論を実行します。pandas UDF では、Apache Arrow を使用してデータを転送し、pandas を使用してデータを処理します。 モデル推論を行うための、pandas UDF を使用したワークフローの大まかな手順を次に示します。
- トレーニング済みモデルを読み込む: 効率を高めるため、Azure Databricks では、ドライバーからモデルの重みをブロードキャストし、モデル グラフを読み込んで、pandas UDF でブロードキャストされた変数から重みを取得することをおお勧めしています。
- 入力データを読み込んで前処理する: データをバッチで読み込むために、Azure Databricks では、TensorFlow には tf.data API を、PyTorch には DataLoader クラスを使用することをお勧めしています。 どちらも、IO バインドの待機時間がわからなくなるように、プリフェッチとマルチスレッド読み込みもサポートしています。
- モデル予測を実行する: データ バッチでモデル推論を実行します。
- 予測を Spark DataFrames に送り返す: 予測結果を収集し、
pd.Series
として返します。
ディープ ラーニング モデル推論の例
このセクションの例は、推奨されているディープ ラーニング推論ワークフローに従っています。 これらの例では、事前トレーニング済みのディープ残差ネットワーク (ResNets) ニューラル ネットワーク モデルを使用してモデル推論を実行する方法を示します。