U-2-Net (U Square Net) による動画用オブジェクト顕著性検出(ソースコードと実行結果)


Python開発環境,ライブラリ類
ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。
Python 3.12 のインストール
インストール済みの場合は実行不要。
管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。
REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul
【関連する外部ページ】
Python の公式ページ: https://www.python.org/
AI エディタ Windsurf のインストール
Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。
管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。
winget install --scope machine Codeium.Windsurf -e --silent
【関連する外部ページ】
Windsurf の公式ページ: https://windsurf.com/
必要なライブラリのインストール
コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する
pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install opencv-python pillow numpy requests
事前学習済みモデルのダウンロード
次のURLからダウンロード
https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view
U-2-Net (U Square Net) による動画用オブジェクト顕著性検出プログラム
概要
動画像中の各フレームから人間の視覚的注意を引く領域を自動的に検出する能力を持つ。U-2-Netの深層学習モデルにより、物体の境界を高精度に認識し、背景から主要な物体を分離する。時間的フィルタリングにより、フレーム間の一貫性を保ちながら動画全体で安定した顕著性検出を実現する。
主要技術
- U-2-Net(U Square Net)
入れ子状のU構造(Nested U-Structure)による顕著性物体検出である[1]。従来のU-Netと異なり、各段階で異なる深さのU構造を組み合わせることで、多スケールの特徴を扱う。RSU(Residual U-blocks)と呼ばれる基本ブロックを使用し、浅い層では高解像度の詳細情報を、深い層では意味的な情報を学習する。
- 時間的フィルタリング
指数移動平均(Exponential Moving Average)を用いて連続するフレーム間の顕著性マップを平滑化する。現在フレームと前フレームの顕著性マップを重み付き平均することで、動画における急激な変化を抑制し、視覚的に安定した出力を生成する。
参考文献
[1] Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O., & Jagersand, M. (2020). U2-Net: Going deeper with nested U-structure for salient object detection. Pattern Recognition, 106, 107404.
[2] Borji, A., Cheng, M. M., Jiang, H., & Li, J. (2015). Salient object detection: A benchmark. IEEE Transactions on Image Processing, 24(12), 5706-5722.
ソースコード
# プログラム名: U-2-Net (U Square Net) による動画用オブジェクト顕著性検出プログラム
# 特徴技術名: U-2-Net (U Square Net)
# 出典: Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O., & Jagersand, M. (2020). U2-Net: Going deeper with nested U-structure for salient object detection. Pattern Recognition, 106, 107404
# 特徴機能: Nested U-Structure with ReSidual U-blocks (RSU)による多スケール特徴抽出。2レベルのネストしたU構造により、異なる深さとスケールの特徴を効果的に統合し、高解像度を維持しながら計算コストを抑制
# 学習済みモデル: u2net.pth(176.3MB)- DUTS-TR(10,553枚)で学習済み、高精度な顕著性検出性能、https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view
# 方式設計:
# 関連利用技術: OpenCV(動画処理・表示、BGR形式画像処理)、PyTorch(深層学習推論エンジン)、PIL(画像形式変換)、NumPy(配列演算)
# 入力と出力: 入力: 動画(ユーザは「0:動画ファイル,1:カメラ,2:サンプル動画」のメニューで選択.0:動画ファイルの場合はtkinterでファイル選択.1の場合はOpenCVでカメラが開く.2の場合はhttps://github.com/opencv/opencv/blob/master/samples/data/vtest.aviを使用)、出力: 処理結果が画像化できる場合にはOpenCV画面でリアルタイムに表示.OpenCV画面内に処理結果をテキストで表示.さらに,1秒間隔で,print()で処理結果を表示.プログラム終了時にprint()で表示した処理結果をresult.txtファイルに保存し,「result.txtに保存」したことをprint()で表示.プログラム開始時に,プログラムの概要,ユーザが行う必要がある操作(もしあれば)をprint()で表示
# 処理手順: フレーム取得→前処理(320x320リサイズ・正規化)→U-2-Net推論(6段階エンコーダ・5段階デコーダ・顕著性マップ融合)→後処理(Sigmoid活性化・正規化)→顕著性マップ生成→カラーマップ適用→表示
# 前処理、後処理: 前処理: 320x320固定サイズリサイズ・0-1正規化、後処理: Sigmoid活性化・最大最小値正規化・カラーマップ適用
# 追加処理: フレーム間の時間的一貫性を保つための指数移動平均フィルタリング(α=0.7)
# 調整を必要とする設定値: CONFIDENCE_THRESHOLD(顕著性閾値、0.0-1.0、デフォルト0.5)
# 将来方策: Otsuの閾値法による自動閾値決定機能の実装
# その他の重要事項: U-2-Netモデルが利用できない場合はOpenCVのStaticSaliencyFineGrainedを代替使用
# 前準備: pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
# pip install opencv-python pillow numpy requests
import sys
import io
import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tkinter as tk
from tkinter import filedialog
from PIL import Image
import requests
import time
import urllib.request
# 設定値(利用者が調整可能)
CONFIDENCE_THRESHOLD = 0.5 # 顕著性閾値(0.0-1.0)- 物体検出の感度調整
COLOR_MAP = cv2.COLORMAP_JET # カラーマップ(JET, HOT, COOL等)- 顕著性の可視化色
FILTER_ALPHA = 0.7 # 時間的フィルタリング強度(0.0-1.0)- 値が大きいほど現在フレームの影響が強い
MODEL_PATH = 'u2net.pth' # モデルファイルパス
MODEL_URL = 'https://drive.usercontent.google.com/download?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ&export=download&authuser=0'
# Windows文字エンコーディング設定
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', line_buffering=True)
def _upsample_like(src, tar):
"""アップサンプリング関数"""
src = F.interpolate(src, size=tar.shape[2:], mode='bilinear', align_corners=False)
return src
class REBNCONV(nn.Module):
"""ReLU-BatchNorm-Conv基本ブロック"""
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super(REBNCONV, self).__init__()
self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1*dirate, dilation=1*dirate)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self, x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
class RSU7(nn.Module):
"""U-2-Netの基本ブロック RSU-7"""
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU7, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv6d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv5d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
hx6dup = _upsample_like(hx6d, hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
class RSU6(nn.Module):
"""U-2-Netの基本ブロック RSU-6"""
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv5d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
class RSU5(nn.Module):
"""U-2-Netの基本ブロック RSU-5"""
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
class RSU4(nn.Module):
"""U-2-Netの基本ブロック RSU-4"""
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
class RSU4F(nn.Module):
"""U-2-Netの基本ブロック RSU-4F"""
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
self.rebnconv3d = REBNCONV(mid_ch*2, mid_ch, dirate=4)
self.rebnconv2d = REBNCONV(mid_ch*2, mid_ch, dirate=2)
self.rebnconv1d = REBNCONV(mid_ch*2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
return hx1d + hxin
class U2NET(nn.Module):
"""U-2-Net メインネットワーク"""
def __init__(self, in_ch=3, out_ch=1):
super(U2NET, self).__init__()
self.stage1 = RSU7(in_ch, 32, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 32, 128)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(128, 64, 256)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(256, 128, 512)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(512, 256, 512)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(512, 256, 512)
# decoder
self.stage5d = RSU4F(1024, 256, 512)
self.stage4d = RSU4(1024, 128, 256)
self.stage3d = RSU5(512, 64, 128)
self.stage2d = RSU6(256, 32, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
self.outconv = nn.Conv2d(6*out_ch, out_ch, 1)
def forward(self, x):
hx = x
# stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# -------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)
def normPRED(d):
"""予測結果の正規化"""
ma = torch.max(d)
mi = torch.min(d)
dn = (d - mi) / (ma - mi)
return dn
def preprocess_frame(frame):
"""フレーム前処理"""
# 320x320にリサイズ
frame_resized = cv2.resize(frame, (320, 320))
# RGB変換と正規化
frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
frame_normalized = frame_rgb.astype(np.float32) / 255.0
# テンソル変換
frame_tensor = torch.from_numpy(frame_normalized).permute(2, 0, 1).unsqueeze(0)
return frame_tensor
def postprocess_saliency(saliency_map, original_shape):
"""顕著性マップ後処理"""
# CPU、NumPy変換
sal_map = saliency_map.squeeze().cpu().data.numpy()
# 元のサイズにリサイズ
sal_map_resized = cv2.resize(sal_map, (original_shape[1], original_shape[0]))
# 閾値適用
sal_map_thresholded = np.where(sal_map_resized > CONFIDENCE_THRESHOLD,
sal_map_resized, 0)
# 0-255スケール変換
sal_map_scaled = (sal_map_thresholded * 255).astype(np.uint8)
# カラーマップ適用
sal_map_colored = cv2.applyColorMap(sal_map_scaled, COLOR_MAP)
return sal_map_colored, np.mean(sal_map_resized)
def download_model():
"""学習済みモデルの自動ダウンロード"""
if not os.path.exists(MODEL_PATH):
print('U-2-Netモデルをダウンロード中...')
print('ファイルサイズ: 約176MB')
response = requests.get(MODEL_URL, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(MODEL_PATH, 'wb') as f:
downloaded = 0
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
percent = (downloaded / total_size) * 100
print(f'\rダウンロード進捗: {percent:.1f}%', end='')
print('\nダウンロード完了')
# ファイルサイズ確認
file_size = os.path.getsize(MODEL_PATH)
if file_size < 100 * 1024 * 1024: # 100MB未満の場合は異常
os.remove(MODEL_PATH)
print('ダウンロードに失敗しました。ファイルサイズが異常です。')
print('手動でダウンロードしてください。')
return False
return True
return True
def download_model_guide():
"""学習済みモデルのダウンロード案内"""
if not os.path.exists(MODEL_PATH):
print('=' * 60)
print('U-2-Netモデルの手動ダウンロードが必要です')
print('=' * 60)
print('以下の手順でモデルをダウンロードしてください:')
print('1. ブラウザで以下のURLにアクセス:')
print(' https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view')
print("2. 'ダウンロード'ボタンをクリック")
print("3. ダウンロードした'u2net.pth'ファイルをこのPythonファイルと同じフォルダに配置")
print('4. プログラムを再実行')
print('=' * 60)
print('注意: モデルファイルサイズは約176MBです')
print('=' * 60)
return False
return True
def load_model():
"""U-2-Netモデルのロード"""
# 自動ダウンロード試行
if not download_model():
# 自動ダウンロード失敗時は手動案内
if not download_model_guide():
return None
# モデルロード
model = U2NET(3, 1)
if torch.cuda.is_available():
model.cuda()
model.load_state_dict(torch.load(MODEL_PATH, weights_only=False))
else:
model.load_state_dict(torch.load(MODEL_PATH, map_location='cpu', weights_only=False))
model.eval()
return model
def opencv_saliency_process(frame):
"""OpenCV組み込み顕著性検出"""
# Fine Grained静的顕著性検出器を使用
saliency = cv2.saliency.StaticSaliencyFineGrained_create()
(success, saliency_map) = saliency.computeSaliency(frame)
if success:
# 閾値適用
saliency_thresholded = np.where(saliency_map > CONFIDENCE_THRESHOLD,
saliency_map, 0)
# 0-255スケール変換
saliency_scaled = (saliency_thresholded * 255).astype(np.uint8)
# カラーマップ適用
saliency_colored = cv2.applyColorMap(saliency_scaled, COLOR_MAP)
return saliency_colored, np.mean(saliency_map)
else:
return np.zeros_like(frame), 0.0
def video_processing(frame, model, use_u2net, prev_saliency):
"""動画フレーム処理関数"""
if use_u2net:
# U-2-Net処理
# 前処理
input_tensor = preprocess_frame(frame)
if torch.cuda.is_available():
input_tensor = input_tensor.cuda()
# 推論
with torch.no_grad():
d0, d1, d2, d3, d4, d5, d6 = model(input_tensor)
pred = d0[:, 0, :, :]
pred = normPRED(pred)
# 後処理
saliency_colored, mean_saliency = postprocess_saliency(pred, frame.shape)
else:
# OpenCV顕著性検出
saliency_colored, mean_saliency = opencv_saliency_process(frame)
# 時間的フィルタリング処理
if prev_saliency is not None:
saliency_colored = cv2.addWeighted(
saliency_colored, FILTER_ALPHA,
prev_saliency, 1 - FILTER_ALPHA, 0
)
return saliency_colored, mean_saliency
# プログラム開始時の概要表示
print('=' * 60)
print('U-2-Net動画用オブジェクト顕著性検出プログラム')
print('=' * 60)
print('概要: U-2-Netを使用して動画中の顕著なオブジェクトを検出します')
print('操作方法:')
print(' - qキー: プログラム終了')
print(' - 処理結果は2つのウィンドウで表示されます')
print(' 1. Original: 元の動画')
print(' 2. Saliency Map: 顕著性マップ(赤色が顕著な領域)')
print('=' * 60)
# メイン処理
print('U-2-Netモデルをロード中...')
model = load_model()
if model is None:
print('U-2-Netが利用できません。OpenCV顕著性検出を使用します。')
use_u2net = False
else:
print('U-2-Netモデルロード完了')
use_u2net = True
# 入力選択
print('0: 動画ファイル')
print('1: カメラ')
print('2: サンプル動画')
choice = input('選択: ')
temp_file = None
if choice == '0':
root = tk.Tk()
root.withdraw()
path = filedialog.askopenfilename()
if not path:
exit()
cap = cv2.VideoCapture(path)
elif choice == '1':
cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
elif choice == '2':
# サンプル動画ダウンロード・処理
url = 'https://github.com/opencv/opencv/raw/master/samples/data/vtest.avi'
filename = 'vtest.avi'
try:
urllib.request.urlretrieve(url, filename)
temp_file = filename
cap = cv2.VideoCapture(filename)
except Exception as e:
print(f'動画のダウンロードに失敗しました: {url}')
print(f'エラー: {e}')
exit()
else:
print('無効な選択です')
exit()
# 処理統計用変数
frame_count = 0
start_time = time.time()
last_print_time = start_time
results_log = []
# 時間的フィルタリング用の変数
prev_saliency = None
print("処理開始 ('q'キーで終了)")
# メイン処理ループ
try:
while True:
cap.grab()
ret, frame = cap.retrieve()
if not ret:
break
# フレーム処理
saliency_colored, mean_saliency = video_processing(frame, model, use_u2net, prev_saliency)
prev_saliency = saliency_colored.copy()
# 統計情報更新
frame_count += 1
current_time = time.time()
elapsed_time = current_time - start_time
fps = frame_count / elapsed_time if elapsed_time > 0 else 0
# OpenCV画面にテキスト表示
info_text = f'FPS: {fps:.1f} | Mean Saliency: {mean_saliency:.3f}'
cv2.putText(saliency_colored, info_text, (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
# 表示
cv2.imshow('Original', frame)
cv2.imshow('Saliency Map', saliency_colored)
# 1秒間隔でprint表示
if current_time - last_print_time >= 1.0:
result_text = f'Frame: {frame_count}, FPS: {fps:.1f}, Mean Saliency: {mean_saliency:.3f}, Elapsed: {elapsed_time:.1f}s'
print(result_text)
results_log.append(result_text)
last_print_time = current_time
if cv2.waitKey(1) & 0xFF == ord('q'):
break
finally:
# 終了処理
cap.release()
cv2.destroyAllWindows()
# 一時ファイル削除
if temp_file and os.path.exists(temp_file):
os.remove(temp_file)
# 結果をファイルに保存
if results_log:
with open('result.txt', 'w', encoding='utf-8') as f:
f.write('U-2-Net動画用オブジェクト顕著性検出プログラム実行結果\n')
f.write('=' * 60 + '\n')
f.write(f'使用モデル: {"U-2-Net" if use_u2net else "OpenCV StaticSaliencyFineGrained"}\n')
f.write(f'総フレーム数: {frame_count}\n')
f.write(f'総処理時間: {elapsed_time:.1f}秒\n')
f.write(f'平均FPS: {fps:.1f}\n')
f.write('=' * 60 + '\n')
f.write('詳細ログ:\n')
for log in results_log:
f.write(log + '\n')
print('result.txtに保存しました')