PaDiM による画像からの異常検知(ソースコードと実行結果)

Python開発環境,ライブラリ類

ここでは、最低限の事前準備について説明する。機械学習や深層学習を行う場合は、NVIDIA CUDA、Visual Studio、Cursorなどを追加でインストールすると便利である。これらについては別ページ https://www.kkaneko.jp/cc/dev/aiassist.htmlで詳しく解説しているので、必要に応じて参照してください。

Python 3.12 のインストール

インストール済みの場合は実行不要。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要である。

REM Python をシステム領域にインストール
winget install --scope machine --id Python.Python.3.12 -e --silent
REM Python のパス設定
set "PYTHON_PATH=C:\Program Files\Python312"
set "PYTHON_SCRIPTS_PATH=C:\Program Files\Python312\Scripts"
echo "%PATH%" | find /i "%PYTHON_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_PATH%" /M >nul
echo "%PATH%" | find /i "%PYTHON_SCRIPTS_PATH%" >nul
if errorlevel 1 setx PATH "%PATH%;%PYTHON_SCRIPTS_PATH%" /M >nul

関連する外部ページ

Python の公式ページ: https://www.python.org/

AI エディタ Windsurf のインストール

Pythonプログラムの編集・実行には、AI エディタの利用を推奨する。ここでは,Windsurfのインストールを説明する。

管理者権限でコマンドプロンプトを起動(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行して、Windsurfをシステム全体にインストールする。管理者権限は、wingetの--scope machineオプションでシステム全体にソフトウェアをインストールするために必要となる。

winget install --scope machine Codeium.Windsurf -e --silent

関連する外部ページ

Windsurf の公式ページ: https://windsurf.com/

必要なライブラリのインストール

コマンドプロンプトを管理者として実行(手順:Windowsキーまたはスタートメニュー > cmd と入力 > 右クリック > 「管理者として実行」)し、以下を実行する


前準備: pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install opencv-python scikit-learn pillow tqdm numpy

PaDiM による画像からの異常検知プログラム

概要

このプログラムは動画の視覚的パターンから正常・異常を判別する。深層ニューラルネットワークにより画像特徴を抽出し、統計的分布モデリングによって正常状態を学習する。その後、新しい入力フレームが,学習した正常分布からどの程度逸脱しているかを測定することで異常検知を実現する。

主要技術

参考文献

[1] Defard, T., Setkov, A., Loesch, A., & Audigier, R. (2021). PaDiM: a patch distribution modeling framework for anomaly detection and localization. In International Conference on Computer Vision Theory and Applications (pp. 475-489).

[2] He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 770-778).

[3] Ledoit, O., & Wolf, M. (2004). A well-conditioned estimator for large-dimensional covariance matrices. Journal of Multivariate Analysis, 88(2), 365-411.

ソースコード


"""
プログラム名: PaDiM による画像からの異常検知プログラム
特徴技術名: PaDiM (Patch Distribution Modeling)
出典: Defard, T., Setkov, A., Loesch, A., & Audigier, R. (2021). PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization. In International Conference on Computer Vision Theory and Applications (pp. 475-489). arXiv:2011.08785
特徴機能: CNN多階層特徴マップから得られるパッチ埋め込みベクトルを多変量ガウス分布でモデリングし、マハラノビス距離による異常スコア算出で高精度な異常検知・局在化を実現
学習済みモデル: ResNet18/ResNet50/EfficientNet-B0 ImageNet学習済みモデル(torchvision提供)- ユーザが選択可能な3種類の深層学習モデル
方式設計:
  - 関連利用技術:
    * LedoitWolf共分散推定(scikit-learn)- 高次元データに対する正則化共分散行列推定
    * マハラノビス距離 - 多変量分布における異常度測定
    * OpenCV - 動画処理・表示
    * PIL/Pillow - 日本語テキスト描画
  - 入力と出力: 入力: 動画(ユーザは「0:動画ファイル,1:カメラ,2:サンプル動画」のメニューで選択.0:動画ファイルの場合はtkinterでファイル選択.1の場合はOpenCVでカメラが開く.2の場合はhttps://github.com/opencv/opencv/blob/master/samples/data/vtest.aviを使用)、出力: 異常スコアマップ重畳動画のOpenCV表示、1秒間隔での異常スコア値print出力・result.txt保存、異常領域切り抜き画像の連番保存(オプション)
  - 処理手順: 1)正常画像からCNN多階層特徴抽出→パッチ埋め込み作成、2)パッチ位置別多変量ガウス分布推定、3)動画フレーム毎の同様特徴抽出、4)マハラノビス距離による異常スコア算出、5)異常マップ生成・リアルタイム可視化、6)異常領域のバウンディングボックス表示と切り抜き保存(オプション)
  - 前処理、後処理: 前処理:フレーム抽出・リサイズ(256x256)・正規化、後処理:異常マップのガウシアンフィルタ適用・色マップ変換・重畳表示・バウンディングボックス描画
  - 追加処理: Ledoit-Wolf正則化による高次元共分散行列の数値安定化、CNN特徴マップのbilinear補間による解像度統一、異常領域の自動検出と切り抜き保存(ユーザ選択式)
  - 調整を必要とする設定値: train_dir(正常画像ディレクトリパス)- 学習用正常画像群の格納場所
将来方策: 設定ファイル(config.json)による学習ディレクトリ自動設定機能、複数ディレクトリ対応GUI選択機能
その他の重要事項: Windows環境対応、CUDA利用可能時GPU自動選択、日本語フォント対応、リアルタイム動画処理対応、複数CNNモデル選択機能、異常画像保存オプション
前準備: pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
pip install opencv-python scikit-learn pillow tqdm numpy
"""

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import cv2
import os
import time
import datetime
import tkinter as tk
from tkinter import filedialog
import urllib.request
import shutil
from tqdm import tqdm
from sklearn.covariance import LedoitWolf
from PIL import Image, ImageFont, ImageDraw

def log_message(message, log_file='system.log'):
    """タイムスタンプ付きログメッセージの記録と表示"""
    timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
    log_entry = f'[{timestamp}] {message}'
    print(log_entry)

    with open(log_file, 'a', encoding='utf-8') as f:
        f.write(log_entry + '\n')

class ResNet18_FeatureExtractor(nn.Module):
    def __init__(self, layers=['layer1', 'layer2', 'layer3']):
        super().__init__()
        log_message('ResNet18特徴抽出器を初期化中')
        backbone = models.resnet18(weights='DEFAULT')
        self.layer1 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool, backbone.layer1)
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layers = layers
        log_message(f'ResNet18初期化完了 - 使用層: {layers}')

    def forward(self, x):
        out = {}
        x = self.layer1(x)
        if 'layer1' in self.layers:
            out['layer1'] = x
        x = self.layer2(x)
        if 'layer2' in self.layers:
            out['layer2'] = x
        x = self.layer3(x)
        if 'layer3' in self.layers:
            out['layer3'] = x
        return out

class ResNet50_FeatureExtractor(nn.Module):
    def __init__(self, layers=['layer1', 'layer2', 'layer3']):
        super().__init__()
        log_message('ResNet50特徴抽出器を初期化中')
        backbone = models.resnet50(weights='DEFAULT')
        self.layer1 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu, backbone.maxpool, backbone.layer1)
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layers = layers
        log_message(f'ResNet50初期化完了 - 使用層: {layers}')

    def forward(self, x):
        out = {}
        x = self.layer1(x)
        if 'layer1' in self.layers:
            out['layer1'] = x
        x = self.layer2(x)
        if 'layer2' in self.layers:
            out['layer2'] = x
        x = self.layer3(x)
        if 'layer3' in self.layers:
            out['layer3'] = x
        return out

class EfficientNetB0_FeatureExtractor(nn.Module):
    def __init__(self, layers=['features.2', 'features.3', 'features.4']):
        super().__init__()
        log_message('EfficientNet-B0特徴抽出器を初期化中')
        backbone = models.efficientnet_b0(weights='DEFAULT')
        self.features = backbone.features
        self.layers = layers
        log_message(f'EfficientNet-B0初期化完了 - 使用層: {layers}')

    def forward(self, x):
        out = {}
        for idx, module in enumerate(self.features):
            x = module(x)
            layer_name = f'features.{idx}'
            if layer_name in self.layers:
                out[layer_name] = x
        return out

def select_model():
    """モデル選択メニューを表示し、選択されたモデルを返す"""
    print('\n=== CNN特徴抽出モデル選択 ===')
    print('\n利用可能なモデル一覧:')
    print('\n1. ResNet18 (Residual Network 18層)')
    print('   - 公式名: ResNet18')
    print('   - パラメータ数: 約11.7M')
    print('   - 特徴: 軽量で高速、基本的な異常検知に適している')
    print('   - ImageNet Top-1精度: 69.76%')
    print('   - 推奨用途: リアルタイム処理、計算資源が限られた環境')

    print('\n2. ResNet50 (Residual Network 50層)')
    print('   - 公式名: ResNet50')
    print('   - パラメータ数: 約25.6M')
    print('   - 特徴: バランスの取れた性能、より複雑な特徴抽出が可能')
    print('   - ImageNet Top-1精度: 76.13%')
    print('   - 推奨用途: 高精度が必要な検査、十分な計算資源がある環境')

    print('\n3. EfficientNet-B0')
    print('   - 公式名: EfficientNet-B0')
    print('   - パラメータ数: 約5.3M')
    print('   - 特徴: 効率的なアーキテクチャ、精度と速度のバランスが良い')
    print('   - ImageNet Top-1精度: 77.69%')
    print('   - 推奨用途: モバイル環境、エッジデバイス、省電力が必要な環境')

    while True:
        choice = input('\nモデルを選択してください (1/2/3): ')

        if choice == '1':
            log_message('ResNet18モデルが選択されました')
            print('\nResNet18を使用します')
            return ResNet18_FeatureExtractor()
        elif choice == '2':
            log_message('ResNet50モデルが選択されました')
            print('\nResNet50を使用します')
            return ResNet50_FeatureExtractor()
        elif choice == '3':
            log_message('EfficientNet-B0モデルが選択されました')
            print('\nEfficientNet-B0を使用します')
            return EfficientNetB0_FeatureExtractor()
        else:
            print('無効な選択です。1、2、または3を入力してください。')

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def copy_files_from_pc(train_dir, current_files):
    """パソコン内のファイルを複数選択してコピーする関数"""
    log_message('パソコン内ファイル選択モードを開始')
    print('\nファイル選択ダイアログを開きます(複数選択可能)...')

    root = tk.Tk()
    root.withdraw()

    # 複数ファイル選択ダイアログ
    file_paths = filedialog.askopenfilenames(
        title='正常画像ファイルを選択してください(複数選択可能)',
        filetypes=[
            ('画像ファイル', '*.jpg *.jpeg *.png'),
            ('JPEGファイル', '*.jpg *.jpeg'),
            ('PNGファイル', '*.png'),
            ('すべてのファイル', '*.*')
        ]
    )

    if not file_paths:
        log_message('ファイル選択がキャンセルされました')
        print('ファイル選択がキャンセルされました')
        return False

    log_message(f'{len(file_paths)}個のファイルが選択されました')
    print(f'{len(file_paths)}個のファイルが選択されました')

    # 既存画像を削除
    for filename in current_files:
        filepath = os.path.join(train_dir, filename)
        if os.path.exists(filepath):
            os.remove(filepath)
            log_message(f'既存画像削除: {filename}')
    print('既存画像を削除しました')

    # 選択されたファイルをコピー
    copied_count = 0
    for file_path in file_paths:
        try:
            filename = os.path.basename(file_path)
            dest_path = os.path.join(train_dir, filename)

            # 同名ファイルが存在する場合の処理
            if os.path.exists(dest_path):
                name, ext = os.path.splitext(filename)
                counter = 1
                while os.path.exists(dest_path):
                    new_filename = f'{name}_{counter}{ext}'
                    dest_path = os.path.join(train_dir, new_filename)
                    counter += 1
                filename = os.path.basename(dest_path)

            shutil.copy2(file_path, dest_path)
            copied_count += 1
            log_message(f'ファイルコピー完了: {filename}')
            print(f'コピー完了: {filename}')

        except Exception as e:
            log_message(f'ファイルコピーに失敗: {file_path}, エラー: {e}')
            print(f'ファイルコピーに失敗: {os.path.basename(file_path)}, エラー: {e}')

    log_message(f'ファイルコピー完了: {copied_count}個のファイル')
    print(f'\n{copied_count}個のファイルをコピーしました')

    if copied_count > 0:
        return True
    else:
        print('有効なファイルがコピーされませんでした')
        return False

def setup_training_data():
    """学習用ディレクトリとサンプル画像のセットアップ"""
    log_message('学習用データのセットアップを開始')
    train_dir = './train_normal'

    if not os.path.exists(train_dir):
        os.makedirs(train_dir)
        log_message(f'学習用ディレクトリ {train_dir} を作成しました')
        print(f'学習用ディレクトリ {train_dir} を作成しました')

    existing_images = [f for f in os.listdir(train_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
    log_message(f'既存画像数: {len(existing_images)}枚')

    # ダウンロード判定: フォルダが空の場合のみダウンロード
    if len(existing_images) == 0:
        log_message('正常画像が存在しません。サンプル画像のダウンロードを開始')
        print('正常画像がありません。サンプル画像をダウンロード中...')

        sample_urls = [
            'https://github.com/opencv/opencv/raw/master/samples/data/fruits.jpg',
            'https://github.com/opencv/opencv/raw/master/samples/data/messi5.jpg',
            'https://github.com/opencv/opencv/raw/master/samples/data/aero3.jpg'
        ]

        for i, url in enumerate(sample_urls):
            try:
                filename = f'normal_sample_{i+1}.jpg'
                filepath = os.path.join(train_dir, filename)
                urllib.request.urlretrieve(url, filepath)
                log_message(f'ダウンロード完了: {filename}')
                print(f'ダウンロード完了: {filename}')
            except Exception as e:
                log_message(f'画像ダウンロードに失敗: {url}, エラー: {e}')
                print(f'画像ダウンロードに失敗: {url}, エラー: {e}')

        log_message('サンプル正常画像のダウンロードが完了')
        print('サンプル正常画像のダウンロードが完了しました')

    while True:
        # フォルダ内のファイル一覧を再取得
        current_files = [f for f in os.listdir(train_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]

        # 必ずメニューを表示
        print('\n--- 利用可能な正常画像 ---')
        for filename in current_files:
            print(filename)

        print('\n選択してください:')
        print('1. これらの画像を使用')
        print('2. これらの画像を削除してカメラで撮影')
        print('3. これらの画像に追加してカメラで撮影')
        print('4. これらの画像を削除してパソコン内のファイル(複数選択可能)をコピーして使用')

        choice = input('\n選択 (1/2/3/4): ')
        log_message(f'画像使用選択: {choice}')

        if choice == '2':
            # 既存画像を削除
            for filename in current_files:
                filepath = os.path.join(train_dir, filename)
                if os.path.exists(filepath):
                    os.remove(filepath)
                    log_message(f'画像削除: {filename}')
            print('既存画像を削除しました')
            capture_normal_images(train_dir)
            break
        elif choice == '3':
            # 追加撮影
            print('既存画像に追加してカメラ撮影します')
            capture_normal_images(train_dir)
            break
        elif choice == '4':
            # パソコン内ファイルをコピー
            if copy_files_from_pc(train_dir, current_files):
                break
            # キャンセルされた場合はループを続行(四択に戻る)
        elif choice == '1':
            # 既存画像をそのまま使用
            log_message('既存画像をそのまま使用')
            print('既存画像を使用します')
            break
        else:
            print('無効な選択です。再度選択してください。')

    log_message('学習用データのセットアップが完了')
    return train_dir

def capture_normal_images(train_dir):
    """カメラで正常画像を撮影する関数"""
    log_message('正常画像撮影モードを開始')
    print('\n=== 正常画像撮影モード ===')
    print('\n推奨撮影条件の詳細説明')
    print('撮影枚数:')
    print('・最低限: 5枚以上')
    print('・推奨: 10-30枚')
    print('・理想: 50枚以上(多いほど精度向上)')
    print('\n撮影のポイント:')
    print('・一貫性: 同じ場所、同じ角度、同じ距離')
    print('・バリエーション: 照明条件、時間帯の変化を含める')
    print('・正常状態のみ: 異常や不要なものが写らないよう注意')
    print('・網羅性: 監視対象の様々な正常パターンを含める')
    print('\n用途別推奨例:')
    print('用途           撮影対象         枚数      条件')
    print('部屋の監視     無人の部屋       15-30枚   朝・昼・夜の照明条件')
    print('デスク監視     整理されたデスク 10-20枚   様々な物の配置パターン')
    print('製品検査       正常製品         20-50枚   複数角度、照明条件')
    print('機械監視       正常稼働中       10-30枚   動作サイクルの各段階')
    print('\n操作方法:')
    print('・スペースキー: 撮影')
    print('・qキー: 撮影終了')
    print('撮影開始します...')

    cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)

    if not cap.isOpened():
        log_message('エラー: カメラの初期化に失敗')
        print('エラー: カメラが開けません')
        return 0

    log_message('カメラ初期化完了')
    capture_count = 0

    try:
        while True:
            cap.grab()
            ret, frame = cap.retrieve()
            if not ret:
                log_message('フレーム取得に失敗')
                print('フレーム取得に失敗しました')
                break

            # 撮影ガイド表示
            display_frame = frame.copy()
            cv2.putText(display_frame, f'Captured: {capture_count}', (10, 30),
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            cv2.putText(display_frame, 'SPACE: Capture, Q: Quit', (10, 70),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

            cv2.imshow('Normal Image Capture', display_frame)

            key = cv2.waitKey(1) & 0xFF
            if key == ord(' '):  # スペースキーで撮影
                capture_count += 1
                filename = f'captured_normal_{capture_count:03d}.jpg'
                filepath = os.path.join(train_dir, filename)
                cv2.imwrite(filepath, frame)
                log_message(f'正常画像撮影完了: {filename}')
                print(f'撮影完了: {filename}')

                # 撮影フィードバック(画面を一瞬白くする)
                white_frame = np.ones_like(frame) * 255
                cv2.imshow('Normal Image Capture', white_frame)
                cv2.waitKey(100)

            elif key == ord('q'):  # qキーで終了
                log_message('正常画像撮影を終了')
                break

    finally:
        cap.release()
        cv2.destroyAllWindows()
        log_message('カメラリソースを解放')

    log_message(f'正常画像撮影完了: 合計 {capture_count} 枚')
    print(f'\n撮影完了: {capture_count}枚の正常画像を保存しました')
    if capture_count >= 10:
        print('十分な枚数が撮影されました。精度の高い異常検知が期待できます')
    elif capture_count >= 5:
        print('最低限の枚数は撮影されました。可能であればさらに追加撮影を推奨します')
    else:
        print('撮影枚数が少ないです。異常検知の精度に影響する可能性があります')

    return capture_count

def train_padim(normal_imgs, model, device):
    log_message(f'PaDiM学習を開始 - 正常画像数: {len(normal_imgs)}枚, デバイス: {device}')
    model.eval()
    embeddings = []

    for i, img in enumerate(tqdm(normal_imgs, desc='特徴抽出中')):
        x = transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            features = model(x)

        # モデルタイプに応じて層名を取得
        if isinstance(model, EfficientNetB0_FeatureExtractor):
            layer_names = ['features.2', 'features.3', 'features.4']
        else:
            layer_names = ['layer1', 'layer2', 'layer3']

        feat_list = [features[layer] for layer in layer_names]
        for j in range(len(feat_list)):
            feat_list[j] = nn.functional.interpolate(feat_list[j], size=(64, 64), mode='bilinear', align_corners=False)
        embedding = torch.cat(feat_list, 1)
        embeddings.append(embedding.squeeze(0).cpu().numpy())

        if (i + 1) % 5 == 0:
            log_message(f'特徴抽出進捗: {i + 1}/{len(normal_imgs)}枚 完了')

    log_message('特徴抽出完了')

    embeddings = np.stack(embeddings, axis=0)
    N, C, H, W = embeddings.shape
    log_message(f'特徴マップサイズ: N={N}, C={C}, H={H}, W={W}')
    embeddings = embeddings.transpose(0, 2, 3, 1).reshape(N * H * W, C)
    log_message(f'特徴ベクトル変換完了: shape={embeddings.shape}')

    mean = np.mean(embeddings, axis=0)
    log_message(f'平均ベクトル計算完了: shape={mean.shape}')

    cov = LedoitWolf().fit(embeddings).covariance_
    log_message(f'共分散行列推定完了: shape={cov.shape}')

    log_message('PaDiM学習が完了')
    return mean, cov, (H, W)

def mahalanobis_map(img, model, mean, cov, out_size, device):
    model.eval()
    x = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        features = model(x)

    # モデルタイプに応じて層名を取得
    if isinstance(model, EfficientNetB0_FeatureExtractor):
        layer_names = ['features.2', 'features.3', 'features.4']
    else:
        layer_names = ['layer1', 'layer2', 'layer3']

    feat_list = [features[layer] for layer in layer_names]
    for i in range(len(feat_list)):
        feat_list[i] = nn.functional.interpolate(feat_list[i], size=(64, 64), mode='bilinear', align_corners=False)
    embedding = torch.cat(feat_list, 1)
    C = embedding.shape[1]
    embedding = embedding.squeeze(0).cpu().numpy().transpose(1, 2, 0).reshape(-1, C)

    try:
        inv_cov = np.linalg.inv(cov)
    except np.linalg.LinAlgError:
        log_message('警告: 特異行列のため疑似逆行列を使用')
        inv_cov = np.linalg.pinv(cov)

    diff = embedding - mean
    dist = np.einsum('ij,jk,ik->i', diff, inv_cov, diff)
    dist_map = dist.reshape(out_size)

    return dist_map

def setup_detection_directory():
    """異常検出画像保存用ディレクトリのセットアップ"""
    base_dir = './detected_regions'

    # タイムスタンプ付きディレクトリ名
    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    detection_dir = os.path.join(base_dir, f'detection_{timestamp}')

    if not os.path.exists(detection_dir):
        os.makedirs(detection_dir)
        log_message(f'異常検出画像保存ディレクトリを作成: {detection_dir}')
        print(f'異常検出画像保存ディレクトリを作成しました: {detection_dir}')

    return detection_dir

def get_next_file_number(directory):
    """ディレクトリ内の既存ファイルから次の連番を取得"""
    existing_files = [f for f in os.listdir(directory) if f.endswith('.png') and f[:6].isdigit()]
    if not existing_files:
        return 1

    numbers = [int(f[:6]) for f in existing_files]
    return max(numbers) + 1

def visualize_anomaly(img_bgr, score_map, label='異常スコア', detection_dir=None, file_counter=None):
    score_map = cv2.resize(score_map, (img_bgr.shape[1], img_bgr.shape[0]))
    norm_map = (score_map - np.min(score_map)) / (np.max(score_map) - np.min(score_map))
    heatmap = cv2.applyColorMap(np.uint8(255 * norm_map), cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(img_bgr, 0.6, heatmap, 0.4, 0)

    # バウンディングボックス表示の追加
    # しきい値の設定(平均値 + 標準偏差 * 2)
    threshold = np.mean(score_map) + np.std(score_map) * 2

    # 二値化
    _, binary_map = cv2.threshold(score_map, threshold, 255, cv2.THRESH_BINARY)
    binary_map = np.uint8(binary_map)

    # モルフォロジー処理でノイズ除去
    kernel = np.ones((5, 5), np.uint8)
    binary_map = cv2.morphologyEx(binary_map, cv2.MORPH_CLOSE, kernel)
    binary_map = cv2.morphologyEx(binary_map, cv2.MORPH_OPEN, kernel)

    # 輪郭検出
    contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # 各輪郭に対してバウンディングボックスを描画
    for contour in contours:
        area = cv2.contourArea(contour)
        # 小さすぎる領域は除外(画像面積の0.1%以上)
        if area > img_bgr.shape[0] * img_bgr.shape[1] * 0.001:
            x, y, w, h = cv2.boundingRect(contour)
            # 赤色の矩形を描画
            cv2.rectangle(overlay, (x, y), (x + w, y + h), (0, 0, 255), 2)

            # 領域内の最大スコアを取得
            region_scores = score_map[y:y+h, x:x+w]
            if region_scores.size > 0:
                max_region_score = np.max(region_scores)
                # スコアをボックス上部に表示
                cv2.putText(overlay, f'{max_region_score:.1f}', (x, y - 5),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)

                # 異常領域を切り抜いて保存
                if detection_dir is not None and file_counter is not None:
                    # 切り抜き領域を取得
                    cropped_region = img_bgr[y:y+h, x:x+w]

                    # ファイル名を生成(6桁の連番)
                    filename = f'{file_counter[0]:06d}.png'
                    filepath = os.path.join(detection_dir, filename)

                    # 画像を保存
                    cv2.imwrite(filepath, cropped_region)
                    log_message(f'異常領域を保存: {filename} (スコア: {max_region_score:.1f})')

                    # カウンターを増加
                    file_counter[0] += 1

    try:
        font = ImageFont.truetype('C:/Windows/Fonts/msgothic.ttc', 30)
        img_pil = Image.fromarray(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
        draw = ImageDraw.Draw(img_pil)
        draw.text((30, 30), label, font=font, fill=(0, 255, 0))
        overlay = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
    except:
        cv2.putText(overlay, label, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

    return overlay

def video_processing(frame, model, mean, cov, out_size, device, detection_dir=None, file_counter=None):
    score_map = mahalanobis_map(frame, model, mean, cov, out_size, device)
    max_score = np.max(score_map)
    result_frame = visualize_anomaly(frame, score_map, f'異常スコア: {max_score:.2f}', detection_dir, file_counter)
    return result_frame, max_score

# プログラム開始時のログ初期化
log_message('=' * 50)
log_message('PaDiM異常検知システムを開始')
log_message('=' * 50)

print('PaDiM異常検知システムを開始します')
print('初回実行時は自動的に学習用サンプル画像をダウンロードします')
print('独自の正常画像を使用する場合は ./train_normal ディレクトリに配置してください')
print('操作方法: 動画再生中は qキー で終了できます')

# モデル選択
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
log_message(f'使用デバイス: {device}')
model = select_model().to(device)

# 学習用データのセットアップ
train_dir = setup_training_data()

train_imgs = []
for f in os.listdir(train_dir):
    if f.lower().endswith(('.jpg', '.png', '.jpeg')):
        img_path = os.path.join(train_dir, f)
        img = cv2.imread(img_path)
        if img is not None:
            train_imgs.append(img)

if len(train_imgs) == 0:
    log_message('エラー: 学習用画像が見つかりません')
    print('エラー: 学習用画像が見つかりません')
    exit()

log_message(f'学習用画像読み込み完了: {len(train_imgs)}枚')
print(f'正常画像 {len(train_imgs)} 枚でモデルを学習中...')
mean, cov, out_size = train_padim(train_imgs, model, device)

# 異常検出画像保存オプションの選択
print('\n異常検出画像の保存設定:')
print('検出された異常領域を切り抜いて画像として保存しますか?')
print('1: 保存する(連番で自動保存)')
print('2: 保存しない(表示のみ)')

save_choice = input('\n選択 (1/2): ')
log_message(f'異常検出画像保存設定: {save_choice}')

detection_dir = None
file_counter = None

if save_choice == '1':
    # 異常検出画像保存用ディレクトリのセットアップ
    detection_dir = setup_detection_directory()
    file_counter = [get_next_file_number(detection_dir)]  # リストで参照渡しを実現
    print('異常検出画像を保存します')
elif save_choice == '2':
    log_message('異常検出画像を保存しない設定')
    print('異常検出画像は保存しません(表示のみ)')
else:
    log_message('無効な選択のため、保存しない設定を適用')
    print('無効な選択です。異常検出画像は保存しません(表示のみ)')

print('\n動画入力方法を選択してください:')
print('0: 動画ファイル')
print('1: カメラ')
print('2: サンプル動画')

choice = input('選択: ')
log_message(f'入力選択: {choice}')
temp_file = None

# 入力ソースの初期化
if choice == '0':
    log_message('動画ファイル選択モード')
    root = tk.Tk()
    root.withdraw()
    path = filedialog.askopenfilename(filetypes=[('Video files', '*.mp4 *.avi *.mov')])
    if not path:
        log_message('動画ファイルが選択されませんでした')
        exit()
    log_message(f'選択された動画ファイル: {path}')
    cap = cv2.VideoCapture(path)
elif choice == '1':
    log_message('カメラモード')
    cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
    cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
elif choice == '2':
    log_message('サンプル動画モード')
    url = 'https://github.com/opencv/opencv/raw/master/samples/data/vtest.avi'
    filename = 'vtest.avi'
    try:
        urllib.request.urlretrieve(url, filename)
        log_message(f'サンプル動画ダウンロード完了: {filename}')
        temp_file = filename
        cap = cv2.VideoCapture(filename)
    except Exception as e:
        log_message(f'サンプル動画ダウンロードに失敗: {url}, エラー: {e}')
        print(f'動画のダウンロードに失敗しました: {url}')
        print(f'エラー: {e}')
        exit()
else:
    log_message(f'無効な選択: {choice}')
    print('無効な選択です')
    exit()

# 動画処理の開始
log_message('動画処理を開始')
results = []
last_print_time = time.time()
frame_count = 0

try:
    while True:
        cap.grab()
        ret, frame = cap.retrieve()
        if not ret:
            log_message('動画終了またはフレーム取得失敗')
            break

        frame_count += 1
        processed_frame, anomaly_score = video_processing(frame, model, mean, cov, out_size, device, detection_dir, file_counter)

        cv2.imshow('PaDiM Anomaly Detection', processed_frame)

        current_time = time.time()
        if current_time - last_print_time >= 1.0:
            log_message(f'フレーム#{frame_count}: 異常スコア={anomaly_score:.4f}')
            print(f'異常スコア: {anomaly_score:.4f}')
            results.append(f'時刻: {current_time:.2f}s, フレーム: {frame_count}, 異常スコア: {anomaly_score:.4f}')
            last_print_time = current_time

        if cv2.waitKey(1) & 0xFF == ord('q'):
            log_message('ユーザーによる終了')
            break
finally:
    cap.release()
    cv2.destroyAllWindows()
    log_message('動画キャプチャとウィンドウを終了')

    if temp_file:
        os.remove(temp_file)
        log_message(f'一時ファイルを削除: {temp_file}')

    # 結果ファイルの保存
    with open('result.txt', 'w', encoding='utf-8') as f:
        f.write(f'# PaDiM異常検知結果 - {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n')
        f.write(f'# 総フレーム数: {frame_count}\n')
        f.write('# 時刻, フレーム番号, 異常スコア\n')
        for result in results:
            f.write(result + '\n')

    log_message('結果ファイル保存完了: result.txt')
    print('result.txtに保存しました')

    # 異常検出画像の保存情報
    if detection_dir is not None:
        total_detections = file_counter[0] - get_next_file_number(detection_dir) + len([f for f in os.listdir(detection_dir) if f.endswith('.png')])
        if total_detections > 0:
            log_message(f'異常検出画像保存完了: {detection_dir} に {total_detections} 個の画像を保存')
            print(f'\n異常検出画像を {detection_dir} に保存しました(合計 {total_detections} 個)')

log_message('=' * 50)
log_message('PaDiM異常検知システムを終了')
log_message('=' * 50)