2016-04-25 197 views
7

所有data types in pyspark.sql.types are如何在PySpark中的UDF中返回「元組類型」?

__all__ = [ 
    "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", 
    "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", 
    "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"] 

我必須寫一個UDF(在pyspark),它返回的元組的陣列。我給第二個參數是什麼,這是udf方法的返回類型?這將是在ArrayType(TupleType())行......

+0

您的標題問題看起來不符合正文。文檔沒有告訴你如何將返回值設置爲*「其他類型的容器類型」*? – jonrsharpe

+0

@jonrsharpe我改變了標題。希望它現在是身體的代表。 – kamalbanga

回答

11

有沒有這樣的事情在火花TupleType。產品類型被表示爲structs,具有特定類型的字段。例如,如果你想返回對數組(整數,字符串),你可以使用模式是這樣的:

from pyspark.sql.types import * 

schema = ArrayType(StructType([ 
    StructField("char", StringType(), False), 
    StructField("count", IntegerType(), False) 
])) 

用法示例:

from pyspark.sql.functions import udf 
from collections import Counter 

char_count_udf = udf(
    lambda s: Counter(s).most_common(), 
    schema 
) 

df = sc.parallelize([(1, "foo"), (2, "bar")]).toDF(["id", "value"]) 

df.select("*", char_count_udf(df["value"])).show(2, False) 

## +---+-----+-------------------------+ 
## |id |value|PythonUDF#<lambda>(value)| 
## +---+-----+-------------------------+ 
## |1 |foo |[[o,2], [f,1]]   | 
## |2 |bar |[[r,1], [a,1], [b,1]] | 
## +---+-----+-------------------------+ 
+0

您的答案正在工作,但我的情況有點複雜。我的返回數據是'[('a1',[('b1',1),('b2',2)]),('a2',[('b1',1),('b2 ',2)])]'所以我創建一個類型爲ArrayType(StructType(「StructField(」date「,StringType(),False),ArrayType(StructType([StructField(」hashId「,StringType(),False ),StructField(「TimeSpent-Front」,FloatType(),False),StructField(「TimeSpent-Back」,FloatType(),False)]))]))'which給** **'ArrayType'object has no attribute'名稱'** ... – kamalbanga

+1

'StructType'需要一個'StructFields'序列,因此您不能單獨使用'ArrayTypes'。你需要'StructField'來存儲'ArrayType'。另外建議 - 如果你發現自己創建這樣的結構,你應該重新考慮數據模型。如果沒有UDF,深嵌套結構很難處理,而且Python UDF遠沒有效率。 – zero323

+0

如何在udf中指定模式來返回列表。 F.udf(lambda start_date,end_date:[0,1] if start_date pseudocode

4

#2一直指導我這個問題,所以我猜猜我會在這裏添加一些信息。

返回簡單類型的UDF:

from pyspark.sql.types import * 
from pyspark.sql import functions as F 

def get_df(): 
    d = [(0.0, 0.0), (0.0, 3.0), (1.0, 6.0), (1.0, 9.0)] 
    df = sqlContext.createDataFrame(d, ['x', 'y']) 
    return df 

df = get_df() 
df.show() 

# +---+---+ 
# | x| y| 
# +---+---+ 
# |0.0|0.0| 
# |0.0|3.0| 
# |1.0|6.0| 
# |1.0|9.0| 
# +---+---+ 

func = udf(lambda x: str(x), StringType()) 
df = df.withColumn('y_str', func('y')) 

func = udf(lambda x: int(x), IntegerType()) 
df = df.withColumn('y_int', func('y')) 

df.show() 

# +---+---+-----+-----+ 
# | x| y|y_str|y_int| 
# +---+---+-----+-----+ 
# |0.0|0.0| 0.0| 0| 
# |0.0|3.0| 3.0| 3| 
# |1.0|6.0| 6.0| 6| 
# |1.0|9.0| 9.0| 9| 
# +---+---+-----+-----+ 

df.printSchema() 

# root 
# |-- x: double (nullable = true) 
# |-- y: double (nullable = true) 
# |-- y_str: string (nullable = true) 
# |-- y_int: integer (nullable = true) 

當整數是不夠的:

df = get_df() 

func = udf(lambda x: [0]*int(x), ArrayType(IntegerType())) 
df = df.withColumn('list', func('y')) 

func = udf(lambda x: {float(y): str(y) for y in range(int(x))}, 
      MapType(FloatType(), StringType())) 
df = df.withColumn('map', func('y')) 

df.show() 
# +---+---+--------------------+--------------------+ 
# | x| y|    list|     map| 
# +---+---+--------------------+--------------------+ 
# |0.0|0.0|     []|    Map()| 
# |0.0|3.0|   [0, 0, 0]|Map(2.0 -> 2, 0.0...| 
# |1.0|6.0| [0, 0, 0, 0, 0, 0]|Map(0.0 -> 0, 5.0...| 
# |1.0|9.0|[0, 0, 0, 0, 0, 0...|Map(0.0 -> 0, 5.0...| 
# +---+---+--------------------+--------------------+ 

df.printSchema() 
# root 
# |-- x: double (nullable = true) 
# |-- y: double (nullable = true) 
# |-- list: array (nullable = true) 
# | |-- element: integer (containsNull = true) 
# |-- map: map (nullable = true) 
# | |-- key: float 
# | |-- value: string (valueContainsNull = true) 

返回從UDF複雜數據類型:

df = get_df() 
df = df.groupBy('x').agg(F.collect_list('y').alias('y[]')) 
df.show() 

# +---+----------+ 
# | x|  y[]| 
# +---+----------+ 
# |0.0|[0.0, 3.0]| 
# |1.0|[9.0, 6.0]| 
# +---+----------+ 

schema = StructType([ 
    StructField("min", FloatType(), True), 
    StructField("size", IntegerType(), True), 
    StructField("edges", ArrayType(FloatType()), True), 
    StructField("val_to_index", MapType(FloatType(), IntegerType()), True) 
    # StructField('insanity', StructType([StructField("min_", FloatType(), True), StructField("size_", IntegerType(), True)])) 

]) 

def func(values): 
    mn = min(values) 
    size = len(values) 
    lst = sorted(values)[::-1] 
    val_to_index = {x: i for i, x in enumerate(values)} 
    return (mn, size, lst, val_to_index) 

func = udf(func, schema) 
dff = df.select('*', func('y[]').alias('complex_type')) 
dff.show(10, False) 

# +---+----------+------------------------------------------------------+ 
# |x |y[]  |complex_type           | 
# +---+----------+------------------------------------------------------+ 
# |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]| 
# |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]| 
# +---+----------+------------------------------------------------------+ 

dff.printSchema() 

# +---+----------+------------------------------------------------------+ 
# |x |y[]  |complex_type           | 
# +---+----------+------------------------------------------------------+ 
# |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]| 
# |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]| 
# +---+----------+------------------------------------------------------+ 

傳遞多個參數的UDF:

df = get_df() 
func = udf(lambda arr: arr[0]*arr[1],FloatType()) 
df = df.withColumn('x*y', func(F.array('x', 'y'))) 

    # +---+---+---+ 
    # | x| y|x*y| 
    # +---+---+---+ 
    # |0.0|0.0|0.0| 
    # |0.0|3.0|0.0| 
    # |1.0|6.0|6.0| 
    # |1.0|9.0|9.0| 
    # +---+---+---+ 

該代碼純粹是爲了演示目的,所有上述轉換都可以在Spark代碼中使用,並且會產生更好的性能。 在上面的註釋中,作爲@ zero323,通常應該在pyspark中避免UDF;返回複雜類型應該讓你考慮簡化你的邏輯。