2016-12-04 70 views
1

我即將開發一個函數,該函數使用spark sql爲每列執行操作。在這個函數中,我需要引用列名:Spark SQL以編程方式引用列

val input = Seq(
    (0, "A", "B", "C", "D"), 
    (1, "A", "B", "C", "D"), 
    (0, "d", "a", "jkl", "d"), 
    (0, "d", "g", "C", "D"), 
    (1, "A", "d", "t", "k"), 
    (1, "d", "c", "C", "D"), 
    (1, "c", "B", "C", "D") 
).toDF("TARGET", "col1", "col2", "col3TooMany", "col4") 

下面的例子明確地通過'column表示列工作正常。

val pre1_1 = input.groupBy('col1).agg(mean($"TARGET").alias("pre_col1")) 
val pre2_1 = input.groupBy('col1, 'TARGET).agg(count("*")/input.filter('TARGET === 1).count alias ("pre2_col1")) 

input.as('a) 
    .join(pre1_1.as('b), $"a.col1" === $"b.col1").drop($"b.col1") 
    .join(pre2_1.as('b), ($"a.col1" === $"b.col1") and ($"a.TARGET" === $"b.TARGET")).drop($"b.col1").drop($"b.TARGET").show 

When referring to the columns programmatically they can no longer be resolved. When 2 joins are performed one after the other which worked fine for the code snippet above. 

我可以觀察到,對於此代碼段的第一和初始的dfcol1從開始到結束移動。可能這是它不能再解決的原因。 但是到目前爲止,我無法弄清楚如何在只傳遞字符串時如何訪問列/如何正確引用函數中的名稱。

val pre1_1 = input.groupBy("col1").agg(mean('TARGET).alias("pre_" + "col1")) 
val pre2_1 = input.groupBy("col1", "TARGET").agg(count("*")/input.filter('TARGET === 1).count alias ("pre2_" + "col1")) 
    input.join(pre1_1, input("col1") === pre1_1("col1")).drop(pre1_1("col1")) 
    .join(pre2_1, (input("col1") === pre2_1("col1")) and (input("TARGET") === pre2_1("TARGET"))).drop(pre2_1("col1")).drop(pre2_1("TARGET")) 

以及像一個替代方法:

df.as('a) 
     .join(pre1_1.as('b), $"a.${col}" === $"b.${col}").drop($"b.${col}") 

沒有成功,因爲$"a.${col}"不再被解析爲a.Column而是df("a.col1")不存在。

回答

2

在複雜情況下,始終使用唯一的別名來引用具有共享沿襲的列。這是確保正確和穩定行爲的唯一方法。

import org.apache.spark.sql.functions.col 

val pre1_1 = input.groupBy("col1").agg(mean('TARGET).alias("pre_" + "col1")).alias("pre1_1") 
val pre2_1 = input.groupBy("col1", "TARGET").agg(count("*")/input.filter('TARGET === 1).count alias ("pre2_" + "col1")).alias("pre2_1") 

input.alias("input") 
    .join(pre1_1, col("input.col1") === col("pre1_1.col1")) 
    .join(pre2_1, (col("input.col1") === col("pre2_1.col1")) and (col("input.TARGET") === col("pre2_1.TARGET"))) 

如果您檢查日誌你其實看到這樣的警告:

WARN柱:構建平凡真實等於斷言, 'COL1#12 =#COL1 12'。也許你需要使用別名

和你使用的代碼只有工作,因爲在Spark源中有「特殊情況」。

在簡單的情況下,這樣只需使用等連接語法:

input.join(pre1_1, Seq("col1")) 
    .join(pre2_1, Seq("col1", "TARGET"))