Добавление трассировок в агенты
Внимание
Эта функция предоставляется в режиме общедоступной предварительной версии.
В этой статье показано, как добавить трассировки в агенты с помощью API Fluent и MLflowClient, доступных с помощью трассировки MLflow.
Примечание.
Подробные справочные материалы по API и примеры кода для трассировки MLflow см. в документации по MLflow.
Требования
- MLflow 2.13.1
Использование автолога для добавления трассировок в агенты
Если вы используете библиотеку GenAI, которая поддерживает трассировку (например, LangChain, LlamaIndex или OpenAI), вы можете включить автоматическую журналирование MLflow для интеграции библиотеки, чтобы включить трассировку.
Например, используйте mlflow.langchain.autolog()
для автоматического добавления трассировок в агент на основе LangChain.
Примечание.
По состоянию на Databricks Runtime 15.4 LTS ML трассировка MLflow включена по умолчанию в записных книжках. Чтобы отключить трассировку, например с LangChain, можно выполнить mlflow.langchain.autolog(log_traces=False)
в записной книжке.
mlflow.langchain.autolog()
MLflow поддерживает дополнительные библиотеки для автологирования трассировки. Полный список интегрированных библиотек см. в документации по трассировке MLflow.
Использование API Fluent для ручного добавления трассировок в агент
Ниже приведен краткий пример, использующий API Fluent: mlflow.trace
и mlflow.start_span
добавление трассировок в .quickstart-agent
Это рекомендуется для моделей PyFunc.
import mlflow
from mlflow.deployments import get_deploy_client
class QAChain(mlflow.pyfunc.PythonModel):
def __init__(self):
self.client = get_deploy_client("databricks")
@mlflow.trace(name="quickstart-agent")
def predict(self, model_input, system_prompt, params):
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": model_input[0]["query"]
}
]
traced_predict = mlflow.trace(self.client.predict)
output = traced_predict(
endpoint=params["model_name"],
inputs={
"temperature": params["temperature"],
"max_tokens": params["max_tokens"],
"messages": messages,
},
)
with mlflow.start_span(name="_final_answer") as span:
# Initiate another span generation
span.set_inputs({"query": model_input[0]["query"]})
answer = output["choices"][0]["message"]["content"]
span.set_outputs({"generated_text": answer})
# Attributes computed at runtime can be set using the set_attributes() method.
span.set_attributes({
"model_name": params["model_name"],
"prompt_tokens": output["usage"]["prompt_tokens"],
"completion_tokens": output["usage"]["completion_tokens"],
"total_tokens": output["usage"]["total_tokens"]
})
return answer
Выполнение вывода
После инструментирования кода можно запустить функцию, как обычно. Следующий пример продолжается с predict()
функцией в предыдущем разделе. Трассировки автоматически отображаются при запуске метода вызова. predict()
SYSTEM_PROMPT = """
You are an assistant for Databricks users. You are answering python, coding, SQL, data engineering, spark, data science, DW and platform, API or infrastructure administration question related to Databricks. If the question is not related to one of these topics, kindly decline to answer. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Use the following pieces of context to answer the question at the end:
"""
model = QAChain()
prediction = model.predict(
[
{"query": "What is in MLflow 5.0"},
],
SYSTEM_PROMPT,
{
# Using Databricks Foundation Model for easier testing, feel free to replace it.
"model_name": "databricks-dbrx-instruct",
"temperature": 0.1,
"max_tokens": 1000,
}
)
API Fluent
API Fluent в MLflow автоматически создают иерархию трассировки на основе того, где и когда выполняется код. В следующих разделах описываются поддерживаемые задачи с помощью API-интерфейсов трассировки MLflow Fluent.
Декорируйте функцию
Вы можете декорировать функцию с @mlflow.trace
помощью декоратора, чтобы создать диапазон для области украшенной функции. Диапазон начинается при вызове функции и заканчивается при возвращении. MLflow автоматически записывает входные и выходные данные функции, а также любые исключения, вызванные функцией. Например, при выполнении следующего кода будет создан диапазон с именем "my_function", запись входных аргументов x и y, а также выходные данные функции.
@mlflow.trace(name="agent", span_type="TYPE", attributes={"key": "value"})
def my_function(x, y):
return x + y
Использование диспетчера контекстов трассировки
Если вы хотите создать диапазон для произвольного блока кода, а не только функции, можно использовать mlflow.start_span()
в качестве диспетчера контекста, который упаковывает блок кода. Диапазон начинается при вводе контекста и заканчивается при выходе контекста. Входные и выходные данные диапазона должны быть предоставлены вручную с помощью методов задания объекта диапазона, возвращаемого диспетчером контекстов.
with mlflow.start_span("my_span") as span:
span.set_inputs({"x": x, "y": y})
result = x + y
span.set_outputs(result)
span.set_attribute("key", "value")
Оболочка внешней функции
Функцию mlflow.trace
можно использовать в качестве оболочки для трассировки выбранной функции. Это полезно, если вы хотите отслеживать функции, импортированные из внешних библиотек. Он создает тот же диапазон, что и вы получаете, декорируя эту функцию.
from sklearn.metrics import accuracy_score
y_pred = [0, 2, 1, 3]
y_true = [0, 1, 2, 3]
traced_accuracy_score = mlflow.trace(accuracy_score)
traced_accuracy_score(y_true, y_pred)
API клиента MLflow
MlflowClient
предоставляет детализированные, потокобезопасные API для запуска и завершения трассировки, управления диапазонами и задания полей диапазона. Он обеспечивает полный контроль жизненного цикла трассировки и структуры. Эти API полезны, если api Fluent недостаточно для ваших требований, например многопоточных приложений и обратных вызовов.
Ниже приведены шаги по созданию полной трассировки с помощью клиента MLflow.
Создание экземпляра MLflowClient по
client = MlflowClient()
.Запустите трассировку с помощью
client.start_trace()
метода. Это инициирует контекст трассировки и запускает абсолютный корневой диапазон и возвращает объект корневого диапазона. Этот метод должен выполняться передstart_span()
API.- Задайте атрибуты, входные данные и выходные данные для трассировки
client.start_trace()
.
Примечание.
В API Fluent нет эквивалента методу
start_trace()
. Это связано с тем, что API Fluent автоматически инициализировать контекст трассировки и определить, является ли он корневым диапазоном на основе управляемого состояния.- Задайте атрибуты, входные данные и выходные данные для трассировки
API start_trace() возвращает диапазон. Получите идентификатор запроса, уникальный идентификатор трассировки, также
trace_id
называемый идентификатором возвращаемого диапазона иspan.request_id
span.span_id
.Запустите дочерний диапазон, используя для
client.start_span(request_id, parent_id=span_id)
задания атрибутов, входных данных и выходных данных для диапазона.- Для этого метода требуется
request_id
parent_id
связать диапазон с правильной позицией в иерархии трассировки. Он возвращает другой объект диапазона.
- Для этого метода требуется
Завершение дочернего диапазона путем вызова
client.end_span(request_id, span_id)
.Повторите 3 – 5 для всех дочерних диапазонов, которые нужно создать.
После завершения всех дочерних диапазонов вызовите
client.end_trace(request_id)
закрытие всей трассировки и запишите его.
from mlflow.client import MlflowClient
mlflow_client = MlflowClient()
root_span = mlflow_client.start_trace(
name="simple-rag-agent",
inputs={
"query": "Demo",
"model_name": "DBRX",
"temperature": 0,
"max_tokens": 200
}
)
request_id = root_span.request_id
# Retrieve documents that are similar to the query
similarity_search_input = dict(query_text="demo", num_results=3)
span_ss = mlflow_client.start_span(
"search",
# Specify request_id and parent_id to create the span at the right position in the trace
request_id=request_id,
parent_id=root_span.span_id,
inputs=similarity_search_input
)
retrieved = ["Test Result"]
# Span has to be ended explicitly
mlflow_client.end_span(request_id, span_id=span_ss.span_id, outputs=retrieved)
root_span.end_trace(request_id, outputs={"output": retrieved})