2 votes

Comment ajouter une légende dans un nuage de points DataFrame de Pandas ?

J'ai un DataFrame pandas qui contient les colonnes d'intérêt suivantes :

['Relative Width', 'Relative Height', 'Object Name', 'Object ID']

Il y a 15 noms d'objets avec 15 couleurs déterminées avec df.plot(c='Object ID') qui produisent la figure suivante :

fig

Je veux afficher une légende avec les 15 noms d'objets, comment faire ?

import matplotlib.pyplot as plt
from annotation_parsers import parse_voc_folder

def visualize_box_relative_sizes(folder_path, voc_conf, cache_file='data_set_labels.csv'):
    frame = parse_voc_folder(folder_path, voc_conf, cache_file)
    title = f'Relative width and height for {frame.shape[0]} boxes.'
    frame.plot(
        kind='scatter',
        x='Relative Width',
        y='Relative Height',
        title=title,
        c='Object ID',
        colormap='gist_rainbow',
        colorbar=False,
    )
    plt.show()

Sur la base de la recommandation de wwnde, j'ai modifié le code comme suit :

def visualize_box_relative_sizes(folder_path, voc_conf, cache_file='data_set_labels.csv'):
    frame = parse_voc_folder(folder_path, voc_conf, cache_file)
    title = f'Relative width and height for {frame.shape[0]} boxes.'
    sns.scatterplot(x=frame["Relative Width"], y=frame["Relative Height"], hue=frame["Object Name"])
    plt.title(title)
    plt.show()

ce qui donne le résultat suivant :

enter image description here

1voto

wwnde Points 14457

Essayez, s'il vous plaît

fig, ax = plt.subplots()

ax = sns.scatterplot(x="total_bill", y="tip",
                     hue="size", size="size",
                     data=tips)
ax.set_title('title')
plt.show()

Cela devrait vous donner une légende colorée par défaut

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X