Source code for neodroidvision.data.synthesis.conversion.mnist.convert_mnist_to_png

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian"
__doc__ = r"""

           Created on 22/03/2020
           """

__all__ = ["convert_to_mnist_png", "generate_mnist_png"]

import os
import struct
import sys
from array import array
from pathlib import Path
from typing import Tuple

import torchvision.datasets
from draugr.numpy_utilities import SplitEnum

from neodroidvision import PROJECT_APP_PATH


def read_data(
    dataset: SplitEnum = SplitEnum.training,
    path: Path = PROJECT_APP_PATH.user_cache / "mnist",
) -> Tuple:
    """

    Args:
      dataset:
      path:

    Returns:

    """
    if dataset is SplitEnum.training:
        file_name_img = path / "train-images-idx3-ubyte"
        file_name_category = path / "train-labels-idx1-ubyte"
    elif dataset is SplitEnum.testing:
        file_name_img = path / "t10k-images-idx3-ubyte"
        file_name_category = path / "t10k-labels-idx1-ubyte"
    else:
        raise ValueError("dataset must be 'testing' or 'training'")

    with open(file_name_category, "rb") as category_file:
        magic_nr, size = struct.unpack(">II", category_file.read(8))
        category = array("b", category_file.read())

    with open(file_name_img, "rb") as image_file:
        magic_nr, size, rows, cols = struct.unpack(">IIII", image_file.read(16))
        image = array("B", image_file.read())

    return category, image, size, rows, cols


def write_dataset(labels, data, size, rows, cols, output_dir) -> None:
    """

    Args:
        labels:
        data:
        size:
        rows:
        cols:
        output_dir:

    Returns:

    """
    output_dirs = [output_dir / str(i) for i in range(10)]
    for dir in output_dirs:  # create output directories
        if not dir.exists():
            os.makedirs(dir)
    import png  # pip install pypng

    for (i, label) in enumerate(labels):
        output_filename = output_dirs[label] / f"{str(i)}.png"
        print(f"writing {output_filename}")
        with open(output_filename, "wb") as h:
            w = png.Writer(cols, rows, greyscale=True)
            data_i = [
                data[(i * rows * cols + j * cols) : (i * rows * cols + (j + 1) * cols)]
                for j in range(rows)
            ]
            w.write(h, data_i)


[docs]def convert_to_mnist_png(input_path, output_path) -> None: """ Args: input_path: output_path: """ for dataset in [SplitEnum.training, SplitEnum.testing]: write_dataset(*read_data(dataset, input_path), output_path / dataset.value)
[docs]def generate_mnist_png( destination_path: Path = PROJECT_APP_PATH.user_cache / "mnist_png", cache_path: Path = PROJECT_APP_PATH.user_cache, ) -> None: """ Args: path: :param destination_path: :type destination_path: :param cache_path: :type cache_path:""" base = cache_path / "mnist" torchvision.datasets.MNIST(str(base), download=True) torchvision.datasets.MNIST(str(base), train=False, download=True) convert_to_mnist_png(base / "MNIST" / "raw", destination_path) print(f"generated mnist_png at {destination_path}")
if __name__ == "__main__": def main2(): """description""" if len(sys.argv) != 3: print(f"usage: {sys.argv[0]} <input_path> <output_path>") sys.exit() input_path = sys.argv[1] output_path = sys.argv[2] convert_to_mnist_png(input_path, output_path) generate_mnist_png()