2016-12-05 302 views
2

我有一個Spark(1.5.2)數據框和訓練有素的RandomForestClassificationModel。我可以很容易地得到數據並得到一個預測,但我想深入分析哪些邊緣值是每個二進制分類場景中最常見的參與者。隨機森林分析

在過去,我做了類似於RDD的功能,通過自己計算預測來跟蹤功能使用情況。在下面的代碼中,我跟蹤用於計算預測的特徵列表。 DataFrame似乎並不像RDD那樣直截了當。

def predict(node:Node, features: Vector, path_in:Array[Int]) : (Double,Double,Array[Int]) = 
{ 
    if (node.isLeaf) 
    { 
     (node.predict.predict,node.predict.prob,path_in) 
    } 
    else 
    { 
     //track our path through the tree 
     val path = path_in :+ node.split.get.feature 

     if (node.split.get.featureType == FeatureType.Continuous) 
     { 
      if (features(node.split.get.feature) <= node.split.get.threshold) 
      { 
       predict(node.leftNode.get, features, path) 
      } 
      else 
      { 
       predict(node.rightNode.get, features, path) 
      } 
     } 
     else 
     { 
      if (node.split.get.categories.contains(features(node.split.get.feature))) 
      { 
       predict(node.leftNode.get, features, path) 
      } 
      else 
      { 
       predict(node.rightNode.get, features, path) 
      } 
     } 
    } 
} 

我想要做類似這樣的代碼什麼的,而是針對每個特徵向量我回所有功能/邊緣值對的列表。請注意,在我的數據集中,所有功能都是分類的,並且在構建林時適當使用了倉設置。

回答

0

我最終建立一個自定義udf做到這一點:

//Base Prediction method. Accepts a Random Forest Model and a Feature Vector 
// Returns an Array of predictions, one per tree, the impurity, the feature used on the final edge, and the feature value. 
def predicForest(m:RandomForestClassificationModel, point: Vector) : (Double, Array[(Double,Double,(Int,Double))])={ 
    val results = m.trees.map(t=> predict(t.rootNode,point)) 

    (results.map(x=> x._1).sum/results.count(x=> true), results) 
} 

def predict(node:Node, features: Vector) : (Double,Double,(Int,Double)) = { 
    if (node.isInstanceOf[InternalNode]){ 
     //track our path through the tree 
     val internalNode = node.asInstanceOf[InternalNode] 
     if (internalNode.split.isInstanceOf[CategoricalSplit]) { 
     val split = internalNode.split.asInstanceOf[CategoricalSplit] 
     val featureValue = features(split.featureIndex) 
     if (split.leftCategories.contains(featureValue)) { 
      if (internalNode.leftChild.isInstanceOf[LeafNode]) { 
      (node.prediction,node.impurity,(internalNode.split.featureIndex, featureValue)) 
      } else 
      predict(internalNode.leftChild, features) 
     } else { 
      if (internalNode.rightChild.isInstanceOf[LeafNode]) { 
      (node.prediction,node.impurity,(internalNode.split.featureIndex, featureValue)) 
      } else 
      predict(internalNode.rightChild, features) 
     } 
     } else { 
     //If we run into an unimplemented type we just return 
     (node.prediction,node.impurity,(-1,-1)) 
     } 
    } else { 
     //If we run into an unimplemented type we just return 
     (node.prediction,node.impurity,(-1,-1)) 
    } 
} 

val rfModel = yourInstanceOfRandomForestClassificationModel 

//This custom UDF executes the Random Forest Classification in a trackable way 
def treeAnalyzer(m:RandomForestClassificationModel) = udf((x:Vector) => 
    predicForest(m,x)) 

//Execute the UDF, this will execute the Random Forest classification on each row and store the results from each tree in a new column named `prediction` 
val df3 = testData.withColumn("prediction", treeAnalyzer(rfModel)(testData("indexedFeatures")))