5 votes

Légende Pandas pour la matrice de dispersion

J'ai un dataframe pandas avec 3 classes et des datapoints de n caractéristiques.

Le code suivant produit une matrice de dispersion avec des histogrammes en diagonale, de 4 des caractéristiques du dataframe.

colums = ['n1','n2','n3','n4']
grr = pd.scatter_matrix(
dataframe[columns], c=y_train, figsize=(15,15), label=['B','N','O'], marker='.',
    hist_kwds={'bins':20}, s=10, alpha=.8, cmap='brg')
plt.legend()
plt.show()

comme ceci:

Matrice de dispersion de ce dataframe

Le problème que je rencontre est que plt.legend() ne semble pas fonctionner, aucune légende n'est affichée du tout (ou c'est le petit 'le8' à peine visible dans la première colonne de la deuxième ligne...)

Ce que je voudrais avoir est une seule légende qui montre simplement quelle couleur correspond à quelle classe.

J'ai essayé toutes les questions suggérées mais aucune n'a de solution. J'ai également essayé de mettre les étiquettes dans les paramètres de la fonction legend comme ceci:

plt.legend(label=['B','N','O'], loc=1)

mais en vain..

Qu'est-ce que je fais de mal?

5voto

ImportanceOfBeingErnest Points 119438

Le scatter_matrix des pandas est un wrapper pour plusieurs graphiques scatter de matplotlib. Les arguments sont transmis à la fonction scatter. Cependant, le scatter est généralement utilisé avec une colormap et non une légende avec des points étiquetés de manière discrète, donc aucun argument n'est disponible pour créer automatiquement une légende.

Je crains que vous deviez créer la légende manuellement. À cette fin, vous pouvez créer les points à partir du scatter en utilisant la fonction plot de matplotlib (avec des données vides) et les ajouter en tant que handles à la légende.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.subplot.right"] = 0.8

v= np.random.rayleigh(size=(30,5))
v[:,4] = np.random.randint(1,4,size=30)/3.
dataframe= pd.DataFrame(v, columns=['n1','n2','n3','n4',"c"])

columns = ['n1','n2','n3','n4']
grr = pd.scatter_matrix(
dataframe[columns], c=dataframe["c"], figsize=(7,5), label=['B','N','O'], marker='.',
    hist_kwds={'bins':20}, s=10, alpha=.8, cmap='brg')

handles = [plt.plot([],[],color=plt.cm.brg(i/2.), ls="", marker=".", \
                    markersize=np.sqrt(10))[0] for i in range(3)]
labels=["Label A", "Label B", "Label C"]
plt.legend(handles, labels, loc=(1.02,0))
plt.show()

description de l'image ici

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X