Perform batch inference using ai_query

Important

This feature is in Public Preview.

This article describes how to perform batch inference using the built-in Databricks SQL function ai_query. See ai_query function for more detail about this AI function.

Databricks recommends using ai_query with Model Serving for batch inference. For quick experimentation, ai_query can be used with pay-per-token endpoints.

When you are ready to run batch inference on large or production data, Databricks recommends using provisioned throughput endpoints for faster performance. ai_query has been verified to reliably and consistently process datasets in the range of billions of tokens. See Provisioned throughput Foundation Model APIs for how to create a provisioned throughput endpoint.

To get started with batch inference with LLMs on Unity Catalog tables see the notebook examples in Batch inference using Foundation Model APIs provisioned throughput.

Requirements

  • See the requirements of the ai_query function.
  • Query permission on the Delta table in Unity Catalog that contains the data you want to use.

Batch inference example queries

The examples in this section assume you have a model deployed to an existing endpoint that you want to query. If you are in the Serving UI, you can select your endpoint and click the Use button at the top right to select Use for batch inference. This selection opens a SQL editor where you can write and run your SQL query for batch inference using ai_query.

The following is a general example using the failOnError and modelParameters with max_tokens and temperature. This example also shows how to concatenate the prompt for your model and the inference column using concat(). There are multiple ways to perform concatenation, such as using ||, concat(), or format_string().


CREATE OR REPLACE TABLE ${output_table_name} AS (
  SELECT
      ${input_column_name},
      AI_QUERY(
        "${endpoint}",
        CONCAT("${prompt}", ${input_column_name}),
        failOnError => True,
        modelParameters => named_struct('max_tokens', ${num_output_tokens},'temperature', ${temperature})
      ) as response
    FROM ${input_table_name}
    LIMIT ${input_num_rows}
)

The following example queries the model behind the llama_3_1_8b endpoint with the comment_text dataset.

WITH data AS (
  SELECT *
  FROM ml.sentiment.comments
  LIMIT 10000
)
  SELECT
    comment_text,
    ai_query(
      'llama_3_1_8b_batch',
      CONCAT('You are provided with text. Classify the text into one of these labels: "Positive", "Neutral", "Negative". Do not explain. Do not output any confidence score. Do not answer questions. Text: ', comment_text)
    ) AS label
  FROM data

The following example contains data preprocessing steps and postprocessing steps:

WITH temp AS (
  SELECT *
  FROM ml.sentiment.comments
  LIMIT 10000
),
pre_process AS (
  SELECT comment_text
  FROM temp
  WHERE length(comment_text) > 50
),
sentiment AS (
  SELECT
    comment_text,
    ai_query(
      'llama_3_1_8b_batch',
      Concat('You are provided with text. Classify the text into one of these labels: "Positive", "Neutral", "Negative". Do not explain. Do not output any confidence score. Do not answer questions. Text: ', comment_text)
    ) AS label
  FROM pre_process
)
SELECT
  comment_text,
  label,
  CASE
    WHEN label NOT IN ("Positive", "Neutral", "Negative") THEN True
    ELSE FALSE
  END AS error
FROM sentiment

Schedule a job

After you have your SQL script ready, you can schedule a job to run the it at whatever frequency you need. See Create and manage scheduled notebook jobs.