Add traces to your agents

Important

This feature is in Public Preview.

This article shows how to add traces to your agents using the Fluent and MLflowClient APIs made available with MLflow Tracing.

Note

For detailed API reference and code examples for MLflow Tracing,see the MLflow documentation.

Requirements

  • MLflow 2.13.1

Use autologging to add traces to your LangChain agent

If you are using LangChain, use MLflow’s langchain.autolog() to automatically add traces to your agent. This is recommended for Langchain agents.

mlflow.langchain.autolog()

MLflow supports additional libraries for trace autologging. See the MLflow Tracing documentation for a full list of integrated libraries.

Use Fluent APIs to manually add traces to your agent

The following is a quick example that uses the Fluent APIs: mlflow.trace and mlflow.start_span to add traces to the quickstart-agent. This is recommended for PyFunc models.


import mlflow
from mlflow.deployments import get_deploy_client

class QAChain(mlflow.pyfunc.PythonModel):
    def __init__(self):
        self.client = get_deploy_client("databricks")

    @mlflow.trace(name="quickstart-agent")
    def predict(self, model_input, system_prompt, params):
        messages = [
                {
                    "role": "system",
                    "content": system_prompt,
                },
                {
                    "role": "user",
                    "content":  model_input[0]["query"]
                }
          ]

        traced_predict = mlflow.trace(self.client.predict)
        output = traced_predict(
            endpoint=params["model_name"],
            inputs={
                "temperature": params["temperature"],
                "max_tokens": params["max_tokens"],
                "messages": messages,
            },
        )

        with mlflow.start_span(name="_final_answer") as span:
          # Initiate another span generation
            span.set_inputs({"query": model_input[0]["query"]})

            answer = output["choices"][0]["message"]["content"]

            span.set_outputs({"generated_text": answer})
            # Attributes computed at runtime can be set using the set_attributes() method.
            span.set_attributes({
              "model_name": params["model_name"],
                        "prompt_tokens": output["usage"]["prompt_tokens"],
                        "completion_tokens": output["usage"]["completion_tokens"],
                        "total_tokens": output["usage"]["total_tokens"]
                    })
              return answer

Perform inference

After you’ve instrumented your code, you can run your function as you normally would. The following continues the example with the predict() function in the previous section. The traces are automatically shown when you run the invocation method, predict().


SYSTEM_PROMPT = """
You are an assistant for Databricks users. You are answering python, coding, SQL, data engineering, spark, data science, DW and platform, API or infrastructure administration question related to Databricks. If the question is not related to one of these topics, kindly decline to answer. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. Use the following pieces of context to answer the question at the end:
"""

model = QAChain()

prediction = model.predict(
  [
      {"query": "What is in MLflow 5.0"},
  ],
  SYSTEM_PROMPT,
  {
    # Using Databricks Foundation Model for easier testing, feel free to replace it.
    "model_name": "databricks-dbrx-instruct",
    "temperature": 0.1,
    "max_tokens": 1000,
  }
)

Fluent APIs

The Fluent APIs in MLflow automatically construct the trace hierarchy based on where and when the code is executed. The following sections describe the supported tasks using the MLflow Tracing Fluent APIs.

Decorate your function

You can decorate your function with the @mlflow.trace decorator to create a span for the scope of the decorated function. The span starts when the function is invoked and ends when it returns. MLflow automatically records the input and output of the function, as well as any exceptions raised from the function. For example, running the following code will create a span with the name “my_function”, capturing the input arguments x and y, as well as the output of the function.

@mlflow.trace(name="agent", span_type="TYPE", attributes={"key": "value"})
def my_function(x, y):
    return x + y

Use the tracing context manager

If you want to create a span for an arbitrary block of code, not just a function, you can use mlflow.start_span() as a context manager that wraps the code block. The span starts when the context is entered and ends when the context is exited. The span input and outputs should be provided manually via setter methods of the span object that is yielded from the context manager.

with mlflow.start_span("my_span") as span:
    span.set_inputs({"x": x, "y": y})
    result = x + y
    span.set_outputs(result)
    span.set_attribute("key", "value")

Wrap an external function

The mlflow.trace function can be used as a wrapper to trace a function of your choice. This is useful when you want to trace functions imported from external libraries. It generates the same span as you would get by decorating that function.


from sklearn.metrics import accuracy_score

y_pred = [0, 2, 1, 3]
y_true = [0, 1, 2, 3]

traced_accuracy_score = mlflow.trace(accuracy_score)
traced_accuracy_score(y_true, y_pred)

MLflow Client APIs

MlflowClient exposes granular, thread safe APIs to start and end traces, manage spans, and set span fields. It provides full control of the trace lifecycle and structure. These APIs are useful when the Fluent APIs are not sufficient for your requirements, such as multi-threaded applications and callbacks.

The following are steps to create a complete trace using the MLflow Client.

  1. Create an instance of MLflowClient by client = MlflowClient().

  2. Start a trace using the client.start_trace() method. This initiates the trace context and starts an absolute root span and returns a root span object. This method must be run before the start_span() API.

    1. Set your attributes, inputs, and outputs for the trace in client.start_trace().

    Note

    There is not an equivalent to the start_trace() method in the Fluent APIs. This is because the Fluent APIs automatically initialize the trace context and determine whether it is the root span based on the managed state.

  3. The start_trace() API returns a span. Get the request ID, a unique identifier of the trace also referred to as the trace_id, and the ID of the returned span using span.request_id and span.span_id.

  4. Start a child span using client.start_span(request_id, parent_id=span_id) to set your attributes, inputs, and outputs for the span.

    1. This method requires request_id and parent_id to associate the span with the correct position in the trace hierarchy. It returns another span object.
  5. End the child span by calling client.end_span(request_id, span_id).

  6. Repeat 3 - 5 for any children spans you want to create.

  7. After all of the children spans are ended, call client.end_trace(request_id) to close the entire trace and record it.

from mlflow.client import MlflowClient

mlflow_client = MlflowClient()

root_span = mlflow_client.start_trace(
  name="simple-rag-agent",
  inputs={
          "query": "Demo",
          "model_name": "DBRX",
          "temperature": 0,
          "max_tokens": 200
         }
  )

request_id = root_span.request_id

# Retrieve documents that are similar to the query
similarity_search_input = dict(query_text="demo", num_results=3)

span_ss = mlflow_client.start_span(
      "search",
      # Specify request_id and parent_id to create the span at the right position in the trace
        request_id=request_id,
        parent_id=root_span.span_id,
        inputs=similarity_search_input
  )
retrieved = ["Test Result"]

# Span has to be ended explicitly
mlflow_client.end_span(request_id, span_id=span_ss.span_id, outputs=retrieved)

root_span.end_trace(request_id, outputs={"output": retrieved})