MMClassification は, OpenMMLab の構成物で, Few Shot Classification, Few Shot Detection の機能を提供する.
【文献】
mmfewshot Contributors, OpenMMLab Few Shot Learning Toolbox and Benchmark, https://github.com/open-mmlab/mmfewshot, 2021.
【関連する外部ページ】
Windows での Git のインストール: 別ページ »で説明
【関連する外部ページ】
Git の公式ページ: https://git-scm.com/
Windows での Python 3.10,関連パッケージ,Python 開発環境のインストール: 別ページ »で説明
【サイト内の関連ページ】
Python のまとめ: 別ページ »にまとめ
【関連する外部ページ】
Python の公式ページ: https://www.python.org/
Windows での Build Tools for Visual Studio 2022,NVIDIA ドライバ,NVIDIA CUDA ツールキット 11.8,NVIDIA cuDNN v8.6 のインストールと動作確認: 別ページ »で説明
【関連する外部ページ】
コマンドプロンプトを管理者として実行: 別ページ »で説明
PyTorch のページ: https://pytorch.org/index.html
次のコマンドは, PyTorch 2.0 (NVIDIA CUDA 11.8 用) をインストールする. 事前に NVIDIA CUDA のバージョンを確認しておくこと(ここでは,NVIDIA CUDA ツールキット 11.8 が前もってインストール済みであるとする).
python -m pip install -U pip python -m pip install -U torch torchvision torchaudio numpy numba --index-url https://download.pytorch.org/whl/cu118 python -c "import torch; print(torch.__version__, torch.cuda.is_available())"
インストールの方法は複数ある. ここでは, NVIDIA CUDA ツールキットを使うことも考え, インストールしやすい方法として,ソースコードからビルドしてインストールする方法を案内している.
MMCV のインストールを行う.
インストールの方法は複数ある. ここでは, NVIDIA CUDA ツールキットを使うことも考え, インストールしやすい方法として,ソースコードからビルドしてインストールする方法を案内している.
コマンドプロンプトを管理者として実行: 別ページ »で説明
python -c "import torch; TORCH_VERSION = '.'.join(torch.__version__.split('.')[:2]); print(TORCH_VERSION)"
このとき,実際には 11.8 をインストールしているのに,「cu117」のように古いバージョンが表示されることがある.このような場合は,気にせずに続行する.
python -c "import torch; CUDA_VERSION = torch.__version__.split('+')[-1]; print(CUDA_VERSION)"
MMFewShot が MMCV 1.6.0 に依存している (2023/1).
https://mmcv.readthedocs.io/en/latest/get_started/installation.html に記載の手順による
python -m pip install -U pip python -m pip install -U opencv-python python -m pip install mmcv-full==1.6.0
python -c "from mmcv.ops import get_compiling_cuda_version, get_compiler_version; print(get_compiling_cuda_version()); print(get_compiler_version())"
MIM, MMClassification, MMDetection, MMFewShot のインストールを行う.
コマンドプロンプトを管理者として実行: 別ページ »で説明
MMFewShot が MMDetection 2.25.0 に依存している (2023/1).
https://github.com/open-mmlab/mmfewshot/blob/main/docs/en/install.md に記載の手順による.
https://mmclassification.readthedocs.io/en/latest/getting_started.html#installation による.
python -m pip install -U git+https://github.com/open-mmlab/mim.git python -m pip install -U git+https://github.com/open-mmlab/mmclassification.git python -m pip install -U mmdet==2.25.0 python -c "import mmcls; print(mmcls.__version__)" python -c "import mmdet; print(mmdet.__version__)"
(省略)
cd %HOMEPATH% rmdir /s /q mmfewshot git clone https://github.com/open-mmlab/mmfewshot.git cd mmfewshot python setup.py build python setup.py install python -c "import mmfewshot; print(mmfewshot.__version__)"
(省略)
コマンドプロンプトを管理者として実行: 別ページ »で説明
次のコマンドを実行する.
cd %HOMEPATH%\mmfewshot mkdir checkpoints cd checkpoints curl -O https://download.openmmlab.com/mmfewshot/detection/attention_rpn/coco/attention-rpn_r50_c4_4xb2_coco_base-training_20211102_003348-da28cdfd.pth
次の Python プログラムを実行する.Matplotlib を使うので,Jupyter QtConsole や Jupyter ノートブック (Jupyter Notebook) の利用が便利である.
Python プログラムは,公式ページhttps://mmclassification.readthedocs.io/en/latest/get_started.html のものを書き換えて使用.
下図では,Python プログラムの実行のため,jupyter qtconsole を使用している.
import os from mmdet.apis import show_result_pyplot from mmfewshot.detection.apis import (inference_detector, init_detector, process_support_images) %matplotlib inline import matplotlib.pyplot as plt import torch import torchvision.models as models from IPython.display import display fconfig = 'configs/detection/attention_rpn/coco/attention-rpn_r50_c4_4xb2_coco_base-training.py' fcheckpoint = 'checkpoints/attention-rpn_r50_c4_4xb2_coco_base-training_20211102_003348-da28cdfd.pth' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') fsupport_images_dir = 'demo/demo_detection_images/support_images' model = init_detector(fconfig,fcheckpoint, device=device) files = os.listdir(fsupport_images_dir) support_images = [ os.path.join(fsupport_images_dir, file) for file in files ] classes = [file.split('.')[0] for file in files] support_labels = [[file.split('.')[0]] for file in files] print("support_images") display(support_images) print("classes") display(classes) print("support_labels") display(support_labels) process_support_images(model, support_images, support_labels, classes=classes) # single image fimage = 'demo/demo_detection_images/query_images/demo_query.jpg' fscore_thr = 0.3 result = inference_detector(model, fimage) show_result_pyplot(model, fimage, result, score_thr=fscore_thr)
エラーメッセージが出ないことを確認.