RTMDet による物体検出(MMDetection を使用 )(ソースコードと実行結果)

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

Visual Studio 2022 Build Toolsとランタイムのインストール

mmcv 2.1.0 のインストールに使用する.

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。


REM Visual Studio 2022 Build Toolsとランタイムのインストール
winget install --scope machine Microsoft.VisualStudio.2022.BuildTools Microsoft.VCRedist.2015+.x64
set VS_INSTALLER="C:\Program Files (x86)\Microsoft Visual Studio\Installer\setup.exe"
set VS_PATH="C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools"
REM C++開発ワークロードのインストール
%VS_INSTALLER% modify --installPath %VS_PATH% ^
--add Microsoft.VisualStudio.Workload.VCTools ^
--add Microsoft.VisualStudio.Component.VC.Tools.x86.x64 ^
--add Microsoft.VisualStudio.Component.Windows11SDK.22621 ^
--includeRecommended --quiet --norestart

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
"C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Auxiliary\Build\vcvars64.bat"
set DISTUTILS_USE_SDK=1
pip install mmengine mmcv==2.1.0 mmdet opencv-python pillow tqdm

関連する外部ページ

MMDetection と RTMDet による物体検出プログラム

概要

このプログラムは、MMDetectionフレームワークのRTMDet物体検出モデルを使用して、動画ファイル・カメラ・サンプル動画からリアルタイム物体検出を行うプログラムです。

ソースコード


"""
- プログラム名: MMDetection と RTMDet による物体検出プログラム
- 特徴技術名: RTMDet
- 出典: Lyu, C., Zhang, W., Huang, H., Zhou, Y., Wang, Y., Liu, Y., ... & Chen, K. (2022). RTMDet: An empirical study of designing real-time object detectors. arXiv preprint arXiv:2212.07784.
- 特徴機能: 大型カーネル深度別分離畳み込みと動的ソフトラベル割り当て - リアルタイム性能とパラメータ-精度トレードオフを最適化。混合画像データ拡張キャッシュ機能により学習効率を向上
- 学習済みモデル: rtmdet_tiny_8xb32-300e_coco - CSPNeXtをバックボーンとしたRTMDet-Tiny。COCO 2017データセットで事前学習済み。性能と精度のバランスを実現。URL: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/
- 方式設計
  - 関連利用技術:
    * MMDetection 3.3.0 - OpenMMLab物体検出フレームワーク、モジュラー設計と性能を提供
    * DetInferencer - MMDetection統一推論インターフェース、簡潔なAPIで推論実行
    * PyTorch - 深層学習フレームワーク、動的グラフとGPU加速をサポート
  - 入力と出力: 入力: 動画(ユーザは「0:動画ファイル,1:カメラ,2:サンプル動画」のメニューで選択.0:動画ファイルの場合はtkinterでファイル選択.1の場合はOpenCVでカメラが開く.2の場合はhttps://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.aviを使用)、出力: 処理結果が画像化できる場合にはOpenCV画面でリアルタイムに表示.OpenCV画面内に処理結果をテキストで表示.プログラム終了時に処理結果をresult.txtファイルに保存し,「result.txtに保存」したことをprint()で表示.プログラム開始時に,プログラムの概要,ユーザが行う必要がある操作をprint()で表示
  - 処理手順: 1)DetInferencerでRTMDetモデル初期化、2)入力画像読み込み、3)単段階エンコーダー-デコーダーによる特徴抽出、4)動的ソフトラベル割り当てによる物体分類と位置回帰、5)検出結果の描画と表示
  - 前処理、後処理: 前処理: 画像の正規化とリサイズ、混合画像データ拡張(Mosaic+MixUp)、後処理: NMS(Non-Maximum Suppression)による重複検出の除去、信頼度閾値による結果フィルタリング
  - 追加処理: GPU/CPU自動デバイス選択による推論処理、検出結果の可視化処理、キャッシュ機能によるデータ拡張
  - 調整を必要とする設定値: model_choice(モデル選択)- 使用するRTMDetモデルの種類を指定。tiny/s/m/l/xで性能と速度のバランスを調整
- 将来方策: 複数のRTMDetモデル設定を自動ベンチマークし、ハードウェア性能に基づいて最適なmodel_choiceを推奨する機能の追加
- その他の重要事項: 初回実行時にモデルの自動ダウンロードが発生、CUDA対応GPU推奨、RTMDetは実時間性能に最適化
- 前準備: pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
  "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Auxiliary\Build\vcvars64.bat"
  set DISTUTILS_USE_SDK=1
  pip install mmengine mmcv==2.1.0 mmdet opencv-python pillow tqdm
"""

import os
import cv2
import time
import torch
import urllib.request
import ssl
import numpy as np
import tkinter as tk
from tkinter import filedialog
from datetime import datetime
from PIL import Image, ImageDraw, ImageFont
from mmdet.apis import DetInferencer
import warnings
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO

# tqdmプログレスバーを無効化
os.environ["TQDM_DISABLE"] = "1"

# SSL証明書検証を無効化(モデルダウンロード用)
ssl._create_default_https_context = ssl._create_unverified_context

# 重要でないUserWarningを最小限に抑制
warnings.filterwarnings("once", category=UserWarning, module="mmdet")
warnings.filterwarnings("once", category=UserWarning, module="mmengine")

# 設定定数
MODEL_CONFIGS = {
    'tiny': 'rtmdet_tiny_8xb32-300e_coco',
    's': 'rtmdet_s_8xb32-300e_coco',
    'm': 'rtmdet_m_8xb32-300e_coco',
    'l': 'rtmdet_l_8xb32-300e_coco',
    'x': 'rtmdet_x_8xb32-300e_coco'
}
PRED_SCORE_THR = 0.3
FONT_PATH = 'C:/Windows/Fonts/meiryo.ttc'
FONT_SIZE = 20
SAMPLE_URL = 'https://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.avi'

# GPU/CPU自動選択
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'デバイス: {str(device)}', flush=True)

# 変数
frame_count = 0
results_log = []
inferencer = None

def get_font():
    """フォントを取得"""
    try:
        return ImageFont.truetype(FONT_PATH, FONT_SIZE)
    except OSError:
        return ImageFont.load_default()

def infer_silently(img):
    """DetInferencer 呼び出し時の内部進捗出力を抑止"""
    out_buf, err_buf = StringIO(), StringIO()
    try:
        with redirect_stdout(out_buf), redirect_stderr(err_buf):
            return inferencer(
                img,
                return_vis=True,
                no_save_vis=True,
                pred_score_thr=PRED_SCORE_THR
            )
    finally:
        # 例外時は呼び出し元でraiseされる。ここではバッファを破棄するだけ。
        out_buf.close()
        err_buf.close()

def video_frame_processing(frame):
    """1フレームを推論し、可視化フレームと検出結果文字列リスト、タイムスタンプを返す"""
    global frame_count
    current_time = time.time()
    frame_count += 1

    # 進捗表示を抑止して推論
    result = infer_silently(frame)
    predictions = result['predictions'][0]
    vis_frame = result['visualization'][0]

    obj_lines = []
    if hasattr(predictions, 'pred_instances'):
        pred_instances = predictions.pred_instances
        bboxes = pred_instances.bboxes.cpu().numpy()
        labels = pred_instances.labels.cpu().numpy()
        scores = pred_instances.scores.cpu().numpy()

        # クラス名(メタ情報を優先し、なければCOCOの既定にフォールバック)
        classes = None
        if hasattr(predictions, 'metainfo') and isinstance(predictions.metainfo, dict):
            classes = predictions.metainfo.get('classes', None)
        if classes is None:
            classes = [
                'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat',
                'traffic light','fire hydrant','stop sign','parking meter','bench','bird','cat','dog',
                'horse','sheep','cow','elephant','bear','zebra','giraffe','backpack','umbrella',
                'handbag','tie','suitcase','frisbee','skis','snowboard','sports ball','kite',
                'baseball bat','baseball glove','skateboard','surfboard','tennis racket','bottle',
                'wine glass','cup','fork','knife','spoon','bowl','banana','apple','sandwich','orange',
                'broccoli','carrot','hot dog','pizza','donut','cake','chair','couch','potted plant',
                'bed','dining table','toilet','tv','laptop','mouse','remote','keyboard','cell phone',
                'microwave','oven','toaster','sink','refrigerator','book','clock','vase','scissors',
                'teddy bear','hair drier','toothbrush'
            ]

        # スコア閾値を適用し、物体単位の出力行を作成
        for bbox, label, score in zip(bboxes, labels, scores):
            if score < PRED_SCORE_THR:
                continue
            x1, y1, x2, y2 = bbox
            class_name = classes[int(label)] if int(label) < len(classes) else str(int(label))
            obj_lines.append(
                f"{class_name} ({score:.2f}), x1={x1:.0f}, y1={y1:.0f}, x2={x2:.0f}, y2={y2:.0f}"
            )

        # 画面に日本語テキスト描画(検出数)
        font = get_font()
        img_pil = Image.fromarray(cv2.cvtColor(vis_frame, cv2.COLOR_BGR2RGB))
        draw = ImageDraw.Draw(img_pil)
        draw.text((10, 30), f"検出物体数: {len(obj_lines)}", font=font, fill=(0, 255, 0))
        vis_frame = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)

    return vis_frame, obj_lines, current_time

def init_inferencer(model_choice_key):
    """モデルを初期化"""
    selected_model = MODEL_CONFIGS[model_choice_key]
    print(f"選択されたモデル: RTMDet-{model_choice_key.upper()} ({selected_model})", flush=True)
    print("モデルを初期化中...", flush=True)
    inf = DetInferencer(model=selected_model, device=str(device))
    print("RTMDetモデルの初期化が完了した", flush=True)
    return inf

# ガイダンス表示
print("概要: MMDetection/RTMDetで物体検出を行う。各物体ごとに1行で出力する。", flush=True)
print("操作方法:", flush=True)
print("  1) モデルを選択する(tiny/s/m/l/x)", flush=True)
print("  2) 入力を選択する(0:動画ファイル, 1:カメラ, 2:サンプル動画)", flush=True)
print("  3) OpenCVウィンドウで結果を確認し、q キーで終了", flush=True)
print("注意事項: 初回実行時はモデルを自動ダウンロードする場合がある", flush=True)

# モデル選択
print("\nRTMDetモデルを選択してください:", flush=True)
print("tiny: RTMDet-Tiny (COCO2017, CSPNeXt)", flush=True)
print("s: RTMDet-S (COCO2017, CSPNeXt)", flush=True)
print("m: RTMDet-M (COCO2017, CSPNeXt)", flush=True)
print("l: RTMDet-L (COCO2017, CSPNeXt)", flush=True)
print("x: RTMDet-X (COCO2017, CSPNeXt)", flush=True)
model_choice = input("モデル選択 (tiny/s/m/l/x): ").lower().strip()
if model_choice not in MODEL_CONFIGS:
    print("無効な選択。RTMDet-Tinyを使用する", flush=True)
    model_choice = 'tiny'
inferencer = init_inferencer(model_choice)

# 入力選択
print("0: 動画ファイル", flush=True)
print("1: カメラ", flush=True)
print("2: サンプル動画", flush=True)
choice = input("選択: ").strip()

temp_file = None
if choice == '0':
    root = tk.Tk()
    root.withdraw()
    path = filedialog.askopenfilename(
        title="動画ファイルを選択",
        filetypes=[("Video files", "*.mp4 *.avi *.mov *.mkv")]
    )
    if not path:
        print("動画ファイルが選択されなかったため終了する", flush=True)
        raise SystemExit
    cap = cv2.VideoCapture(path)
elif choice == '1':
    cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
    if not cap.isOpened():
        cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
else:
    print("サンプル動画をダウンロード中...", flush=True)
    SAMPLE_FILE = 'vtest.avi'
    urllib.request.urlretrieve(SAMPLE_URL, SAMPLE_FILE)
    temp_file = SAMPLE_FILE
    cap = cv2.VideoCapture(SAMPLE_FILE)

if not cap.isOpened():
    print('動画ファイル・カメラを開けなかった', flush=True)
    if temp_file and os.path.exists(temp_file):
        os.remove(temp_file)
    raise SystemExit

# メイン処理
MAIN_FUNC_DESC = "RTMDet 物体検出"
print('\n=== 動画処理開始 ===', flush=True)
print('操作方法:', flush=True)
print('  q キー: プログラム終了', flush=True)
try:
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        processed_frame, obj_lines, current_time = video_frame_processing(frame)
        cv2.imshow(MAIN_FUNC_DESC, processed_frame)

        # 物体単位1行。カメラ=日付を含む現地時刻(YYYY-MM-DD HH:MM:SS.mmm)、動画=フレーム番号(1開始)。
        if choice == '1':  # カメラ
            ts = datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
            for line in obj_lines:
                print(ts, line, flush=True)
                results_log.append(f"{ts} {line}")
        else:  # 動画ファイル/サンプル動画
            for line in obj_lines:
                print(frame_count, line, flush=True)
                results_log.append(f"{frame_count} {line}")

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
finally:
    cap.release()
    cv2.destroyAllWindows()
    if results_log:
        with open('result.txt', 'w', encoding='utf-8') as f:
            f.write('=== 結果 ===\n')
            f.write(f'処理フレーム数: {frame_count}\n')
            f.write(f'使用デバイス: {str(device).upper()}\n')
            if device.type == 'cuda':
                f.write(f'GPU: {torch.cuda.get_device_name(0)}\n')
            f.write('\n')
            f.write('\n'.join(results_log))
        print('処理結果をresult.txtに保存しました', flush=True)
    if temp_file and os.path.exists(temp_file):
        os.remove(temp_file)
    print('\n=== プログラム終了 ===', flush=True)
# プログラム名: MMDetection と最新Transformerモデルによる物体検出プログラム
# 特徴技術名: Co-DINO-SQL-L, Grounding DINO, RTMDet
# pip install transformers tokenizers sentencepiece protobuf mltk
# python -c "import nltk; nltk.download('all')"

import os
import urllib.request
import ssl
import zipfile
import shutil
from pathlib import Path
import torch
from mmdet.apis import init_detector, inference_detector

# ========== モデル設定(他のモデルに変更する場合はここを修正) ==========
# モデル情報の確認方法:
# https://github.com/open-mmlab/mmdetection/tree/main/configs
# 上記URLの各モデルフォルダ内のREADME.mdに以下の情報が記載されています:
# - 学習に使用したデータセット
# - mAP(精度)
# - FPS(推論速度)
# - 学習済みモデルのダウンロードURL
# - 設定ファイル名
#
# 例:RTMDetの情報
# https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet

# モデル定義
MODELS = [
    {
        'name': 'RTMDet-Tiny',
        'config_file': 'rtmdet/rtmdet_tiny_8xb32-300e_coco.py',
        'checkpoint_url': 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth',
        'checkpoint_file': 'rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth',
        'dataset': 'COCO 2017',
        'num_classes': 80,
        'is_project': False
    },
    {
        'name': 'RTMDet-L',
        'config_file': 'rtmdet/rtmdet_l_8xb32-300e_coco.py',
        'checkpoint_url': 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth',
        'checkpoint_file': 'rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth',
        'dataset': 'COCO 2017',
        'num_classes': 80,
        'is_project': False
    },
    {
        'name': 'Co-DINO-Swin-L (60.7 mAP)',
        'config_file': 'projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_lsj_16xb1_3x_coco.py',
        'checkpoint_url': 'https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_lsj_swin_large_1x_coco-3af73af2.pth',
        'checkpoint_file': 'co_dino_5scale_lsj_swin_large_1x_coco-3af73af2.pth',
        'dataset': 'COCO 2017',
        'num_classes': 80,
        'is_project': True
    },
    {
        'name': 'Co-DINO-Swin-L (64.1 mAP)',
        'config_file': 'projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_16e_o365tococo.py',
        'checkpoint_url': 'https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_swin_large_16e_o365tococo-614254c9.pth',
        'checkpoint_file': 'co_dino_5scale_swin_large_16e_o365tococo-614254c9.pth',
        'dataset': 'Objects365事前学習 + COCO 2017',
        'num_classes': 80,
        'is_project': True
    },
    {
        'name': 'MM-Grounding-DINO-Swin-L',
        'config_file': 'mm_grounding_dino/grounding_dino_swin-l_pretrain_all.py',
        'checkpoint_url': 'https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-l_pretrain_all/grounding_dino_swin-l_pretrain_all-56d69e78.pth',
        'checkpoint_file': 'grounding_dino_swin-l_pretrain_all-56d69e78.pth',
        'dataset': 'O365V2, OpenImageV6, ALL',
        'num_classes': 80,
        'is_project': False
    }
]

# モデル選択メニュー
print("\n利用可能なモデル:")
print("-" * 70)
for i, model in enumerate(MODELS, 1):
    print(f"{i}. {model['name']:<30} | {model['dataset']}")
print("-" * 70)

while True:
    try:
        choice = input(f"\nモデルを選択してください (1-{len(MODELS)}): ")
        choice_idx = int(choice) - 1
        if 0 <= choice_idx < len(MODELS):
            MODEL_CONFIG = MODELS[choice_idx]
            print(f"\n選択されたモデル: {MODEL_CONFIG['name']}")
            break
        else:
            print(f"1から{len(MODELS)}の間で入力してください")
    except ValueError:
        print("数値を入力してください")

# 共通設定
MMDET_VERSION = "v3.3.0"
BASE_DIR = Path("./mmdetection_configs")
DEMO_IMAGE_URL = "https://raw.githubusercontent.com/open-mmlab/mmdetection/main/demo/demo.jpg"
CONFIDENCE_THRESHOLD = 0.3
DEVICE = 'cpu'  # 'cuda:0' for GPU

# SSL証明書検証を無効化
ssl._create_default_https_context = ssl._create_unverified_context

# PyTorch 2.6でモデルファイル読み込み時のセキュリティエラー回避
# PyTorch 2.6から weights_only=True がデフォルトになり、
# MMDetectionの学習済みモデルに含まれる HistoryBuffer 等のオブジェクトが
# ブロックされるため、weights_only=False に強制変更
_original_torch_load = torch.load
def patched_load(*args, **kwargs):
    kwargs['weights_only'] = False
    return _original_torch_load(*args, **kwargs)
torch.load = patched_load

# 設定ファイル一式のダウンロード
if MODEL_CONFIG.get('is_project', False):
    # projectsフォルダのモデルの場合
    config_file_path = BASE_DIR / MODEL_CONFIG['config_file']
    base_config_path = BASE_DIR / "configs" / "_base_" / "default_runtime.py"
    projects_path = BASE_DIR / "projects"
else:
    # configsフォルダのモデルの場合
    config_file_path = BASE_DIR / "configs" / MODEL_CONFIG['config_file']
    base_config_path = BASE_DIR / "configs" / "_base_" / "default_runtime.py"
    projects_path = None

if not (config_file_path.exists() and base_config_path.exists()):
    BASE_DIR.mkdir(parents=True, exist_ok=True)
    zip_url = f"https://github.com/open-mmlab/mmdetection/archive/refs/tags/{MMDET_VERSION}.zip"
    zip_path = BASE_DIR / "mmdetection.zip"

    print(f"\n設定ファイルをダウンロード中...")
    urllib.request.urlretrieve(zip_url, str(zip_path))

    print("設定ファイルを展開中...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        for member in zip_ref.namelist():
            # configsとprojectsフォルダを抽出
            if (f"mmdetection-{MMDET_VERSION[1:]}/configs/" in member or
                f"mmdetection-{MMDET_VERSION[1:]}/projects/" in member):
                target_path = member.replace(f"mmdetection-{MMDET_VERSION[1:]}/", "")
                target_file = BASE_DIR / target_path

                if member.endswith('/'):
                    target_file.mkdir(parents=True, exist_ok=True)
                else:
                    target_file.parent.mkdir(parents=True, exist_ok=True)
                    with zip_ref.open(member) as source, open(target_file, 'wb') as target:
                        shutil.copyfileobj(source, target)

    zip_path.unlink()

# モデルのダウンロード
checkpoint_file = Path(MODEL_CONFIG['checkpoint_file'])

if not checkpoint_file.exists():
    print(f"\nモデルファイルをダウンロード中: {MODEL_CONFIG['checkpoint_file']}")
    urllib.request.urlretrieve(MODEL_CONFIG['checkpoint_url'], str(checkpoint_file))

# デモ画像のダウンロード
demo_dir = Path("./demo")
demo_dir.mkdir(exist_ok=True)
demo_image = demo_dir / "demo.jpg"

if not demo_image.exists():
    print("\nデモ画像をダウンロード中...")
    urllib.request.urlretrieve(DEMO_IMAGE_URL, str(demo_image))

# 元の5行相当の処理
print("\nモデルを初期化中...")

# projectsモデルの場合、PYTHONPATHに追加
import sys
if MODEL_CONFIG.get('is_project', False):
    mmdet_config_path = str(BASE_DIR.resolve())
    if mmdet_config_path not in sys.path:
        sys.path.insert(0, mmdet_config_path)

config_file = str(config_file_path)
checkpoint_file = str(checkpoint_file)
model = init_detector(config_file, checkpoint_file, device=DEVICE)

print("推論を実行中...")
if MODEL_CONFIG['name'] == 'MM-Grounding-DINO-Swin-L':
    import torch
    from PIL import Image
    import numpy as np
    from mmdet.structures import DetDataSample
    from mmengine.structures import PixelData

    # 画像を読み込んでテンソルに変換
    img = Image.open('demo/demo.jpg')
    img_array = np.array(img)
    img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).float().unsqueeze(0)

    data_sample = DetDataSample()
    data_sample.text = 'person . bicycle . car . motorcycle . airplane . bus . train . truck . boat . traffic light . fire hydrant . stop sign . parking meter . bench . bird . cat . dog . horse . sheep . cow . elephant . bear . zebra . giraffe . backpack . umbrella . handbag . tie . suitcase . frisbee . skis . snowboard . sports ball . kite . baseball bat . baseball glove . skateboard . surfboard . tennis racket . bottle . wine glass . cup . fork . knife . spoon . bowl . banana . apple . sandwich . orange . broccoli . carrot . hot dog . pizza . donut . cake . chair . couch . potted plant . bed . dining table . toilet . tv . laptop . mouse . remote . keyboard . cell phone . microwave . oven . toaster . sink . refrigerator . book . clock . vase . scissors . teddy bear . hair drier . toothbrush'
    data_sample.set_metainfo({'img_shape': img_array.shape[:2]})
    data_sample.set_metainfo({'ori_shape': img_array.shape[:2]})
    data_sample.set_metainfo({'scale_factor': (1.0, 1.0)})

    data_batch = dict(
        inputs=img_tensor,
        data_samples=[data_sample]
    )
    result = model.test_step(data_batch)[0]
else:
    result = inference_detector(model, 'demo/demo.jpg')

# COCOクラス名
COCO_CLASSES = [
    'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
    'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
    'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
    'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
    'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
    'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
    'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
    'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
    'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
    'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

# 検出結果の詳細表示
print(f"\n{'='*70}")
print(f"モデル: {MODEL_CONFIG['name']}")
print(f"学習データセット: {MODEL_CONFIG['dataset']}")
print(f"クラス数: {MODEL_CONFIG['num_classes']}")
print(f"{'='*70}")
print("\n検出結果詳細:")
pred_instances = result.pred_instances

# 検出結果の詳細表示部分を修正
detected_count = 0
for i in range(len(pred_instances)):
    score = pred_instances.scores[i].item()
    if score > CONFIDENCE_THRESHOLD:
        detected_count += 1
        label = pred_instances.labels[i].item()
        bbox = pred_instances.bboxes[i].detach().cpu().numpy()  # detach()を追加
        class_name = COCO_CLASSES[label] if label < len(COCO_CLASSES) else f"class_{label}"
        print(f"{detected_count:2d}. {class_name:<20} | score: {score:.3f} | bbox: [{bbox[0]:6.1f}, {bbox[1]:6.1f}, {bbox[2]:6.1f}, {bbox[3]:6.1f}]")

if detected_count == 0:
    print("検出されたオブジェクトはありません")
else:
    print(f"\n合計 {detected_count} 個のオブジェクトを検出しました (閾値: {CONFIDENCE_THRESHOLD})")