Databricks Connect for Scala のユーザー定義関数

Note

この記事では、Databricks Runtime 14.1 以降用の Databricks Connect について説明します。

この記事では、Databricks Connect for Scala でユーザー定義関数を実行する方法について説明します。 Databricks Connect を使用すると、一般的な IDE、ノートブック サーバー、カスタム アプリケーションを Azure Databricks クラスターに接続できます。 この記事の Python バージョンについては、「Databricks Connect for Python のユーザー定義関数」を参照してください。

Note

Databricks Connect の使用を開始する前に、Databricks Connect クライアントを設定必要があります。

Databricks Runtime 14.1 以降の場合、Databricks Connect for Scala ではユーザー定義関数 (UDF) の実行がサポートされています。

UDF を実行するには、UDF に必要なコンパイル済みクラスと JAR をクラスターにアップロードする必要があります。 addCompiledArtifacts() API を使用して、アップロードする必要があるコンパイル済みクラスと JAR ファイルを指定できます。

Note

クライアントによって使用される Scala は、Azure Databricks クラスターの Scala バージョンと一致している必要があります。 クラスターの Databricks Runtime バージョンを確認するには、「Databricks Runtime リリース ノートのバージョンと互換性」の「システム環境」セクションを参照してください。

次の Scala プログラムは、列の値を 2 乗するシンプルな UDF を設定します。

import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}

object Main {
  def main(args: Array[String]): Unit = {
    val sourceLocation = getClass.getProtectionDomain.getCodeSource.getLocation.toURI

    val spark = DatabricksSession.builder()
      .addCompiledArtifacts(sourceLocation)
      .getOrCreate()

    def squared(x: Int): Int = x * x

    val squared_udf = udf(squared _)

    spark.range(3)
      .withColumn("squared", squared_udf(col("id")))
      .select("squared")
      .show()
  }
}

前の例では、UDF が Main に完全に含まれているため、Main のコンパイル済みの成果物のみが追加されます。 UDF が他のクラスに分散している場合、または外部ライブラリ (JAR) を使用している場合は、これらのライブラリもすべて含める必要があります。

Spark セッションが既に初期化されている場合は、spark.addArtifact() API を使用してさらにコンパイル済みのクラスと JAR をアップロードできます。

Note

JAR をアップロードするときは、すべての推移的な依存関係の JAR をアップロードに含める必要があります。 API では、推移的な依存関係の自動検出は実行されません。

型指定されたデータセット API

前のセクションで UDF について説明したのと同じメカニズムが、型指定されたデータセット API にも適用されます。

型指定されたデータセット API を使用すると、結果のデータセットに対してマップ、フィルター、集計などの変換を実行できます。 これらは、Databricks クラスター上の UDF と同様に実行されます。

次の Scala アプリケーションでは、map() API を使用して、結果列の数値を接頭辞付き文字列に変更します。

import com.databricks.connect.DatabricksSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, udf}

object Main {
  def main(args: Array[String]): Unit = {
    val sourceLocation = getClass.getProtectionDomain.getCodeSource.getLocation.toURI

    val spark = DatabricksSession.builder()
      .addCompiledArtifacts(sourceLocation)
      .getOrCreate()

    spark.range(3).map(f => s"row-$f").show()
  }
}

この例では map() API を使用していますが、これは、他の型指定されたデータセット API (filter()mapPartitions() など) にも適用されます。