62 votes

Existe-t-il une fonction pour faire des matrices de nuage de points dans matplotlib ?

Exemple de matrice de nuage de points

enter image description here

Existe-t-il une telle fonction dans matplotlib.pyplot ?

1 votes

120voto

Roman Pekar Points 31863

Pour ceux qui ne veulent pas définir leurs propres fonctions, il existe une excellente bibliothèque d'analyse de données en Python, appelée Pandas où l'on peut trouver le matrice de dispersion() método:

from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd'])
scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde')

enter image description here

2 votes

Bonjour, comment se fait-il que seule une partie des sous-intrigues ait une grille ? Peut-on modifier cela (tout ou rien) ? Merci

5 votes

+1 Ça m'apprendra à chercher une fonctionnalité Python avant de regarder si elle n'est pas déjà dans pandas. Etape 1 : Toujours se demander, est-ce que ça existe déjà dans pandas ? pd.scatter_matrix(df); plt.show() . Incroyable.

2 votes

Placer un kde dans la matrice scatterplot de matplotlib est un sport extrême. J'aime les pandas.

30voto

Joe Kington Points 68089

D'une manière générale, matplotlib ne contient pas de fonctions de traçage qui opèrent sur plus d'un objet d'axe (sous-plot, dans ce cas). On s'attend à ce que vous écriviez une fonction simple pour enchaîner les choses comme vous le souhaitez.

Je ne sais pas exactement à quoi ressemblent vos données, mais il est assez simple de construire une fonction pour faire cela à partir de zéro. Si vous êtes amené à travailler avec des tableaux structurés ou récurrents, vous pouvez simplifier un peu les choses. (c'est-à-dire qu'il y a toujours un nom associé à chaque série de données, vous pouvez donc éviter de devoir spécifier des noms).

A titre d'exemple :

import itertools
import numpy as np
import matplotlib.pyplot as plt

def main():
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
            linestyle='none', marker='o', color='black', mfc='none')
    fig.suptitle('Simple Scatterplot Matrix')
    plt.show()

def scatterplot_matrix(data, names, **kwargs):
    """Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
    against other rows, resulting in a nrows by nrows grid of subplots with the
    diagonal subplots labeled with "names".  Additional keyword arguments are
    passed on to matplotlib's "plot" command. Returns the matplotlib figure
    object containg the subplot grid."""
    numvars, numdata = data.shape
    fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
    fig.subplots_adjust(hspace=0.05, wspace=0.05)

    for ax in axes.flat:
        # Hide all ticks and labels
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        # Set up ticks only on one side for the "edge" subplots...
        if ax.is_first_col():
            ax.yaxis.set_ticks_position('left')
        if ax.is_last_col():
            ax.yaxis.set_ticks_position('right')
        if ax.is_first_row():
            ax.xaxis.set_ticks_position('top')
        if ax.is_last_row():
            ax.xaxis.set_ticks_position('bottom')

    # Plot the data.
    for i, j in zip(*np.triu_indices_from(axes, k=1)):
        for x, y in [(i,j), (j,i)]:
            axes[x,y].plot(data[x], data[y], **kwargs)

    # Label the diagonal subplots...
    for i, label in enumerate(names):
        axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
                ha='center', va='center')

    # Turn on the proper x or y axes ticks.
    for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
        axes[j,i].xaxis.set_visible(True)
        axes[i,j].yaxis.set_visible(True)

    return fig

main()

enter image description here

17voto

sushmit Points 1590

Vous pouvez également utiliser Seaborn's pairplot fonction :

import seaborn as sns
sns.set()
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")

0 votes

La partie ennuyeuse de Seaborn est qu'elle est centrée sur les DataFrames de Pandas. Si vous avez un tableau NumPy, cette solution de contournement est ennuyeuse, et si vous avez déjà un DataFrame pandas, pourquoi ne pas simplement utiliser la méthode scatter_matrix intégrée de pandas ?

0 votes

Malheureusement, il ne permet pas de réaliser des matrices de nuage de points formées par deux groupes distincts de variables. Il ne donne que le tracé vars vs vars. Cela complique l'analyse pour les ensembles de données de taille moyenne et grande.

10voto

tisimst Points 226

Merci de partager votre code ! Vous avez résolu tous les problèmes pour nous. En travaillant avec, j'ai remarqué quelques petites choses qui ne semblaient pas tout à fait correctes.

  1. [CORRECTIF #1] Les axes ne s'alignaient pas comme je l'aurais souhaité (c'est-à-dire que dans votre exemple ci-dessus, vous devriez pouvoir dessiner une ligne verticale et horizontale passant par n'importe quel point sur tous les graphiques et les lignes devraient passer par le point correspondant dans les autres graphiques, mais tel qu'il est maintenant, cela ne se produit pas.

  2. [FIX #2] Si vous avez un nombre impair de variables avec lesquelles vous tracez, les axes du coin inférieur droit ne tirent pas les xtics ou ytics corrects. Il laisse juste les 0..1 ticks par défaut.

  3. Ce n'est pas un correctif, mais j'ai rendu optionnel le fait d'entrer explicitement les données. names de façon à ce qu'il mette une valeur par défaut xi pour la variable i dans les positions diagonales.

Vous trouverez ci-dessous une version actualisée de votre code qui répond à ces deux points, tout en préservant la beauté de votre code.

import itertools
import numpy as np
import matplotlib.pyplot as plt

def scatterplot_matrix(data, names=[], **kwargs):
    """
    Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
    against other rows, resulting in a nrows by nrows grid of subplots with the
    diagonal subplots labeled with "names".  Additional keyword arguments are
    passed on to matplotlib's "plot" command. Returns the matplotlib figure
    object containg the subplot grid.
    """
    numvars, numdata = data.shape
    fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
    fig.subplots_adjust(hspace=0.0, wspace=0.0)

    for ax in axes.flat:
        # Hide all ticks and labels
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        # Set up ticks only on one side for the "edge" subplots...
        if ax.is_first_col():
            ax.yaxis.set_ticks_position('left')
        if ax.is_last_col():
            ax.yaxis.set_ticks_position('right')
        if ax.is_first_row():
            ax.xaxis.set_ticks_position('top')
        if ax.is_last_row():
            ax.xaxis.set_ticks_position('bottom')

    # Plot the data.
    for i, j in zip(*np.triu_indices_from(axes, k=1)):
        for x, y in [(i,j), (j,i)]:
            # FIX #1: this needed to be changed from ...(data[x], data[y],...)
            axes[x,y].plot(data[y], data[x], **kwargs)

    # Label the diagonal subplots...
    if not names:
        names = ['x'+str(i) for i in range(numvars)]

    for i, label in enumerate(names):
        axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
                ha='center', va='center')

    # Turn on the proper x or y axes ticks.
    for i, j in zip(range(numvars), itertools.cycle((-1, 0))):
        axes[j,i].xaxis.set_visible(True)
        axes[i,j].yaxis.set_visible(True)

    # FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
    # correct axes limits, so we pull them from other axes
    if numvars%2:
        xlimits = axes[0,-1].get_xlim()
        ylimits = axes[-1,0].get_ylim()
        axes[-1,-1].set_xlim(xlimits)
        axes[-1,-1].set_ylim(ylimits)

    return fig

if __name__=='__main__':
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
            linestyle='none', marker='o', color='black', mfc='none')
    fig.suptitle('Simple Scatterplot Matrix')
    plt.show()

Merci encore de partager cela avec nous. Je l'ai utilisé de nombreuses fois ! Oh, et j'ai réarrangé le main() une partie du code afin qu'il puisse être un exemple formel de code ou ne pas être appelé s'il est importé dans un autre morceau de code.

4voto

omun Points 103

En lisant la question, je m'attendais à voir une réponse comprenant rpy . Je pense qu'il s'agit d'une option intéressante qui tire parti de deux belles langues. Alors voilà :

import rpy
import numpy as np

def main():
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    mpg = data[0,:]
    disp = data[1,:]
    drat = data[2,:]
    wt = data[3,:]
    rpy.set_default_mode(rpy.NO_CONVERSION)

    R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt)

    # Figure saved as eps
    rpy.r.postscript('pairsPlot.eps')
    rpy.r.pairs(R_data,
       main="Simple Scatterplot Matrix Via RPy")
    rpy.r.dev_off()

    # Figure saved as png
    rpy.r.png('pairsPlot.png')
    rpy.r.pairs(R_data,
       main="Simple Scatterplot Matrix Via RPy")
    rpy.r.dev_off()

    rpy.set_default_mode(rpy.BASIC_CONVERSION)

if __name__ == '__main__': main()

Je ne peux pas poster une image pour montrer le résultat :( désolé !

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X