Create and log AI agents

Important

This feature is in Public Preview.

This article shows you how to create and log AI agents such as RAG applications using the Mosaic AI Agent Framework.

What are chains and agents?

AI systems often have many components. For example, an AI system might retrieve documents from a vector index, use those documents to supplement prompt text, and use a foundation model to summarize the response. The code that links these components, also called steps, together is called a chain.

An agent is a much more advanced AI system that relies on large language models to make decisions on which steps to take based on input. In contrast, chains are hardcoded sequences of steps intended to achieve a specific outcome.

With Agent Framework, you can use any libraries or packages to create code. Agent Framework also makes it easy to iterate on your code as you develop and test it. You can set up configuration files that let you change code parameters in a traceable way without having to modify the actual code.

Requirements

For agents that use a Databricks-managed vector search index, mlflow version 2.13.1 or above is required to use automatic authorization with the vector index.

Input schema for the RAG agent

The following are supported input formats for your chain.

  • (Recommended) Queries using the OpenAI chat completion schema. It should have an array of objects as a messages parameter. This format is best for RAG applications.

    question = {
        "messages": [
            {
                "role": "user",
                "content": "What is Retrieval-Augmented Generation?",
            },
            {
                "role": "assistant",
                "content": "RAG, or Retrieval Augmented Generation, is a generative AI design pattern that combines a large language model (LLM) with external knowledge retrieval. This approach allows for real-time data connection to generative AI applications, improving their accuracy and quality by providing context from your data to the LLM during inference. Databricks offers integrated tools that support various RAG scenarios, such as unstructured data, structured data, tools & function calling, and agents.",
            },
            {
                "role": "user",
                "content": "How to build RAG for unstructured data",
            },
        ]
    }
    
  • SplitChatMessagesRequest. Recommended for multi-turn chat applications, especially when you want to manage current query and history separately.

    {
    "query": "What is MLflow",
    "history": [
      {
      "role": "user",
      "content": "What is Retrieval-augmented Generation?"
      },
      {
      "role": "assistant",
      "content": "RAG is"
      }
      ]
    }
    

For LangChain, Databricks recommends writing your chain in LangChain Expression Language. In your chain definition code, you can use an itemgetter to get the messages or query or history objects depending on which input format you are using.

Output schema for the RAG agent

Your code must comply with one of the following supported output formats:

  • (Recommended) ChatCompletionResponse. This format is recommended for customers with OpenAI response format interoperability.
  • StringResponse. This format is the easies and simplest to interpret.

For LangChain, use StrOutputParser() as your final chain step. Your output must return a single string value.

  chain = (
      {
          "user_query": itemgetter("messages")
          | RunnableLambda(extract_user_query_string),
          "chat_history": itemgetter("messages") | RunnableLambda(extract_chat_history),
      }
      | RunnableLambda(fake_model)
      | StrOutputParser()
  )

If you are using PyFunc, Databricks recommends using type hints to annotate the predict() function with input and output data classes that are subclasses of classes defined in mlflow.models.rag_signatures.

You can construct an output object from the data class inside predict() to ensure the format is followed. The returned object must be transformed into a dictionary representation to ensure it can be serialized.

Use parameters to control quality iteration

In the Agent Framework, you can use parameters to control how agents are executed. This allows you to quickly iterate by varying characteristics of your agent without changing the code. Parameters are key-value pairs that you define in a Python dictionary or a .yaml file.

To configure the code, you create a ModelConfig, a set of key-value parameters. The ModelConfig is either a Python dictionary or a .yaml file. For example, you could use a dictionary during development and then convert it to a .yaml file for production deployment and CI/CD. For details about ModelConfig, see the MLflow documentation.

An example ModelConfig is shown below.

llm_parameters:
  max_tokens: 500
  temperature: 0.01
model_serving_endpoint: databricks-dbrx-instruct
vector_search_index: ml.docs.databricks_docs_index
prompt_template: 'You are a hello world bot. Respond with a reply to the user''s
  question that indicates your prompt template came from a YAML file. Your response
  must use the word "YAML" somewhere. User''s question: {question}'
prompt_template_input_vars:
- question

To call the configuration from your code, use one of the following:

# Example for loading from a .yml file
config_file = "configs/hello_world_config.yml"
model_config = mlflow.models.ModelConfig(development_config=config_file)

# Example of using a dictionary
config_dict = {
    "prompt_template": "You are a hello world bot. Respond with a reply to the user's question that is fun and interesting to the user. User's question: {question}",
    "prompt_template_input_vars": ["question"],
    "model_serving_endpoint": "databricks-dbrx-instruct",
    "llm_parameters": {"temperature": 0.01, "max_tokens": 500},
}

model_config = mlflow.models.ModelConfig(development_config=config_dict)

# Use model_config.get() to retrieve a parameter value
value = model_config.get('sample_param')

Log the agent

Logging an agent is the basis of the development process. Logging captures a “point in time” of the agent’s code and configuration so you can evaluate the quality of the configuration. When developing agents, Databricks recommends that you use code-based logging instead of serialization-based logging. For more information about the pros and cons of each type of logging, see Code-based vs. serialization-based logging.

This section covers how to use code-based logging. For details about how to use serialization-based logging, see Serialization-based logging workflow.

Code-based logging workflow

For code-based logging, the code that logs your agent or chain must be in a separate notebook from your chain code. This notebook is called a driver notebook. For an example notebook, see Example notebooks.

Code-based logging workflow with LangChain

  1. Create a notebook or Python file with your code. For purposes of this example, the notebook or file is named chain.py. The notebook or file must contain a LangChain chain, referred to here as lc_chain.
  2. Include mlflow.models.set_model(lc_chain) in the notebook or file.
  3. Create a new notebook to serve as the driver notebook (called driver.py in this example).
  4. In the driver notebook, include the call mlflow.lang_chain.log_model(lc_model=”/path/to/chain.py”). This call runs chain.py and logs the results to an MLflow model.
  5. Deploy the model.
  6. When the serving environment is loaded, chain.py is executed.
  7. When a serving request comes in, lc_chain.invoke(...) is called.

Code-based logging workflow with PyFunc

  1. Create a notebook or Python file with your code. For purposes of this example, the notebook or file is named chain.py. The notebook or file must contain a PyFunc class, referred to here as PyFuncClass.
  2. Include mlflow.models.set_model(PyFuncClass) in the notebook or file.
  3. Create a new notebook to serve as the driver notebook (called driver.py in this example).
  4. In the driver notebook, include the call mlflow.pyfunc.log_model(python_model=”/path/to/chain.py”). This call runs chain.py and logs the results to an MLflow model.
  5. Deploy the model.
  6. When the serving environment is loaded, chain.py is executed.
  7. When a serving request comes in, PyFuncClass.predict(...) is called.

Example code for logging chains

import mlflow

code_path = "/Workspace/Users/first.last/chain.py"
config_path = "/Workspace/Users/first.last/config.yml"

input_example = {
    "messages": [
        {
            "role": "user",
            "content": "What is Retrieval-augmented Generation?",
        }
    ]
}

# example using LangChain
with mlflow.start_run():
  logged_chain_info = mlflow.langchain.log_model(
    lc_model=code_path,
    model_config=config_path, # If you specify this parameter, this is the configuration that is used for training the model. The development_config is overwritten.
    artifact_path="chain", # This string is used as the path inside the MLflow model where artifacts are stored
    input_example=input_example, # Must be a valid input to your chain
    example_no_conversion=True, # Required
  )

# or use a PyFunc model
# with mlflow.start_run():
#   logged_chain_info = mlflow.pyfunc.log_model(
#     python_model=chain_notebook_path,
#     artifact_path="chain",
#     input_example=input_example,
#     example_no_conversion=True,
#   )

print(f"MLflow Run: {logged_chain_info.run_id}")
print(f"Model URI: {logged_chain_info.model_uri}")

To verify that the model has been logged correctly, load the chain and call invoke:

# Using LangChain
model = mlflow.langchain.load_model(logged_chain_info.model_uri)
model.invoke(example)

# Using PyFunc
model = mlflow.pyfunc.load_model(logged_chain_info.model_uri)
model.invoke(example)

Register the chain to Unity Catalog

Before you deploy the chain, you must register the chain to Unity Catalog. When you register the chain, it is packaged as a model in Unity Catalog, and you can use Unity Catalog permissions for authorization for resources in the chain.

import mlflow

mlflow.set_registry_uri("databricks-uc")

catalog_name = "test_catalog"
schema_name = "schema"
model_name = "chain_name"

model_name = catalog_name + "." + schema_name + "." + model_name
uc_model_info = mlflow.register_model(model_uri=logged_chain_info.model_uri, name=model_name)

Example notebooks

These notebooks create a simple “Hello, world” chain to illustrate how to create a chain application in Databricks. The first example creates a simple chain. The second example notebook illustrates how to use parameters to minimize code changes during development.

Simple chain notebook

Get notebook

Simple chain driver notebook

Get notebook

Parameterized chain notebook

Get notebook

Parameterized chain driver notebook

Get notebook

Code-based vs. serialization-based logging

To create and log a chain, you can use code-based MLflow logging or serialization-based MLflow logging. Databricks recommends that you use code-based logging.

With Code-based MLflow logging, the chain’s code is captured as a Python file. The Python environment is captured as a list of packages. When the chain is deployed, the Python environment is restored, and the chain’s code is executed to load the chain into memory so it can be invoked when the endpoint is called.

With Serialization-based MLflow logging, the chain’s code and current state in the Python environment is serialized to disk, often using libraries such as pickle or joblib. When the chain is deployed, the Python environment is restored, and the serialized object is loaded into memory so it can be invoked when the endpoint is called.

The table shows the advantages and disadvantages of each method.

Method Advantages Disadvantages
Code-based MLflow logging * Overcomes inherent limitations of serialization, which is not supported by many popular GenAI libraries.
* Saves a copy of the original code for later reference.
* No need to restructure your code into a single object that can be serialized.
log_model(...) must be called from a different notebook than the chain’s code (called a driver notebook).
Serialization-based MLflow logging log_model(...) can be called from the same notebook where the model is defined. * Original code is not available.
* All libraries and objects used in the chain must support serialization.

Serialization-based logging workflow

Databricks recommends that you use code-based logging instead of serialization-based logging. For details about how to use code-based logging, see Code-based logging workflow.

This section describes how to use serialization-based logging.

Serialization-based logging workflow with LangChain

  1. Create a notebook or Python file with your code. The notebook or file must contain a LangChain chain, referred to here as lc_chain.
  2. Include mlflow.lang_chain.log_model(lc_model=lc_chain) in the notebook or file.
  3. A serialized copy of PyFuncClass() is logged to an MLflow Model.
  4. Deploy the model.
  5. When the serving environment is loaded, PyFuncClass is de-serialized.
  6. When a serving request comes in, lc_chain.invoke(...) is called.

Serialization-based logging workflow with PyFunc

  1. Create a notebook or Python file with your code. For purposes of this example, the notebook or file is named notebook.py. The notebook or file must contain a PyFunc class, referred to here as PyFuncClass.
  2. Include mlflow.pyfunc.log_model(python_model=PyFuncClass()) in notebook.py.
  3. A serialized copy of PyFuncClass() is logged to an MLflow Model.
  4. Deploy the model.
  5. When the serving environment is loaded, PyFuncClass is de-serialized.
  6. When a serving request comes in, PyFuncClass.predict(...) is called.