> 文章列表 > 一文解读pandas_udf

一文解读pandas_udf

一文解读pandas_udf

1.函数定义

pyspark.sql.functions.pandas_udf(f=None, returnType=None, functionType=None)
Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to work with the data, which allows vectorized operations
使用spark arrow传输数据,由pandas处理数据,由于使用pandas,所以可以进行一些向量化处理

参数解读

  • f: user-defined function 用户自定义函数
  • returnType:the return type of the user-defined function 用户自定义函数输出类型
  • functionType:
type 说明 备注
SCALAR 单独处理 DataFrame 的每个元素。它采用一个或多个 pandas Series 作为输入,并返回一个 pandas Series。这种类型的 Pandas UDF 应用于 DataFrame 的 select 和 withColumn 方法。适用于 element-wise 操作 default
SCALAR_ITER 类似于 SCALAR,但它是在迭代器上操作的,允许更有效地处理大型数据集 -
GROUPED_MAP 用于分组操作,需要返回与输入 DataFrame 相同大小的 DataFrame。应用于 DataFrame 的 groupBy 和 apply 方法。适用于分组转换操作 -
GROUPED_AGG 用于分组聚合操作,将一组值减少为一个标量值。应用于 DataFrame 的 groupBy 和 agg 方法。适用于分组聚合操作 还有一个和MAP的显著区别是,这个只支持一列作为输入,所以无法将整个pdf输入到UDF函数里

2.code show

2.1 SCALER 操作

from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, PandasUDFTypespark = SparkSession.builder \\.appName("pandas_udf scaler Example") \\.getOrCreate()# 创建虚拟数据集
g = np.tile(['group a','group b'], 10)
x = np.linspace(0, 10., 20)
np.random.seed(3) # set seed for reproducibility
y_lin = 2*x + np.random.rand(len(x))/10.
y_qua = 3*x**2 + np.random.rand(len(x))
df = pd.DataFrame({'group': g, 'x': x, 'y_lin': y_lin, 'y_qua': y_qua})
schema = StructType([StructField('group', StringType(), nullable=False),StructField('x', DoubleType(), nullable=False),StructField('y_lin', DoubleType(), nullable=False),StructField('y_qua', DoubleType(), nullable=False),
])
df = spark.createDataFrame(df, schema=schema)
+-------+------------------+-------------------+-------------------+
|  group|                 x|              y_lin|              y_qua|
+-------+------------------+-------------------+-------------------+
|group a|               0.0|0.05507979025745755|0.28352508177131874|
|group b|0.5263157894736842|  1.123446361209179| 1.5241628490609185|
|group a|1.0526315789473684|  2.134353631786031| 3.7645534406624286|
|group b|1.5789473684210527| 3.2089774973618717| 7.6360921152062655|
|group a|2.1052631578947367|  4.299821011224239|   13.8410479099986|
|group b| 2.631578947368421|  5.352787203630186| 21.555938033209422|
|group a|3.1578947368421053|  6.328348004730595|  30.22326103930139|
+-------+------------------+-------------------+-------------------+# 对一列进行操作
# series to series pandas UDF
@F.pandas_udf(DoubleType())
def standardise(col1: pd.Series) -> pd.Series:return (col1 - col1.mean())/col1.std()
res = df.select(standardise(F.col('y_lin')).alias('result'))
res.show(5)
+-------------------+
|             result|
+-------------------+
|-1.6054255151193093|
|-1.4337009540623533|
|-1.2712121491623172|
| -1.098481817986802|
|-0.9231444116198374|
+-------------------+def standardise(col1: pd.Series) -> pd.Series:return (col1 - col1.mean())/col1.std()standard_udf = pandas_udf(standardise, DoubleType())
df = df.withColumn("y_lin_standard", standard_udf(F.col('y_lin')))
df.show(3)
+-------+------------------+-------------------+-------------------+-------------------+
|  group|                 x|              y_lin|              y_qua|     y_lin_standard|
+-------+------------------+-------------------+-------------------+-------------------+
|group a|               0.0|0.05507979025745755|0.28352508177131874|-1.6054255151193093|
|group b|0.5263157894736842|  1.123446361209179| 1.5241628490609185|-1.4337009540623533|
|group a|1.0526315789473684|  2.134353631786031| 3.7645534406624286|-1.2712121491623172|
+-------+------------------+-------------------+-------------------+-------------------+def standardise(col1: pd.Series, col2: pd.Series) -> pd.Series:return (col1 - col2.mean())/col1.std()standard_udf = pandas_udf(standardise, DoubleType())
df = df.withColumn("ret", standard_udf(F.col('y_lin'), F.col('y_qua')))
df.show(3)
+-------+------------------+-------------------+-------------------+-------------------+-------------------+
|  group|                 x|              y_lin|              y_qua|     y_lin_standard|                ret|
+-------+------------------+-------------------+-------------------+-------------------+-------------------+
|group a|               0.0|0.05507979025745755|0.28352508177131874|-1.6054255151193093| -16.57141348616838|
|group b|0.5263157894736842|  1.123446361209179| 1.5241628490609185|-1.4337009540623533|-16.399688925111427|
|group a|1.0526315789473684|  2.134353631786031| 3.7645534406624286|-1.2712121491623172| -16.23720012021139|
+-------+------------------+-------------------+-------------------+-------------------+-------------------+# 官方
@pandas_udf("col1 string, col2 long")
def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame:s3['col2'] = s1 + s2.str.len()return s3# Create a Spark DataFrame that has three columns including a struct column.
df = spark.createDataFrame([[1, "a string", ("a nested string",)]],"long_col long, string_col string, struct_col struct<col1:string>")
df.show()
+--------+----------+-----------------+
|long_col|string_col|       struct_col|
+--------+----------+-----------------+
|       1|  a string|{a nested string}|
+--------+----------+-----------------+
df.select(func("long_col", "string_col", "struct_col").alias('ret')).show()
+--------------------+
|                 ret|
+--------------------+
|{a nested string, 9}|
+--------------------+# 输出dataframe
@pandas_udf("first string, last string")
def split_expand(s: pd.Series) -> pd.DataFrame:return s.str.split(expand=True)
df = spark.createDataFrame([("John Doe",)], ("name",))
df.select(split_expand("name")).show()

2.2 SCALER_ITER

from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StringType, LongType
import pandas as pd
import respark = SparkSession.builder \\.appName("Pandas UDF Example") \\.getOrCreate()def extract_numbers(series: pd.Series) -> pd.Series:return series.apply(lambda x: int(re.sub(r'\\D', '', x)) if re.sub(r'\\D', '', x) else None)@pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
def extract_numbers_udf(series_iter):for series in series_iter:yield extract_numbers(series)data = [("abc123",), ("def456",), ("ghi789",), ("jkl0",)]
schema = "text STRING"
input_df = spark.createDataFrame(data, schema=schema)
result_df = input_df.select(extract_numbers_udf("text").alias("numbers"))
result_df.show()
+-------+
|numbers|
+-------+
|    123|
|    456|
|    789|
|      0|
+-------+

2.3 GROUP_MAP

@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
# Input/output are both a pandas.DataFrame
def subtract_mean(pdf):return pdf.assign(v=pdf.v - pdf.v.mean())df.groupby('id').apply(subtract_mean)

2.4 GROUP_AGG

# stype1
@F.pandas_udf(T.DoubleType())
def average_column(col1: pd.Series, col2: pd.Series) -> float:return (col1 + col2).mean()
res = df.groupby('group').agg(average_column(F.col('y_lin'), F.col('y_qua')).alias('average of y_lin + y_qua'))# stype2
def average_(col1: pd.Series, col2: pd.Series) -> float:return (col1 + col2).mean()
average_column = pandas_udf(average_, DoubleType(), PandasUDFType.GROUPED_AGG)
res = df.groupby('group').agg(average_column(F.col('y_lin'), F.col('y_qua')).alias('average of y_lin + y_qua'))show_frame(res)
# +-------+------------------------+
# |group  |average of y_lin + y_qua|
# +-------+------------------------+
# |group a|104.770                 |
# |group b|121.621                 |
# +-------+------------------------+

3.使用限制以及解决方案:

使用限制

  1. 自定义函数不接受额外的参数
  2. 不接受conditional expressions(a > 1)或者Short-circuiting(ex:a==b)
  3. pyspark.sql.types.ArrayType的pyspark.sql.types.TimestampType和嵌套的pyspark.sql.types.StructType目前不支持作为输出类型
  4. (错误)当函数类型是GROUPED_AGG时,只支持一列作为输入,所以无法将整个pdf输入到UDF函数里; 这个限制后面来看是不成立的,所以增加标识

限制1,自定义中需要传入函数,可以通过python的装饰器函数解决

def sum_pd(pdf):v = pdf.vreturn pdf.assign(c=v.sum())
sum_udf = pandas_udf(sum_pd, "id long, v double, c double", PandasUDFType.GROUPED_MAP)
df.groupby("id").apply(sum_udf).show()+---+----+----+
| id|   v|   c|
+---+----+----+
|  1| 1.0| 3.0|
|  1| 2.0| 3.0|
|  2| 3.0|18.0|
|  2| 5.0|18.0|
|  2|10.0|18.0|
+---+----+----+# 增加参数的例子
def sum_pd(pp):def wrap(pdf):v = pdf.vreturn pdf.assign(c=v.sum() + pp)return wrappp = 1
sum_p = sum_pd(pp)
sum_udf = pandas_udf(sum_p, "id long, v double, c double", PandasUDFType.GROUPED_MAP)
df.groupby("id").apply(sum_udf).show()+---+----+----+
| id|   v|   c|
+---+----+----+
|  1| 1.0| 4.0|
|  1| 2.0| 4.0|
|  2| 3.0|19.0|
|  2| 5.0|19.0|
|  2|10.0|19.0|
+---+----+----+

对于限制4,首先需要声明的是是否支持多列输入是取决于函数本身,在我开始的例子中,由于入参是pdf,所以无法支持多列,此中情况下,可以引入StructType解决,将需要输入的列整合到struct中输入到UDF函数中;
当入参设定的就是多列时,是支持多列的,但是为了代码的简洁性,个人更加倾向于第一种写法

from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, PandasUDFType, struct
import pandas as pdspark = SparkSession.builder \\.appName("Two Day Subtract Example") \\.getOrCreate()data = [(1, "2023-04-12", 2.0),(1, "2023-04-11", 4.0),(2, "2023-04-12", 3.0),(2, "2023-04-11", 5.0)
]columns = ["id", "dt", "v"]df = spark.createDataFrame(data, columns)def two_day_subtract(window_end, datapipe):def wrap(pdf):assert 0 < pdf.shape[0] <= 2dt = pdf[0]if pdf.shape[0] == 1:return pdf[0][datapipe.stat_col] * (1 if dt == window_end else -1)else:return (pdf[0][datapipe.stat_col] - pdf[1][datapipe.stat_col]) * (1 if dt == window_end else -1)return wrapclass dd:def __init__(self, stat_col, dt):self.stat_col = stat_colself.dt = dtwindow_end = '2023-04-12'
datapipe = dd(stat_col='v', dt='dt')
idd = 'id'substract_udf = pandas_udf("double", PandasUDFType.GROUPED_AGG)(two_day_subtract(window_end, datapipe))
stat_df = df.groupby(idd).agg(substract_udf(struct(df['v'], df['dt'])).alias("num"))
stat_df.show()# 装饰器写法也可以
@pandas_udf(DoubleType())
def two_day_subtract(window_end, datapipe):def wrap(pdf):assert 0 < pdf.shape[0] <= 2dt = pdf[0]if pdf.shape[0] == 1:return pdf[0][datapipe.stat_col] * (1 if dt == window_end else -1)else:return (pdf[0][datapipe.stat_col] - pdf[1][datapipe.stat_col]) * (1 if dt == window_end else -1)return wrap
stat_df = df.groupby(idd).agg(substract_udf(struct(df['v'], df['dt'])).alias("num"))
stat_df.show()# 直接多列输入
def two_day_subtract(window_end, datapipe):def wrap(s1, s2):assert 0 < len(s1) <= 2dt = s2[0]if len(s1) == 1:return s1 * (1 if dt == window_end else -1)else:return (s1[0] - s1[1]) * (1 if dt == window_end else -1)return wrapsubstract_udf = pandas_udf("double", PandasUDFType.GROUPED_AGG)(two_day_subtract(window_end, datapipe))
stat_df = df.groupby(idd).agg(substract_udf(F.col('v'), F.col('dt')).alias("num"))
stat_df.show()

reference:
1. Documents
2. Pandas UDFs in PySpark
3. Blog on databricks