Add traces to your agents
Important
This feature is in Public Preview.
This article shows how to add traces to your agents using the Fluent and MLflowClient APIs made available with MLflow Tracing.
Note
For detailed API reference and code examples for MLflow Tracing,see the MLflow documentation.
Requirements
- MLflow 2.13.1
Use autologging to add traces to your agents
If you are using a GenAI library that has support for tracing (such as LangChain, LlamaIndex, or OpenAI), you can enable the MLflow autologging for the library integration to enable tracing.
For example, use mlflow.langchain.autolog()
to automatically add traces to your LangChain-based agent.
Note
As of Databricks Runtime 15.4 LTS ML, MLflow tracing is enabled by default within notebooks. To disable tracing, for example with LangChain, you can execute mlflow.langchain.autolog(log_traces=False)
in your notebook.
mlflow.langchain.autolog()
MLflow supports additional libraries for trace autologging. See the MLflow Tracing documentation for a full list of integrated libraries.
Use Fluent APIs to manually add traces to your agent
The following is a quick example that uses the Fluent APIs: mlflow.trace
and mlflow.start_span
to add traces to the quickstart-agent
. This is recommended for PyFunc models.
import mlflow
from mlflow.deployments import get_deploy_client
class QAChain(mlflow.pyfunc.PythonModel):
def __init__(self):
self.client = get_deploy_client("databricks")
@mlflow.trace(name="quickstart-agent")
def predict(self, model_input, system_prompt, params):
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": model_input[0]["query"]
}
]
traced_predict = mlflow.trace(self.client.predict)
output = traced_predict(
endpoint=params["model_name"],
inputs={
"temperature": params["temperature"],
"max_tokens": params["max_tokens"],
"messages": messages,
},
)
with mlflow.start_span(name="_final_answer") as span:
# Initiate another span generation
span.set_inputs({"query": model_input[0]["query"]})
answer = output["choices"][0]["message"]["content"]
span.set_outputs({"generated_text": answer})
# Attributes computed at runtime can be set using the set_attributes() method.
span.set_attributes({
"model_name": params["model_name"],
"prompt_tokens": output["usage"]["prompt_tokens"],
"completion_tokens": output["usage"]["completion_tokens"],
"total_tokens": output["usage"]["total_tokens"]
})
return answer
Perform inference
After you’ve instrumented your code, you can run your function as you normally would. The following continues the example with the predict()
function in the previous section. The traces are automatically shown when you run the invocation method, predict()
.
SYSTEM_PROMPT = """
You are an assistant for Databricks users. You are answering python, coding, SQL, data engineering, spark, data science, DW and platform, API or infrastructure administration question related to Databricks. If the question is not related to one of these topics, kindly decline to answer. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Use the following pieces of context to answer the question at the end:
"""
model = QAChain()
prediction = model.predict(
[
{"query": "What is in MLflow 5.0"},
],
SYSTEM_PROMPT,
{
# Using Databricks Foundation Model for easier testing, feel free to replace it.
"model_name": "databricks-dbrx-instruct",
"temperature": 0.1,
"max_tokens": 1000,
}
)
Fluent APIs
The Fluent APIs in MLflow automatically construct the trace hierarchy based on where and when the code is executed. The following sections describe the supported tasks using the MLflow Tracing Fluent APIs.
Decorate your function
You can decorate your function with the @mlflow.trace
decorator to create a span for the scope of the decorated function. The span starts when the function is invoked and ends when it returns. MLflow automatically records the input and output of the function, as well as any exceptions raised from the function. For example, running the following code will create a span with the name “my_function”, capturing the input arguments x and y, as well as the output of the function.
@mlflow.trace(name="agent", span_type="TYPE", attributes={"key": "value"})
def my_function(x, y):
return x + y
Use the tracing context manager
If you want to create a span for an arbitrary block of code, not just a function, you can use mlflow.start_span()
as a context manager that wraps the code block. The span starts when the context is entered and ends when the context is exited. The span input and outputs should be provided manually via setter methods of the span object that is yielded from the context manager.
with mlflow.start_span("my_span") as span:
span.set_inputs({"x": x, "y": y})
result = x + y
span.set_outputs(result)
span.set_attribute("key", "value")
Wrap an external function
The mlflow.trace
function can be used as a wrapper to trace a function of your choice. This is useful when you want to trace functions imported from external libraries. It generates the same span as you would get by decorating that function.
from sklearn.metrics import accuracy_score
y_pred = [0, 2, 1, 3]
y_true = [0, 1, 2, 3]
traced_accuracy_score = mlflow.trace(accuracy_score)
traced_accuracy_score(y_true, y_pred)
MLflow Client APIs
MlflowClient
exposes granular, thread safe APIs to start and end traces, manage spans, and set span fields. It provides full control of the trace lifecycle and structure. These APIs are useful when the Fluent APIs are not sufficient for your requirements, such as multi-threaded applications and callbacks.
The following are steps to create a complete trace using the MLflow Client.
Create an instance of MLflowClient by
client = MlflowClient()
.Start a trace using the
client.start_trace()
method. This initiates the trace context and starts an absolute root span and returns a root span object. This method must be run before thestart_span()
API.- Set your attributes, inputs, and outputs for the trace in
client.start_trace()
.
Note
There is not an equivalent to the
start_trace()
method in the Fluent APIs. This is because the Fluent APIs automatically initialize the trace context and determine whether it is the root span based on the managed state.- Set your attributes, inputs, and outputs for the trace in
The start_trace() API returns a span. Get the request ID, a unique identifier of the trace also referred to as the
trace_id
, and the ID of the returned span usingspan.request_id
andspan.span_id
.Start a child span using
client.start_span(request_id, parent_id=span_id)
to set your attributes, inputs, and outputs for the span.- This method requires
request_id
andparent_id
to associate the span with the correct position in the trace hierarchy. It returns another span object.
- This method requires
End the child span by calling
client.end_span(request_id, span_id)
.Repeat 3 - 5 for any children spans you want to create.
After all of the children spans are ended, call
client.end_trace(request_id)
to close the entire trace and record it.
from mlflow.client import MlflowClient
mlflow_client = MlflowClient()
root_span = mlflow_client.start_trace(
name="simple-rag-agent",
inputs={
"query": "Demo",
"model_name": "DBRX",
"temperature": 0,
"max_tokens": 200
}
)
request_id = root_span.request_id
# Retrieve documents that are similar to the query
similarity_search_input = dict(query_text="demo", num_results=3)
span_ss = mlflow_client.start_span(
"search",
# Specify request_id and parent_id to create the span at the right position in the trace
request_id=request_id,
parent_id=root_span.span_id,
inputs=similarity_search_input
)
retrieved = ["Test Result"]
# Span has to be ended explicitly
mlflow_client.end_span(request_id, span_id=span_ss.span_id, outputs=retrieved)
root_span.end_trace(request_id, outputs={"output": retrieved})