Share via


pandas_udf

Creates a pandas user defined function.

Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to work with the data, which allows pandas operations. A Pandas UDF is defined using the pandas_udf as a decorator or to wrap the function, and no additional configuration is required. A Pandas UDF behaves as a regular PySpark function API in general.

Syntax

import pyspark.sql.functions as sf

# As a decorator
@sf.pandas_udf(returnType=<returnType>, functionType=<functionType>)
def function_name(col):
    # function body
    pass

# As a function wrapper
sf.pandas_udf(f=<function>, returnType=<returnType>, functionType=<functionType>)

Parameters

Parameter Type Description
f function Optional. User-defined function. A python function if used as a standalone function.
returnType pyspark.sql.types.DataType or str Optional. The return type of the user-defined function. The value can be either a DataType object or a DDL-formatted type string.
functionType int Optional. An enum value in PandasUDFType. Default: SCALAR. This parameter exists for compatibility. Using Python type hints is encouraged.

Examples

Example 1: Series to Series - Convert strings to uppercase.

import pandas as pd
from pyspark.sql.functions import pandas_udf

@pandas_udf("string")
def to_upper(s: pd.Series) -> pd.Series:
    return s.str.upper()

df = spark.createDataFrame([("John Doe",)], ("name",))
df.select(to_upper("name")).show()
+--------------+
|to_upper(name)|
+--------------+
|      JOHN DOE|
+--------------+

Example 2: Series to Series with keyword arguments.

import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import IntegerType
from pyspark.sql import functions as sf

@pandas_udf(returnType=IntegerType())
def calc(a: pd.Series, b: pd.Series) -> pd.Series:
    return a + 10 * b

spark.range(2).select(calc(b=sf.col("id") * 10, a=sf.col("id"))).show()
+-----------------------------+
|calc(b => (id * 10), a => id)|
+-----------------------------+
|                            0|
|                          101|
+-----------------------------+

Example 3: Iterator of Series to Iterator of Series.

import pandas as pd
from typing import Iterator
from pyspark.sql.functions import pandas_udf

@pandas_udf("long")
def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
    for s in iterator:
        yield s + 1

df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"]))
df.select(plus_one(df.v)).show()
+-----------+
|plus_one(v)|
+-----------+
|          2|
|          3|
|          4|
+-----------+

Example 4: Series to Scalar - Grouped aggregation.

import pandas as pd
from pyspark.sql.functions import pandas_udf

@pandas_udf("double")
def mean_udf(v: pd.Series) -> float:
    return v.mean()

df = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
df.groupby("id").agg(mean_udf(df['v'])).show()
+---+-----------+
| id|mean_udf(v)|
+---+-----------+
|  1|        1.5|
|  2|        6.0|
+---+-----------+

Example 5: Series to Scalar with window functions.

import pandas as pd
from pyspark.sql import Window
from pyspark.sql.functions import pandas_udf

@pandas_udf("double")
def mean_udf(v: pd.Series) -> float:
    return v.mean()

df = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
w = Window.partitionBy('id').orderBy('v').rowsBetween(-1, 0)
df.withColumn('mean_v', mean_udf("v").over(w)).show()
+---+----+------+
| id|   v|mean_v|
+---+----+------+
|  1| 1.0|   1.0|
|  1| 2.0|   1.5|
|  2| 3.0|   3.0|
|  2| 5.0|   4.0|
|  2|10.0|   7.5|
+---+----+------+

Example 6: Iterator of Series to Scalar - Memory-efficient grouped aggregation.

import pandas as pd
from typing import Iterator
from pyspark.sql.functions import pandas_udf

@pandas_udf("double")
def pandas_mean_iter(it: Iterator[pd.Series]) -> float:
    sum_val = 0.0
    cnt = 0
    for v in it:
        sum_val += v.sum()
        cnt += len(v)
    return sum_val / cnt

df = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
df.groupby("id").agg(pandas_mean_iter(df['v'])).show()
+---+-------------------+
| id|pandas_mean_iter(v)|
+---+-------------------+
|  1|                1.5|
|  2|                6.0|
+---+-------------------+