How to solve Invalid SessionHandle error with Azure Databricks ?

Feng YIJUN 0 Reputation points
2024-06-17T14:38:14.4333333+00:00

I am applying a SQLDatabaseChain Chatbot model by using LangChain SQLDatabaseChain and GPT4. I first created this model on Azure Databricks notebook like this :

import json
import os
import langchain
import mlflow
from mlflow.models import infer_signature
from langchain.chains import LLMChain
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain_experimental.sql import SQLDatabaseChain

api_key = os.getenv("OPENAI_API_KEY", "openai_api_key")
llm = ChatOpenAI(temperature=0, openai_api_key=api_key, model_name='gpt-4-0125-preview')

def get_db(_=None):
    db = SQLDatabase.from_databricks(catalog="catalog", schema="default", host="databricks_workspace_url", api_token="databricks_access_token", cluster_id="cluster_id", sample_rows_in_table_info=3)
    return db

db = get_db()

def get_chain():
    db = get_db()
    chain = SQLDatabaseChain.from_llm(llm=llm, db=db, verbose=True)
    return chain

chain = get_chain()

import mlflow
import langchain
import langchain_community
from langchain_community.chat_models import ChatOpenAI
import langchain_openai
import langchain_experimental
import langchain_core
import openai
import databricks
from databricks import sql

mlflow.set_registry_uri("databricks-uc")
model_name = f"SQL_chain"
with mlflow.start_run(run_name="sql_chain_test") as run:
    user_query = "..."
    answer = chain.invoke(user_query)
    answer_json = json.dumps(answer)
    signature = infer_signature(user_query, answer_json)
    model_info = mlflow.langchain.log_model(
        chain,
        loader_fn=get_db,
        artifact_path="chain",
        registered_model_name=model_name,
        pip_requirements=[
            f"mlflow=={mlflow.__version__}",
            f"langchain=={langchain.__version__}",
            f"langchain_community=={langchain_community.__version__}",
            f"langchain_experimental=={langchain_experimental.__version__}",
            f"openai=={openai.__version__}",
            f"langchain_core=={langchain_core.__version__}",
            f"databricks",
            f"databricks-sql-connector==2.9.3",
            ],
        input_example=user_query,
        signature=signature
    )

With this code, I deploy this model on MLflow, and then I create a serving endpoint with Unity catalog model.

And after, I create a frontend application by using Streamlit.

After creating this serving endpoint for 15 minutes, as long as I send a new question to the chatbot, I will recieve a Invalid SessionHandle error :

{"error_code": "BAD_REQUEST", "message": "1 tasks failed. Errors: {0: 'error: DatabaseError(\\'(databricks.sql.exc.DatabaseError) Invalid SessionHandle: SessionHandle [ee17e137-26ea-4677-a7c3-69e81be048bc]\\')

The informations of my cluster : 14.2 (includes Apache Spark 3.5.0, Scala 2.12)

I think it's because the session created is expired, but I don't konw how to make a new one without re-creating a serving endpoint. The goal of this chatbot is anytime when we launch, we can use directly. But with the risk of session timeout, I can't make this chatbot work normally. I didn't find any solutions which can help with my issue. Can anyone help ?

Azure Databricks
Azure Databricks
An Apache Spark-based analytics platform optimized for Azure.
2,228 questions
{count} votes

1 answer

Sort by: Most helpful
  1. Feng YIJUN 0 Reputation points
    2024-07-01T07:36:45.54+00:00

    @PRADEEPCHEEKATLA

    I have tried the solutions. But they don't work.

    When i tried :

    from pyspark.sql import SparkSession from databricks import sql

    def get_db(_=None):
        db = SQLDatabase.from_databricks(...)
    
        
        spark = SparkSession.builder.appName("NewSession").getOrCreate()
        sql_context = sql.Context(spark.sparkContext)
        sql_context.sessionState().newSession()
    
        return db
    

    It returned : AttributeError: module 'databricks.sql' has no attribute 'Context'.

    When I tried :

    def get_chain(): global db if db is None or not db.is_valid(): db = get_db() chain = SQLDatabaseChain.from_llm(llm=llm, db=db, verbose=True) return chain
    

    It returned: AttributeError: 'SQLDatabase' object has no attribute 'is_valid'.

    0 comments No comments

Your answer

Answers can be marked as Accepted Answers by the question author, which helps users to know the answer solved the author's problem.