5 votes

Lecture de la base de données MNIST

J'explore actuellement les réseaux neuronaux et l'apprentissage automatique et j'ai implémenté un réseau neuronal de base en c#. Je souhaite maintenant tester mon algorithme de formation par rétropropagation avec la base de données MNIST. Cependant, j'ai de sérieux problèmes pour lire correctement les fichiers.

Spoiler le code est actuellement très mal optimisé pour les performances. Mon objectif actuel est d'appréhender le sujet et d'avoir une vision structurée de la façon dont les choses fonctionnent avant de commencer à jeter mes structures de données pour des structures plus rapides.

Pour entraîner le réseau, je veux lui fournir une structure de données TrainingSet personnalisée :

[Serializable]
public class TrainingSet
{
    public Dictionary<List<double>, List<double>> data = new Dictionary<List<double>, List<double>>();
}

Les clés seront mes données d'entrée (784 pixels par entrée (image) qui représenteront les valeurs en niveaux de gris dans la gamme de 0 à 1). Les valeurs seront mes données de sortie (10 entrées représentant les chiffres de 0 à 9 avec toutes les entrées sur 0 sauf celle attendue sur 1).

Maintenant je veux lire la base de données MNIST selon ce contrat. J'en suis actuellement à mon deuxième essai qui est inspiré par ce billet : https://jamesmccaffrey.wordpress.com/2013/11/23/reading-the-mnist-data-set-with-c/ . Malheureusement, il produit toujours les mêmes bêtises que mon premier essai, en dispersant les pixels selon un modèle étrange : Pattern screenshot

Mon algorithme de lecture actuel :

    public static TrainingSet GenerateTrainingSet(FileInfo imagesFile, FileInfo labelsFile)
    {
        MnistImageView imageView = new MnistImageView();
        imageView.Show();

        TrainingSet trainingSet = new TrainingSet();

        List<List<double>> labels = new List<List<double>>();
        List<List<double>> images = new List<List<double>>();

        using (BinaryReader brLabels = new BinaryReader(new FileStream(labelsFile.FullName, FileMode.Open)))
        {
            using (BinaryReader brImages = new BinaryReader(new FileStream(imagesFile.FullName, FileMode.Open)))
            {
                int magic1 = brImages.ReadBigInt32(); //Reading as BigEndian
                int numImages = brImages.ReadBigInt32();
                int numRows = brImages.ReadBigInt32();
                int numCols = brImages.ReadBigInt32();

                int magic2 = brLabels.ReadBigInt32();
                int numLabels = brLabels.ReadBigInt32();

                byte[] pixels = new byte[numRows * numCols];

                // each image
                for (int imageCounter = 0; imageCounter < numImages; imageCounter++)
                {
                    List<double> imageInput = new List<double>();
                    List<double> exspectedOutput = new List<double>();

                    for (int i = 0; i < 10; i++) //generate empty exspected output
                        exspectedOutput.Add(0);

                    //read image
                    for (int p = 0; p < pixels.Length; p++)
                    {
                        byte b = brImages.ReadByte();
                        pixels[p] = b;

                        imageInput.Add(b / 255.0f); //scale in 0 to 1 range
                    }

                    //read label
                    byte lbl = brLabels.ReadByte();
                    exspectedOutput[lbl] = 1; //modify exspected output

                    labels.Add(exspectedOutput);
                    images.Add(imageInput);

                    //Debug view showing parsed image.......................
                    Bitmap image = new Bitmap(numCols, numRows);

                    for (int y = 0; y < numRows; y++)
                    {
                        for (int x = 0; x < numCols; x++)
                        {
                            image.SetPixel(x, y, Color.FromArgb(255 - pixels[x * y], 255 - pixels[x * y], 255 - pixels[x * y])); //invert colors to have 0,0,0 be white as specified by mnist
                        }
                    }

                    imageView.SetImage(image);
                    imageView.Refresh();
                    //.......................................................
                }

                brImages.Close();
                brLabels.Close();
            }
        }

        for (int i = 0; i < images.Count; i++)
        {
            trainingSet.data.Add(images[i], labels[i]);
        }

        return trainingSet;
    }

Toutes les images produisent un motif comme indiqué ci-dessus. Ce n'est jamais exactement le même motif, mais les pixels semblent toujours être "tirés" vers le bas, dans le coin droit.

10voto

koryakinp Points 1980

C'est comme ça que j'ai fait :

public static class MnistReader
{
    private const string TrainImages = "mnist/train-images.idx3-ubyte";
    private const string TrainLabels = "mnist/train-labels.idx1-ubyte";
    private const string TestImages = "mnist/t10k-images.idx3-ubyte";
    private const string TestLabels = "mnist/t10k-labels.idx1-ubyte";

    public static IEnumerable<Image> ReadTrainingData()
    {
        foreach (var item in Read(TrainImages, TrainLabels))
        {
            yield return item;
        }
    }

    public static IEnumerable<Image> ReadTestData()
    {
        foreach (var item in Read(TestImages, TestLabels))
        {
            yield return item;
        }
    }

    private static IEnumerable<Image> Read(string imagesPath, string labelsPath)
    {
        BinaryReader labels = new BinaryReader(new FileStream(labelsPath, FileMode.Open));
        BinaryReader images = new BinaryReader(new FileStream(imagesPath, FileMode.Open));

        int magicNumber = images.ReadBigInt32();
        int numberOfImages = images.ReadBigInt32();
        int width = images.ReadBigInt32();
        int height = images.ReadBigInt32();

        int magicLabel = labels.ReadBigInt32();
        int numberOfLabels = labels.ReadBigInt32();

        for (int i = 0; i < numberOfImages; i++)
        {
            var bytes = images.ReadBytes(width * height);
            var arr = new byte[height, width];

            arr.ForEach((j,k) => arr[j, k] = bytes[j * height + k]);

            yield return new Image()
            {
                Data = arr,
                Label = labels.ReadByte()
            };
        }
    }
}

Image classe :

public class Image
{
    public byte Label { get; set; }
    public byte[,] Data { get; set; }
}

Quelques méthodes d'extension :

public static class Extensions
{
    public static int ReadBigInt32(this BinaryReader br)
    {
        var bytes = br.ReadBytes(sizeof(Int32));
        if (BitConverter.IsLittleEndian) Array.Reverse(bytes);
        return BitConverter.ToInt32(bytes, 0);
    }

    public static void ForEach<T>(this T[,] source, Action<int, int> action)
    {
        for (int w = 0; w < source.GetLength(0); w++)
        {
            for (int h = 0; h < source.GetLength(1); h++)
            {
                action(w, h);
            }
        }
    }
}

Utilisation :

foreach (var image in MnistReader.ReadTrainingData())
{
    //use image here     
}

ou

foreach (var image in MnistReader.ReadTestData())
{
    //use image here     
}

3voto

Guy Langston Points 61

Pourquoi ne pas utiliser un paquet nuget :

  • MNIST.IO Just a datareader (disclaimer : mon paquet)
  • Accord.DataSets Contient des classes pour télécharger et analyser des ensembles de données d'apprentissage automatique tels que MNIST, News20, Iris. Ce paquet fait partie du cadre Accord.NET.

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X