Create and log AI agents

Important

This feature is in Public Preview.

This article shows you how to create and log AI agents such as RAG applications using the Mosaic AI Agent Framework.

What are chains and agents?

AI systems often have many components. For example, an AI system might retrieve documents from a vector index, use those documents to supplement prompt text, and use a foundation model to summarize the response. The code that links these components, also called steps, together is called a chain.

An agent is a much more advanced AI system that relies on large language models to make decisions on which steps to take based on input. In contrast, chains are hardcoded sequences of steps intended to achieve a specific outcome.

With Agent Framework, you can use any libraries or packages to create code. Agent Framework also makes it easy to iterate on your code as you develop and test it. You can set up configuration files that let you change code parameters in a traceable way without having to modify the actual code.

Requirements

For agents that use a Databricks-managed vector search index, mlflow version 2.13.1 or above is required to use automatic authorization with the vector index.

Input schema for the RAG agent

The following are supported input formats for your chain.

  • (Recommended) Queries using the OpenAI chat completion schema. It should have an array of objects as a messages parameter. This format is best for RAG applications.

    question = {
        "messages": [
            {
                "role": "user",
                "content": "What is Retrieval-Augmented Generation?",
            },
            {
                "role": "assistant",
                "content": "RAG, or Retrieval Augmented Generation, is a generative AI design pattern that combines a large language model (LLM) with external knowledge retrieval. This approach allows for real-time data connection to generative AI applications, improving their accuracy and quality by providing context from your data to the LLM during inference. Databricks offers integrated tools that support various RAG scenarios, such as unstructured data, structured data, tools & function calling, and agents.",
            },
            {
                "role": "user",
                "content": "How to build RAG for unstructured data",
            },
        ]
    }
    
  • SplitChatMessagesRequest. Recommended for multi-turn chat applications, especially when you want to manage current query and history separately.

    question = {
        "query": "What is MLflow",
        "history": [
            {
                "role": "user",
                "content": "What is Retrieval-augmented Generation?"
            },
            {
                "role": "assistant",
                "content": "RAG is"
            }
        ]
    }
    

For LangChain, Databricks recommends writing your chain in LangChain Expression Language. In your chain definition code, you can use an itemgetter to get the messages or query or history objects depending on which input format you are using.

Output schema for the RAG agent

Your code must comply with one of the following supported output formats:

  • (Recommended) ChatCompletionResponse. This format is recommended for customers with OpenAI response format interoperability.
  • StringObjectResponse. This format is the easiest and simplest to interpret.

For LangChain, use StringResponseOutputParser() or ChatCompletionsOutputParser() from MLflow as your final chain step. Doing so formats the LangChain AI message into an Agent-compatible format.


  from mlflow.langchain.output_parsers import StringResponseOutputParser, ChatCompletionsOutputParser

  chain = (
      {
          "user_query": itemgetter("messages")
          | RunnableLambda(extract_user_query_string),
          "chat_history": itemgetter("messages") | RunnableLambda(extract_chat_history),
      }
      | RunnableLambda(fake_model)
      | StringResponseOutputParser() # use this for StringObjectResponse
      # ChatCompletionsOutputParser() # or use this for ChatCompletionResponse
  )

If you are using PyFunc, Databricks recommends using type hints to annotate the predict() function with input and output data classes that are subclasses of classes defined in mlflow.models.rag_signatures.

You can construct an output object from the data class inside predict() to ensure the format is followed. The returned object must be transformed into a dictionary representation to ensure it can be serialized.


  from mlflow.models.rag_signatures import ChatCompletionRequest, ChatCompletionResponse, ChainCompletionChoice, Message

  class RAGModel(PythonModel):
    ...
      def predict(self, context, model_input: ChatCompletionRequest) -> ChatCompletionResponse:
        ...
        return asdict(ChatCompletionResponse(
            choices=[ChainCompletionChoice(message=Message(content=text))]
        ))

Use parameters to control quality iteration

In the Agent Framework, you can use parameters to control how agents are executed. This allows you to quickly iterate by varying characteristics of your agent without changing the code. Parameters are key-value pairs that you define in a Python dictionary or a .yaml file.

To configure the code, you create a ModelConfig, a set of key-value parameters. The ModelConfig is either a Python dictionary or a .yaml file. For example, you could use a dictionary during development and then convert it to a .yaml file for production deployment and CI/CD. For details about ModelConfig, see the MLflow documentation.

An example ModelConfig is shown below.

llm_parameters:
  max_tokens: 500
  temperature: 0.01
model_serving_endpoint: databricks-dbrx-instruct
vector_search_index: ml.docs.databricks_docs_index
prompt_template: 'You are a hello world bot. Respond with a reply to the user''s
  question that indicates your prompt template came from a YAML file. Your response
  must use the word "YAML" somewhere. User''s question: {question}'
prompt_template_input_vars:
- question

To call the configuration from your code, use one of the following:

# Example for loading from a .yml file
config_file = "configs/hello_world_config.yml"
model_config = mlflow.models.ModelConfig(development_config=config_file)

# Example of using a dictionary
config_dict = {
    "prompt_template": "You are a hello world bot. Respond with a reply to the user's question that is fun and interesting to the user. User's question: {question}",
    "prompt_template_input_vars": ["question"],
    "model_serving_endpoint": "databricks-dbrx-instruct",
    "llm_parameters": {"temperature": 0.01, "max_tokens": 500},
}

model_config = mlflow.models.ModelConfig(development_config=config_dict)

# Use model_config.get() to retrieve a parameter value
value = model_config.get('sample_param')

Log the agent

Logging an agent is the basis of the development process. Logging captures a “point in time” of the agent’s code and configuration so you can evaluate the quality of the configuration. When developing agents, Databricks recommends that you use code-based logging instead of serialization-based logging. For more information about the pros and cons of each type of logging, see Code-based vs. serialization-based logging.

This section covers how to use code-based logging.

Code-based logging workflow

For code-based logging, the code that logs your agent or chain must be in a separate notebook from your chain code. This notebook is called a driver notebook. For an example notebook, see Example notebooks.

Code-based logging workflow with LangChain

  1. Create a notebook or Python file with your code. For purposes of this example, the notebook or file is named chain.py. The notebook or file must contain a LangChain chain, referred to here as lc_chain.
  2. Include mlflow.models.set_model(lc_chain) in the notebook or file.
  3. Create a new notebook to serve as the driver notebook (called driver.py in this example).
  4. In the driver notebook, use mlflow.lang_chain.log_model(lc_model=”/path/to/chain.py”)to run chain.py and log the results to an MLflow model.
  5. Deploy the model. See Deploy an agent for generative AI application. The deployment of your agent might depend on other Databricks resources such as a vector search index and model serving endpoints. For LangChain agents:
    • The MLflow log_model infers the dependencies required by the chain and logs them to the MLmodel file in the logged model artifact.
    • During deployment, databricks.agents.deploy automatically creates the M2M OAuth tokens required to access and communicate with these inferred resource dependencies.
  6. When the serving environment is loaded, chain.py is executed.
  7. When a serving request comes in, lc_chain.invoke(...) is called.

Code-based logging workflow with PyFunc

  1. Create a notebook or Python file with your code. For purposes of this example, the notebook or file is named chain.py. The notebook or file must contain a PyFunc class, referred to here as PyFuncClass.
  2. Include mlflow.models.set_model(PyFuncClass) in the notebook or file.
  3. Create a new notebook to serve as the driver notebook (called driver.py in this example).
  4. In the driver notebook, use mlflow.pyfunc.log_model(python_model=”/path/to/chain.py”, resources=”/path/to/resources.yaml”) to run chain.py and log the results to an MLflow model. The resources parameter declares any resources needed to serve the model such as, a vector search index or serving endpoint that serves a foundation model. See an example resources file for PyFunc.
  5. Deploy the model. See Deploy an agent for generative AI application.
  6. When the serving environment is loaded, chain.py is executed.
  7. When a serving request comes in, PyFuncClass.predict(...) is called.

Example code for logging chains

import mlflow

code_path = "/Workspace/Users/first.last/chain.py"
config_path = "/Workspace/Users/first.last/config.yml"

input_example = {
    "messages": [
        {
            "role": "user",
            "content": "What is Retrieval-augmented Generation?",
        }
    ]
}

# example using LangChain
with mlflow.start_run():
  logged_chain_info = mlflow.langchain.log_model(
    lc_model=code_path,
    model_config=config_path, # If you specify this parameter, this is the configuration that is used for training the model. The development_config is overwritten.
    artifact_path="chain", # This string is used as the path inside the MLflow model where artifacts are stored
    input_example=input_example, # Must be a valid input to your chain
    example_no_conversion=True, # Required
  )

# or use a PyFunc model

# resources_path = "/Workspace/Users/first.last/resources.yml"

# with mlflow.start_run():
#   logged_chain_info = mlflow.pyfunc.log_model(
#     python_model=chain_notebook_path,
#     artifact_path="chain",
#     input_example=input_example,
#     resources=resources_path,
#     example_no_conversion=True,
#   )

print(f"MLflow Run: {logged_chain_info.run_id}")
print(f"Model URI: {logged_chain_info.model_uri}")

To verify that the model has been logged correctly, load the chain and call invoke:

# Using LangChain
model = mlflow.langchain.load_model(logged_chain_info.model_uri)
model.invoke(example)

# Using PyFunc
model = mlflow.pyfunc.load_model(logged_chain_info.model_uri)
model.invoke(example)

Specify resources for PyFunc agent

You can specify resources, such as a vector search index and a serving endpoint, that are required to serve the model. For LangChain, resources are automatically picked up and logged along with the model.

When deploying a pyfunc flavored agent, you must manually add any resource dependencies of the deployed agent. An M2M OAuth token with access to all the specified resources in the resources parameter is created and provided to the deployed agent.

Note

You can override the resources your endpoint has permission to by manually specifying the resources when logging the chain.

The following shows how to add serving endpoint and vector search index dependencies by specifying them in the resources parameter.

 with mlflow.start_run():
   logged_chain_info = mlflow.pyfunc.log_model(
     python_model=chain_notebook_path,
     artifact_path="chain",
     input_example=input_example,
     example_no_conversion=True,
     resources=[
            DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"),
            DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
            DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index")
        ]
   )

You can also add resources by specifying them in a resources.yaml file. You can reference that file path in the resources parameter. An M2M OAuth token with access to all the specified resources in the resources.yaml is created and provided to the deployed agent.

The following is an example resources.yaml file that defines model serving endpoints and a vector search index.


api_version: "1"
databricks:
  vector_search_index:
    - name: "catalog.schema.my_vs_index"
  serving_endpoint:
    - name: databricks-dbrx-instruct
    - name: databricks-bge-large-en

Register the chain to Unity Catalog

Before you deploy the chain, you must register the chain to Unity Catalog. When you register the chain, it is packaged as a model in Unity Catalog, and you can use Unity Catalog permissions for authorization for resources in the chain.

import mlflow

mlflow.set_registry_uri("databricks-uc")

catalog_name = "test_catalog"
schema_name = "schema"
model_name = "chain_name"

model_name = catalog_name + "." + schema_name + "." + model_name
uc_model_info = mlflow.register_model(model_uri=logged_chain_info.model_uri, name=model_name)

Example notebooks

These notebooks create a simple “Hello, world” chain to illustrate how to create a chain application in Databricks. The first example creates a simple chain. The second example notebook illustrates how to use parameters to minimize code changes during development.

Simple chain notebook

Get notebook

Simple chain driver notebook

Get notebook

Parameterized chain notebook

Get notebook

Parameterized chain driver notebook

Get notebook

Code-based vs. serialization-based logging

To create and log a chain, you can use code-based MLflow logging or serialization-based MLflow logging. Databricks recommends that you use code-based logging.

With Code-based MLflow logging, the chain’s code is captured as a Python file. The Python environment is captured as a list of packages. When the chain is deployed, the Python environment is restored, and the chain’s code is executed to load the chain into memory so it can be invoked when the endpoint is called.

With Serialization-based MLflow logging, the chain’s code and current state in the Python environment is serialized to disk, often using libraries such as pickle or joblib. When the chain is deployed, the Python environment is restored, and the serialized object is loaded into memory so it can be invoked when the endpoint is called.

The table shows the advantages and disadvantages of each method.

Method Advantages Disadvantages
Code-based MLflow logging * Overcomes inherent limitations of serialization, which is not supported by many popular GenAI libraries.
* Saves a copy of the original code for later reference.
* No need to restructure your code into a single object that can be serialized.
log_model(...) must be called from a different notebook than the chain’s code (called a driver notebook).
Serialization-based MLflow logging log_model(...) can be called from the same notebook where the model is defined. * Original code is not available.
* All libraries and objects used in the chain must support serialization.