import torch import matplotlib.pyplot as plt from PIL import Image, ImageDraw, ImageFont import requests from io import BytesIO import numpy as np from retinaface import RetinaFace import gazelle from gazelle.model import get_gazelle_model import os os.environ['DISABLE_XFORMERS'] = '1' # xformersを無効化 # デバイスの自動判別と設定 device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}") # モデルのロード時にxformersを無効化 model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitb14_inout', force_reload=True) model.eval() model = model.to(device) # データ型変換を削除 # 入力画像の読み込み image_url = "https://www.looper.com/img/gallery/the-office-funniest-moments-ranked/jim-and-dwights-customer-service-training-1627594561.jpg" try: response = requests.get(image_url, stream=True) response.raise_for_status() image = Image.open(BytesIO(response.content)) width, height = image.size plt.figure(1) plt.imshow(image) plt.axis('off') plt.savefig('original.png') except requests.exceptions.RequestException as e: print(f"Error downloading image: {e}") # 顔検出 resp = RetinaFace.detect_faces(np.array(image)) bboxes = [resp[key]['facial_area'] for key in resp.keys()] # Gazelle入力の準備 img_tensor = transform(image).unsqueeze(0).to(device) # データ型変換を削除 norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height]) for bbox in bboxes]] input = { "images": img_tensor, "bboxes": norm_bboxes } # モデル推論 with torch.no_grad(): output = model(input) def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None): if isinstance(heatmap, torch.Tensor): heatmap = heatmap.detach().cpu().numpy() heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR) heatmap = plt.cm.jet(np.array(heatmap) / 255.) heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8) heatmap = Image.fromarray(heatmap).convert("RGBA") heatmap.putalpha(90) overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap) if bbox is not None: width, height = pil_image.size xmin, ymin, xmax, ymax = bbox draw = ImageDraw.Draw(overlay_image) draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01)) if inout_score is not None: text = f"in-frame: {inout_score:.2f}" text_height = int(height * 0.01) text_x = xmin * width text_y = ymax * height + text_height draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05))) return overlay_image # 各人物のヒートマップ表示 for i in range(len(bboxes)): plt.figure(i + 2) plt.imshow(visualize_heatmap(image, output['heatmap'][0][i], norm_bboxes[0][i], inout_score=output['inout'][0][i] if output['inout'] is not None else None)) plt.axis('off') plt.savefig(f'heatmap_{i}.png') def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5): colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow'] overlay_image = pil_image.convert("RGBA") draw = ImageDraw.Draw(overlay_image) width, height = pil_image.size for i in range(len(bboxes)): bbox = bboxes[i] xmin, ymin, xmax, ymax = bbox color = colors[i % len(colors)] draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01)) if inout_scores is not None: inout_score = inout_scores[i] text = f"in-frame: {inout_score:.2f}" text_height = int(height * 0.01) text_x = xmin * width text_y = ymax * height + text_height draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05))) if inout_scores is not None and inout_score > inout_thresh: heatmap = heatmaps[i] heatmap_np = heatmap.detach().cpu().numpy() max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape) gaze_target_x = max_index[1] / heatmap_np.shape[1] * width gaze_target_y = max_index[0] / heatmap_np.shape[0] * height bbox_center_x = ((xmin + xmax) / 2) * width bbox_center_y = ((ymin + ymax) / 2) * height draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height))) draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height))) return overlay_image # 最終的な可視化 plt.figure(len(bboxes) + 2) plt.imshow(visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5)) plt.axis('off') plt.savefig('combined_visualization.png') # すべての図を表示 plt.show()