2014-09-11 80 views
1

我有以下方法,執行分層k摺疊交叉驗證的邏輯的一部分。試圖瞭解linq /延遲執行如何工作

private static IEnumerable<IEnumerable<int>> GenerateFolds(
    IClassificationProblemData problemData, int numberOfFolds) 
{ 
    IRandom random = new MersenneTwister(); 
    IEnumerable<double> values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices); 

    var valuesIndices = 
     problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v }); 

    IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = 
     valuesIndices.GroupBy(x => x.Value, x => x.Index) 
        .Select(g => GenerateFolds(g, g.Count(), numberOfFolds)); 

    var enumerators = foldsByClass.Select(x => x.GetEnumerator()).ToList(); 

    while (enumerators.All(e => e.MoveNext())) 
    { 
     var fold = enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()); 
     yield return fold.ToList(); 
    } 
} 

褶皺產生:

private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(
    IEnumerable<T> values, int valuesCount, int numberOfFolds) 
{ 
    // number of folds rounded to integer and remainder 
    int f = valuesCount/numberOfFolds, r = valuesCount % numberOfFolds; 
    int start = 0, end = f; 

    for (int i = 0; i < numberOfFolds; ++i) 
    { 
     if (r > 0) 
     { 
      ++end; 
      --r; 
     } 

     yield return values.Skip(start).Take(end - start); 
     start = end; 
     end += f; 
    } 
} 

通用GenerateFolds<T方法只是根據摺疊的指定數目的分割的IEnumerable<T>IEnumerable秒的序列。例如,如果我有101個訓練樣本,它將產生11倍大小的一倍和10倍大小的9倍。

上面的方法根據類別值對樣本進行分組,將每個組分割成指定的摺疊數然後將最後的摺疊連接到類別摺疊,以確保類別標籤的分佈相同。

我的問題關於行yield return fold.ToList()。實際上,如果我刪除了ToList(),則該方法正常工作,但結果不再正確。在我的測試案例中,我有641個訓練樣本和10個摺疊,這意味着第一個摺疊的大小應爲65,剩餘的摺疊大小爲64.但是,當我刪除ToList()時,所有摺疊的大小爲64,並且類標籤不正確分散式。任何想法爲什麼?謝謝。

+0

旁註 - '的IEnumerable >>'爽:) – 2014-09-11 07:58:12

+1

感謝編輯和製作我的問題看起來更好! :) – 2014-09-11 08:10:47

+0

剛剛驗證你的代碼 - 工作正常,沒有保存摺疊列表。它返回'numberOfFolds'索引組數 – 2014-09-11 08:25:02

回答

1

讓我們覺得是fold變量:

var fold = enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()); 

這不是查詢執行的結果。這是一個查詢定義。因爲SelectManyOrderBy都是具有延遲執行方式的運算符。所以,它只會保存所有統計員平整當前項目的知識,並以隨機順序返回它們。我已經突出顯示字當前,因爲它是查詢執行時的當前項目。

現在讓我們考慮何時執行此查詢。 GenerateFolds方法執行的結果是IEnumerableIEnumerable<int>查詢。以下代碼不會執行任何查詢:

var folds = GenerateFolds(indices, values, numberOfFolds); 

這又是一個查詢。你可以通過調用ToList()或枚舉它執行:

var f = folds.ToList(); 

但即使是現在的內層查詢不被執行。他們全部退回,但沒有執行。即while循環在GenerateFolds已執行,同時您將查詢保存到列表f。而e.MoveNext()已經打過幾次電話,直到你退出循環:

while (enumerators.All(e => e.MoveNext())) 
{ 
    var fold = enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next()); 
    yield return fold; 
} 

那麼,什麼f持有?它包含查詢列表。因此,你已經得到了他們全部,當前項目是每個調查員的最後一項(記住 - 我們已經在此時完全迭代while循環)。但是這些查詢都沒有執行!在這裏你執行其中的第一個:

f[0].Count() 

你得到第一個查詢返回的項數(定義在問題頂部)。但是,因此您已經列舉了所有查詢當前項目是最後一個項目。你可以得到最後一個項目的索引數。

現在採取

folds.First().Count() 

看看這裏你不枚舉所有的查詢將其保存到列表中。即while循環只執行一次,當前項目是第一項。這就是爲什麼你有第一項索引計數。這就是爲什麼這些值是不同的。

最後一個問題 - 爲什麼當您在while循環內添加ToList()時,所有工作都正常。答案很簡單 - 執行每個查詢。你有索引列表而不是查詢定義。每個查詢在每次迭代中執行,因此當前的項目總是不同的。而你的代碼工作正常。

+1

謝謝,我的印象是,查詢定義與e.Current的狀態無關(即,'enumerators.SelectMany(e => e.Current)'將在內部存儲關於相應的信息IEnumerables被夷爲平地)。當您被迫執行內部查詢來強制執行正確的行爲時,並不是真正的延遲執行。 – 2014-09-11 21:38:11