エージェントにトレースを追加する

重要

この機能はパブリック プレビュー段階にあります。

この記事では、MLflow トレースで使用できる Fluent および MLflowClient API を使用して、エージェントにトレースを追加する方法を示します。

Note

MLflow トレースの詳細な API リファレンスとコード例については、MLflow のドキュメントを参照してください。

要件

  • MLflow 2.13.1

自動ログ記録を使用して Langchain エージェントにトレースを追加する

Langchain を使用している場合は、MLflow の langchain.autolog() を使用して、エージェントにトレースを自動的に追加します。 これは、Langchain エージェントに対して推奨されます。

mlflow.langchain.autolog()

Fluent API を使用してエージェントにトレースを手動で追加する

Fluent APImlflow.tracemlflow.start_span を使用してトレースを quickstart-agent に追加する簡単な例を次に示します。 これは PyFunc モデルに対して推奨されます。


import mlflow
from mlflow.deployments import get_deploy_client

class QAChain(mlflow.pyfunc.PythonModel):
    def __init__(self):
        self.client = get_deploy_client("databricks")

    @mlflow.trace(name="quickstart-agent")
    def predict(self, model_input, system_prompt, params):
        messages = [
                {
                    "role": "system",
                    "content": system_prompt,
                },
                {
                    "role": "user",
                    "content":  model_input[0]["query"]
                }
          ]

        traced_predict = mlflow.trace(self.client.predict)
        output = traced_predict(
            endpoint=params["model_name"],
            inputs={
                "temperature": params["temperature"],
                "max_tokens": params["max_tokens"],
                "messages": messages,
            },
        )

        with mlflow.start_span(name="_final_answer") as span:
          # Initiate another span generation
            span.set_inputs({"query": model_input[0]["query"]})

            answer = output["choices"][0]["message"]["content"]

            span.set_outputs({"generated_text": answer})
            # Attributes computed at runtime can be set using the set_attributes() method.
            span.set_attributes({
              "model_name": params["model_name"],
                        "prompt_tokens": output["usage"]["prompt_tokens"],
                        "completion_tokens": output["usage"]["completion_tokens"],
                        "total_tokens": output["usage"]["total_tokens"]
                    })
              return answer

推論を実行する

コードをインストルメント化したら、通常どおりに関数を実行できます。 前のセクションの predict() 関数を使用して例を進めます。 呼び出しメソッド predict() を実行すると、トレースが自動的に表示されます。


SYSTEM_PROMPT = """
You are an assistant for Databricks users. You are answering python, coding, SQL, data engineering, spark, data science, DW and platform, API or infrastructure administration question related to Databricks. If the question is not related to one of these topics, kindly decline to answer. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Use the following pieces of context to answer the question at the end:
"""

model = QAChain()

prediction = model.predict(
  [
      {"query": "What is in MLflow 5.0"},
  ],
  SYSTEM_PROMPT,
  {
    # Using Databricks Foundation Model for easier testing, feel free to replace it.
    "model_name": "databricks-dbrx-instruct",
    "temperature": 0.1,
    "max_tokens": 1000,
  }
)

Fluent API

MLflow の Fluent API は、コードを実行する場所とタイミングに基づいてトレース階層を自動的に構築します。 以降のセクションでは、MLflow トレース Fluent API を使用してサポートされるタスクについて説明します。

関数を装飾する

@mlflow.trace デコレーターを使用して関数を装飾し、装飾された関数のスコープのスパンを作成できます。 スパンは、関数が呼び出されたときに開始し、戻ったときに終了します。 MLflow は、関数の入力と出力、および関数から発生した例外を自動的に記録します。 たとえば、次のコードを実行すると、"my_function" という名前のスパンが作成され、入力引数 x と y、および関数の出力がキャプチャされます。

@mlflow.trace(name="agent", span_type="TYPE", attributes={"key": "value"})
def my_function(x, y):
    return x + y

トレース コンテキスト マネージャーを使用する

関数だけでなく、任意のコード ブロックのスパンを作成する場合は、コード ブロックをラップするコンテキスト マネージャーとして mlflow.start_span() を使用できます。 スパンは、コンテキストが入力されたときに開始し、コンテキストが終了したときに終了します。 スパンの入力と出力は、コンテキスト マネージャーから生成されるスパン オブジェクトのセッター メソッドを使用して手動で提供する必要があります。

with mlflow.start_span("my_span") as span:
    span.set_inputs({"x": x, "y": y})
    result = x + y
    span.set_outputs(result)
    span.set_attribute("key", "value")

外部関数をラップする

mlflow.trace 関数は、選択した関数をトレースするためのラッパーとして使用できます。 これは、外部ライブラリからインポートされた関数をトレースする場合に便利です。 その関数を修飾することによって取得するのと同じスパンが生成されます。


from sklearn.metrics import accuracy_score

y_pred = [0, 2, 1, 3]
y_true = [0, 1, 2, 3]

traced_accuracy_score = mlflow.trace(accuracy_score)
traced_accuracy_score(y_true, y_pred)

MLflow クライアント API

MlflowClient では、トレースの開始と終了、スパンの管理、スパン フィールドの設定を行うために、きめ細かいスレッド セーフの API を公開しています。 トレースのライフサイクルと構造を完全に制御できます。 これらの API は、マルチスレッド アプリケーションやコールバックなどの要件に対して Fluent API が十分でない場合に便利です。

MLflow クライアントを使用して完全なトレースを作成する手順を次に示します。

  1. client = MlflowClient() によって MLflowClient のインスタンスを作成します。

  2. client.start_trace() メソッドを使用してトレースを開始します。 これにより、トレース コンテキストが開始し、絶対ルート スパンが開始し、ルート スパン オブジェクトが返されます。 このメソッドは start_span() API の前に実行する必要があります。

    1. client.start_trace() でトレースの属性、入力、出力を設定します。

    Note

    Fluent API の start_trace() メソッドと同等のものはありません。 これは、Fluent API によってトレース コンテキストが自動的に初期化され、それがルート スパンであるかどうかがマネージド状態に基づいて判断されるためです。

  3. start_trace() API はスパンを返します。 要求 ID、trace_id とも呼ばれるトレースの一意識別子を取得し、span.request_idspan.span_id を使用して返されるスパンの ID を取得します。

  4. client.start_span(request_id, parent_id=span_id) を使用して子スパンを開始して、スパンの属性、入力、出力を設定します。

    1. このメソッドでは、トレース階層内の正しい位置にスパンを関連付けるために、request_idparent_id が必要です。 別のスパン オブジェクトを返します。
  5. client.end_span(request_id, span_id) を呼び出して子スパンを終了します。

  6. 作成するすべての子スパンに対して 3 - 5 を繰り返します。

  7. すべての子スパンが終了したら、client.end_trace(request_id) を呼び出してトレース全体を終了し、記録します。

from mlflow.client import MlflowClient

mlflow_client = MlflowClient()

root_span = mlflow_client.start_trace(
  name="simple-rag-agent",
  inputs={
          "query": "Demo",
          "model_name": "DBRX",
          "temperature": 0,
          "max_tokens": 200
         }
  )

request_id = root_span.request_id

# Retrieve documents that are similar to the query
similarity_search_input = dict(query_text="demo", num_results=3)

span_ss = mlflow_client.start_span(
      "search",
      # Specify request_id and parent_id to create the span at the right position in the trace
        request_id=request_id,
        parent_id=root_span.span_id,
        inputs=similarity_search_input
  )
retrieved = ["Test Result"]

# Span has to be ended explicitly
mlflow_client.end_span(request_id, span_id=span_ss.span_id, outputs=retrieved)

root_span.end_trace(request_id, outputs={"output": retrieved})