En Python et Matplotlib, il est facile d'afficher le tracé sous forme de fenêtre contextuelle ou d'enregistrer le tracé sous forme de fichier PNG. Comment puis-je enregistrer le tracé dans un tableau numpy au format RVB ?
Réponses
Trop de publicités?C'est une astuce pratique pour les tests unitaires et autres, lorsque vous devez effectuer une comparaison pixel à pixel avec un tracé enregistré.
Une façon consiste à utiliser fig.canvas.tostring_rgb
puis numpy.fromstring
avec le dtype approprié. Il y a aussi d'autres méthodes, mais c'est celle que j'ai tendance à utiliser.
Par exemple
import matplotlib.pyplot as plt
import numpy as np
# Make a random plot...
fig = plt.figure()
fig.add_subplot(111)
# If we haven't already shown or saved the plot, then we need to
# draw the figure first...
fig.canvas.draw()
# Now we can save it to a numpy array.
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
Il existe une option un peu plus simple pour la réponse de @JUN_NETWORKS. Au lieu d'enregistrer le chiffre en png
, on peut utiliser un autre format, comme raw
ou rgba
et sauter l'étape de décodage cv2
En d'autres termes, la conversion réelle de plot à numpy se résume à :
io_buf = io.BytesIO()
fig.savefig(io_buf, format='raw', dpi=DPI)
io_buf.seek(0)
img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
io_buf.close()
J'espère que cela t'aides.
Certaines personnes proposent une méthode qui ressemble à celle-ci
np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
Bien sûr, ce code fonctionne. Mais, l'image du tableau numpy de sortie est si basse résolution.
Mon code de proposition est le suivant.
import io
import cv2
import numpy as np
import matplotlib.pyplot as plt
# plot sin wave
fig = plt.figure()
ax = fig.add_subplot(111)
x = np.linspace(-np.pi, np.pi)
ax.set_xlim(-np.pi, np.pi)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.plot(x, np.sin(x), label="sin")
ax.legend()
ax.set_title("sin(x)")
# define a function which returns an image as numpy array from figure
def get_img_from_fig(fig, dpi=180):
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi)
buf.seek(0)
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
buf.close()
img = cv2.imdecode(img_arr, 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
# you can get a high-resolution image as numpy array!!
plot_img_np = get_img_from_fig(fig)
Ce code fonctionne bien. Vous pouvez obtenir une image haute résolution sous forme de tableau numpy si vous définissez un grand nombre sur l'argument dpi.
Au cas où quelqu'un voudrait une solution plug and play, sans modifier aucun code antérieur (obtenir la référence à la figure pyplot et tout), ce qui suit a fonctionné pour moi. Ajoutez simplement ceci après toutes les pyplot
, c'est-à-dire juste avant pyplot.show()
canvas = pyplot.gca().figure.canvas
canvas.draw()
data = numpy.frombuffer(canvas.tostring_rgb(), dtype=numpy.uint8)
image = data.reshape(canvas.get_width_height()[::-1] + (3,))