金子邦彦研究室インストールオープンデータ,データファイル処理npz 形式のMNIST データセットを,CSVファイルに変換

npz 形式のMNIST データセットを,CSVファイルに変換

要点

  1. MNIST の Web ページを開く http://yann.lecun.com/exdb/mnist/
  2. 説明文を確認

    この Web ページ、http://yann.lecun.com/exdb/mnist/ に説明が書いてあるので、読んでおくこと。この Web ページの一部分を引用すると、

    "The MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers."
    
  3. ダウンロード解凍

    次の4つのファイルをダウンロードし,展開(解凍)する.

  4. 展開(解凍)してできたファイルを1つのディレクトリに集める。
  5. この Web ページの末尾にあるプログラムのソースコードからビルドして,インストールする.

    ソースコード中の「 private static String DIR = "C:\\R\\";」 の行は、ディレクトリ名にあわせて書き換える。

  6. プログラムの実行
  7. ファイルができるので確認する
  8. ファイルの中身を確認する

使用するプログラムのソースコード

package hoge.hoge.com;

import java.io.BufferedInputStream;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStreamWriter;

import sun.java2d.pipe.BufferedOpCodes;

public class MNISTRead {
    private static String DIR = "C:\\R\\";
    private static String T10K_IMAGES_IDX3_FILE_NAME = DIR + "t10k-images.idx3-ubyte";
    private static String T10K_LABELS_IDX1_FILE_NAME = DIR + "t10k-labels.idx1-ubyte";
    private static String TRAIN_IMAGES_IDX3_FILE_NAME = DIR + "train-images.idx3-ubyte";
    private static String TRAIN_LABELS_IDX1_FILE_NAME = DIR + "train-labels.idx1-ubyte";
    public static void main(String[] args) {

        InputStream in = null;
        BufferedWriter fw = null;
       
        try {
            System.out.println("----------------");
            System.out.println("TRAINING SET LABEL FILE, file name = " + TRAIN_LABELS_IDX1_FILE_NAME );
            System.out.println("----------------");
            in = new BufferedInputStream(new FileInputStream( TRAIN_LABELS_IDX1_FILE_NAME ) );
            fw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(TRAIN_LABELS_IDX1_FILE_NAME + ".csv"), "Shift_JIS"));
            fw.write( "id, label" );
            fw.newLine();
            
            int b;
            int bytes = 0;
            // 1バイト単位で読み込み
            while ((b = in.read()) != -1) {
                if ( ( bytes == 0 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 1 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 2 ) && ( b != 0x08 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 3 ) && ( b != 0x01 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 4 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 5 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 6 ) && ( b != (60000 / 256 ) ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 7 ) && ( b != (60000 % 256 ) ) ) System.out.println( "BAD Number of Items" );
                if ( bytes >= 8 ) {
                    StringBuffer buf = new StringBuffer();
                    // 1列目が画像番号,2列目が「label」
                    buf.append( Integer.toString( bytes - 7 )); // 1列目
                    buf.append(",");
                    buf.append(b); // 2行目
                    fw.write( buf.toString(), 0, buf.toString().length() );
                    fw.newLine();
                }
                bytes++;
            }
            fw.flush();
            System.out.println("----------------");
            System.out.println("CVS file created ! file name = " + TRAIN_LABELS_IDX1_FILE_NAME + ".csv" );
            System.out.println("----------------");

        } catch (FileNotFoundException e) {
            System.out.println( "File Not Found" );
        } catch (IOException e) {
            System.out.println( "I/O Exception" );
        } finally {
            try {
                if (in != null) {
                    in.close();
                }
                if (fw != null) {
                    fw.close();
                }
            } catch (Exception e) {
            }
        }
       
        // ---------------------------------------------------------------------------------------------------
        try {
            System.out.println("----------------");
            System.out.println("TEST SET LABEL FILE, file name = " + T10K_LABELS_IDX1_FILE_NAME);
            System.out.println("----------------");
            in = new BufferedInputStream(new FileInputStream( T10K_LABELS_IDX1_FILE_NAME ) );
            fw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(T10K_LABELS_IDX1_FILE_NAME + ".csv"), "Shift_JIS"));
            fw.write( "id, label" );
            fw.newLine();
            
            int b;
            int bytes = 0;
            // 1バイト単位で読み込み
            while ((b = in.read()) != -1) {
                if ( ( bytes == 0 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 1 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 2 ) && ( b != 0x08 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 3 ) && ( b != 0x01 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 4 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 5 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 6 ) && ( b != (10000 / 256 ) ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 7 ) && ( b != (10000 % 256 ) ) ) System.out.println( "BAD Number of Items" );
                if ( bytes >= 8 ) {
                    StringBuffer buf = new StringBuffer();
                    // 1列目が画像番号,2列目が「label」
                    buf.append( Integer.toString( bytes - 7 )); // 1列目
                    buf.append(",");
                    buf.append(b); // 2行目
                    fw.write( buf.toString(), 0, buf.toString().length() );
                    fw.newLine();
                }
                bytes++;
            }
            fw.flush();
            System.out.println("----------------");
            System.out.println("CVS file created ! file name = " + T10K_LABELS_IDX1_FILE_NAME + ".csv" );
            System.out.println("----------------");

        } catch (FileNotFoundException e) {
            System.out.println( "File Not Found" );
        } catch (IOException e) {
            System.out.println( "I/O Exception" );
        } finally {
            try {
                if (in != null) {
                    in.close();
                }
                if (fw != null) {
                    fw.close();
                }
            } catch (Exception e) {
            }
        }
       
     // ---------------------------------------------------------------------------------------------------
        try {
            System.out.println("----------------");
            System.out.println("TRAINING SET IMAGE FILE, file name = " + TRAIN_IMAGES_IDX3_FILE_NAME);
            System.out.println("----------------");
            in = new BufferedInputStream(new FileInputStream( TRAIN_IMAGES_IDX3_FILE_NAME ) );
            fw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(TRAIN_IMAGES_IDX3_FILE_NAME + ".csv"), "Shift_JIS"));
            fw.write( "id, pixel" );
            fw.newLine();
            
            int b;
            int bytes = 0;
            StringBuffer buf = null;
            // 1バイト単位で読み込み
            while ((b = in.read()) != -1) {
                if ( ( bytes == 0 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 1 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 2 ) && ( b != 0x08 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 3 ) && ( b != 0x03 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 4 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 5 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 6 ) && ( b != (60000 / 256 ) ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 7 ) && ( b != (60000 % 256 ) ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 8 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" );
                if ( ( bytes == 9 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" );
                if ( ( bytes == 10 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" );
                if ( ( bytes == 11 ) && ( b != 28   ) ) System.out.println( "BAD Number of Rows" );
                if ( ( bytes == 12 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" );
                if ( ( bytes == 13 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" );
                if ( ( bytes == 14 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" );
                if ( ( bytes == 15 ) && ( b != 28   ) ) System.out.println( "BAD Number of Columns" );
                if ( bytes >= 16 ) {
                    // 本体
                    if ( ( ( bytes - 16 ) % ( 28 * 28 ) ) == 0 ) {
                        // 1画像の始まり
                        buf = new StringBuffer();
                        buf.append( Integer.toString( ( bytes - 16 ) / ( 28 * 28 ) + 1 ) );
                        buf.append(",");
                    }
                    buf.append(b);
                    if ( ( ( bytes - 16 ) % ( 28 * 28 ) ) < ( ( 28 * 28 ) - 1 ) ) {
                        // まだ1画像分は終わっていない
                        buf.append(",");
                    }
                    else {
                        // 1画像の終わり
                        fw.write( buf.toString(), 0, buf.toString().length() );
                        fw.newLine();
                    }
                }
                bytes++;
            } // end of while
            fw.flush();
            System.out.println("----------------");
            System.out.println("CVS file created ! file name = " + TRAIN_IMAGES_IDX3_FILE_NAME + ".csv" );
            System.out.println("----------------");

        } catch (FileNotFoundException e) {
            System.out.println( "File Not Found" );
        } catch (IOException e) {
            System.out.println( "I/O Exception" );
        } finally {
            try {
                if (in != null) {
                    in.close();
                }
                if (fw != null) {
                    fw.close();
                }
            } catch (Exception e) {
            }
        }
       
     // ---------------------------------------------------------------------------------------------------
        try {
            System.out.println("----------------");
            System.out.println("TEST SET IMAGE FILE, file name = " + T10K_IMAGES_IDX3_FILE_NAME);
            System.out.println("----------------");
            in = new BufferedInputStream(new FileInputStream( T10K_IMAGES_IDX3_FILE_NAME ) );
            fw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(T10K_IMAGES_IDX3_FILE_NAME + ".csv"), "Shift_JIS"));
            fw.write( "id, pixel" );
            fw.newLine();
            
            int b;
            int bytes = 0;
            StringBuffer buf = null;
            // 1バイト単位で読み込み
            while ((b = in.read()) != -1) {
                if ( ( bytes == 0 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 1 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 2 ) && ( b != 0x08 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 3 ) && ( b != 0x03 ) ) System.out.println( "BAD Magic Number" );
                if ( ( bytes == 4 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 5 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 6 ) && ( b != (10000 / 256 ) ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 7 ) && ( b != (10000 % 256 ) ) ) System.out.println( "BAD Number of Items" );
                if ( ( bytes == 8 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" );
                if ( ( bytes == 9 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" );
                if ( ( bytes == 10 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" );
                if ( ( bytes == 11 ) && ( b != 28   ) ) System.out.println( "BAD Number of Rows" );
                if ( ( bytes == 12 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" );
                if ( ( bytes == 13 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" );
                if ( ( bytes == 14 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" );
                if ( ( bytes == 15 ) && ( b != 28   ) ) System.out.println( "BAD Number of Columns" );
                if ( bytes >= 16 ) {
                    // 本体
                    if ( ( ( bytes - 16 ) % ( 28 * 28 ) ) == 0 ) {
                        // 1画像の始まり
                        buf = new StringBuffer();
                        buf.append( Integer.toString( ( bytes - 16 ) / ( 28 * 28 ) + 1 ) );
                        buf.append(",");
                    }
                    buf.append(b);
                    if ( ( ( bytes - 16 ) % ( 28 * 28 ) ) < ( ( 28 * 28 ) - 1 ) ) {
                        // まだ1画像分は終わっていない
                        buf.append(",");
                    }
                    else {
                        // 1画像の終わり
                        fw.write( buf.toString(), 0, buf.toString().length() );
                        fw.newLine();
                    }
                }
                bytes++;
            } // end of while
            fw.flush();
            System.out.println("----------------");
            System.out.println("CVS file created ! file name = " + T10K_IMAGES_IDX3_FILE_NAME + ".csv" );
            System.out.println("----------------");

        } catch (FileNotFoundException e) {
            System.out.println( "File Not Found" );
        } catch (IOException e) {
            System.out.println( "I/O Exception" );
        } finally {
            try {
                if (in != null) {
                    in.close();
                }
                if (fw != null) {
                    fw.close();
                }
            } catch (Exception e) {
            }
        }
    }
}