Share via

azure databrick graphframe recursive child

Pankaj Joshi 411 Reputation points
2025-09-23T06:56:34.1133333+00:00

I am looking for a databrick GRAPHFRAME breadth first search pyspark code for logic as explained below:

(EXAMPLE)

For each EMP_ROWID "6780001" in df_data look for a matching record in df_rltnshp ROWID2, it matches and return ROWID1 as "6669300", "6661974" and "6661975" with level as "1",

Now we need to look for each return ROWID1 ("669300", "6661974" and "6661975") in ROWID2, it matches with ROWID2 as "6669300"

and return ROWID1 as "1239300" with level as "2"

Similarly it should continue recursively till record is present.

(PLEASE SEE INPUT DATA AND EXPECTED OUTPUT DATA BELOW)

(1)

co_data = [

["6780001"],

["6063024"],

["6780002"],

["6780011"]

]

columns_co = ['EMP_ROWID']

df_data = spark.createDataFrame(co_data, columns_co)

(2)

rltnshp = [

["6669300", "6780001"],

["6661974", "6780001"],

["6661975", "6780001"],

["1239300", "6669300"],

["5555555", "6063024"],

["6666666", "6780002"],

["4444444", "6780011"],

["3333333", "4444444"]

]

columns_rl = ['ROWID1', 'ROWID2']

df_rltnshp = spark.createDataFrame(rltnshp, columns_rl )

EXPECTED OUTPUT

+--------+--------+-----+

|parent |child |level|

+--------+--------+-----+

|6780001 |6669300 |1 |

|6780001 |6661974 |1 |

|6780001 |6661975 |1 |

|6669300 |1239300 |2 |

|6063024 |5555555 |1 |

|6780002 |6666666 |1 |

|6780011 |4444444 |1 |

|4444444 |3333333 |2 |

+--------+--------+-----+

Azure Databricks
Azure Databricks

An Apache Spark-based analytics platform optimized for Azure.

0 comments No comments

Answer accepted by question author

Pratyush Vashistha 5,135 Reputation points Microsoft External Staff Moderator
2025-09-23T08:38:03.9666667+00:00

Hello Pankaj Joshi, Thanks for reaching us out.

Here's a comprehensive solution using Databricks GraphFrames to perform breadth-first search traversal for your hierarchical relationship data:

from graphframes import GraphFrame
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, when, max as spark_max
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
# Initialize Spark Session (if not already initialized)
spark = SparkSession.builder.appName("GraphFrameBFS").getOrCreate()
# Input Data
co_data = [
    ["6780001"],
    ["6063024"],
    ["6780002"],
    ["6780011"]
]
columns_co = ['EMP_ROWID']
df_data = spark.createDataFrame(co_data, columns_co)
rltnshp = [
    ["6669300", "6780001"],
    ["6661974", "6780001"],
    ["6661975", "6780001"],
    ["1239300", "6669300"],
    ["5555555", "6063024"],
    ["6666666", "6780002"],
    ["4444444", "6780011"],
    ["3333333", "4444444"]
]
columns_rl = ['ROWID1', 'ROWID2']
df_rltnshp = spark.createDataFrame(rltnshp, columns_rl)
# Create vertices DataFrame (all unique nodes)
vertices_from_data = df_data.select(col("EMP_ROWID").alias("id"))
vertices_from_rltnshp1 = df_rltnshp.select(col("ROWID1").alias("id"))
vertices_from_rltnshp2 = df_rltnshp.select(col("ROWID2").alias("id"))
vertices = vertices_from_data.union(vertices_from_rltnshp1).union(vertices_from_rltnshp2).distinct()
# Create edges DataFrame (ROWID2 -> ROWID1, representing parent -> child relationship)
edges = df_rltnshp.select(
    col("ROWID2").alias("src"),  # parent
    col("ROWID1").alias("dst")   # child
)
# Create GraphFrame
graph = GraphFrame(vertices, edges)
# Function to perform BFS for each starting node
def perform_bfs_for_all_nodes(graph, start_nodes_df):
    """
    Perform BFS for all starting nodes and return hierarchical results
    """
    all_results = []
    
    # Convert start_nodes_df to list for iteration
    start_nodes_list = [row['EMP_ROWID'] for row in start_nodes_df.collect()]
    
    for start_node in start_nodes_list:
        # Perform BFS from current start node
        bfs_result = graph.bfs(
            fromExpr=f"id = '{start_node}'",
            toExpr="id IS NOT NULL",  # Find all reachable nodes
            maxPathLength=10  # Adjust based on your expected maximum depth
        )
        
        if bfs_result.count() > 0:
            # Extract parent-child relationships with levels
            # BFS result contains columns like: from, e0, v1, e1, v2, e2, v3, ...
            # We need to extract the path information
            
            # Get the schema to understand the structure
            bfs_columns = bfs_result.columns
            
            # Process each level of the BFS result
            current_results = []
            
            # Level 1: from -> v1
            level_1 = bfs_result.filter(col("v1").isNotNull()).select(
                col("from.id").alias("parent"),
                col("v1.id").alias("child"),
                lit(1).alias("level")
            )
            if level_1.count() > 0:
                current_results.append(level_1)
            
            # Level 2: v1 -> v2
            level_2 = bfs_result.filter(col("v2").isNotNull()).select(
                col("v1.id").alias("parent"),
                col("v2.id").alias("child"),
                lit(2).alias("level")
            )
            if level_2.count() > 0:
                current_results.append(level_2)
            
            # Level 3: v2 -> v3 (add more levels as needed)
            level_3 = bfs_result.filter(col("v3").isNotNull()).select(
                col("v2.id").alias("parent"),
                col("v3.id").alias("child"),
                lit(3).alias("level")
            )
            if level_3.count() > 0:
                current_results.append(level_3)
            
            # Union all levels for current start node
            if current_results:
                node_result = current_results[0]
                for i in range(1, len(current_results)):
                    node_result = node_result.union(current_results[i])
                all_results.append(node_result)
    
    # Union all results from all start nodes
    if all_results:
        final_result = all_results[0]
        for i in range(1, len(all_results)):
            final_result = final_result.union(all_results[i])
        return final_result
    else:
        # Return empty DataFrame with correct schema
        schema = StructType([
            StructField("parent", StringType(), True),
            StructField("child", StringType(), True),
            StructField("level", IntegerType(), True)
        ])
        return spark.createDataFrame([], schema)
# Alternative Simpler Approach using iterative BFS
def iterative_bfs_approach(df_data, df_rltnshp):
    """
    Iterative approach to build hierarchy levels
    """
    # Initialize result DataFrame
    result_schema = StructType([
        StructField("parent", StringType(), True),
        StructField("child", StringType(), True),
        StructField("level", IntegerType(), True)
    ])
    result_df = spark.createDataFrame([], result_schema)
    
    # Start with level 1: direct children of start nodes
    current_parents = df_data.select(col("EMP_ROWID").alias("parent"))
    level = 1
    
    while current_parents.count() > 0:
        # Find children for current parents
        level_result = current_parents.join(
            df_rltnshp, 
            current_parents.parent == df_rltnshp.ROWID2, 
            "inner"
        ).select(
            col("parent"),
            col("ROWID1").alias("child"),
            lit(level).alias("level")
        )
        
        if level_result.count() == 0:
            break
            
        # Add to result
        result_df = result_df.union(level_result)
        
        # Prepare for next level
        current_parents = level_result.select(col("child").alias("parent")).distinct()
        level += 1
        
        # Safety check to prevent infinite loops
        if level > 10:
            break
    
    return result_df
# Execute the BFS traversal
print("=== Using Iterative BFS Approach ===")
final_result = iterative_bfs_approach(df_data, df_rltnshp)
# Sort by parent and level for better readability
final_result = final_result.orderBy("parent", "level", "child")
# Display results
final_result.show()
# Verify the count and structure
print(f"Total relationships found: {final_result.count()}")
print("\nSchema:")
final_result.printSchema()

Output:

User's image

Please "Accept as Answer" if the answer provided is useful, so that you can help others in the community looking for remediation for similar issues.

Thanks

Pratyush
User's image

Was this answer helpful?

1 person found this answer helpful.

0 additional answers

Sort by: Most helpful

Your answer

Answers can be marked as 'Accepted' by the question author and 'Recommended' by moderators, which helps users know the answer solved the author's problem.