您需要爲Vector創建一個udf進行過濾。以下爲我工作:
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.functions.udf
val df = sc.parallelize(Seq(
(1, 1, 1), (1, 2, 3), (1, 3, 5), (2, 4, 6),
(2, 5, 2), (2, 6, 1), (3, 7, 5), (3, 8, 16),
(1, 1, 1))).toDF("c1", "c2", "c3")
val dfVec = new VectorAssembler()
.setInputCols(Array("c1", "c2", "c3"))
.setOutputCol("features")
.transform(df)
def vectors_unequal(vec1: Vector) = udf((vec2: Vector) => !vec1.equals(vec2))
val vecToRemove = Vectors.dense(1,1,1)
val filtered = dfVec.where(vectors_unequal(vecToRemove)(dfVec.col("features")))
val filtered2 = dfVec.filter(vectors_unequal(vecToRemove)($"features")) // Also possible
dfVec show
產量:
+---+---+---+--------------+
| c1| c2| c3| features|
+---+---+---+--------------+
| 1| 1| 1| [1.0,1.0,1.0]|
| 1| 2| 3| [1.0,2.0,3.0]|
| 1| 3| 5| [1.0,3.0,5.0]|
| 2| 4| 6| [2.0,4.0,6.0]|
| 2| 5| 2| [2.0,5.0,2.0]|
| 2| 6| 1| [2.0,6.0,1.0]|
| 3| 7| 5| [3.0,7.0,5.0]|
| 3| 8| 16|[3.0,8.0,16.0]|
| 1| 1| 1| [1.0,1.0,1.0]|
+---+---+---+--------------+
filtered show
產量:
+---+---+---+--------------+
| c1| c2| c3| features|
+---+---+---+--------------+
| 1| 2| 3| [1.0,2.0,3.0]|
| 1| 3| 5| [1.0,3.0,5.0]|
| 2| 4| 6| [2.0,4.0,6.0]|
| 2| 5| 2| [2.0,5.0,2.0]|
| 2| 6| 1| [2.0,6.0,1.0]|
| 3| 7| 5| [3.0,7.0,5.0]|
| 3| 8| 16|[3.0,8.0,16.0]|
+---+---+---+--------------+