通过


你当前正在访问 Microsoft Azure Global Edition 技术文档网站。 如果需要访问由世纪互联运营的 Microsoft Azure 中国技术文档网站,请访问 https://docs.azure.cn

快速入门:在 Azure DocumentDB 中使用 Go 进行矢量搜索

在 Azure DocumentDB 中将矢量搜索与 Go 客户端库配合使用。 高效存储和查询矢量数据。

本快速入门使用 JSON 文件中的示例酒店数据集,其中包含来自 text-embedding-3-small 模型的预计算矢量。 数据集包括酒店名称、位置、说明和矢量嵌入。

在 GitHub 上查找 示例代码

先决条件

  • Go 版本 1.24 或更高版本

使用矢量创建数据文件

  1. 为酒店数据文件创建新的数据目录:

    mkdir data
    
  2. Hotels_Vector.json包含矢量的原始数据文件 复制到 data 目录。

创建 Go 项目

  1. 为项目创建新目录,并在 Visual Studio Code 中打开它:

    mkdir vector-search-quickstart
    cd vector-search-quickstart
    code .
    
  2. 初始化 Go 模块:

    go mod init vector-search-quickstart
    
  3. 安装所需的 Go 包:

    go get go.mongodb.org/mongo-driver
    go get github.com/Azure/azure-sdk-for-go/sdk/azcore
    go get github.com/Azure/azure-sdk-for-go/sdk/azidentity
    go get github.com/openai/openai-go/v3
    go get github.com/joho/godotenv
    
    • go.mongodb.org/mongo-driver:MongoDB Go 驱动程序
    • github.com/Azure/azure-sdk-for-go/sdk/azcore:用于 HTTP 管道和身份验证的 Azure SDK 核心实用工具
    • github.com/Azure/azure-sdk-for-go/sdk/azidentity:用于无密码令牌身份验证的 Azure 身份库
    • github.com/openai/openai-go/v3:OpenAI Go 客户端库以创建向量
    • github.com/joho/godotenv:从 .env 文件加载环境变量
  4. .env在项目根目录中为环境变量创建文件:

    # Identity for local developer authentication with Azure CLI
    AZURE_TOKEN_CREDENTIALS=AzureCliCredential
    
    # Azure OpenAI Embedding Settings
    AZURE_OPENAI_EMBEDDING_MODEL=text-embedding-3-small
    AZURE_OPENAI_EMBEDDING_API_VERSION=2023-05-15
    AZURE_OPENAI_EMBEDDING_ENDPOINT=<AZURE_OPENAI_ENDPOINT>
    EMBEDDING_SIZE_BATCH=16
    
    # Azure DocumentDB configuration
    MONGO_CLUSTER_NAME=<DOCUMENTDB_NAME>
    
    # Data file
    DATA_FILE_WITH_VECTORS=../data/Hotels_Vector.json
    EMBEDDED_FIELD=DescriptionVector
    EMBEDDING_DIMENSIONS=1536
    LOAD_SIZE_BATCH=50
    

    将文件中的 .env 占位符值替换为你自己的信息:

    • AZURE_OPENAI_EMBEDDING_ENDPOINT:Azure OpenAI 资源终结点的 URL
    • MONGO_CLUSTER_NAME:Azure DocumentDB 资源名称

    应始终首选无密码身份验证,但它需要其他设置。 有关设置托管标识和各种身份验证选项的详细信息,请参阅 使用 Azure 标识库向 Azure 服务进行身份验证 Go 应用

通过创建矢量搜索的代码文件继续项目。

为 Go 文件创建一个 src 目录。 添加两个文件:diskann.goutils.go 用于 DiskANN 索引的实现。

mkdir src    
touch src/diskann.go
touch src/utils.go

完成后,项目结构应如下所示:

data
│── Hotels_Vector.json
vector-search-quickstart
├── .env
├── go.mod
├── src
│   ├── diskann.go
│   └── utils.go

将以下代码添加到 src/diskann.go 文件:

package main

import (
    "context"
    "fmt"
    "log"
    "strings"
    "time"

    "go.mongodb.org/mongo-driver/bson"
    "go.mongodb.org/mongo-driver/mongo"

    "github.com/openai/openai-go/v3"
)

// CreateDiskANNVectorIndex creates a DiskANN vector index on the specified field
func CreateDiskANNVectorIndex(ctx context.Context, collection *mongo.Collection, vectorField string, dimensions int) error {
    fmt.Printf("Creating DiskANN vector index on field '%s'...\n", vectorField)

    // Drop any existing vector indexes on this field first
    err := DropVectorIndexes(ctx, collection, vectorField)
    if err != nil {
        fmt.Printf("Warning: Could not drop existing indexes: %v\n", err)
    }

    // Use the native MongoDB command for DocumentDB vector indexes
    // Note: Must use bson.D for commands to preserve order and avoid "multi-key map" errors
    indexCommand := bson.D{
        {"createIndexes", collection.Name()},
        {"indexes", []bson.D{
            {
                {"name", fmt.Sprintf("diskann_index_%s", vectorField)},
                {"key", bson.D{
                    {vectorField, "cosmosSearch"}, // DocumentDB vector search index type
                }},
                {"cosmosSearchOptions", bson.D{
                    // DiskANN algorithm configuration
                    {"kind", "vector-diskann"},

                    // Vector dimensions must match the embedding model
                    {"dimensions", dimensions},

                    // Vector similarity metric - cosine is good for text embeddings
                    {"similarity", "COS"},

                    // Maximum degree: number of edges per node in the graph
                    // Higher values improve accuracy but increase memory usage
                    {"maxDegree", 20},

                    // Build parameter: candidates evaluated during index construction
                    // Higher values improve index quality but increase build time
                    {"lBuild", 10},
                }},
            },
        }},
    }

    // Execute the createIndexes command directly
    var result bson.M
    err = collection.Database().RunCommand(ctx, indexCommand).Decode(&result)
    if err != nil {
        // Check if it's a tier limitation and suggest alternatives
        if strings.Contains(err.Error(), "not enabled for this cluster tier") {
            fmt.Println("\nDiskANN indexes require a higher cluster tier.")
            fmt.Println("Try one of these alternatives:")
            fmt.Println("  • Upgrade your DocumentDB cluster to a higher tier")
            fmt.Println("  • Use HNSW instead: go run src/hnsw.go")
            fmt.Println("  • Use IVF instead: go run src/ivf.go")
        }
        return fmt.Errorf("error creating DiskANN vector index: %v", err)
    }

    fmt.Println("DiskANN vector index created successfully")
    return nil
}

// PerformDiskANNVectorSearch performs a vector search using DiskANN algorithm
func PerformDiskANNVectorSearch(ctx context.Context, collection *mongo.Collection, openAIClient openai.Client, queryText, vectorField, modelName string, topK int) ([]SearchResult, error) {
    fmt.Printf("Performing DiskANN vector search for: '%s'\n", queryText)

    // Generate embedding for the query text
    queryEmbedding, err := GenerateEmbedding(ctx, openAIClient, queryText, modelName)
    if err != nil {
        return nil, fmt.Errorf("error generating embedding: %v", err)
    }

    // Construct the aggregation pipeline for vector search
    // DocumentDB uses $search with cosmosSearch
    pipeline := []bson.M{
        {
            "$search": bson.M{
                // Use cosmosSearch for vector operations in DocumentDB
                "cosmosSearch": bson.M{
                    // The query vector to search for
                    "vector": queryEmbedding,

                    // Field containing the document vectors to compare against
                    "path": vectorField,

                    // Number of final results to return
                    "k": topK,
                },
            },
        },
        {
            // Add similarity score to the results
            "$project": bson.M{
                "document": "$$ROOT",
                // Add search score from metadata
                "score": bson.M{"$meta": "searchScore"},
            },
        },
    }

    // Execute the aggregation pipeline
    cursor, err := collection.Aggregate(ctx, pipeline)
    if err != nil {
        return nil, fmt.Errorf("error performing DiskANN vector search: %v", err)
    }
    defer cursor.Close(ctx)

    var results []SearchResult
    for cursor.Next(ctx) {
        var result SearchResult
        if err := cursor.Decode(&result); err != nil {
            fmt.Printf("Warning: Could not decode result: %v\n", err)
            continue
        }
        results = append(results, result)
    }

    if err := cursor.Err(); err != nil {
        return nil, fmt.Errorf("cursor error: %v", err)
    }

    return results, nil
}

// main function demonstrates DiskANN vector search functionality
func main() {
    ctx := context.Background()

    // Load configuration from environment variables
    config := LoadConfig()

    fmt.Println("\nInitializing MongoDB and Azure OpenAI clients...")
    mongoClient, azureOpenAIClient, err := GetClientsPasswordless()
    if err != nil {
        log.Fatalf("Failed to initialize clients: %v", err)
    }
    defer mongoClient.Disconnect(ctx)

    // Get database and collection
    database := mongoClient.Database(config.DatabaseName)
    collection := database.Collection("hotels_diskann")

    // Load data with embeddings
    fmt.Printf("\nLoading data from %s...\n", config.DataFile)
    data, err := ReadFileReturnJSON(config.DataFile)
    if err != nil {
        log.Fatalf("Failed to load data: %v", err)
    }
    fmt.Printf("Loaded %d documents\n", len(data))

    // Verify embeddings are present
    var documentsWithEmbeddings []map[string]interface{}
    for _, doc := range data {
        if _, exists := doc[config.VectorField]; exists {
            documentsWithEmbeddings = append(documentsWithEmbeddings, doc)
        }
    }

    if len(documentsWithEmbeddings) == 0 {
        log.Fatalf("No documents found with embeddings in field '%s'. Please run create_embeddings.go first.", config.VectorField)
    }

    // Insert data into collection
    fmt.Printf("\nInserting data into collection '%s'...\n", config.CollectionName)

    // Clear existing data to ensure clean state
    deleteResult, err := collection.DeleteMany(ctx, bson.M{})
    if err != nil {
        log.Fatalf("Failed to clear existing data: %v", err)
    }
    if deleteResult.DeletedCount > 0 {
        fmt.Printf("Cleared %d existing documents from collection\n", deleteResult.DeletedCount)
    }

    // Insert the hotel data
    stats, err := InsertData(ctx, collection, documentsWithEmbeddings, config.BatchSize, nil)
    if err != nil {
        log.Fatalf("Failed to insert data: %v", err)
    }

    if stats.Inserted == 0 {
        log.Fatalf("No documents were inserted successfully")
    }

    fmt.Printf("Insertion completed: %d inserted, %d failed\n", stats.Inserted, stats.Failed)

    // Create DiskANN vector index
    err = CreateDiskANNVectorIndex(ctx, collection, config.VectorField, config.Dimensions)
    if err != nil {
        log.Fatalf("Failed to create DiskANN vector index: %v", err)
    }

    // Wait briefly for index to be ready
    fmt.Println("Waiting for index to be ready...")
    time.Sleep(2 * time.Second)

    // Perform sample vector search
    query := "quintessential lodging near running trails, eateries, retail"

    results, err := PerformDiskANNVectorSearch(
        ctx,
        collection,
        azureOpenAIClient,
        query,
        config.VectorField,
        config.ModelName,
        5,
    )
    if err != nil {
        log.Fatalf("Failed to perform vector search: %v", err)
    }

    // Display results
    PrintSearchResults(results, 5, true)

    fmt.Println("\nDiskANN demonstration completed successfully!")
}

此主模块提供以下功能:

  • 包括实用工具函数
  • 为环境变量创建配置结构
  • 为 Azure OpenAI 和 Azure DocumentDB 创建客户端
  • 连接到 MongoDB、创建数据库和集合、插入数据以及创建标准索引
  • 使用 IVF、HNSW 或 DiskANN 创建矢量索引
  • 使用 OpenAI 客户端为示例查询文本创建嵌入。 可以在 main 函数中更改查询
  • 使用嵌入运行矢量搜索并输出结果

创建实用工具函数

将下列代码添加到 src/utils.go

package main

import (
    "context"
    "encoding/json"
    "fmt"
    "log"
    "os"
    "strconv"
    "strings"
    "time"

    "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
    "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
    "github.com/joho/godotenv"
    "github.com/openai/openai-go/v3"
    "github.com/openai/openai-go/v3/azure"
    "github.com/openai/openai-go/v3/option"
    "go.mongodb.org/mongo-driver/bson"
    "go.mongodb.org/mongo-driver/mongo"
    "go.mongodb.org/mongo-driver/mongo/options"
)

// Config holds the application configuration
type Config struct {
    ClusterName    string
    DatabaseName   string
    CollectionName string
    DataFile       string
    VectorField    string
    ModelName      string
    Dimensions     int
    BatchSize      int
}

// SearchResult represents a search result document
type SearchResult struct {
    Document interface{} `bson:"document"`
    Score    float64     `bson:"score"`
}

// HotelData represents a hotel document structure
type HotelData struct {
    HotelName         string    `bson:"HotelName" json:"HotelName"`
    Description       string    `bson:"Description" json:"Description"`
    DescriptionVector []float64 `bson:"DescriptionVector,omitempty" json:"DescriptionVector,omitempty"`
    // Add other fields as needed
}

// InsertStats holds statistics about data insertion
type InsertStats struct {
    Total    int `json:"total"`
    Inserted int `json:"inserted"`
    Failed   int `json:"failed"`
}

// LoadConfig loads configuration from environment variables
func LoadConfig() *Config {
    // Load environment variables from .env file
    // For production use, prefer Azure Key Vault or similar secret management
    // services instead of .env files. For development/demo purposes only.
    err := godotenv.Load()
    if err != nil {
        log.Printf("Warning: Error loading .env file: %v", err)
    }

    dimensions, _ := strconv.Atoi(getEnvOrDefault("EMBEDDING_DIMENSIONS", "1536"))
    batchSize, _ := strconv.Atoi(getEnvOrDefault("LOAD_SIZE_BATCH", "100"))

    return &Config{
        ClusterName:    getEnvOrDefault("MONGO_CLUSTER_NAME", "vectorSearch"),
        DatabaseName:   "Hotels",
        CollectionName: "vectorSearchCollection",
        DataFile:       getEnvOrDefault("DATA_FILE_WITH_VECTORS", "../data/Hotels_Vector.json"),
        VectorField:    getEnvOrDefault("EMBEDDED_FIELD", "DescriptionVector"),
        ModelName:      getEnvOrDefault("AZURE_OPENAI_EMBEDDING_MODEL", "text-embedding-3-small"),
        Dimensions:     dimensions,
        BatchSize:      batchSize,
    }
}

// getEnvOrDefault returns environment variable value or default if not set
func getEnvOrDefault(key, defaultValue string) string {
    if value := os.Getenv(key); value != "" {
        return value
    }
    return defaultValue
}

// GetClients creates MongoDB and Azure OpenAI clients with connection string authentication
func GetClients() (*mongo.Client, openai.Client, error) {
    ctx := context.Background()

    // Get MongoDB connection string
    mongoConnectionString := os.Getenv("MONGO_CONNECTION_STRING")
    if mongoConnectionString == "" {
        return nil, openai.Client{}, fmt.Errorf("MONGO_CONNECTION_STRING environment variable is required. " +
            "Set it to your DocumentDB connection string or use GetClientsPasswordless() for OIDC auth")
    }

    // Create MongoDB client with optimized settings for DocumentDB
    clientOptions := options.Client().
        ApplyURI(mongoConnectionString).
        SetMaxPoolSize(50).
        SetMinPoolSize(5).
        SetMaxConnIdleTime(30 * time.Second).
        SetServerSelectionTimeout(5 * time.Second).
        SetSocketTimeout(20 * time.Second)

    mongoClient, err := mongo.Connect(ctx, clientOptions)
    if err != nil {
        return nil, openai.Client{}, fmt.Errorf("failed to connect to MongoDB: %v", err)
    }

    // Test the connection
    err = mongoClient.Ping(ctx, nil)
    if err != nil {
        return nil, openai.Client{}, fmt.Errorf("failed to ping MongoDB: %v", err)
    }

    // Get Azure OpenAI configuration
    azureOpenAIEndpoint := os.Getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
    azureOpenAIKey := os.Getenv("AZURE_OPENAI_EMBEDDING_KEY")

    if azureOpenAIEndpoint == "" || azureOpenAIKey == "" {
        return nil, openai.Client{}, fmt.Errorf("Azure OpenAI endpoint and key are required")
    }

    // Create Azure OpenAI client
    openAIClient := openai.NewClient(
        option.WithBaseURL(fmt.Sprintf("%s/openai/v1", azureOpenAIEndpoint)),
        option.WithAPIKey(azureOpenAIKey))

    return mongoClient, openAIClient, nil
}

// GetClientsPasswordless creates MongoDB and Azure OpenAI clients with passwordless authentication
func GetClientsPasswordless() (*mongo.Client, openai.Client, error) {
    ctx := context.Background()

    // Get MongoDB cluster name
    clusterName := os.Getenv("MONGO_CLUSTER_NAME")
    if clusterName == "" {
        return nil, openai.Client{}, fmt.Errorf("MONGO_CLUSTER_NAME environment variable is required")
    }

    // Create Azure credential
    credential, err := azidentity.NewDefaultAzureCredential(nil)
    if err != nil {
        return nil, openai.Client{}, fmt.Errorf("failed to create Azure credential: %v", err)
    }

    // Attempt OIDC authentication
    mongoURI := fmt.Sprintf("mongodb+srv://%s.global.mongocluster.cosmos.azure.com/", clusterName)

    fmt.Println("Attempting OIDC authentication...")
    mongoClient, err := connectWithOIDC(ctx, mongoURI, credential)
    if err != nil {
        return nil, openai.Client{}, fmt.Errorf("OIDC authentication failed: %v", err)
    }
    fmt.Println("OIDC authentication successful!")

    // Get Azure OpenAI endpoint
    azureOpenAIEndpoint := os.Getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
    if azureOpenAIEndpoint == "" {
        return nil, openai.Client{}, fmt.Errorf("AZURE_OPENAI_EMBEDDING_ENDPOINT environment variable is required")
    }

    // Create Azure OpenAI client with credential-based authentication
    openAIClient := openai.NewClient(
        option.WithBaseURL(fmt.Sprintf("%s/openai/v1", azureOpenAIEndpoint)),
        azure.WithTokenCredential(credential))

    return mongoClient, openAIClient, nil
}

// connectWithOIDC attempts to connect using OIDC authentication
func connectWithOIDC(ctx context.Context, mongoURI string, credential *azidentity.DefaultAzureCredential) (*mongo.Client, error) {
    // Create OIDC machine callback using Azure credential
    oidcCallback := func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) {
        scope := "https://ossrdbms-aad.database.windows.net/.default"
        fmt.Printf("Getting token with scope: %s\n", scope)
        token, err := credential.GetToken(ctx, policy.TokenRequestOptions{
            Scopes: []string{scope},
        })
        if err != nil {
            return nil, fmt.Errorf("failed to get token with scope %s: %v", scope, err)
        }

        fmt.Printf("Successfully obtained token")

        return &options.OIDCCredential{
            AccessToken: token.Token,
        }, nil
    }
    // Set up MongoDB client options with OIDC authentication
    clientOptions := options.Client().
        ApplyURI(mongoURI).
        SetConnectTimeout(30 * time.Second).
        SetServerSelectionTimeout(30 * time.Second).
        SetRetryWrites(true).
        SetAuth(options.Credential{
            AuthMechanism: "MONGODB-OIDC",
            // For local development, don't set ENVIRONMENT=azure to allow custom callbacks
            AuthMechanismProperties: map[string]string{
                "TOKEN_RESOURCE": "https://ossrdbms-aad.database.windows.net",
            },
            OIDCMachineCallback: oidcCallback,
        })

    mongoClient, err := mongo.Connect(ctx, clientOptions)
    if err != nil {
        return nil, err
    }

    return mongoClient, nil
}

// connectWithConnectionString attempts to connect using a connection string
func connectWithConnectionString(ctx context.Context, connectionString string) (*mongo.Client, error) {
    clientOptions := options.Client().
        ApplyURI(connectionString).
        SetMaxPoolSize(50).
        SetMinPoolSize(5).
        SetMaxConnIdleTime(30 * time.Second).
        SetServerSelectionTimeout(5 * time.Second).
        SetSocketTimeout(20 * time.Second)

    mongoClient, err := mongo.Connect(ctx, clientOptions)
    if err != nil {
        return nil, err
    }

    return mongoClient, nil
}

// ReadFileReturnJSON reads a JSON file and returns the data as a slice of maps
func ReadFileReturnJSON(filePath string) ([]map[string]interface{}, error) {
    file, err := os.ReadFile(filePath)
    if err != nil {
        return nil, fmt.Errorf("error reading file '%s': %v", filePath, err)
    }

    var data []map[string]interface{}
    err = json.Unmarshal(file, &data)
    if err != nil {
        return nil, fmt.Errorf("error parsing JSON in file '%s': %v", filePath, err)
    }

    return data, nil
}

// WriteFileJSON writes data to a JSON file
func WriteFileJSON(data []map[string]interface{}, filePath string) error {
    jsonData, err := json.MarshalIndent(data, "", "  ")
    if err != nil {
        return fmt.Errorf("error marshalling data to JSON: %v", err)
    }

    err = os.WriteFile(filePath, jsonData, 0644)
    if err != nil {
        return fmt.Errorf("error writing to file '%s': %v", filePath, err)
    }

    fmt.Printf("Data successfully written to '%s'\n", filePath)
    return nil
}

// InsertData inserts data into a MongoDB collection in batches
func InsertData(ctx context.Context, collection *mongo.Collection, data []map[string]interface{}, batchSize int, indexFields []string) (*InsertStats, error) {
    totalDocuments := len(data)
    insertedCount := 0
    failedCount := 0

    fmt.Printf("Starting batch insertion of %d documents...\n", totalDocuments)

    // Create indexes if specified
    if len(indexFields) > 0 {
        for _, field := range indexFields {
            indexModel := mongo.IndexModel{
                Keys: bson.D{{Key: field, Value: 1}},
            }
            _, err := collection.Indexes().CreateOne(ctx, indexModel)
            if err != nil {
                fmt.Printf("Warning: Could not create index on %s: %v\n", field, err)
            } else {
                fmt.Printf("Created index on field: %s\n", field)
            }
        }
    }

    // Process data in batches
    for i := 0; i < totalDocuments; i += batchSize {
        end := i + batchSize
        if end > totalDocuments {
            end = totalDocuments
        }

        batch := data[i:end]
        batchNum := (i / batchSize) + 1

        // Convert to []interface{} for MongoDB driver
        documents := make([]interface{}, len(batch))
        for j, doc := range batch {
            documents[j] = doc
        }

        // Insert batch
        result, err := collection.InsertMany(ctx, documents, options.InsertMany().SetOrdered(false))
        if err != nil {
            // Handle bulk write errors
            if bulkErr, ok := err.(mongo.BulkWriteException); ok {
                inserted := len(bulkErr.WriteErrors)
                insertedCount += len(batch) - inserted
                failedCount += inserted

                fmt.Printf("Batch %d had errors: %d inserted, %d failed\n", batchNum, len(batch)-inserted, inserted)

                // Print specific error details
                for _, writeErr := range bulkErr.WriteErrors {
                    fmt.Printf("  Error: %s\n", writeErr.Message)
                }
            } else {
                // Handle unexpected errors
                failedCount += len(batch)
                fmt.Printf("Batch %d failed completely: %v\n", batchNum, err)
            }
        } else {
            insertedCount += len(result.InsertedIDs)
            fmt.Printf("Batch %d completed: %d documents inserted\n", batchNum, len(result.InsertedIDs))
        }

        // Small delay between batches
        time.Sleep(100 * time.Millisecond)
    }

    return &InsertStats{
        Total:    totalDocuments,
        Inserted: insertedCount,
        Failed:   failedCount,
    }, nil
}

// DropVectorIndexes drops existing vector indexes on the specified field
func DropVectorIndexes(ctx context.Context, collection *mongo.Collection, vectorField string) error {
    // Get all indexes for the collection
    cursor, err := collection.Indexes().List(ctx)
    if err != nil {
        return fmt.Errorf("could not list indexes: %v", err)
    }
    defer cursor.Close(ctx)

    var vectorIndexes []string
    for cursor.Next(ctx) {
        var index bson.M
        if err := cursor.Decode(&index); err != nil {
            continue
        }

        // Check if this is a vector index on the specified field
        if key, ok := index["key"].(bson.M); ok {
            if indexType, exists := key[vectorField]; exists && indexType == "cosmosSearch" {
                if name, ok := index["name"].(string); ok {
                    vectorIndexes = append(vectorIndexes, name)
                }
            }
        }
    }

    // Drop each vector index found
    for _, indexName := range vectorIndexes {
        fmt.Printf("Dropping existing vector index: %s\n", indexName)
        _, err := collection.Indexes().DropOne(ctx, indexName)
        if err != nil {
            fmt.Printf("Warning: Could not drop index %s: %v\n", indexName, err)
        }
    }

    if len(vectorIndexes) > 0 {
        fmt.Printf("Dropped %d existing vector index(es)\n", len(vectorIndexes))
    } else {
        fmt.Println("No existing vector indexes found to drop")
    }

    return nil
}

// PrintSearchResults prints search results in a formatted way
func PrintSearchResults(results []SearchResult, maxResults int, showScore bool) {
    if len(results) == 0 {
        fmt.Println("No search results found.")
        return
    }

    if maxResults > len(results) {
        maxResults = len(results)
    }

    fmt.Printf("\nSearch Results (showing top %d):\n", maxResults)
    fmt.Println(strings.Repeat("=", 80))

    for i := 0; i < maxResults; i++ {
        result := results[i]

        // Extract HotelName from document (assuming bson.D structure)
        doc := result.Document.(bson.D)
        var hotelName string
        for _, elem := range doc {
            if elem.Key == "HotelName" {
                hotelName = fmt.Sprintf("%v", elem.Value)
                break
            }
        }

        // Display results
        fmt.Printf("%d. HotelName: %s", i+1, hotelName)

        if showScore {
            fmt.Printf(", Score: %.4f", result.Score)
        }

        fmt.Println()
    }
}

// GenerateEmbedding generates an embedding for the given text using Azure OpenAI
func GenerateEmbedding(ctx context.Context, client openai.Client, text, modelName string) ([]float64, error) {
    resp, err := client.Embeddings.New(ctx, openai.EmbeddingNewParams{
        Input: openai.EmbeddingNewParamsInputUnion{
            OfString: openai.String(text),
        },
        Model: modelName,
    })
    if err != nil {
        return nil, fmt.Errorf("failed to generate embedding: %v", err)
    }

    if len(resp.Data) == 0 {
        return nil, fmt.Errorf("no embedding data received")
    }

    // Convert []float32 to []float64
    embedding := make([]float64, len(resp.Data[0].Embedding))
    for i, v := range resp.Data[0].Embedding {
        embedding[i] = float64(v)
    }

    return embedding, nil
}

此实用工具模块提供以下功能:

  • Config:环境变量的配置结构
  • SearchResult:包含分数的搜索结果文档的结构
  • HotelData:表示酒店文档的结构
  • GetClients:为 Azure OpenAI 和 Azure DocumentDB 创建并返回客户端
  • GetClientsPasswordless:使用无密码身份验证创建和返回客户端(OIDC)。 在资源上启用 RBAC 并登录到 Azure CLI
  • ReadFileReturnJSON:读取 JSON 文件并将其内容作为地图切片返回
  • WriteFileJSON:将数据写入 JSON 文件
  • InsertData:将数据批量插入 MongoDB 集合,并在指定字段上创建标准索引
  • PrintSearchResults:打印矢量搜索的结果,包括分数和酒店名称
  • GenerateEmbedding:使用 Azure OpenAI 创建嵌入内容

使用 Azure CLI 进行身份验证

在运行应用程序之前登录到 Azure CLI,以便它可以安全地访问 Azure 资源。

az login

该代码使用本地开发人员身份验证访问 Azure DocumentDB 和 Azure OpenAI。 设置 AZURE_TOKEN_CREDENTIALS=AzureCliCredential后,此设置会告知函数以 确定方式使用 Azure CLI 凭据进行身份验证。 身份验证依赖于 azidentity 中的 DefaultAzureCredential 在环境中查找你的 Azure 凭据。 详细了解如何使用 Azure 标识库向 Azure 服务验证 Go 应用

生成并运行应用程序

生成并运行 Go 应用程序:

go mod tidy
go run src/diskann.go src/utils.go

应用日志记录和输出显示:

  • 集合创建和数据插入状态
  • 矢量索引创建
  • 具有酒店名称和相似性分数的搜索结果
Starting DiskANN vector search demonstration...

Initializing MongoDB and Azure OpenAI clients...
Attempting OIDC authentication...
OIDC authentication successful!

Loading data from ../data/Hotels_Vector.json...
Loaded 50 documents

Inserting data into collection 'hotels_diskann'...
Getting token with scope: https://ossrdbms-aad.database.windows.net/.default
Successfully obtained token
Starting batch insertion of 50 documents...
Batch 1 completed: 50 documents inserted
Insertion completed: 50 inserted, 0 failed
Creating DiskANN vector index on field 'DescriptionVector'...
No existing vector indexes found to drop
DiskANN vector index created successfully
Waiting for index to be ready...
Performing DiskANN vector search for: 'quintessential lodging near running trails, eateries, retail'

Search Results (showing top 5):
================================================================================
1. HotelName: Royal Cottage Resort, Score: 0.4991
2. HotelName: Country Comfort Inn, Score: 0.4785
3. HotelName: Nordick's Valley Motel, Score: 0.4635
4. HotelName: Economy Universe Motel, Score: 0.4461
5. HotelName: Roach Motel, Score: 0.4388

DiskANN demonstration completed successfully!

在 Visual Studio Code 中查看和管理数据

  1. 在 Visual Studio Code 中选择 DocumentDB 扩展 以连接到 Azure DocumentDB 帐户。

  2. 查看 Hotels 数据库中的数据和索引。

    显示 Azure DocumentDB 集合的 DocumentDB 扩展的屏幕截图。

清理资源

当不需要资源组、DocumentDB 帐户和 Azure OpenAI 资源时,请将其删除,以免产生额外费用。