2017-02-20 184 views
1

我對scala和spark 2.1很陌生。 我試圖計算一個數據幀,它看起來像這些元素之間的相關性:將Spark數據幀轉換爲org.apache.spark.rdd.RDD [org.apache.spark.mllib.linalg.Vector]

item_1 | item_2 | item_3 | item_4 
    1 |  1 |  4 |  3 
    2 |  0 |  2 |  0 
    0 |  2 |  0 |  1 

這裏是我試過:元素之間

val df = sqlContext.createDataFrame(
    Seq((1, 1, 4, 3), 
     (2, 0, 2, 0), 
     (0, 2, 0, 1) 
).toDF("item_1", "item_2", "item_3", "item_4") 


val items = df.select(array(df.columns.map(col(_)): _*)).rdd.map(_.getSeq[Double](0)) 

而且calcualte相關:

val correlMatrix: Matrix = Statistics.corr(items, "pearson") 

隨着followning錯誤消息:

<console>:89: error: type mismatch; 
found : org.apache.spark.rdd.RDD[Seq[Double]] 
required: org.apache.spark.rdd.RDD[org.apache.spark.mllib.linalg.Vector] 
     val correlMatrix: Matrix = Statistics.corr(items, "pearson") 

我不知道如何從數據框中創建org.apache.spark.rdd.RDD[org.apache.spark.mllib.linalg.Vector]

這可能是一個非常簡單的任務,但我有點掙扎,我很樂意提供任何建議。

回答

5

例如,您可以使用VectorAssembler。裝配載體和轉換爲RDD

import org.apache.spark.ml.feature.VectorAssembler 

val rows = new VectorAssembler().setInputCols(df.columns).setOutputCol("vs") 
    .transform(df) 
    .select("vs") 
    .rdd 

提取Vectors

  • 星火1.x中:

    rows.map(_.getAs[org.apache.spark.mllib.linalg.Vector](0)) 
    
  • 星火2.X:

    rows 
        .map(_.getAs[org.apache.spark.ml.linalg.Vector](0)) 
        .map(org.apache.spark.mllib.linalg.Vectors.fromML) 
    

關於你的代碼:

  • Integer列不Double
  • 數據不是array所以你不能使用_.getSeq[Double](0)
+0

非常感謝你 - 這就是我一直在尋找的解決方案 – Duesentrieb

2

如果您的目標是執行皮爾森相關性,則不必真正使用RDD和向量。以下是直接在DataFrame列上執行Pearson相關性的示例(所討論的列是Doublebles類型)。

代碼:

import org.apache.spark.sql.{SQLContext, Row, DataFrame} 
import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerType, DoubleType} 
import org.apache.spark.sql.functions._ 


val rb = spark.read.option("delimiter","|").option("header","false").option("inferSchema","true").format("csv").load("rb.csv").toDF("name","beerId","brewerId","abv","style","appearance","aroma","palate","taste","overall","time","reviewer").cache() 

rb.agg(
    corr("overall","taste"), 
    corr("overall","aroma"), 
    corr("overall","palate"), 
    corr("overall","appearance"), 
    corr("overall","abv") 
    ).show() 

在本例中,我導入數據幀(具有自定義分隔符,無標頭,和推斷的數據類型),然後簡單地對數據幀執行AGG功能其中有多個相關關係。



輸出:

+--------------------+--------------------+---------------------+-------------------------+------------------+ 
|corr(overall, taste)|corr(overall, aroma)|corr(overall, palate)|corr(overall, appearance)|corr(overall, abv)| 
+--------------------+--------------------+---------------------+-------------------------+------------------+ 
| 0.8762432795943761| 0.789023067942876| 0.7008942639550395|  0.5663593891357243|0.3539158620897098| 
+--------------------+--------------------+---------------------+-------------------------+------------------+ 

你可以從結果中看到,(整體,味道)列是高度相關的,而(整體,ABV)沒有這麼多。

這是鏈接到Scala Docs DataFrame page which has the Aggregation Correlation Function

+0

謝謝你的這種方式。它做的工作,但我有超過300列來計算 – Duesentrieb

+0

有沒有一種方法來計算許多列沒有明確定義每個組合? – Duesentrieb