U-2-Net (U Square Net) による動画用オブジェクト顕著性検出(ソースコードと実行結果)


Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install opencv-python pillow numpy requests

事前学習済みモデルのダウンロード

次のURLからダウンロード

https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view

U-2-Net (U Square Net) による動画用オブジェクト顕著性検出プログラム

概要

動画像中の各フレームから人間の視覚的注意を引く領域を自動的に検出する能力を持つ。U-2-Netの深層学習モデルにより、物体の境界を高精度に認識し、背景から主要な物体を分離する。時間的フィルタリングにより、フレーム間の一貫性を保ちながら動画全体で安定した顕著性検出を実現する。

主要技術

参考文献

[1] Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O., & Jagersand, M. (2020). U2-Net: Going deeper with nested U-structure for salient object detection. Pattern Recognition, 106, 107404.

[2] Borji, A., Cheng, M. M., Jiang, H., & Li, J. (2015). Salient object detection: A benchmark. IEEE Transactions on Image Processing, 24(12), 5706-5722.

ソースコード


# プログラム名: U-2-Net (U Square Net) による動画用オブジェクト顕著性検出プログラム
# 特徴技術名: U-2-Net (U Square Net)
# 出典: Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O., & Jagersand, M. (2020). U2-Net: Going deeper with nested U-structure for salient object detection. Pattern Recognition, 106, 107404
# 特徴機能: Nested U-Structure with ReSidual U-blocks (RSU)による多スケール特徴抽出。2レベルのネストしたU構造により、異なる深さとスケールの特徴を効果的に統合し、高解像度を維持しながら計算コストを抑制
# 学習済みモデル: u2net.pth(176.3MB)- DUTS-TR(10,553枚)で学習済み、高精度な顕著性検出性能、https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view
# 方式設計:
#   関連利用技術: OpenCV(動画処理・表示、BGR形式画像処理)、PyTorch(深層学習推論エンジン)、PIL(画像形式変換)、NumPy(配列演算)
#   入力と出力: 入力: 動画(ユーザは「0:動画ファイル,1:カメラ,2:サンプル動画」のメニューで選択.0:動画ファイルの場合はtkinterでファイル選択.1の場合はOpenCVでカメラが開く.2の場合はhttps://github.com/opencv/opencv/blob/master/samples/data/vtest.aviを使用)、出力: 処理結果が画像化できる場合にはOpenCV画面でリアルタイムに表示.OpenCV画面内に処理結果をテキストで表示.さらに,1秒間隔で,print()で処理結果を表示.プログラム終了時にprint()で表示した処理結果をresult.txtファイルに保存し,「result.txtに保存」したことをprint()で表示.プログラム開始時に,プログラムの概要,ユーザが行う必要がある操作(もしあれば)をprint()で表示
#   処理手順: フレーム取得→前処理(320x320リサイズ・正規化)→U-2-Net推論(6段階エンコーダ・5段階デコーダ・顕著性マップ融合)→後処理(Sigmoid活性化・正規化)→顕著性マップ生成→カラーマップ適用→表示
#   前処理、後処理: 前処理: 320x320固定サイズリサイズ・0-1正規化、後処理: Sigmoid活性化・最大最小値正規化・カラーマップ適用
#   追加処理: フレーム間の時間的一貫性を保つための指数移動平均フィルタリング(α=0.7)
#   調整を必要とする設定値: CONFIDENCE_THRESHOLD(顕著性閾値、0.0-1.0、デフォルト0.5)
# 将来方策: Otsuの閾値法による自動閾値決定機能の実装
# その他の重要事項: U-2-Netモデルが利用できない場合はOpenCVのStaticSaliencyFineGrainedを代替使用
# 前準備: pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
# pip install opencv-python pillow numpy requests

import sys
import io
import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tkinter as tk
from tkinter import filedialog
from PIL import Image
import requests
import time
import urllib.request

# 設定値(利用者が調整可能)
CONFIDENCE_THRESHOLD = 0.5    # 顕著性閾値(0.0-1.0)- 物体検出の感度調整
COLOR_MAP = cv2.COLORMAP_JET  # カラーマップ(JET, HOT, COOL等)- 顕著性の可視化色
FILTER_ALPHA = 0.7            # 時間的フィルタリング強度(0.0-1.0)- 値が大きいほど現在フレームの影響が強い
MODEL_PATH = 'u2net.pth'      # モデルファイルパス
MODEL_URL = 'https://drive.usercontent.google.com/download?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ&export=download&authuser=0'

# Windows文字エンコーディング設定
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', line_buffering=True)


def _upsample_like(src, tar):
    """アップサンプリング関数"""
    src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False)
    return src


class REBNCONV(nn.Module):
    """ReLU-BatchNorm-Conv基本ブロック"""

    def __init__(self, in_ch=3, out_ch=3, dirate=1):
        super(REBNCONV, self).__init__()

        self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1*dirate, dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)

    def forward(self, x):
        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

        return xout


class RSU7(nn.Module):
    """U-2-Netの基本ブロック RSU-7"""

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv6d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv5d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)

    def forward(self, x):
        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)

        hx6 = self.rebnconv6(hx)

        hx7 = self.rebnconv7(hx6)

        hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
        hx6dup = _upsample_like(hx6d, hx5)

        hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        return hx1d + hxin


class RSU6(nn.Module):
    """U-2-Netの基本ブロック RSU-6"""

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU6, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv5d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)

    def forward(self, x):
        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)

        hx6 = self.rebnconv6(hx5)

        hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        return hx1d + hxin


class RSU5(nn.Module):
    """U-2-Netの基本ブロック RSU-5"""

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU5, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv4d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)

    def forward(self, x):
        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)

        hx5 = self.rebnconv5(hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        return hx1d + hxin


class RSU4(nn.Module):
    """U-2-Netの基本ブロック RSU-4"""

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)

    def forward(self, x):
        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))

        return hx1d + hxin


class RSU4F(nn.Module):
    """U-2-Netの基本ブロック RSU-4F"""

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F, self).__init__()

        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)

        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)

        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)

    def forward(self, x):
        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
        hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
        hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))

        return hx1d + hxin


class U2NET(nn.Module):
    """U-2-Net メインネットワーク"""

    def __init__(self, in_ch=3, out_ch=1):
        super(U2NET, self).__init__()

        self.stage1 = RSU7(in_ch, 32, 64)
        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage2 = RSU6(64, 32, 128)
        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage3 = RSU5(128, 64, 256)
        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage4 = RSU4(256, 128, 512)
        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage5 = RSU4F(512, 256, 512)
        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.stage6 = RSU4F(512, 256, 512)

        # decoder
        self.stage5d = RSU4F(1024, 256, 512)
        self.stage4d = RSU4(1024, 128, 256)
        self.stage3d = RSU5(512, 64, 128)
        self.stage2d = RSU6(256, 32, 64)
        self.stage1d = RSU7(128, 16, 64)

        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
        self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
        self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
        self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
        self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)

        self.outconv = nn.Conv2d(6*out_ch, out_ch, 1)

    def forward(self, x):

        hx = x

        # stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        # stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        # stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        # stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        # stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        # stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6, hx5)

        # -------------------- decoder --------------------
        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
        hx5dup = _upsample_like(hx5d, hx4)

        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
        hx4dup = _upsample_like(hx4d, hx3)

        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
        hx3dup = _upsample_like(hx3d, hx2)

        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
        hx2dup = _upsample_like(hx2d, hx1)

        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))

        # side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2, d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3, d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4, d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5, d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6, d1)

        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))

        return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)


def normPRED(d):
    """予測結果の正規化"""
    ma = torch.max(d)
    mi = torch.min(d)

    dn = (d - mi) / (ma - mi)

    return dn


def preprocess_frame(frame):
    """フレーム前処理"""
    # 320x320にリサイズ
    frame_resized = cv2.resize(frame, (320, 320))

    # RGB変換と正規化
    frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
    frame_normalized = frame_rgb.astype(np.float32) / 255.0

    # テンソル変換
    frame_tensor = torch.from_numpy(frame_normalized).permute(2, 0, 1).unsqueeze(0)

    return frame_tensor


def postprocess_saliency(saliency_map, original_shape):
    """顕著性マップ後処理"""
    # CPU、NumPy変換
    sal_map = saliency_map.squeeze().cpu().data.numpy()

    # 元のサイズにリサイズ
    sal_map_resized = cv2.resize(sal_map, (original_shape[1], original_shape[0]))

    # 閾値適用
    sal_map_thresholded = np.where(sal_map_resized > CONFIDENCE_THRESHOLD,
                                   sal_map_resized, 0)

    # 0-255スケール変換
    sal_map_scaled = (sal_map_thresholded * 255).astype(np.uint8)

    # カラーマップ適用
    sal_map_colored = cv2.applyColorMap(sal_map_scaled, COLOR_MAP)

    return sal_map_colored, np.mean(sal_map_resized)


def download_model():
    """学習済みモデルの自動ダウンロード"""
    if not os.path.exists(MODEL_PATH):
        print('U-2-Netモデルをダウンロード中...')
        print('ファイルサイズ: 約176MB')

        response = requests.get(MODEL_URL, stream=True)
        total_size = int(response.headers.get('content-length', 0))

        with open(MODEL_PATH, 'wb') as f:
            downloaded = 0
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)
                    downloaded += len(chunk)
                    if total_size > 0:
                        percent = (downloaded / total_size) * 100
                        print(f'\rダウンロード進捗: {percent:.1f}%', end='')

        print('\nダウンロード完了')

        # ファイルサイズ確認
        file_size = os.path.getsize(MODEL_PATH)
        if file_size < 100 * 1024 * 1024:  # 100MB未満の場合は異常
            os.remove(MODEL_PATH)
            print('ダウンロードに失敗しました。ファイルサイズが異常です。')
            print('手動でダウンロードしてください。')
            return False

        return True
    return True


def download_model_guide():
    """学習済みモデルのダウンロード案内"""
    if not os.path.exists(MODEL_PATH):
        print('=' * 60)
        print('U-2-Netモデルの手動ダウンロードが必要です')
        print('=' * 60)
        print('以下の手順でモデルをダウンロードしてください:')
        print('1. ブラウザで以下のURLにアクセス:')
        print('   https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view')
        print("2. 'ダウンロード'ボタンをクリック")
        print("3. ダウンロードした'u2net.pth'ファイルをこのPythonファイルと同じフォルダに配置")
        print('4. プログラムを再実行')
        print('=' * 60)
        print('注意: モデルファイルサイズは約176MBです')
        print('=' * 60)
        return False
    return True


def load_model():
    """U-2-Netモデルのロード"""
    # 自動ダウンロード試行
    if not download_model():
        # 自動ダウンロード失敗時は手動案内
        if not download_model_guide():
            return None

    # モデルロード
    model = U2NET(3, 1)

    if torch.cuda.is_available():
        model.cuda()
        model.load_state_dict(torch.load(MODEL_PATH, weights_only=False))
    else:
        model.load_state_dict(torch.load(MODEL_PATH, map_location='cpu', weights_only=False))

    model.eval()
    return model


def opencv_saliency_process(frame):
    """OpenCV組み込み顕著性検出"""
    # Fine Grained静的顕著性検出器を使用
    saliency = cv2.saliency.StaticSaliencyFineGrained_create()
    (success, saliency_map) = saliency.computeSaliency(frame)

    if success:
        # 閾値適用
        saliency_thresholded = np.where(saliency_map > CONFIDENCE_THRESHOLD,
                                       saliency_map, 0)

        # 0-255スケール変換
        saliency_scaled = (saliency_thresholded * 255).astype(np.uint8)

        # カラーマップ適用
        saliency_colored = cv2.applyColorMap(saliency_scaled, COLOR_MAP)

        return saliency_colored, np.mean(saliency_map)
    else:
        return np.zeros_like(frame), 0.0


def video_processing(frame, model, use_u2net, prev_saliency):
    """動画フレーム処理関数"""
    if use_u2net:
        # U-2-Net処理
        # 前処理
        input_tensor = preprocess_frame(frame)

        if torch.cuda.is_available():
            input_tensor = input_tensor.cuda()

        # 推論
        with torch.no_grad():
            d0, d1, d2, d3, d4, d5, d6 = model(input_tensor)
            pred = d0[:, 0, :, :]
            pred = normPRED(pred)

        # 後処理
        saliency_colored, mean_saliency = postprocess_saliency(pred, frame.shape)
    else:
        # OpenCV顕著性検出
        saliency_colored, mean_saliency = opencv_saliency_process(frame)

    # 時間的フィルタリング処理
    if prev_saliency is not None:
        saliency_colored = cv2.addWeighted(
            saliency_colored, FILTER_ALPHA,
            prev_saliency, 1 - FILTER_ALPHA, 0
        )

    return saliency_colored, mean_saliency


# プログラム開始時の概要表示
print('=' * 60)
print('U-2-Net動画用オブジェクト顕著性検出プログラム')
print('=' * 60)
print('概要: U-2-Netを使用して動画中の顕著なオブジェクトを検出します')
print('操作方法:')
print('  - qキー: プログラム終了')
print('  - 処理結果は2つのウィンドウで表示されます')
print('    1. Original: 元の動画')
print('    2. Saliency Map: 顕著性マップ(赤色が顕著な領域)')
print('=' * 60)

# メイン処理
print('U-2-Netモデルをロード中...')
model = load_model()

if model is None:
    print('U-2-Netが利用できません。OpenCV顕著性検出を使用します。')
    use_u2net = False
else:
    print('U-2-Netモデルロード完了')
    use_u2net = True

# 入力選択
print('0: 動画ファイル')
print('1: カメラ')
print('2: サンプル動画')

choice = input('選択: ')
temp_file = None

if choice == '0':
    root = tk.Tk()
    root.withdraw()
    path = filedialog.askopenfilename()
    if not path:
        exit()
    cap = cv2.VideoCapture(path)
elif choice == '1':
    cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
elif choice == '2':
    # サンプル動画ダウンロード・処理
    url = 'https://github.com/opencv/opencv/raw/master/samples/data/vtest.avi'
    filename = 'vtest.avi'
    try:
        urllib.request.urlretrieve(url, filename)
        temp_file = filename
        cap = cv2.VideoCapture(filename)
    except Exception as e:
        print(f'動画のダウンロードに失敗しました: {url}')
        print(f'エラー: {e}')
        exit()
else:
    print('無効な選択です')
    exit()

# 処理統計用変数
frame_count = 0
start_time = time.time()
last_print_time = start_time
results_log = []

# 時間的フィルタリング用の変数
prev_saliency = None

print("処理開始 ('q'キーで終了)")

# メイン処理ループ
try:
    while True:
        cap.grab()
        ret, frame = cap.retrieve()
        if not ret:
            break

        # フレーム処理
        saliency_colored, mean_saliency = video_processing(frame, model, use_u2net, prev_saliency)
        prev_saliency = saliency_colored.copy()

        # 統計情報更新
        frame_count += 1
        current_time = time.time()
        elapsed_time = current_time - start_time
        fps = frame_count / elapsed_time if elapsed_time > 0 else 0

        # OpenCV画面にテキスト表示
        info_text = f'FPS: {fps:.1f} | Mean Saliency: {mean_saliency:.3f}'
        cv2.putText(saliency_colored, info_text, (10, 30),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

        # 表示
        cv2.imshow('Original', frame)
        cv2.imshow('Saliency Map', saliency_colored)

        # 1秒間隔でprint表示
        if current_time - last_print_time >= 1.0:
            result_text = f'Frame: {frame_count}, FPS: {fps:.1f}, Mean Saliency: {mean_saliency:.3f}, Elapsed: {elapsed_time:.1f}s'
            print(result_text)
            results_log.append(result_text)
            last_print_time = current_time

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

finally:
    # 終了処理
    cap.release()
    cv2.destroyAllWindows()

    # 一時ファイル削除
    if temp_file and os.path.exists(temp_file):
        os.remove(temp_file)

    # 結果をファイルに保存
    if results_log:
        with open('result.txt', 'w', encoding='utf-8') as f:
            f.write('U-2-Net動画用オブジェクト顕著性検出プログラム実行結果\n')
            f.write('=' * 60 + '\n')
            f.write(f'使用モデル: {"U-2-Net" if use_u2net else "OpenCV StaticSaliencyFineGrained"}\n')
            f.write(f'総フレーム数: {frame_count}\n')
            f.write(f'総処理時間: {elapsed_time:.1f}秒\n')
            f.write(f'平均FPS: {fps:.1f}\n')
            f.write('=' * 60 + '\n')
            f.write('詳細ログ:\n')
            for log in results_log:
                f.write(log + '\n')
        print('result.txtに保存しました')