MambaOut による画像分類

目次

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

ライブラリのインストール:

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。


pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install opencv-python pillow timm

MambaOut による画像分類プログラム

概要

本プログラムは、動画の各フレームを1000種類のカテゴリに分類する。ConvNeXtアーキテクチャを用いて、リアルタイムで画像認識を行い、認識結果の確信度とともに上位5つの候補を提示する。

主要技術

1. ConvNeXt
畳み込みニューラルネットワーク(CNN)の一種である[1]。Vision Transformerの設計思想を取り入れながら、純粋なCNNとして実装されている。7×7の大きな畳み込みカーネル、深さ方向分離畳み込み、LayerNormalizationなどの特徴を持つ。

2. MambaOut
Vision Mambaから状態空間モデル(SSM)を除去した結果、ConvNeXtと同一のアーキテクチャになることを示した研究である[2]。この発見により、視覚タスクにおいてSSMが必須ではないことが示された。

参考文献

ソースコード


# プログラム名: MambaOut リアルタイム動画分類プログラム
# 特徴技術名: MambaOut
# 出典: Yu, W., & Wang, X. (2025). MambaOut: Do We Really Need Mamba for Vision? In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition.
# 特徴機能: Gated CNNブロックを積み重ねた階層的アーキテクチャによる画像分類。7x7カーネルサイズの深さ方向畳み込みによるトークンミキシングを実現し、ImageNet画像分類において視覚的Mambaモデルを上回る性能を達成
# 学習済みモデル: MambaOut-Tiny(5M parameters)、MambaOut-Small(25M parameters)、MambaOut-Base(76M parameters)、MambaOut-Kobe(11M parameters)が利用可能。timmライブラリから事前学習済みモデルをダウンロード可能
# 方式設計:
#   関連利用技術: timm(PyTorch Image Models、MambaOutモデルを提供するライブラリ)、PyTorch(深層学習フレームワーク)、OpenCV(画像処理・カメラ制御)、PIL(画像処理・テキスト描画)
#   入力と出力: 入力: 動画(ユーザは「0:動画ファイル,1:カメラ,2:サンプル動画」のメニューで選択.0:動画ファイルの場合はtkinterでファイル選択.1の場合はOpenCVでカメラが開く.2の場合はhttps://github.com/opencv/opencv/blob/master/samples/data/vtest.aviを使用)、出力: 処理結果が画像化できる場合にはOpenCV画面でリアルタイムに表示.OpenCV画面内に処理結果をテキストで表示.さらに,1秒間隔で,print()で処理結果を表示.プログラム終了時にprint()で表示した処理結果をresult.txtファイルに保存し,「result.txtに保存」したことをprint()で表示
#   処理手順: 1.カメラ/動画から画像フレーム取得、2.画像を224x224にリサイズ、3.ImageNet正規化を適用、4.MambaOutモデルで推論、5.softmaxで確率計算、6.上位5クラスを抽出、7.結果を画像に描画して表示
#   前処理、後処理: 前処理: Resize(224,224)、ToTensor変換、ImageNet統計量での正規化(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])。後処理: softmax関数による確率計算、top-k抽出(k=5)
#   追加処理: なし
#   調整を必要とする設定値: MODEL_NAME (使用するMambaOutモデルの選択。'mambaout_tiny.in1k'、'mambaout_small.in1k'、'mambaout_base.in1k'、'mambaout_kobe.in1k'から選択可能。大きいモデルほど精度向上するが処理速度低下)
# 将来方策: MODEL_NAMEの自動選択機能。処理開始時に利用可能なGPUメモリを検出し、最適なモデルサイズを自動選択する機能の実装
# その他の重要事項: MambaOutはGated CNNブロックを用いることで、SSMなしでも画像分類を実現。timmライブラリのバージョン0.6.11以上が必要
# 前準備: pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
#         pip install opencv-python pillow timm

import cv2
import tkinter as tk
from tkinter import filedialog
import os
import numpy as np
import torch
import torchvision.transforms as transforms
import urllib.request
from PIL import Image, ImageDraw, ImageFont
import timm
import time

# 定数定義
MAMBAOUT_MODELS = {
    'mambaout_tiny.in1k': {'params': '5M', 'input_size': '224x224'},
    'mambaout_small.in1k': {'params': '25M', 'input_size': '224x224'},
    'mambaout_base.in1k': {'params': '76M', 'input_size': '224x224'},
    'mambaout_kobe.in1k': {'params': '11M', 'input_size': '224x224'},
    'mambaout_femto.in1k': {'params': '3M', 'input_size': '224x224'}
}
FONT_PATH = 'C:/Windows/Fonts/msgothic.ttc'  # 日本語フォントパス
FONT_SIZE = 18  # フォントサイズ
RANDOM_SEED = 42  # 乱数シード
WINDOW_NAME = 'MambaOut Image Classification'  # OpenCVウィンドウ名
RESULT_FILE = 'result.txt'  # 結果保存ファイル名
LOG_INTERVAL = 1.0  # ログ出力間隔(秒)
TOP_K = 5  # Top-K設定(1-10まで設定可能)
SUPPORTED_VIDEO_FORMATS = ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm', '.m4v', '.mpg', '.mpeg']  # サポート動画形式

def detect_available_models():
    """利用可能なMambaOutモデルを検出"""
    available_models = []
    for model_name in MAMBAOUT_MODELS.keys():
        try:
            timm.create_model(model_name, pretrained=False)
            available_models.append(model_name)
        except:
            continue
    return available_models

def select_model():
    """モデル選択メニューを表示してユーザーに選択させる"""
    available_models = detect_available_models()

    if not available_models:
        print('利用可能なMambaOutモデルが見つかりません')
        exit()

    print('利用可能なMambaOutモデル:')
    print('')
    for i, model_name in enumerate(available_models):
        model_info = MAMBAOUT_MODELS[model_name]
        print(f'{i}: {model_name} ({model_info["params"]} params, {model_info["input_size"]})')
    print('')

    try:
        choice = input(f'モデルを選択 (デフォルト: 0): ').strip()
        if choice == '':
            choice = 0
        else:
            choice = int(choice)

        if 0 <= choice < len(available_models):
            return available_models[choice]
        else:
            print('無効な選択です。デフォルトモデルを使用します')
            return available_models[0]
    except:
        print('無効な入力です。デフォルトモデルを使用します')
        return available_models[0]

# ImageNet クラス名の動的取得
try:
    url = 'https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt'
    with urllib.request.urlopen(url) as response:
        IMAGENET_CLASSES = [line.decode('utf-8').strip() for line in response.readlines()]
    print('ImageNetクラス名を取得しました')
except Exception as e:
    print(f'ImageNetクラス名の取得に失敗しました: {e}')
    print('定義済みのクラス名を使用します')
    IMAGENET_CLASSES = [f'class_{i}' for i in range(1000)]

print('MambaOut リアルタイム動画分類プログラム')
print('')

# モデル選択
MODEL_NAME = select_model()
print('')

print('特徴技術: MambaOut')
print(f'使用モデル: {MODEL_NAME}')
print('機能: リアルタイム動画像分類(ImageNet 1000クラス)')
print(f'サポート動画形式: {", ".join(SUPPORTED_VIDEO_FORMATS)}')
print(f'Top-{TOP_K} 分類結果を表示')
print('')
print('操作方法:')
print("- 動画再生中: 'q'キーで終了")
print('- 1秒ごとに分類結果を表示します')
print('- 終了時にresult.txtに結果を保存します')
print('')

# 乱数シード設定
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    gpu_name = torch.cuda.get_device_name(0)
    print(f'使用デバイス: {device} ({gpu_name})')
else:
    print(f'使用デバイス: {device}')

# モデルのロード
print(f"MambaOutモデル '{MODEL_NAME}' をロード中...")
try:
    model = timm.create_model(MODEL_NAME, pretrained=True)
    model = model.to(device)
    model.eval()
    print('モデルのロードが完了しました')
except Exception as e:
    print(f'モデルの読み込みに失敗しました: {e}')
    exit()

# ImageNetクラス名を使用
class_names = IMAGENET_CLASSES

# 前処理の定義
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# フォントの読み込み(一度だけ実行)
try:
    font = ImageFont.truetype(FONT_PATH, FONT_SIZE)
except:
    font = None

# 結果保存用リスト
results_log = []

def get_confidence_color(prob):
    """確信度に応じた色を返す"""
    if prob >= 0.7:
        return (0, 255, 0)    # 緑(高信頼度)
    elif prob >= 0.5:
        return (0, 255, 255)  # 黄(中信頼度)
    elif prob >= 0.3:
        return (0, 165, 255)  # オレンジ(低中信頼度)
    else:
        return (0, 0, 255)    # 赤(低信頼度)

def video_processing(frame, frame_count, elapsed_time):
    # フレームを処理して分類結果を描画
    img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(tensor)
        probs = torch.nn.functional.softmax(outputs[0], dim=0)

    topk_prob, topk_idx = torch.topk(probs, TOP_K)

    # 結果を画像に描画
    draw = ImageDraw.Draw(img)

    # タイトル
    draw.text((10, 10), 'MambaOut 分類結果', font=font, fill=(0, 255, 0))
    draw.text((10, 35), f'Top-{TOP_K} 予測:', font=font, fill=(0, 255, 0))

    # Top-K結果
    result_text = []
    for i, (prob, idx) in enumerate(zip(topk_prob, topk_idx)):
        class_idx = idx.item()
        name = class_names[class_idx] if class_idx < len(class_names) else f'class_{class_idx}'
        conf = prob.item()
        text = f'{i+1}. {name} ({conf:.3f})'
        color = get_confidence_color(conf)
        draw.text((10, 60 + i * 25), text, font=font, fill=color)
        result_text.append(text)

    # フレーム情報
    info = f'フレーム: {frame_count} | 経過時間: {elapsed_time:.1f}秒'
    draw.text((10, 60 + TOP_K * 25 + 10), info, font=font, fill=(0, 255, 0))

    processed_frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)

    return processed_frame, result_text


print('')
print('0: 動画ファイル')
print('1: カメラ')
print('2: サンプル動画')
print('')

choice = input('選択: ')
temp_file = None

if choice == '0':
    root = tk.Tk()
    root.withdraw()
    path = filedialog.askopenfilename(
        title='動画ファイルを選択',
        filetypes=[('動画ファイル', ' '.join(f'*{ext}' for ext in SUPPORTED_VIDEO_FORMATS)), ('すべてのファイル', '*.*')]
    )
    if not path:
        exit()
    # サポート形式の確認
    if not any(path.lower().endswith(ext) for ext in SUPPORTED_VIDEO_FORMATS):
        print(f'警告: {os.path.splitext(path)[1]} は推奨されない形式です')
    cap = cv2.VideoCapture(path)
elif choice == '1':
    cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
elif choice == '2':
    # サンプル動画ダウンロード・処理
    import urllib.request
    url = 'https://github.com/opencv/opencv/raw/master/samples/data/vtest.avi'
    filename = 'vtest.avi'
    try:
        print(f'サンプル動画をダウンロード中: {url}')
        urllib.request.urlretrieve(url, filename)
        temp_file = filename
        cap = cv2.VideoCapture(filename)
        print('サンプル動画のダウンロードが完了しました')
    except Exception as e:
        print(f'動画のダウンロードに失敗しました: {url}')
        print(f'エラー: {e}')
        exit()
else:
    print('無効な選択です')
    exit()

# メイン処理
frame_count = 0
start_time = time.time()
last_print_time = start_time

try:
    while True:
        cap.grab()
        ret, frame = cap.retrieve()
        if not ret:
            break

        frame_count += 1
        current_time = time.time()
        elapsed_time = current_time - start_time

        processed_frame, result_text = video_processing(frame, frame_count, elapsed_time)
        cv2.imshow(WINDOW_NAME, processed_frame)

        # 指定間隔で結果を表示
        if current_time - last_print_time >= LOG_INTERVAL:
            print(f'\n[時刻: {elapsed_time:.1f}秒] フレーム: {frame_count}')
            print(f'分類結果 (Top-{TOP_K}):')
            for text in result_text:
                print(f'  {text}')

            # ログに保存
            log_entry = {
                'time': elapsed_time,
                'frame': frame_count,
                'results': result_text
            }
            results_log.append(log_entry)

            last_print_time = current_time

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

finally:
    cap.release()
    cv2.destroyAllWindows()

    # 結果をファイルに保存
    if results_log:
        with open(RESULT_FILE, 'w', encoding='utf-8') as f:
            f.write('MambaOut リアルタイム動画分類プログラム 実行結果\n')
            f.write('=' * 50 + '\n')
            f.write(f'特徴技術: MambaOut\n')
            f.write(f'使用モデル: {MODEL_NAME}\n')
            f.write(f'総フレーム数: {frame_count}\n')
            f.write(f'総実行時間: {elapsed_time:.1f}秒\n')
            f.write('=' * 50 + '\n\n')

            for entry in results_log:
                f.write(f"[時刻: {entry['time']:.1f}秒] フレーム: {entry['frame']}\n")
                f.write(f'分類結果 (Top-{TOP_K}):\n')
                for text in entry['results']:
                    f.write(f'  {text}\n')
                f.write('\n')

        print(f'\n{RESULT_FILE}に保存しました')

    if temp_file:
        os.remove(temp_file)

    print('プログラムを終了しました')

使用方法

  1. 上記のプログラムを実行する
  2. 動作確認

    • Webカメラ映像が表示される
    • 映像上にリアルタイムで分類結果が表示される
    • 上位5位の分類結果と信頼度が表示される
  3. 終了方法:映像ウィンドウで 'q' キーを押す。

実験・探求のアイデア

AIモデル選択の実験

実験要素

体験・実験・探求のアイデア