SAM2による前景背景分離(ソースコード)

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


pip install crepe librosa matplotlib numpy scipy sounddevice japanize-matplotlib tensorflow

SAM2による前景背景分離プログラム


#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SAM2による前景背景分離プログラム

特徴技術名:SAM2(Segment Anything Model 2)
出典:Ravi, N., Gabeur, V., Hu, Y. T., Hu, R., Ryali, C., Ma, T., ... & Feichtenhofer, C. (2024). SAM 2: Segment anything in images and videos. arXiv preprint arXiv:2408.00714.
特徴機能:promptable visual segmentation - 画像と動画の両方で、ユーザーからのプロンプト(点、ボックス、マスク等)に基づいてリアルタイムでオブジェクトをセグメンテーションする機能。

学習済みモデル:
- sam2_hiera_large.pt:最高精度モデル(636M parameters)- https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt
- sam2_hiera_base_plus.pt:バランス型モデル(308M parameters)- https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt
- sam2_hiera_small.pt:軽量モデル - https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt
- sam2_hiera_tiny.pt:最軽量モデル(91M parameters)- https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt

方式設計:
- 関連利用技術:
  - OpenCV:カメラ入力と画像処理(リアルタイム表示、マウスイベント処理)
  - PyTorch:深層学習フレームワーク(SAM2モデルの実行)
  - gitpython:Gitリポジトリの自動クローン(SAM2公式実装の取得)
  - PIL:画像処理ライブラリ(色空間変換)
  - NumPy:数値計算(配列操作、座標処理)
- 入力と出力:
  入力: 動画(ユーザは「0:動画ファイル,1:カメラ,2:サンプル動画」のメニューで選択.0:動画ファイルの場合はtkinterでファイル選択.1の場合はOpenCVでカメラが開く.2の場合はhttps://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.aviを使用)
  出力: リアルタイム前景背景分離結果をOpenCV画面で表示。元画像、前景画像、マスク画像を同時表示。OpenCV画面内に処理結果をテキストで表示。1秒間隔でprint()で処理結果を表示。プログラム終了時にprint()で表示した処理結果をresult.txtファイルに保存し、「result.txtに保存」したことをprint()で表示。プログラム開始時に、プログラムの概要、ユーザが行う必要がある操作をprint()で表示。
- 処理手順:
  1. gitpythonによるSAM2リポジトリの自動ダウンロード
  2. 事前学習済みモデル重みの自動取得
  3. SAM2モデルの初期化とメモリ配置
  4. 動画ソースからのフレーム取得
  5. マウスクリックによるプロンプトポイント設定
  6. SAM2によるセグメンテーション実行
  7. 前景抽出とマスク生成
  8. リアルタイム結果表示
- 前処理:RGB色空間変換(OpenCVのBGRからRGBへの変換)、フレームサイズ正規化
- 後処理:セグメンテーションマスクのスコア評価による最適マスク選択、前景領域の抽出処理
- 追加処理:メモリ効率化のためのGPU/CPU自動選択、マウスコールバックによるインタラクティブなプロンプト設定

調整を必要とする設定値:
- model_cfg:モデル設定ファイル('sam2_hiera_l.yaml'等)- 使用するSAM2モデルの種類を決定
- camera_id:カメラデバイス番号(通常は0)- 複数カメラ接続時の選択

将来方策:プログラム内でのモデル性能評価機能を追加し、利用可能なGPUメモリ量に基づいて最適なmodel_cfg(tiny/small/base/large)を自動選択する機能の実装

その他の重要事項:
- Windows環境専用設計
- CUDA対応GPU推奨(CPU動作も可能だが処理速度が大幅に低下)
- 初回実行時に数GB のモデルダウンロードが発生

前準備:
- pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
- pip install opencv-python numpy pillow gitpython requests
"""

import os
import sys
import cv2
import numpy as np
import torch
from pathlib import Path
import git
from PIL import Image
import urllib.request
import tkinter as tk
from tkinter import filedialog

# 設定値
base_dir = './sam2_models'
repo_url = 'https://github.com/facebookresearch/segment-anything-2.git'

# グローバル変数
predictor = None
point_coords_list = []
point_labels_list = []
results = []

def download_sam2_repository():
    repo_dir = Path(base_dir) / 'segment-anything-2'
    if repo_dir.exists():
        print('SAM2リポジトリは既に存在します')
        return str(repo_dir)

    Path(base_dir).mkdir(exist_ok=True)
    try:
        git.Repo.clone_from(repo_url, repo_dir)
        print('SAM2リポジトリのダウンロード完了')
        return str(repo_dir)
    except Exception as e:
        print(f'SAM2リポジトリのダウンロードに失敗しました: {e}')
        exit()

def download_model_weights():
    model_urls = {
        'sam2_hiera_tiny.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt',
        'sam2_hiera_small.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt',
        'sam2_hiera_base_plus.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt',
        'sam2_hiera_large.pt': 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt'
    }

    checkpoints_dir = Path(base_dir) / 'checkpoints'
    checkpoints_dir.mkdir(exist_ok=True)

    # 選択されたモデルのみダウンロード
    if model_file in model_urls:
        model_path = checkpoints_dir / model_file
        if model_path.exists():
            print(f'モデルは既に存在: {model_file}')
        else:
            print(f'モデルをダウンロード中: {model_file}')
            try:
                urllib.request.urlretrieve(model_urls[model_file], model_path)
                print(f'モデルダウンロード完了: {model_file}')
            except Exception as e:
                print(f'モデル重みのダウンロードに失敗しました: {e}')
                exit()

def initialize_sam2():
    global predictor

    repo_dir = download_sam2_repository()
    download_model_weights()

    # sys.pathにSAM2ディレクトリを追加
    sys.path.insert(0, repo_dir)

    # PYTHONPATHにSAM2リポジトリを追加
    if 'PYTHONPATH' in os.environ:
        os.environ['PYTHONPATH'] = f"{repo_dir}{os.pathsep}{os.environ['PYTHONPATH']}"
    else:
        os.environ['PYTHONPATH'] = repo_dir

    # 作業ディレクトリをSAM2リポジトリに変更
    original_cwd = os.getcwd()
    os.chdir(repo_dir)

    try:
        from sam2.build_sam import build_sam2
        from sam2.sam2_image_predictor import SAM2ImagePredictor

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f'使用デバイス: {device}')

        # チェックポイントファイルの絶対パスを設定
        checkpoint_path = str(Path(original_cwd) / base_dir / 'checkpoints' / model_file)

        sam2_model = build_sam2(model_cfg, checkpoint_path, device=device)
        predictor = SAM2ImagePredictor(sam2_model)
        print('SAM2モデルの初期化完了')

    except Exception as e:
        print(f'SAM2モデルの初期化に失敗しました: {e}')
        exit()
    finally:
        # 作業ディレクトリを元に戻す
        os.chdir(original_cwd)

def video_processing(frame):
    global predictor, point_coords_list, point_labels_list, results

    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    if len(point_coords_list) == 0:
        return frame

    try:
        predictor.set_image(frame_rgb)

        # 複数のプロンプトポイントを配列として渡す
        point_coords = np.array(point_coords_list)
        point_labels = np.array(point_labels_list)

        masks, scores, logits = predictor.predict(
            point_coords=point_coords,
            point_labels=point_labels,
            multimask_output=True
        )

        best_mask_idx = np.argmax(scores)
        mask = masks[best_mask_idx]

        # マスクをboolean型に変換
        mask_bool = mask.astype(bool)

        foreground = frame_rgb.copy()
        foreground[~mask_bool] = [0, 0, 0]
        foreground_bgr = cv2.cvtColor(foreground, cv2.COLOR_RGB2BGR)

        # マスクを3チャンネルに変換(uint8型に変換)
        mask_display = (mask_bool.astype(np.uint8) * 255)
        mask_3ch = cv2.cvtColor(mask_display, cv2.COLOR_GRAY2BGR)

        # 結果を横に並べて表示
        combined = np.hstack((frame, foreground_bgr, mask_3ch))

        # テキスト情報を画像に描画
        cv2.putText(combined, 'Original', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        cv2.putText(combined, 'Foreground', (frame.shape[1] + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        cv2.putText(combined, 'Mask', (frame.shape[1] * 2 + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

        # 全てのプロンプトポイントを描画(元画像部分に描画)
        for i, (coord, label) in enumerate(zip(point_coords_list, point_labels_list)):
            color = (0, 255, 0) if label == 1 else (0, 0, 255)
            # 元画像の座標にポイントを描画
            cv2.circle(combined, tuple(coord), 5, color, -1)
            # ポイント番号も表示
            cv2.putText(combined, str(i+1), (coord[0]+8, coord[1]-8), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)

        # デバッグ情報を画面に表示
        debug_text = f'Points: {len(point_coords_list)}'
        cv2.putText(combined, debug_text, (10, combined.shape[0] - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)

        # 全てのプロンプトポイントを描画(元画像部分に描画)
        for i, (coord, label) in enumerate(zip(point_coords_list, point_labels_list)):
            color = (0, 255, 0) if label == 1 else (0, 0, 255)
            # 元画像の座標にポイントを描画
            cv2.circle(combined, tuple(coord), 5, color, -1)
            # ポイント番号も表示
            cv2.putText(combined, str(i+1), (coord[0]+8, coord[1]-8), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)

        # 1秒間隔で結果を記録
        current_time = cv2.getTickCount() / cv2.getTickFrequency()
        if not hasattr(video_processing, 'last_log_time'):
            video_processing.last_log_time = 0

        if current_time - video_processing.last_log_time >= 1.0:
            result_text = f'前景背景分離実行中 - ポイント数: {len(point_coords_list)}, スコア: {scores[best_mask_idx]:.3f}'
            print(result_text)
            results.append(result_text)
            video_processing.last_log_time = current_time

        return combined

    except Exception as e:
        print(f'前景抽出エラー: {e}')
        return frame

def mouse_callback(event, x, y, flags, param):
    global point_coords_list, point_labels_list

    # 元画像の範囲内のクリックのみ処理(修正)
    if x > param.shape[1]:
        return

    if event == cv2.EVENT_LBUTTONDOWN:
        point_coords_list.append([x, y])
        point_labels_list.append(1)
        result_text = f'前景ポイント追加: ({x}, {y}) - 総数: {len(point_coords_list)}'
        print(result_text)
        results.append(result_text)
    elif event == cv2.EVENT_RBUTTONDOWN:
        point_coords_list.append([x, y])
        point_labels_list.append(0)
        result_text = f'背景ポイント追加: ({x}, {y}) - 総数: {len(point_coords_list)}'
        print(result_text)
        results.append(result_text)
    elif event == cv2.EVENT_MBUTTONDOWN:
        # 中ボタンクリックで全てのポイントをクリア
        point_coords_list.clear()
        point_labels_list.clear()
        result_text = '全ポイントクリア'
        print(result_text)
        results.append(result_text)

print('SAM2カメラ前景背景分離システム')
print('概要: 動画からSAM2を使用してリアルタイムで前景背景を分離します')
print('')
print('基本操作: 前景ポイント1つ + 背景ポイント1つを指定')
print('- 前景ポイント: 分離したいオブジェクトの上で左クリック(緑色の丸)')
print('- 背景ポイント: 除外したい部分で右クリック(赤色の丸)')
print('- 精度向上のため複数ポイント指定可能')
print('')
print('操作方法:')
print('- 左クリック: 前景ポイント追加(元画像部分のみ)')
print('- 右クリック: 背景ポイント追加(元画像部分のみ)')
print('- 中ボタンクリック: 全ポイントクリア')
print('- q: 終了')
print('')

print('モデル選択:')
print('1: sam2_hiera_tiny.pt (155.9MB, 91M parameters)')
print('2: sam2_hiera_small.pt (180MB, parameters未公開)')
print('3: sam2_hiera_base_plus.pt (320MB, 308M parameters)')
print('4: sam2_hiera_large.pt (900MB, 636M parameters)')

model_choice = input('モデル選択 (1-4): ')

if model_choice == '1':
    model_file = 'sam2_hiera_tiny.pt'
    model_cfg = 'sam2_hiera_t.yaml'
elif model_choice == '2':
    model_file = 'sam2_hiera_small.pt'
    model_cfg = 'sam2_hiera_s.yaml'
elif model_choice == '3':
    model_file = 'sam2_hiera_base_plus.pt'
    model_cfg = 'sam2_hiera_b+.yaml'
elif model_choice == '4':
    model_file = 'sam2_hiera_large.pt'
    model_cfg = 'sam2_hiera_l.yaml'
else:
    print('無効な選択です')
    exit()

print(f'選択されたモデル: {model_file}')
print('')

print('0: 動画ファイル')
print('1: カメラ')
print('2: サンプル動画')

choice = input('選択: ')
temp_file = None

initialize_sam2()

if choice == '0':
    root = tk.Tk()
    root.withdraw()
    path = filedialog.askopenfilename()
    if not path:
        exit()
    cap = cv2.VideoCapture(path)
elif choice == '1':
    cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
elif choice == '2':
    url = 'https://raw.githubusercontent.com/opencv/opencv/master/samples/data/vtest.avi'
    filename = 'vtest.avi'
    try:
        urllib.request.urlretrieve(url, filename)
        temp_file = filename
        cap = cv2.VideoCapture(filename)
        result_text = f'サンプル動画をダウンロード: {filename}'
        print(result_text)
        results.append(result_text)
    except Exception as e:
        print(f'動画のダウンロードに失敗しました: {url}')
        print(f'エラー: {e}')
        exit()
else:
    print('無効な選択です')
    exit()

# メイン処理
try:
    # 最初のフレームを取得してマウスコールバックを設定
    ret, first_frame = cap.read()
    if ret:
        cv2.namedWindow('Video', cv2.WINDOW_AUTOSIZE)
        cv2.setMouseCallback('Video', mouse_callback, first_frame)
        cap.set(cv2.CAP_PROP_POS_FRAMES, 0)

    while True:
        cap.grab()
        ret, frame = cap.retrieve()
        if not ret:
            break

        processed_frame = video_processing(frame)
        cv2.imshow('Video', processed_frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
finally:
    cap.release()
    cv2.destroyAllWindows()
    if temp_file:
        os.remove(temp_file)

# 結果をファイルに保存
with open('result.txt', 'w', encoding='utf-8') as f:
    for result in results:
        f.write(result + '\n')
print('result.txtに保存しました')

print('SAM2カメラ前景背景分離システム終了')