一文解读pandas_udf
1.函数定义
pyspark.sql.functions.pandas_udf(f=None, returnType=None, functionType=None)
Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to work with the data, which allows vectorized operations
使用spark arrow传输数据,由pandas处理数据,由于使用pandas,所以可以进行一些向量化处理
参数解读
- f: user-defined function 用户自定义函数
- returnType:the return type of the user-defined function 用户自定义函数输出类型
- functionType:
type | 说明 | 备注 |
---|---|---|
SCALAR | 单独处理 DataFrame 的每个元素。它采用一个或多个 pandas Series 作为输入,并返回一个 pandas Series。这种类型的 Pandas UDF 应用于 DataFrame 的 select 和 withColumn 方法。适用于 element-wise 操作 | default |
SCALAR_ITER | 类似于 SCALAR,但它是在迭代器上操作的,允许更有效地处理大型数据集 | - |
GROUPED_MAP | 用于分组操作,需要返回与输入 DataFrame 相同大小的 DataFrame。应用于 DataFrame 的 groupBy 和 apply 方法。适用于分组转换操作 | - |
GROUPED_AGG | 用于分组聚合操作,将一组值减少为一个标量值。应用于 DataFrame 的 groupBy 和 agg 方法。适用于分组聚合操作 | 还有一个和MAP的显著区别是,这个只支持一列作为输入,所以无法将整个pdf输入到UDF函数里 |
2.code show
2.1 SCALER 操作
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, PandasUDFTypespark = SparkSession.builder \\.appName("pandas_udf scaler Example") \\.getOrCreate()# 创建虚拟数据集
g = np.tile(['group a','group b'], 10)
x = np.linspace(0, 10., 20)
np.random.seed(3) # set seed for reproducibility
y_lin = 2*x + np.random.rand(len(x))/10.
y_qua = 3*x**2 + np.random.rand(len(x))
df = pd.DataFrame({'group': g, 'x': x, 'y_lin': y_lin, 'y_qua': y_qua})
schema = StructType([StructField('group', StringType(), nullable=False),StructField('x', DoubleType(), nullable=False),StructField('y_lin', DoubleType(), nullable=False),StructField('y_qua', DoubleType(), nullable=False),
])
df = spark.createDataFrame(df, schema=schema)
+-------+------------------+-------------------+-------------------+
| group| x| y_lin| y_qua|
+-------+------------------+-------------------+-------------------+
|group a| 0.0|0.05507979025745755|0.28352508177131874|
|group b|0.5263157894736842| 1.123446361209179| 1.5241628490609185|
|group a|1.0526315789473684| 2.134353631786031| 3.7645534406624286|
|group b|1.5789473684210527| 3.2089774973618717| 7.6360921152062655|
|group a|2.1052631578947367| 4.299821011224239| 13.8410479099986|
|group b| 2.631578947368421| 5.352787203630186| 21.555938033209422|
|group a|3.1578947368421053| 6.328348004730595| 30.22326103930139|
+-------+------------------+-------------------+-------------------+# 对一列进行操作
# series to series pandas UDF
@F.pandas_udf(DoubleType())
def standardise(col1: pd.Series) -> pd.Series:return (col1 - col1.mean())/col1.std()
res = df.select(standardise(F.col('y_lin')).alias('result'))
res.show(5)
+-------------------+
| result|
+-------------------+
|-1.6054255151193093|
|-1.4337009540623533|
|-1.2712121491623172|
| -1.098481817986802|
|-0.9231444116198374|
+-------------------+def standardise(col1: pd.Series) -> pd.Series:return (col1 - col1.mean())/col1.std()standard_udf = pandas_udf(standardise, DoubleType())
df = df.withColumn("y_lin_standard", standard_udf(F.col('y_lin')))
df.show(3)
+-------+------------------+-------------------+-------------------+-------------------+
| group| x| y_lin| y_qua| y_lin_standard|
+-------+------------------+-------------------+-------------------+-------------------+
|group a| 0.0|0.05507979025745755|0.28352508177131874|-1.6054255151193093|
|group b|0.5263157894736842| 1.123446361209179| 1.5241628490609185|-1.4337009540623533|
|group a|1.0526315789473684| 2.134353631786031| 3.7645534406624286|-1.2712121491623172|
+-------+------------------+-------------------+-------------------+-------------------+def standardise(col1: pd.Series, col2: pd.Series) -> pd.Series:return (col1 - col2.mean())/col1.std()standard_udf = pandas_udf(standardise, DoubleType())
df = df.withColumn("ret", standard_udf(F.col('y_lin'), F.col('y_qua')))
df.show(3)
+-------+------------------+-------------------+-------------------+-------------------+-------------------+
| group| x| y_lin| y_qua| y_lin_standard| ret|
+-------+------------------+-------------------+-------------------+-------------------+-------------------+
|group a| 0.0|0.05507979025745755|0.28352508177131874|-1.6054255151193093| -16.57141348616838|
|group b|0.5263157894736842| 1.123446361209179| 1.5241628490609185|-1.4337009540623533|-16.399688925111427|
|group a|1.0526315789473684| 2.134353631786031| 3.7645534406624286|-1.2712121491623172| -16.23720012021139|
+-------+------------------+-------------------+-------------------+-------------------+-------------------+# 官方
@pandas_udf("col1 string, col2 long")
def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame:s3['col2'] = s1 + s2.str.len()return s3# Create a Spark DataFrame that has three columns including a struct column.
df = spark.createDataFrame([[1, "a string", ("a nested string",)]],"long_col long, string_col string, struct_col struct<col1:string>")
df.show()
+--------+----------+-----------------+
|long_col|string_col| struct_col|
+--------+----------+-----------------+
| 1| a string|{a nested string}|
+--------+----------+-----------------+
df.select(func("long_col", "string_col", "struct_col").alias('ret')).show()
+--------------------+
| ret|
+--------------------+
|{a nested string, 9}|
+--------------------+# 输出dataframe
@pandas_udf("first string, last string")
def split_expand(s: pd.Series) -> pd.DataFrame:return s.str.split(expand=True)
df = spark.createDataFrame([("John Doe",)], ("name",))
df.select(split_expand("name")).show()
2.2 SCALER_ITER
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StringType, LongType
import pandas as pd
import respark = SparkSession.builder \\.appName("Pandas UDF Example") \\.getOrCreate()def extract_numbers(series: pd.Series) -> pd.Series:return series.apply(lambda x: int(re.sub(r'\\D', '', x)) if re.sub(r'\\D', '', x) else None)@pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
def extract_numbers_udf(series_iter):for series in series_iter:yield extract_numbers(series)data = [("abc123",), ("def456",), ("ghi789",), ("jkl0",)]
schema = "text STRING"
input_df = spark.createDataFrame(data, schema=schema)
result_df = input_df.select(extract_numbers_udf("text").alias("numbers"))
result_df.show()
+-------+
|numbers|
+-------+
| 123|
| 456|
| 789|
| 0|
+-------+
2.3 GROUP_MAP
@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
# Input/output are both a pandas.DataFrame
def subtract_mean(pdf):return pdf.assign(v=pdf.v - pdf.v.mean())df.groupby('id').apply(subtract_mean)
2.4 GROUP_AGG
# stype1
@F.pandas_udf(T.DoubleType())
def average_column(col1: pd.Series, col2: pd.Series) -> float:return (col1 + col2).mean()
res = df.groupby('group').agg(average_column(F.col('y_lin'), F.col('y_qua')).alias('average of y_lin + y_qua'))# stype2
def average_(col1: pd.Series, col2: pd.Series) -> float:return (col1 + col2).mean()
average_column = pandas_udf(average_, DoubleType(), PandasUDFType.GROUPED_AGG)
res = df.groupby('group').agg(average_column(F.col('y_lin'), F.col('y_qua')).alias('average of y_lin + y_qua'))show_frame(res)
# +-------+------------------------+
# |group |average of y_lin + y_qua|
# +-------+------------------------+
# |group a|104.770 |
# |group b|121.621 |
# +-------+------------------------+
3.使用限制以及解决方案:
使用限制
- 自定义函数不接受额外的参数
- 不接受conditional expressions(a > 1)或者Short-circuiting(ex:a==b)
- pyspark.sql.types.ArrayType的pyspark.sql.types.TimestampType和嵌套的pyspark.sql.types.StructType目前不支持作为输出类型
- (错误)当函数类型是GROUPED_AGG时,只支持一列作为输入,所以无法将整个pdf输入到UDF函数里; 这个限制后面来看是不成立的,所以增加标识
限制1,自定义中需要传入函数,可以通过python的装饰器函数解决
def sum_pd(pdf):v = pdf.vreturn pdf.assign(c=v.sum())
sum_udf = pandas_udf(sum_pd, "id long, v double, c double", PandasUDFType.GROUPED_MAP)
df.groupby("id").apply(sum_udf).show()+---+----+----+
| id| v| c|
+---+----+----+
| 1| 1.0| 3.0|
| 1| 2.0| 3.0|
| 2| 3.0|18.0|
| 2| 5.0|18.0|
| 2|10.0|18.0|
+---+----+----+# 增加参数的例子
def sum_pd(pp):def wrap(pdf):v = pdf.vreturn pdf.assign(c=v.sum() + pp)return wrappp = 1
sum_p = sum_pd(pp)
sum_udf = pandas_udf(sum_p, "id long, v double, c double", PandasUDFType.GROUPED_MAP)
df.groupby("id").apply(sum_udf).show()+---+----+----+
| id| v| c|
+---+----+----+
| 1| 1.0| 4.0|
| 1| 2.0| 4.0|
| 2| 3.0|19.0|
| 2| 5.0|19.0|
| 2|10.0|19.0|
+---+----+----+
对于限制4,首先需要声明的是是否支持多列输入是取决于函数本身,在我开始的例子中,由于入参是pdf,所以无法支持多列,此中情况下,可以引入StructType解决,将需要输入的列整合到struct中输入到UDF函数中;
当入参设定的就是多列时,是支持多列的,但是为了代码的简洁性,个人更加倾向于第一种写法
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, PandasUDFType, struct
import pandas as pdspark = SparkSession.builder \\.appName("Two Day Subtract Example") \\.getOrCreate()data = [(1, "2023-04-12", 2.0),(1, "2023-04-11", 4.0),(2, "2023-04-12", 3.0),(2, "2023-04-11", 5.0)
]columns = ["id", "dt", "v"]df = spark.createDataFrame(data, columns)def two_day_subtract(window_end, datapipe):def wrap(pdf):assert 0 < pdf.shape[0] <= 2dt = pdf[0]if pdf.shape[0] == 1:return pdf[0][datapipe.stat_col] * (1 if dt == window_end else -1)else:return (pdf[0][datapipe.stat_col] - pdf[1][datapipe.stat_col]) * (1 if dt == window_end else -1)return wrapclass dd:def __init__(self, stat_col, dt):self.stat_col = stat_colself.dt = dtwindow_end = '2023-04-12'
datapipe = dd(stat_col='v', dt='dt')
idd = 'id'substract_udf = pandas_udf("double", PandasUDFType.GROUPED_AGG)(two_day_subtract(window_end, datapipe))
stat_df = df.groupby(idd).agg(substract_udf(struct(df['v'], df['dt'])).alias("num"))
stat_df.show()# 装饰器写法也可以
@pandas_udf(DoubleType())
def two_day_subtract(window_end, datapipe):def wrap(pdf):assert 0 < pdf.shape[0] <= 2dt = pdf[0]if pdf.shape[0] == 1:return pdf[0][datapipe.stat_col] * (1 if dt == window_end else -1)else:return (pdf[0][datapipe.stat_col] - pdf[1][datapipe.stat_col]) * (1 if dt == window_end else -1)return wrap
stat_df = df.groupby(idd).agg(substract_udf(struct(df['v'], df['dt'])).alias("num"))
stat_df.show()# 直接多列输入
def two_day_subtract(window_end, datapipe):def wrap(s1, s2):assert 0 < len(s1) <= 2dt = s2[0]if len(s1) == 1:return s1 * (1 if dt == window_end else -1)else:return (s1[0] - s1[1]) * (1 if dt == window_end else -1)return wrapsubstract_udf = pandas_udf("double", PandasUDFType.GROUPED_AGG)(two_day_subtract(window_end, datapipe))
stat_df = df.groupby(idd).agg(substract_udf(F.col('v'), F.col('dt')).alias("num"))
stat_df.show()
reference:
1. Documents
2. Pandas UDFs in PySpark
3. Blog on databricks