RT-DETRv2による物体検出・ByteTrackによる追跡とTTAの機能付き(COCO 80クラス)(ソースコードと説明と利用ガイド)

【概要】 RT-DETRv2を用いた物体検出システムで、動画やウェブカメラからCOCO 80クラスの物体をリアルタイムで検出する。CLAHE前処理とTTAにより暗所でも高精度な検出が可能。ByteTrackによる物体追跡機能を搭載し、フレーム間での継続的な追跡を実現。3種類のバックボーン選択、日本語表示対応、検出結果の自動保存機能を備える。

RT-DETRv2物体検出 RT-DETRv2物体検出

プログラム利用ガイド

1. このプログラムの利用シーン

動画ファイルやウェブカメラの映像から、人、車、動物などの物体をリアルタイムで自動検出するためのツールである。監視システム、交通流解析、画像解析研究、教育用デモンストレーションなどの用途に適用できる。CLAHE前処理により暗い環境でも安定した検出性能を発揮する。

2. 主な機能

3. 基本的な使い方

  1. プログラムの起動: Pythonスクリプトを実行し、使用するRT-DETRv2モデル(1/2/3)を選択する。
  2. 入力ソースの選択: キーボードで0(動画ファイル)、1(ウェブカメラ)、2(サンプル動画)のいずれかを入力する。
  3. 検出処理の実行: 映像が表示され、検出された物体が色分けされたバウンディングボックスで囲まれる。ByteTrack有効時は追跡IDも表示される。
  4. プログラムの終了: 映像表示画面でqキーを押すと、処理を終了し結果がファイルに保存される。

4. 便利な機能

使用する学習済みモデル

RT-DETRv2事前学習済みモデル:

事前準備

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

必要なパッケージのインストール

管理者権限でコマンドプロンプトを起動し、以下のコマンドを実行する:


pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install -U transformers opencv-python numpy pillow boxmot

RT-DETRv2による物体検出プログラム・ByteTrackによる追跡とTTAの機能付き(COCO 80クラス)

概要

このプログラムは、RT-DETRv2を用いた物体検出システムである。動画ファイル、ウェブカメラ、サンプル動画から取得した映像に対してリアルタイムで物体検出を実行し、COCOデータセット80クラスの物体をバウンディングボックスで表示する。検出精度の向上を目的として、CLAHE(コントラスト制限付き適応ヒストグラム均一化)とTTA(Test-Time Augmentation)を組み合わせた前処理を実装している[1][2]。

主要技術

RT-DETRv2(Real-Time Detection Transformer version 2)

Peking UniversityとBaiduが開発したリアルタイム物体検出Transformerの改良版である[1][2]。RT-DETRの後継として、選択的マルチスケール特徴抽出、離散サンプリング演算子、動的データ拡張によりBag-of-Freebiesアプローチを実装し、速度を損なうことなく精度を向上させる。NMS(非最大抑制)を必要としないエンドツーエンドのアーキテクチャを採用する。

CLAHE(Contrast Limited Adaptive Histogram Equalization)

Zuiderveldが1994年に提案したコントラスト強化手法である[3][4]。画像を小領域(タイル)に分割し、各タイルでヒストグラム均一化を適用する。コントラスト制限機能により、ノイズの過度な増幅を防止する。

ByteTrack

カルマンフィルタとハンガリアンアルゴリズムを組み合わせた物体追跡手法である。低信頼度検出も含めた2段階の関連付けにより、遮蔽環境でも安定した追跡を実現する。

技術的特徴

実装の特色

リアルタイム映像処理に特化した設計を採用し、以下の機能を備える:

参考文献

[1] Lv, W., Zhao, Y., Chang, Q., Huang, K., Wang, G., & Liu, Y. (2024). RT-DETRv2: Improved Baseline with Bag-of-Freebies for Real-Time Detection Transformer. arXiv preprint arXiv:2407.17140.

[2] Hugging Face. (2024). RT-DETRv2 Documentation. https://huggingface.co/docs/transformers/en/model_doc/rt_detr_v2

[3] Zuiderveld, K. (1994). Contrast limited adaptive histogram equalization. Graphics gems IV, 474-485.

[4] OpenCV Team. (2024). Histogram Equalization Documentation. https://docs.opencv.org/4.x/d5/daf/tutorial_py_histogram_equalization.html

[5] Shanmugam, D., Blalock, D., Balakrishnan, G., & Guttag, J. (2021). When and why test-time augmentation works. arXiv preprint arXiv:2011.11156.

[6] Machine Learning Mastery. (2020). How to Use Test-Time Augmentation. https://machinelearningmastery.com/how-to-use-test-time-augmentation-to-improve-model-performance-for-image-classification/

ソースコード


"""
プログラム名: RT-DETRv2による物体検出プログラム(COCO 80クラス)・ByteTrackによる追跡とTTAの機能付き
特徴技術名: RT-DETRv2 (Real-Time Detection Transformer version 2)
出典: W. Lv, Y. Zhao, Q. Chang, K. Huang, G. Wang, and Y. Liu, "RT-DETRv2: Improved Baseline with Bag-of-Freebies for Real-Time Detection Transformer," arXiv preprint arXiv:2407.17140, 2024.
特徴機能: Bag-of-Freebiesによる改良ベースラインと選択的マルチスケール特徴抽出による物体検出
学習済みモデル: PekingU/rtdetr_v2_r50vd/r101vd/hgnetv2_l(Hugging Face Transformers)、COCOデータセットで事前学習済み
特徴技術および学習済モデルの利用制限: RT-DETRv2はApache 2.0ライセンス(商用利用可能)。boxmot(ByteTrack実装)はAGPL-3.0ライセンス(ネットワークサービスとして提供する場合はソースコード公開が必要)。必ず利用者自身で各ライセンスの詳細を確認すること。
方式設計:
  関連利用技術:
    - PyTorch: ディープラーニングフレームワーク、GPU/CPU自動選択
    - Transformers: Hugging Face Transformersライブラリ
    - OpenCV: 画像・動画処理、カメラ制御
    - CLAHE (Contrast Limited Adaptive Histogram Equalization): 低照度環境での画像品質向上
    - ByteTrack: カルマンフィルタとハンガリアンアルゴリズムによる物体追跡(boxmotパッケージ版)
    - TTA (Test Time Augmentation): 複数の画像変換で推論し結果を統合
  入力と出力: 入力: 動画(ユーザは「0:動画ファイル,1:カメラ,2:サンプル動画」のメニューで選択.0:動画ファイルの場合はtkinterでファイル選択.1の場合はOpenCVでカメラが開く.2の場合はhttps://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.aviを使用)、出力: OpenCV画面でリアルタイム表示、検出結果をresult.txtに保存
  処理手順: 1.動画フレーム取得→2.CLAHE前処理→3.TTA適用→4.RT-DETRv2推論実行→5.バウンディングボックス抽出→6.ByteTrack追跡→7.結果描画
  前処理、後処理: 前処理:CLAHE適用による画像コントラスト強化、後処理:post_process_object_detection()による結果整形、ByteTrack追跡による検出結果の安定化とID管理
  追加処理: TTA - 水平反転による推論結果の統合
  調整を必要とする設定値: CONF_THRESH(信頼度閾値、デフォルト0.25)- 検出感度を制御、値が低いほど多くの物体を検出、TTA_ENABLED(TTAの有効/無効、デフォルトTrue)
将来方策: 信頼度閾値の自動最適化 - 検出結果の時系列分析により、シーンごとに最適な閾値を動的に学習・適用する機能
その他の重要事項: COCOクラス検出可能、Windows環境での動作を想定(フォントはC:/Windows/Fonts/meiryo.ttc を使用)
前準備:
pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install -U transformers opencv-python numpy pillow boxmot
"""
import cv2
import numpy as np
import torch
import torchvision
from transformers import RTDetrV2ForObjectDetection, AutoImageProcessor
import tkinter as tk
from tkinter import filedialog
import urllib.request
import time
import sys
import io
from datetime import datetime
from PIL import Image, ImageDraw, ImageFont
from boxmot import ByteTrack
import warnings
import threading

warnings.filterwarnings('ignore')

# Windows文字エンコーディング設定
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', line_buffering=True)

# GPU/CPU自動選択
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'デバイス: {str(device)}')

# GPU使用時の最適化
if device.type == 'cuda':
    torch.backends.cudnn.benchmark = True

# モデル情報の構造化
MODEL_INFO = {
    '1': {
        'name': 'PekingU/rtdetr_v2_r50vd',
        'desc': 'ResNet-50D backbone (速度と精度のバランス)',
        'backbone': 'ResNet-50D'
    },
    '2': {
        'name': 'PekingU/rtdetr_v2_r101vd',
        'desc': 'ResNet-101D backbone (精度重視)',
        'backbone': 'ResNet-101D'
    },
    '3': {
        'name': 'PekingU/rtdetr_v2_hgnetv2_l',
        'desc': 'HGNetv2-L backbone (最高精度)',
        'backbone': 'HGNetv2-L'
    }
}

# 調整可能な設定値
CONF_THRESH = 0.25
NMS_THRESHOLD = 0.6
CLAHE_CLIP_LIMIT = 3.0
CLAHE_TILE_SIZE = (8, 8)
WINDOW_NAME = "RT-DETRv2 COCO Detection"
TTA_ENABLED = True
TTA_CONF_BOOST = 0.03
USE_TRACKER = True

# CLAHEオブジェクトをグローバルスコープで一度だけ定義
clahe = cv2.createCLAHE(clipLimit=CLAHE_CLIP_LIMIT, tileGridSize=CLAHE_TILE_SIZE)

# ByteTrackトラッカーを初期化
tracker = ByteTrack() if USE_TRACKER else None

# BGR→RGB色変換のヘルパー関数
def bgr_to_rgb(color_bgr):
    """BGRカラーをRGBカラーに変換"""
    return (color_bgr[2], color_bgr[1], color_bgr[0])

# クラスごとの色生成
def generate_class_colors(num_classes):
    colors = []
    for i in range(num_classes):
        hue = int(180.0 * i / num_classes)
        hsv = np.uint8([[[hue, 255, 255]]])
        bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)[0][0]
        colors.append((int(bgr[0]), int(bgr[1]), int(bgr[2])))
    return colors

# 日本語クラス名マッピング
CLASS_NAMES_JP = {
    'person': '人', 'bicycle': '自転車', 'car': '車', 'motorcycle': 'バイク',
    'airplane': '飛行機', 'bus': 'バス', 'train': '電車', 'truck': 'トラック',
    'boat': 'ボート', 'traffic light': '信号機', 'fire hydrant': '消火栓',
    'stop sign': '停止標識', 'parking meter': 'パーキングメーター', 'bench': 'ベンチ',
    'bird': '鳥', 'cat': '猫', 'dog': '犬', 'horse': '馬', 'sheep': '羊',
    'cow': '牛', 'elephant': '象', 'bear': '熊', 'zebra': 'シマウマ', 'giraffe': 'キリン',
    'backpack': 'リュック', 'umbrella': '傘', 'handbag': 'ハンドバッグ', 'tie': 'ネクタイ',
    'suitcase': 'スーツケース', 'frisbee': 'フリスビー', 'skis': 'スキー板',
    'snowboard': 'スノーボード', 'sports ball': 'ボール', 'kite': '凧',
    'baseball bat': 'バット', 'baseball glove': 'グローブ', 'skateboard': 'スケートボード',
    'surfboard': 'サーフボード', 'tennis racket': 'テニスラケット', 'bottle': 'ボトル',
    'wine glass': 'ワイングラス', 'cup': 'カップ', 'fork': 'フォーク', 'knife': 'ナイフ',
    'spoon': 'スプーン', 'bowl': 'ボウル', 'banana': 'バナナ', 'apple': 'リンゴ',
    'sandwich': 'サンドイッチ', 'orange': 'オレンジ', 'broccoli': 'ブロッコリー',
    'carrot': 'ニンジン', 'hot dog': 'ホットドッグ', 'pizza': 'ピザ', 'donut': 'ドーナツ',
    'cake': 'ケーキ', 'chair': '椅子', 'couch': 'ソファ', 'potted plant': '鉢植え',
    'bed': 'ベッド', 'dining table': 'テーブル', 'toilet': 'トイレ', 'tv': 'テレビ',
    'laptop': 'ノートPC', 'mouse': 'マウス', 'remote': 'リモコン', 'keyboard': 'キーボード',
    'cell phone': '携帯電話', 'microwave': '電子レンジ', 'oven': 'オーブン',
    'toaster': 'トースター', 'sink': 'シンク', 'refrigerator': '冷蔵庫',
    'book': '本', 'clock': '時計', 'vase': '花瓶', 'scissors': 'ハサミ',
    'teddy bear': 'ぬいぐるみ', 'hair drier': 'ドライヤー', 'toothbrush': '歯ブラシ'
}

# 日本語フォント設定
FONT_PATH = 'C:/Windows/Fonts/meiryo.ttc'
FONT_SIZE_MAIN = 16
font_main = ImageFont.truetype(FONT_PATH, FONT_SIZE_MAIN)

# グローバル変数
frame_count = 0
results_log = []
class_counts = {}
model = None
processor = None
id2label = {}
CLASS_COLORS = []


class ThreadedVideoCapture:
    """スレッド化されたVideoCapture(常に最新フレームを取得)"""
    def __init__(self, src, is_camera=False):
        if is_camera:
            self.cap = cv2.VideoCapture(src, cv2.CAP_DSHOW)
            fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
            self.cap.set(cv2.CAP_PROP_FOURCC, fourcc)
            self.cap.set(cv2.CAP_PROP_FPS, 60)
        else:
            self.cap = cv2.VideoCapture(src)

        self.grabbed, self.frame = self.cap.read()
        self.stopped = False
        self.lock = threading.Lock()
        self.thread = threading.Thread(target=self.update, args=())
        self.thread.daemon = True
        self.thread.start()

    def update(self):
        """バックグラウンドでフレームを取得し続ける"""
        while not self.stopped:
            grabbed, frame = self.cap.read()
            with self.lock:
                self.grabbed = grabbed
                if grabbed:
                    self.frame = frame

    def read(self):
        """最新フレームを返す"""
        with self.lock:
            return self.grabbed, self.frame.copy() if self.grabbed else None

    def isOpened(self):
        return self.cap.isOpened()

    def get(self, prop):
        return self.cap.get(prop)

    def release(self):
        self.stopped = True
        self.thread.join()
        self.cap.release()


def display_program_header():
    print('=' * 60)
    print('=== RT-DETRv2オブジェクト検出プログラム ===')
    print('=' * 60)
    print('概要: CLAHEとTTAを適用し、リアルタイムでオブジェクトを検出します')
    print('機能: RT-DETRv2による物体検出(COCOデータセット対応)')
    print('技術: CLAHE (コントラスト強化) + ByteTrack による追跡 + TTA (Test Time Augmentation) + Transformers + RT-DETRv2')
    print('操作: qキーで終了')
    print('出力: 各フレームごとに処理結果を表示し、終了時にresult.txtへ保存')
    print()


def move_inputs_to_device(inputs, device):
    """入力データをデバイスに転送"""
    for k in inputs:
        if isinstance(inputs[k], torch.Tensor):
            inputs[k] = inputs[k].to(device)
    return inputs


def frame_to_pil(frame):
    """OpenCVフレームをPIL Imageに変換"""
    return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))


def apply_tta_inference(frame, model, processor, id2label, conf_thresh):
    """Test Time Augmentation (TTA)を適用した推論"""
    frame_width = frame.shape[1]

    img_pil = frame_to_pil(frame)
    inputs = processor(images=img_pil, return_tensors='pt')
    inputs = move_inputs_to_device(inputs, device)

    with torch.no_grad():
        outputs = model(**inputs)

    target_sizes = torch.tensor([frame.shape[:2]], device=device)
    results_orig = processor.post_process_object_detection(
        outputs,
        target_sizes=target_sizes,
        threshold=conf_thresh
    )[0]

    flipped_frame = cv2.flip(frame, 1)
    img_pil_flipped = frame_to_pil(flipped_frame)
    inputs_flipped = processor(images=img_pil_flipped, return_tensors='pt')
    inputs_flipped = move_inputs_to_device(inputs_flipped, device)

    with torch.no_grad():
        outputs_flipped = model(**inputs_flipped)

    results_flipped = processor.post_process_object_detection(
        outputs_flipped,
        target_sizes=target_sizes,
        threshold=conf_thresh
    )[0]

    all_boxes = []
    all_confs = []
    all_labels = []

    if len(results_orig['scores']) > 0:
        all_boxes.append(results_orig['boxes'])
        all_confs.append(results_orig['scores'])
        all_labels.append(results_orig['labels'])

    if len(results_flipped['scores']) > 0:
        boxes_flipped = results_flipped['boxes'].clone()
        if boxes_flipped.shape[0] > 0:
            # 水平反転画像での検出結果を元の画像座標系に変換
            # x1, x2 の大小関係を保つ必要がある
            x1_flipped = boxes_flipped[:, 0].clone()
            x2_flipped = boxes_flipped[:, 2].clone()

            # 元の画像座標系での新しい座標
            boxes_flipped[:, 0] = frame_width - 1 - x2_flipped  # 新しいx1(左端)
            boxes_flipped[:, 2] = frame_width - 1 - x1_flipped  # 新しいx2(右端)

        all_boxes.append(boxes_flipped)
        all_confs.append(results_flipped['scores'])
        all_labels.append(results_flipped['labels'])

    if len(all_boxes) == 0:
        return []

    all_boxes = torch.cat(all_boxes, dim=0)
    all_confs = torch.cat(all_confs, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    valid_indices = all_confs > conf_thresh
    if valid_indices.sum() > 0:
        all_boxes = all_boxes[valid_indices]
        all_confs = all_confs[valid_indices]
        all_labels = all_labels[valid_indices]

        nms_indices = torchvision.ops.nms(all_boxes, all_confs, iou_threshold=NMS_THRESHOLD)
        final_boxes = all_boxes[nms_indices].cpu().numpy()
        final_confs = all_confs[nms_indices].cpu().numpy()
        final_labels = all_labels[nms_indices].cpu().numpy()

        detections = []
        for i in range(len(final_confs)):
            conf_boost = TTA_CONF_BOOST if TTA_ENABLED else 0
            x1, y1, x2, y2 = map(int, final_boxes[i])
            cls = int(final_labels[i])
            name = id2label.get(cls, str(cls))
            detections.append({
                'x1': x1, 'y1': y1,
                'x2': x2, 'y2': y2,
                'conf': min(1.0, final_confs[i] + conf_boost),
                'class': cls,
                'name': name
            })

        return detections

    return []


def normal_inference(frame, model, processor, id2label, conf_thresh):
    """通常の推論処理"""
    img_pil = frame_to_pil(frame)
    inputs = processor(images=img_pil, return_tensors='pt')
    inputs = move_inputs_to_device(inputs, device)

    with torch.no_grad():
        outputs = model(**inputs)

    target_sizes = torch.tensor([frame.shape[:2]], device=device)
    results = processor.post_process_object_detection(
        outputs,
        target_sizes=target_sizes,
        threshold=conf_thresh
    )[0]

    curr_dets = []
    if len(results['scores']) > 0:
        scores = results['scores'].cpu().numpy()
        labels = results['labels'].cpu().numpy()
        boxes = results['boxes'].cpu().numpy()

        order = np.argsort(scores)[::-1]
        for i in order:
            x1, y1, x2, y2 = map(int, boxes[i])
            conf_score = float(scores[i])
            cls = int(labels[i])
            name = id2label.get(cls, str(cls))
            curr_dets.append({
                'x1': x1, 'y1': y1,
                'x2': x2, 'y2': y2,
                'conf': conf_score,
                'class': cls,
                'name': name
            })

    return curr_dets


def apply_tta_if_enabled(frame, model, processor, id2label, conf_thresh):
    """TTA機能を条件付きで適用"""
    if not TTA_ENABLED:
        return normal_inference(frame, model, processor, id2label, conf_thresh)
    return apply_tta_inference(frame, model, processor, id2label, conf_thresh)


def apply_bytetrack(detections, frame):
    """ByteTrackerを使用したトラッキング処理"""
    global tracker

    if len(detections) > 0:
        dets_array = np.array([[d['x1'], d['y1'], d['x2'], d['y2'], d['conf'], d['class']]
                               for d in detections])
    else:
        dets_array = np.empty((0, 6))

    tracks = tracker.update(dets_array, frame)

    tracked_dets = []
    if len(tracks) > 0:
        for track in tracks:
            if len(track) >= 7:
                x1, y1, x2, y2, track_id, conf, cls = track[:7]
                name = id2label.get(int(cls), str(int(cls)))
                tracked_dets.append({
                    'x1': int(x1), 'y1': int(y1),
                    'x2': int(x2), 'y2': int(y2),
                    'track_id': int(track_id),
                    'conf': float(conf),
                    'class': int(cls),
                    'name': name
                })
    return tracked_dets


def apply_tracking_if_enabled(detections, frame):
    """トラッキング機能を条件付きで適用"""
    if not USE_TRACKER:
        return detections
    return apply_bytetrack(detections, frame)


def draw_detection_results(frame, detections):
    """物体検出の描画処理"""
    for det in detections:
        color_seed = det['class']
        color = CLASS_COLORS[color_seed % len(CLASS_COLORS)]
        cv2.rectangle(frame, (det['x1'], det['y1']),
                      (det['x2'], det['y2']), color, 2)

    texts_to_draw = []
    for det in detections:
        color_seed = det['class']
        color = CLASS_COLORS[color_seed % len(CLASS_COLORS)]
        track_id = det.get('track_id', 0) if USE_TRACKER else 0
        jp_name = CLASS_NAMES_JP.get(det['name'], det['name'])
        if USE_TRACKER and track_id > 0:
            label = f"ID:{track_id} {jp_name}: {det['conf']:.2f}"
        else:
            label = f"{jp_name}: {det['conf']:.2f}"

        texts_to_draw.append({
            'text': label,
            'org': (det['x1'], det['y1']-20),
            'color': bgr_to_rgb(color),
            'font_type': 'main'
        })
    frame = draw_texts_with_pillow(frame, texts_to_draw)

    tta_status = "TTA:ON" if TTA_ENABLED else "TTA:OFF"
    tracker_status = "ByteTrack:ON" if USE_TRACKER else "ByteTrack:OFF"
    info_text = f"Objects: {len(detections)} | Frame: {frame_count} | Classes: {len(set(d['name'] for d in detections)) if detections else 0} | {tta_status} | {tracker_status}"
    cv2.putText(frame, info_text, (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

    return frame


def format_detection_output(detections):
    """物体検出の出力フォーマット"""
    if len(detections) == 0:
        return 'count=0'
    else:
        parts = []
        for det in detections:
            x1, y1, x2, y2 = det['x1'], det['y1'], det['x2'], det['y2']
            class_name = det['name']
            conf = det['conf']
            if USE_TRACKER and 'track_id' in det:
                track_id = det['track_id']
                parts.append(f'class={class_name},ID={track_id},conf={conf:.3f},box=[{x1},{y1},{x2},{y2}]')
            else:
                parts.append(f'class={class_name},conf={conf:.3f},box=[{x1},{y1},{x2},{y2}]')
        return f'count={len(detections)}; ' + ' | '.join(parts)


def draw_texts_with_pillow(bgr_frame, texts):
    """テキスト描画"""
    img_pil = Image.fromarray(cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB))
    draw = ImageDraw.Draw(img_pil)

    for item in texts:
        text = item['text']
        x, y = item['org']
        color = item['color']
        draw.text((x, y), text, font=font_main, fill=color)

    return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)


def detect_objects(frame):
    """共通の検出処理"""
    global model, processor, id2label

    yuv_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2YUV)
    yuv_frame[:, :, 0] = clahe.apply(yuv_frame[:, :, 0])
    enh_frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR)

    curr_dets = apply_tta_if_enabled(enh_frame, model, processor, id2label, CONF_THRESH)

    return curr_dets


def process_video_frame(frame, timestamp_ms, is_camera):
    """動画用フレーム処理"""
    detections = detect_objects(frame)

    tracked_dets = apply_tracking_if_enabled(detections, frame)

    global class_counts
    for det in tracked_dets:
        name = det['name']
        if name not in class_counts:
            class_counts[name] = 0
        class_counts[name] += 1

    frame = draw_detection_results(frame, tracked_dets)

    result = format_detection_output(tracked_dets)

    return frame, result


def video_frame_processing(frame, timestamp_ms, is_camera):
    """動画フレーム処理(標準形式)"""
    global frame_count
    current_time = time.time()
    frame_count += 1

    processed_frame, result = process_video_frame(frame, timestamp_ms, is_camera)
    return processed_frame, result, current_time


display_program_header()

print("\n=== RT-DETRv2モデル選択 ===")
print('使用するRT-DETRv2モデルを選択してください:')
for key, info in MODEL_INFO.items():
    print(f'{key}: {info["name"]} ({info["desc"]})')
print()

model_choice = ''
while model_choice not in MODEL_INFO.keys():
    model_choice = input("選択 (1/2/3) [デフォルト: 1]: ").strip()
    if model_choice == '':
        model_choice = '1'
        break
    if model_choice not in MODEL_INFO.keys():
        print("無効な選択です。もう一度入力してください。")

print(f"\nRT-DETRv2モデルをロード中...")
try:
    model_name = MODEL_INFO[model_choice]['name']
    model = RTDetrV2ForObjectDetection.from_pretrained(model_name)
    processor = AutoImageProcessor.from_pretrained(model_name)
    model.to(device)
    model.eval()

    id2label = {int(k): v for k, v in model.config.id2label.items()}
    NUM_CLASSES = len(id2label)
    CLASS_COLORS = generate_class_colors(NUM_CLASSES)

    print(f"\n検出可能なクラス数: {len(id2label)}")
    print(f"クラス一覧: {', '.join(id2label.values())}")
    print(f"モデル情報: {MODEL_INFO[model_choice]['desc']}")
    print("モデルのロード完了")
except Exception as e:
    print(f"モデルのロードに失敗しました: {e}")
    raise SystemExit(1)

if TTA_ENABLED:
    print("\nTest Time Augmentation (TTA): 有効")
    print("  - 水平反転による推論結果の統合")
    print(f"  - 信頼度ブースト値: {TTA_CONF_BOOST}")
    print(f"  - NMS閾値: {NMS_THRESHOLD}")
else:
    print("\nTest Time Augmentation (TTA): 無効")

if USE_TRACKER:
    print("\nByteTrack: 有効")
    print("  - カルマンフィルタによる動き予測")

print("\n=== RT-DETRv2リアルタイム物体検出(COCO対応) ===")
print("0: 動画ファイル")
print("1: カメラ")
print("2: サンプル動画")

choice = input("選択: ")

is_camera = (choice == '1')

if choice == '0':
    root = tk.Tk()
    root.withdraw()
    path = filedialog.askopenfilename()
    if not path:
        raise SystemExit(1)
    cap = cv2.VideoCapture(path)
elif choice == '1':
    cap = ThreadedVideoCapture(0, is_camera=True)
else:
    print("サンプル動画をダウンロード中...")
    url = "https://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.avi"
    filename = "vtest.avi"
    urllib.request.urlretrieve(url, filename)
    cap = cv2.VideoCapture(filename)

if not cap.isOpened():
    print('動画ファイル・カメラを開けませんでした')
    raise SystemExit(1)

if is_camera:
    actual_fps = cap.get(cv2.CAP_PROP_FPS)
    print(f'カメラのfps: {actual_fps}')
    timestamp_increment = int(1000 / actual_fps) if actual_fps > 0 else 33
else:
    video_fps = cap.get(cv2.CAP_PROP_FPS)
    timestamp_increment = int(1000 / video_fps) if video_fps > 0 else 33

print('\n=== 動画処理開始 ===')
print('操作方法:')
print('  q キー: プログラム終了')

start_time = time.time()
last_info_time = start_time
info_interval = 10.0
timestamp_ms = 0
total_processing_time = 0.0

try:
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        timestamp_ms += timestamp_increment

        processing_start = time.time()
        processed_frame, result, current_time = video_frame_processing(frame, timestamp_ms, is_camera)
        processing_time = time.time() - processing_start
        total_processing_time += processing_time

        cv2.imshow(WINDOW_NAME, processed_frame)

        if result:
            if is_camera:
                timestamp = datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
                print(f'{timestamp}, {result}')
            else:
                print(f'Frame {frame_count}: {result}')

            results_log.append(result)

        if is_camera:
            elapsed = current_time - last_info_time
            if elapsed >= info_interval:
                total_elapsed = current_time - start_time
                actual_fps = frame_count / total_elapsed if total_elapsed > 0 else 0
                avg_processing_time = (total_processing_time / frame_count * 1000) if frame_count > 0 else 0
                print(f'[情報] 経過時間: {total_elapsed:.1f}秒, 処理フレーム数: {frame_count}, 実測fps: {actual_fps:.1f}, 平均処理時間: {avg_processing_time:.1f}ms')
                last_info_time = current_time

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

finally:
    print('\n=== プログラム終了 ===')
    cap.release()
    cv2.destroyAllWindows()

    if results_log:
        with open('result.txt', 'w', encoding='utf-8') as f:
            f.write('=== RT-DETRv2物体検出結果 ===\n')
            f.write(f'処理フレーム数: {frame_count}\n')
            f.write(f'使用モデル: {model_name}\n')
            f.write(f'モデル情報: {MODEL_INFO[model_choice]["desc"]}\n')
            f.write(f'使用デバイス: {str(device).upper()}\n')
            if device.type == 'cuda':
                f.write(f'GPU: {torch.cuda.get_device_name(0)}\n')
            f.write(f'画像処理: CLAHE適用(YUV色空間)\n')
            f.write(f'TTA (Test Time Augmentation): {"有効" if TTA_ENABLED else "無効"}\n')
            if TTA_ENABLED:
                f.write(f'  - NMS閾値: {NMS_THRESHOLD}\n')
                f.write(f'  - 信頼度ブースト: {TTA_CONF_BOOST}\n')
            f.write(f'ByteTrack: {"有効" if USE_TRACKER else "無効"}\n')
            f.write(f'信頼度閾値: {CONF_THRESH}(固定値)\n')
            f.write(f'\n検出されたクラス一覧:\n')
            for class_name, count in sorted(class_counts.items()):
                jp_name = CLASS_NAMES_JP.get(class_name, class_name)
                f.write(f'  {jp_name} ({class_name}): {count}回\n')
            f.write('\n')
            f.write('\n'.join(results_log))
        print(f'\n処理結果をresult.txtに保存しました')
        print(f'検出されたクラス数: {len(class_counts)}')