2

我創建了一個PipelineModel爲(通過PySpark API)在星火2.0做LDA:任何方式來訪問PySpark PipelineModel中各個階段的方法?

def create_lda_pipeline(minTokenLength=1, minDF=1, minTF=1, numTopics=10, seed=42, pattern='[\W]+'): 
    """ 
    Create a pipeline for running an LDA model on a corpus. This function does not need data and will not actually do 
    any fitting until invoked by the caller. 
    Args: 
     minTokenLength: 
     minDF: minimum number of documents word is present in corpus 
     minTF: minimum number of times word is found in a document 
     numTopics: 
     seed: 
     pattern: regular expression to split words 

    Returns: 
     pipeline: class pyspark.ml.PipelineModel 
    """ 
    reTokenizer = RegexTokenizer(inputCol="text", outputCol="tokens", pattern=pattern, minTokenLength=minTokenLength) 
    cntVec = CountVectorizer(inputCol=reTokenizer.getOutputCol(), outputCol="vectors", minDF=minDF, minTF=minTF) 
    lda = LDA(k=numTopics, seed=seed, optimizer="em", featuresCol=cntVec.getOutputCol()) 
    pipeline = Pipeline(stages=[reTokenizer, cntVec, lda]) 
    return pipeline 

我想用用LDAModel.logPerplexity()方法訓練的模型來計算一個數據集的困惑,所以我嘗試運行以下:

try: 
    training = get_20_newsgroups_data(test_or_train='test') 
    pipeline = create_lda_pipeline(numTopics=20, minDF=3, minTokenLength=5) 
    model = pipeline.fit(training) # train model on training data 
    testing = get_20_newsgroups_data(test_or_train='test') 
    perplexity = model.logPerplexity(testing) 
    pprint(perplexity) 

這只是導致以下AttributeError

'PipelineModel' object has no attribute 'logPerplexity' 

我明白爲什麼會發生此錯誤,因爲logPerplexity方法屬於LDAModel,而不是PipelineModel,但我想知道是否有方法從該階段訪問該方法。

回答

4

管道中的所有變壓器都存儲在stages屬性中。提取stages,拿最後一個,你準備好了:

model.stages[-1].logPerplexity(testing) 
+0

哇,謝謝。保存了我的培根! –