78 votes

Matplotlib: enregistrer le tracé dans le tableau numpy

En Python et Matplotlib, il est facile d'afficher le tracé sous forme de fenêtre contextuelle ou d'enregistrer le tracé sous forme de fichier PNG. Comment puis-je enregistrer le tracé dans un tableau numpy au format RVB ?

103voto

Joe Kington Points 68089

C'est une astuce pratique pour les tests unitaires et autres, lorsque vous devez effectuer une comparaison pixel à pixel avec un tracé enregistré.

Une façon consiste à utiliser fig.canvas.tostring_rgb puis numpy.fromstring avec le dtype approprié. Il y a aussi d'autres méthodes, mais c'est celle que j'ai tendance à utiliser.

Par exemple

 import matplotlib.pyplot as plt
import numpy as np

# Make a random plot...
fig = plt.figure()
fig.add_subplot(111)

# If we haven't already shown or saved the plot, then we need to
# draw the figure first...
fig.canvas.draw()

# Now we can save it to a numpy array.
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))

43voto

Il existe une option un peu plus simple pour la réponse de @JUN_NETWORKS. Au lieu d'enregistrer le chiffre en png , on peut utiliser un autre format, comme raw ou rgba et sauter l'étape de décodage cv2

En d'autres termes, la conversion réelle de plot à numpy se résume à :

 io_buf = io.BytesIO()
fig.savefig(io_buf, format='raw', dpi=DPI)
io_buf.seek(0)
img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
                     newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
io_buf.close()

J'espère que cela t'aides.

21voto

JUN_NETWORKS Points 181

Certaines personnes proposent une méthode qui ressemble à celle-ci

 np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')

Bien sûr, ce code fonctionne. Mais, l'image du tableau numpy de sortie est si basse résolution.

Mon code de proposition est le suivant.

 import io
import cv2
import numpy as np
import matplotlib.pyplot as plt

# plot sin wave
fig = plt.figure()
ax = fig.add_subplot(111)

x = np.linspace(-np.pi, np.pi)

ax.set_xlim(-np.pi, np.pi)
ax.set_xlabel("x")
ax.set_ylabel("y")

ax.plot(x, np.sin(x), label="sin")

ax.legend()
ax.set_title("sin(x)")


# define a function which returns an image as numpy array from figure
def get_img_from_fig(fig, dpi=180):
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=dpi)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    return img

# you can get a high-resolution image as numpy array!!
plot_img_np = get_img_from_fig(fig)

Ce code fonctionne bien. Vous pouvez obtenir une image haute résolution sous forme de tableau numpy si vous définissez un grand nombre sur l'argument dpi.

4voto

Nagabhushan S N Points 630

Au cas où quelqu'un voudrait une solution plug and play, sans modifier aucun code antérieur (obtenir la référence à la figure pyplot et tout), ce qui suit a fonctionné pour moi. Ajoutez simplement ceci après toutes les pyplot , c'est-à-dire juste avant pyplot.show()

 canvas = pyplot.gca().figure.canvas
canvas.draw()
data = numpy.frombuffer(canvas.tostring_rgb(), dtype=numpy.uint8)
image = data.reshape(canvas.get_width_height()[::-1] + (3,))

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X