2016-08-30 88 views
4

我有一個包含多個類的大型數據集。我的目標是爲每個班級擬定一個模型,然後預測結果並在一個方面爲每個班級形象化他們。將不同的模型擬合到R中的每個數據子集

對於一個可重複的例子,我創建了一些基本的使用mtcars。這適用於每個類的簡單迴歸模型。

mtcars = data.table(mtcars) 
model = mtcars[, list(fit = list(lm(mpg~disp+hp+wt))), keyby = cyl] 
setkey(mtcars, cyl) 
mtcars[model, pred := predict(i.fit[[1]], .SD), by = .EACHI] 
ggplot(data = mtcars, aes(x = mpg, y = pred)) + geom_line() + facet_wrap(~cyl) 

但是,我想嘗試下面的東西,它還沒有工作。這個嘗試與公式列表有關,但我也希望向每個數據子集發送不同的模型(一些glms,幾棵樹)。

mtcars = data.table(mtcars) 
factors = list(c("disp","wt"), c("disp"), c("hp")) 
form = lapply(factors, function(x) as.formula(paste("mpg~",paste(x,collapse="+")))) 
model = mtcars[, list(fit = list(lm(form))), keyby = cyl] 
setkey(mtcars, cyl) 
mtcars[model, pred := predict(i.fit[[1]], .SD), by = .EACHI] 
ggplot(data = mtcars, aes(x = mpg, y = pred)) + geom_line() + facet_wrap(~cyl) 
+0

數據表真的有必要嗎? – rawr

+0

不,但在大型數據集上更快。 dplyr也可以。 – Divi

+1

我只是暗示瓶頸將會被預測,lm,ggplot。是'list(fit = lapply(form,lm,data = .SD))'你想要什麼 – rawr

回答

4

這裏就是我們每個模式作爲一個未評估的列表中設置了predict的方法,該data.table對象,gather輸出中評估他們,並把它傳遞到ggplot

models = quote(list(
     predict(lm(form[[1]], .SD)), 
     predict(lm(form[[2]], .SD)), 
     predict(lm(form[[3]], .SD)))) 

d <- mtcars 
d[, c("est1", "est2", "est3") := eval(models), by = cyl] 
d <- tidyr::gather(d, key = model, value = pred, est1:est3) 

library(ggplot2) 
ggplot(d, aes(x = mpg, y = pred)) + geom_line() + facet_grid(cyl ~ model) 

輸出:

enter image description here

3

lm()也接受公式作爲字符向量。因此,我想簡單地創建form爲:

form = lapply(factors, function(x) paste("mpg~", paste(x, collapse="+"))) 

而且,你需要提供正確的數據(使用內置的特殊符號.SD對應於每個組):

model = mtcars[, list(fit=lapply(form, lm, data=.SD)), keyby=cyl] 

對於每個cylform被循環遍歷,並且將相應的公式作爲第一個參數傳遞給lm,每次連同data = .SD,其中.SD代表數據子集並且其本身是data.table。您可以從vignettes瞭解更多關於它的信息。


如果你也想在結果中的公式,則:

chform = unlist(form) 
model = mtcars[, list(form=chform, fit=lapply(form, lm, data=.SD)), keyby = cyl] 

HTH

PS:請閱讀this post如果你打算使用[...]內data.tables使用update()

+0

這解決了我目前面臨的問題。唯一的問題 - 我不明白爲什麼在不提供'data = .SD'的情況下安裝一個通用模型時,它可以工作? – Divi

+0

公式對象還捕獲它們創建的環境..這就是要使用的。看一下'?lm'。 – Arun

1

我現在確實在做這個事情,所以完美的時機。這將是一個「反轉」 - 重要的答案,但我真的很喜歡它的運作方式。

purrr有一些非常方便的功能map用時列表的列tibble相結合,使這個令人難以置信的流暢。使用您的定義(我並不想以優化)

library(data.table) 
mtcars = data.table(mtcars) 
factors = list(c("disp","wt"), c("disp"), c("hp")) 
form = lapply(factors, function(x) as.formula(paste("mpg~",paste(x,collapse="+")))) 

它提供的功能列表,這些可以被傳遞到purrr::invoke_map它適用的參數列表(你有),以功能列表(在你的情況下,只是lm,但我懷疑可擴展到其他人)與可選參數(在你的例子中,mtcars)。使用tibble,這些被存儲成爲純淨data.frame -esque list,否則會被返回lm對象

library(tibble) 
library(purrr) 
models <- tibble(fit = invoke_map(lm, form, data = mtcars)) 
models 
#> # A tibble: 3 x 1 
#>   fit 
#>  <list> 
#> 1 <S3: lm> 
#> 2 <S3: lm> 
#> 3 <S3: lm> 

超有用的部分來當你想要做的事,以所有這些因素,比方說,提取擬合係數:

map(models$fit, coefficients) 
#> [[1]] 
#> (Intercept)  disp   wt 
#> 34.96055404 -0.01772474 -3.35082533 
#> 
#> [[2]] 
#> (Intercept)  disp 
#> 29.59985476 -0.04121512 
#> 
#> [[3]] 
#> (Intercept)   hp 
#> 30.09886054 -0.06822828 

或重新審視式使用

map(models$fit, formula) 
#> [[1]] 
#> mpg ~ disp + wt 
#> <environment: 0x0000000017ee73a8> 
#> 
#> [[2]] 
#> mpg ~ disp 
#> <environment: 0x0000000018392c58> 
#> 
#> [[3]] 
#> mpg ~ hp 
#> <environment: 0x0000000018471d18> 

菲爾特ermore,如果你想從型號上添加一些預測,這是很容易使用broom::augment

library(broom) 
models_with_predicts <- models %>% mutate(predict = map(fit, augment)) 
models_with_predicts 
#> # A tibble: 3 x 2 
#>   fit    predict 
#>  <list>     <list> 
#> 1 <S3: lm> <data.frame [32 x 10]> 
#> 2 <S3: lm> <data.frame [32 x 9]> 
#> 3 <S3: lm> <data.frame [32 x 9]> 

你可以回去數據級(含預測)由unnest() ING實現,但這將結合所有的數據(添加分組級別以保持擬合獨立)

library(tidyr) 
unnest(models_with_predicts, predict) 

#> # A tibble: 96 x 11 
#> mpg disp wt .fitted .se.fit  .resid  .hat .sigma  .cooksd .std.resid hp 
#> <dbl> <dbl> <dbl> <dbl>  <dbl>  <dbl>  <dbl> <dbl>  <dbl>  <dbl> <dbl> 
#> 1 21.0 160.0 2.620 23.34543 0.6075520 -2.3454326 0.04339369 2.933379 0.010222201 -0.8222164 NA 
#> 2 21.0 160.0 2.875 22.49097 0.6221836 -1.4909721 0.04550894 2.954135 0.004351414 -0.5232550 NA 
#> 3 22.8 108.0 2.320 25.27237 0.7326015 -2.4723669 0.06309504 2.928665 0.017217431 -0.8757799 NA 
#> 4 21.4 258.0 3.215 19.61467 0.5743205 1.7853334 0.03877647 2.948162 0.005241995 0.6243627 NA 
#> 5 18.7 360.0 3.440 17.05281 1.0943208 1.6471930 0.14078260 2.949120 0.020275438 0.6092882 NA 
#> 6 18.1 225.0 3.460 19.37863 0.6122393 -1.2786309 0.04406584 2.957872 0.003089406 -0.4483953 NA 
#> 7 14.3 360.0 3.570 16.61720 0.9897465 -2.3171997 0.11516157 2.931444 0.030948880 -0.8446199 NA 
#> 8 24.4 146.7 3.190 21.67120 0.9053245 2.7287988 0.09635365 2.918183 0.034431234 0.9842424 NA 
#> 9 22.8 140.8 3.150 21.90981 0.9165259 0.8901898 0.09875274 2.962885 0.003775416 0.3215070 NA 
#> 10 19.2 167.6 3.440 20.46305 0.9678618 -1.2630477 0.11012510 2.957375 0.008693734 -0.4590766 NA 
#> # ... with 86 more rows