2016-06-09 380 views
0

如何在單列(在新DataFrame中)中將DataFrame中的多個列(例如3)組合成一個Spark DenseVector?類似於這個thread,但在Java中,並在下面提到的一些調整。在Spark DataFrame中合併多個列[Java]

我嘗試使用UDF這樣的:

private UDF3<Double, Double, Double, Row> toColumn = new UDF3<Double, Double, Double, Row>() { 

    private static final long serialVersionUID = 1L; 

    public Row call(Double first, Double second, Double third) throws Exception {   
     Row row = RowFactory.create(Vectors.dense(first, second, third)); 

     return row; 
    } 
}; 

,然後註冊UDF:

sqlContext.udf().register("toColumn", toColumn, dataType); 

dataType是:

StructType dataType = DataTypes.createStructType(new StructField[]{ 
    new StructField("bla", new VectorUDT(), false, Metadata.empty()), 
    }); 

當我把這個UDF一個有3列的DataFrame並打印出新的DataFrame的模式,我得到這個:

root |-- features: struct (nullable = true) | |-- bla: vector (nullable = false)

這裏的問題是,我需要一個載體來在外面,而不是一個結構內。 事情是這樣的:

root 
|-- features: vector (nullable = true) 

我不知道如何得到這個,因爲register功能需要UDF的返回類型爲DataType(這反過來,不提供VectorType)

回答

0

你居然手動通過使用這種數據類型嵌套向量型成一個結構:

new StructField("bla", new VectorUDT(), false, Metadata.empty()), 

如果刪除外StructField,你會得到你想要的東西。當然,在這種情況下,你需要修改你的函數定義的簽名。也就是說,你需要返回Vector類型。

請參閱下面的具體示例,我的意思是以簡單的JUnit測試的形式。

package sample.spark.test; 

import org.apache.spark.api.java.JavaSparkContext; 
import org.apache.spark.mllib.linalg.Vector; 
import org.apache.spark.mllib.linalg.VectorUDT; 
import org.apache.spark.mllib.linalg.Vectors; 
import org.apache.spark.sql.DataFrame; 
import org.apache.spark.sql.RowFactory; 
import org.apache.spark.sql.SQLContext; 
import org.apache.spark.sql.api.java.UDF3; 
import org.apache.spark.sql.types.DataTypes; 
import org.apache.spark.sql.types.Metadata; 
import org.apache.spark.sql.types.StructField; 
import org.junit.Test; 

import java.io.Serializable; 
import java.util.Arrays; 
import java.util.HashSet; 
import java.util.Set; 

import static org.junit.Assert.assertEquals; 
import static org.junit.Assert.assertTrue; 

public class ToVectorTest implements Serializable { 
    private static final long serialVersionUID = 2L; 

    private UDF3<Double, Double, Double, Vector> toColumn = new UDF3<Double, Double, Double, Vector>() { 

    private static final long serialVersionUID = 1L; 

    public Vector call(Double first, Double second, Double third) throws Exception { 
     return Vectors.dense(first, second, third); 
    } 
    }; 

    @Test 
    public void testUDF() { 
    // context 
    final JavaSparkContext sc = new JavaSparkContext("local", "ToVectorTest"); 
    final SQLContext sqlContext = new SQLContext(sc); 

    // test input 
    final DataFrame input = sqlContext.createDataFrame(
     sc.parallelize(
      Arrays.asList(
       RowFactory.create(1.0, 2.0, 3.0), 
       RowFactory.create(4.0, 5.0, 6.0), 
       RowFactory.create(7.0, 8.0, 9.0), 
       RowFactory.create(10.0, 11.0, 12.0) 
      )), 
     DataTypes.createStructType(
      Arrays.asList(
       new StructField("feature1", DataTypes.DoubleType, false, Metadata.empty()), 
       new StructField("feature2", DataTypes.DoubleType, false, Metadata.empty()), 
       new StructField("feature3", DataTypes.DoubleType, false, Metadata.empty()) 
      ) 
     ) 
    ); 
    input.registerTempTable("input"); 

    // expected output 
    final Set<Vector> expectedOutput = new HashSet<>(Arrays.asList(
     Vectors.dense(1.0, 2.0, 3.0), 
     Vectors.dense(4.0, 5.0, 6.0), 
     Vectors.dense(7.0, 8.0, 9.0), 
     Vectors.dense(10.0, 11.0, 12.0) 
    )); 

    // processing 
    sqlContext.udf().register("toColumn", toColumn, new VectorUDT()); 
    final DataFrame outputDF = sqlContext.sql("SELECT toColumn(feature1, feature2, feature3) AS x FROM input"); 
    final Set<Vector> output = new HashSet<>(outputDF.toJavaRDD().map(r -> r.<Vector>getAs("x")).collect()); 

    // evaluation 
    assertEquals(expectedOutput.size(), output.size()); 
    for (Vector x : output) { 
     assertTrue(expectedOutput.contains(x)); 
    } 

    // show the schema and the content 
    System.out.println(outputDF.schema()); 
    outputDF.show(); 

    sc.stop(); 
    } 
} 
+0

這正是我所需要的。不知何故,我設法不考慮從UDF返回一個Vector,並向VectorUDT註冊函數。感謝羅伯特! – Rajko