RT-DETRv2による人物検出・ByteTrackによる追跡とTTAの機能付き(ソースコードと説明と利用ガイド)

プログラム利用ガイド

1. このプログラムの利用シーン

動画内の人物をリアルタイムで検出するためのプログラムである。監視カメラ映像の解析、人流計測、スポーツ映像の分析など、人物の位置と動きを把握する必要がある場面で使用できる。

2. 主な機能

3. 基本的な使い方

  1. プログラムを起動すると、モデル選択画面が表示される。1(ResNet-50D)、2(ResNet-101D)、3(HGNetv2-L)から選択する。
  2. 入力ソースを選択する。0(動画ファイル)、1(カメラ)、2(サンプル動画)から選択する。
  3. 処理が開始され、検出結果がリアルタイムで画面に表示される。
  4. キーボードのqキーを押すとプログラムが終了し、結果がresult.txtに保存される。

4. 便利な機能

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install transformers opencv-python numpy pillow boxmot

RT-DETRv2による人物検出プログラム・ByteTrackによる追跡とTTAの機能付き

概要

このプログラムは、RT-DETRv2を用いて動画からpersonカテゴリのみを検出する。CLAHE(Contrast Limited Adaptive Histogram Equalization)による前処理、TTA(Test Time Augmentation)による精度向上、ByteTrackによる物体追跡を組み合わせ、リアルタイム処理を実現する。

主要技術

RT-DETRv2 (Real-Time Detection Transformer version 2)

2024年にLvらが発表したTransformerベースの物体検出モデルである[1]。選択的マルチスケール特徴抽出とBag-of-Freebiesと呼ばれる手法により、速度を維持しながら検出精度を向上させる。deformable attentionモジュールを用いてスケールごとに異なるサンプリング点数を設定することで、柔軟な特徴抽出を実現する。

ByteTrack

2022年にZhangらが発表した多物体追跡アルゴリズムである[2]。カルマンフィルタによる動き予測とハンガリアンアルゴリズムによるデータアソシエーションを組み合わせる。低信頼度の検出結果も活用することで、追跡の頑健性を向上させる。

技術的特徴

実装の特色

3種類のバックボーン(ResNet-50D、ResNet-101D、HGNetv2-L)から選択可能な構成である。動画ファイル、カメラ、サンプル動画の3つの入力ソースに対応する。ByteTrackによるトラッキングIDの管理と表示機能を備え、処理結果をリアルタイムで画面表示するとともに、ファイルに保存する。

参考文献

[1] Lv, W., Zhao, Y., Chang, Q., Huang, K., Wang, G., & Liu, Y. (2024). RT-DETRv2: Improved Baseline with Bag-of-Freebies for Real-Time Detection Transformer. arXiv preprint arXiv:2407.17140. https://arxiv.org/abs/2407.17140

[2] Zhang, Y., Sun, P., Jiang, Y., Yu, D., Weng, F., Yuan, Z., Luo, P., Liu, W., & Wang, X. (2022). ByteTrack: Multi-Object Tracking by Associating Every Detection Box. In Computer Vision – ECCV 2022. https://arxiv.org/abs/2110.06864

[3] Zuiderveld, K. (1994). Contrast Limited Adaptive Histogram Equalization. In P. Heckbert (Ed.), Graphics Gems IV (pp. 474-485). Academic Press.

[4] Shanmugam, D., Blalock, D., Balakrishnan, G., & Guttag, J. (2020). Better Aggregation in Test-Time Augmentation. 2021 IEEE/CVF International Conference on Computer Vision (ICCV).

ソースコード


"""
プログラム名: RT-DETRv2による人物検出プログラム・ByteTrackによる追跡とTTAの機能付き
特徴技術名: RT-DETRv2 (Real-Time Detection Transformer version 2)
出典: W. Lv, Y. Zhao, Q. Chang, K. Huang, G. Wang, and Y. Liu, "RT-DETRv2: Improved Baseline with Bag-of-Freebies for Real-Time Detection Transformer," arXiv preprint arXiv:2407.17140, 2024.
特徴機能: Bag-of-Freebiesによる改良ベースラインと選択的マルチスケール特徴抽出による物体検出
学習済みモデル: PekingU/rtdetr_v2_r50vd/r101vd/hgnetv2_l(Hugging Face Transformers)、COCOデータセットで事前学習済み
特徴技術および学習済モデルの利用制限: Apache 2.0ライセンス(Hugging Face Transformers)。学術研究および商用利用が可能。ただし、利用者自身で最新の利用規約を確認すること。
方式設計:
  関連利用技術:
    - PyTorch: ディープラーニングフレームワーク、GPU/CPU自動選択
    - Transformers: Hugging Face Transformersライブラリ
    - OpenCV: 画像・動画処理、カメラ制御
    - CLAHE (Contrast Limited Adaptive Histogram Equalization): 低照度環境での画像品質向上
    - ByteTrack: カルマンフィルタとハンガリアンアルゴリズムによる物体追跡(boxmotパッケージ版)
    - TTA (Test Time Augmentation): 複数の画像変換で推論し結果を統合
  入力と出力: 入力: 動画(ユーザは「0:動画ファイル,1:カメラ,2:サンプル動画」のメニューで選択.0:動画ファイルの場合はtkinterでファイル選択.1の場合はOpenCVでカメラが開く.2の場合はhttps://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.aviを使用)、出力: OpenCV画面でリアルタイム表示、検出結果をresult.txtに保存
  処理手順: 1.動画フレーム取得→2.CLAHE前処理→3.TTA適用→4.RT-DETRv2推論実行→5.バウンディングボックス抽出(personのみ)→6.ByteTrack追跡→7.結果描画
  前処理、後処理: 前処理:CLAHE適用による画像コントラスト強化、後処理:post_process_object_detection()による結果整形、ByteTrack追跡による検出結果の安定化とID管理
  追加処理: TTA - 水平反転による推論結果の統合
  調整を必要とする設定値: CONF_THRESH(信頼度閾値、デフォルト0.25)- 検出感度を制御、値が低いほど多くの物体を検出、TTA_ENABLED(TTAの有効/無効)、USE_TRACKER(ByteTrackの有効/無効)
将来方策: 信頼度閾値の自動最適化 - 検出結果の時系列分析により、シーンごとに最適な閾値を動的に学習・適用する機能
その他の重要事項: person検出専用、Windows環境での動作を想定(フォントはC:/Windows/Fonts/meiryo.ttc を使用)
前準備:
pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install transformers opencv-python numpy pillow boxmot
"""
import cv2
import numpy as np
import torch
import torchvision
from transformers import RTDetrV2ForObjectDetection, AutoImageProcessor
import tkinter as tk
from tkinter import filedialog
import urllib.request
import time
import sys
import io
from datetime import datetime
from PIL import Image, ImageDraw, ImageFont
from boxmot import ByteTrack
import warnings
import threading

warnings.filterwarnings('ignore')

# Windows文字エンコーディング設定
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', line_buffering=True)

# GPU/CPU自動選択
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'デバイス: {str(device)}')

# GPU使用時の最適化
if device.type == 'cuda':
    torch.backends.cudnn.benchmark = True

# モデル情報の構造化
MODEL_INFO = {
    '1': {
        'name': 'PekingU/rtdetr_v2_r50vd',
        'desc': 'ResNet-50D backbone(速度と精度のバランス)',
        'backbone': 'ResNet-50D'
    },
    '2': {
        'name': 'PekingU/rtdetr_v2_r101vd',
        'desc': 'ResNet-101D backbone(精度重視)',
        'backbone': 'ResNet-101D'
    },
    '3': {
        'name': 'PekingU/rtdetr_v2_hgnetv2_l',
        'desc': 'HGNetv2-L backbone(最高精度)',
        'backbone': 'HGNetv2-L'
    }
}

# 調整可能な設定値
CONF_THRESH = 0.25      # 信頼度閾値 - 検出感度制御
NMS_THRESHOLD = 0.6     # TTA用のNMS閾値(独立管理)
CLAHE_CLIP_LIMIT = 3.0  # CLAHE制限値
CLAHE_TILE_SIZE = (8, 8)  # CLAHEタイルサイズ
WINDOW_NAME = "RT-DETRv2 COCO Detection"  # OpenCVウィンドウ名
TTA_ENABLED = False     # TTA(Test Time Augmentation)の有効/無効(デフォルト)
TTA_CONF_BOOST = 0.03   # TTA使用時の信頼度ブースト値
USE_TRACKER = False     # トラッカーの使用有無(デフォルト)

# CLAHEオブジェクトをグローバルスコープで一度だけ定義(AIモデルの入力用にCLAHEを適用)
clahe = cv2.createCLAHE(clipLimit=CLAHE_CLIP_LIMIT, tileGridSize=CLAHE_TILE_SIZE)

# ByteTrackトラッカーを初期化(後で設定に応じて初期化)
tracker = None

# BGR→RGB色変換のヘルパー関数
def bgr_to_rgb(color_bgr):
    """BGRカラーをRGBカラーに変換"""
    return (color_bgr[2], color_bgr[1], color_bgr[0])

# IDから色を生成する関数
def get_color_from_id(track_id):
    """IDをハッシュ化してHSV色空間の色を生成(視認性重視)"""
    hue = int((track_id * 37) % 180)
    hsv = np.uint8([[[hue, 255, 255]]])
    bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)[0][0]
    return (int(bgr[0]), int(bgr[1]), int(bgr[2]))

# person用の色
PERSON_COLOR = (0, 255, 0)

# 日本語フォント設定
FONT_PATH = 'C:/Windows/Fonts/meiryo.ttc'
FONT_SIZE_MAIN = 16
font_main = ImageFont.truetype(FONT_PATH, FONT_SIZE_MAIN)

# グローバル変数
frame_count = 0
results_log = []
person_count = 0
model = None
processor = None
id2label = {}


class ThreadedVideoCapture:
    """スレッド化されたVideoCapture(常に最新フレームを取得)"""
    def __init__(self, src, is_camera=False):
        if is_camera:
            self.cap = cv2.VideoCapture(src, cv2.CAP_DSHOW)
            fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
            self.cap.set(cv2.CAP_PROP_FOURCC, fourcc)
            self.cap.set(cv2.CAP_PROP_FPS, 60)
        else:
            self.cap = cv2.VideoCapture(src)

        self.grabbed, self.frame = self.cap.read()
        self.stopped = False
        self.lock = threading.Lock()
        self.thread = threading.Thread(target=self.update, args=())
        self.thread.daemon = True
        self.thread.start()

    def update(self):
        """バックグラウンドでフレームを取得し続ける"""
        while not self.stopped:
            grabbed, frame = self.cap.read()
            with self.lock:
                self.grabbed = grabbed
                if grabbed:
                    self.frame = frame

    def read(self):
        """最新フレームを返す"""
        with self.lock:
            return self.grabbed, self.frame.copy() if self.grabbed else None

    def isOpened(self):
        return self.cap.isOpened()

    def get(self, prop):
        return self.cap.get(prop)

    def release(self):
        self.stopped = True
        self.thread.join()
        self.cap.release()


def display_program_header():
    print('=' * 60)
    print('=== RT-DETRv2オブジェクト検出プログラム ===')
    print('=' * 60)
    print('概要: CLAHEとTTAを適用し、リアルタイムでオブジェクトを検出します')
    print('機能: RT-DETRv2による物体検出(person検出専用)')
    print('技術: CLAHE (コントラスト強化) + ByteTrack による追跡 + TTA (Test Time Augmentation) + Transformers + RT-DETRv2')
    print('操作: qキーで終了')
    print('出力: 各フレームごとに処理結果を表示し、終了時にresult.txtへ保存')
    print()


# ===== 共通処理関数 =====
def prepare_image_inputs(image_bgr):
    """画像をPIL形式に変換し、processor処理を行い、デバイスに移動する共通処理"""
    img_pil = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
    inputs = processor(images=img_pil, return_tensors='pt')
    for k in inputs:
        if isinstance(inputs[k], torch.Tensor):
            inputs[k] = inputs[k].to(device)
    return inputs


# ===== TTA機能 =====
def apply_tta_inference(frame, model, processor, id2label, conf_thresh):
    """Test Time Augmentation (TTA)を適用した推論"""
    frame_width = frame.shape[1]

    # 元画像の推論
    inputs = prepare_image_inputs(frame)

    with torch.no_grad():
        outputs = model(**inputs)

    target_sizes = torch.tensor([frame.shape[:2]], device=device)
    results_orig = processor.post_process_object_detection(
        outputs,
        target_sizes=target_sizes,
        threshold=conf_thresh
    )[0]

    # 水平反転画像の推論
    flipped_frame = cv2.flip(frame, 1)
    inputs_flipped = prepare_image_inputs(flipped_frame)

    with torch.no_grad():
        outputs_flipped = model(**inputs_flipped)

    results_flipped = processor.post_process_object_detection(
        outputs_flipped,
        target_sizes=target_sizes,
        threshold=conf_thresh
    )[0]

    # 結果を結合(personのみ)
    all_boxes = []
    all_confs = []
    all_labels = []

    # 元画像の結果
    if len(results_orig['scores']) > 0:
        person_mask = results_orig['labels'] == 0
        if person_mask.any():
            all_boxes.append(results_orig['boxes'][person_mask])
            all_confs.append(results_orig['scores'][person_mask])
            all_labels.append(results_orig['labels'][person_mask])

    # 反転画像の結果(座標を元に戻す)
    if len(results_flipped['scores']) > 0:
        person_mask = results_flipped['labels'] == 0
        if person_mask.any():
            boxes_flipped = results_flipped['boxes'][person_mask].clone()
            if boxes_flipped.shape[0] > 0:
                # x1 < x2の関係を保つための修正
                x1_flipped = boxes_flipped[:, 0].clone()
                x2_flipped = boxes_flipped[:, 2].clone()
                boxes_flipped[:, 0] = frame_width - 1 - x2_flipped  # 新しいx1(左端)
                boxes_flipped[:, 2] = frame_width - 1 - x1_flipped  # 新しいx2(右端)

            all_boxes.append(boxes_flipped)
            all_confs.append(results_flipped['scores'][person_mask])
            all_labels.append(results_flipped['labels'][person_mask])

    if len(all_boxes) == 0:
        return []

    # 全ての結果を結合
    all_boxes = torch.cat(all_boxes, dim=0)
    all_confs = torch.cat(all_confs, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    # 信頼度閾値でフィルタリング(NMS前に実施)
    valid_indices = all_confs > conf_thresh
    if valid_indices.sum() > 0:
        all_boxes = all_boxes[valid_indices]
        all_confs = all_confs[valid_indices]
        all_labels = all_labels[valid_indices]

        # NMSを適用
        nms_indices = torchvision.ops.nms(all_boxes, all_confs, iou_threshold=NMS_THRESHOLD)
        final_boxes = all_boxes[nms_indices].cpu().numpy()
        final_confs = all_confs[nms_indices].cpu().numpy()
        final_labels = all_labels[nms_indices].cpu().numpy()

        # 結果をリスト形式に変換
        detections = []
        for i in range(len(final_confs)):
            conf_boost = TTA_CONF_BOOST if TTA_ENABLED else 0
            x1, y1, x2, y2 = map(int, final_boxes[i])
            cls = int(final_labels[i])
            name = id2label.get(cls, str(cls))
            detections.append({
                'x1': x1, 'y1': y1,
                'x2': x2, 'y2': y2,
                'conf': min(1.0, final_confs[i] + conf_boost),
                'class': cls,
                'name': name
            })

        return detections

    return []


def normal_inference(frame, model, processor, id2label, conf_thresh):
    """通常の推論処理"""
    inputs = prepare_image_inputs(frame)

    with torch.no_grad():
        outputs = model(**inputs)

    target_sizes = torch.tensor([frame.shape[:2]], device=device)
    results = processor.post_process_object_detection(
        outputs,
        target_sizes=target_sizes,
        threshold=conf_thresh
    )[0]

    curr_dets = []
    if len(results['scores']) > 0:
        scores = results['scores'].cpu().numpy()
        labels = results['labels'].cpu().numpy()
        boxes = results['boxes'].cpu().numpy()

        order = np.argsort(scores)[::-1]
        for i in order:
            cls = int(labels[i])
            if cls != 0:
                continue
            x1, y1, x2, y2 = map(int, boxes[i])
            conf_score = float(scores[i])
            name = id2label.get(cls, str(cls))
            curr_dets.append({
                'x1': x1, 'y1': y1,
                'x2': x2, 'y2': y2,
                'conf': conf_score,
                'class': cls,
                'name': name
            })

    return curr_dets


def apply_tta_if_enabled(frame, model, processor, id2label, conf_thresh):
    """TTA機能を条件付きで適用"""
    if not TTA_ENABLED:
        return normal_inference(frame, model, processor, id2label, conf_thresh)
    return apply_tta_inference(frame, model, processor, id2label, conf_thresh)


# ===== トラッキング機能 =====
def apply_bytetrack(detections, frame):
    """ByteTrackerを使用したトラッキング処理"""
    global tracker

    if len(detections) > 0:
        dets_array = np.array([[d['x1'], d['y1'], d['x2'], d['y2'], d['conf'], d['class']]
                               for d in detections])
    else:
        dets_array = np.empty((0, 6))

    tracks = tracker.update(dets_array, frame)

    tracked_dets = []
    if len(tracks) > 0:
        for track in tracks:
            if len(track) >= 7:
                x1, y1, x2, y2, track_id, conf, cls = track[:7]
                name = id2label.get(int(cls), str(int(cls)))
                tracked_dets.append({
                    'x1': int(x1), 'y1': int(y1),
                    'x2': int(x2), 'y2': int(y2),
                    'track_id': int(track_id),
                    'conf': float(conf),
                    'class': int(cls),
                    'name': name
                })
    return tracked_dets


def apply_tracking_if_enabled(detections, frame):
    """トラッキング機能を条件付きで適用"""
    if not USE_TRACKER:
        return detections
    return apply_bytetrack(detections, frame)


# ===== 物体検出タスク固有の処理 =====
def draw_detection_results(frame, detections):
    """物体検出の描画処理"""
    # バウンディングボックスを描画(OpenCVで)
    for det in detections:
        # トラッキング有効時はIDに基づく色、無効時は緑色を使用
        if USE_TRACKER:
            box_color = get_color_from_id(det['track_id'])
        else:
            box_color = PERSON_COLOR

        cv2.rectangle(frame, (det['x1'], det['y1']),
                      (det['x2'], det['y2']), box_color, 2)

    # 構造化されたテキスト描画を実行
    texts_to_draw = []
    for det in detections:
        track_id = det.get('track_id', 0) if USE_TRACKER else 0
        if USE_TRACKER and track_id > 0:
            label = f"ID:{track_id} 人: {det['conf']:.2f}"
        else:
            label = f"人: {det['conf']:.2f}"

        texts_to_draw.append({
            'text': label,
            'org': (det['x1'], det['y1']-20),
            'color': bgr_to_rgb(PERSON_COLOR),
            'font_type': 'main'
        })
    frame = draw_texts_with_pillow(frame, texts_to_draw)

    # 統計情報を描画
    tta_status = "TTA:ON" if TTA_ENABLED else "TTA:OFF"
    tracker_status = "ByteTrack:ON" if USE_TRACKER else "ByteTrack:OFF"
    info_text = f"Objects: {len(detections)} | Frame: {frame_count} | Classes: {len(set(d['name'] for d in detections)) if detections else 0} | {tta_status} | {tracker_status}"
    cv2.putText(frame, info_text, (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

    return frame


def format_detection_output(detections):
    """物体検出の出力フォーマット"""
    if len(detections) == 0:
        return 'count=0'
    else:
        parts = []
        for det in detections:
            x1, y1, x2, y2 = det['x1'], det['y1'], det['x2'], det['y2']
            class_name = det['name']
            conf = det['conf']
            parts.append(f'class={class_name},conf={conf:.3f},box=[{x1},{y1},{x2},{y2}]')
        return f'count={len(detections)}; ' + ' | '.join(parts)


# ===== 共通処理関数 =====
def draw_texts_with_pillow(bgr_frame, texts):
    """テキスト描画, texts: list of dict with keys {text, org, color, font_type}"""
    img_pil = Image.fromarray(cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB))
    draw = ImageDraw.Draw(img_pil)

    for item in texts:
        text = item['text']
        x, y = item['org']
        color = item['color']  # RGB
        draw.text((x, y), text, font=font_main, fill=color)

    return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)


def detect_objects(frame):
    """共通の検出処理(CLAHE、推論、検出を実行)"""
    global model, processor, id2label

    # AIモデルの入力用にCLAHEを適用(YUV色空間で輝度チャンネルのみ処理)
    yuv_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2YUV)
    yuv_frame[:, :, 0] = clahe.apply(yuv_frame[:, :, 0])
    enh_frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR)

    # TTA機能を条件付きで適用
    curr_dets = apply_tta_if_enabled(enh_frame, model, processor, id2label, CONF_THRESH)

    return curr_dets


def process_video_frame(frame):
    """動画用ラッパー"""
    # 共通の検出処理
    detections = detect_objects(frame)

    # トラッキングを条件付きで適用
    tracked_dets = apply_tracking_if_enabled(detections, frame)

    # person検出数を更新
    global person_count
    person_count += len(tracked_dets)

    # 物体検出固有の描画処理
    frame = draw_detection_results(frame, tracked_dets)

    # 物体検出固有の出力フォーマット
    result = format_detection_output(tracked_dets)

    return frame, result


def video_frame_processing(frame, timestamp_ms, is_camera):
    """動画フレーム処理(標準形式)"""
    global frame_count
    current_time = time.time()
    frame_count += 1

    processed_frame, result = process_video_frame(frame)
    return processed_frame, result, current_time


# プログラムヘッダー表示
display_program_header()

# モデル選択(対話的実装)
print("\n=== RT-DETRv2モデル選択 ===")
print('使用するRT-DETRv2モデルを選択してください:')
for key, info in MODEL_INFO.items():
    print(f'{key}: {info["name"]} ({info["desc"]})')
print()

model_choice = ''
while model_choice not in MODEL_INFO.keys():
    model_choice = input("選択 (1/2/3) [デフォルト: 1]: ").strip()
    if model_choice == '':
        model_choice = '1'
        break
    if model_choice not in MODEL_INFO.keys():
        print("無効な選択です。もう一度入力してください。")

# ByteTrackとTTAの設定選択
print("\n=== 機能設定 ===")
print("1: ByteTrack, TTA (Test time augmentation) 無効化")
print("2: ByteTrack, TTA (Test time augmentation) 有効化")
print()

feature_choice = input("選択 (1/2) [デフォルト: 1]: ").strip()
if feature_choice == '2':
    TTA_ENABLED = True
    USE_TRACKER = True
    print("\nByteTrackとTTAを有効化しました")
else:
    TTA_ENABLED = False
    USE_TRACKER = False
    print("\nByteTrackとTTAを無効化しました(デフォルト)")

# ByteTrackトラッカーを初期化
if USE_TRACKER:
    tracker = ByteTrack()

# モデルの初期化
print(f"\nRT-DETRv2モデルをロード中...")
try:
    model_name = MODEL_INFO[model_choice]['name']
    model = RTDetrV2ForObjectDetection.from_pretrained(model_name)
    processor = AutoImageProcessor.from_pretrained(model_name)
    model.to(device)
    model.eval()

    # ラベルマッピング
    id2label = {int(k): v for k, v in model.config.id2label.items()}

    print(f"\n検出対象: person")
    print(f"モデル情報: {MODEL_INFO[model_choice]['desc']}")
    print("モデルのロード完了")
except Exception as e:
    print(f"モデルのロードに失敗しました: {e}")
    raise SystemExit(1)

# TTA設定の表示
if TTA_ENABLED:
    print("\nTest Time Augmentation (TTA): 有効")
    print("  - 水平反転による推論結果の統合")
    print(f"  - 信頼度ブースト値: {TTA_CONF_BOOST}")
    print(f"  - NMS閾値: {NMS_THRESHOLD}")
else:
    print("\nTest Time Augmentation (TTA): 無効")

# ByteTrack設定の表示
if USE_TRACKER:
    print("\nByteTrack: 有効")
    print("  - カルマンフィルタによる動き予測")
    print("  - IDごとに異なる色でバウンディングボックスを表示")
else:
    print("\nByteTrack: 無効")

# 入力選択
print("\n=== RT-DETRv2リアルタイム物体検出(person検出専用) ===")
print("0: 動画ファイル")
print("1: カメラ")
print("2: サンプル動画")

choice = input("選択: ")

is_camera = (choice == '1')

if choice == '0':
    root = tk.Tk()
    root.withdraw()
    path = filedialog.askopenfilename()
    if not path:
        raise SystemExit(1)
    cap = cv2.VideoCapture(path)
elif choice == '1':
    cap = ThreadedVideoCapture(0, is_camera=True)
else:
    SAMPLE_URL = 'https://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.avi'
    SAMPLE_FILE = 'vtest.avi'
    print('サンプル動画をダウンロード中...')
    urllib.request.urlretrieve(SAMPLE_URL, SAMPLE_FILE)
    cap = cv2.VideoCapture(SAMPLE_FILE)

if not cap.isOpened():
    print('動画ファイル・カメラを開けませんでした')
    raise SystemExit(1)

# フレームレートの取得とタイムスタンプ増分の計算
if is_camera:
    actual_fps = cap.get(cv2.CAP_PROP_FPS)
    print(f'カメラのfps: {actual_fps}')
    timestamp_increment = int(1000 / actual_fps) if actual_fps > 0 else 33
else:
    video_fps = cap.get(cv2.CAP_PROP_FPS)
    timestamp_increment = int(1000 / video_fps) if video_fps > 0 else 33

# メイン処理
print('\n=== 動画処理開始 ===')
print('操作方法:')
print('  q キー: プログラム終了')

start_time = time.time()
last_info_time = start_time
info_interval = 10.0  # 10秒ごとに表示
timestamp_ms = 0
total_processing_time = 0.0

try:
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        timestamp_ms += timestamp_increment

        processing_start = time.time()
        processed_frame, result, current_time = video_frame_processing(frame, timestamp_ms, is_camera)
        processing_time = time.time() - processing_start
        total_processing_time += processing_time

        cv2.imshow(WINDOW_NAME, processed_frame)

        if result:
            if is_camera:
                timestamp = datetime.fromtimestamp(current_time).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
                print(f'{timestamp}, {result}')
            else:
                print(f'Frame {frame_count}: {result}')

            results_log.append(result)

        # 情報提供(カメラモードのみ、info_interval秒ごと)
        if is_camera:
            elapsed = current_time - last_info_time
            if elapsed >= info_interval:
                total_elapsed = current_time - start_time
                actual_fps = frame_count / total_elapsed if total_elapsed > 0 else 0
                avg_processing_time = (total_processing_time / frame_count * 1000) if frame_count > 0 else 0
                print(f'[情報] 経過時間: {total_elapsed:.1f}秒, 処理フレーム数: {frame_count}, 実測fps: {actual_fps:.1f}, 平均処理時間: {avg_processing_time:.1f}ms')
                last_info_time = current_time

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

finally:
    print('\n=== プログラム終了 ===')
    cap.release()
    cv2.destroyAllWindows()

    if results_log:
        with open('result.txt', 'w', encoding='utf-8') as f:
            f.write('=== RT-DETRv2物体検出結果 ===\n')
            f.write(f'処理フレーム数: {frame_count}\n')
            f.write(f'使用モデル: {model_name}\n')
            f.write(f'モデル情報: {MODEL_INFO[model_choice]["desc"]}\n')
            f.write(f'使用デバイス: {str(device).upper()}\n')
            if device.type == 'cuda':
                f.write(f'GPU: {torch.cuda.get_device_name(0)}\n')
            f.write(f'画像処理: CLAHE適用(YUV色空間)\n')
            f.write(f'TTA (Test Time Augmentation): {"有効" if TTA_ENABLED else "無効"}\n')
            if TTA_ENABLED:
                f.write(f'  - NMS閾値: {NMS_THRESHOLD}\n')
                f.write(f'  - 信頼度ブースト: {TTA_CONF_BOOST}\n')
            f.write(f'ByteTrack: {"有効" if USE_TRACKER else "無効"}\n')
            if USE_TRACKER:
                f.write(f'  - IDごとに異なる色でバウンディングボックスを表示\n')
            f.write(f'信頼度閾値: {CONF_THRESH}\n')
            f.write(f'\n検出されたクラス:\n')
            f.write(f'  人 (person): {person_count}回\n')
            if is_camera:
                f.write('形式: タイムスタンプ, 検出結果\n')
            else:
                f.write('形式: フレーム番号, 検出結果\n')
            f.write('\n')
            f.write('\n'.join(results_log))
        print(f'\n処理結果をresult.txtに保存しました')
        print(f'検出されたperson数: {person_count}')