npz 形式のMNIST データセットを,CSVファイルに変換
要点
- MNIST データセットを R で扱いたい.CSV ファイル形式に変換しておく方が取り扱いやすい(と感じた)
- MNIST の Web ページを開く http://yann.lecun.com/exdb/mnist/
- 説明文を確認
この Web ページ、http://yann.lecun.com/exdb/mnist/ に説明が書いてあるので、読んでおくこと。この Web ページの一部分を引用すると、
"The MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers."
- ダウンロードと解凍
次の4つのファイルをダウンロードし,展開(解凍)する.
- train-images-idx3-ubyte.gz
- train-labels-idx1-ubyte.gz
- t10k-images-idx3-ubyte.gz
- t10k-labels-idx1-ubyte.gz
- 展開(解凍)してできたファイルを1つのディレクトリに集める。
- この Web ページの末尾にあるプログラムのソースコードからビルドして,インストールする.
ソースコード中の「 private static String DIR = "C:\\R\\";」 の行は、ディレクトリ名にあわせて書き換える。
- プログラムの実行
- ファイルができるので確認する
- ファイルの中身を確認する
使用するプログラムのソースコード
package hoge.hoge.com; import java.io.BufferedInputStream; import java.io.BufferedWriter; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStreamWriter; import sun.java2d.pipe.BufferedOpCodes; public class MNISTRead { private static String DIR = "C:\\R\\"; private static String T10K_IMAGES_IDX3_FILE_NAME = DIR + "t10k-images.idx3-ubyte"; private static String T10K_LABELS_IDX1_FILE_NAME = DIR + "t10k-labels.idx1-ubyte"; private static String TRAIN_IMAGES_IDX3_FILE_NAME = DIR + "train-images.idx3-ubyte"; private static String TRAIN_LABELS_IDX1_FILE_NAME = DIR + "train-labels.idx1-ubyte"; public static void main(String[] args) { InputStream in = null; BufferedWriter fw = null; try { System.out.println("----------------"); System.out.println("TRAINING SET LABEL FILE, file name = " + TRAIN_LABELS_IDX1_FILE_NAME ); System.out.println("----------------"); in = new BufferedInputStream(new FileInputStream( TRAIN_LABELS_IDX1_FILE_NAME ) ); fw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(TRAIN_LABELS_IDX1_FILE_NAME + ".csv"), "Shift_JIS")); fw.write( "id, label" ); fw.newLine(); int b; int bytes = 0; // 1バイト単位で読み込み while ((b = in.read()) != -1) { if ( ( bytes == 0 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 1 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 2 ) && ( b != 0x08 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 3 ) && ( b != 0x01 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 4 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 5 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 6 ) && ( b != (60000 / 256 ) ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 7 ) && ( b != (60000 % 256 ) ) ) System.out.println( "BAD Number of Items" ); if ( bytes >= 8 ) { StringBuffer buf = new StringBuffer(); // 1列目が画像番号,2列目が「label」 buf.append( Integer.toString( bytes - 7 )); // 1列目 buf.append(","); buf.append(b); // 2行目 fw.write( buf.toString(), 0, buf.toString().length() ); fw.newLine(); } bytes++; } fw.flush(); System.out.println("----------------"); System.out.println("CVS file created ! file name = " + TRAIN_LABELS_IDX1_FILE_NAME + ".csv" ); System.out.println("----------------"); } catch (FileNotFoundException e) { System.out.println( "File Not Found" ); } catch (IOException e) { System.out.println( "I/O Exception" ); } finally { try { if (in != null) { in.close(); } if (fw != null) { fw.close(); } } catch (Exception e) { } } // --------------------------------------------------------------------------------------------------- try { System.out.println("----------------"); System.out.println("TEST SET LABEL FILE, file name = " + T10K_LABELS_IDX1_FILE_NAME); System.out.println("----------------"); in = new BufferedInputStream(new FileInputStream( T10K_LABELS_IDX1_FILE_NAME ) ); fw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(T10K_LABELS_IDX1_FILE_NAME + ".csv"), "Shift_JIS")); fw.write( "id, label" ); fw.newLine(); int b; int bytes = 0; // 1バイト単位で読み込み while ((b = in.read()) != -1) { if ( ( bytes == 0 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 1 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 2 ) && ( b != 0x08 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 3 ) && ( b != 0x01 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 4 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 5 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 6 ) && ( b != (10000 / 256 ) ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 7 ) && ( b != (10000 % 256 ) ) ) System.out.println( "BAD Number of Items" ); if ( bytes >= 8 ) { StringBuffer buf = new StringBuffer(); // 1列目が画像番号,2列目が「label」 buf.append( Integer.toString( bytes - 7 )); // 1列目 buf.append(","); buf.append(b); // 2行目 fw.write( buf.toString(), 0, buf.toString().length() ); fw.newLine(); } bytes++; } fw.flush(); System.out.println("----------------"); System.out.println("CVS file created ! file name = " + T10K_LABELS_IDX1_FILE_NAME + ".csv" ); System.out.println("----------------"); } catch (FileNotFoundException e) { System.out.println( "File Not Found" ); } catch (IOException e) { System.out.println( "I/O Exception" ); } finally { try { if (in != null) { in.close(); } if (fw != null) { fw.close(); } } catch (Exception e) { } } // --------------------------------------------------------------------------------------------------- try { System.out.println("----------------"); System.out.println("TRAINING SET IMAGE FILE, file name = " + TRAIN_IMAGES_IDX3_FILE_NAME); System.out.println("----------------"); in = new BufferedInputStream(new FileInputStream( TRAIN_IMAGES_IDX3_FILE_NAME ) ); fw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(TRAIN_IMAGES_IDX3_FILE_NAME + ".csv"), "Shift_JIS")); fw.write( "id, pixel" ); fw.newLine(); int b; int bytes = 0; StringBuffer buf = null; // 1バイト単位で読み込み while ((b = in.read()) != -1) { if ( ( bytes == 0 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 1 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 2 ) && ( b != 0x08 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 3 ) && ( b != 0x03 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 4 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 5 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 6 ) && ( b != (60000 / 256 ) ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 7 ) && ( b != (60000 % 256 ) ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 8 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" ); if ( ( bytes == 9 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" ); if ( ( bytes == 10 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" ); if ( ( bytes == 11 ) && ( b != 28 ) ) System.out.println( "BAD Number of Rows" ); if ( ( bytes == 12 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" ); if ( ( bytes == 13 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" ); if ( ( bytes == 14 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" ); if ( ( bytes == 15 ) && ( b != 28 ) ) System.out.println( "BAD Number of Columns" ); if ( bytes >= 16 ) { // 本体 if ( ( ( bytes - 16 ) % ( 28 * 28 ) ) == 0 ) { // 1画像の始まり buf = new StringBuffer(); buf.append( Integer.toString( ( bytes - 16 ) / ( 28 * 28 ) + 1 ) ); buf.append(","); } buf.append(b); if ( ( ( bytes - 16 ) % ( 28 * 28 ) ) < ( ( 28 * 28 ) - 1 ) ) { // まだ1画像分は終わっていない buf.append(","); } else { // 1画像の終わり fw.write( buf.toString(), 0, buf.toString().length() ); fw.newLine(); } } bytes++; } // end of while fw.flush(); System.out.println("----------------"); System.out.println("CVS file created ! file name = " + TRAIN_IMAGES_IDX3_FILE_NAME + ".csv" ); System.out.println("----------------"); } catch (FileNotFoundException e) { System.out.println( "File Not Found" ); } catch (IOException e) { System.out.println( "I/O Exception" ); } finally { try { if (in != null) { in.close(); } if (fw != null) { fw.close(); } } catch (Exception e) { } } // --------------------------------------------------------------------------------------------------- try { System.out.println("----------------"); System.out.println("TEST SET IMAGE FILE, file name = " + T10K_IMAGES_IDX3_FILE_NAME); System.out.println("----------------"); in = new BufferedInputStream(new FileInputStream( T10K_IMAGES_IDX3_FILE_NAME ) ); fw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(T10K_IMAGES_IDX3_FILE_NAME + ".csv"), "Shift_JIS")); fw.write( "id, pixel" ); fw.newLine(); int b; int bytes = 0; StringBuffer buf = null; // 1バイト単位で読み込み while ((b = in.read()) != -1) { if ( ( bytes == 0 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 1 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 2 ) && ( b != 0x08 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 3 ) && ( b != 0x03 ) ) System.out.println( "BAD Magic Number" ); if ( ( bytes == 4 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 5 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 6 ) && ( b != (10000 / 256 ) ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 7 ) && ( b != (10000 % 256 ) ) ) System.out.println( "BAD Number of Items" ); if ( ( bytes == 8 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" ); if ( ( bytes == 9 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" ); if ( ( bytes == 10 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" ); if ( ( bytes == 11 ) && ( b != 28 ) ) System.out.println( "BAD Number of Rows" ); if ( ( bytes == 12 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" ); if ( ( bytes == 13 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" ); if ( ( bytes == 14 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" ); if ( ( bytes == 15 ) && ( b != 28 ) ) System.out.println( "BAD Number of Columns" ); if ( bytes >= 16 ) { // 本体 if ( ( ( bytes - 16 ) % ( 28 * 28 ) ) == 0 ) { // 1画像の始まり buf = new StringBuffer(); buf.append( Integer.toString( ( bytes - 16 ) / ( 28 * 28 ) + 1 ) ); buf.append(","); } buf.append(b); if ( ( ( bytes - 16 ) % ( 28 * 28 ) ) < ( ( 28 * 28 ) - 1 ) ) { // まだ1画像分は終わっていない buf.append(","); } else { // 1画像の終わり fw.write( buf.toString(), 0, buf.toString().length() ); fw.newLine(); } } bytes++; } // end of while fw.flush(); System.out.println("----------------"); System.out.println("CVS file created ! file name = " + T10K_IMAGES_IDX3_FILE_NAME + ".csv" ); System.out.println("----------------"); } catch (FileNotFoundException e) { System.out.println( "File Not Found" ); } catch (IOException e) { System.out.println( "I/O Exception" ); } finally { try { if (in != null) { in.close(); } if (fw != null) { fw.close(); } } catch (Exception e) { } } } }