MNIST データセットのファイル形式を CSV ファイル形式に変換

サイト構成 連絡先,業績など 実践知識 コンピュータ 教材 サポートページ

要点

  1. MNIST の Web ページを開く http://yann.lecun.com/exdb/mnist/

  2. 説明文を確認

    この Web ページ、http://yann.lecun.com/exdb/mnist/ に説明が書いてあるので、読んでおくこと。この Web ページの一部分を引用すると、

    "The MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers."
    

  3. ダウンロード解凍

    次の4つのファイルをダウンロードし,解凍する.

  4. 解凍してできたファイルを1つのディレクトリに集める。

    例えば,次のように.

  5. この Web ページの末尾にあるプログラムのソースコードを使う.

    ソースコード中の「 private static String DIR = "C:\\R\\";」 の行は、ディレクトリ名にあわせて書き換える。

  6. プログラムの実行

  7. ファイルができるので確認する

  8. ファイルの中身を確認する


使用するプログラムのソースコード

package hoge.hoge.com;

import java.io.BufferedInputStream;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStreamWriter;

import sun.java2d.pipe.BufferedOpCodes;

public class MNISTRead {
	private static String DIR = "C:\\R\\";
	private static String T10K_IMAGES_IDX3_FILE_NAME = DIR + "t10k-images.idx3-ubyte";
	private static String T10K_LABELS_IDX1_FILE_NAME = DIR + "t10k-labels.idx1-ubyte";
	private static String TRAIN_IMAGES_IDX3_FILE_NAME = DIR + "train-images.idx3-ubyte";
	private static String TRAIN_LABELS_IDX1_FILE_NAME = DIR + "train-labels.idx1-ubyte";
	public static void main(String[] args) {

	    InputStream in = null;
	    BufferedWriter fw = null;
	   
	    try {
	    	System.out.println("----------------");
	    	System.out.println("TRAINING SET LABEL FILE, file name = " + TRAIN_LABELS_IDX1_FILE_NAME );
	    	System.out.println("----------------");
	        in = new BufferedInputStream(new FileInputStream( TRAIN_LABELS_IDX1_FILE_NAME ) );
	        fw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(TRAIN_LABELS_IDX1_FILE_NAME + ".csv"), "Shift_JIS"));
        	fw.write( "id, label" );
        	fw.newLine();
        	
	        int b;
	        int bytes = 0;
	        // 1バイト単位で読み込み
	        while ((b = in.read()) != -1) {
	        	if ( ( bytes == 0 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 1 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 2 ) && ( b != 0x08 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 3 ) && ( b != 0x01 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 4 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 5 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 6 ) && ( b != (60000 / 256 ) ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 7 ) && ( b != (60000 % 256 ) ) ) System.out.println( "BAD Number of Items" );
	            if ( bytes >= 8 ) {
	            	StringBuffer buf = new StringBuffer();
	            	// 1列目が画像番号,2列目が「label」
	            	buf.append( Integer.toString( bytes - 7 )); // 1列目
	            	buf.append(",");
	            	buf.append(b); // 2行目
	            	fw.write( buf.toString(), 0, buf.toString().length() );
	            	fw.newLine();
	            }
	            bytes++;
	        }
	        fw.flush();
	    	System.out.println("----------------");
	    	System.out.println("CVS file created ! file name = " + TRAIN_LABELS_IDX1_FILE_NAME + ".csv" );
	    	System.out.println("----------------");

	    } catch (FileNotFoundException e) {
	        System.out.println( "File Not Found" );
	    } catch (IOException e) {
	        System.out.println( "I/O Exception" );
	    } finally {
	        try {
	            if (in != null) {
	                in.close();
	            }
	            if (fw != null) {
	                fw.close();
	            }
	        } catch (Exception e) {
	        }
	    }
	   
	    // ---------------------------------------------------------------------------------------------------
	    try {
	    	System.out.println("----------------");
	    	System.out.println("TEST SET LABEL FILE, file name = " + T10K_LABELS_IDX1_FILE_NAME);
	    	System.out.println("----------------");
	        in = new BufferedInputStream(new FileInputStream( T10K_LABELS_IDX1_FILE_NAME ) );
	        fw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(T10K_LABELS_IDX1_FILE_NAME + ".csv"), "Shift_JIS"));
        	fw.write( "id, label" );
        	fw.newLine();
        	
	        int b;
	        int bytes = 0;
	        // 1バイト単位で読み込み
	        while ((b = in.read()) != -1) {
	        	if ( ( bytes == 0 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 1 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 2 ) && ( b != 0x08 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 3 ) && ( b != 0x01 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 4 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 5 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 6 ) && ( b != (10000 / 256 ) ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 7 ) && ( b != (10000 % 256 ) ) ) System.out.println( "BAD Number of Items" );
	            if ( bytes >= 8 ) {
	            	StringBuffer buf = new StringBuffer();
	            	// 1列目が画像番号,2列目が「label」
	            	buf.append( Integer.toString( bytes - 7 )); // 1列目
	            	buf.append(",");
	            	buf.append(b); // 2行目
	            	fw.write( buf.toString(), 0, buf.toString().length() );
	            	fw.newLine();
	            }
	            bytes++;
	        }
	        fw.flush();
	    	System.out.println("----------------");
	    	System.out.println("CVS file created ! file name = " + T10K_LABELS_IDX1_FILE_NAME + ".csv" );
	    	System.out.println("----------------");

	    } catch (FileNotFoundException e) {
	        System.out.println( "File Not Found" );
	    } catch (IOException e) {
	        System.out.println( "I/O Exception" );
	    } finally {
	        try {
	            if (in != null) {
	                in.close();
	            }
	            if (fw != null) {
	                fw.close();
	            }
	        } catch (Exception e) {
	        }
	    }
	   
	 // ---------------------------------------------------------------------------------------------------
	    try {
	    	System.out.println("----------------");
	    	System.out.println("TRAINING SET IMAGE FILE, file name = " + TRAIN_IMAGES_IDX3_FILE_NAME);
	    	System.out.println("----------------");
	        in = new BufferedInputStream(new FileInputStream( TRAIN_IMAGES_IDX3_FILE_NAME ) );
	        fw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(TRAIN_IMAGES_IDX3_FILE_NAME + ".csv"), "Shift_JIS"));
        	fw.write( "id, pixel" );
        	fw.newLine();
        	
	        int b;
	        int bytes = 0;
    		StringBuffer buf = null;
	        // 1バイト単位で読み込み
	        while ((b = in.read()) != -1) {
	        	if ( ( bytes == 0 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 1 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 2 ) && ( b != 0x08 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 3 ) && ( b != 0x03 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 4 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 5 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 6 ) && ( b != (60000 / 256 ) ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 7 ) && ( b != (60000 % 256 ) ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 8 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" );
	        	if ( ( bytes == 9 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" );
	        	if ( ( bytes == 10 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" );
	        	if ( ( bytes == 11 ) && ( b != 28   ) ) System.out.println( "BAD Number of Rows" );
	        	if ( ( bytes == 12 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" );
	        	if ( ( bytes == 13 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" );
	        	if ( ( bytes == 14 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" );
	        	if ( ( bytes == 15 ) && ( b != 28   ) ) System.out.println( "BAD Number of Columns" );
	        	if ( bytes >= 16 ) {
	        		// 本体
	        		if ( ( ( bytes - 16 ) % ( 28 * 28 ) ) == 0 ) {
	        			// 1画像の始まり
	        			buf = new StringBuffer();
	        			buf.append( Integer.toString( ( bytes - 16 ) / ( 28 * 28 ) + 1 ) );
	        			buf.append(",");
	        		}
        			buf.append(b);
        			if ( ( ( bytes - 16 ) % ( 28 * 28 ) ) < ( ( 28 * 28 ) - 1 ) ) {
        				// まだ1画像分は終わっていない
	        			buf.append(",");
        			}
        			else {
        				// 1画像の終わり
	        			fw.write( buf.toString(), 0, buf.toString().length() );
	        			fw.newLine();
        			}
	        	}
        		bytes++;
	        } // end of while
	        fw.flush();
	    	System.out.println("----------------");
	    	System.out.println("CVS file created ! file name = " + TRAIN_IMAGES_IDX3_FILE_NAME + ".csv" );
	    	System.out.println("----------------");

	    } catch (FileNotFoundException e) {
	        System.out.println( "File Not Found" );
	    } catch (IOException e) {
	        System.out.println( "I/O Exception" );
	    } finally {
	        try {
	            if (in != null) {
	                in.close();
	            }
	            if (fw != null) {
	                fw.close();
	            }
	        } catch (Exception e) {
	        }
	    }
	   
	 // ---------------------------------------------------------------------------------------------------
	    try {
	    	System.out.println("----------------");
	    	System.out.println("TEST SET IMAGE FILE, file name = " + T10K_IMAGES_IDX3_FILE_NAME);
	    	System.out.println("----------------");
	        in = new BufferedInputStream(new FileInputStream( T10K_IMAGES_IDX3_FILE_NAME ) );
	        fw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(T10K_IMAGES_IDX3_FILE_NAME + ".csv"), "Shift_JIS"));
        	fw.write( "id, pixel" );
        	fw.newLine();
        	
	        int b;
	        int bytes = 0;
    		StringBuffer buf = null;
	        // 1バイト単位で読み込み
	        while ((b = in.read()) != -1) {
	        	if ( ( bytes == 0 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 1 ) && ( b != 0x00 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 2 ) && ( b != 0x08 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 3 ) && ( b != 0x03 ) ) System.out.println( "BAD Magic Number" );
	        	if ( ( bytes == 4 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 5 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 6 ) && ( b != (10000 / 256 ) ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 7 ) && ( b != (10000 % 256 ) ) ) System.out.println( "BAD Number of Items" );
	        	if ( ( bytes == 8 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" );
	        	if ( ( bytes == 9 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" );
	        	if ( ( bytes == 10 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Rows" );
	        	if ( ( bytes == 11 ) && ( b != 28   ) ) System.out.println( "BAD Number of Rows" );
	        	if ( ( bytes == 12 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" );
	        	if ( ( bytes == 13 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" );
	        	if ( ( bytes == 14 ) && ( b != 0x00 ) ) System.out.println( "BAD Number of Columns" );
	        	if ( ( bytes == 15 ) && ( b != 28   ) ) System.out.println( "BAD Number of Columns" );
	        	if ( bytes >= 16 ) {
	        		// 本体
	        		if ( ( ( bytes - 16 ) % ( 28 * 28 ) ) == 0 ) {
	        			// 1画像の始まり
	        			buf = new StringBuffer();
	        			buf.append( Integer.toString( ( bytes - 16 ) / ( 28 * 28 ) + 1 ) );
	        			buf.append(",");
	        		}
        			buf.append(b);
        			if ( ( ( bytes - 16 ) % ( 28 * 28 ) ) < ( ( 28 * 28 ) - 1 ) ) {
        				// まだ1画像分は終わっていない
	        			buf.append(",");
        			}
        			else {
        				// 1画像の終わり
	        			fw.write( buf.toString(), 0, buf.toString().length() );
	        			fw.newLine();
        			}
	        	}
        		bytes++;
	        } // end of while
	        fw.flush();
	    	System.out.println("----------------");
	    	System.out.println("CVS file created ! file name = " + T10K_IMAGES_IDX3_FILE_NAME + ".csv" );
	    	System.out.println("----------------");

	    } catch (FileNotFoundException e) {
	        System.out.println( "File Not Found" );
	    } catch (IOException e) {
	        System.out.println( "I/O Exception" );
	    } finally {
	        try {
	            if (in != null) {
	                in.close();
	            }
	            if (fw != null) {
	                fw.close();
	            }
	        } catch (Exception e) {
	        }
	    }
	}
}