12 votes

Élaguer les feuilles inutiles dans le classificateur DecisionTreeClassifier de sklearn

J'utilise sklearn.tree.DecisionTreeClassifier pour construire un arbre de décision. Avec les réglages optimaux des paramètres, j'obtiens un arbre qui a des feuilles inutiles (cf. exemple image ci-dessous - je n'ai pas besoin de probabilités, donc les nœuds feuilles marqués en rouge sont une division inutile)

Tree

Existe-t-il une bibliothèque tierce pour élaguer ces nœuds inutiles ? Ou un extrait de code ? Je pourrais en écrire un, mais je ne pense pas être la première personne à rencontrer ce problème...

Code à répliquer :

from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target
mdl = DecisionTreeClassifier(max_leaf_nodes=8)
mdl.fit(X,y)

PS : J'ai essayé plusieurs recherches par mots-clés et je suis un peu surpris de ne rien trouver - n'y a-t-il vraiment pas de post-élimination en général dans Sklearn ?

PPS : En réponse à l'éventuel doublon : Alors que la question proposée pourrait m'aider à coder moi-même l'algorithme d'élagage, il répond à une question différente - je veux me débarrasser des feuilles qui ne changent pas la décision finale, alors que l'autre question veut un seuil minimum pour la division des nœuds.

PPPS : L'arbre présenté est un exemple pour montrer mon problème. Je suis conscient du fait que les paramètres utilisés pour créer l'arbre sont sous-optimaux. Je ne demande pas d'optimiser cet arbre spécifique, j'ai besoin de faire un post-élagage pour me débarrasser des feuilles qui pourraient être utiles si on a besoin des probabilités de classe, mais qui ne sont pas utiles si on est seulement intéressé par la classe la plus probable.

16voto

Thomas Points 1141

En utilisant le lien de ncfirth, j'ai pu modifier le code qui s'y trouve pour qu'il corresponde à mon problème :

from sklearn.tree._tree import TREE_LEAF

def is_leaf(inner_tree, index):
    # Check whether node is leaf node
    return (inner_tree.children_left[index] == TREE_LEAF and 
            inner_tree.children_right[index] == TREE_LEAF)

def prune_index(inner_tree, decisions, index=0):
    # Start pruning from the bottom - if we start from the top, we might miss
    # nodes that become leaves during pruning.
    # Do not use this directly - use prune_duplicate_leaves instead.
    if not is_leaf(inner_tree, inner_tree.children_left[index]):
        prune_index(inner_tree, decisions, inner_tree.children_left[index])
    if not is_leaf(inner_tree, inner_tree.children_right[index]):
        prune_index(inner_tree, decisions, inner_tree.children_right[index])

    # Prune children if both children are leaves now and make the same decision:     
    if (is_leaf(inner_tree, inner_tree.children_left[index]) and
        is_leaf(inner_tree, inner_tree.children_right[index]) and
        (decisions[index] == decisions[inner_tree.children_left[index]]) and 
        (decisions[index] == decisions[inner_tree.children_right[index]])):
        # turn node into a leaf by "unlinking" its children
        inner_tree.children_left[index] = TREE_LEAF
        inner_tree.children_right[index] = TREE_LEAF
        ##print("Pruned {}".format(index))

def prune_duplicate_leaves(mdl):
    # Remove leaves if both 
    decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
    prune_index(mdl.tree_, decisions)

Utilisation sur un DecisionTreeClassifier clf :

prune_duplicate_leaves(clf)

Edit : Correction d'un bug pour les arbres plus complexes

1voto

jonnor Points 587

DecisionTreeClassifier(max_leaf_nodes=8) spécifie (max) 8 feuilles, donc à moins que le constructeur d'arbres ait une autre raison d'arrêter, il atteindra le maximum.

Dans l'exemple présenté, 5 des 8 feuilles ont un nombre très faible d'échantillons (<=3) par rapport aux 3 autres feuilles (>50), ce qui peut être un signe de sur-ajustement. Au lieu d'élaguer l'arbre après la formation, on peut spécifier soit min_samples_leaf o min_samples_split pour mieux guider la formation, ce qui permettra probablement d'éliminer les feuilles problématiques. Par exemple, utilisez la valeur 0.05 pour au moins 5% des échantillons.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X