Je voudrais mettre en œuvre une régression simple de forêt aléatoire pour prédire une valeur. Les entrées sont des échantillons avec plusieurs caractéristiques, et l'étiquette est une valeur. Cependant, je ne trouve pas d'exemple simple sur le problème de la régression par forêt aléatoire. Ainsi, j'ai vu le document de tensorflow et j'ai trouvé ça :
Un estimateur qui peut entraîner et évaluer une forêt aléatoire. Exemple :
python
params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
num_classes=2, num_features=40, num_trees=10, max_nodes=1000)
# Estimator using the default graph builder.
estimator = TensorForestEstimator(params, model_dir=model_dir)
# Or estimator using TrainingLossForest as the graph builder.
estimator = TensorForestEstimator(
params, graph_builder_class=tensor_forest.TrainingLossForest,
model_dir=model_dir)
# Input builders
def input_fn_train: # returns x, y
...
def input_fn_eval: # returns x, y
...
estimator.fit(input_fn=input_fn_train)
estimator.evaluate(input_fn=input_fn_eval)
# Predict returns an iterable of dicts.
results = list(estimator.predict(x=x))
prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME]
Cependant, lorsque je suis l'exemple, j'ai obtenu l'erreur sur la ligne, prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
l'erreur montre que :
Example conversion:
est = Estimator(...) -> est = SKCompat(Estimator(...))
Traceback (most recent call last):
File "RF_2.py", line 312, in <module>
main()
File "RF_2.py", line 298, in main
train_eval(x_train, y_train, x_validation, y_validation, x_test, y_test, num_tree)
File "RF_2.py", line 221, in train_eval
prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
KeyError: 'probabilities'
Je pense que l'erreur se produit sur INFERENCE_PROB_NAME
et j'ai vu le document . Cependant, je ne sais toujours pas quel est le mot pour remplacer INFERENCE_PROB_NAME
.
J'ai essayé get_metric('accuracy')
pour remplacer INFERENCE_PROB_NAME
il renvoie l'erreur : KeyError: <function _accuracy at 0x11a06eaa0>
.
J'ai aussi essayé get_prediction_key('accuracy')
pour remplacer INFERENCE_PROB_NAME
il renvoie l'erreur : KeyError: 'classes'
.
Si vous connaissez la réponse possible, veuillez me la communiquer. Merci d'avance.