MMPretrain による画像分類(静止画像向けプログラム)(ソースコードと説明と利用ガイド)

プログラム利用ガイド

1. このプログラムの利用シーン

静止画像の内容を自動的に識別するプログラムである。写真に写っている物体や動物、風景などを判定し、上位5つの候補とその確率を表示する。

2. 主な機能

3. 基本的な使い方

  1. モデル選択

    プログラム起動後、0から5の番号を入力してモデルを選択する。初回実行時はモデルの自動ダウンロードが行われる。

  2. 入力方法の選択

    0(画像ファイル)、1(カメラ)、2(サンプル画像)のいずれかを入力する。

  3. 画像の処理
    • 画像ファイル選択時:ダイアログで画像を選択すると、分類結果が表示される。
    • カメラ選択時:スペースキーで撮影し、qキーで終了する。
    • サンプル画像選択時:自動的にダウンロードされ、分類結果が表示される。
  4. 終了

    任意のキーを押してプログラムを終了する。

4. 便利な機能

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install mmengine mmcv-lite opencv-python pillow
pip install transformers==4.36.2 --force-reinstall
pip install --no-build-isolation mmpretrain

MMPretrain画像分類プログラム

概要

このプログラムは、MMPretrainライブラリを使用して静止画像の分類を行う。事前訓練済みの畳み込みニューラルネットワーク(CNN)モデルを用いて、入力画像に対するTop-5分類結果を出力する。

主要技術

ConvNeXt V2

2023年にWooらが提案した純粋畳み込みアーキテクチャである[1]。Fully Convolutional Masked Autoencoder(FCMAE)とGlobal Response Normalization(GRN)を組み合わせることで、Transformerを使用せずにImageNet分類で競争力のある性能を達成する。

MMPretrain

OpenMMLab が開発する画像分類・事前訓練のためのツールボックスである[2]。統一されたAPIにより、複数の事前訓練モデルを簡単に利用できる。

技術的特徴

実装の特色

参考文献

[1] Woo, S., Debnath, S., Hu, R., Chen, X., Liu, Z., Kweon, I. S., & Xie, S. (2023). ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 16133-16142). https://openaccess.thecvf.com/content/CVPR2023/papers/

[2] MMPretrain Contributors. (2023). OpenMMLab's Pre-training Toolbox and Benchmark. https://github.com/open-mmlab/mmpretrain

ソースコード


"""
- プログラム名: MMPretrain画像分類デモプログラム
- 特徴技術名: ConvNeXt V2(純粋ConvNet)
- 出典: S. Woo, S. Debnath, R. Hu, X. Chen, Z. Liu, I. S. Kweon, and S. Xie, "ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders," in Proc. IEEE/CVF Conf. Computer Vision and Pattern Recognition (CVPR), 2023, pp. 16133-16142.
- 特徴機能: Fully Convolutional Masked Autoencoder (FCMAE)とGlobal Response Normalization (GRN)による自己教師あり学習。従来のConvNetでImageNet精度を達成し、公開データのみを使用した手法
- 特徴技術および学習済みモデルの利用制限: ImageNet事前訓練・ファインチューニング済みモデルはCC-BY-NC-4.0ライセンス(非商用のみ、帰属表示必須)。商用利用には制限あり。コード自体はMITライセンス。必ず利用者自身で公式リポジトリ(https://github.com/facebookresearch/ConvNeXt-V2)の利用制限を確認すること。
- 学習済みモデル:
  各モデルの学習済み重みの詳細情報
  - ResNet50: 25.6Mパラメータ 98MBファイルサイズ 224x224入力 ImageNet-1k 1.28M画像 90エポック訓練
  - ConvNeXt-Tiny: 28.6Mパラメータ 109MBファイルサイズ 224x224入力 ImageNet-1k 1.28M画像 300エポック訓練
  - ConvNeXt-Base: 88.6Mパラメータ 338MBファイルサイズ 224x224入力 ImageNet-1k 1.28M画像 300エポック訓練
  - EfficientNet-B4: 19.3Mパラメータ 74MBファイルサイズ 380x380入力 ImageNet-1k 1.28M画像訓練
  - ConvNeXt-V2-Atto: 3.7Mパラメータ 15MBファイルサイズ 224x224入力 FCMAE自己教師あり事前訓練→ImageNet-1kファインチューニング
  - ConvNeXt-V2-Huge: 650Mパラメータ 2.4GBファイルサイズ 224x224入力 ImageNet-22k 14M画像FCMAE事前訓練→ImageNet-1kファインチューニング
  全モデル MMPretrainライブラリ経由で自動ダウンロード
- 方式設計
  - 関連利用技術:
    * MMPretrain(OpenMMLab画像分類ツールボックス、統一的なAPIと豊富な事前訓練モデルを提供)
    * PyTorch(深層学習フレームワーク、動的計算グラフと柔軟なモデル定義を提供)
    * MMCV(OpenMMLab基盤ライブラリ、画像処理とデータ変換機能を提供)
  - 入力と出力:
    入力: 1つの静止画像,カメラ(ユーザは「0:画像ファイル,1:カメラ,2:サンプル画像」のメニューで選択.0:画像ファイルの場合はtkinterでファイル選択可能.1の場合はOpenCVでカメラが開き,スペースキーで撮影.2の場合はhttps://raw.githubusercontent.com/opencv/opencv/master/samples/data/fruits.jpg とhttps://raw.githubusercontent.com/opencv/opencv/master/samples/data/messi5.jpgとhttps://raw.githubusercontent.com/opencv/opencv/master/samples/data/aero3.jpgとhttps://upload.wikimedia.org/wikipedia/commons/3/3a/Cat03.jpgからダウンロード)
    出力: 処理結果をOpenCV画面でリアルタイムに表示.OpenCV画面内に処理結果をテキストで表示.さらに,print()で処理結果を表示.プログラム終了時にprint()で表示した処理結果をresult.txtファイルに保存し,「result.txtに保存」したことをprint()で表示.プログラム開始時に,プログラムの概要,ユーザが行う必要がある操作(もしあれば)をprint()で表示
  - 処理手順: 1)モデル初期化(ConvNeXt V2事前訓練モデル読み込み)、2)画像入力(ファイル/カメラ/サンプルから選択)、3)前処理(画像リサイズと正規化)、4)推論実行(FCMAE学習とGRNによる特徴抽出と分類)、5)結果出力(予測クラスと信頼度表示)
  - 前処理、後処理: 前処理:画像を224x224ピクセルにリサイズ、ImageNet統計による正規化(平均値と標準偏差でピクセル値を正規化)。後処理:softmax関数による確率変換、上位予測クラスの選択
  - 追加処理: なし
  - 調整を必要とする設定値: なし(デモプログラムのため固定設定)
- 将来方策: デモプログラムのため特別な調整値なし
- その他の重要事項: 初回実行時はモデルの自動ダウンロードが発生し時間を要する場合がある
- 前準備: Python 3.12対応の手順
pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install mmengine mmcv-lite opencv-python pillow
pip install transformers==4.36.2 --force-reinstall
pip install --no-build-isolation mmpretrain
"""

import cv2
import tkinter as tk
from tkinter import filedialog
import urllib.request
import os
import ssl
import time
from datetime import datetime
import traceback

import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont

from mmpretrain import ImageClassificationInferencer

ssl._create_default_https_context = ssl._create_unverified_context

# 日本語フォント設定
FONT_PATH = 'C:/Windows/Fonts/meiryo.ttc'
FONT_SIZE = 20

# GPU/CPU自動選択
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'デバイス: {str(device)}')
# GPU使用時の最適化
if device.type == 'cuda':
    torch.backends.cudnn.benchmark = True

# モデル一覧(環境のmmpretrain版によりID差異の可能性あり)
models = {
    '0': {'name': 'ResNet50',            'model_id': 'resnet50_8xb32_in1k',                         'params': '25.6M', 'size': '98MB',  'input': '224x224', 'pretrain': 'ImageNet-1k 1.28M画像 90エポック'},
    '1': {'name': 'ConvNeXt-Tiny',       'model_id': 'convnext-tiny_32xb128_in1k',                  'params': '28.6M', 'size': '109MB', 'input': '224x224', 'pretrain': 'ImageNet-1k 1.28M画像 300エポック'},
    '2': {'name': 'ConvNeXt-Base',       'model_id': 'convnext-base_32xb128_in1k',                  'params': '88.6M', 'size': '338MB', 'input': '224x224', 'pretrain': 'ImageNet-1k 1.28M画像 300エポック'},
    '3': {'name': 'EfficientNet-B4',     'model_id': 'efficientnet-b4_3rdparty_8xb32_in1k',         'params': '19.3M', 'size': '74MB',  'input': '380x380', 'pretrain': 'ImageNet-1k 1.28M画像'},
    '4': {'name': 'ConvNeXt-V2-Atto',    'model_id': 'convnext-v2-atto_fcmae-pre_3rdparty_in1k',    'params': '3.7M',  'size': '15MB',  'input': '224x224', 'pretrain': 'FCMAE→ImageNet-1k'},
    '5': {'name': 'ConvNeXt-V2-Huge',    'model_id': 'convnext-v2-huge_fcmae-pre_3rdparty_in1k',    'params': '650M',  'size': '2.4GB', 'input': '224x224', 'pretrain': 'ImageNet-22k FCMAE→ImageNet-1k'}
}

results_log = []
inferencer = None
selected_model = None

def load_font():
    try:
        return ImageFont.truetype(FONT_PATH, FONT_SIZE)
    except Exception:
        return ImageFont.load_default()

def draw_text_lines_bgr(img_bgr, lines, color=(0, 255, 0), anchor=(10, 10), line_gap=28):
    font = load_font()
    img_pil = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
    draw = ImageDraw.Draw(img_pil)
    x, y = anchor
    for line in lines:
        draw.text((x, y), line, font=font, fill=color)
        y += line_gap
    return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)

def format_topk(classes, scores, k=5):
    idx = np.argsort(scores)[::-1][:k]
    items = []
    lines = []
    for rank, i in enumerate(idx, start=1):
        label = classes[i] if (classes is not None and i < len(classes)) else str(i)
        sc = float(scores[i])
        items.append(f"{rank}) {label}:{sc:.3f}")
        lines.append(f"{rank}. {label}  {sc:.3f}")
    return "Top-5: " + ", ".join(items), lines

def resolve_classes_from_inferencer(out):
    cls = getattr(inferencer, 'classes', None)
    if cls is not None:
        return cls
    mdl = getattr(inferencer, 'model', None)
    if mdl is not None:
        dm = getattr(mdl, 'dataset_meta', None)
        if isinstance(dm, dict) and 'classes' in dm:
            return dm['classes']
    if isinstance(out, dict) and 'classes' in out:
        return out['classes']
    return None

def image_processing(img_bgr):
    current_time = time.time()
    try:

        result = inferencer(img_bgr)
        out = result[0]

        if hasattr(out, 'pred_scores'):
            scores = out.pred_scores
        elif hasattr(out, 'get'):
            scores = out.get('pred_scores', None)
        else:
            print(f"ERROR: Cannot find pred_scores")
            raise ValueError("pred_scores not found")

        if scores is None:
            score = float(out.get('pred_score', 0.0))
            scores = np.array([score], dtype=np.float32)
        else:
            if hasattr(scores, 'detach'):
                scores = scores.detach().cpu().numpy()
            elif hasattr(scores, 'numpy'):
                scores = scores.numpy()
            else:
                scores = np.asarray(scores)

        classes = resolve_classes_from_inferencer(out)

        top5_line, top5_lines = format_topk(classes, scores, k=min(5, len(scores)))

        lines = [
            f"Model: {selected_model['name']}",
            "画像分類 Top-5"
        ] + top5_lines

        processed = draw_text_lines_bgr(img_bgr.copy(), lines, color=(0, 255, 0), anchor=(10, 10))
        result_text = f"{top5_line}"
        return processed, result_text, current_time
    except Exception as e:
        print(f"\n=== エラー詳細 ===")
        print(traceback.format_exc())
        print(f"=== エラー詳細終了 ===\n")
        err_msg = f"推論エラー: {e}"
        processed = draw_text_lines_bgr(img_bgr.copy(), [err_msg], color=(0, 0, 255), anchor=(10, 10))
        return processed, err_msg, current_time

def process_and_display_images(image_sources, source_type):
    display_index = 1
    for source in image_sources:
        img = cv2.imread(source) if source_type == 'file' else source
        if img is None:
            continue
        cv2.imshow(f'Image_{display_index}', img)
        processed_img, result, current_time = image_processing(img)
        cv2.imshow(f'画像分類 Top-5_{display_index}', processed_img)
        print(datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], result)
        results_log.append(result)
        display_index += 1

print("=" * 60)
print("MMPretrain 画像分類デモプログラム")
print("=" * 60)
print("概要: 事前訓練モデルを用いて静止画の画像分類(Top-5)を行う")
print("操作方法:")
print("  1) モデル番号を選択する")
print("  2) 入力方法を選択する(0:画像ファイル, 1:カメラ, 2:サンプル画像)")
print("  3) カメラ選択時はスペースキーで撮影、qキーで終了")
print("注意: 初回実行時はモデルの自動ダウンロードに時間を要する場合がある")
print("=" * 60)

print("\n利用可能なモデル:")
print("-" * 110)
print(f"{'No.':<3} {'Model':<18} {'Parameters':<12} {'Size':<8} {'Input':<10} {'学習済みモデル':<30}")
print("-" * 110)
for key, model in models.items():
    print(f"{key:<3} {model['name']:<18} {model['params']:<12} {model['size']:<8} {model['input']:<10} {model['pretrain']:<30}")
print("-" * 110)

model_choice = input("モデル番号を選択 (0-5): ").strip()
if model_choice not in models:
    print("無効な選択です")
    raise SystemExit

selected_model = models[model_choice]
print(f"\n選択されたモデル: {selected_model['name']} ({selected_model['params']} parameters)")

inferencer = ImageClassificationInferencer(selected_model['model_id'], pretrained=True, device=str(device))

print("\n画像入力方法を選択してください:")
print("0: 画像ファイル")
print("1: カメラ")
print("2: サンプル画像")

choice = input("選択: ").strip()

try:
    if choice == '0':
        root = tk.Tk()
        root.withdraw()
        if not (paths := filedialog.askopenfilenames()):
            raise SystemExit
        process_and_display_images(paths, 'file')
        cv2.waitKey(0)

    elif choice == '1':
        cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
        if not cap.isOpened():
            cap = cv2.VideoCapture(0)
        cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
        try:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                cv2.imshow('Camera', frame)
                key = cv2.waitKey(1) & 0xFF
                if key == ord(' '):
                    processed_img, result, current_time = image_processing(frame)
                    cv2.imshow('画像分類 Top-5', processed_img)
                    print(datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], result)
                    results_log.append(result)
                elif key == ord('q'):
                    break
        finally:
            cap.release()

    else:
        print("\nサンプル画像をダウンロードしています...")
        opener = urllib.request.build_opener()
        opener.addheaders = [('User-Agent', 'Mozilla/5.0')]
        urllib.request.install_opener(opener)

        urls = [
            "https://raw.githubusercontent.com/opencv/opencv/master/samples/data/fruits.jpg",
            "https://raw.githubusercontent.com/opencv/opencv/master/samples/data/messi5.jpg",
            "https://raw.githubusercontent.com/opencv/opencv/master/samples/data/aero3.jpg",
            "https://upload.wikimedia.org/wikipedia/commons/3/3a/Cat03.jpg"
        ]
        downloaded_files = []
        for i, url in enumerate(urls):
            try:
                urllib.request.urlretrieve(url, f"sample_{i}.jpg")
                downloaded_files.append(f"sample_{i}.jpg")
                print(f"sample_{i}.jpg をダウンロードしました")
            except Exception as e:
                print(f"画像のダウンロードに失敗しました: {url}")
                print(f"エラー: {e}")

        if downloaded_files:
            print(f"\n{len(downloaded_files)}個のサンプル画像の処理を開始します...\n")
            process_and_display_images(downloaded_files, 'file')
            print("\n画像を表示中です。任意のキーを押すと終了します。")
            cv2.waitKey(0)
        else:
            print("\nサンプル画像のダウンロードに失敗しました。")

finally:
    print('\n=== プログラム終了 ===')
    cv2.destroyAllWindows()
    if results_log:
        with open('result.txt', 'w', encoding='utf-8') as f:
            f.write('=== 結果 ===\n')
            f.write(f'使用デバイス: {str(device).upper()}\n')
            if device.type == 'cuda':
                f.write(f'GPU: {torch.cuda.get_device_name(0)}\n')
            f.write('\n')
            f.write('\n'.join(results_log))
        print(f'\n処理結果をresult.txtに保存しました')