跟踪实体 SDK 参考

本综合指南介绍了 MLflow 跟踪实体 SDK,演示如何访问和分析跟踪数据的各个方面,包括元数据、范围、评估等。

概述

MLflow Trace 对象由两个主要组件组成:

  • TraceInfo:有关跟踪的元数据(ID、计时、状态、标记、评估)
  • TraceData:实际执行数据(范围、请求/响应)

创建复杂的示例轨迹

让我们创建一个全面的示例来演示所有功能:

import mlflow
import time
from mlflow.entities import SpanType

# Create a complex RAG application trace
@mlflow.trace(span_type=SpanType.CHAIN)
def rag_pipeline(question: str):
    """Main RAG pipeline that orchestrates retrieval and generation."""
    # Add custom tags and metadata
    mlflow.update_current_trace(
        tags={
            "environment": "production",
            "version": "2.1.0",
            "user_id": "U12345",
            "session_id": "S98765",
            "mlflow.traceName": "rag_pipeline"
        }
    )

    # Retrieve relevant documents
    documents = retrieve_documents(question)

    # Generate response with context
    response = generate_answer(question, documents)

    # Simulate tool usage
    fact_check_result = fact_check_tool(response)

    return {
        "answer": response,
        "fact_check": fact_check_result,
        "sources": [doc["metadata"]["doc_uri"] for doc in documents]
    }

@mlflow.trace(span_type=SpanType.RETRIEVER)
def retrieve_documents(query: str):
    """Retrieve relevant documents from vector store."""
    time.sleep(0.1)  # Simulate retrieval time

    # Get current span to set outputs properly
    span = mlflow.get_current_active_span()

    # Create document objects following MLflow schema
    from mlflow.entities import Document
    documents = [
        Document(
            page_content="MLflow Tracing provides observability for GenAI apps...",
            metadata={
                "doc_uri": "docs/mlflow/tracing_guide.md",
                "chunk_id": "chunk_001",
                "relevance_score": 0.95
            }
        ),
        Document(
            page_content="Traces consist of spans that capture execution steps...",
            metadata={
                "doc_uri": "docs/mlflow/trace_concepts.md",
                "chunk_id": "chunk_042",
                "relevance_score": 0.87
            }
        )
    ]

    # Set span outputs properly for RETRIEVER type
    span.set_outputs(documents)

    return [doc.to_dict() for doc in documents]

@mlflow.trace(span_type=SpanType.CHAT_MODEL)
def generate_answer(question: str, documents: list):
    """Generate answer using LLM with retrieved context."""
    time.sleep(0.2)  # Simulate LLM processing

    # Set chat-specific attributes
    from mlflow.tracing import set_span_chat_messages, set_span_chat_tools

    messages = [
        {"role": "system", "content": "You are a helpful assistant. Use the provided context to answer questions."},
        {"role": "user", "content": f"Context: {documents}\n\nQuestion: {question}"}
    ]

    # Define available tools
    tools = [
        {
            "type": "function",
            "function": {
                "name": "fact_check",
                "description": "Verify facts in the response",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "statement": {"type": "string"}
                    },
                    "required": ["statement"]
                }
            }
        }
    ]

    span = mlflow.get_current_active_span()
    set_span_chat_messages(span, messages)
    set_span_chat_tools(span, tools)

    # Simulate token usage
    span.set_attribute("llm.token_usage.input_tokens", 150)
    span.set_attribute("llm.token_usage.output_tokens", 75)
    span.set_attribute("llm.token_usage.total_tokens", 225)

    return "MLflow Tracing provides comprehensive observability for GenAI applications by capturing detailed execution information through spans."

@mlflow.trace(span_type=SpanType.TOOL)
def fact_check_tool(statement: str):
    """Tool to verify facts in the generated response."""
    time.sleep(0.05)

    # Simulate an error for demonstration
    if "comprehensive" in statement:
        raise ValueError("Fact verification service unavailable")

    return {"verified": True, "confidence": 0.92}

# Execute the pipeline
try:
    result = rag_pipeline("What is MLflow Tracing?")
except Exception as e:
    print(f"Pipeline error: {e}")

# Get the trace
trace_id = mlflow.get_last_active_trace_id()
trace = mlflow.get_trace(trace_id)

# Log assessments to the trace
from mlflow.entities import AssessmentSource, AssessmentSourceType

# Add human feedback
mlflow.log_feedback(
    trace_id=trace_id,
    name="helpfulness",
    value=4,
    source=AssessmentSource(
        source_type=AssessmentSourceType.HUMAN,
        source_id="reviewer_alice@company.com"
    ),
    rationale="Clear and accurate response with good context usage"
)

# Add LLM judge assessment
mlflow.log_feedback(
    trace_id=trace_id,
    name="relevance_score",
    value=0.92,
    source=AssessmentSource(
        source_type=AssessmentSourceType.LLM_JUDGE,
        source_id="gpt-4-evaluator"
    ),
    metadata={"evaluation_prompt_version": "v2.1"}
)

# Add ground truth expectation
mlflow.log_expectation(
    trace_id=trace_id,
    name="expected_facts",
    value=["observability", "spans", "GenAI applications"],
    source=AssessmentSource(
        source_type=AssessmentSourceType.HUMAN,
        source_id="subject_matter_expert"
    )
)

# Add span-specific feedback
retriever_span = trace.search_spans(name="retrieve_documents")[0]
mlflow.log_feedback(
    trace_id=trace_id,
    span_id=retriever_span.span_id,
    name="retrieval_quality",
    value="excellent",
    source=AssessmentSource(
        source_type=AssessmentSourceType.CODE,
        source_id="retrieval_evaluator.py"
    )
)

# Refresh trace to get assessments
trace = mlflow.get_trace(trace_id)

访问跟踪元数据 (TraceInfo)

基本元数据属性

# Primary identifiers
print(f"Trace ID: {trace.info.trace_id}")
print(f"Client Request ID: {trace.info.client_request_id}")

# Status information
print(f"State: {trace.info.state}")  # OK, ERROR, IN_PROGRESS
print(f"Status (deprecated): {trace.info.status}")  # Use state instead

# Request/response previews (truncated)
print(f"Request preview: {trace.info.request_preview}")
print(f"Response preview: {trace.info.response_preview}")
# Timestamps (milliseconds since epoch)
print(f"Start time (ms): {trace.info.request_time}")
print(f"Timestamp (ms): {trace.info.timestamp_ms}")  # Alias for request_time

# Duration
print(f"Execution duration (ms): {trace.info.execution_duration}")
print(f"Execution time (ms): {trace.info.execution_time_ms}")  # Alias

# Convert to human-readable format
import datetime
start_time = datetime.datetime.fromtimestamp(trace.info.request_time / 1000)
print(f"Started at: {start_time}")

位置和试验信息

# Trace storage location
location = trace.info.trace_location
print(f"Location type: {location.type}")

# If stored in MLflow experiment
if location.mlflow_experiment:
    print(f"Experiment ID: {location.mlflow_experiment.experiment_id}")
    # Shortcut property
    print(f"Experiment ID: {trace.info.experiment_id}")

# If stored in Databricks inference table
if location.inference_table:
    print(f"Table: {location.inference_table.full_table_name}")

标记和元数据

# Tags (mutable, can be updated after creation)
print("Tags:")
for key, value in trace.info.tags.items():
    print(f"  {key}: {value}")

# Access specific tags
print(f"Environment: {trace.info.tags.get('environment')}")
print(f"User ID: {trace.info.tags.get('user_id')}")

# Trace metadata (immutable, set at creation)
print("\nTrace metadata:")
for key, value in trace.info.trace_metadata.items():
    print(f"  {key}: {value}")

# Deprecated alias
print(f"Request metadata: {trace.info.request_metadata}")  # Same as trace_metadata

令牌使用情况信息

# Get aggregated token usage (if available)
token_usage = trace.info.token_usage
if token_usage:
    print(f"Input tokens: {token_usage.get('input_tokens')}")
    print(f"Output tokens: {token_usage.get('output_tokens')}")
    print(f"Total tokens: {token_usage.get('total_tokens')}")

访问跟踪数据 (TraceData)

处理跨度

# Access all spans
spans = trace.data.spans
print(f"Total spans: {len(spans)}")

# Iterate through spans
for span in spans:
    print(f"\nSpan: {span.name}")
    print(f"  ID: {span.span_id}")
    print(f"  Type: {span.span_type}")
    print(f"  Status: {span.status}")
    print(f"  Start time: {span.start_time_ns}")
    print(f"  End time: {span.end_time_ns}")
    print(f"  Duration (ns): {span.end_time_ns - span.start_time_ns}")

    # Parent-child relationships
    if span.parent_id:
        print(f"  Parent ID: {span.parent_id}")

    # Inputs and outputs
    if span.inputs:
        print(f"  Inputs: {span.inputs}")
    if span.outputs:
        print(f"  Outputs: {span.outputs}")

请求和响应数据

# Get root span request/response (backward compatibility)
request_json = trace.data.request
response_json = trace.data.response

# Parse JSON strings
import json
if request_json:
    request_data = json.loads(request_json)
    print(f"Request: {request_data}")

if response_json:
    response_data = json.loads(response_json)
    print(f"Response: {response_data}")

中间输出

# Get intermediate outputs from non-root spans
intermediate = trace.data.intermediate_outputs
if intermediate:
    print("\nIntermediate outputs:")
    for span_name, output in intermediate.items():
        print(f"  {span_name}: {output}")

在痕迹中搜索

使用 search_spans() 查找范围

import re
from mlflow.entities import SpanType

# 1. Search by exact name
retriever_spans = trace.search_spans(name="retrieve_documents")
print(f"Found {len(retriever_spans)} retriever spans")

# 2. Search by regex pattern
pattern = re.compile(r".*_tool$")
tool_spans = trace.search_spans(name=pattern)
print(f"Found {len(tool_spans)} tool spans")

# 3. Search by span type
chat_spans = trace.search_spans(span_type=SpanType.CHAT_MODEL)
llm_spans = trace.search_spans(span_type="CHAT_MODEL")  # String also works
print(f"Found {len(chat_spans)} chat model spans")

# 4. Search by span ID
specific_span = trace.search_spans(span_id=retriever_spans[0].span_id)
print(f"Found span: {specific_span[0].name if specific_span else 'Not found'}")

# 5. Combine criteria
tool_fact_check = trace.search_spans(
    name="fact_check_tool",
    span_type=SpanType.TOOL
)
print(f"Found {len(tool_fact_check)} fact check tool spans")

# 6. Get all spans of a type
all_tools = trace.search_spans(span_type=SpanType.TOOL)
for tool in all_tools:
    print(f"Tool: {tool.name}")

访问span标签属性

from mlflow.tracing.constant import SpanAttributeKey

# Get a chat model span
chat_span = trace.search_spans(span_type=SpanType.CHAT_MODEL)[0]

# Access chat-specific attributes
messages = chat_span.get_attribute(SpanAttributeKey.CHAT_MESSAGES)
tools = chat_span.get_attribute(SpanAttributeKey.CHAT_TOOLS)

print(f"Chat messages: {messages}")
print(f"Available tools: {tools}")

# Access token usage from span
input_tokens = chat_span.get_attribute("llm.token_usage.input_tokens")
output_tokens = chat_span.get_attribute("llm.token_usage.output_tokens")
print(f"Span token usage - Input: {input_tokens}, Output: {output_tokens}")

# Access all attributes
print("\nAll span attributes:")
for key, value in chat_span.attributes.items():
    print(f"  {key}: {value}")

与评估一起工作

使用 search_assessments() 查找评估

# 1. Get all assessments
all_assessments = trace.search_assessments()
print(f"Total assessments: {len(all_assessments)}")

# 2. Search by name
helpfulness = trace.search_assessments(name="helpfulness")
if helpfulness:
    assessment = helpfulness[0]
    print(f"Helpfulness: {assessment.value}")
    print(f"Source: {assessment.source.source_type} - {assessment.source.source_id}")
    print(f"Rationale: {assessment.rationale}")

# 3. Search by type
feedback_only = trace.search_assessments(type="feedback")
expectations_only = trace.search_assessments(type="expectation")
print(f"Feedback assessments: {len(feedback_only)}")
print(f"Expectation assessments: {len(expectations_only)}")

# 4. Search by span ID
span_assessments = trace.search_assessments(span_id=retriever_span.span_id)
print(f"Assessments for retriever span: {len(span_assessments)}")

# 5. Get all assessments including overridden ones
all_including_invalid = trace.search_assessments(all=True)
print(f"All assessments (including overridden): {len(all_including_invalid)}")

# 6. Combine criteria
human_feedback = trace.search_assessments(
    type="feedback",
    name="helpfulness"
)
for fb in human_feedback:
    print(f"Human feedback: {fb.name} = {fb.value}")

访问评估详细信息

# Get detailed assessment information
for assessment in trace.info.assessments:
    print(f"\nAssessment: {assessment.name}")
    print(f"  Type: {type(assessment).__name__}")
    print(f"  Value: {assessment.value}")
    print(f"  Source: {assessment.source.source_type.value}")
    print(f"  Source ID: {assessment.source.source_id}")

    # Optional fields
    if assessment.rationale:
        print(f"  Rationale: {assessment.rationale}")
    if assessment.metadata:
        print(f"  Metadata: {assessment.metadata}")
    if assessment.error:
        print(f"  Error: {assessment.error}")
    if hasattr(assessment, 'span_id') and assessment.span_id:
        print(f"  Span ID: {assessment.span_id}")

数据导出和转换

转换为字典

# Convert entire trace to dictionary
trace_dict = trace.to_dict()
print(f"Trace dict keys: {trace_dict.keys()}")
print(f"Info keys: {trace_dict['info'].keys()}")
print(f"Data keys: {trace_dict['data'].keys()}")

# Convert individual components
info_dict = trace.info.to_dict()
data_dict = trace.data.to_dict()

# Reconstruct trace from dictionary
from mlflow.entities import Trace
reconstructed_trace = Trace.from_dict(trace_dict)
print(f"Reconstructed trace ID: {reconstructed_trace.info.trace_id}")

JSON 序列化

# Convert to JSON string
trace_json = trace.to_json()
print(f"JSON length: {len(trace_json)} characters")

# Pretty print JSON
trace_json_pretty = trace.to_json(pretty=True)
print("Pretty JSON (first 500 chars):")
print(trace_json_pretty[:500])

# Load trace from JSON
from mlflow.entities import Trace
loaded_trace = Trace.from_json(trace_json)
print(f"Loaded trace ID: {loaded_trace.info.trace_id}")

Pandas 数据帧转换

# Convert trace to DataFrame row
row_data = trace.to_pandas_dataframe_row()
print(f"DataFrame row keys: {list(row_data.keys())}")

# Create DataFrame from multiple traces
import pandas as pd

# Get multiple traces
traces = mlflow.search_traces(max_results=5)

# If you have individual trace objects
trace_rows = [t.to_pandas_dataframe_row() for t in [trace]]
df = pd.DataFrame(trace_rows)

print(f"DataFrame shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")

# Access specific data from DataFrame
print(f"Trace IDs: {df['trace_id'].tolist()}")
print(f"States: {df['state'].tolist()}")
print(f"Durations: {df['execution_duration'].tolist()}")

跨实体 SDK 参考

跨度是跟踪的基础单元,表示单个操作或工作单位。 MLflow 提供了多个与范围相关的类和实用工具。

范围类型

MLflow 定义用于对作进行分类的标准跨度类型:

from mlflow.entities import SpanType

# Predefined span types
print("Available span types:")
print(f"  CHAIN: {SpanType.CHAIN}")        # Orchestration/workflow spans
print(f"  LLM: {SpanType.LLM}")            # LLM inference spans
print(f"  CHAT_MODEL: {SpanType.CHAT_MODEL}")  # Chat completion spans
print(f"  RETRIEVER: {SpanType.RETRIEVER}")    # Document retrieval spans
print(f"  TOOL: {SpanType.TOOL}")              # Tool/function execution spans
print(f"  EMBEDDING: {SpanType.EMBEDDING}")    # Embedding generation spans
print(f"  PARSER: {SpanType.PARSER}")          # Output parsing spans
print(f"  RERANKER: {SpanType.RERANKER}")      # Document reranking spans
print(f"  AGENT: {SpanType.AGENT}")            # Agent execution spans
print(f"  UNKNOWN: {SpanType.UNKNOWN}")        # Default/unspecified type

# You can also use custom string values
custom_type = "CUSTOM_PROCESSOR"

使用不可变范围(Span 类)

Span 类表示从跟踪检索到的不可变已完成跨度:

# Get a span from a trace
spans = trace.data.spans
span = spans[0]

# Basic properties
print(f"Span ID: {span.span_id}")
print(f"Name: {span.name}")
print(f"Type: {span.span_type}")
print(f"Trace ID: {span.trace_id}")  # Which trace this span belongs to
print(f"Parent ID: {span.parent_id}")  # None for root spans

# Timing information (nanoseconds)
print(f"Start time: {span.start_time_ns}")
print(f"End time: {span.end_time_ns}")
duration_ms = (span.end_time_ns - span.start_time_ns) / 1_000_000
print(f"Duration: {duration_ms:.2f}ms")

# Status information
print(f"Status: {span.status}")
print(f"Status code: {span.status.status_code}")
print(f"Status description: {span.status.description}")

# Inputs and outputs
print(f"Inputs: {span.inputs}")
print(f"Outputs: {span.outputs}")

# Get all attributes
attributes = span.attributes
print(f"Total attributes: {len(attributes)}")

# Get specific attribute
specific_attr = span.get_attribute("custom_attribute")
print(f"Custom attribute: {specific_attr}")

# Access events
for event in span.events:
    print(f"Event: {event.name} at {event.timestamp}")
    print(f"  Attributes: {event.attributes}")

在字典与范围之间进行转换

# Convert span to dictionary
span_dict = span.to_dict()
print(f"Span dict keys: {span_dict.keys()}")

# Recreate span from dictionary
from mlflow.entities import Span
reconstructed_span = Span.from_dict(span_dict)
print(f"Reconstructed span: {reconstructed_span.name}")

使用实时跨度 (LiveSpan 类)

在执行期间创建跨度时,可以使用 LiveSpan 可修改的对象:

import mlflow
from mlflow.entities import SpanType, SpanStatus, SpanStatusCode

@mlflow.trace(span_type=SpanType.CHAIN)
def process_data(data: dict):
    # Get the current active span (LiveSpan)
    span = mlflow.get_current_active_span()

    # Set span type (if not set via decorator)
    span.set_span_type(SpanType.CHAIN)

    # Set inputs
    span.set_inputs({"data": data, "timestamp": time.time()})

    # Set individual attributes
    span.set_attribute("processing_version", "2.0")
    span.set_attribute("data_size", len(str(data)))

    # Set multiple attributes at once
    span.set_attributes({
        "environment": "production",
        "region": "us-west-2",
        "custom_metadata": {"key": "value"}
    })

    try:
        # Process the data
        result = {"processed": True, "count": len(data)}

        # Set outputs
        span.set_outputs(result)

        # Set success status
        span.set_status(SpanStatusCode.OK)

    except Exception as e:
        # Record the exception
        span.record_exception(e)
        # This automatically sets status to ERROR and adds an exception event
        raise

    return result

# Example with manual span creation
with mlflow.start_span(name="manual_span", span_type=SpanType.TOOL) as span:
    # Add events during execution
    from mlflow.entities import SpanEvent

    span.add_event(SpanEvent(
        name="processing_started",
        attributes={
            "stage": "initialization",
            "memory_usage_mb": 256
        }
    ))

    # Do some work...
    time.sleep(0.1)

    # Add another event
    span.add_event(SpanEvent(
        name="checkpoint_reached",
        attributes={"progress": 0.5}
    ))

    # Manually end the span with outputs and status
    span.end(
        outputs={"result": "success"},
        attributes={"final_metric": 0.95},
        status=SpanStatusCode.OK
    )

跨事件

事件记录跨度生命周期中的具体发生事件:

from mlflow.entities import SpanEvent
import time

# Create an event with current timestamp
event = SpanEvent(
    name="validation_completed",
    attributes={
        "records_validated": 1000,
        "errors_found": 3,
        "validation_type": "schema"
    }
)

# Create an event with specific timestamp (nanoseconds)
specific_time_event = SpanEvent(
    name="data_checkpoint",
    timestamp=int(time.time() * 1e9),
    attributes={"checkpoint_id": "ckpt_123"}
)

# Create an event from an exception
try:
    raise ValueError("Invalid input format")
except Exception as e:
    error_event = SpanEvent.from_exception(e)
    # This creates an event with name="exception" and attributes containing:
    # - exception.message
    # - exception.type
    # - exception.stacktrace

    # Add to current span
    span = mlflow.get_current_active_span()
    span.add_event(error_event)

时间跨度状态

控制和查询跨度执行状态:

from mlflow.entities import SpanStatus, SpanStatusCode

# Create status objects
success_status = SpanStatus(SpanStatusCode.OK)
error_status = SpanStatus(
    SpanStatusCode.ERROR,
    description="Failed to connect to database"
)

# Set status on a live span
span = mlflow.get_current_active_span()
span.set_status(success_status)

# Or use string shortcuts
span.set_status("OK")
span.set_status("ERROR")

# Query status from completed spans
for span in trace.data.spans:
    if span.status.status_code == SpanStatusCode.ERROR:
        print(f"Error in {span.name}: {span.status.description}")

特殊跨度属性

MLflow 使用特定属性键进行特殊用途:

from mlflow.tracing.constant import SpanAttributeKey

# Common span attributes
span = mlflow.get_current_active_span()

# These are set automatically but can be accessed
request_id = span.get_attribute(SpanAttributeKey.REQUEST_ID)  # Trace ID
span_type = span.get_attribute(SpanAttributeKey.SPAN_TYPE)

# For CHAT_MODEL spans
from mlflow.tracing import set_span_chat_messages, set_span_chat_tools

messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Hello!"}
]

tools = [{
    "type": "function",
    "function": {
        "name": "get_weather",
        "description": "Get weather for a location"
    }
}]

span = mlflow.get_current_active_span()
set_span_chat_messages(span, messages)
set_span_chat_tools(span, tools)

# Access these special attributes
chat_messages = span.get_attribute(SpanAttributeKey.CHAT_MESSAGES)
chat_tools = span.get_attribute(SpanAttributeKey.CHAT_TOOLS)

# For token usage tracking
span.set_attribute("llm.token_usage.input_tokens", 150)
span.set_attribute("llm.token_usage.output_tokens", 75)
span.set_attribute("llm.token_usage.total_tokens", 225)

使用 RETRIEVER 范围

RETRIEVER 段具有特殊的输出要求:

from mlflow.entities import Document, SpanType

@mlflow.trace(span_type=SpanType.RETRIEVER)
def retrieve_documents(query: str):
    span = mlflow.get_current_active_span()

    # Create Document objects (required for RETRIEVER spans)
    documents = [
        Document(
            page_content="The content of the document...",
            metadata={
                "doc_uri": "path/to/document.md",
                "chunk_id": "chunk_001",
                "relevance_score": 0.95,
                "source": "knowledge_base"
            },
            id="doc_123"  # Optional document ID
        ),
        Document(
            page_content="Another relevant section...",
            metadata={
                "doc_uri": "path/to/other.md",
                "chunk_id": "chunk_042",
                "relevance_score": 0.87
            }
        )
    ]

    # Set outputs as Document objects for proper UI rendering
    span.set_outputs(documents)

    # Return in your preferred format
    return [doc.to_dict() for doc in documents]

# Accessing retriever outputs
retriever_span = trace.search_spans(span_type=SpanType.RETRIEVER)[0]
if retriever_span.outputs:
    for doc in retriever_span.outputs:
        if isinstance(doc, dict):
            content = doc.get('page_content', '')
            uri = doc.get('metadata', {}).get('doc_uri', '')
            score = doc.get('metadata', {}).get('relevance_score', 0)
            print(f"Document from {uri} (score: {score})")

高级跨度分析

def analyze_span_tree(trace):
    """Analyze the span hierarchy and relationships."""
    spans = trace.data.spans

    # Build parent-child relationships
    span_dict = {span.span_id: span for span in spans}
    children = {}

    for span in spans:
        if span.parent_id:
            if span.parent_id not in children:
                children[span.parent_id] = []
            children[span.parent_id].append(span)

    # Find root spans
    roots = [s for s in spans if s.parent_id is None]

    def print_tree(span, indent=0):
        duration_ms = (span.end_time_ns - span.start_time_ns) / 1_000_000
        status_icon = "✓" if span.status.status_code == SpanStatusCode.OK else "✗"
        print(f"{'  ' * indent}{status_icon} {span.name} ({span.span_type}) - {duration_ms:.1f}ms")

        # Print children
        for child in sorted(children.get(span.span_id, []),
                          key=lambda s: s.start_time_ns):
            print_tree(child, indent + 1)

    print("Span Hierarchy:")
    for root in roots:
        print_tree(root)

    # Calculate span statistics
    total_time = sum((s.end_time_ns - s.start_time_ns) / 1_000_000
                     for s in spans)
    llm_time = sum((s.end_time_ns - s.start_time_ns) / 1_000_000
                   for s in spans if s.span_type in [SpanType.LLM, SpanType.CHAT_MODEL])
    retrieval_time = sum((s.end_time_ns - s.start_time_ns) / 1_000_000
                        for s in spans if s.span_type == SpanType.RETRIEVER)

    print(f"\nSpan Statistics:")
    print(f"  Total spans: {len(spans)}")
    print(f"  Total time: {total_time:.1f}ms")
    print(f"  LLM time: {llm_time:.1f}ms ({llm_time/total_time*100:.1f}%)")
    print(f"  Retrieval time: {retrieval_time:.1f}ms ({retrieval_time/total_time*100:.1f}%)")

    # Find critical path (longest duration path from root to leaf)
    def find_critical_path(span):
        child_paths = []
        for child in children.get(span.span_id, []):
            path, duration = find_critical_path(child)
            child_paths.append((path, duration))

        span_duration = (span.end_time_ns - span.start_time_ns) / 1_000_000
        if child_paths:
            best_path, best_duration = max(child_paths, key=lambda x: x[1])
            return [span] + best_path, span_duration + best_duration
        else:
            return [span], span_duration

    if roots:
        critical_paths = [find_critical_path(root) for root in roots]
        critical_path, critical_duration = max(critical_paths, key=lambda x: x[1])

        print(f"\nCritical Path ({critical_duration:.1f}ms total):")
        for span in critical_path:
            duration_ms = (span.end_time_ns - span.start_time_ns) / 1_000_000
            print(f"  → {span.name} ({duration_ms:.1f}ms)")

# Use the analyzer
analyze_span_tree(trace)

实用示例:全面的跟踪分析

让我们生成一个完整的跟踪分析实用工具,用于提取所有有意义的信息:

def analyze_trace(trace_id: str):
    """Comprehensive analysis of a trace."""

    # Get the trace
    trace = mlflow.get_trace(trace_id)

    print(f"=== TRACE ANALYSIS: {trace_id} ===\n")

    # 1. Basic Information
    print("1. BASIC INFORMATION")
    print(f"   State: {trace.info.state}")
    print(f"   Duration: {trace.info.execution_duration}ms")
    print(f"   Start time: {datetime.datetime.fromtimestamp(trace.info.request_time/1000)}")

    if trace.info.experiment_id:
        print(f"   Experiment: {trace.info.experiment_id}")

    # 2. Tags Analysis
    print("\n2. TAGS")
    for key, value in sorted(trace.info.tags.items()):
        print(f"   {key}: {value}")

    # 3. Token Usage
    print("\n3. TOKEN USAGE")
    if tokens := trace.info.token_usage:
        print(f"   Input: {tokens.get('input_tokens', 0)}")
        print(f"   Output: {tokens.get('output_tokens', 0)}")
        print(f"   Total: {tokens.get('total_tokens', 0)}")

        # Calculate from spans if not in metadata
        total_input = 0
        total_output = 0
        for span in trace.data.spans:
            if span.span_type == SpanType.CHAT_MODEL:
                total_input += span.get_attribute("llm.token_usage.input_tokens") or 0
                total_output += span.get_attribute("llm.token_usage.output_tokens") or 0

        if total_input or total_output:
            print(f"   (From spans - Input: {total_input}, Output: {total_output})")

    # 4. Span Analysis
    print("\n4. SPAN ANALYSIS")
    span_types = {}
    error_spans = []

    for span in trace.data.spans:
        # Count by type
        span_types[span.span_type] = span_types.get(span.span_type, 0) + 1

        # Collect errors
        if span.status.status_code.name == "ERROR":
            error_spans.append(span)

    print("   Span counts by type:")
    for span_type, count in sorted(span_types.items()):
        print(f"     {span_type}: {count}")

    if error_spans:
        print(f"\n   Error spans ({len(error_spans)}):")
        for span in error_spans:
            print(f"     - {span.name}: {span.status.description}")

    # 5. Retrieval Analysis
    print("\n5. RETRIEVAL ANALYSIS")
    retriever_spans = trace.search_spans(span_type=SpanType.RETRIEVER)
    if retriever_spans:
        for r_span in retriever_spans:
            if r_span.outputs:
                docs = r_span.outputs
                print(f"   Retrieved {len(docs)} documents:")
                for doc in docs[:3]:  # Show first 3
                    if isinstance(doc, dict):
                        uri = doc.get('metadata', {}).get('doc_uri', 'Unknown')
                        score = doc.get('metadata', {}).get('relevance_score', 'N/A')
                        print(f"     - {uri} (score: {score})")

    # 6. Assessment Summary
    print("\n6. ASSESSMENTS")
    assessments = trace.search_assessments()

    # Group by source type
    by_source = {}
    for assessment in assessments:
        source_type = assessment.source.source_type.value
        if source_type not in by_source:
            by_source[source_type] = []
        by_source[source_type].append(assessment)

    for source_type, items in by_source.items():
        print(f"\n   {source_type} ({len(items)}):")
        for assessment in items:
            value_str = f"{assessment.value}"
            if assessment.rationale:
                value_str += f" - {assessment.rationale[:50]}..."
            print(f"     {assessment.name}: {value_str}")

    # 7. Performance Breakdown
    print("\n7. PERFORMANCE BREAKDOWN")
    root_span = next((s for s in trace.data.spans if s.parent_id is None), None)
    if root_span:
        total_duration_ns = root_span.end_time_ns - root_span.start_time_ns

        # Calculate time spent in each span type
        time_by_type = {}
        for span in trace.data.spans:
            duration_ms = (span.end_time_ns - span.start_time_ns) / 1_000_000
            if span.span_type not in time_by_type:
                time_by_type[span.span_type] = 0
            time_by_type[span.span_type] += duration_ms

        print("   Time by span type:")
        for span_type, duration_ms in sorted(time_by_type.items(),
                                           key=lambda x: x[1], reverse=True):
            percentage = (duration_ms / (total_duration_ns / 1_000_000)) * 100
            print(f"     {span_type}: {duration_ms:.1f}ms ({percentage:.1f}%)")

    # 8. Data Flow
    print("\n8. DATA FLOW")
    if intermediate := trace.data.intermediate_outputs:
        print("   Intermediate outputs:")
        for name, output in intermediate.items():
            output_str = str(output)[:100] + "..." if len(str(output)) > 100 else str(output)
            print(f"     {name}: {output_str}")

    return trace

# Run the analysis
analysis_result = analyze_trace(trace_id)

生成可重用跟踪实用工具

class TraceAnalyzer:
    """Utility class for advanced trace analysis."""

    def __init__(self, trace: mlflow.entities.Trace):
        self.trace = trace

    def get_error_summary(self):
        """Get summary of all errors in the trace."""
        errors = []

        # Check trace status
        if self.trace.info.state == "ERROR":
            errors.append({
                "level": "trace",
                "message": "Trace failed",
                "details": self.trace.info.response_preview
            })

        # Check span errors
        for span in self.trace.data.spans:
            if span.status.status_code.name == "ERROR":
                errors.append({
                    "level": "span",
                    "span_name": span.name,
                    "span_type": span.span_type,
                    "message": span.status.description,
                    "span_id": span.span_id
                })

        # Check assessment errors
        for assessment in self.trace.info.assessments:
            if assessment.error:
                errors.append({
                    "level": "assessment",
                    "assessment_name": assessment.name,
                    "error": str(assessment.error)
                })

        return errors

    def get_llm_usage_summary(self):
        """Aggregate LLM usage across all spans."""
        usage = {
            "total_llm_calls": 0,
            "total_input_tokens": 0,
            "total_output_tokens": 0,
            "spans": []
        }

        for span in self.trace.data.spans:
            if span.span_type in [SpanType.CHAT_MODEL, "LLM"]:
                usage["total_llm_calls"] += 1

                input_tokens = span.get_attribute("llm.token_usage.input_tokens") or 0
                output_tokens = span.get_attribute("llm.token_usage.output_tokens") or 0

                usage["total_input_tokens"] += input_tokens
                usage["total_output_tokens"] += output_tokens
                usage["spans"].append({
                    "name": span.name,
                    "input_tokens": input_tokens,
                    "output_tokens": output_tokens
                })

        usage["total_tokens"] = usage["total_input_tokens"] + usage["total_output_tokens"]
        return usage

    def get_retrieval_metrics(self):
        """Extract retrieval quality metrics."""
        metrics = []

        for span in self.trace.search_spans(span_type=SpanType.RETRIEVER):
            if span.outputs:
                docs = span.outputs
                relevance_scores = []

                for doc in docs:
                    if isinstance(doc, dict) and 'metadata' in doc:
                        if score := doc['metadata'].get('relevance_score'):
                            relevance_scores.append(score)

                metrics.append({
                    "span_name": span.name,
                    "num_documents": len(docs),
                    "avg_relevance": sum(relevance_scores) / len(relevance_scores) if relevance_scores else None,
                    "max_relevance": max(relevance_scores) if relevance_scores else None,
                    "min_relevance": min(relevance_scores) if relevance_scores else None
                })

        return metrics

    def get_span_hierarchy(self):
        """Build a hierarchical view of spans."""
        # Create span lookup
        span_dict = {span.span_id: span for span in self.trace.data.spans}

        # Find root spans
        roots = [span for span in self.trace.data.spans if span.parent_id is None]

        def build_tree(span, indent=0):
            result = []
            duration_ms = (span.end_time_ns - span.start_time_ns) / 1_000_000
            result.append({
                "indent": indent,
                "name": span.name,
                "type": span.span_type,
                "duration_ms": duration_ms,
                "status": span.status.status_code.name
            })

            # Find children
            children = [s for s in self.trace.data.spans if s.parent_id == span.span_id]
            for child in sorted(children, key=lambda s: s.start_time_ns):
                result.extend(build_tree(child, indent + 1))

            return result

        hierarchy = []
        for root in roots:
            hierarchy.extend(build_tree(root))

        return hierarchy

    def export_for_evaluation(self):
        """Export trace data in a format suitable for evaluation."""
        # Get root span data
        request = response = None
        if self.trace.data.request:
            request = json.loads(self.trace.data.request)
        if self.trace.data.response:
            response = json.loads(self.trace.data.response)

        # Get expected values from assessments
        expectations = self.trace.search_assessments(type="expectation")
        expected_values = {exp.name: exp.value for exp in expectations}

        # Get retrieval context
        retrieved_context = []
        for span in self.trace.search_spans(span_type=SpanType.RETRIEVER):
            if span.outputs:
                for doc in span.outputs:
                    if isinstance(doc, dict) and 'page_content' in doc:
                        retrieved_context.append(doc['page_content'])

        return {
            "trace_id": self.trace.info.trace_id,
            "request": request,
            "response": response,
            "retrieved_context": retrieved_context,
            "expected_facts": expected_values.get("expected_facts", []),
            "metadata": {
                "user_id": self.trace.info.tags.get("user_id"),
                "session_id": self.trace.info.tags.get("session_id"),
                "duration_ms": self.trace.info.execution_duration,
                "timestamp": self.trace.info.request_time
            }
        }

# Use the analyzer
analyzer = TraceAnalyzer(trace)

# Get various analyses
errors = analyzer.get_error_summary()
print(f"\nErrors found: {len(errors)}")
for error in errors:
    print(f"  - {error['level']}: {error.get('message', error.get('error'))}")

llm_usage = analyzer.get_llm_usage_summary()
print(f"\nLLM Usage: {llm_usage['total_tokens']} total tokens across {llm_usage['total_llm_calls']} calls")

retrieval_metrics = analyzer.get_retrieval_metrics()
print(f"\nRetrieval Metrics:")
for metric in retrieval_metrics:
    print(f"  - {metric['span_name']}: {metric['num_documents']} docs, avg relevance: {metric['avg_relevance']}")

# Export for evaluation
eval_data = analyzer.export_for_evaluation()
print(f"\nExported evaluation data with {len(eval_data['retrieved_context'])} context chunks")

后续步骤

参考指南