我們需要使用新的基於DataFrame的API ml
來獲取概率,而不是基於RDD的mllib
API。
更新
以下是從火花文檔更新例如使用BinaryClassificationEvaluator
並顯示指標:Area Under Receiver Operating Characteristic
(AUROC)和Area Under Precision Recall Curve
(AUPRC)。
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
// Load and parse the data file, converting it to a DataFrame.
val data = sqlContext.read.format("libsvm").load("D:/Sources/spark/data/mllib/sample_libsvm_data.txt")
// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data)
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
val featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)
.fit(data)
// Split the data into training and test sets (30% held out for testing)
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
// Train a RandomForest model.
val rf = new RandomForestClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setNumTrees(10)
// Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels)
// Chain indexers and forest in a Pipeline
val pipeline = new Pipeline()
.setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))
// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)
// Make predictions.
val predictions = model.transform(testData)
// Select example rows to display.
predictions
.select("indexedLabel", "rawPrediction", "prediction")
.show()
val binaryClassificationEvaluator = new BinaryClassificationEvaluator()
.setLabelCol("indexedLabel")
.setRawPredictionCol("rawPrediction")
def printlnMetric(metricName: String): Unit = {
println(metricName + " = " + binaryClassificationEvaluator.setMetricName(metricName).evaluate(predictions))
}
printlnMetric("areaUnderROC")
printlnMetric("areaUnderPR")
的可能的複製[1.5.1火花,MLLib隨機森林的概率(http://stackoverflow.com/questions/33401437/spark-1-5-1-mllib-random-forest-probability) – eliasah
@eliasah實際上並不是一個重複的問題,但其中的答案提供了問題的解決方案。在您評論之前,我已經在答案中添加了這一點。 –
沒關係。沒問題 !因此,使用「可能」一詞 – eliasah