3 votes

Rééchantillonnage imbriqué + LASSO (regr.cvglment) en utilisant mlr

Je suis en train de m'efforcer d'effectuer un échantillonnage imbriqué avec 10 CV pour la boucle interne et 10 CV pour la boucle externe en utilisant regr.cvglmnet. Mlr fournit le code en utilisant une fonction d'enrobage (https://mlr-org.github.io/mlr/articles/tutorial/devel/nested_resampling.html)

Maintenant, j'ai échangé deux choses dans leur code fourni 1) "regr.cvglmnet" au lieu de la machine à vecteurs de support (ksvm) 2) le nombre d'itérations pour la boucle interne et externe

Après la fonction lrn, je reçois l'erreur spécifiée ci-dessous. Quelqu'un pourrait-il m'expliquer cela? Je suis complètement nouveau en programmation et en apprentissage automatique, donc j'ai peut-être fait quelque chose de très stupide dans le code....

ps = makeParamSet(
  makeDiscreteParam("C", values = 2^(-12:12)),
  makeDiscreteParam("sigma", values = 2^(-12:12))
)
ctrl = makeTuneControlGrid()
inner = makeResampleDesc("Subsample", iters = 10)
lrn = makeTuneWrapper("regr.cvglmnet", resampling = inner, par.set = ps, 
                      control = ctrl, show.info = FALSE)

# Erreur dans checkTunerParset(learner, par.set, mesures, control) :
# Ne peut ajuster les paramètres pour lesquels des paramètres de l'apprenant existent : C, sigma

### Boucle externe d'échantillonnage
outer = makeResampleDesc("CV", iters = 10) 
r = resample(lrn, iris.task, resampling = outer, extract = getTuneResult, 
             show.info = FALSE)

2voto

pat-s Points 1813

En utilisant LASSO avec glmnet, vous n'avez besoin de régler que s. C'est le paramètre important utilisé lorsque le modèle prédit de nouvelles données. Le paramètre lambda n'a absolument aucune influence en raison de la façon dont le package est codé sur la prédiction. Si vous définissez s différemment des valeurs de lambda qui ont été choisies, le modèle sera retravaillé avec s comme terme de pénalisation.

Par défaut, plusieurs modèles avec diverses valeurs de lambda sont ajustés pendant l'appel de train. Cependant, pour la prédiction, un nouveau modèle sera ajusté en utilisant la meilleure valeur de lambda. En fait, le réglage est effectué à l'étape de prédiction.

De bonnes plages par défaut pour s peuvent être choisies en

  1. Entrainer le modèle avec les valeurs par défaut de glmnet
  2. Vérifier les valeurs minimales et maximales de lambda
  3. Utiliser celles-ci comme bornes inférieures et supérieures pour s qui est ensuite ajusté en utilisant mlr

Voir aussi cette discussion.

library(mlr)
#> Chargement du package requis: ParamHelpers

lrn_glmnet <- makeLearner("regr.glmnet",
                          alpha = 1,
                          intercept = FALSE)

# vérifier lambda
glmnet_train = mlr::train(lrn_glmnet, bh.task)
summary(glmnet_train$learner.model$lambda)
#>    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
#>   143.5   157.4   172.8   174.3   189.6   208.1

# définir les limites
ps_glmnet <- makeParamSet(makeNumericParam("s", lower = 140, upper = 208))

# ajuster les paramètres en parallèle en utilisant une recherche en grille pour la simplicité
tune.ctrl = makeTuneControlGrid()
inner <- makeResampleDesc("CV", iters = 2)

configureMlr(on.learner.error = "warn", on.error.dump = TRUE)
library(parallelMap)
parallelStart(mode = "multicore", level = "mlr.tuneParams", cpus = 4,
              mc.set.seed = TRUE) # paralleliser uniquement la recherche d'ajustement
#> Démarrage de la parallélisation en mode=multicore avec cpus=4.
set.seed(12345)
params_tuned_glmnet = tuneParams(lrn_glmnet, task = bh.task, resampling = inner,
                                 par.set = ps_glmnet, control = tune.ctrl, 
                                 measure = list(rmse))
#> [Tune] Démarrage de la recherche pour le learner regr.glmnet pour le jeu de paramètres:
#>      Type len Def     Constr Req Tunable Trafo
#> s numeric   -   - 140 à 208   -    TRUE     -
#> Avec une classe de contrôle: TuneControlGrid
#> Valeur d'imputation: Inf
#> Mapping en parallèle: mode = multicore; cpus = 4; éléments = 10.
#> [Tune] Résultat: s=140 : rmse.test.rmse=17.9803086
parallelStop()
#> Arrêt de la parallélisation. Tout est nettoyé.

# entrainer le modèle sur l'ensemble du jeu de données en utilisant la valeur de `s` provenant de l'ajustement

lrn_glmnet_tuned <- makeLearner("regr.glmnet",
                                alpha = 1,
                                s = 140,
                                intercept = FALSE)
#lambda = sort(seq(0, 5, length.out = 100), decreasing = T))
glmnet_train_tuned = mlr::train(lrn_glmnet_tuned, bh.task)

Créé le 03-07-2018 par le paquet reprex (v0.2.0).

devtools::session_info()
#> Informations de session -------------------------------------------------------------
#>  paramètre   valeur                       
#>  version  R version 3.5.0 (2018-04-23)
#>  système   x86_64, linux-gnu           
#>  ui       X11                         
#>  langue (EN)                        
#>  collate  en_US.UTF-8                 
#>  tz       Europe/Berlin               
#>  date     2018-07-03
#> Packages -----------------------------------------------------------------
#>  package      * version   date       source         
#>  backports      1.1.2     2017-12-13 CRAN (R 3.5.0) 
#>  base         * 3.5.0     2018-06-04 local          
#>  BBmisc         1.11      2017-03-10 CRAN (R 3.5.0) 
#>  bit            1.1-14    2018-05-29 cran (@1.1-14) 
#>  bit64          0.9-7     2017-05-08 CRAN (R 3.5.0) 
#>  blob           1.1.1     2018-03-25 CRAN (R 3.5.0) 
#>  checkmate      1.8.5     2017-10-24 CRAN (R 3.5.0) 
#>  codetools      0.2-15    2016-10-05 CRAN (R 3.5.0) 
#>  colorspace     1.3-2     2016-12-14 CRAN (R 3.5.0) 
#>  compiler       3.5.0     2018-06-04 local          
#>  data.table     1.11.4    2018-05-27 CRAN (R 3.5.0) 
#>  datasets     * 3.5.0     2018-06-04 local          
#>  DBI            1.0.0     2018-05-02 cran (@1.0.0)  
#>  devtools       1.13.6    2018-06-27 CRAN (R 3.5.0) 
#>  digest         0.6.15    2018-01-28 CRAN (R 3.5.0) 
#>  evaluate       0.10.1    2017-06-24 CRAN (R 3.5.0) 
#>  fastmatch      1.1-0     2017-01-28 CRAN (R 3.5.0) 
#>  foreach        1.4.4     2017-12-12 CRAN (R 3.5.0) 
#>  ggplot2        2.2.1     2016-12-30 CRAN (R 3.5.0) 
#>  git2r          0.21.0    2018-01-04 CRAN (R 3.5.0) 
#>  glmnet         2.0-16    2018-04-02 CRAN (R 3.5.0) 
#>  graphics     * 3.5.0     2018-06-04 local          
#>  grDevices    * 3.5.0     2018-06-04 local          
#>  grid           3.5.0     2018-06-04 local          
#>  gtable         0.2.0     2016-02-26 CRAN (R 3.5.0) 
#>  htmltools      0.3.6     2017-04-28 CRAN (R 3.5.0) 
#>  iterators      1.0.9     2017-12-12 CRAN (R 3.5.0) 
#>  knitr          1.20      2018-02-20 CRAN (R 3.5.0) 
#>  lattice        0.20-35   2017-03-25 CRAN (R 3.5.0) 
#>  lazyeval       0.2.1     2017-10-29 CRAN (R 3.5.0) 
#>  magrittr       1.5       2014-11-22 CRAN (R 3.5.0) 
#>  Matrix         1.2-14    2018-04-09 CRAN (R 3.5.0) 
#>  memoise        1.1.0     2017-04-21 CRAN (R 3.5.0) 
#>  memuse         4.0-0     2017-11-10 CRAN (R 3.5.0) 
#>  methods      * 3.5.0     2018-06-04 local          
#>  mlr          * 2.13      2018-07-01 local          
#>  munsell        0.5.0     2018-06-12 CRAN (R 3.5.0) 
#>  parallel       3.5.0     2018-06-04 local          
#>  parallelMap  * 1.3       2015-06-10 CRAN (R 3.5.0) 
#>  ParamHelpers * 1.11      2018-06-25 CRAN (R 3.5.0) 
#>  pillar         1.2.3     2018-05-25 CRAN (R 3.5.0) 
#>  plyr           1.8.4     2016-06-08 CRAN (R 3.5.0) 
#>  Rcpp           0.12.17   2018-05-18 cran (@0.12.17)
#>  rlang          0.2.1     2018-05-30 CRAN (R 3.5.0) 
#>  rmarkdown      1.10      2018-06-11 CRAN (R 3.5.0) 
#>  rprojroot      1.3-2     2018-01-03 CRAN (R 3.5.0) 
#>  RSQLite        2.1.1     2018-05-06 cran (@2.1.1)  
#>  scales         0.5.0     2017-08-24 CRAN (R 3.5.0) 
#>  splines        3.5.0     2018-06-04 local          
#>  stats        * 3.5.0     2018-06-04 local          
#>  stringi        1.2.3     2018-06-12 CRAN (R 3.5.0) 
#>  stringr        1.3.1     2018-05-10 CRAN (R 3.5.0) 
#>  survival       2.42-3    2018-04-16 CRAN (R 3.5.0) 
#>  tibble         1.4.2     2018-01-22 CRAN (R 3.5.0) 
#>  tools          3.5.0     2018-06-04 local          
#>  utils        * 3.5.0     2018-06-04 local          
#>  withr          2.1.2     2018-03-15 CRAN (R 3.5.0) 
#>  XML            3.98-1.11 2018-04-16 CRAN (R 3.5.0) 
#>  yaml           2.1.19    2018-05-01 CRAN (R 3.5.0)

1voto

Lars Kotthoff Points 44924

Le message d'erreur vous indique que vous ne pouvez pas régler des paramètres que mlr ne connaît pas pour cet apprenant -- regr.cvglmnet n'a pas les paramètres C et sigma. Vous pouvez obtenir les paramètres que mlr connaît pour un apprenant avec la fonction getLearnerParamSet():

\> getLearnerParamSet(makeLearner("regr.cvglmnet"))
                          Type  len        Def                Constr Req
family                discrete    -   gaussian      gaussian,poisson   -
alpha                  numeric    -          1                0 to 1   -
nfolds                 integer    -         10              3 to Inf   -
type.measure          discrete    -        mse               mse,mae   -
s                     discrete    - lambda.1se lambda.1se,lambda.min   -
nlambda                integer    -        100              1 to Inf   -
lambda.min.ratio       numeric    -          -                0 to 1   -
standardize            logical    -       TRUE                     -   -
intercept              logical    -       TRUE                     -   -
thresh                 numeric    -      1e-07              0 to Inf   -
dfmax                  integer    -          -              0 to Inf   -
pmax                   integer    -          -              0 to Inf   -
exclude          integervector           -              1 to Inf   -
penalty.factor   numericvector           -                0 to 1   -
lower.limits     numericvector           -             -Inf to 0   -
upper.limits     numericvector           -              0 to Inf   -
maxit                  integer    -     100000              1 to Inf   -
type.gaussian         discrete    -          -      covariance,naive   -
fdev                   numeric    -      1e-05                0 to 1   -
devmax                 numeric    -      0.999                0 to 1   -
eps                    numeric    -      1e-06                0 to 1   -
big                    numeric    -    9.9e+35           -Inf to Inf   -
mnlam                  integer    -          5              1 to Inf   -
pmin                   numeric    -      1e-09                0 to 1   -
exmx                   numeric    -        250           -Inf to Inf   -
prec                   numeric    -      1e-10           -Inf to Inf   -
mxit                   integer    -        100              1 to Inf   -
                 Tunable Trafo
family              TRUE     -
alpha               TRUE     -
nfolds              TRUE     -
type.measure        TRUE     -
s                   TRUE     -
nlambda             TRUE     -
lambda.min.ratio    TRUE     -
standardize         TRUE     -
intercept           TRUE     -
thresh              TRUE     -
dfmax               TRUE     -
pmax                TRUE     -
exclude             TRUE     -
penalty.factor      TRUE     -
lower.limits        TRUE     -
upper.limits        TRUE     -
maxit               TRUE     -
type.gaussian       TRUE     -
fdev                TRUE     -
devmax              TRUE     -
eps                 TRUE     -
big                 TRUE     -
mnlam               TRUE     -
pmin                TRUE     -
exmx                TRUE     -
prec                TRUE     -
mxit                TRUE     -

Vous pouvez utiliser l'un de ces paramètres pour définir un ensemble de paramètres valide pour l'accord de cet apprenant particulier, par exemple:

ps = makeParamSet(
  makeDiscreteParam("family", values = c("gaussian", "poisson")),
  makeDiscreteParam("alpha", values = 0.1\*0:10)
)

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X