CoTracker3による動画ポイント追跡(ソースコードと実行結果)

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install hf_xet
pip install transformers pillow opencv-python

CoTracker3による動画ポイント追跡プログラム


# CoTracker3による動画ポイント追跡プログラム
# 特徴技術名: CoTracker3
# 出典: Karaev, N., Makarov, I., Wang, J., Rocco, I., Graham, B., Neverova, N., Vedaldi, A., & Rupprecht, C. (2024). CoTracker3: Simpler and Better Point Tracking by Pseudo-Labelling Real Videos. arXiv:2410.11831.
# 特徴機能: Co-tracking(協調追跡)機能。複数ポイントを相互関係を活用してグループとして追跡することで、オクルージョンや長期間追跡における安定性を大幅向上
# 学習済みモデル: CoTracker3 (online/offline), PyTorch Hub経由で利用可能, 従来モデルより1000倍少ないデータで訓練されながら最高性能を実現
# 方式設計
#   関連利用技術: OpenCV(動画処理), PyTorch(深層学習), tkinter(ファイル選択), urllib(ダウンロード)
#   入力と出力: 入力: 動画(ユーザは「0:動画ファイル,1:カメラ,2:サンプル動画」のメニューで選択.0:動画ファイルの場合はtkinterでファイル選択.1の場合はOpenCVでカメラが開く.2の場合はhttps://github.com/opencv/opencv/blob/master/samples/data/vtest.aviを使用), 出力: 動画でのポイント追跡結果をOpenCV画面でリアルタイム表示
#   処理手順: 1.動画フレーム読み込み → 2.CoTracker3モデル読み込み → 3.グリッドポイント設定 → 4.協調追跡実行 → 5.結果可視化
#   前処理、後処理: 前処理:動画フレームのテンソル変換とGPU転送, 後処理:追跡結果の可視化とファイル保存
#   追加処理: グリッドベースポイントサンプリング,協調追跡による相互関係活用,可視性判定
#   調整を必要とする設定値: grid_size(追跡グリッドサイズ,デフォルト10)
# 将来方策: 動的grid_size調整機能,カスタムポイント選択機能
# その他の重要事項: GPU使用推奨,メモリ効率のためのオンライン/オフライン選択
# 前準備: pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
# pip install opencv-python imageio[ffmpeg]

import cv2
import torch
import tkinter as tk
from tkinter import filedialog
import urllib.request
import os
import numpy as np

# グローバル変数
cotracker_model = None
device = None
grid_size = 10
result_log = []
frame_buffer = []
buffer_size = 32
tracking_results = None
current_frame_idx = 0
fps = 30  # 仮のfps。実際は動画ファイルなどから取得可能
processing_mode = 'offline'  # 現状固定(オンラインモデルを使うならここを 'online' に)
point_history = []  # (x, y, speed) のタプルのリスト
max_speed_global = 1.0  # 全履歴での最大速度(色正規化用)

def load_cotracker_model():
    global cotracker_model, device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    try:
        cotracker_model = torch.hub.load('facebookresearch/co-tracker', 'cotracker3_offline').to(device)
        print(f'CoTracker3オフラインモデル読み込み完了 (デバイス: {device})')
    except Exception as e:
        print(f'モデル読み込みエラー: {e}')
        exit()

def process_frame_buffer():
    global cotracker_model, device, grid_size, frame_buffer
    if len(frame_buffer) < buffer_size:
        return None
    try:
        frames_array = np.stack(frame_buffer)
        video_tensor = torch.tensor(frames_array).permute(0, 3, 1, 2)[None].float().to(device)
        pred_tracks, pred_visibility = cotracker_model(video_tensor, grid_size=grid_size)
        return pred_tracks, pred_visibility
    except Exception as e:
        print(f'バッファ処理エラー: {e}')
        return None

def get_speed_color(speed, max_speed):
    # 速度を0〜1に正規化
    ratio = min(speed / max_speed, 1.0)
    # 青(低速)→赤(高速)のグラデーション(HSVで色相を変化)
    # HSV色相: 120(青)→0(赤)
    hue = int(120 * (1 - ratio))
    hsv = np.uint8([[[hue, 255, 255]]])
    bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)[0][0]
    return int(bgr[0]), int(bgr[1]), int(bgr[2])

def video_processing(frame):
    global cotracker_model, grid_size, result_log, frame_buffer, tracking_results, current_frame_idx, point_history, max_speed_global

    if cotracker_model is None:
        return frame

    frame_buffer.append(frame.copy())
    if len(frame_buffer) > buffer_size:
        frame_buffer.pop(0)

    if len(frame_buffer) == buffer_size and current_frame_idx % buffer_size == 0:
        tracking_results = process_frame_buffer()
        if tracking_results is not None:
            print(f'追跡処理完了: {buffer_size}フレーム処理')

            # 速度情報付き履歴保存処理
            pred_tracks, pred_visibility = tracking_results
            num_points = pred_tracks.shape[2]

            # 速度計算のために座標をnumpy配列に変換
            tracks_np = pred_tracks[0].cpu().numpy()  # shape: (buffer_size, num_points, 2)

            # 各時刻の各点について履歴に追加
            for t in range(buffer_size):
                for pid in range(num_points):
                    if pred_visibility[0, t, pid] > 0.5:
                        x, y = int(tracks_np[t, pid, 0]), int(tracks_np[t, pid, 1])
                        if 0 <= x < frame.shape[1] and 0 <= y < frame.shape[0]:

                            # 速度計算(前フレームとの差分)
                            if t > 0 and pred_visibility[0, t-1, pid] > 0.5:
                                prev_x, prev_y = tracks_np[t-1, pid, 0], tracks_np[t-1, pid, 1]
                                speed = np.sqrt((x - prev_x)**2 + (y - prev_y)**2)
                            else:
                                speed = 0.0  # 初回フレームまたは前フレーム非可視の場合

                            # 最大速度更新
                            max_speed_global = max(max_speed_global, speed)

                            # 履歴に追加
                            point_history.append((x, y, speed))

    vis_frame = frame.copy()

    # 履歴の全点を速度に応じた色で描画
    for x, y, speed in point_history:
        color = get_speed_color(speed, max_speed_global)
        cv2.circle(vis_frame, (x, y), 2, color, -1)

    # 現在フレームの状態表示
    history_count = len(point_history)
    cv2.putText(vis_frame, f'履歴点数: {history_count}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)

    if tracking_results is None:
        status_text = f'バッファリング中: {len(frame_buffer)}/{buffer_size}'
        cv2.putText(vis_frame, status_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 1)

    # 結果ログ記録(バッファ処理完了時)
    if tracking_results is not None and current_frame_idx % buffer_size == 0:
        result_text = f'フレーム{current_frame_idx}: 履歴点数{history_count}'
        print(result_text)
        result_log.append(result_text)
        # 座標データのCSV形式での出力(バッファ処理完了時に1回のみ)
        # ヘッダー: PointID, t, x, y
        pred_tracks, pred_visibility = tracking_results
        num_points = pred_tracks.shape[2]
        for pid in range(num_points):
            for t in range(buffer_size):
                x, y = int(pred_tracks[0, t, pid, 0].item()), int(pred_tracks[0, t, pid, 1].item())
                result_log.append(f'{pid},{t},{x},{y}')

    current_frame_idx += 1
    return vis_frame

def main():
    global fps
    print('CoTracker3動画ポイント追跡プログラム')
    print('複数ポイントの協調追跡により、オクルージョンに対応した高精度追跡を実現')
    print('操作: qキーで終了')

    load_cotracker_model()

    print('0: 動画ファイル')
    print('1: カメラ')
    print('2: サンプル動画')
    choice = input('選択: ')
    temp_file = None

    if choice == '0':
        root = tk.Tk()
        root.withdraw()
        path = filedialog.askopenfilename()
        if not path:
            return
        cap = cv2.VideoCapture(path)
        fps = cap.get(cv2.CAP_PROP_FPS)
    elif choice == '1':
        cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
        cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
        fps = cap.get(cv2.CAP_PROP_FPS) or 30
    elif choice == '2':
        url = 'https://github.com/opencv/opencv/raw/master/samples/data/vtest.avi'
        filename = 'vtest.avi'
        try:
            urllib.request.urlretrieve(url, filename)
            temp_file = filename
            cap = cv2.VideoCapture(filename)
            fps = cap.get(cv2.CAP_PROP_FPS)
        except Exception as e:
            print(f'動画ダウンロード失敗: {e}')
            return
    else:
        print('無効な選択です')
        return

    try:
        while True:
            cap.grab()
            ret, frame = cap.retrieve()
            if not ret:
                break
            processed = video_processing(frame)
            cv2.imshow('CoTracker3 Video Tracking', processed)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
    finally:
        cap.release()
        cv2.destroyAllWindows()
        if temp_file:
            os.remove(temp_file)
        # result.txtにメタ情報とCSV形式の軌跡データを出力
        with open('result.txt', 'w', encoding='utf-8') as f:
            # メタ情報追記
            f.write(f'grid_size,{grid_size}\n')
            f.write(f'buffer_size,{buffer_size}\n')
            f.write(f'fps,{fps}\n')
            f.write(f'processing_mode,{processing_mode}\n')
            f.write(f'total_history_points,{len(point_history)}\n')
            f.write(f'max_speed_global,{max_speed_global}\n')
            # CSVヘッダー
            f.write('PointID,TimeFrame,X,Y\n')
            for log in result_log:
                f.write(log + '\n')
        print('result.txtに保存完了')

if __name__ == '__main__':
    main()