init
This commit is contained in:
130
.gitignore
vendored
Normal file
130
.gitignore
vendored
Normal file
@@ -0,0 +1,130 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
.envrc
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre
|
||||
.pyre/
|
||||
|
||||
# ruff
|
||||
.ruff_cache/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.project
|
||||
.pydevproject
|
||||
.settings/
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
|
||||
# Deep Learning
|
||||
*.pth
|
||||
*.pt
|
||||
*.onnx
|
||||
*.engine
|
||||
*.trt
|
||||
checkpoints/
|
||||
work_dirs/
|
||||
logs/
|
||||
runs/
|
||||
outputs/
|
||||
wandb/
|
||||
mlruns/
|
||||
tensorboard/
|
||||
|
||||
# Data (uncomment if needed)
|
||||
# data/
|
||||
# datasets/
|
||||
# *.tif
|
||||
# *.tiff
|
||||
# *.zip
|
||||
# *.tar
|
||||
# *.tar.gz
|
||||
|
||||
# Misc
|
||||
*.log
|
||||
*.tmp
|
||||
*.temp
|
||||
.cache/
|
||||
353
README.md
Normal file
353
README.md
Normal file
@@ -0,0 +1,353 @@
|
||||
# SkySense++
|
||||
|
||||
This repository is the official implementation of the paper "SkySense++: A Semantic-Enhanced Multi-Modal Remote Sensing Foundation Model Beyond SkySense for Earth Observation".
|
||||
|
||||
## 📢 Latest Updates
|
||||
|
||||
🔥🔥🔥 Last Updated on 2025.09.15 🔥🔥🔥
|
||||
- [2025.09.15] Add a [🌍 project page](https://zqcrafts.github.io/SkySense-O/project.html).
|
||||
- [2025.08.04] Our work has been published in [*Nature Machine Intelligence*](https://www.nature.com/articles/s42256-025-01078-8).
|
||||
- [2025.03.23] Code for preprocessing/pretraining/application and [model weights](https://www.notion.so/SkySense-Checkpoints-a7fcff6ce29a4647a08c7fe416910509?source=copy_link) for models have been uploaded.
|
||||
- [2025.03.14] updated optical images of JL-16 dataset in [Huggingface](https://huggingface.co/datasets/KKKKKKang/JL-16).
|
||||
- [2025.03.12] updated sentinel-1 images and labels of JL-16 dataset in [Zenodo](https://zenodo.org/records/15010418).
|
||||
- [2025.03.09] created repo in [Zenodo](https://zenodo.org/records/15010418), datasets are uploading.
|
||||
- [2024.11.13] updated details of pretrain and evaluation data.
|
||||
|
||||
## Pretrain Data
|
||||
|
||||
### RS-Semantic Dataset
|
||||
|
||||
We conduct semantic-enhanced pretraining on the RS-Semantic dataset, which consists of 13 datasets with pixel-level annotations. Below are the specifics of these datasets. (also see in [Zenodo](https://zenodo.org/records/15010418)).
|
||||
|
||||
| Dataset | Modalities | GSD(m) | Size | Categories | Download Link |
|
||||
| ------------------- | ---------------- | ------ | --------------------- | ---------- | --------------------------------------------------------------------------------------------- |
|
||||
| Five Billion Pixels | Gaofen-2 | 4 | 6800x7200 | 24 | [Download](https://x-ytong.github.io/project/Five-Billion-Pixels.html) |
|
||||
| Potsdam | Airborne | 0.05 | 6000x6000 | 5 | [Download](https://www.isprs.org/education/benchmarks/UrbanSemLab/2d-sem-label-potsdam.aspx) |
|
||||
| Vaihingen | Airborne | 0.05 | 2494x2064 | 5 | [Download](https://www.isprs.org/education/benchmarks/UrbanSemLab/2d-sem-label-vaihingen.aspx) |
|
||||
| Deepglobe | WorldView | 0.5 | 2448x2448 | 6 | [Download](https://www.kaggle.com/datasets/balraj98/deepglobe-land-cover-classification-dataset) |
|
||||
| iSAID | Multiple Sensors | - | 800x800 to 4000x13000 | 15 | [Download](https://captain-whu.github.io/iSAID/index.html) |
|
||||
| LoveDA | Spaceborne | 0.3 | 1024x1024 | 7 | [Download](https://github.com/Junjue-Wang/LoveDA) |
|
||||
| DynamicEarthNet | WorldView | 0.3 | 1024x1024 | 7 | [Download](https://github.com/aysim/dynnet) |
|
||||
| | Sentinel-2* | 10 | 32x32 | | |
|
||||
| | Sentinel-1* | 10 | 32x33 | | |
|
||||
| Pastis-MM | WorldView | 0.3 | 1024x1024 | 18 | [Download](https://github.com/VSainteuf/pastis-benchmark) |
|
||||
| | Sentinel-2* | 10 | 32x32 | | |
|
||||
| | Sentinel-1* | 10 | 32x33 | | |
|
||||
| C2Seg-AB | Sentinel-2* | 10 | 128x128 | 13 | [Download](https://github.com/danfenghong/RSE_Cross-city) |
|
||||
| | Sentinel-1* | 10 | 128x128 | | |
|
||||
| FLAIR | Spot-5 | 0.2 | 512x512 | 12 | [Download](https://github.com/IGNF/FLAIR-2) |
|
||||
| | Sentinel-2* | 10 | 40x40 | | |
|
||||
| DFC20 | Sentinel-2 | 10 | 256x256 | 9 | [Download](https://ieee-dataport.org/competitions/2020-ieee-grss-data-fusion-contest) |
|
||||
| | Sentinel-1 | 10 | 256x256 | | |
|
||||
| S2-naip | NAIP | 1 | 512x512 | 32 | [Download](https://huggingface.co/datasets/allenai/s2-naip) |
|
||||
| | Sentinel-2* | 10 | 64x64 | | |
|
||||
| | Sentinel-1* | 10 | 64x64 | | |
|
||||
| JL-16 | Jilin-1 | 0.72 | 512x512 | 16 | [Download](https://zenodo.org/records/15010418) |
|
||||
| | Sentinel-1* | 10 | 40x40 | | |
|
||||
|
||||
*\* for time-series data.*
|
||||
|
||||
### RS-Representation Dataset
|
||||
|
||||
The pretraining list is in the [Zenodo](https://zenodo.org/records/15068572)- `rep_data_list.tar`. The download and process scripts are in [tools/pretraining_data_builder](tools/pretraining_data_builder).
|
||||
|
||||
## EO Benchmark
|
||||
|
||||
We evaluate our SkySense++ on 12 typical Earth Observation (EO) tasks across 7 domains: *agriculture*, *forestry*, *oceanography*, *atmosphere*, *biology*, *land surveying*, and *disaster management*. The detailed information about the datasets used for evaluation is as follows.
|
||||
|
||||
| Domain | Task type | Dataset | Modalities | GSD | Image size | Download Link | Notes |
|
||||
| ------------------- | --------------------------- | --------------------- | ---------------------- | ---- | ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- | ----- |
|
||||
| Agriculture | Crop classification | Germany | Sentinel-2* | 10 | 24x24 | [Download](https://github.com/michaeltrs/DeepSatModels/tree/main/data) | |
|
||||
| Foresetry | Tree species classification | TreeSatAI-Time-Series | Airborne, | 0.2 | 304x304 | [Download](http://example.com/download/treesatai-time-series) | |
|
||||
| | | | Sentinel-2* | 10 | 6x6 | | |
|
||||
| | | | Sentinel-1* | 10 | 6x6 | | |
|
||||
| | Deforestation segmentation | Atlantic | Sentinel-2 | 10 | 512x512 | [Download](https://github.com/davej23/attention-mechanism-unet) | |
|
||||
| Oceanography | Oil spill segmentation | SOS | Sentinel-1 | 10 | 256x256 | [Download](https://grzy.cug.edu.cn/zhuqiqi/en/yjgk/32384/list/index.htm) | |
|
||||
| Atmosphere | Air pollution regression | 3pollution | Sentinel-2 | 10 | 200x200 | [Download](https://github.com/CoDIS-Lab/AQNet) | |
|
||||
| | | | Sentinel-5P | 2600 | 120x120 | | |
|
||||
| Biology | Wildlife detection | Kenya | Airborne | - | 3068x4603 | [Download](https://data.4tu.nl/articles/_/12713903/1) | |
|
||||
| Land surveying | LULC mapping | C2Seg-BW | Gaofen-6 | 10 | 256x256 | [Download](https://github.com/danfenghong/RSE_Cross-city) | |
|
||||
| | | | Gaofen-3 | 10 | 256x256 | | |
|
||||
| | Change detection | dsifn-cd | GoogleEarth | 0.3 | 512x512 | [Download](https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/tree/master/dataset) | |
|
||||
| Disaster management | Flood monitoring | Flood-3i | Airborne | 0.05 | 256 × 256 | [Download](https://drive.google.com/drive/folders/1FMAKf2sszoFKjq0UrUmSLnJDbwQSpfxR) | |
|
||||
| | | C2SMSFloods | Sentinel-2, Sentinel-1 | 10 | 512x512 | [Download](https://beta.source.coop/c2sms/) | |
|
||||
| | Wildfire monitoring | CABUAR | Sentinel-2 | 10 | 5490 × 5490 | [Download](https://github.com/DarthReca/CaBuAr) | |
|
||||
| | Landslide mapping | GVLM | GoogleEarth | 0.3 | 1748x1748 ~ 10808x7424 | [Download](https://github.com/zxk688/GVLM) | |
|
||||
| | Building damage assessment | xBD | WorldView | 0.3 | 1024x1024 | [Download](https://xview2.org/) | |
|
||||
|
||||
*\* for time-series data.*
|
||||
|
||||
## Implementation Code
|
||||
|
||||
### Structure
|
||||
|
||||
This project mainly contains the following parts.
|
||||
|
||||
```plain
|
||||
./
|
||||
├── antmmf/ # antmmf framework code
|
||||
├── configs/
|
||||
│ ├── eval_skysense_pp_flood3i.yml # eval cfg on flood3i
|
||||
│ └── pretrain_skysensepp.yml # pretrain cfg
|
||||
├── finetune/ # finetuning code
|
||||
│ ├── configs/ # finetuning configs
|
||||
│ ├── mmseg/ # mmseg library
|
||||
│ ├── requirements/ # mmseg install requirements folder
|
||||
│ ├── requirements.txt # mmseg install requirements
|
||||
│ ├── setup.py # mmseg setup file
|
||||
│ └── tools/ # mmseg utils
|
||||
├── lib/ # model implementation
|
||||
│ ├── datasets/ # datasets for evaluation
|
||||
│ ├── evaluation/ # evaluation code
|
||||
│ ├── models/ # model architecture
|
||||
│ ├── predictors/ # inference code
|
||||
│ ├── task/ # task code
|
||||
│ ├── trainer/ # trainer code
|
||||
│ ├── utils/ # library code
|
||||
│ └── __init__.py # packages init file
|
||||
├── pretrain/ # pretrain ckpts
|
||||
├── tools/ # tools ckpts
|
||||
│ ├── pretraining_data_builder # pretraining dataset builder
|
||||
│ ├── run_1shot_flood3i.sh # datasets for evaluation
|
||||
│ ├── run_ft_atlantic.sh # run ft script
|
||||
│ ├── run_pretrain.sh # run pretrain script
|
||||
│ └── run.py # Program entry point
|
||||
└── readme.md # project readme
|
||||
```
|
||||
|
||||
### Environment
|
||||
|
||||
Each machine for implementating the pretraining or fintuning are with *Alibaba Group Enterprise Linux(7.2)* and *Python 3.8.10*. The pretraining and finetuning code are implemented on severs with *Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz* and *Nvidia A100 GPUS*.
|
||||
|
||||
### Pretraining
|
||||
|
||||
To run our pretraining code, please install dependency packages. (Instalazation takes about 14 minutes on a node with Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz and 8 A100 GPUs.)
|
||||
|
||||
```plain
|
||||
torch==1.13.1
|
||||
atorch==0.1.3
|
||||
torchvision==0.14.1
|
||||
mmcv-full==1.7.1
|
||||
mmsegmentation==0.30.0
|
||||
mmcls==0.25.0
|
||||
timm==0.6.13
|
||||
gdal==3.4.0
|
||||
scikit-image==0.19.3
|
||||
```
|
||||
|
||||
Step1. Install the above packages and clone [antmmf framework](https://github.com/alipay/Ant-Multi-Modal-Framework):
|
||||
|
||||
```bash
|
||||
git clone https://github.com/alipay/Ant-Multi-Modal-Framework.git antmmf/
|
||||
```
|
||||
|
||||
Step2. Download the pretraining datasets in [Zenodo](https://zenodo.org/records/15010418) and orgnize them as follows:
|
||||
|
||||
```
|
||||
pretrain_datasets
|
||||
├── dynamic-mm # multi-modal dynamic-mm datasets
|
||||
│ ├── images_hr # hr images
|
||||
│ ├── images_s2 # sentinel-2 images
|
||||
│ ├── images_s1 # sentinel-1 images
|
||||
│ ├── labels # segmentation annotations
|
||||
│ ├── dynamic-mm_train.json # train list file
|
||||
│ └── dynamic-mm_val.json # val list file
|
||||
├── fbp # single-modal fbp datasets
|
||||
│ ├── images # input gaofen-2 images
|
||||
│ ├── labels # segmentation annotations
|
||||
│ ├── fbp_train.json # train list file
|
||||
│ └── fbp_val.json # val list file
|
||||
└── ......
|
||||
```
|
||||
|
||||
The `<dataset>_<train/val>.json` is used to read information for training and validation, with a unified organizational format:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"hr_path": "dataset_name/images_hr/<img_name>.png", // hr info c,h,w
|
||||
"s2_path": ["dataset_name/images_s2/<img_name>_20240101.npz", "dataset_name/images_s2/<img_name>_20240103.npz"], // s2 c,h,w
|
||||
"s1_path": ["dataset_name/images_s1/<img_name>_20240104.npz", "dataset_name/images_s1/<img_name>_20240108.npz"], // s1 c,h,w
|
||||
"target_path": "dataset_name/labels/<img_name>.png", // annotation info
|
||||
"type": "dataset_name", // dataset_name
|
||||
"classes": [ // Included categories
|
||||
0,
|
||||
2,
|
||||
4,
|
||||
5
|
||||
]
|
||||
},
|
||||
{
|
||||
...
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Step3. Download the pretraining weights of SkySense [here](https://www.notion.so/SkySense-Checkpoints-a7fcff6ce29a4647a08c7fe416910509) and move it to `pretrain/`
|
||||
|
||||
Step4. Run the pretrain code on 4 nodes (each node with 8 A100 gpus):
|
||||
|
||||
```
|
||||
bash tools/run_pretrain.sh <node_rank:0-3> <master_ip_address>
|
||||
```
|
||||
|
||||
For example, if the ip adress of master node is *192.168.112.10*, the command for node *1* is:
|
||||
|
||||
```
|
||||
bash tools/run_pretrain.sh 1 192.168.112.10
|
||||
```
|
||||
|
||||
### Downstream 1-shot application
|
||||
|
||||
#### Requirments
|
||||
|
||||
To run our code, please install dependency packages. ( Instalazation takes about 10 minutes on a sever with Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz and 2 A100 GPUs.)
|
||||
|
||||
```plain
|
||||
torch==1.13.1
|
||||
atorch==0.1.3
|
||||
torchvision==0.14.1
|
||||
mmcv-full==1.7.1
|
||||
mmcls==0.25.0
|
||||
mmsegmentation==0.30.0
|
||||
timm==0.6.13
|
||||
gdal==3.4.0
|
||||
scikit-image==0.19.3
|
||||
```
|
||||
|
||||
#### Run steps
|
||||
|
||||
step1. Clone [antmmf framework](https://github.com/alipay/Ant-Multi-Modal-Framework). and install the above packages:
|
||||
|
||||
```plain
|
||||
git clone https://github.com/alipay/Ant-Multi-Modal-Framework.git antmmf/
|
||||
```
|
||||
|
||||
step1. Download the flood-3i dataset (`Images.zip`/`Semantic_mask.zip` at [here](https://drive.google.com/drive/folders/1FMAKf2sszoFKjq0UrUmSLnJDbwQSpfxR), `val.txt` at [here](https://www.notion.so/SkySense-Checkpoints-a7fcff6ce29a4647a08c7fe416910509). Testing dataset should be organized as follows:
|
||||
```plain
|
||||
eval_datasets/
|
||||
└── flood3i/
|
||||
├── Images/
|
||||
│ ├── 10165_0_2.jpg
|
||||
│ ├── 10165_1_0.jpg
|
||||
│ └── ...
|
||||
├── Semantic_mask/
|
||||
│ ├── 10165_lab_0_2.png
|
||||
│ ├── 10165_lab_1_0.png
|
||||
│ └── ...
|
||||
└── val.txt
|
||||
```
|
||||
|
||||
step2. Using the above pretraining wieights or download the pretrained model weights [here](https://www.notion.so/SkySense-Checkpoints-a7fcff6ce29a4647a08c7fe416910509).
|
||||
|
||||
step3. Run the script for evaluating 1-shot performance on flood-3i:
|
||||
|
||||
```plain
|
||||
bash tools/run_1shot.sh <gpu_idx> flood-3i(dataset_name)
|
||||
```
|
||||
|
||||
### Downstream finetuning application
|
||||
|
||||
#### Requirments
|
||||
|
||||
We build our fine-tuning application code on the [openmmlab framework](https://github.com/open-mmlab/mmcv).
|
||||
|
||||
To run our code, please install dependency packages. ( Instalazation takes about 10 minutes on a sever with Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz and 2 A100 GPUs.)
|
||||
|
||||
```plain
|
||||
torch==1.13.1
|
||||
torchvision==0.14.1
|
||||
mmcv-full==2.1.0
|
||||
mmpretrain==1.2.0
|
||||
mmsegmentation==1.2.2
|
||||
mmdetection==3.3.0
|
||||
timm==0.6.13
|
||||
gdal==3.4.0
|
||||
scikit-image==0.19.3
|
||||
```
|
||||
|
||||
#### Run steps
|
||||
|
||||
Step1. Install the mmsegmentation framework under the instrction in [here](https://mmsegmentation.readthedocs.io/en/latest/index.html)
|
||||
|
||||
Step2. Download the evaluation datsets. We take Atlantic dataset for deforestation segmentation as an example. Download the Atlantic dataset at [here](https://github.com/davej23/attention-mechanism-unet). Spliting json files of evaluation framwork [here](https://www.notion.so/SkySense-Checkpoints-a7fcff6ce29a4647a08c7fe416910509).
|
||||
```
|
||||
../rs_datasets/deforestation_atlantic/
|
||||
--
|
||||
├── deforestation_atlantic_test.json
|
||||
├── deforestation_atlantic_train.json
|
||||
├── deforestation_atlantic_val.json
|
||||
├── Test/
|
||||
│ ├── image/
|
||||
│ └── label/
|
||||
├── Training/
|
||||
│ ├── image/
|
||||
│ └── label/
|
||||
└── Validation/
|
||||
├── image/
|
||||
└── label/
|
||||
```
|
||||
|
||||
Step3. Use your pretrained model weights or download the model weights: [here](https://www.notion.so/SkySense-Checkpoints-a7fcff6ce29a4647a08c7fe416910509)
|
||||
|
||||
Step4. Run the finetuning script. We take the Atlantic dataset as an example:
|
||||
|
||||
```bash
|
||||
bash tools/run_finetune.sh configs/atlantic.py
|
||||
```
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
This projects are mainly built on the following projects:
|
||||
|
||||
+ [antmmf](https://github.com/alipay/Ant-Multi-Modal-Framework)
|
||||
+ [mmcv](https://github.com/open-mmlab/mmcv)
|
||||
+ [mmpretrain](https://github.com/open-mmlab/mmpretrain)
|
||||
+ [mmsegmentation](https://github.com/open-mmlab/mmsegmentation)
|
||||
+ [mmdetection](https://github.com/open-mmlab/mmdetection)
|
||||
+ [Painter](https://github.com/baaivision/Painter)
|
||||
|
||||
## License
|
||||
|
||||
The pre-trained model weight and pre-training code are only available for the non-commercial research. For any commercial use or cooperation, please contact Yansheng Li at Wuhan University (e-mail: yansheng.li@whu.edu.cn).
|
||||
|
||||
## Citation
|
||||
If you find our repo useful, please consider giving a star and citation:
|
||||
|
||||
```
|
||||
@article{wu2025semantic,
|
||||
author = {Wu, Kang and Zhang, Yingying and Ru, Lixiang and Dang, Bo and Lao, Jiangwei and Yu, Lei and Luo, Junwei and Zhu, Zifan and Sun, Yue and Zhang, Jiahao and Zhu, Qi and Wang, Jian and Yang, Ming and Chen, Jingdong and Zhang, Yongjun and Li, Yansheng},
|
||||
title = {A semantic‑enhanced multi‑modal remote sensing foundation model for Earth observation},
|
||||
journal = {Nature Machine Intelligence},
|
||||
year = {2025},
|
||||
doi = {10.1038/s42256-025-01078-8},
|
||||
url = {https://doi.org/10.1038/s42256-025-01078-8}
|
||||
}
|
||||
|
||||
@inproceedings{guo2024skysense,
|
||||
author = {Guo, Xin and Lao, Jiangwei and Dang, Bo and Zhang, Yingying and Yu, Lei and Ru, Lixiang and Zhong, Liheng and Huang, Ziyuan and Wu, Kang and Hu, Dingxiang and He, Huimei and Wang, Jian and Chen, Jingdong and Yang, Ming and Zhang, Yongjun and Li, Yansheng},
|
||||
title = {SkySense: A Multi-Modal Remote Sensing Foundation Model Towards Universal Interpretation for Earth Observation Imagery},
|
||||
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
month = {June},
|
||||
year = {2024},
|
||||
pages = {27672-27683}
|
||||
}
|
||||
|
||||
@inproceedings{zhu2025skysenseo,
|
||||
title={Skysense-o: Towards open-world remote sensing interpretation with vision-centric visual-language modeling},
|
||||
author={Zhu, Qi and Lao, Jiangwei and Ji, Deyi and Luo, Junwei and Wu, Kang and Zhang, Yingying and Ru, Lixiang and Wang, Jian and Chen, Jingdong and Yang, Ming and others},
|
||||
booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference},
|
||||
pages={14733--14744},
|
||||
year={2025}
|
||||
}
|
||||
|
||||
@article{luo2024skysensegpt,
|
||||
title={Skysensegpt: A fine-grained instruction tuning dataset and model for remote sensing vision-language understanding},
|
||||
author={Luo, Junwei and Pang, Zhen and Zhang, Yongjun and Wang, Tingzhu and Wang, Linlin and Dang, Bo and Lao, Jiangwei and Wang, Jian and Chen, Jingdong and Tan, Yihua and others},
|
||||
journal={arXiv preprint arXiv:2406.10100},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
## Star History
|
||||
|
||||
[](https://www.star-history.com/#kang-wu/SkySensePlusPlus&Date)
|
||||
171
configs/eval_skysense_pp_flood3i.yml
Normal file
171
configs/eval_skysense_pp_flood3i.yml
Normal file
@@ -0,0 +1,171 @@
|
||||
task_attributes:
|
||||
segmentation:
|
||||
dataset_attributes:
|
||||
few_shot_flood_segmentation:
|
||||
data_root_dir: 'eval_datasets/flood3i'
|
||||
data_txt: 'eval_datasets/flood3i/val.txt'
|
||||
img_dir: 'eval_datasets/flood3i/Images'
|
||||
tgt_dir: 'eval_datasets/flood3i/Semantic_mask'
|
||||
num_shot: 1
|
||||
seq_len: 1
|
||||
npz_key: 'arr_0'
|
||||
image_size:
|
||||
hr: (512, 512)
|
||||
s2: (16, 16)
|
||||
s1: (16, 16)
|
||||
anno: (512, 512)
|
||||
mim:
|
||||
input_size: (1024, 512)
|
||||
patch_size: 128
|
||||
mask_ratio: 0.5
|
||||
|
||||
model_attributes:
|
||||
SkySensePP:
|
||||
sources: ['hr', 's2', 's1']
|
||||
use_glbank: False
|
||||
use_modal_vae: True
|
||||
use_ctpe: False
|
||||
use_cls_token_uper_head: False
|
||||
upsacle_results: True
|
||||
calendar_time: 365
|
||||
vocabulary_size: 64
|
||||
backbone_hr:
|
||||
type: 'SwinTransformerV2MSL'
|
||||
arch: 'huge'
|
||||
use_attn: True
|
||||
merge_stage: 2
|
||||
vocabulary_size: 64
|
||||
img_size: 224
|
||||
patch_size: 4
|
||||
in_channels: 3
|
||||
window_size: 8
|
||||
drop_rate: 0.
|
||||
drop_path_rate: 0.2
|
||||
out_indices: (0,1,2,3)
|
||||
use_abs_pos_embed: False
|
||||
interpolate_mode: 'bicubic'
|
||||
with_cp: True
|
||||
frozen_stages: -1
|
||||
norm_eval: False
|
||||
pad_small_map: False
|
||||
pretrained_window_sizes: [0, 0, 0, 0]
|
||||
|
||||
backbone_s2:
|
||||
type: 'VisionTransformerMSL'
|
||||
img_size: (16, 16)
|
||||
use_attn: False
|
||||
merge_stage: 4
|
||||
vocabulary_size: 64
|
||||
patch_size: 4
|
||||
in_channels: 10
|
||||
embed_dims: 1024
|
||||
num_layers: 24
|
||||
num_heads: 16
|
||||
mlp_ratio: 4
|
||||
out_indices: (5,11,17,23)
|
||||
qkv_bias: True
|
||||
drop_rate: 0.
|
||||
attn_drop_rate: 0.
|
||||
drop_path_rate: 0.3
|
||||
with_cls_token: False
|
||||
output_cls_token: False
|
||||
act_cfg:
|
||||
type: 'GELU'
|
||||
norm_cfg:
|
||||
type: 'LN'
|
||||
eps: 1e-6
|
||||
with_cp: True
|
||||
interpolate_mode: 'bicubic'
|
||||
|
||||
head_s2:
|
||||
type: 'UPHead'
|
||||
in_dim: 1024
|
||||
out_dim: 2816
|
||||
up_scale: 4
|
||||
|
||||
backbone_s1:
|
||||
type: 'VisionTransformerMSL'
|
||||
img_size: (16, 16)
|
||||
use_attn: False
|
||||
merge_stage: 4
|
||||
vocabulary_size: 64
|
||||
patch_size: 4
|
||||
in_channels: 2
|
||||
embed_dims: 1024
|
||||
num_layers: 24
|
||||
num_heads: 16
|
||||
mlp_ratio: 4
|
||||
out_indices: (5,11,17,23)
|
||||
qkv_bias: True
|
||||
drop_rate: 0.
|
||||
attn_drop_rate: 0.
|
||||
drop_path_rate: 0.3
|
||||
with_cls_token: False
|
||||
output_cls_token: False
|
||||
act_cfg:
|
||||
type: 'GELU'
|
||||
norm_cfg:
|
||||
type: 'LN'
|
||||
eps: 1e-6
|
||||
with_cp: True
|
||||
interpolate_mode: 'bicubic'
|
||||
|
||||
head_s1:
|
||||
type: 'UPHead'
|
||||
in_dim: 1024
|
||||
out_dim: 2816
|
||||
up_scale: 4
|
||||
|
||||
rec_head_hr:
|
||||
type: 'UPerHead'
|
||||
in_channels: [704, 704, 1408, 2816, 1024]
|
||||
in_index: [0, 1, 2, 3, 4]
|
||||
pool_scales: (1, 2, 3, 6)
|
||||
channels: 512
|
||||
dropout_ratio: 0.1
|
||||
num_classes: 65
|
||||
norm_cfg:
|
||||
type: 'SyncBN'
|
||||
requires_grad: true
|
||||
align_corners: false
|
||||
|
||||
necks:
|
||||
type: 'TransformerEncoder'
|
||||
input_dims: 2816
|
||||
embed_dims: 1024
|
||||
num_layers: 24
|
||||
num_heads: 16
|
||||
mlp_ratio: 4
|
||||
qkv_bias: True
|
||||
drop_rate: 0.
|
||||
attn_drop_rate: 0.
|
||||
drop_path_rate: 0.3
|
||||
with_cls_token: True
|
||||
output_cls_token: True
|
||||
norm_cfg:
|
||||
type: 'LN'
|
||||
act_cfg:
|
||||
type: 'GELU'
|
||||
num_fcs: 2
|
||||
norm_eval: False
|
||||
with_cp: True
|
||||
|
||||
modality_vae:
|
||||
type: 'ModalityCompletion'
|
||||
input_shape_hr: [2816, 32, 16]
|
||||
input_shape_s2: [2816, 32, 16]
|
||||
input_shape_s1: [2816, 32, 16]
|
||||
conv_dim: 256
|
||||
z_dim: 256
|
||||
n_codebook: 8192
|
||||
|
||||
amp_attributes:
|
||||
amp_escapes: Conv2d
|
||||
opt_level: O1
|
||||
init_scale: 1
|
||||
|
||||
predictor_parameters:
|
||||
predictor: 'OneshotPredictor'
|
||||
replace_speedup_op: True
|
||||
device: cuda
|
||||
local_rank: 0
|
||||
248
configs/pretrain_skysensepp.yml
Normal file
248
configs/pretrain_skysensepp.yml
Normal file
@@ -0,0 +1,248 @@
|
||||
task_attributes:
|
||||
segmentation:
|
||||
dataset_attributes:
|
||||
pretraining_loader:
|
||||
data_root_dir: 'pretrain_datasets/'
|
||||
train_json_path_list: ['dynamic-mm/dynamic-mm_train.json', 'pastis-mm/pastis-mm_train.json', 'vaihingen/vaihingen_train.json', 'deepglobe/deepglobe_train.json', 'potsdam/potsdam_train.json', 'fbp/fbp_train.json', 'loveda/loveda_train.json', 'isaid/isaid_train.json', 'jl16-mm/jl16-mm_train.json', 'flair-mm/flair-mm_train.json', 's2naip-mm/s2naip-mm_train.json', 'dfc20-mm/dfc20-mm_train.json', 'c2segab-mm/c2segab-mm_train.json']
|
||||
val_json_path_list: ['dynamic-mm/dynamic-mm_val.json', 'pastis-mm/pastis-mm_val.json', 'vaihingen/vaihingen_val.json', 'deepglobe/deepglobe_val.json', 'potsdam/potsdam_val.json', 'fbp/fbp_val.json', 'loveda/loveda_val.json', 'isaid/isaid_val.json', 'jl16-mm/jl16-mm_val.json', 'flair-mm/flair-mm_val.json', 's2naip-mm/s2naip-mm_val.json', 'dfc20-mm/dfc20-mm_val.json', 'c2segab-mm/c2segab-mm_val.json']
|
||||
use_multi_pairs: True
|
||||
seq_len: 1
|
||||
half_mask_ratio: 0.3
|
||||
min_random_scale: 0.3
|
||||
cls_repeat_cnt: 2000
|
||||
image_size:
|
||||
hr: (512, 512)
|
||||
s2: (16, 16)
|
||||
s1: (16, 16)
|
||||
anno: (512, 512)
|
||||
mim:
|
||||
input_size: (1024, 512)
|
||||
patch_size: 128
|
||||
mask_ratio: 0.5
|
||||
|
||||
model_attributes:
|
||||
SkySensePP:
|
||||
sources: ['hr', 's2', 's1']
|
||||
use_modal_vae: True
|
||||
use_ctpe: False
|
||||
use_cls_token_uper_head: False
|
||||
upsacle_results: True
|
||||
calendar_time: 365
|
||||
vocabulary_size: 64
|
||||
backbone_hr:
|
||||
type: 'SwinTransformerV2MSL'
|
||||
arch: 'huge'
|
||||
use_attn: True
|
||||
merge_stage: 2
|
||||
vocabulary_size: 64
|
||||
img_size: 224
|
||||
patch_size: 4
|
||||
in_channels: 3
|
||||
window_size: 8
|
||||
drop_rate: 0.
|
||||
drop_path_rate: 0.2
|
||||
out_indices: (0,1,2,3)
|
||||
use_abs_pos_embed: False
|
||||
interpolate_mode: 'bicubic'
|
||||
with_cp: True
|
||||
frozen_stages: -1
|
||||
norm_eval: False
|
||||
pad_small_map: False
|
||||
pretrained_window_sizes: [0, 0, 0, 0]
|
||||
init_cfg:
|
||||
type: Pretrained
|
||||
checkpoint: 'pretrain/skysense_model_backbone_hr.pth'
|
||||
|
||||
backbone_s2:
|
||||
type: 'VisionTransformerMSL'
|
||||
img_size: (16, 16)
|
||||
use_attn: False
|
||||
merge_stage: 4
|
||||
vocabulary_size: 64
|
||||
patch_size: 4
|
||||
in_channels: 10
|
||||
embed_dims: 1024
|
||||
num_layers: 24
|
||||
num_heads: 16
|
||||
mlp_ratio: 4
|
||||
out_indices: (5,11,17,23)
|
||||
qkv_bias: True
|
||||
drop_rate: 0.
|
||||
attn_drop_rate: 0.
|
||||
drop_path_rate: 0.3
|
||||
with_cls_token: False
|
||||
output_cls_token: False
|
||||
act_cfg:
|
||||
type: 'GELU'
|
||||
norm_cfg:
|
||||
type: 'LN'
|
||||
eps: 1e-6
|
||||
with_cp: True
|
||||
interpolate_mode: 'bicubic'
|
||||
init_cfg:
|
||||
type: Pretrained
|
||||
checkpoint: 'pretrain/skysense_model_backbone_s2.pth'
|
||||
|
||||
head_s2:
|
||||
type: 'UPHead'
|
||||
in_dim: 1024
|
||||
out_dim: 2816 #2816
|
||||
up_scale: 4
|
||||
init_cfg:
|
||||
type: Pretrained
|
||||
checkpoint: 'pretrain/skysense_model_head_s2.pth'
|
||||
|
||||
backbone_s1:
|
||||
type: 'VisionTransformerMSL'
|
||||
img_size: (16, 16)
|
||||
use_attn: False
|
||||
merge_stage: 4
|
||||
vocabulary_size: 64
|
||||
patch_size: 4
|
||||
in_channels: 2
|
||||
embed_dims: 1024
|
||||
num_layers: 24
|
||||
num_heads: 16
|
||||
mlp_ratio: 4
|
||||
out_indices: (5,11,17,23)
|
||||
qkv_bias: True
|
||||
drop_rate: 0.
|
||||
attn_drop_rate: 0.
|
||||
drop_path_rate: 0.3
|
||||
with_cls_token: False
|
||||
output_cls_token: False
|
||||
act_cfg:
|
||||
type: 'GELU'
|
||||
norm_cfg:
|
||||
type: 'LN'
|
||||
eps: 1e-6
|
||||
with_cp: True
|
||||
interpolate_mode: 'bicubic'
|
||||
init_cfg:
|
||||
type: Pretrained
|
||||
checkpoint: 'pretrain/skysense_model_backbone_s1.pth'
|
||||
|
||||
head_s1:
|
||||
type: 'UPHead'
|
||||
in_dim: 1024
|
||||
out_dim: 2816 #2816
|
||||
up_scale: 4
|
||||
init_cfg:
|
||||
type: Pretrained
|
||||
checkpoint: 'pretrain/skysense_model_head_s1.pth'
|
||||
|
||||
rec_head_hr:
|
||||
type: 'UPerHead'
|
||||
in_channels: [704, 704, 1408, 2816, 1024]
|
||||
in_index: [0, 1, 2, 3, 4]
|
||||
pool_scales: (1, 2, 3, 6)
|
||||
channels: 512
|
||||
dropout_ratio: 0.1
|
||||
num_classes: 65
|
||||
norm_cfg:
|
||||
type: 'SyncBN'
|
||||
requires_grad: true
|
||||
align_corners: false
|
||||
|
||||
necks:
|
||||
type: 'TransformerEncoder'
|
||||
input_dims: 2816
|
||||
embed_dims: 1024
|
||||
num_layers: 24
|
||||
num_heads: 16
|
||||
mlp_ratio: 4
|
||||
qkv_bias: True
|
||||
drop_rate: 0.
|
||||
attn_drop_rate: 0.
|
||||
drop_path_rate: 0.3
|
||||
with_cls_token: True
|
||||
output_cls_token: True
|
||||
norm_cfg:
|
||||
type: 'LN'
|
||||
act_cfg:
|
||||
type: 'GELU'
|
||||
num_fcs: 2
|
||||
norm_eval: False
|
||||
with_cp: True
|
||||
init_cfg:
|
||||
type: Pretrained
|
||||
checkpoint: 'pretrain/skysense_model_fusion.pth'
|
||||
|
||||
modality_vae:
|
||||
type: 'ModalityCompletion'
|
||||
input_shape_hr: [2816, 32, 16]
|
||||
input_shape_s2: [2816, 32, 16]
|
||||
input_shape_s1: [2816, 32, 16]
|
||||
conv_dim: 256
|
||||
z_dim: 256
|
||||
n_codebook: 8192
|
||||
|
||||
metrics:
|
||||
- type: 'sem_metric'
|
||||
|
||||
losses:
|
||||
- type: 'RecLoss'
|
||||
params:
|
||||
weight: 1.0
|
||||
patch_size: 4
|
||||
balance: True
|
||||
use_all_patch: True
|
||||
vocabulary_size: 64
|
||||
feature_merged: True
|
||||
pred_key: 'logits_hr'
|
||||
mask_key: 'mask_hr'
|
||||
target_key: 'mapped_targets'
|
||||
use_bg: True
|
||||
|
||||
- type: 'ModalityVAELoss'
|
||||
params:
|
||||
weight: 1.0
|
||||
|
||||
|
||||
optimizer_attributes:
|
||||
type: AdamW
|
||||
params:
|
||||
lr: 2e-04
|
||||
betas: (0.9, 0.999)
|
||||
weight_decay: 0.04
|
||||
|
||||
lr_parameters:
|
||||
layer_decay: 0.7
|
||||
frozen_blocks: 12
|
||||
frozen_fusion_blocks_start: 3
|
||||
|
||||
training_parameters:
|
||||
trainer: 'seg_trainer'
|
||||
run_type: train
|
||||
seed: 24042301
|
||||
pin_memory: True
|
||||
batch_size: 256
|
||||
test_batch_size: 128
|
||||
num_workers: 16
|
||||
max_iterations: 30000
|
||||
num_warmup_steps: 1000
|
||||
log_interval: 50
|
||||
snapshot_interval: 3000
|
||||
cos_lr: False
|
||||
|
||||
clip_norm_mode: all
|
||||
clip_gradients: true
|
||||
max_grad_l2_norm: 5
|
||||
|
||||
enable_tf32: False
|
||||
enable_amp: True
|
||||
find_unused_parameters: True
|
||||
synchronized_loss: True
|
||||
|
||||
static_graph: True
|
||||
replace_speedup_op: True
|
||||
|
||||
ema: False
|
||||
|
||||
distributed_batch_sampler:
|
||||
batch_size: 8
|
||||
|
||||
amp_attributes:
|
||||
amp_escapes: Conv2d
|
||||
opt_level: O1
|
||||
init_scale: 1
|
||||
303
finetune/configs/atlantic.py
Normal file
303
finetune/configs/atlantic.py
Normal file
@@ -0,0 +1,303 @@
|
||||
crop_size = (
|
||||
256,
|
||||
256,
|
||||
)
|
||||
data_preprocessor = dict(
|
||||
bgr_to_rgb=True,
|
||||
mean=[
|
||||
537.9411629981602,
|
||||
615.7886221108977,
|
||||
343.4481583821405,
|
||||
3010.641650390625,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=crop_size,
|
||||
std=[
|
||||
367.4598430230881,
|
||||
254.2473100510193,
|
||||
187.5437562223154,
|
||||
921.0792775874182,
|
||||
],
|
||||
type='SegDataPreProcessor')
|
||||
data_root = 'rs_datasets/'
|
||||
dataset_type = 'AtlanticDataset'
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
by_epoch=False, interval=1000, save_best='mIoU',max_keep_ckpts=1,
|
||||
type='CheckpointHook'),
|
||||
logger=dict(interval=50, log_metric_by_epoch=False, type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
default_scope = 'mmseg'
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=True,
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
|
||||
img_ratios = [
|
||||
0.5,
|
||||
0.75,
|
||||
1.0,
|
||||
1.25,
|
||||
1.5,
|
||||
1.75,
|
||||
]
|
||||
launcher = 'pytorch'
|
||||
load_from = None
|
||||
log_level = 'INFO'
|
||||
log_processor = dict(by_epoch=False)
|
||||
mean = [
|
||||
537.9411629981602,
|
||||
615.7886221108977,
|
||||
343.4481583821405,
|
||||
3010.641650390625,
|
||||
]
|
||||
model = dict(
|
||||
auxiliary_head=dict(
|
||||
align_corners=False,
|
||||
channels=256,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=512,
|
||||
in_index=3,
|
||||
loss_decode=dict(
|
||||
loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=2,
|
||||
num_convs=1,
|
||||
type='FCNHead'),
|
||||
backbone=dict(
|
||||
act_cfg=dict(type='GELU'),
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
drop_rate=0.1,
|
||||
embed_dims=1024,
|
||||
img_size=crop_size,
|
||||
in_channels=4,
|
||||
interpolate_mode='bilinear',
|
||||
mlp_ratio=4,
|
||||
norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'),
|
||||
norm_eval=False,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
out_indices=(
|
||||
5,
|
||||
11,
|
||||
17,
|
||||
23,
|
||||
),
|
||||
patch_size=4,
|
||||
qkv_bias=True,
|
||||
type='VisionTransformer',
|
||||
with_cls_token=False),
|
||||
data_preprocessor=dict(
|
||||
bgr_to_rgb=True,
|
||||
mean=[
|
||||
537.9411629981602,
|
||||
615.7886221108977,
|
||||
343.4481583821405,
|
||||
3010.641650390625,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=crop_size,
|
||||
std=[
|
||||
367.4598430230881,
|
||||
254.2473100510193,
|
||||
187.5437562223154,
|
||||
921.0792775874182,
|
||||
],
|
||||
type='SegDataPreProcessor'),
|
||||
decode_head=dict(
|
||||
align_corners=False,
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=[
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
],
|
||||
in_index=[
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
],
|
||||
loss_decode=dict(
|
||||
loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=2,
|
||||
pool_scales=(
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
6,
|
||||
),
|
||||
type='UPerHead'),
|
||||
neck=dict(
|
||||
in_channels=[
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
],
|
||||
out_channels=512,
|
||||
scales=[
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
type='MultiLevelNeck'),
|
||||
pretrained="pretrain/skysensepp_release_s2.pth",
|
||||
test_cfg=dict(mode='whole'),
|
||||
train_cfg=dict(),
|
||||
type='EncoderDecoder')
|
||||
norm_cfg = dict(requires_grad=True, type='SyncBN')
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
betas=(
|
||||
0.9,
|
||||
0.999,
|
||||
), lr=6e-05, type='AdamW', weight_decay=0.01),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys=dict(
|
||||
cls_token=dict(decay_mult=0.0),
|
||||
norm=dict(decay_mult=0.0),
|
||||
pos_embed=dict(decay_mult=0.0))),
|
||||
type='OptimWrapper')
|
||||
optimizer = dict(lr=0.01, momentum=0.9, type='SGD', weight_decay=0.0005)
|
||||
param_scheduler = [
|
||||
dict(
|
||||
begin=0, by_epoch=False, end=1000, start_factor=1e-06,
|
||||
type='LinearLR'),
|
||||
dict(
|
||||
begin=1000,
|
||||
by_epoch=False,
|
||||
end=10000,
|
||||
eta_min=0.0,
|
||||
power=1.0,
|
||||
type='PolyLR'),
|
||||
]
|
||||
randomness = dict(seed=20240315)
|
||||
resume = False
|
||||
std = [
|
||||
367.4598430230881,
|
||||
254.2473100510193,
|
||||
187.5437562223154,
|
||||
921.0792775874182,
|
||||
]
|
||||
test_cfg = dict(type='TestLoop')
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'deforestation_atlantic/deforestation_atlantic_test.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadSingleRSImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='AtlanticDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
test_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
test_pipeline = [
|
||||
dict(type='LoadSingleRSImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
train_cfg = dict(max_iters=10000, type='IterBasedTrainLoop', val_interval=500)
|
||||
train_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'deforestation_atlantic/deforestation_atlantic_train.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadSingleRSImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(cat_max_ratio=0.75, crop_size=crop_size, type='RandomCrop'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='AtlanticDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=True, type='InfiniteSampler'))
|
||||
train_pipeline = [
|
||||
dict(type='LoadSingleRSImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(cat_max_ratio=0.75, crop_size=crop_size, type='RandomCrop'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
tta_model = dict(type='SegTTAModel')
|
||||
tta_pipeline = [
|
||||
dict(backend_args=None, type='LoadImageFromFile'),
|
||||
dict(
|
||||
transforms=[
|
||||
[
|
||||
dict(keep_ratio=True, scale_factor=0.5, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=0.75, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.0, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.25, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.5, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.75, type='Resize'),
|
||||
],
|
||||
[
|
||||
dict(direction='horizontal', prob=0.0, type='RandomFlip'),
|
||||
dict(direction='horizontal', prob=1.0, type='RandomFlip'),
|
||||
],
|
||||
[
|
||||
dict(type='LoadAnnotations'),
|
||||
],
|
||||
[
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
],
|
||||
type='TestTimeAug'),
|
||||
]
|
||||
val_cfg = dict(type='ValLoop')
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'deforestation_atlantic/deforestation_atlantic_val.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadSingleRSImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='AtlanticDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
val_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
]
|
||||
visualizer = dict(
|
||||
name='visualizer',
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=[
|
||||
dict(type='LocalVisBackend'),
|
||||
])
|
||||
work_dir = 'save/atlantic_skysensepp'
|
||||
284
finetune/configs/c2smsflood.py
Normal file
284
finetune/configs/c2smsflood.py
Normal file
@@ -0,0 +1,284 @@
|
||||
dataset_type = 'C2SFloodDataset'
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
by_epoch=False,
|
||||
interval=2000,
|
||||
max_keep_ckpts=2,
|
||||
save_best='mIoU',
|
||||
type='CheckpointHook'),
|
||||
logger=dict(interval=50, log_metric_by_epoch=False, type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
default_scope = 'mmseg'
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=True,
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
|
||||
launcher = 'pytorch'
|
||||
load_from = None
|
||||
log_level = 'INFO'
|
||||
log_processor = dict(by_epoch=False)
|
||||
model = dict(
|
||||
auxiliary_head=dict(
|
||||
align_corners=False,
|
||||
channels=256,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=512,
|
||||
in_index=3,
|
||||
loss_decode=dict(
|
||||
loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=4,
|
||||
num_convs=1,
|
||||
type='FCNHead'),
|
||||
backbone=dict(
|
||||
act_cfg=dict(type='GELU'),
|
||||
attn_drop_rate=0.0,
|
||||
downscale_indices=(
|
||||
5,
|
||||
11,
|
||||
),
|
||||
drop_path_rate=0.0,
|
||||
drop_rate=0.1,
|
||||
embed_dims=1024,
|
||||
img_size=(
|
||||
256,
|
||||
256,
|
||||
),
|
||||
in_channels=10,
|
||||
interpolate_mode='bilinear',
|
||||
mlp_ratio=4,
|
||||
norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'),
|
||||
norm_eval=False,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
out_indices=(
|
||||
5,
|
||||
11,
|
||||
17,
|
||||
23,
|
||||
),
|
||||
patch_size=4,
|
||||
qkv_bias=True,
|
||||
type='VisionTransformer',
|
||||
with_cls_token=False),
|
||||
data_preprocessor=dict(
|
||||
bgr_to_rgb=True,
|
||||
mean=[
|
||||
1381.26,
|
||||
1302.05,
|
||||
1179.27,
|
||||
1393.56,
|
||||
2164.76,
|
||||
2561.75,
|
||||
2377.94,
|
||||
2791.13,
|
||||
1998.09,
|
||||
1196.08,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(
|
||||
256,
|
||||
256,
|
||||
),
|
||||
std=[
|
||||
653.25,
|
||||
659.61,
|
||||
779.58,
|
||||
720.45,
|
||||
871.09,
|
||||
1035.57,
|
||||
965.36,
|
||||
1141.71,
|
||||
1019.73,
|
||||
825.01,
|
||||
],
|
||||
type='SegDataPreProcessor'),
|
||||
decode_head=dict(
|
||||
align_corners=False,
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=[
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
],
|
||||
in_index=[
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
],
|
||||
loss_decode=dict(
|
||||
loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=4,
|
||||
pool_scales=(
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
6,
|
||||
),
|
||||
type='UPerHead'),
|
||||
neck=dict(
|
||||
in_channels=[
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
],
|
||||
out_channels=512,
|
||||
scales=[
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
type='MultiLevelNeck'),
|
||||
pretrained=
|
||||
'pretrain/skysensepp_mmcvt_s2.pth',
|
||||
test_cfg=dict(mode='whole'),
|
||||
train_cfg=dict(),
|
||||
type='EncoderDecoder')
|
||||
norm_cfg = None
|
||||
optim_wrapper = dict(
|
||||
constructor='LearningRateDecayOptimizerConstructor',
|
||||
optimizer=dict(
|
||||
betas=(
|
||||
0.9,
|
||||
0.999,
|
||||
), lr=0.0005, type='AdamW', weight_decay=0.015),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys=dict(
|
||||
cls_token=dict(decay_mult=0.0),
|
||||
norm=dict(decay_mult=0.0),
|
||||
pos_embed=dict(decay_mult=0.0)),
|
||||
decay_rate=0.84,
|
||||
decay_type='layer_wise',
|
||||
num_layers=24),
|
||||
type='OptimWrapper')
|
||||
param_scheduler = [
|
||||
dict(
|
||||
begin=0, by_epoch=False, end=1000, start_factor=1e-06,
|
||||
type='LinearLR'),
|
||||
dict(
|
||||
begin=1000,
|
||||
by_epoch=False,
|
||||
end=10000,
|
||||
eta_min=0.0,
|
||||
power=1.0,
|
||||
type='PolyLR'),
|
||||
]
|
||||
randomness = dict(seed=1)
|
||||
resume = False
|
||||
test_cfg = dict(type='TestLoop')
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/c2smsfloods/c2s_s2_val_mm.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromNpz'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='C2SFloodDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
test_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore'
|
||||
], type='IoUMetric')
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromNpz'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
train_cfg = dict(max_iters=10000, type='IterBasedTrainLoop', val_interval=500)
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/c2smsfloods/c2s_s2_train_mm.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromNpz'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotationsNpz'),
|
||||
dict(
|
||||
cat_max_ratio=0.75, crop_size=(
|
||||
256,
|
||||
256,
|
||||
), type='RandomCrop'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='C2SFloodDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=True, type='InfiniteSampler'))
|
||||
tta_model = dict(type='SegTTAModel')
|
||||
tta_pipeline = [
|
||||
dict(backend_args=None, type='LoadImageFromFile'),
|
||||
dict(
|
||||
transforms=[
|
||||
[
|
||||
dict(keep_ratio=True, scale_factor=1.0, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.2, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.4, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.6, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.8, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=2.0, type='Resize'),
|
||||
],
|
||||
[
|
||||
dict(direction='horizontal', prob=0.0, type='RandomFlip'),
|
||||
dict(direction='horizontal', prob=1.0, type='RandomFlip'),
|
||||
],
|
||||
[
|
||||
dict(type='LoadAnnotations'),
|
||||
],
|
||||
[
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
],
|
||||
type='TestTimeAug'),
|
||||
]
|
||||
val_cfg = dict(type='ValLoop')
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/c2smsfloods/c2s_s2_val_mm.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromNpz'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='C2SFloodDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
val_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
], type='IoUMetric')
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
]
|
||||
visualizer = dict(
|
||||
name='visualizer',
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=[
|
||||
dict(type='LocalVisBackend'),
|
||||
])
|
||||
227
finetune/configs/cabuar.py
Normal file
227
finetune/configs/cabuar.py
Normal file
@@ -0,0 +1,227 @@
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
by_epoch=False,
|
||||
interval=1000,
|
||||
max_keep_ckpts=1,
|
||||
save_best='mIoU',
|
||||
type='CheckpointHook'),
|
||||
logger=dict(interval=50, log_metric_by_epoch=False, type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
default_scope = 'mmseg'
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=True,
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
|
||||
launcher = 'pytorch'
|
||||
load_from = None
|
||||
log_level = 'INFO'
|
||||
log_processor = dict(by_epoch=False)
|
||||
model = dict(
|
||||
auxiliary_head=dict(
|
||||
align_corners=False,
|
||||
channels=256,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=512,
|
||||
in_index=3,
|
||||
loss_decode=dict(
|
||||
loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=2,
|
||||
num_convs=1,
|
||||
type='FCNHead'),
|
||||
backbone=dict(
|
||||
act_cfg=dict(type='GELU'),
|
||||
attn_drop_rate=0.0,
|
||||
downscale_indices=(5, 11,),
|
||||
drop_path_rate=0.0,
|
||||
drop_rate=0.1,
|
||||
embed_dims=1024,
|
||||
img_size=(256, 256,),
|
||||
in_channels=10,
|
||||
interpolate_mode='bilinear',
|
||||
mlp_ratio=4,
|
||||
norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'),
|
||||
norm_eval=False,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
out_indices=(5, 11, 17, 23,),
|
||||
patch_size=4,
|
||||
qkv_bias=True,
|
||||
type='VisionTransformer',
|
||||
with_cls_token=False),
|
||||
data_preprocessor=dict(
|
||||
bgr_to_rgb=True,
|
||||
mean=[4.50021132, 6.09891466, 7.50766315, 9.54643074, 12.82568112, 14.29062133, 15.24644993, 15.73945708, 16.60374872, 12.31011599,],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(256, 256,),
|
||||
std=[2.6094148, 2.49566825, 1.37103968, 2.6094148, 2.49566825, 1.37103968, 2.6094148, 2.49566825, 1.37103968, 1.37103968, ],
|
||||
type='SegDataPreProcessor'),
|
||||
decode_head=dict(
|
||||
align_corners=False,
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=[512, 512, 512, 512,],
|
||||
in_index=[0, 1, 2, 3,],
|
||||
loss_decode=dict(
|
||||
loss_weight=1.0, type='CrossEntropyLoss', balance=True, max_scale=4.0, use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=2,
|
||||
pool_scales=(1, 2, 3, 6,),
|
||||
type='UPerHead'),
|
||||
neck=dict(
|
||||
in_channels=[1024, 1024, 1024, 1024,],
|
||||
out_channels=512,
|
||||
scales=[1, 1, 1, 1,],
|
||||
type='MultiLevelNeck'),
|
||||
pretrained='pretrain/skysensepp_mmcvt_s2.pth',
|
||||
test_cfg=dict(mode='whole'),
|
||||
train_cfg=dict(),
|
||||
type='EncoderDecoder')
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(betas=(0.9, 0.999,), lr=None, type='AdamW', weight_decay=None,),
|
||||
constructor='LearningRateDecayOptimizerConstructor',
|
||||
paramwise_cfg=dict(
|
||||
num_layers=24,
|
||||
decay_rate=None,
|
||||
decay_type='layer_wise',
|
||||
custom_keys=dict(
|
||||
cls_token=dict(decay_mult=0.0),
|
||||
norm=dict(decay_mult=0.0),
|
||||
pos_embed=dict(decay_mult=0.0))),
|
||||
type='OptimWrapper')
|
||||
param_scheduler = [
|
||||
dict(
|
||||
begin=0, by_epoch=False, end=1000, start_factor=1e-06,
|
||||
type='LinearLR'),
|
||||
dict(
|
||||
begin=1000,
|
||||
by_epoch=False,
|
||||
end=10000,
|
||||
eta_min=0.0,
|
||||
power=1.0,
|
||||
type='PolyLR'),
|
||||
]
|
||||
randomness = dict(seed=20240315)
|
||||
resume = False
|
||||
test_cfg = dict(type='TestLoop')
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/cabuar/cabura_val_fold_4.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(data_key='image', type='LoadImageFromNpz'),
|
||||
dict(
|
||||
data_key='image',
|
||||
reduce_zero_label=False,
|
||||
type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='CABURADataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
test_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
test_pipeline = [
|
||||
dict(data_key='image', type='LoadImageFromNpz'),
|
||||
dict(data_key='image', reduce_zero_label=False, type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
train_cfg = dict(max_iters=10000, type='IterBasedTrainLoop', val_interval=500)
|
||||
train_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/cabuar/cabura_train_fold0_3.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(data_key='image', type='LoadImageFromNpz'),
|
||||
dict(
|
||||
data_key='image',
|
||||
reduce_zero_label=False,
|
||||
type='LoadAnnotationsNpz'),
|
||||
dict(
|
||||
cat_max_ratio=0.75, crop_size=(
|
||||
256,
|
||||
256,
|
||||
), type='RandomCrop'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='CABURADataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=True, type='InfiniteSampler'))
|
||||
tta_model = dict(type='SegTTAModel')
|
||||
tta_pipeline = [
|
||||
dict(backend_args=None, type='LoadImageFromFile'),
|
||||
dict(
|
||||
transforms=[
|
||||
[
|
||||
dict(keep_ratio=True, scale_factor=1.0, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.2, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.4, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.6, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.8, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=2.0, type='Resize'),
|
||||
],
|
||||
[
|
||||
dict(direction='horizontal', prob=0.0, type='RandomFlip'),
|
||||
dict(direction='horizontal', prob=1.0, type='RandomFlip'),
|
||||
],
|
||||
[
|
||||
dict(type='LoadAnnotations'),
|
||||
],
|
||||
[
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
],
|
||||
type='TestTimeAug'),
|
||||
]
|
||||
val_cfg = dict(type='ValLoop')
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/cabuar/cabura_val_fold_4.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(data_key='image', type='LoadImageFromNpz'),
|
||||
dict(
|
||||
data_key='image',
|
||||
reduce_zero_label=False,
|
||||
type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='CABURADataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
val_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
]
|
||||
visualizer = dict(
|
||||
name='visualizer',
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=[
|
||||
dict(type='LocalVisBackend'),
|
||||
])
|
||||
work_dir = 'work_dirs/ft_cabura_test'
|
||||
414
finetune/configs/germany.py
Normal file
414
finetune/configs/germany.py
Normal file
@@ -0,0 +1,414 @@
|
||||
crop_size = (
|
||||
24,
|
||||
24,
|
||||
)
|
||||
data_preprocessor = dict(
|
||||
bgr_to_rgb=False,
|
||||
mean=[
|
||||
2482.0061841829206,
|
||||
2456.642580060208,
|
||||
2667.8229979675334,
|
||||
2744.9377076257624,
|
||||
3620.1499158373827,
|
||||
4063.9126981046647,
|
||||
3922.2406108776354,
|
||||
4264.908986788407,
|
||||
2453.0070206816135,
|
||||
1774.0019119673998,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(
|
||||
24,
|
||||
24,
|
||||
),
|
||||
std=[
|
||||
2392.1256366526068,
|
||||
2100.1364646122875,
|
||||
2262.6154840764625,
|
||||
2353.899770400333,
|
||||
2089.598452203458,
|
||||
2057.1247114077073,
|
||||
2013.2108514271458,
|
||||
2041.0248949410561,
|
||||
1380.4643757742374,
|
||||
1243.547946113518,
|
||||
],
|
||||
ts_size=30,
|
||||
type='RSTsSegDataPreProcessor')
|
||||
data_root = 'rs_datasets/'
|
||||
dataset_type = 'GermanyCropDataset'
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
by_epoch=False,
|
||||
interval=2000,
|
||||
max_keep_ckpts=1,
|
||||
save_best='mIoU',
|
||||
type='CheckpointHook'),
|
||||
logger=dict(interval=20, log_metric_by_epoch=False, type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
default_scope = 'mmseg'
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=True,
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
|
||||
find_unused_parameters = True
|
||||
img_ratios = [
|
||||
0.5,
|
||||
0.75,
|
||||
1.0,
|
||||
1.25,
|
||||
1.5,
|
||||
1.75,
|
||||
]
|
||||
launcher = 'pytorch'
|
||||
load_from = None
|
||||
log_level = 'INFO'
|
||||
log_processor = dict(by_epoch=False)
|
||||
mean = [
|
||||
2482.0061841829206,
|
||||
2456.642580060208,
|
||||
2667.8229979675334,
|
||||
2744.9377076257624,
|
||||
3620.1499158373827,
|
||||
4063.9126981046647,
|
||||
3922.2406108776354,
|
||||
4264.908986788407,
|
||||
2453.0070206816135,
|
||||
1774.0019119673998,
|
||||
]
|
||||
model = dict(
|
||||
auxiliary_head=dict(
|
||||
align_corners=False,
|
||||
channels=256,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=1024,
|
||||
in_index=3,
|
||||
loss_decode=dict(
|
||||
loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=18,
|
||||
num_convs=1,
|
||||
type='FCNHead'),
|
||||
backbone=dict(
|
||||
act_cfg=dict(type='GELU'),
|
||||
attn_drop_rate=0.0,
|
||||
downscale_indices=[
|
||||
-1,
|
||||
],
|
||||
drop_path_rate=0.0,
|
||||
drop_rate=0.1,
|
||||
embed_dims=1024,
|
||||
img_size=(
|
||||
24,
|
||||
24,
|
||||
),
|
||||
in_channels=10,
|
||||
init_cfg=dict(
|
||||
checkpoint=
|
||||
'pretrain/skysensepp_mmcvt_s2.pth',
|
||||
type='Pretrained'),
|
||||
interpolate_mode='bilinear',
|
||||
mlp_ratio=4,
|
||||
norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'),
|
||||
norm_eval=False,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
out_indices=(
|
||||
5,
|
||||
11,
|
||||
17,
|
||||
23,
|
||||
),
|
||||
patch_size=4,
|
||||
qkv_bias=True,
|
||||
type='VisionTransformer',
|
||||
with_cls_token=False),
|
||||
data_preprocessor=dict(
|
||||
bgr_to_rgb=False,
|
||||
mean=[
|
||||
2482.0061841829206,
|
||||
2456.642580060208,
|
||||
2667.8229979675334,
|
||||
2744.9377076257624,
|
||||
3620.1499158373827,
|
||||
4063.9126981046647,
|
||||
3922.2406108776354,
|
||||
4264.908986788407,
|
||||
2453.0070206816135,
|
||||
1774.0019119673998,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(
|
||||
24,
|
||||
24,
|
||||
),
|
||||
std=[
|
||||
2392.1256366526068,
|
||||
2100.1364646122875,
|
||||
2262.6154840764625,
|
||||
2353.899770400333,
|
||||
2089.598452203458,
|
||||
2057.1247114077073,
|
||||
2013.2108514271458,
|
||||
2041.0248949410561,
|
||||
1380.4643757742374,
|
||||
1243.547946113518,
|
||||
],
|
||||
ts_size=30,
|
||||
type='RSTsSegDataPreProcessor'),
|
||||
decode_head=dict(
|
||||
align_corners=False,
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=[
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
],
|
||||
in_index=[
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
],
|
||||
loss_decode=dict(
|
||||
loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=False),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=18,
|
||||
pool_scales=(
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
6,
|
||||
),
|
||||
type='UPerHead'),
|
||||
neck=dict(
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.3,
|
||||
drop_rate=0.0,
|
||||
embed_dims=1024,
|
||||
in_channels=[
|
||||
768,
|
||||
768,
|
||||
768,
|
||||
768,
|
||||
],
|
||||
in_channels_ml=[
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
],
|
||||
init_cfg=dict(
|
||||
checkpoint=
|
||||
'pretrain/skysensepp_mmcvt_fusion.pth',
|
||||
type='Pretrained'),
|
||||
input_dims=1024,
|
||||
mlp_ratio=4,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
out_channels=768,
|
||||
out_channels_ml=1024,
|
||||
output_cls_token=True,
|
||||
qkv_bias=True,
|
||||
scales=[
|
||||
4,
|
||||
2,
|
||||
1,
|
||||
0.5,
|
||||
],
|
||||
scales_ml=[
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
ts_size=30,
|
||||
type='FusionMultiLevelNeck',
|
||||
with_cls_token=True),
|
||||
test_cfg=dict(mode='whole'),
|
||||
train_cfg=dict(),
|
||||
type='EncoderDecoder')
|
||||
norm_cfg = dict(requires_grad=True, type='SyncBN')
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
betas=(
|
||||
0.9,
|
||||
0.999,
|
||||
), lr=0.0001, type='AdamW', weight_decay=0.01),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys=dict(
|
||||
cls_token=dict(decay_mult=0.0),
|
||||
norm=dict(decay_mult=0.0),
|
||||
pos_embed=dict(decay_mult=0.0))),
|
||||
type='OptimWrapper')
|
||||
optimizer = dict(lr=0.01, momentum=0.9, type='SGD', weight_decay=0.0005)
|
||||
param_scheduler = [
|
||||
dict(
|
||||
begin=0, by_epoch=False, end=2000, start_factor=1e-06,
|
||||
type='LinearLR'),
|
||||
dict(
|
||||
begin=2000,
|
||||
by_epoch=False,
|
||||
end=20000,
|
||||
eta_min=0.0,
|
||||
power=1.0,
|
||||
type='PolyLR'),
|
||||
]
|
||||
randomness = dict(seed=20240311)
|
||||
resume = False
|
||||
static_graph = True
|
||||
std = [
|
||||
2392.1256366526068,
|
||||
2100.1364646122875,
|
||||
2262.6154840764625,
|
||||
2353.899770400333,
|
||||
2089.598452203458,
|
||||
2057.1247114077073,
|
||||
2013.2108514271458,
|
||||
2041.0248949410561,
|
||||
1380.4643757742374,
|
||||
1243.547946113518,
|
||||
]
|
||||
test_cfg = dict(type='TestLoop')
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/germany_crop/germany_crop_val.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'),
|
||||
dict(
|
||||
data_key='image',
|
||||
reduce_zero_label=True,
|
||||
type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='GermanyCropDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
test_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
test_pipeline = [
|
||||
dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'),
|
||||
dict(data_key='image', reduce_zero_label=True, type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
train_cfg = dict(
|
||||
dynamic_intervals=[
|
||||
(
|
||||
0,
|
||||
1000,
|
||||
),
|
||||
(
|
||||
4000,
|
||||
2000,
|
||||
),
|
||||
(
|
||||
8000,
|
||||
4000,
|
||||
),
|
||||
],
|
||||
max_iters=20000,
|
||||
type='IterBasedTrainLoop',
|
||||
val_interval=2000)
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/germany_crop/germany_crop_train.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'),
|
||||
dict(
|
||||
data_key='image',
|
||||
reduce_zero_label=True,
|
||||
type='LoadAnnotationsNpz'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='GermanyCropDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=True, type='InfiniteSampler'))
|
||||
train_pipeline = [
|
||||
dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'),
|
||||
dict(data_key='image', reduce_zero_label=True, type='LoadAnnotationsNpz'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
tta_model = dict(type='SegTTAModel')
|
||||
tta_pipeline = [
|
||||
dict(backend_args=None, type='LoadImageFromFile'),
|
||||
dict(
|
||||
transforms=[
|
||||
[
|
||||
dict(keep_ratio=True, scale_factor=0.5, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=0.75, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.0, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.25, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.5, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.75, type='Resize'),
|
||||
],
|
||||
[
|
||||
dict(direction='horizontal', prob=0.0, type='RandomFlip'),
|
||||
dict(direction='horizontal', prob=1.0, type='RandomFlip'),
|
||||
],
|
||||
[
|
||||
dict(type='LoadAnnotations'),
|
||||
],
|
||||
[
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
],
|
||||
type='TestTimeAug'),
|
||||
]
|
||||
val_cfg = dict(type='ValLoop')
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/germany_crop/germany_crop_val.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'),
|
||||
dict(
|
||||
data_key='image',
|
||||
reduce_zero_label=True,
|
||||
type='LoadAnnotationsNpz'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='GermanyCropDataset'),
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
val_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
]
|
||||
visualizer = dict(
|
||||
name='visualizer',
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=[
|
||||
dict(type='LocalVisBackend'),
|
||||
])
|
||||
work_dir = 'save_germany'
|
||||
297
finetune/configs/sos.py
Normal file
297
finetune/configs/sos.py
Normal file
@@ -0,0 +1,297 @@
|
||||
crop_size = (
|
||||
256,
|
||||
256,
|
||||
)
|
||||
data_preprocessor = dict(
|
||||
bgr_to_rgb=True,
|
||||
mean=[
|
||||
83.37,
|
||||
83.37,
|
||||
83.37,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(
|
||||
256,
|
||||
256,
|
||||
),
|
||||
std=[
|
||||
40.45,
|
||||
40.45,
|
||||
40.45,
|
||||
],
|
||||
type='SegDataPreProcessor')
|
||||
data_root = 'rs_datasets/'
|
||||
dataset_type = 'SOSDataset'
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
by_epoch=False,
|
||||
interval=1000,
|
||||
max_keep_ckpts=2,
|
||||
save_best='mIoU',
|
||||
type='CheckpointHook'),
|
||||
logger=dict(interval=50, log_metric_by_epoch=False, type='LoggerHook'),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
timer=dict(type='IterTimerHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
default_scope = 'mmseg'
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=True,
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
|
||||
img_ratios = [
|
||||
0.5,
|
||||
0.75,
|
||||
1.0,
|
||||
1.25,
|
||||
1.5,
|
||||
1.75,
|
||||
]
|
||||
launcher = 'pytorch'
|
||||
load_from = None
|
||||
log_level = 'INFO'
|
||||
log_processor = dict(by_epoch=False)
|
||||
mean = [
|
||||
83.37,
|
||||
83.37,
|
||||
83.37,
|
||||
]
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
act_cfg=dict(type='GELU'),
|
||||
attn_drop_rate=0.0,
|
||||
downscale_indices=[
|
||||
-1,
|
||||
],
|
||||
drop_path_rate=0.0,
|
||||
drop_rate=0.1,
|
||||
embed_dims=1024,
|
||||
img_size=(
|
||||
256,
|
||||
256,
|
||||
),
|
||||
in_channels=2,
|
||||
interpolate_mode='bilinear',
|
||||
mlp_ratio=4,
|
||||
norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'),
|
||||
norm_eval=False,
|
||||
num_heads=16,
|
||||
num_layers=24,
|
||||
out_indices=(
|
||||
5,
|
||||
11,
|
||||
17,
|
||||
23,
|
||||
),
|
||||
patch_size=4,
|
||||
qkv_bias=True,
|
||||
type='VisionTransformer',
|
||||
with_cls_token=False),
|
||||
data_preprocessor=dict(
|
||||
bgr_to_rgb=True,
|
||||
mean=[
|
||||
83.37,
|
||||
83.37,
|
||||
83.37,
|
||||
],
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(
|
||||
256,
|
||||
256,
|
||||
),
|
||||
std=[
|
||||
40.45,
|
||||
40.45,
|
||||
40.45,
|
||||
],
|
||||
type='SegDataPreProcessor'),
|
||||
decode_head=dict(
|
||||
align_corners=False,
|
||||
channels=512,
|
||||
dropout_ratio=0.1,
|
||||
in_channels=[
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
],
|
||||
in_index=[
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
],
|
||||
loss_decode=dict(
|
||||
loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=True),
|
||||
norm_cfg=dict(requires_grad=True, type='SyncBN'),
|
||||
num_classes=2,
|
||||
pool_scales=(
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
6,
|
||||
),
|
||||
type='UPerHead'),
|
||||
neck=dict(
|
||||
in_channels=[
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
1024,
|
||||
],
|
||||
out_channels=1024,
|
||||
scales=[
|
||||
4,
|
||||
2,
|
||||
1,
|
||||
0.5,
|
||||
],
|
||||
type='MultiLevelNeck'),
|
||||
pretrained=
|
||||
'pretrain/skysensepp_mmcvt_s1.pth',
|
||||
test_cfg=dict(mode='whole'),
|
||||
train_cfg=dict(),
|
||||
type='EncoderDecoder')
|
||||
norm_cfg = dict(requires_grad=True, type='SyncBN')
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
betas=(
|
||||
0.9,
|
||||
0.999,
|
||||
), lr=6e-05, type='AdamW', weight_decay=0.01),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys=dict(
|
||||
cls_token=dict(decay_mult=0.0),
|
||||
norm=dict(decay_mult=0.0),
|
||||
pos_embed=dict(decay_mult=0.0))),
|
||||
type='OptimWrapper')
|
||||
param_scheduler = [
|
||||
dict(
|
||||
begin=0, by_epoch=False, end=1000, start_factor=1e-06,
|
||||
type='LinearLR'),
|
||||
dict(
|
||||
begin=1000,
|
||||
by_epoch=False,
|
||||
end=10000,
|
||||
eta_min=0.0,
|
||||
power=1.0,
|
||||
type='PolyLR'),
|
||||
]
|
||||
randomness = dict(seed=0)
|
||||
resume = False
|
||||
std = [
|
||||
40.45,
|
||||
40.45,
|
||||
40.45,
|
||||
]
|
||||
test_cfg = dict(type='TestLoop')
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/sos/sos_test_sentinel.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='SOSDataset'),
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
test_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
train_cfg = dict(max_iters=0, type='IterBasedTrainLoop', val_interval=500)
|
||||
train_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/sos/sos_train_sentinel.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='SOSDataset'),
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=True, type='InfiniteSampler'))
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(prob=0.5, type='RandomFlip'),
|
||||
dict(type='PackSegInputs'),
|
||||
]
|
||||
tta_model = dict(type='SegTTAModel')
|
||||
tta_pipeline = [
|
||||
dict(backend_args=None, type='LoadImageFromFile'),
|
||||
dict(
|
||||
transforms=[
|
||||
[
|
||||
dict(keep_ratio=True, scale_factor=0.5, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=0.75, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.0, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.25, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.5, type='Resize'),
|
||||
dict(keep_ratio=True, scale_factor=1.75, type='Resize'),
|
||||
],
|
||||
[
|
||||
dict(direction='horizontal', prob=0.0, type='RandomFlip'),
|
||||
dict(direction='horizontal', prob=1.0, type='RandomFlip'),
|
||||
],
|
||||
[
|
||||
dict(type='LoadAnnotations'),
|
||||
],
|
||||
[
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
],
|
||||
type='TestTimeAug'),
|
||||
]
|
||||
val_cfg = dict(type='ValLoop')
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
ann_file=
|
||||
'rs_datasets/sos/sos_test_sentinel.json',
|
||||
data_prefix=dict(img_path='images', seg_map_path='idx_labels'),
|
||||
data_root='rs_datasets/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(reduce_zero_label=False, type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs'),
|
||||
],
|
||||
type='SOSDataset'),
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
sampler=dict(shuffle=False, type='DefaultSampler'))
|
||||
val_evaluator = dict(
|
||||
iou_metrics=[
|
||||
'mIoU',
|
||||
'mFscore',
|
||||
], type='IoUMetric')
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
]
|
||||
visualizer = dict(
|
||||
name='visualizer',
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=[
|
||||
dict(type='LocalVisBackend'),
|
||||
])
|
||||
work_dir = 'save_sos/'
|
||||
74
finetune/mmseg/__init__.py
Normal file
74
finetune/mmseg/__init__.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
from packaging.version import parse
|
||||
|
||||
from .version import __version__, version_info
|
||||
|
||||
MMCV_MIN = '2.0.0rc4'
|
||||
MMCV_MAX = '2.2.0'
|
||||
MMENGINE_MIN = '0.5.0'
|
||||
MMENGINE_MAX = '1.0.0'
|
||||
|
||||
|
||||
def digit_version(version_str: str, length: int = 4):
|
||||
"""Convert a version string into a tuple of integers.
|
||||
|
||||
This method is usually used for comparing two versions. For pre-release
|
||||
versions: alpha < beta < rc.
|
||||
|
||||
Args:
|
||||
version_str (str): The version string.
|
||||
length (int): The maximum number of version levels. Default: 4.
|
||||
|
||||
Returns:
|
||||
tuple[int]: The version info in digits (integers).
|
||||
"""
|
||||
version = parse(version_str)
|
||||
assert version.release, f'failed to parse version {version_str}'
|
||||
release = list(version.release)
|
||||
release = release[:length]
|
||||
if len(release) < length:
|
||||
release = release + [0] * (length - len(release))
|
||||
if version.is_prerelease:
|
||||
mapping = {'a': -3, 'b': -2, 'rc': -1}
|
||||
val = -4
|
||||
# version.pre can be None
|
||||
if version.pre:
|
||||
if version.pre[0] not in mapping:
|
||||
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
|
||||
'version checking may go wrong')
|
||||
else:
|
||||
val = mapping[version.pre[0]]
|
||||
release.extend([val, version.pre[-1]])
|
||||
else:
|
||||
release.extend([val, 0])
|
||||
|
||||
elif version.is_postrelease:
|
||||
release.extend([1, version.post])
|
||||
else:
|
||||
release.extend([0, 0])
|
||||
return tuple(release)
|
||||
|
||||
|
||||
mmcv_min_version = digit_version(MMCV_MIN)
|
||||
mmcv_max_version = digit_version(MMCV_MAX)
|
||||
mmcv_version = digit_version(mmcv.__version__)
|
||||
|
||||
|
||||
assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \
|
||||
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
||||
f'Please install mmcv>=2.0.0rc4.'
|
||||
|
||||
mmengine_min_version = digit_version(MMENGINE_MIN)
|
||||
mmengine_max_version = digit_version(MMENGINE_MAX)
|
||||
mmengine_version = digit_version(mmengine.__version__)
|
||||
|
||||
assert (mmengine_min_version <= mmengine_version < mmengine_max_version), \
|
||||
f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
|
||||
f'Please install mmengine>={mmengine_min_version}, '\
|
||||
f'<{mmengine_max_version}.'
|
||||
|
||||
__all__ = ['__version__', 'version_info', 'digit_version']
|
||||
9
finetune/mmseg/apis/__init__.py
Normal file
9
finetune/mmseg/apis/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .inference import inference_model, init_model, show_result_pyplot
|
||||
from .mmseg_inferencer import MMSegInferencer
|
||||
from .remote_sense_inferencer import RSImage, RSInferencer
|
||||
|
||||
__all__ = [
|
||||
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer',
|
||||
'RSInferencer', 'RSImage'
|
||||
]
|
||||
189
finetune/mmseg/apis/inference.py
Normal file
189
finetune/mmseg/apis/inference.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
from mmseg.models import BaseSegmentor
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
from .utils import ImageType, _preprare_data
|
||||
|
||||
|
||||
def init_model(config: Union[str, Path, Config],
|
||||
checkpoint: Optional[str] = None,
|
||||
device: str = 'cuda:0',
|
||||
cfg_options: Optional[dict] = None):
|
||||
"""Initialize a segmentor from config file.
|
||||
|
||||
Args:
|
||||
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
|
||||
:obj:`Path`, or the config object.
|
||||
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||
will not load any weights.
|
||||
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
|
||||
Use 'cpu' for loading model on CPU.
|
||||
cfg_options (dict, optional): Options to override some settings in
|
||||
the used config.
|
||||
Returns:
|
||||
nn.Module: The constructed segmentor.
|
||||
"""
|
||||
if isinstance(config, (str, Path)):
|
||||
config = Config.fromfile(config)
|
||||
elif not isinstance(config, Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
'but got {}'.format(type(config)))
|
||||
if cfg_options is not None:
|
||||
config.merge_from_dict(cfg_options)
|
||||
if config.model.type == 'EncoderDecoder':
|
||||
if 'init_cfg' in config.model.backbone:
|
||||
config.model.backbone.init_cfg = None
|
||||
elif config.model.type == 'MultimodalEncoderDecoder':
|
||||
for k, v in config.model.items():
|
||||
if isinstance(v, dict) and 'init_cfg' in v:
|
||||
config.model[k].init_cfg = None
|
||||
config.model.pretrained = None
|
||||
config.model.train_cfg = None
|
||||
init_default_scope(config.get('default_scope', 'mmseg'))
|
||||
|
||||
model = MODELS.build(config.model)
|
||||
if checkpoint is not None:
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
|
||||
# save the dataset_meta in the model for convenience
|
||||
if 'dataset_meta' in checkpoint.get('meta', {}):
|
||||
# mmseg 1.x
|
||||
model.dataset_meta = dataset_meta
|
||||
elif 'CLASSES' in checkpoint.get('meta', {}):
|
||||
# < mmseg 1.x
|
||||
classes = checkpoint['meta']['CLASSES']
|
||||
palette = checkpoint['meta']['PALETTE']
|
||||
model.dataset_meta = {'classes': classes, 'palette': palette}
|
||||
else:
|
||||
warnings.simplefilter('once')
|
||||
warnings.warn(
|
||||
'dataset_meta or class names are not saved in the '
|
||||
'checkpoint\'s meta data, classes and palette will be'
|
||||
'set according to num_classes ')
|
||||
num_classes = model.decode_head.num_classes
|
||||
dataset_name = None
|
||||
for name in dataset_aliases.keys():
|
||||
if len(get_classes(name)) == num_classes:
|
||||
dataset_name = name
|
||||
break
|
||||
if dataset_name is None:
|
||||
warnings.warn(
|
||||
'No suitable dataset found, use Cityscapes by default')
|
||||
dataset_name = 'cityscapes'
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes(dataset_name),
|
||||
'palette': get_palette(dataset_name)
|
||||
}
|
||||
model.cfg = config # save the config in the model for convenience
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def inference_model(model: BaseSegmentor,
|
||||
img: ImageType) -> Union[SegDataSample, SampleList]:
|
||||
"""Inference image(s) with the segmentor.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded segmentor.
|
||||
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
|
||||
images.
|
||||
|
||||
Returns:
|
||||
:obj:`SegDataSample` or list[:obj:`SegDataSample`]:
|
||||
If imgs is a list or tuple, the same length list type results
|
||||
will be returned, otherwise return the segmentation results directly.
|
||||
"""
|
||||
# prepare data
|
||||
data, is_batch = _preprare_data(img, model)
|
||||
|
||||
# forward the model
|
||||
with torch.no_grad():
|
||||
results = model.test_step(data)
|
||||
|
||||
return results if is_batch else results[0]
|
||||
|
||||
|
||||
def show_result_pyplot(model: BaseSegmentor,
|
||||
img: Union[str, np.ndarray],
|
||||
result: SegDataSample,
|
||||
opacity: float = 0.5,
|
||||
title: str = '',
|
||||
draw_gt: bool = True,
|
||||
draw_pred: bool = True,
|
||||
wait_time: float = 0,
|
||||
show: bool = True,
|
||||
with_labels: Optional[bool] = True,
|
||||
save_dir=None,
|
||||
out_file=None):
|
||||
"""Visualize the segmentation results on the image.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded segmentor.
|
||||
img (str or np.ndarray): Image filename or loaded image.
|
||||
result (SegDataSample): The prediction SegDataSample result.
|
||||
opacity(float): Opacity of painted segmentation map.
|
||||
Default 0.5. Must be in (0, 1] range.
|
||||
title (str): The title of pyplot figure.
|
||||
Default is ''.
|
||||
draw_gt (bool): Whether to draw GT SegDataSample. Default to True.
|
||||
draw_pred (bool): Whether to draw Prediction SegDataSample.
|
||||
Defaults to True.
|
||||
wait_time (float): The interval of show (s). 0 is the special value
|
||||
that means "forever". Defaults to 0.
|
||||
show (bool): Whether to display the drawn image.
|
||||
Default to True.
|
||||
with_labels(bool, optional): Add semantic labels in visualization
|
||||
result, Default to True.
|
||||
save_dir (str, optional): Save file dir for all storage backends.
|
||||
If it is None, the backend storage will not save any data.
|
||||
out_file (str, optional): Path to output file. Default to None.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
np.ndarray: the drawn image which channel is RGB.
|
||||
"""
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if isinstance(img, str):
|
||||
image = mmcv.imread(img, channel_order='rgb')
|
||||
else:
|
||||
image = img
|
||||
if save_dir is not None:
|
||||
mkdir_or_exist(save_dir)
|
||||
# init visualizer
|
||||
visualizer = SegLocalVisualizer(
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
save_dir=save_dir,
|
||||
alpha=opacity)
|
||||
visualizer.dataset_meta = dict(
|
||||
classes=model.dataset_meta['classes'],
|
||||
palette=model.dataset_meta['palette'])
|
||||
visualizer.add_datasample(
|
||||
name=title,
|
||||
image=image,
|
||||
data_sample=result,
|
||||
draw_gt=draw_gt,
|
||||
draw_pred=draw_pred,
|
||||
wait_time=wait_time,
|
||||
out_file=out_file,
|
||||
show=show,
|
||||
with_labels=with_labels)
|
||||
vis_img = visualizer.get_image()
|
||||
|
||||
return vis_img
|
||||
382
finetune/mmseg/apis/mmseg_inferencer.py
Normal file
382
finetune/mmseg/apis/mmseg_inferencer.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.transforms import Compose
|
||||
from mmengine.infer.infer import BaseInferencer, ModelType
|
||||
from mmengine.model import revert_sync_batchnorm
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner.checkpoint import _load_checkpoint_to_model
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList, get_classes, get_palette
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
|
||||
InputType = Union[str, np.ndarray]
|
||||
InputsType = Union[InputType, Sequence[InputType]]
|
||||
PredType = Union[SegDataSample, SampleList]
|
||||
|
||||
|
||||
class MMSegInferencer(BaseInferencer):
|
||||
"""Semantic segmentation inferencer, provides inference and visualization
|
||||
interfaces. Note: MMEngine >= 0.5.0 is required.
|
||||
|
||||
Args:
|
||||
model (str, optional): Path to the config file or the model name
|
||||
defined in metafile. Take the `mmseg metafile <https://github.com/open-mmlab/mmsegmentation/blob/main/configs/fcn/metafile.yaml>`_
|
||||
as an example the `model` could be
|
||||
"fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model
|
||||
will be download automatically. If use config file, like
|
||||
"configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py", the
|
||||
`weights` should be defined.
|
||||
weights (str, optional): Path to the checkpoint. If it is not specified
|
||||
and model is a model name of metafile, the weights will be loaded
|
||||
from metafile. Defaults to None.
|
||||
classes (list, optional): Input classes for result rendering, as the
|
||||
prediction of segmentation model is a segment map with label
|
||||
indices, `classes` is a list which includes items responding to the
|
||||
label indices. If classes is not defined, visualizer will take
|
||||
`cityscapes` classes by default. Defaults to None.
|
||||
palette (list, optional): Input palette for result rendering, which is
|
||||
a list of color palette responding to the classes. If palette is
|
||||
not defined, visualizer will take `cityscapes` palette by default.
|
||||
Defaults to None.
|
||||
dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/class_names.py#L302-L317>`_
|
||||
visulizer will use the meta information of the dataset i.e. classes
|
||||
and palette, but the `classes` and `palette` have higher priority.
|
||||
Defaults to None.
|
||||
device (str, optional): Device to run inference. If None, the available
|
||||
device will be automatically used. Defaults to None.
|
||||
scope (str, optional): The scope of the model. Defaults to 'mmseg'.
|
||||
""" # noqa
|
||||
|
||||
preprocess_kwargs: set = set()
|
||||
forward_kwargs: set = {'mode', 'out_dir'}
|
||||
visualize_kwargs: set = {
|
||||
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis',
|
||||
'with_labels'
|
||||
}
|
||||
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
||||
|
||||
def __init__(self,
|
||||
model: Union[ModelType, str],
|
||||
weights: Optional[str] = None,
|
||||
classes: Optional[Union[str, List]] = None,
|
||||
palette: Optional[Union[str, List]] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
scope: Optional[str] = 'mmseg') -> None:
|
||||
# A global counter tracking the number of images processes, for
|
||||
# naming of the output images
|
||||
self.num_visualized_imgs = 0
|
||||
self.num_pred_imgs = 0
|
||||
init_default_scope(scope if scope else 'mmseg')
|
||||
super().__init__(
|
||||
model=model, weights=weights, device=device, scope=scope)
|
||||
|
||||
if device == 'cpu' or not torch.cuda.is_available():
|
||||
self.model = revert_sync_batchnorm(self.model)
|
||||
|
||||
assert isinstance(self.visualizer, SegLocalVisualizer)
|
||||
self.visualizer.set_dataset_meta(classes, palette, dataset_name)
|
||||
|
||||
def _load_weights_to_model(self, model: nn.Module,
|
||||
checkpoint: Optional[dict],
|
||||
cfg: Optional[ConfigType]) -> None:
|
||||
"""Loading model weights and meta information from cfg and checkpoint.
|
||||
|
||||
Subclasses could override this method to load extra meta information
|
||||
from ``checkpoint`` and ``cfg`` to model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to load weights and meta information.
|
||||
checkpoint (dict, optional): The loaded checkpoint.
|
||||
cfg (Config or ConfigDict, optional): The loaded config.
|
||||
"""
|
||||
|
||||
if checkpoint is not None:
|
||||
_load_checkpoint_to_model(model, checkpoint)
|
||||
checkpoint_meta = checkpoint.get('meta', {})
|
||||
# save the dataset_meta in the model for convenience
|
||||
if 'dataset_meta' in checkpoint_meta:
|
||||
# mmsegmentation 1.x
|
||||
model.dataset_meta = {
|
||||
'classes': checkpoint_meta['dataset_meta'].get('classes'),
|
||||
'palette': checkpoint_meta['dataset_meta'].get('palette')
|
||||
}
|
||||
elif 'CLASSES' in checkpoint_meta:
|
||||
# mmsegmentation 0.x
|
||||
classes = checkpoint_meta['CLASSES']
|
||||
palette = checkpoint_meta.get('PALETTE', None)
|
||||
model.dataset_meta = {'classes': classes, 'palette': palette}
|
||||
else:
|
||||
warnings.warn(
|
||||
'dataset_meta or class names are not saved in the '
|
||||
'checkpoint\'s meta data, use classes of Cityscapes by '
|
||||
'default.')
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes('cityscapes'),
|
||||
'palette': get_palette('cityscapes')
|
||||
}
|
||||
else:
|
||||
warnings.warn('Checkpoint is not loaded, and the inference '
|
||||
'result is calculated by the randomly initialized '
|
||||
'model!')
|
||||
warnings.warn(
|
||||
'weights is None, use cityscapes classes by default.')
|
||||
model.dataset_meta = {
|
||||
'classes': get_classes('cityscapes'),
|
||||
'palette': get_palette('cityscapes')
|
||||
}
|
||||
|
||||
def __call__(self,
|
||||
inputs: InputsType,
|
||||
return_datasamples: bool = False,
|
||||
batch_size: int = 1,
|
||||
return_vis: bool = False,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
out_dir: str = '',
|
||||
img_out_dir: str = 'vis',
|
||||
pred_out_dir: str = 'pred',
|
||||
**kwargs) -> dict:
|
||||
"""Call the inferencer.
|
||||
|
||||
Args:
|
||||
inputs (Union[list, str, np.ndarray]): Inputs for the inferencer.
|
||||
return_datasamples (bool): Whether to return results as
|
||||
:obj:`SegDataSample`. Defaults to False.
|
||||
batch_size (int): Batch size. Defaults to 1.
|
||||
show (bool): Whether to display the rendering color segmentation
|
||||
mask in a popup window. Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
out_dir (str): Output directory of inference results. Defaults
|
||||
to ''.
|
||||
img_out_dir (str): Subdirectory of `out_dir`, used to save
|
||||
rendering color segmentation mask, so `out_dir` must be defined
|
||||
if you would like to save predicted mask. Defaults to 'vis'.
|
||||
pred_out_dir (str): Subdirectory of `out_dir`, used to save
|
||||
predicted mask file, so `out_dir` must be defined if you would
|
||||
like to save predicted mask. Defaults to 'pred'.
|
||||
|
||||
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
|
||||
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
||||
Each key in kwargs should be in the corresponding set of
|
||||
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
|
||||
and ``postprocess_kwargs``.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results.
|
||||
"""
|
||||
|
||||
if out_dir != '':
|
||||
pred_out_dir = osp.join(out_dir, pred_out_dir)
|
||||
img_out_dir = osp.join(out_dir, img_out_dir)
|
||||
else:
|
||||
pred_out_dir = ''
|
||||
img_out_dir = ''
|
||||
|
||||
return super().__call__(
|
||||
inputs=inputs,
|
||||
return_datasamples=return_datasamples,
|
||||
batch_size=batch_size,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
img_out_dir=img_out_dir,
|
||||
pred_out_dir=pred_out_dir,
|
||||
return_vis=return_vis,
|
||||
**kwargs)
|
||||
|
||||
def visualize(self,
|
||||
inputs: list,
|
||||
preds: List[dict],
|
||||
return_vis: bool = False,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
img_out_dir: str = '',
|
||||
opacity: float = 0.8,
|
||||
with_labels: Optional[bool] = True) -> List[np.ndarray]:
|
||||
"""Visualize predictions.
|
||||
|
||||
Args:
|
||||
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
|
||||
preds (Any): Predictions of the model.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
img_out_dir (str): Output directory of rendering prediction i.e.
|
||||
color segmentation mask. Defaults: ''
|
||||
opacity (int, float): The transparency of segmentation mask.
|
||||
Defaults to 0.8.
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: Visualization results.
|
||||
"""
|
||||
if not show and img_out_dir == '' and not return_vis:
|
||||
return None
|
||||
if self.visualizer is None:
|
||||
raise ValueError('Visualization needs the "visualizer" term'
|
||||
'defined in the config, but got None.')
|
||||
|
||||
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
|
||||
self.visualizer.alpha = opacity
|
||||
|
||||
results = []
|
||||
|
||||
for single_input, pred in zip(inputs, preds):
|
||||
if isinstance(single_input, str):
|
||||
img_bytes = mmengine.fileio.get(single_input)
|
||||
img = mmcv.imfrombytes(img_bytes)
|
||||
img = img[:, :, ::-1]
|
||||
img_name = osp.basename(single_input)
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
img = single_input.copy()
|
||||
img_num = str(self.num_visualized_imgs).zfill(8) + '_vis'
|
||||
img_name = f'{img_num}.jpg'
|
||||
else:
|
||||
raise ValueError('Unsupported input type:'
|
||||
f'{type(single_input)}')
|
||||
|
||||
out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\
|
||||
else None
|
||||
|
||||
self.visualizer.add_datasample(
|
||||
img_name,
|
||||
img,
|
||||
pred,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
draw_gt=False,
|
||||
draw_pred=True,
|
||||
out_file=out_file,
|
||||
with_labels=with_labels)
|
||||
if return_vis:
|
||||
results.append(self.visualizer.get_image())
|
||||
self.num_visualized_imgs += 1
|
||||
|
||||
return results if return_vis else None
|
||||
|
||||
def postprocess(self,
|
||||
preds: PredType,
|
||||
visualization: List[np.ndarray],
|
||||
return_datasample: bool = False,
|
||||
pred_out_dir: str = '') -> dict:
|
||||
"""Process the predictions and visualization results from ``forward``
|
||||
and ``visualize``.
|
||||
|
||||
This method should be responsible for the following tasks:
|
||||
|
||||
1. Pack the predictions and visualization results and return them.
|
||||
2. Save the predictions, if it needed.
|
||||
|
||||
Args:
|
||||
preds (List[Dict]): Predictions of the model.
|
||||
visualization (List[np.ndarray]): The list of rendering color
|
||||
segmentation mask.
|
||||
return_datasample (bool): Whether to return results as datasamples.
|
||||
Defaults to False.
|
||||
pred_out_dir: File to save the inference results w/o
|
||||
visualization. If left as empty, no file will be saved.
|
||||
Defaults to ''.
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results with key ``predictions``
|
||||
and ``visualization``
|
||||
|
||||
- ``visualization (Any)``: Returned by :meth:`visualize`
|
||||
- ``predictions`` (List[np.ndarray], np.ndarray): Returned by
|
||||
:meth:`forward` and processed in :meth:`postprocess`.
|
||||
If ``return_datasample=False``, it will be the segmentation mask
|
||||
with label indice.
|
||||
"""
|
||||
if return_datasample:
|
||||
if len(preds) == 1:
|
||||
return preds[0]
|
||||
else:
|
||||
return preds
|
||||
|
||||
results_dict = {}
|
||||
|
||||
results_dict['predictions'] = []
|
||||
results_dict['visualization'] = []
|
||||
|
||||
for i, pred in enumerate(preds):
|
||||
pred_data = dict()
|
||||
if 'pred_sem_seg' in pred.keys():
|
||||
pred_data['sem_seg'] = pred.pred_sem_seg.numpy().data[0]
|
||||
elif 'pred_depth_map' in pred.keys():
|
||||
pred_data['depth_map'] = pred.pred_depth_map.numpy().data[0]
|
||||
|
||||
if visualization is not None:
|
||||
vis = visualization[i]
|
||||
results_dict['visualization'].append(vis)
|
||||
if pred_out_dir != '':
|
||||
mmengine.mkdir_or_exist(pred_out_dir)
|
||||
for key, data in pred_data.items():
|
||||
post_fix = '_pred.png' if key == 'sem_seg' else '_pred.npy'
|
||||
img_name = str(self.num_pred_imgs).zfill(8) + post_fix
|
||||
img_path = osp.join(pred_out_dir, img_name)
|
||||
if key == 'sem_seg':
|
||||
output = Image.fromarray(data.astype(np.uint8))
|
||||
output.save(img_path)
|
||||
else:
|
||||
np.save(img_path, data)
|
||||
pred_data = next(iter(pred_data.values()))
|
||||
results_dict['predictions'].append(pred_data)
|
||||
self.num_pred_imgs += 1
|
||||
|
||||
if len(results_dict['predictions']) == 1:
|
||||
results_dict['predictions'] = results_dict['predictions'][0]
|
||||
if visualization is not None:
|
||||
results_dict['visualization'] = \
|
||||
results_dict['visualization'][0]
|
||||
return results_dict
|
||||
|
||||
def _init_pipeline(self, cfg: ConfigType) -> Compose:
|
||||
"""Initialize the test pipeline.
|
||||
|
||||
Return a pipeline to handle various input data, such as ``str``,
|
||||
``np.ndarray``. It is an abstract method in BaseInferencer, and should
|
||||
be implemented in subclasses.
|
||||
|
||||
The returned pipeline will be used to process a single data.
|
||||
It will be used in :meth:`preprocess` like this:
|
||||
|
||||
.. code-block:: python
|
||||
def preprocess(self, inputs, batch_size, **kwargs):
|
||||
...
|
||||
dataset = map(self.pipeline, dataset)
|
||||
...
|
||||
"""
|
||||
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
# Loading annotations is also not applicable
|
||||
for transform in ('LoadAnnotations', 'LoadDepthAnnotation'):
|
||||
idx = self._get_transform_idx(pipeline_cfg, transform)
|
||||
if idx != -1:
|
||||
del pipeline_cfg[idx]
|
||||
|
||||
load_img_idx = self._get_transform_idx(pipeline_cfg,
|
||||
'LoadImageFromFile')
|
||||
if load_img_idx == -1:
|
||||
raise ValueError(
|
||||
'LoadImageFromFile is not found in the test pipeline')
|
||||
pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader'
|
||||
return Compose(pipeline_cfg)
|
||||
|
||||
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
|
||||
"""Returns the index of the transform in a pipeline.
|
||||
|
||||
If the transform is not found, returns -1.
|
||||
"""
|
||||
for i, transform in enumerate(pipeline_cfg):
|
||||
if transform['type'] == name:
|
||||
return i
|
||||
return -1
|
||||
279
finetune/mmseg/apis/remote_sense_inferencer.py
Normal file
279
finetune/mmseg/apis/remote_sense_inferencer.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import threading
|
||||
from queue import Queue
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import load_checkpoint
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
except ImportError:
|
||||
gdal = None
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import _preprare_data
|
||||
|
||||
|
||||
class RSImage:
|
||||
"""Remote sensing image class.
|
||||
|
||||
Args:
|
||||
img (str or gdal.Dataset): Image file path or gdal.Dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, image):
|
||||
self.dataset = gdal.Open(image, gdal.GA_ReadOnly) if isinstance(
|
||||
image, str) else image
|
||||
assert isinstance(self.dataset, gdal.Dataset), \
|
||||
f'{image} is not a image'
|
||||
self.width = self.dataset.RasterXSize
|
||||
self.height = self.dataset.RasterYSize
|
||||
self.channel = self.dataset.RasterCount
|
||||
self.trans = self.dataset.GetGeoTransform()
|
||||
self.proj = self.dataset.GetProjection()
|
||||
self.band_list = []
|
||||
self.band_list.extend(
|
||||
self.dataset.GetRasterBand(c + 1) for c in range(self.channel))
|
||||
self.grids = []
|
||||
|
||||
def read(self, grid: Optional[List] = None) -> np.ndarray:
|
||||
"""Read image data. If grid is None, read the whole image.
|
||||
|
||||
Args:
|
||||
grid (Optional[List], optional): Grid to read. Defaults to None.
|
||||
Returns:
|
||||
np.ndarray: Image data.
|
||||
"""
|
||||
if grid is None:
|
||||
return np.einsum('ijk->jki', self.dataset.ReadAsArray())
|
||||
assert len(
|
||||
grid) >= 4, 'grid must be a list containing at least 4 elements'
|
||||
data = self.dataset.ReadAsArray(*grid[:4])
|
||||
if data.ndim == 2:
|
||||
data = data[np.newaxis, ...]
|
||||
return np.einsum('ijk->jki', data)
|
||||
|
||||
def write(self, data: Optional[np.ndarray], grid: Optional[List] = None):
|
||||
"""Write image data.
|
||||
|
||||
Args:
|
||||
grid (Optional[List], optional): Grid to write. Defaults to None.
|
||||
data (Optional[np.ndarray], optional): Data to write.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: Either grid or data must be provided.
|
||||
"""
|
||||
if grid is not None:
|
||||
assert len(grid) == 8, 'grid must be a list of 8 elements'
|
||||
for band in self.band_list:
|
||||
band.WriteArray(
|
||||
data[grid[5]:grid[5] + grid[7], grid[4]:grid[4] + grid[6]],
|
||||
grid[0] + grid[4], grid[1] + grid[5])
|
||||
elif data is not None:
|
||||
for i in range(self.channel):
|
||||
self.band_list[i].WriteArray(data[..., i])
|
||||
else:
|
||||
raise ValueError('Either grid or data must be provided.')
|
||||
|
||||
def create_seg_map(self, output_path: Optional[str] = None):
|
||||
if output_path is None:
|
||||
output_path = 'output_label.tif'
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
seg_map = driver.Create(output_path, self.width, self.height, 1,
|
||||
gdal.GDT_Byte)
|
||||
seg_map.SetGeoTransform(self.trans)
|
||||
seg_map.SetProjection(self.proj)
|
||||
seg_map_img = RSImage(seg_map)
|
||||
seg_map_img.path = output_path
|
||||
return seg_map_img
|
||||
|
||||
def create_grids(self,
|
||||
window_size: Tuple[int, int],
|
||||
stride: Tuple[int, int] = (0, 0)):
|
||||
"""Create grids for image inference.
|
||||
|
||||
Args:
|
||||
window_size (Tuple[int, int]): the size of the sliding window.
|
||||
stride (Tuple[int, int], optional): the stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
|
||||
Raises:
|
||||
AssertionError: window_size must be a tuple of 2 elements.
|
||||
AssertionError: stride must be a tuple of 2 elements.
|
||||
"""
|
||||
assert len(
|
||||
window_size) == 2, 'window_size must be a tuple of 2 elements'
|
||||
assert len(stride) == 2, 'stride must be a tuple of 2 elements'
|
||||
win_w, win_h = window_size
|
||||
stride_x, stride_y = stride
|
||||
|
||||
stride_x = win_w if stride_x == 0 else stride_x
|
||||
stride_y = win_h if stride_y == 0 else stride_y
|
||||
|
||||
x_half_overlap = (win_w - stride_x + 1) // 2
|
||||
y_half_overlap = (win_h - stride_y + 1) // 2
|
||||
|
||||
for y in range(0, self.height, stride_y):
|
||||
y_end = y + win_h >= self.height
|
||||
y_offset = self.height - win_h if y_end else y
|
||||
y_size = win_h
|
||||
y_crop_off = 0 if y_offset == 0 else y_half_overlap
|
||||
y_crop_size = y_size if y_end else win_h - y_crop_off
|
||||
|
||||
for x in range(0, self.width, stride_x):
|
||||
x_end = x + win_w >= self.width
|
||||
x_offset = self.width - win_w if x_end else x
|
||||
x_size = win_w
|
||||
x_crop_off = 0 if x_offset == 0 else x_half_overlap
|
||||
x_crop_size = x_size if x_end else win_w - x_crop_off
|
||||
|
||||
self.grids.append([
|
||||
x_offset, y_offset, x_size, y_size, x_crop_off, y_crop_off,
|
||||
x_crop_size, y_crop_size
|
||||
])
|
||||
|
||||
|
||||
class RSInferencer:
|
||||
"""Remote sensing inference class.
|
||||
|
||||
Args:
|
||||
model (BaseModel): The loaded model.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
thread (int, optional): Number of threads. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, model: BaseModel, batch_size: int = 1, thread: int = 1):
|
||||
self.model = model
|
||||
self.batch_size = batch_size
|
||||
self.END_FLAG = object()
|
||||
self.read_buffer = Queue(self.batch_size)
|
||||
self.write_buffer = Queue(self.batch_size)
|
||||
self.thread = thread
|
||||
|
||||
@classmethod
|
||||
def from_config_path(cls,
|
||||
config_path: str,
|
||||
checkpoint_path: str,
|
||||
batch_size: int = 1,
|
||||
thread: int = 1,
|
||||
device: Optional[str] = 'cpu'):
|
||||
"""Initialize a segmentor from config file.
|
||||
|
||||
Args:
|
||||
config_path (str): Config file path.
|
||||
checkpoint_path (str): Checkpoint path.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
"""
|
||||
init_default_scope('mmseg')
|
||||
cfg = Config.fromfile(config_path)
|
||||
model = MODELS.build(cfg.model)
|
||||
model.cfg = cfg
|
||||
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return cls(model, batch_size, thread)
|
||||
|
||||
@classmethod
|
||||
def from_model(cls,
|
||||
model: BaseModel,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
batch_size: int = 1,
|
||||
thread: int = 1,
|
||||
device: Optional[str] = 'cpu'):
|
||||
"""Initialize a segmentor from model.
|
||||
|
||||
Args:
|
||||
model (BaseModel): The loaded model.
|
||||
checkpoint_path (Optional[str]): Checkpoint path.
|
||||
batch_size (int, optional): Batch size. Defaults to 1.
|
||||
"""
|
||||
if checkpoint_path is not None:
|
||||
load_checkpoint(model, checkpoint_path, map_location='cpu')
|
||||
model.to(device)
|
||||
return cls(model, batch_size, thread)
|
||||
|
||||
def read(self,
|
||||
image: RSImage,
|
||||
window_size: Tuple[int, int],
|
||||
strides: Tuple[int, int] = (0, 0)):
|
||||
"""Load image data to read buffer.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to read.
|
||||
window_size (Tuple[int, int]): The size of the sliding window.
|
||||
strides (Tuple[int, int], optional): The stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
"""
|
||||
image.create_grids(window_size, strides)
|
||||
for grid in image.grids:
|
||||
self.read_buffer.put([grid, image.read(grid=grid)])
|
||||
self.read_buffer.put(self.END_FLAG)
|
||||
|
||||
def inference(self):
|
||||
"""Inference image data from read buffer and put the result to write
|
||||
buffer."""
|
||||
while True:
|
||||
item = self.read_buffer.get()
|
||||
if item == self.END_FLAG:
|
||||
self.read_buffer.put(self.END_FLAG)
|
||||
self.write_buffer.put(item)
|
||||
break
|
||||
data, _ = _preprare_data(item[1], self.model)
|
||||
with torch.no_grad():
|
||||
result = self.model.test_step(data)
|
||||
item[1] = result[0].pred_sem_seg.cpu().data.numpy()[0]
|
||||
self.write_buffer.put(item)
|
||||
self.read_buffer.task_done()
|
||||
|
||||
def write(self, image: RSImage, output_path: Optional[str] = None):
|
||||
"""Write image data from write buffer.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to write.
|
||||
output_path (Optional[str], optional): The path to save the
|
||||
segmentation map. Defaults to None.
|
||||
"""
|
||||
seg_map = image.create_seg_map(output_path)
|
||||
while True:
|
||||
item = self.write_buffer.get()
|
||||
if item == self.END_FLAG:
|
||||
break
|
||||
seg_map.write(data=item[1], grid=item[0])
|
||||
self.write_buffer.task_done()
|
||||
|
||||
def run(self,
|
||||
image: RSImage,
|
||||
window_size: Tuple[int, int],
|
||||
strides: Tuple[int, int] = (0, 0),
|
||||
output_path: Optional[str] = None):
|
||||
"""Run inference with multi-threading.
|
||||
|
||||
Args:
|
||||
image (RSImage): The image to inference.
|
||||
window_size (Tuple[int, int]): The size of the sliding window.
|
||||
strides (Tuple[int, int], optional): The stride of the sliding
|
||||
window. Defaults to (0, 0).
|
||||
output_path (Optional[str], optional): The path to save the
|
||||
segmentation map. Defaults to None.
|
||||
"""
|
||||
read_thread = threading.Thread(
|
||||
target=self.read, args=(image, window_size, strides))
|
||||
read_thread.start()
|
||||
inference_threads = []
|
||||
for _ in range(self.thread):
|
||||
inference_thread = threading.Thread(target=self.inference)
|
||||
inference_thread.start()
|
||||
inference_threads.append(inference_thread)
|
||||
write_thread = threading.Thread(
|
||||
target=self.write, args=(image, output_path))
|
||||
write_thread.start()
|
||||
read_thread.join()
|
||||
for inference_thread in inference_threads:
|
||||
inference_thread.join()
|
||||
write_thread.join()
|
||||
41
finetune/mmseg/apis/utils.py
Normal file
41
finetune/mmseg/apis/utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import defaultdict
|
||||
from typing import Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.model import BaseModel
|
||||
|
||||
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
|
||||
|
||||
|
||||
def _preprare_data(imgs: ImageType, model: BaseModel):
|
||||
|
||||
cfg = model.cfg
|
||||
for t in cfg.test_pipeline:
|
||||
if t.get('type') == 'LoadAnnotations':
|
||||
cfg.test_pipeline.remove(t)
|
||||
|
||||
is_batch = True
|
||||
if not isinstance(imgs, (list, tuple)):
|
||||
imgs = [imgs]
|
||||
is_batch = False
|
||||
|
||||
if isinstance(imgs[0], np.ndarray):
|
||||
cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'
|
||||
|
||||
# TODO: Consider using the singleton pattern to avoid building
|
||||
# a pipeline for each inference
|
||||
pipeline = Compose(cfg.test_pipeline)
|
||||
|
||||
data = defaultdict(list)
|
||||
for img in imgs:
|
||||
if isinstance(img, np.ndarray):
|
||||
data_ = dict(img=img)
|
||||
else:
|
||||
data_ = dict(img_path=img)
|
||||
data_ = pipeline(data_)
|
||||
data['inputs'].append(data_['inputs'])
|
||||
data['data_samples'].append(data_['data_samples'])
|
||||
|
||||
return data, is_batch
|
||||
35
finetune/mmseg/datasets/__init__.py
Normal file
35
finetune/mmseg/datasets/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# yapf: disable
|
||||
from .atlantic import AtlanticDataset
|
||||
from .c2sfloods import C2SFloodDataset
|
||||
from .cabuar import CABURADataset
|
||||
from .germany import GermanyCropDataset
|
||||
from .sos import SOSDataset
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
|
||||
LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadImageFromNDArray, LoadMultipleRSImageFromFile,
|
||||
LoadSingleRSImageFromFile, PackSegInputs,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomMosaic, RandomRotate, RandomRotFlip, Rerange,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
|
||||
# yapf: enable
|
||||
__all__ = [
|
||||
'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
|
||||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
|
||||
'Albu', 'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
|
||||
'ConcatCDInput', 'AtlanticDataset', 'C2SFloodDataset',
|
||||
'CABURADataset', 'GermanyCropDataset', 'SOSDataset'
|
||||
]
|
||||
48
finetune/mmseg/datasets/atlantic.py
Normal file
48
finetune/mmseg/datasets/atlantic.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import List
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
@DATASETS.register_module()
|
||||
class AtlanticDataset(BaseSegDataset):
|
||||
METAINFO = dict(
|
||||
classes=('Background', 'Deforestation area'),
|
||||
palette=[[0,0,0], [255,255,255]]
|
||||
)
|
||||
def __init__(self,
|
||||
img_suffix='.tif',
|
||||
seg_map_suffix='.tif',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
for line in lines:
|
||||
# img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['s2_path']))
|
||||
if 'target_path' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['target_path'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
552
finetune/mmseg/datasets/basesegdataset.py
Normal file
552
finetune/mmseg/datasets/basesegdataset.py
Normal file
@@ -0,0 +1,552 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import mmengine
|
||||
import mmengine.fileio as fileio
|
||||
import numpy as np
|
||||
from mmengine.dataset import BaseDataset, Compose
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class BaseSegDataset(BaseDataset):
|
||||
"""Custom dataset for semantic segmentation. An example of file structure
|
||||
is as followed.
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
├── data
|
||||
│ ├── my_dataset
|
||||
│ │ ├── img_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── ann_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{seg_map_suffix}
|
||||
│ │ │ │ ├── yyy{seg_map_suffix}
|
||||
│ │ │ │ ├── zzz{seg_map_suffix}
|
||||
│ │ │ ├── val
|
||||
|
||||
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
|
||||
except suffix. A valid img/gt_semantic_seg filename pair should be like
|
||||
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
|
||||
in the suffix). If split is given, then ``xxx`` is specified in txt file.
|
||||
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
|
||||
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
|
||||
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path. Defaults to ''.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
specify classes to load. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||
dict(img_path=None, seg_map_path=None).
|
||||
img_suffix (str): Suffix of images. Default: '.jpg'
|
||||
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to None which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=True``. Defaults to False.
|
||||
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
ignore_index (int): The label index to be ignored. Default: 255
|
||||
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
||||
Default to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
METAINFO: dict = dict()
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str = '',
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: dict = dict(img_path='', seg_map_path=''),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
pipeline: List[Union[dict, Callable]] = [],
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000,
|
||||
ignore_index: int = 255,
|
||||
reduce_zero_label: bool = False,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
|
||||
self.img_suffix = img_suffix
|
||||
self.seg_map_suffix = seg_map_suffix
|
||||
self.ignore_index = ignore_index
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
self.data_root = data_root
|
||||
self.data_prefix = copy.copy(data_prefix)
|
||||
self.ann_file = ann_file
|
||||
self.filter_cfg = copy.deepcopy(filter_cfg)
|
||||
self._indices = indices
|
||||
self.serialize_data = serialize_data
|
||||
self.test_mode = test_mode
|
||||
self.max_refetch = max_refetch
|
||||
self.data_list: List[dict] = []
|
||||
self.data_bytes: np.ndarray
|
||||
|
||||
# Set meta information.
|
||||
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
|
||||
|
||||
# Get label map for custom classes
|
||||
new_classes = self._metainfo.get('classes', None)
|
||||
self.label_map = self.get_label_map(new_classes)
|
||||
self._metainfo.update(
|
||||
dict(
|
||||
label_map=self.label_map,
|
||||
reduce_zero_label=self.reduce_zero_label))
|
||||
|
||||
# Update palette based on label map or generate palette
|
||||
# if it is not defined
|
||||
updated_palette = self._update_palette()
|
||||
self._metainfo.update(dict(palette=updated_palette))
|
||||
|
||||
# Join paths.
|
||||
if self.data_root is not None:
|
||||
self._join_prefix()
|
||||
|
||||
# Build pipeline.
|
||||
self.pipeline = Compose(pipeline)
|
||||
# Full initialize the dataset.
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
if test_mode:
|
||||
assert self._metainfo.get('classes') is not None, \
|
||||
'dataset metainfo `classes` should be specified when testing'
|
||||
|
||||
@classmethod
|
||||
def get_label_map(cls,
|
||||
new_classes: Optional[Sequence] = None
|
||||
) -> Union[Dict, None]:
|
||||
"""Require label mapping.
|
||||
|
||||
The ``label_map`` is a dictionary, its keys are the old label ids and
|
||||
its values are the new label ids, and is used for changing pixel
|
||||
labels in load_annotations. If and only if old classes in cls.METAINFO
|
||||
is not equal to new classes in self._metainfo and nether of them is not
|
||||
None, `label_map` is not None.
|
||||
|
||||
Args:
|
||||
new_classes (list, tuple, optional): The new classes name from
|
||||
metainfo. Default to None.
|
||||
|
||||
|
||||
Returns:
|
||||
dict, optional: The mapping from old classes in cls.METAINFO to
|
||||
new classes in self._metainfo
|
||||
"""
|
||||
old_classes = cls.METAINFO.get('classes', None)
|
||||
if (new_classes is not None and old_classes is not None
|
||||
and list(new_classes) != list(old_classes)):
|
||||
|
||||
label_map = {}
|
||||
if not set(new_classes).issubset(cls.METAINFO['classes']):
|
||||
raise ValueError(
|
||||
f'new classes {new_classes} is not a '
|
||||
f'subset of classes {old_classes} in METAINFO.')
|
||||
for i, c in enumerate(old_classes):
|
||||
if c not in new_classes:
|
||||
label_map[i] = 255
|
||||
else:
|
||||
label_map[i] = new_classes.index(c)
|
||||
return label_map
|
||||
else:
|
||||
return None
|
||||
|
||||
def _update_palette(self) -> list:
|
||||
"""Update palette after loading metainfo.
|
||||
|
||||
If length of palette is equal to classes, just return the palette.
|
||||
If palette is not defined, it will randomly generate a palette.
|
||||
If classes is updated by customer, it will return the subset of
|
||||
palette.
|
||||
|
||||
Returns:
|
||||
Sequence: Palette for current dataset.
|
||||
"""
|
||||
palette = self._metainfo.get('palette', [])
|
||||
classes = self._metainfo.get('classes', [])
|
||||
# palette does match classes
|
||||
if len(palette) == len(classes):
|
||||
return palette
|
||||
|
||||
if len(palette) == 0:
|
||||
# Get random state before set seed, and restore
|
||||
# random state later.
|
||||
# It will prevent loss of randomness, as the palette
|
||||
# may be different in each iteration if not specified.
|
||||
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||
state = np.random.get_state()
|
||||
np.random.seed(42)
|
||||
# random palette
|
||||
new_palette = np.random.randint(
|
||||
0, 255, size=(len(classes), 3)).tolist()
|
||||
np.random.set_state(state)
|
||||
elif len(palette) >= len(classes) and self.label_map is not None:
|
||||
new_palette = []
|
||||
# return subset of palette
|
||||
for old_id, new_id in sorted(
|
||||
self.label_map.items(), key=lambda x: x[1]):
|
||||
if new_id != 255:
|
||||
new_palette.append(palette[old_id])
|
||||
new_palette = type(palette)(new_palette)
|
||||
else:
|
||||
raise ValueError('palette does not match classes '
|
||||
f'as metainfo is {self._metainfo}.')
|
||||
return new_palette
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
img_dir = self.data_prefix.get('img_path', None)
|
||||
ann_dir = self.data_prefix.get('seg_map_path', None)
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = mmengine.list_from_file(
|
||||
self.ann_file, backend_args=self.backend_args)
|
||||
for line in lines:
|
||||
img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img_name + self.img_suffix))
|
||||
if ann_dir is not None:
|
||||
seg_map = img_name + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
else:
|
||||
_suffix_len = len(self.img_suffix)
|
||||
for img in fileio.list_dir_or_file(
|
||||
dir_path=img_dir,
|
||||
list_dir=False,
|
||||
suffix=self.img_suffix,
|
||||
recursive=True,
|
||||
backend_args=self.backend_args):
|
||||
data_info = dict(img_path=osp.join(img_dir, img))
|
||||
if ann_dir is not None:
|
||||
seg_map = img[:-_suffix_len] + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class BaseCDDataset(BaseDataset):
|
||||
"""Custom dataset for change detection. An example of file structure is as
|
||||
followed.
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
├── data
|
||||
│ ├── my_dataset
|
||||
│ │ ├── img_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── img_dir2
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── ann_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{seg_map_suffix}
|
||||
│ │ │ │ ├── yyy{seg_map_suffix}
|
||||
│ │ │ │ ├── zzz{seg_map_suffix}
|
||||
│ │ │ ├── val
|
||||
|
||||
The image names in img_dir and img_dir2 should be consistent.
|
||||
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
|
||||
except suffix. A valid img/gt_semantic_seg filename pair should be like
|
||||
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
|
||||
in the suffix). If split is given, then ``xxx`` is specified in txt file.
|
||||
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
|
||||
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
|
||||
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path. Defaults to ''.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
specify classes to load. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||
dict(img_path=None, img_path2=None, seg_map_path=None).
|
||||
img_suffix (str): Suffix of images. Default: '.jpg'
|
||||
img_suffix2 (str): Suffix of images. Default: '.jpg'
|
||||
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to None which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=True``. Defaults to False.
|
||||
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
ignore_index (int): The label index to be ignored. Default: 255
|
||||
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
||||
Default to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
METAINFO: dict = dict()
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str = '',
|
||||
img_suffix='.jpg',
|
||||
img_suffix2='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: dict = dict(
|
||||
img_path='', img_path2='', seg_map_path=''),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
pipeline: List[Union[dict, Callable]] = [],
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000,
|
||||
ignore_index: int = 255,
|
||||
reduce_zero_label: bool = False,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
|
||||
self.img_suffix = img_suffix
|
||||
self.img_suffix2 = img_suffix2
|
||||
self.seg_map_suffix = seg_map_suffix
|
||||
self.ignore_index = ignore_index
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
self.data_root = data_root
|
||||
self.data_prefix = copy.copy(data_prefix)
|
||||
self.ann_file = ann_file
|
||||
self.filter_cfg = copy.deepcopy(filter_cfg)
|
||||
self._indices = indices
|
||||
self.serialize_data = serialize_data
|
||||
self.test_mode = test_mode
|
||||
self.max_refetch = max_refetch
|
||||
self.data_list: List[dict] = []
|
||||
self.data_bytes: np.ndarray
|
||||
|
||||
# Set meta information.
|
||||
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
|
||||
|
||||
# Get label map for custom classes
|
||||
new_classes = self._metainfo.get('classes', None)
|
||||
self.label_map = self.get_label_map(new_classes)
|
||||
self._metainfo.update(
|
||||
dict(
|
||||
label_map=self.label_map,
|
||||
reduce_zero_label=self.reduce_zero_label))
|
||||
|
||||
# Update palette based on label map or generate palette
|
||||
# if it is not defined
|
||||
updated_palette = self._update_palette()
|
||||
self._metainfo.update(dict(palette=updated_palette))
|
||||
|
||||
# Join paths.
|
||||
if self.data_root is not None:
|
||||
self._join_prefix()
|
||||
|
||||
# Build pipeline.
|
||||
self.pipeline = Compose(pipeline)
|
||||
# Full initialize the dataset.
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
if test_mode:
|
||||
assert self._metainfo.get('classes') is not None, \
|
||||
'dataset metainfo `classes` should be specified when testing'
|
||||
|
||||
@classmethod
|
||||
def get_label_map(cls,
|
||||
new_classes: Optional[Sequence] = None
|
||||
) -> Union[Dict, None]:
|
||||
"""Require label mapping.
|
||||
|
||||
The ``label_map`` is a dictionary, its keys are the old label ids and
|
||||
its values are the new label ids, and is used for changing pixel
|
||||
labels in load_annotations. If and only if old classes in cls.METAINFO
|
||||
is not equal to new classes in self._metainfo and nether of them is not
|
||||
None, `label_map` is not None.
|
||||
|
||||
Args:
|
||||
new_classes (list, tuple, optional): The new classes name from
|
||||
metainfo. Default to None.
|
||||
|
||||
|
||||
Returns:
|
||||
dict, optional: The mapping from old classes in cls.METAINFO to
|
||||
new classes in self._metainfo
|
||||
"""
|
||||
old_classes = cls.METAINFO.get('classes', None)
|
||||
if (new_classes is not None and old_classes is not None
|
||||
and list(new_classes) != list(old_classes)):
|
||||
|
||||
label_map = {}
|
||||
if not set(new_classes).issubset(cls.METAINFO['classes']):
|
||||
raise ValueError(
|
||||
f'new classes {new_classes} is not a '
|
||||
f'subset of classes {old_classes} in METAINFO.')
|
||||
for i, c in enumerate(old_classes):
|
||||
if c not in new_classes:
|
||||
label_map[i] = 255
|
||||
else:
|
||||
label_map[i] = new_classes.index(c)
|
||||
return label_map
|
||||
else:
|
||||
return None
|
||||
|
||||
def _update_palette(self) -> list:
|
||||
"""Update palette after loading metainfo.
|
||||
|
||||
If length of palette is equal to classes, just return the palette.
|
||||
If palette is not defined, it will randomly generate a palette.
|
||||
If classes is updated by customer, it will return the subset of
|
||||
palette.
|
||||
|
||||
Returns:
|
||||
Sequence: Palette for current dataset.
|
||||
"""
|
||||
palette = self._metainfo.get('palette', [])
|
||||
classes = self._metainfo.get('classes', [])
|
||||
# palette does match classes
|
||||
if len(palette) == len(classes):
|
||||
return palette
|
||||
|
||||
if len(palette) == 0:
|
||||
# Get random state before set seed, and restore
|
||||
# random state later.
|
||||
# It will prevent loss of randomness, as the palette
|
||||
# may be different in each iteration if not specified.
|
||||
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||
state = np.random.get_state()
|
||||
np.random.seed(42)
|
||||
# random palette
|
||||
new_palette = np.random.randint(
|
||||
0, 255, size=(len(classes), 3)).tolist()
|
||||
np.random.set_state(state)
|
||||
elif len(palette) >= len(classes) and self.label_map is not None:
|
||||
new_palette = []
|
||||
# return subset of palette
|
||||
for old_id, new_id in sorted(
|
||||
self.label_map.items(), key=lambda x: x[1]):
|
||||
if new_id != 255:
|
||||
new_palette.append(palette[old_id])
|
||||
new_palette = type(palette)(new_palette)
|
||||
else:
|
||||
raise ValueError('palette does not match classes '
|
||||
f'as metainfo is {self._metainfo}.')
|
||||
return new_palette
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
img_dir = self.data_prefix.get('img_path', None)
|
||||
img_dir2 = self.data_prefix.get('img_path2', None)
|
||||
ann_dir = self.data_prefix.get('seg_map_path', None)
|
||||
if osp.isfile(self.ann_file):
|
||||
lines = mmengine.list_from_file(
|
||||
self.ann_file, backend_args=self.backend_args)
|
||||
for line in lines:
|
||||
img_name = line.strip()
|
||||
if '.' in osp.basename(img_name):
|
||||
img_name, img_ext = osp.splitext(img_name)
|
||||
self.img_suffix = img_ext
|
||||
self.img_suffix2 = img_ext
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img_name + self.img_suffix),
|
||||
img_path2=osp.join(img_dir2, img_name + self.img_suffix2))
|
||||
|
||||
if ann_dir is not None:
|
||||
seg_map = img_name + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
else:
|
||||
for img in fileio.list_dir_or_file(
|
||||
dir_path=img_dir,
|
||||
list_dir=False,
|
||||
suffix=self.img_suffix,
|
||||
recursive=True,
|
||||
backend_args=self.backend_args):
|
||||
if '.' in osp.basename(img):
|
||||
img, img_ext = osp.splitext(img)
|
||||
self.img_suffix = img_ext
|
||||
self.img_suffix2 = img_ext
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img + self.img_suffix),
|
||||
img_path2=osp.join(img_dir2, img + self.img_suffix2))
|
||||
if ann_dir is not None:
|
||||
seg_map = img + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
65
finetune/mmseg/datasets/c2sfloods.py
Normal file
65
finetune/mmseg/datasets/c2sfloods.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
# LEGEND = [
|
||||
# 255 255 255; % Background
|
||||
# 0 0 0; % Roads
|
||||
# 100 100 100; % Buildings
|
||||
# 0 125 0; % Trees
|
||||
# 0 255 0; % Grass
|
||||
# 150 80 0; % Bare Soil
|
||||
# 0 0 150; % Water
|
||||
# 255 255 0; % Railways
|
||||
# 150 150 255]; % Swimming Pools
|
||||
|
||||
@DATASETS.register_module()
|
||||
class C2SFloodDataset(BaseSegDataset):
|
||||
"""Zurich dataset.
|
||||
|
||||
In segmentation map annotation for LoveDA, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Background', 'Water', 'Cloud', 'Cloud shadow'),
|
||||
palette=[[0,0,0], [255,255,255], [255,0,0], [0,255,0]]
|
||||
)
|
||||
def __init__(self,
|
||||
img_suffix='.npz',
|
||||
seg_map_suffix='.npz',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
for line in lines:
|
||||
# img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['s2_path']))
|
||||
if 'target_path' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['target_path'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
54
finetune/mmseg/datasets/cabuar.py
Normal file
54
finetune/mmseg/datasets/cabuar.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
@DATASETS.register_module()
|
||||
class CABURADataset(BaseSegDataset):
|
||||
"""Zurich dataset.
|
||||
|
||||
In segmentation map annotation for LoveDA, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Background', 'Burned area'),
|
||||
palette=[[0,0,0], [255,255,255]]
|
||||
)
|
||||
def __init__(self,
|
||||
img_suffix='.npz',
|
||||
seg_map_suffix='.npz',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
for line in lines:
|
||||
# img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['post_fire']))
|
||||
if 'mask' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['mask'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
72
finetune/mmseg/datasets/germany.py
Normal file
72
finetune/mmseg/datasets/germany.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
from mmengine.logging import print_log
|
||||
import pandas as pd
|
||||
# LEGEND = [
|
||||
# 255 255 255; % Background
|
||||
# 0 0 0; % Roads
|
||||
# 100 100 100; % Buildings
|
||||
# 0 125 0; % Trees
|
||||
# 0 255 0; % Grass
|
||||
# 150 80 0; % Bare Soil
|
||||
# 0 0 150; % Water
|
||||
# 255 255 0; % Railways
|
||||
# 150 150 255]; % Swimming Pools
|
||||
|
||||
@DATASETS.register_module()
|
||||
class GermanyCropDataset(BaseSegDataset):
|
||||
"""Zurich dataset.
|
||||
|
||||
In segmentation map annotation for LoveDA, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
# {0: "unknown", 1: "sugar_beet", 2: "summer_oat", 3: "meadow", 5: "rape", 8: "hop",
|
||||
# 9: "winter_spelt", 12: "winter_triticale", 13: "beans", 15: "peas", 16: "potatoes",
|
||||
# 17: "soybeans", 19: "asparagus", 22: "winter_wheat", 23: "winter_barley", 24: "winter_rye",
|
||||
# 25: "summer_barley", 26: "maize"}
|
||||
METAINFO = dict(
|
||||
classes=('sugar_beet', 'summer_oat', 'meadow', 'rape', 'hop', 'winter_spelt', 'winter_triticale', 'beans', 'peas',\
|
||||
'potatoes', 'soybeans', 'asparagus', 'winter_wheat', 'winter_barley', 'winter_rye', 'summer_barley', 'maize'),
|
||||
palette=[(255, 255, 255), (255, 255, 170), (255, 255, 85), (255, 170, 255), (255, 170, 170), (255, 170, 85), \
|
||||
(255, 85, 255), (255, 85, 170), (255, 85, 85), (170, 255, 255), (170, 255, 170), (170, 255, 85), (170, 170, 255), \
|
||||
(170, 170, 170), (170, 170, 85), (170, 85, 255), (170, 85, 170)])
|
||||
def __init__(self,
|
||||
img_suffix='.pickle',
|
||||
seg_map_suffix='.pickle',
|
||||
reduce_zero_label=True,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
print_log(f'dataset count: {len(lines)}')
|
||||
for line in lines:
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['s2_path']))
|
||||
if 'target_path' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['target_path'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
66
finetune/mmseg/datasets/sos.py
Normal file
66
finetune/mmseg/datasets/sos.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
# LEGEND = [
|
||||
# 255 255 255; % Background
|
||||
# 0 0 0; % Roads
|
||||
# 100 100 100; % Buildings
|
||||
# 0 125 0; % Trees
|
||||
# 0 255 0; % Grass
|
||||
# 150 80 0; % Bare Soil
|
||||
# 0 0 150; % Water
|
||||
# 255 255 0; % Railways
|
||||
# 150 150 255]; % Swimming Pools
|
||||
|
||||
@DATASETS.register_module()
|
||||
class SOSDataset(BaseSegDataset):
|
||||
"""Zurich dataset.
|
||||
|
||||
In segmentation map annotation for LoveDA, 0 is the ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Background', 'Oil Spill Area'),
|
||||
palette=[[0,0,0], [255,255,255]]
|
||||
)
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
if not osp.isdir(self.ann_file) and self.ann_file:
|
||||
assert osp.isfile(self.ann_file), \
|
||||
f'Failed to load `ann_file` {self.ann_file}'
|
||||
lines = json.load(open(self.ann_file))
|
||||
for line in lines:
|
||||
# img_name = line.strip()
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, line['s1_path']))
|
||||
if 'target_path' in line.keys():
|
||||
data_info['seg_map_path'] = osp.join(self.data_root, line['target_path'])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
# print(data_list)
|
||||
return data_list
|
||||
32
finetune/mmseg/datasets/transforms/__init__.py
Normal file
32
finetune/mmseg/datasets/transforms/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .formatting import PackSegInputs
|
||||
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadDepthAnnotation, LoadImageFromNDArray,
|
||||
LoadMultipleRSImageFromFile, LoadSingleRSImageFromFile)
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomDepthMix, RandomFlip, RandomMosaic,
|
||||
RandomRotate, RandomRotFlip, Rerange, Resize,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
from .loading_npz import (LoadAnnotationsNpz, LoadImageFromNpz, LoadTsImageFromNpz, LoadAnnotationsOil, LoadImageOil, LoadImageSingleChannel)
|
||||
|
||||
# yapf: enable
|
||||
__all__ = [
|
||||
'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale',
|
||||
'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange',
|
||||
'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs',
|
||||
'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
|
||||
'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput',
|
||||
'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix',
|
||||
'RandomFlip', 'Resize', 'LoadAnnotationsNpz', 'LoadImageFromNpz', 'LoadTsImageFromNpz',
|
||||
'LoadAnnotationsOil', 'LoadImageOil', 'LoadImageSingleChannel'
|
||||
]
|
||||
112
finetune/mmseg/datasets/transforms/formatting.py
Normal file
112
finetune/mmseg/datasets/transforms/formatting.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from mmcv.transforms import to_tensor
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class PackSegInputs(BaseTransform):
|
||||
"""Pack the inputs data for the semantic segmentation.
|
||||
|
||||
The ``img_meta`` item is always populated. The contents of the
|
||||
``img_meta`` dictionary depends on ``meta_keys``. By default this includes:
|
||||
|
||||
- ``img_path``: filename of the image
|
||||
|
||||
- ``ori_shape``: original shape of the image as a tuple (h, w, c)
|
||||
|
||||
- ``img_shape``: shape of the image input to the network as a tuple \
|
||||
(h, w, c). Note that images may be zero padded on the \
|
||||
bottom/right if the batch tensor is larger than this shape.
|
||||
|
||||
- ``pad_shape``: shape of padded images
|
||||
|
||||
- ``scale_factor``: a float indicating the preprocessing scale
|
||||
|
||||
- ``flip``: a boolean indicating if image flip transform was used
|
||||
|
||||
- ``flip_direction``: the flipping direction
|
||||
|
||||
Args:
|
||||
meta_keys (Sequence[str], optional): Meta keys to be packed from
|
||||
``SegDataSample`` and collected in ``data[img_metas]``.
|
||||
Default: ``('img_path', 'ori_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction')``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
meta_keys=('img_path', 'seg_map_path', 'ori_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'reduce_zero_label')):
|
||||
self.meta_keys = meta_keys
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Method to pack the input data.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from the data pipeline.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
|
||||
- 'inputs' (obj:`torch.Tensor`): The forward data of models.
|
||||
- 'data_sample' (obj:`SegDataSample`): The annotation info of the
|
||||
sample.
|
||||
"""
|
||||
packed_results = dict()
|
||||
if 'img' in results:
|
||||
img = results['img']
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
if not img.flags.c_contiguous:
|
||||
img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1)))
|
||||
else:
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = to_tensor(img).contiguous()
|
||||
packed_results['inputs'] = img
|
||||
|
||||
data_sample = SegDataSample()
|
||||
if 'gt_seg_map' in results:
|
||||
if len(results['gt_seg_map'].shape) == 2:
|
||||
data = to_tensor(results['gt_seg_map'][None,
|
||||
...].astype(np.int64))
|
||||
else:
|
||||
warnings.warn('Please pay attention your ground truth '
|
||||
'segmentation map, usually the segmentation '
|
||||
'map is 2D, but got '
|
||||
f'{results["gt_seg_map"].shape}')
|
||||
data = to_tensor(results['gt_seg_map'].astype(np.int64))
|
||||
gt_sem_seg_data = dict(data=data)
|
||||
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
|
||||
if 'gt_edge_map' in results:
|
||||
gt_edge_data = dict(
|
||||
data=to_tensor(results['gt_edge_map'][None,
|
||||
...].astype(np.int64)))
|
||||
data_sample.set_data(dict(gt_edge_map=PixelData(**gt_edge_data)))
|
||||
|
||||
if 'gt_depth_map' in results:
|
||||
gt_depth_data = dict(
|
||||
data=to_tensor(results['gt_depth_map'][None, ...]))
|
||||
data_sample.set_data(dict(gt_depth_map=PixelData(**gt_depth_data)))
|
||||
|
||||
img_meta = {}
|
||||
for key in self.meta_keys:
|
||||
if key in results:
|
||||
img_meta[key] = results[key]
|
||||
data_sample.set_metainfo(img_meta)
|
||||
packed_results['data_samples'] = data_sample
|
||||
|
||||
return packed_results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(meta_keys={self.meta_keys})'
|
||||
return repr_str
|
||||
771
finetune/mmseg/datasets/transforms/loading.py
Normal file
771
finetune/mmseg/datasets/transforms/loading.py
Normal file
@@ -0,0 +1,771 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine.fileio as fileio
|
||||
import numpy as np
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
||||
from mmcv.transforms import LoadImageFromFile
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.utils import datafrombytes
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
except ImportError:
|
||||
gdal = None
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadAnnotations(MMCV_LoadAnnotations):
|
||||
"""Load annotations for semantic segmentation provided by dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# Filename of semantic segmentation ground truth file.
|
||||
'seg_map_path': 'a/b/c'
|
||||
}
|
||||
|
||||
After this module, the annotation has been changed to the format below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# in str
|
||||
'seg_fields': List
|
||||
# In uint8 type.
|
||||
'gt_seg_map': np.ndarray (H, W)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
||||
|
||||
Added Keys:
|
||||
|
||||
- seg_fields (List)
|
||||
- gt_seg_map (np.uint8)
|
||||
|
||||
Args:
|
||||
reduce_zero_label (bool, optional): Whether reduce all label value
|
||||
by 1. Usually used for datasets where 0 is background label.
|
||||
Defaults to None.
|
||||
imdecode_backend (str): The image decoding backend type. The backend
|
||||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :fun:``mmcv.imfrombytes`` for details.
|
||||
Defaults to 'pillow'.
|
||||
backend_args (dict): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_zero_label=None,
|
||||
backend_args=None,
|
||||
imdecode_backend='pillow',
|
||||
) -> None:
|
||||
super().__init__(
|
||||
with_bbox=False,
|
||||
with_label=False,
|
||||
with_seg=True,
|
||||
with_keypoints=False,
|
||||
imdecode_backend=imdecode_backend,
|
||||
backend_args=backend_args)
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
if self.reduce_zero_label is not None:
|
||||
warnings.warn('`reduce_zero_label` will be deprecated, '
|
||||
'if you would like to ignore the zero label, please '
|
||||
'set `reduce_zero_label=True` when dataset '
|
||||
'initialized')
|
||||
self.imdecode_backend = imdecode_backend
|
||||
|
||||
def _load_seg_map(self, results: dict) -> None:
|
||||
"""Private function to load semantic segmentation annotations.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded semantic segmentation annotations.
|
||||
"""
|
||||
|
||||
img_bytes = fileio.get(
|
||||
results['seg_map_path'], backend_args=self.backend_args)
|
||||
gt_semantic_seg = mmcv.imfrombytes(
|
||||
img_bytes, flag='unchanged',
|
||||
backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
||||
|
||||
# reduce zero_label
|
||||
if self.reduce_zero_label is None:
|
||||
self.reduce_zero_label = results['reduce_zero_label']
|
||||
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
||||
'Initialize dataset with `reduce_zero_label` as ' \
|
||||
f'{results["reduce_zero_label"]} but when load annotation ' \
|
||||
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
||||
if self.reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||
gt_semantic_seg = gt_semantic_seg - 1
|
||||
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
# Add deep copy to solve bug of repeatedly
|
||||
# replace `gt_semantic_seg`, which is reported in
|
||||
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
||||
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||
results['gt_seg_map'] = gt_semantic_seg
|
||||
results['seg_fields'].append('gt_seg_map')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(reduce_zero_label={self.reduce_zero_label}, '
|
||||
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
|
||||
repr_str += f'backend_args={self.backend_args})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromNDArray(LoadImageFromFile):
|
||||
"""Load an image from ``results['img']``.
|
||||
|
||||
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
|
||||
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
|
||||
from webcam.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_path
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Transform function to add image meta information.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict with Webcam read image in
|
||||
``results['img']``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
img = results['img']
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img_path'] = None
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalImageFromFile(BaseTransform):
|
||||
"""Load an biomedical mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default,
|
||||
N is the number of modalities, and data type is float32
|
||||
if set to_float32 = True, or float64 if decode_backend is 'nifti' and
|
||||
to_float32 is False.
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an float64 array.
|
||||
Defaults to True.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decode_backend: str = 'nifti',
|
||||
to_xyz: bool = False,
|
||||
to_float32: bool = True,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.to_float32 = to_float32
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
|
||||
data_bytes = fileio.get(filename, self.backend_args)
|
||||
img = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
if len(img.shape) == 3:
|
||||
img = img[None, ...]
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
if self.to_xyz:
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[1:]
|
||||
results['ori_shape'] = img.shape[1:]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalAnnotation(BaseTransform):
|
||||
"""Load ``seg_map`` annotation provided by biomedical dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'gt_seg_map': np.ndarray (X, Y, Z) or (Z, Y, X)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- gt_seg_map (np.ndarray): Biomedical seg map with shape (Z, Y, X) by
|
||||
default, and data type is float32 if set to_float32 = True, or
|
||||
float64 if decode_backend is 'nifti' and to_float32 is False.
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
to_float32 (bool): Whether to convert the loaded seg map to a float32
|
||||
numpy array. If set to False, the loaded image is an float64 array.
|
||||
Defaults to True.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See :class:`mmengine.fileio` for details.
|
||||
Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decode_backend: str = 'nifti',
|
||||
to_xyz: bool = False,
|
||||
to_float32: bool = True,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
super().__init__()
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.to_float32 = to_float32
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
data_bytes = fileio.get(results['seg_map_path'], self.backend_args)
|
||||
gt_seg_map = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
gt_seg_map = gt_seg_map.astype(np.float32)
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
if self.to_xyz:
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
results['gt_seg_map'] = gt_seg_map
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalData(BaseTransform):
|
||||
"""Load an biomedical image and annotation from file.
|
||||
|
||||
The loading data format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'img': np.ndarray data[:-1, X, Y, Z]
|
||||
'seg_map': np.ndarray data[-1, X, Y, Z]
|
||||
}
|
||||
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default,
|
||||
N is the number of modalities.
|
||||
- gt_seg_map (np.ndarray, optional): Biomedical seg map with shape
|
||||
(Z, Y, X) by default.
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
with_seg (bool): Whether to parse and load the semantic segmentation
|
||||
annotation. Defaults to False.
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
with_seg=False,
|
||||
decode_backend: str = 'numpy',
|
||||
to_xyz: bool = False,
|
||||
backend_args: Optional[dict] = None) -> None: # noqa
|
||||
self.with_seg = with_seg
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
data_bytes = fileio.get(results['img_path'], self.backend_args)
|
||||
data = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
# img is 4D data (N, X, Y, Z), N is the number of protocol
|
||||
img = data[:-1, :]
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
if self.to_xyz:
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[1:]
|
||||
results['ori_shape'] = img.shape[1:]
|
||||
|
||||
if self.with_seg:
|
||||
gt_seg_map = data[-1, :]
|
||||
if self.decode_backend == 'nifti':
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
if self.to_xyz:
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
results['gt_seg_map'] = gt_seg_map
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'with_seg={self.with_seg}, '
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class InferencerLoader(BaseTransform):
|
||||
"""Load an image from ``results['img']``.
|
||||
|
||||
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
|
||||
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
|
||||
from webcam.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_path
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.from_file = TRANSFORMS.build(
|
||||
dict(type='LoadImageFromFile', **kwargs))
|
||||
self.from_ndarray = TRANSFORMS.build(
|
||||
dict(type='LoadImageFromNDArray', **kwargs))
|
||||
|
||||
def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict:
|
||||
"""Transform function to add image meta information.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict with Webcam read image in
|
||||
``results['img']``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
if isinstance(single_input, str):
|
||||
inputs = dict(img_path=single_input)
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
inputs = dict(img=single_input)
|
||||
elif isinstance(single_input, dict):
|
||||
inputs = single_input
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if 'img' in inputs:
|
||||
return self.from_ndarray(inputs)
|
||||
return self.from_file(inputs)
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadSingleRSImageFromFile(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
ds = gdal.Open(filename)
|
||||
if ds is None:
|
||||
raise Exception(f'Unable to open file: {filename}')
|
||||
img = np.einsum('ijk->jki', ds.ReadAsArray())
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadMultipleRSImageFromFile(BaseTransform):
|
||||
"""Load two Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
- img_path2
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img2
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32: bool = True):
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
self.to_float32 = to_float32
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
filename2 = results['img_path2']
|
||||
|
||||
ds = gdal.Open(filename)
|
||||
ds2 = gdal.Open(filename2)
|
||||
|
||||
if ds is None:
|
||||
raise Exception(f'Unable to open file: {filename}')
|
||||
if ds2 is None:
|
||||
raise Exception(f'Unable to open file: {filename2}')
|
||||
|
||||
img = np.einsum('ijk->jki', ds.ReadAsArray())
|
||||
img2 = np.einsum('ijk->jki', ds2.ReadAsArray())
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
img2 = img2.astype(np.float32)
|
||||
|
||||
if img.shape != img2.shape:
|
||||
raise Exception(f'Image shapes do not match:'
|
||||
f' {img.shape} vs {img2.shape}')
|
||||
|
||||
results['img'] = img
|
||||
results['img2'] = img2
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadDepthAnnotation(BaseTransform):
|
||||
"""Load ``depth_map`` annotation provided by depth estimation dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'gt_depth_map': np.ndarray [Y, X]
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_depth_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- gt_depth_map (np.ndarray): Depth map with shape (Y, X) by
|
||||
default, and data type is float32 if set to_float32 = True.
|
||||
- depth_rescale_factor (float): The rescale factor of depth map, which
|
||||
can be used to recover the original value of depth map.
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy', 'nifti', and 'cv2'. Defaults to 'cv2'.
|
||||
to_float32 (bool): Whether to convert the loaded depth map to a float32
|
||||
numpy array. If set to False, the loaded image is an uint16 array.
|
||||
Defaults to True.
|
||||
depth_rescale_factor (float): Factor to rescale the depth value to
|
||||
limit the range. Defaults to 1.0.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See :class:`mmengine.fileio` for details.
|
||||
Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
decode_backend: str = 'cv2',
|
||||
to_float32: bool = True,
|
||||
depth_rescale_factor: float = 1.0,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
super().__init__()
|
||||
self.decode_backend = decode_backend
|
||||
self.to_float32 = to_float32
|
||||
self.depth_rescale_factor = depth_rescale_factor
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load depth map.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded depth map.
|
||||
"""
|
||||
data_bytes = fileio.get(results['depth_map_path'], self.backend_args)
|
||||
gt_depth_map = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
gt_depth_map = gt_depth_map.astype(np.float32)
|
||||
|
||||
gt_depth_map *= self.depth_rescale_factor
|
||||
results['gt_depth_map'] = gt_depth_map
|
||||
results['seg_fields'].append('gt_depth_map')
|
||||
results['depth_rescale_factor'] = self.depth_rescale_factor
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'backend_args={self.backend_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromNpyFile(LoadImageFromFile):
|
||||
"""Load an image from ``results['img_path']``.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def transform(self, results: dict) -> Optional[dict]:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from
|
||||
:class:`mmengine.dataset.BaseDataset`.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
|
||||
try:
|
||||
if Path(filename).suffix in ['.npy', '.npz']:
|
||||
img = np.load(filename)
|
||||
else:
|
||||
if self.file_client_args is not None:
|
||||
file_client = fileio.FileClient.infer_client(
|
||||
self.file_client_args, filename)
|
||||
img_bytes = file_client.get(filename)
|
||||
else:
|
||||
img_bytes = fileio.get(
|
||||
filename, backend_args=self.backend_args)
|
||||
img = mmcv.imfrombytes(
|
||||
img_bytes,
|
||||
flag=self.color_type,
|
||||
backend=self.imdecode_backend)
|
||||
except Exception as e:
|
||||
if self.ignore_empty:
|
||||
return None
|
||||
else:
|
||||
raise e
|
||||
|
||||
# in some cases, images are not read successfully, the img would be
|
||||
# `None`, refer to https://github.com/open-mmlab/mmpretrain/issues/1427
|
||||
assert img is not None, f'failed to load image: {filename}'
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
493
finetune/mmseg/datasets/transforms/loading_npz.py
Normal file
493
finetune/mmseg/datasets/transforms/loading_npz.py
Normal file
@@ -0,0 +1,493 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Dict, Optional, Union
|
||||
import io
|
||||
|
||||
import mmcv
|
||||
import mmengine.fileio as fileio
|
||||
import numpy as np
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
||||
from mmcv.transforms import LoadImageFromFile
|
||||
import imageio
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.utils import datafrombytes
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
except ImportError:
|
||||
gdal = None
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadAnnotationsNpz(MMCV_LoadAnnotations):
|
||||
"""Load annotations for semantic segmentation provided by dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# Filename of semantic segmentation ground truth file.
|
||||
'seg_map_path': 'a/b/c'
|
||||
}
|
||||
|
||||
After this module, the annotation has been changed to the format below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# in str
|
||||
'seg_fields': List
|
||||
# In uint8 type.
|
||||
'gt_seg_map': np.ndarray (H, W)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
||||
|
||||
Added Keys:
|
||||
|
||||
- seg_fields (List)
|
||||
- gt_seg_map (np.uint8)
|
||||
|
||||
Args:
|
||||
reduce_zero_label (bool, optional): Whether reduce all label value
|
||||
by 1. Usually used for datasets where 0 is background label.
|
||||
Defaults to None.
|
||||
imdecode_backend (str): The image decoding backend type. The backend
|
||||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :fun:``mmcv.imfrombytes`` for details.
|
||||
Defaults to 'pillow'.
|
||||
backend_args (dict): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_zero_label=None,
|
||||
backend_args=None,
|
||||
data_key='arr_0',
|
||||
imdecode_backend='pillow',
|
||||
) -> None:
|
||||
super().__init__(
|
||||
with_bbox=False,
|
||||
with_label=False,
|
||||
with_seg=True,
|
||||
with_keypoints=False,
|
||||
imdecode_backend=imdecode_backend,
|
||||
backend_args=backend_args)
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
if self.reduce_zero_label is not None:
|
||||
warnings.warn('`reduce_zero_label` will be deprecated, '
|
||||
'if you would like to ignore the zero label, please '
|
||||
'set `reduce_zero_label=True` when dataset '
|
||||
'initialized')
|
||||
self.imdecode_backend = imdecode_backend
|
||||
self.data_key = data_key
|
||||
|
||||
def _load_seg_map(self, results: dict) -> None:
|
||||
"""Private function to load semantic segmentation annotations.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded semantic segmentation annotations.
|
||||
"""
|
||||
|
||||
# img_bytes = fileio.get(
|
||||
# results['seg_map_path'], backend_args=self.backend_args)
|
||||
# gt_semantic_seg = mmcv.imfrombytes(
|
||||
# img_bytes, flag='unchanged',
|
||||
# backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
||||
gt_semantic_seg = np.load(results['seg_map_path'])[self.data_key].squeeze().astype(np.uint8)
|
||||
# reduce zero_label
|
||||
if self.reduce_zero_label is None:
|
||||
self.reduce_zero_label = results['reduce_zero_label']
|
||||
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
||||
'Initialize dataset with `reduce_zero_label` as ' \
|
||||
f'{results["reduce_zero_label"]} but when load annotation ' \
|
||||
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
||||
if self.reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||
gt_semantic_seg = gt_semantic_seg - 1
|
||||
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
# Add deep copy to solve bug of repeatedly
|
||||
# replace `gt_semantic_seg`, which is reported in
|
||||
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
||||
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||
results['gt_seg_map'] = gt_semantic_seg
|
||||
results['seg_fields'].append('gt_seg_map')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(reduce_zero_label={self.reduce_zero_label}, '
|
||||
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
|
||||
repr_str += f'backend_args={self.backend_args})'
|
||||
return repr_str
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageSingleChannel(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
# self.data_key = data_key
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
img = imageio.imread(filename) # h, w, c
|
||||
img = img[:, :, 0]
|
||||
# if ds is None:
|
||||
# raise Exception(f'Unable to open file: {filename}')
|
||||
# print(img)s
|
||||
# img = np.einsum('ijk->jki', img)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadAnnotationsOil(MMCV_LoadAnnotations):
|
||||
"""Load annotations for semantic segmentation provided by dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# Filename of semantic segmentation ground truth file.
|
||||
'seg_map_path': 'a/b/c'
|
||||
}
|
||||
|
||||
After this module, the annotation has been changed to the format below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# in str
|
||||
'seg_fields': List
|
||||
# In uint8 type.
|
||||
'gt_seg_map': np.ndarray (H, W)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path (str): Path of semantic segmentation ground truth file.
|
||||
|
||||
Added Keys:
|
||||
|
||||
- seg_fields (List)
|
||||
- gt_seg_map (np.uint8)
|
||||
|
||||
Args:
|
||||
reduce_zero_label (bool, optional): Whether reduce all label value
|
||||
by 1. Usually used for datasets where 0 is background label.
|
||||
Defaults to None.
|
||||
imdecode_backend (str): The image decoding backend type. The backend
|
||||
argument for :func:``mmcv.imfrombytes``.
|
||||
See :fun:``mmcv.imfrombytes`` for details.
|
||||
Defaults to 'pillow'.
|
||||
backend_args (dict): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduce_zero_label=None,
|
||||
backend_args=None,
|
||||
data_key='arr_0',
|
||||
imdecode_backend='pillow',
|
||||
) -> None:
|
||||
super().__init__(
|
||||
with_bbox=False,
|
||||
with_label=False,
|
||||
with_seg=True,
|
||||
with_keypoints=False,
|
||||
imdecode_backend=imdecode_backend,
|
||||
backend_args=backend_args)
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
if self.reduce_zero_label is not None:
|
||||
warnings.warn('`reduce_zero_label` will be deprecated, '
|
||||
'if you would like to ignore the zero label, please '
|
||||
'set `reduce_zero_label=True` when dataset '
|
||||
'initialized')
|
||||
self.imdecode_backend = imdecode_backend
|
||||
self.data_key = data_key
|
||||
|
||||
def _load_seg_map(self, results: dict) -> None:
|
||||
"""Private function to load semantic segmentation annotations.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded semantic segmentation annotations.
|
||||
"""
|
||||
|
||||
# img_bytes = fileio.get(
|
||||
# results['seg_map_path'], backend_args=self.backend_args)
|
||||
# gt_semantic_seg = mmcv.imfrombytes(
|
||||
# img_bytes, flag='unchanged',
|
||||
# backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
||||
seg_map = gdal.Open(results['seg_map_path']).ReadAsArray()
|
||||
gt_semantic_seg = np.zeros_like(seg_map).astype(np.uint8)
|
||||
gt_semantic_seg[seg_map==3.] = 1
|
||||
# reduce zero_label
|
||||
if self.reduce_zero_label is None:
|
||||
self.reduce_zero_label = results['reduce_zero_label']
|
||||
assert self.reduce_zero_label == results['reduce_zero_label'], \
|
||||
'Initialize dataset with `reduce_zero_label` as ' \
|
||||
f'{results["reduce_zero_label"]} but when load annotation ' \
|
||||
f'the `reduce_zero_label` is {self.reduce_zero_label}'
|
||||
if self.reduce_zero_label:
|
||||
# avoid using underflow conversion
|
||||
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
||||
gt_semantic_seg = gt_semantic_seg - 1
|
||||
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
||||
# modify if custom classes
|
||||
if results.get('label_map', None) is not None:
|
||||
# Add deep copy to solve bug of repeatedly
|
||||
# replace `gt_semantic_seg`, which is reported in
|
||||
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
|
||||
gt_semantic_seg_copy = gt_semantic_seg.copy()
|
||||
for old_id, new_id in results['label_map'].items():
|
||||
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
|
||||
results['gt_seg_map'] = gt_semantic_seg
|
||||
results['seg_fields'].append('gt_seg_map')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(reduce_zero_label={self.reduce_zero_label}, '
|
||||
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
|
||||
repr_str += f'backend_args={self.backend_args})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromNpz(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, data_key='arr_0', to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
self.data_key = data_key
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
img = np.load(filename)[self.data_key]
|
||||
# if ds is None:
|
||||
# raise Exception(f'Unable to open file: {filename}')
|
||||
# print(img)s
|
||||
img = np.einsum('ijk->jki', img)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageOil(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, data_key='arr_0', to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
self.data_key = data_key
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
img = gdal.Open(filename).ReadAsArray()
|
||||
img = img[:,:,None]
|
||||
# if ds is None:
|
||||
# raise Exception(f'Unable to open file: {filename}')
|
||||
# print(img)s
|
||||
# img = np.einsum('ijk->jki', img)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadTsImageFromNpz(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, data_key='arr_0', to_float32: bool = True, ts_size: int=10):
|
||||
self.to_float32 = to_float32
|
||||
self.data_key = data_key
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
self.ts_size = ts_size
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
img = np.load(filename)[self.data_key]
|
||||
ts, c, h, w = img.shape
|
||||
if ts >= self.ts_size:
|
||||
selected_indices = np.random.choice(ts, size=self.ts_size, replace=False)
|
||||
else:
|
||||
selected_indices = np.random.choice(ts, size=self.ts_size, replace=True)
|
||||
selected_indices.sort()
|
||||
img = img[selected_indices, :, :, :]
|
||||
# print(f'after input shape: {img.shape}')
|
||||
img = img.transpose(2, 3, 0, 1).reshape(h, w, self.ts_size*c) # h, w, ts, c -> h, w, ts*c
|
||||
# if ds is None:
|
||||
# raise Exception(f'Unable to open file: {filename}')
|
||||
# print(img)s
|
||||
# img = np.einsum('ijk->jki', img)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
|
||||
2537
finetune/mmseg/datasets/transforms/transforms.py
Normal file
2537
finetune/mmseg/datasets/transforms/transforms.py
Normal file
File diff suppressed because it is too large
Load Diff
12
finetune/mmseg/engine/__init__.py
Normal file
12
finetune/mmseg/engine/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hooks import SegVisualizationHook
|
||||
from .optimizers import (ForceDefaultOptimWrapperConstructor,
|
||||
LayerDecayOptimizerConstructor,
|
||||
LearningRateDecayOptimizerConstructor)
|
||||
from .schedulers import PolyLRRatio
|
||||
|
||||
__all__ = [
|
||||
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
|
||||
'SegVisualizationHook', 'PolyLRRatio',
|
||||
'ForceDefaultOptimWrapperConstructor'
|
||||
]
|
||||
4
finetune/mmseg/engine/hooks/__init__.py
Normal file
4
finetune/mmseg/engine/hooks/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .visualization_hook import SegVisualizationHook
|
||||
|
||||
__all__ = ['SegVisualizationHook']
|
||||
129
finetune/mmseg/engine/hooks/visualization_hook.py
Normal file
129
finetune/mmseg/engine/hooks/visualization_hook.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import mmcv
|
||||
from mmengine.fileio import get
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
from mmseg.registry import HOOKS
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class SegVisualizationHook(Hook):
|
||||
"""Segmentation Visualization Hook. Used to visualize validation and
|
||||
testing process prediction results.
|
||||
|
||||
In the testing phase:
|
||||
|
||||
1. If ``show`` is True, it means that only the prediction results are
|
||||
visualized without storing data, so ``vis_backends`` needs to
|
||||
be excluded.
|
||||
|
||||
Args:
|
||||
draw (bool): whether to draw prediction results. If it is False,
|
||||
it means that no drawing will be done. Defaults to False.
|
||||
interval (int): The interval of visualization. Defaults to 50.
|
||||
show (bool): Whether to display the drawn image. Default to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
draw: bool = False,
|
||||
interval: int = 50,
|
||||
show: bool = False,
|
||||
wait_time: float = 0.,
|
||||
backend_args: Optional[dict] = None):
|
||||
self._visualizer: Visualizer = Visualizer.get_current_instance()
|
||||
self.interval = interval
|
||||
self.show = show
|
||||
if self.show:
|
||||
# No need to think about vis backends.
|
||||
self._visualizer._vis_backends = {}
|
||||
warnings.warn('The show is True, it means that only '
|
||||
'the prediction results are visualized '
|
||||
'without storing data, so vis_backends '
|
||||
'needs to be excluded.')
|
||||
|
||||
self.wait_time = wait_time
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
self.draw = draw
|
||||
if not self.draw:
|
||||
warnings.warn('The draw is False, it means that the '
|
||||
'hook for visualization will not take '
|
||||
'effect. The results will NOT be '
|
||||
'visualized or stored.')
|
||||
self._test_index = 0
|
||||
|
||||
def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
||||
outputs: Sequence[SegDataSample]) -> None:
|
||||
"""Run after every ``self.interval`` validation iterations.
|
||||
|
||||
Args:
|
||||
runner (:obj:`Runner`): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (dict): Data from dataloader.
|
||||
outputs (Sequence[:obj:`SegDataSample`]]): A batch of data samples
|
||||
that contain annotations and predictions.
|
||||
"""
|
||||
if self.draw is False:
|
||||
return
|
||||
|
||||
# There is no guarantee that the same batch of images
|
||||
# is visualized for each evaluation.
|
||||
total_curr_iter = runner.iter + batch_idx
|
||||
|
||||
# Visualize only the first data
|
||||
img_path = outputs[0].img_path
|
||||
img_bytes = get(img_path, backend_args=self.backend_args)
|
||||
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
||||
window_name = f'val_{osp.basename(img_path)}'
|
||||
|
||||
if total_curr_iter % self.interval == 0:
|
||||
self._visualizer.add_datasample(
|
||||
window_name,
|
||||
img,
|
||||
data_sample=outputs[0],
|
||||
show=self.show,
|
||||
wait_time=self.wait_time,
|
||||
step=total_curr_iter)
|
||||
|
||||
def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict,
|
||||
outputs: Sequence[SegDataSample]) -> None:
|
||||
"""Run after every testing iterations.
|
||||
|
||||
Args:
|
||||
runner (:obj:`Runner`): The runner of the testing process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (dict): Data from dataloader.
|
||||
outputs (Sequence[:obj:`SegDataSample`]): A batch of data samples
|
||||
that contain annotations and predictions.
|
||||
"""
|
||||
if self.draw is False:
|
||||
return
|
||||
|
||||
for data_sample in outputs:
|
||||
self._test_index += 1
|
||||
|
||||
img_path = data_sample.img_path
|
||||
window_name = f'test_{osp.basename(img_path)}'
|
||||
|
||||
img_path = data_sample.img_path
|
||||
img_bytes = get(img_path, backend_args=self.backend_args)
|
||||
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
||||
|
||||
self._visualizer.add_datasample(
|
||||
window_name,
|
||||
img,
|
||||
data_sample=data_sample,
|
||||
show=self.show,
|
||||
wait_time=self.wait_time,
|
||||
step=self._test_index)
|
||||
9
finetune/mmseg/engine/optimizers/__init__.py
Normal file
9
finetune/mmseg/engine/optimizers/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .force_default_constructor import ForceDefaultOptimWrapperConstructor
|
||||
from .layer_decay_optimizer_constructor import (
|
||||
LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor)
|
||||
|
||||
__all__ = [
|
||||
'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor',
|
||||
'ForceDefaultOptimWrapperConstructor'
|
||||
]
|
||||
255
finetune/mmseg/engine/optimizers/force_default_constructor.py
Normal file
255
finetune/mmseg/engine/optimizers/force_default_constructor.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.optim import DefaultOptimWrapperConstructor
|
||||
from mmengine.utils.dl_utils import mmcv_full_available
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||
from torch.nn import GroupNorm, LayerNorm
|
||||
|
||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
||||
|
||||
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class ForceDefaultOptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
||||
"""Default constructor with forced optimizer settings.
|
||||
|
||||
This constructor extends the default constructor to add an option for
|
||||
forcing default optimizer settings. This is useful for ensuring that
|
||||
certain parameters or layers strictly adhere to pre-defined default
|
||||
settings, regardless of any custom settings specified.
|
||||
|
||||
By default, each parameter share the same optimizer settings, and we
|
||||
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
|
||||
It is a dict and may contain various fields like 'custom_keys',
|
||||
'bias_lr_mult', etc., as well as the additional field
|
||||
`force_default_settings` which allows for enforcing default settings on
|
||||
optimizer parameters.
|
||||
|
||||
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
|
||||
one of the keys in ``custom_keys`` is a substring of the name of one
|
||||
parameter, then the setting of the parameter will be specified by
|
||||
``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
|
||||
be ignored. It should be noted that the aforementioned ``key`` is the
|
||||
longest key that is a substring of the name of the parameter. If there
|
||||
are multiple matched keys with the same length, then the key with lower
|
||||
alphabet order will be chosen.
|
||||
``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
|
||||
and ``decay_mult``. See Example 2 below.
|
||||
- ``bias_lr_mult`` (float): It will be multiplied to the learning
|
||||
rate for all bias parameters (except for those in normalization
|
||||
layers and offset layers of DCN).
|
||||
- ``bias_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all bias parameters (except for those in
|
||||
normalization layers, depthwise conv layers, offset layers of DCN).
|
||||
- ``norm_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all weight and bias parameters of normalization
|
||||
layers.
|
||||
- ``flat_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all one-dimensional parameters
|
||||
- ``dwconv_decay_mult`` (float): It will be multiplied to the weight
|
||||
decay for all weight and bias parameters of depthwise conv
|
||||
layers.
|
||||
- ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
|
||||
rate for parameters of offset layer in the deformable convs
|
||||
of a model.
|
||||
- ``bypass_duplicate`` (bool): If true, the duplicate parameters
|
||||
would not be added into optimizer. Defaults to False.
|
||||
- ``force_default_settings`` (bool): If true, this will override any
|
||||
custom settings defined by ``custom_keys`` and enforce the use of
|
||||
default settings for optimizer parameters like ``bias_lr_mult``.
|
||||
This is particularly useful when you want to ensure that certain layers
|
||||
or parameters adhere strictly to the pre-defined default settings.
|
||||
|
||||
Note:
|
||||
|
||||
1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
|
||||
override the effect of ``bias_lr_mult`` in the bias of offset layer.
|
||||
So be careful when using both ``bias_lr_mult`` and
|
||||
``dcn_offset_lr_mult``. If you wish to apply both of them to the offset
|
||||
layer in deformable convs, set ``dcn_offset_lr_mult`` to the original
|
||||
``dcn_offset_lr_mult`` * ``bias_lr_mult``.
|
||||
|
||||
2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
|
||||
apply it to all the DCN layers in the model. So be careful when the
|
||||
model contains multiple DCN layers in places other than backbone.
|
||||
|
||||
3. When the option ``force_default_settings`` is true, it will override
|
||||
any custom settings provided in ``custom_keys``. This ensures that the
|
||||
default settings for the optimizer parameters are used.
|
||||
|
||||
Args:
|
||||
optim_wrapper_cfg (dict): The config dict of the optimizer wrapper.
|
||||
|
||||
Required fields of ``optim_wrapper_cfg`` are
|
||||
|
||||
- ``type``: class name of the OptimizerWrapper
|
||||
- ``optimizer``: The configuration of optimizer.
|
||||
|
||||
Optional fields of ``optim_wrapper_cfg`` are
|
||||
|
||||
- any arguments of the corresponding optimizer wrapper type,
|
||||
e.g., accumulative_counts, clip_grad, etc.
|
||||
|
||||
Required fields of ``optimizer`` are
|
||||
|
||||
- `type`: class name of the optimizer.
|
||||
|
||||
Optional fields of ``optimizer`` are
|
||||
|
||||
- any arguments of the corresponding optimizer type, e.g.,
|
||||
lr, weight_decay, momentum, etc.
|
||||
|
||||
paramwise_cfg (dict, optional): Parameter-wise options.
|
||||
|
||||
Example 1:
|
||||
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
||||
>>> optim_wrapper_cfg = dict(
|
||||
>>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01,
|
||||
>>> momentum=0.9, weight_decay=0.0001))
|
||||
>>> paramwise_cfg = dict(norm_decay_mult=0.)
|
||||
>>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
|
||||
>>> optim_wrapper_cfg, paramwise_cfg)
|
||||
>>> optim_wrapper = optim_wrapper_builder(model)
|
||||
|
||||
Example 2:
|
||||
>>> # assume model have attribute model.backbone and model.cls_head
|
||||
>>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict(
|
||||
>>> type='SGD', lr=0.01, weight_decay=0.95))
|
||||
>>> paramwise_cfg = dict(custom_keys={
|
||||
>>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9)})
|
||||
>>> optim_wrapper_builder = DefaultOptimWrapperConstructor(
|
||||
>>> optim_wrapper_cfg, paramwise_cfg)
|
||||
>>> optim_wrapper = optim_wrapper_builder(model)
|
||||
>>> # Then the `lr` and `weight_decay` for model.backbone is
|
||||
>>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
|
||||
>>> # model.cls_head is (0.01, 0.95).
|
||||
"""
|
||||
|
||||
def add_params(self,
|
||||
params: List[dict],
|
||||
module: nn.Module,
|
||||
prefix: str = '',
|
||||
is_dcn_module: Optional[Union[int, float]] = None) -> None:
|
||||
"""Add all parameters of module to the params list.
|
||||
|
||||
The parameters of the given module will be added to the list of param
|
||||
groups, with specific rules defined by paramwise_cfg.
|
||||
|
||||
Args:
|
||||
params (list[dict]): A list of param groups, it will be modified
|
||||
in place.
|
||||
module (nn.Module): The module to be added.
|
||||
prefix (str): The prefix of the module
|
||||
is_dcn_module (int|float|None): If the current module is a
|
||||
submodule of DCN, `is_dcn_module` will be passed to
|
||||
control conv_offset layer's learning rate. Defaults to None.
|
||||
"""
|
||||
# get param-wise options
|
||||
custom_keys = self.paramwise_cfg.get('custom_keys', {})
|
||||
# first sort with alphabet order and then sort with reversed len of str
|
||||
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
|
||||
|
||||
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', None)
|
||||
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None)
|
||||
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None)
|
||||
dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', None)
|
||||
flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None)
|
||||
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
|
||||
dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', None)
|
||||
force_default_settings = self.paramwise_cfg.get(
|
||||
'force_default_settings', False)
|
||||
|
||||
# special rules for norm layers and depth-wise conv layers
|
||||
is_norm = isinstance(module,
|
||||
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
|
||||
is_dwconv = (
|
||||
isinstance(module, torch.nn.Conv2d)
|
||||
and module.in_channels == module.groups)
|
||||
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
param_group = {'params': [param]}
|
||||
if bypass_duplicate and self._is_in(param_group, params):
|
||||
print_log(
|
||||
f'{prefix} is duplicate. It is skipped since '
|
||||
f'bypass_duplicate={bypass_duplicate}',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
continue
|
||||
if not param.requires_grad:
|
||||
params.append(param_group)
|
||||
continue
|
||||
|
||||
# if the parameter match one of the custom keys, ignore other rules
|
||||
is_custom = False
|
||||
for key in sorted_keys:
|
||||
if key in f'{prefix}.{name}':
|
||||
is_custom = True
|
||||
lr_mult = custom_keys[key].get('lr_mult', 1.)
|
||||
param_group['lr'] = self.base_lr * lr_mult
|
||||
if self.base_wd is not None:
|
||||
decay_mult = custom_keys[key].get('decay_mult', 1.)
|
||||
param_group['weight_decay'] = self.base_wd * decay_mult
|
||||
# add custom settings to param_group
|
||||
for k, v in custom_keys[key].items():
|
||||
param_group[k] = v
|
||||
break
|
||||
|
||||
if not is_custom or force_default_settings:
|
||||
# bias_lr_mult affects all bias parameters
|
||||
# except for norm.bias dcn.conv_offset.bias
|
||||
if name == 'bias' and not (
|
||||
is_norm or is_dcn_module) and bias_lr_mult is not None:
|
||||
param_group['lr'] = self.base_lr * bias_lr_mult
|
||||
|
||||
if (prefix.find('conv_offset') != -1 and is_dcn_module
|
||||
and dcn_offset_lr_mult is not None
|
||||
and isinstance(module, torch.nn.Conv2d)):
|
||||
# deal with both dcn_offset's bias & weight
|
||||
param_group['lr'] = self.base_lr * dcn_offset_lr_mult
|
||||
|
||||
# apply weight decay policies
|
||||
if self.base_wd is not None:
|
||||
# norm decay
|
||||
if is_norm and norm_decay_mult is not None:
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * norm_decay_mult
|
||||
# bias lr and decay
|
||||
elif (name == 'bias' and not is_dcn_module
|
||||
and bias_decay_mult is not None):
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * bias_decay_mult
|
||||
# depth-wise conv
|
||||
elif is_dwconv and dwconv_decay_mult is not None:
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * dwconv_decay_mult
|
||||
# flatten parameters except dcn offset
|
||||
elif (param.ndim == 1 and not is_dcn_module
|
||||
and flat_decay_mult is not None):
|
||||
param_group[
|
||||
'weight_decay'] = self.base_wd * flat_decay_mult
|
||||
params.append(param_group)
|
||||
for key, value in param_group.items():
|
||||
if key == 'params':
|
||||
continue
|
||||
full_name = f'{prefix}.{name}' if prefix else name
|
||||
print_log(
|
||||
f'paramwise_options -- {full_name}:{key}={value}',
|
||||
logger='current')
|
||||
|
||||
if mmcv_full_available():
|
||||
from mmcv.ops import DeformConv2d, ModulatedDeformConv2d
|
||||
is_dcn_module = isinstance(module,
|
||||
(DeformConv2d, ModulatedDeformConv2d))
|
||||
else:
|
||||
is_dcn_module = False
|
||||
for child_name, child_mod in module.named_children():
|
||||
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
|
||||
self.add_params(
|
||||
params,
|
||||
child_mod,
|
||||
prefix=child_prefix,
|
||||
is_dcn_module=is_dcn_module)
|
||||
@@ -0,0 +1,207 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
import warnings
|
||||
|
||||
from mmengine.dist import get_dist_info
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.optim import DefaultOptimWrapperConstructor
|
||||
|
||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
||||
|
||||
|
||||
def get_layer_id_for_convnext(var_name, max_layer_id):
|
||||
"""Get the layer id to set the different learning rates in ``layer_wise``
|
||||
decay_type.
|
||||
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
max_layer_id (int): Maximum number of backbone layers.
|
||||
|
||||
Returns:
|
||||
int: The id number corresponding to different learning rate in
|
||||
``LearningRateDecayOptimizerConstructor``.
|
||||
"""
|
||||
|
||||
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
||||
'backbone.pos_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.downsample_layers'):
|
||||
stage_id = int(var_name.split('.')[2])
|
||||
if stage_id == 0:
|
||||
layer_id = 0
|
||||
elif stage_id == 1:
|
||||
layer_id = 2
|
||||
elif stage_id == 2:
|
||||
layer_id = 3
|
||||
elif stage_id == 3:
|
||||
layer_id = max_layer_id
|
||||
return layer_id
|
||||
elif var_name.startswith('backbone.stages'):
|
||||
stage_id = int(var_name.split('.')[2])
|
||||
block_id = int(var_name.split('.')[3])
|
||||
if stage_id == 0:
|
||||
layer_id = 1
|
||||
elif stage_id == 1:
|
||||
layer_id = 2
|
||||
elif stage_id == 2:
|
||||
layer_id = 3 + block_id // 3
|
||||
elif stage_id == 3:
|
||||
layer_id = max_layer_id
|
||||
return layer_id
|
||||
else:
|
||||
return max_layer_id + 1
|
||||
|
||||
|
||||
def get_stage_id_for_convnext(var_name, max_stage_id):
|
||||
"""Get the stage id to set the different learning rates in ``stage_wise``
|
||||
decay_type.
|
||||
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
max_stage_id (int): Maximum number of backbone layers.
|
||||
|
||||
Returns:
|
||||
int: The id number corresponding to different learning rate in
|
||||
``LearningRateDecayOptimizerConstructor``.
|
||||
"""
|
||||
|
||||
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
||||
'backbone.pos_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.downsample_layers'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.stages'):
|
||||
stage_id = int(var_name.split('.')[2])
|
||||
return stage_id + 1
|
||||
else:
|
||||
return max_stage_id - 1
|
||||
|
||||
|
||||
def get_layer_id_for_vit(var_name, max_layer_id):
|
||||
"""Get the layer id to set the different learning rates.
|
||||
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
num_max_layer (int): Maximum number of backbone layers.
|
||||
|
||||
Returns:
|
||||
int: Returns the layer id of the key.
|
||||
"""
|
||||
|
||||
if var_name in ('backbone.cls_token', 'backbone.mask_token',
|
||||
'backbone.pos_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.patch_embed'):
|
||||
return 0
|
||||
elif var_name.startswith('backbone.layers'):
|
||||
layer_id = int(var_name.split('.')[2])
|
||||
return layer_id + 1
|
||||
else:
|
||||
return max_layer_id - 1
|
||||
|
||||
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
Note: Currently, this optimizer constructor is built for ConvNeXt,
|
||||
BEiT and MAE.
|
||||
"""
|
||||
|
||||
def add_params(self, params, module, **kwargs):
|
||||
"""Add all parameters of module to the params list.
|
||||
|
||||
The parameters of the given module will be added to the list of param
|
||||
groups, with specific rules defined by paramwise_cfg.
|
||||
|
||||
Args:
|
||||
params (list[dict]): A list of param groups, it will be modified
|
||||
in place.
|
||||
module (nn.Module): The module to be added.
|
||||
"""
|
||||
|
||||
parameter_groups = {}
|
||||
print_log(f'self.paramwise_cfg is {self.paramwise_cfg}')
|
||||
num_layers = self.paramwise_cfg.get('num_layers') + 2
|
||||
decay_rate = self.paramwise_cfg.get('decay_rate')
|
||||
decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
|
||||
print_log('Build LearningRateDecayOptimizerConstructor '
|
||||
f'{decay_type} {decay_rate} - {num_layers}')
|
||||
weight_decay = self.base_wd
|
||||
for name, param in module.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue # frozen weights
|
||||
if len(param.shape) == 1 or name.endswith('.bias') or name in (
|
||||
'pos_embed', 'cls_token'):
|
||||
group_name = 'no_decay'
|
||||
this_weight_decay = 0.
|
||||
else:
|
||||
group_name = 'decay'
|
||||
this_weight_decay = weight_decay
|
||||
if 'layer_wise' in decay_type:
|
||||
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
||||
layer_id = get_layer_id_for_convnext(
|
||||
name, self.paramwise_cfg.get('num_layers'))
|
||||
print_log(f'set param {name} as id {layer_id}')
|
||||
elif 'BEiT' in module.backbone.__class__.__name__ or \
|
||||
'MAE' in module.backbone.__class__.__name__:
|
||||
layer_id = get_layer_id_for_vit(name, num_layers)
|
||||
print_log(f'set param {name} as id {layer_id}')
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
elif decay_type == 'stage_wise':
|
||||
if 'ConvNeXt' in module.backbone.__class__.__name__:
|
||||
layer_id = get_stage_id_for_convnext(name, num_layers)
|
||||
print_log(f'set param {name} as id {layer_id}')
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
group_name = f'layer_{layer_id}_{group_name}'
|
||||
|
||||
if group_name not in parameter_groups:
|
||||
scale = decay_rate**(num_layers - layer_id - 1)
|
||||
|
||||
parameter_groups[group_name] = {
|
||||
'weight_decay': this_weight_decay,
|
||||
'params': [],
|
||||
'param_names': [],
|
||||
'lr_scale': scale,
|
||||
'group_name': group_name,
|
||||
'lr': scale * self.base_lr,
|
||||
}
|
||||
|
||||
parameter_groups[group_name]['params'].append(param)
|
||||
parameter_groups[group_name]['param_names'].append(name)
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
to_display = {}
|
||||
for key in parameter_groups:
|
||||
to_display[key] = {
|
||||
'param_names': parameter_groups[key]['param_names'],
|
||||
'lr_scale': parameter_groups[key]['lr_scale'],
|
||||
'lr': parameter_groups[key]['lr'],
|
||||
'weight_decay': parameter_groups[key]['weight_decay'],
|
||||
}
|
||||
print_log(f'Param groups = {json.dumps(to_display, indent=2)}')
|
||||
params.extend(parameter_groups.values())
|
||||
|
||||
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
Note: Currently, this optimizer constructor is built for BEiT,
|
||||
and it will be deprecated.
|
||||
Please use ``LearningRateDecayOptimizerConstructor`` instead.
|
||||
"""
|
||||
|
||||
def __init__(self, optim_wrapper_cfg, paramwise_cfg):
|
||||
warnings.warn('DeprecationWarning: Original '
|
||||
'LayerDecayOptimizerConstructor of BEiT '
|
||||
'will be deprecated. Please use '
|
||||
'LearningRateDecayOptimizerConstructor instead, '
|
||||
'and set decay_type = layer_wise_vit in paramwise_cfg.')
|
||||
paramwise_cfg.update({'decay_type': 'layer_wise_vit'})
|
||||
warnings.warn('DeprecationWarning: Layer_decay_rate will '
|
||||
'be deleted, please use decay_rate instead.')
|
||||
paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate')
|
||||
super().__init__(optim_wrapper_cfg, paramwise_cfg)
|
||||
4
finetune/mmseg/engine/schedulers/__init__.py
Normal file
4
finetune/mmseg/engine/schedulers/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .poly_ratio_scheduler import PolyLRRatio
|
||||
|
||||
__all__ = ['PolyLRRatio']
|
||||
62
finetune/mmseg/engine/schedulers/poly_ratio_scheduler.py
Normal file
62
finetune/mmseg/engine/schedulers/poly_ratio_scheduler.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
from mmengine.optim.scheduler import PolyLR
|
||||
|
||||
from mmseg.registry import PARAM_SCHEDULERS
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class PolyLRRatio(PolyLR):
|
||||
"""Implements polynomial learning rate decay with ratio.
|
||||
|
||||
This scheduler adjusts the learning rate of each parameter group
|
||||
following a polynomial decay equation. The decay can occur in
|
||||
conjunction with external parameter adjustments made outside this
|
||||
scheduler.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer or OptimWrapper): Wrapped optimizer.
|
||||
eta_min (float): Minimum learning rate at the end of scheduling.
|
||||
Defaults to 0.
|
||||
eta_min_ratio (float, optional): The ratio of the minimum parameter
|
||||
value to the base parameter value. Either `eta_min` or
|
||||
`eta_min_ratio` should be specified. Defaults to None.
|
||||
power (float): The power of the polynomial. Defaults to 1.0.
|
||||
begin (int): Step at which to start updating the parameters.
|
||||
Defaults to 0.
|
||||
end (int): Step at which to stop updating the parameters.
|
||||
Defaults to INF.
|
||||
last_step (int): The index of last step. Used for resume without
|
||||
state dict. Defaults to -1.
|
||||
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||
epochs. Defaults to True.
|
||||
verbose (bool): Whether to print the value for each update.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, eta_min_ratio: Optional[int] = None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.eta_min_ratio = eta_min_ratio
|
||||
|
||||
def _get_value(self):
|
||||
"""Compute value using chainable form of the scheduler."""
|
||||
|
||||
if self.last_step == 0:
|
||||
return [
|
||||
group[self.param_name] for group in self.optimizer.param_groups
|
||||
]
|
||||
|
||||
param_groups_value = []
|
||||
for base_value, param_group in zip(self.base_values,
|
||||
self.optimizer.param_groups):
|
||||
eta_min = self.eta_min if self.eta_min_ratio is None else \
|
||||
base_value * self.eta_min_ratio
|
||||
step_ratio = (1 - 1 /
|
||||
(self.total_iters - self.last_step + 1))**self.power
|
||||
step_value = (param_group[self.param_name] -
|
||||
eta_min) * step_ratio + eta_min
|
||||
param_groups_value.append(step_value)
|
||||
|
||||
return param_groups_value
|
||||
4
finetune/mmseg/evaluation/__init__.py
Normal file
4
finetune/mmseg/evaluation/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .metrics import CityscapesMetric, DepthMetric, IoUMetric
|
||||
|
||||
__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric']
|
||||
6
finetune/mmseg/evaluation/metrics/__init__.py
Normal file
6
finetune/mmseg/evaluation/metrics/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .citys_metric import CityscapesMetric
|
||||
from .depth_metric import DepthMetric
|
||||
from .iou_metric import IoUMetric
|
||||
|
||||
__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric']
|
||||
158
finetune/mmseg/evaluation/metrics/citys_metric.py
Normal file
158
finetune/mmseg/evaluation/metrics/citys_metric.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
try:
|
||||
|
||||
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
|
||||
import cityscapesscripts.helpers.labels as CSLabels
|
||||
except ImportError:
|
||||
CSLabels = None
|
||||
CSEval = None
|
||||
|
||||
import numpy as np
|
||||
from mmengine.dist import is_main_process, master_only
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class CityscapesMetric(BaseMetric):
|
||||
"""Cityscapes evaluation metric.
|
||||
|
||||
Args:
|
||||
output_dir (str): The directory for output prediction
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
Default: 255.
|
||||
format_only (bool): Only format result for results commit without
|
||||
perform evaluation. It is useful when you want to format the result
|
||||
to a specific format and submit it to the test server.
|
||||
Defaults to False.
|
||||
keep_results (bool): Whether to keep the results. When ``format_only``
|
||||
is True, ``keep_results`` must be True. Defaults to False.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
output_dir: str,
|
||||
ignore_index: int = 255,
|
||||
format_only: bool = False,
|
||||
keep_results: bool = False,
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
if CSEval is None:
|
||||
raise ImportError('Please run "pip install cityscapesscripts" to '
|
||||
'install cityscapesscripts first.')
|
||||
self.output_dir = output_dir
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
self.format_only = format_only
|
||||
if format_only:
|
||||
assert keep_results, (
|
||||
'When format_only is True, the results must be keep, please '
|
||||
f'set keep_results as True, but got {keep_results}')
|
||||
self.keep_results = keep_results
|
||||
self.prefix = prefix
|
||||
if is_main_process():
|
||||
mkdir_or_exist(self.output_dir)
|
||||
|
||||
@master_only
|
||||
def __del__(self) -> None:
|
||||
"""Clean up."""
|
||||
if not self.keep_results:
|
||||
shutil.rmtree(self.output_dir)
|
||||
|
||||
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and data_samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (dict): A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
mkdir_or_exist(self.output_dir)
|
||||
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy()
|
||||
# when evaluating with official cityscapesscripts,
|
||||
# labelIds should be used
|
||||
pred_label = self._convert_to_label_id(pred_label)
|
||||
basename = osp.splitext(osp.basename(data_sample['img_path']))[0]
|
||||
png_filename = osp.abspath(
|
||||
osp.join(self.output_dir, f'{basename}.png'))
|
||||
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
|
||||
output.save(png_filename)
|
||||
if self.format_only:
|
||||
# format_only always for test dataset without ground truth
|
||||
gt_filename = ''
|
||||
else:
|
||||
# when evaluating with official cityscapesscripts,
|
||||
# **_gtFine_labelIds.png is used
|
||||
gt_filename = data_sample['seg_map_path'].replace(
|
||||
'labelTrainIds.png', 'labelIds.png')
|
||||
self.results.append((png_filename, gt_filename))
|
||||
|
||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
|
||||
Returns:
|
||||
dict[str: float]: Cityscapes evaluation results.
|
||||
"""
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
if self.format_only:
|
||||
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
|
||||
return OrderedDict()
|
||||
|
||||
msg = 'Evaluating in Cityscapes style'
|
||||
if logger is None:
|
||||
msg = '\n' + msg
|
||||
print_log(msg, logger=logger)
|
||||
|
||||
eval_results = dict()
|
||||
print_log(
|
||||
f'Evaluating results under {self.output_dir} ...', logger=logger)
|
||||
|
||||
CSEval.args.evalInstLevelScore = True
|
||||
CSEval.args.predictionPath = osp.abspath(self.output_dir)
|
||||
CSEval.args.evalPixelAccuracy = True
|
||||
CSEval.args.JSONOutput = False
|
||||
|
||||
pred_list, gt_list = zip(*results)
|
||||
metric = dict()
|
||||
eval_results.update(
|
||||
CSEval.evaluateImgLists(pred_list, gt_list, CSEval.args))
|
||||
metric['averageScoreCategories'] = eval_results[
|
||||
'averageScoreCategories']
|
||||
metric['averageScoreInstCategories'] = eval_results[
|
||||
'averageScoreInstCategories']
|
||||
return metric
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_label_id(result):
|
||||
"""Convert trainId to id for cityscapes."""
|
||||
if isinstance(result, str):
|
||||
result = np.load(result)
|
||||
result_copy = result.copy()
|
||||
for trainId, label in CSLabels.trainId2label.items():
|
||||
result_copy[result == trainId] = label.id
|
||||
|
||||
return result_copy
|
||||
212
finetune/mmseg/evaluation/metrics/depth_metric.py
Normal file
212
finetune/mmseg/evaluation/metrics/depth_metric.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dist import is_main_process
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from prettytable import PrettyTable
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class DepthMetric(BaseMetric):
|
||||
"""Depth estimation evaluation metric.
|
||||
|
||||
Args:
|
||||
depth_metrics (List[str], optional): List of metrics to compute. If
|
||||
not specified, defaults to all metrics in self.METRICS.
|
||||
min_depth_eval (float): Minimum depth value for evaluation.
|
||||
Defaults to 0.0.
|
||||
max_depth_eval (float): Maximum depth value for evaluation.
|
||||
Defaults to infinity.
|
||||
crop_type (str, optional): Specifies the type of cropping to be used
|
||||
during evaluation. This option can affect how the evaluation mask
|
||||
is generated. Currently, 'nyu_crop' is supported, but other
|
||||
types can be added in future. Defaults to None if no cropping
|
||||
should be applied.
|
||||
depth_scale_factor (float): Factor to scale the depth values.
|
||||
Defaults to 1.0.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
output_dir (str): The directory for output prediction. Defaults to
|
||||
None.
|
||||
format_only (bool): Only format result for results commit without
|
||||
perform evaluation. It is useful when you want to save the result
|
||||
to a specific format and submit it to the test server.
|
||||
Defaults to False.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
"""
|
||||
METRICS = ('d1', 'd2', 'd3', 'abs_rel', 'sq_rel', 'rmse', 'rmse_log',
|
||||
'log10', 'silog')
|
||||
|
||||
def __init__(self,
|
||||
depth_metrics: Optional[List[str]] = None,
|
||||
min_depth_eval: float = 0.0,
|
||||
max_depth_eval: float = float('inf'),
|
||||
crop_type: Optional[str] = None,
|
||||
depth_scale_factor: float = 1.0,
|
||||
collect_device: str = 'cpu',
|
||||
output_dir: Optional[str] = None,
|
||||
format_only: bool = False,
|
||||
prefix: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
if depth_metrics is None:
|
||||
self.metrics = self.METRICS
|
||||
elif isinstance(depth_metrics, [tuple, list]):
|
||||
for metric in depth_metrics:
|
||||
assert metric in self.METRICS, f'the metric {metric} is not ' \
|
||||
f'supported. Please use metrics in {self.METRICS}'
|
||||
self.metrics = depth_metrics
|
||||
|
||||
# Validate crop_type, if provided
|
||||
assert crop_type in [
|
||||
None, 'nyu_crop'
|
||||
], (f'Invalid value for crop_type: {crop_type}. Supported values are '
|
||||
'None or \'nyu_crop\'.')
|
||||
self.crop_type = crop_type
|
||||
self.min_depth_eval = min_depth_eval
|
||||
self.max_depth_eval = max_depth_eval
|
||||
self.output_dir = output_dir
|
||||
if self.output_dir and is_main_process():
|
||||
mkdir_or_exist(self.output_dir)
|
||||
self.format_only = format_only
|
||||
self.depth_scale_factor = depth_scale_factor
|
||||
|
||||
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and data_samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (dict): A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_depth_map']['data'].squeeze()
|
||||
# format_only always for test dataset without ground truth
|
||||
if not self.format_only:
|
||||
gt_depth = data_sample['gt_depth_map']['data'].squeeze().to(
|
||||
pred_label)
|
||||
|
||||
eval_mask = self._get_eval_mask(gt_depth)
|
||||
self.results.append(
|
||||
(gt_depth[eval_mask], pred_label[eval_mask]))
|
||||
# format_result
|
||||
if self.output_dir is not None:
|
||||
basename = osp.splitext(osp.basename(
|
||||
data_sample['img_path']))[0]
|
||||
png_filename = osp.abspath(
|
||||
osp.join(self.output_dir, f'{basename}.png'))
|
||||
output_mask = pred_label.cpu().numpy(
|
||||
) * self.depth_scale_factor
|
||||
|
||||
cv2.imwrite(png_filename, output_mask.astype(np.uint16),
|
||||
[cv2.IMWRITE_PNG_COMPRESSION, 0])
|
||||
|
||||
def _get_eval_mask(self, gt_depth: Tensor):
|
||||
"""Generates an evaluation mask based on ground truth depth and
|
||||
cropping.
|
||||
|
||||
Args:
|
||||
gt_depth (Tensor): Ground truth depth map.
|
||||
|
||||
Returns:
|
||||
Tensor: Boolean mask where evaluation should be performed.
|
||||
"""
|
||||
valid_mask = torch.logical_and(gt_depth > self.min_depth_eval,
|
||||
gt_depth < self.max_depth_eval)
|
||||
|
||||
if self.crop_type == 'nyu_crop':
|
||||
# this implementation is adapted from
|
||||
# https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blob/main/depth/datasets/nyu.py # noqa
|
||||
crop_mask = torch.zeros_like(valid_mask)
|
||||
crop_mask[45:471, 41:601] = 1
|
||||
else:
|
||||
crop_mask = torch.ones_like(valid_mask)
|
||||
|
||||
eval_mask = torch.logical_and(valid_mask, crop_mask)
|
||||
return eval_mask
|
||||
|
||||
@staticmethod
|
||||
def _calc_all_metrics(gt_depth, pred_depth):
|
||||
"""Computes final evaluation metrics based on accumulated results."""
|
||||
assert gt_depth.shape == pred_depth.shape
|
||||
|
||||
thresh = torch.max((gt_depth / pred_depth), (pred_depth / gt_depth))
|
||||
diff = pred_depth - gt_depth
|
||||
diff_log = torch.log(pred_depth) - torch.log(gt_depth)
|
||||
|
||||
d1 = torch.sum(thresh < 1.25).float() / len(thresh)
|
||||
d2 = torch.sum(thresh < 1.25**2).float() / len(thresh)
|
||||
d3 = torch.sum(thresh < 1.25**3).float() / len(thresh)
|
||||
|
||||
abs_rel = torch.mean(torch.abs(diff) / gt_depth)
|
||||
sq_rel = torch.mean(torch.pow(diff, 2) / gt_depth)
|
||||
|
||||
rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
|
||||
rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log, 2)))
|
||||
|
||||
log10 = torch.mean(
|
||||
torch.abs(torch.log10(pred_depth) - torch.log10(gt_depth)))
|
||||
silog = torch.sqrt(
|
||||
torch.pow(diff_log, 2).mean() -
|
||||
0.5 * torch.pow(diff_log.mean(), 2))
|
||||
|
||||
return {
|
||||
'd1': d1.item(),
|
||||
'd2': d2.item(),
|
||||
'd3': d3.item(),
|
||||
'abs_rel': abs_rel.item(),
|
||||
'sq_rel': sq_rel.item(),
|
||||
'rmse': rmse.item(),
|
||||
'rmse_log': rmse_log.item(),
|
||||
'log10': log10.item(),
|
||||
'silog': silog.item()
|
||||
}
|
||||
|
||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: The computed metrics. The keys are the names of
|
||||
the metrics, and the values are corresponding results. The keys
|
||||
are identical with self.metrics.
|
||||
"""
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
if self.format_only:
|
||||
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
|
||||
return OrderedDict()
|
||||
|
||||
metrics = defaultdict(list)
|
||||
for gt_depth, pred_depth in results:
|
||||
for key, value in self._calc_all_metrics(gt_depth,
|
||||
pred_depth).items():
|
||||
metrics[key].append(value)
|
||||
metrics = {k: sum(metrics[k]) / len(metrics[k]) for k in self.metrics}
|
||||
|
||||
table_data = PrettyTable()
|
||||
for key, val in metrics.items():
|
||||
table_data.add_column(key, [round(val, 5)])
|
||||
|
||||
print_log('results:', logger)
|
||||
print_log('\n' + table_data.get_string(), logger=logger)
|
||||
|
||||
return metrics
|
||||
286
finetune/mmseg/evaluation/metrics/iou_metric.py
Normal file
286
finetune/mmseg/evaluation/metrics/iou_metric.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dist import is_main_process
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from PIL import Image
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from mmseg.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class IoUMetric(BaseMetric):
|
||||
"""IoU evaluation metric.
|
||||
|
||||
Args:
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
Default: 255.
|
||||
iou_metrics (list[str] | str): Metrics to be calculated, the options
|
||||
includes 'mIoU', 'mDice' and 'mFscore'.
|
||||
nan_to_num (int, optional): If specified, NaN values will be replaced
|
||||
by the numbers defined by the user. Default: None.
|
||||
beta (int): Determines the weight of recall in the combined score.
|
||||
Default: 1.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
output_dir (str): The directory for output prediction. Defaults to
|
||||
None.
|
||||
format_only (bool): Only format result for results commit without
|
||||
perform evaluation. It is useful when you want to save the result
|
||||
to a specific format and submit it to the test server.
|
||||
Defaults to False.
|
||||
prefix (str, optional): The prefix that will be added in the metric
|
||||
names to disambiguate homonymous metrics of different evaluators.
|
||||
If prefix is not provided in the argument, self.default_prefix
|
||||
will be used instead. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ignore_index: int = 255,
|
||||
iou_metrics: List[str] = ['mIoU'],
|
||||
nan_to_num: Optional[int] = None,
|
||||
beta: int = 1,
|
||||
collect_device: str = 'cpu',
|
||||
output_dir: Optional[str] = None,
|
||||
format_only: bool = False,
|
||||
prefix: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.metrics = iou_metrics
|
||||
self.nan_to_num = nan_to_num
|
||||
self.beta = beta
|
||||
self.output_dir = output_dir
|
||||
if self.output_dir and is_main_process():
|
||||
mkdir_or_exist(self.output_dir)
|
||||
self.format_only = format_only
|
||||
|
||||
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and data_samples.
|
||||
|
||||
The processed results should be stored in ``self.results``, which will
|
||||
be used to compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (dict): A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
num_classes = len(self.dataset_meta['classes'])
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_sem_seg']['data'].squeeze()
|
||||
# format_only always for test dataset without ground truth
|
||||
if not self.format_only:
|
||||
label = data_sample['gt_sem_seg']['data'].squeeze().to(
|
||||
pred_label)
|
||||
self.results.append(
|
||||
self.intersect_and_union(pred_label, label, num_classes,
|
||||
self.ignore_index))
|
||||
# format_result
|
||||
if self.output_dir is not None:
|
||||
basename = osp.splitext(osp.basename(
|
||||
data_sample['img_path']))[0]
|
||||
png_filename = osp.abspath(
|
||||
osp.join(self.output_dir, f'{basename}.png'))
|
||||
output_mask = pred_label.cpu().numpy()
|
||||
# The index range of official ADE20k dataset is from 0 to 150.
|
||||
# But the index range of output is from 0 to 149.
|
||||
# That is because we set reduce_zero_label=True.
|
||||
if data_sample.get('reduce_zero_label', False):
|
||||
output_mask = output_mask + 1
|
||||
output = Image.fromarray(output_mask.astype(np.uint8))
|
||||
output.save(png_filename)
|
||||
|
||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: The computed metrics. The keys are the names of
|
||||
the metrics, and the values are corresponding results. The key
|
||||
mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision,
|
||||
mRecall.
|
||||
"""
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
if self.format_only:
|
||||
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
|
||||
return OrderedDict()
|
||||
# convert list of tuples to tuple of lists, e.g.
|
||||
# [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to
|
||||
# ([A_1, ..., A_n], ..., [D_1, ..., D_n])
|
||||
results = tuple(zip(*results))
|
||||
assert len(results) == 4
|
||||
|
||||
total_area_intersect = sum(results[0])
|
||||
total_area_union = sum(results[1])
|
||||
total_area_pred_label = sum(results[2])
|
||||
total_area_label = sum(results[3])
|
||||
ret_metrics = self.total_area_to_metrics(
|
||||
total_area_intersect, total_area_union, total_area_pred_label,
|
||||
total_area_label, self.metrics, self.nan_to_num, self.beta)
|
||||
|
||||
class_names = self.dataset_meta['classes']
|
||||
|
||||
# summary table
|
||||
ret_metrics_summary = OrderedDict({
|
||||
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
|
||||
for ret_metric, ret_metric_value in ret_metrics.items()
|
||||
})
|
||||
metrics = dict()
|
||||
for key, val in ret_metrics_summary.items():
|
||||
if key == 'aAcc':
|
||||
metrics[key] = val
|
||||
else:
|
||||
metrics['m' + key] = val
|
||||
|
||||
# each class table
|
||||
ret_metrics.pop('aAcc', None)
|
||||
ret_metrics_class = OrderedDict({
|
||||
ret_metric: np.round(ret_metric_value * 100, 2)
|
||||
for ret_metric, ret_metric_value in ret_metrics.items()
|
||||
})
|
||||
ret_metrics_class.update({'Class': class_names})
|
||||
ret_metrics_class.move_to_end('Class', last=False)
|
||||
class_table_data = PrettyTable()
|
||||
for key, val in ret_metrics_class.items():
|
||||
class_table_data.add_column(key, val)
|
||||
|
||||
print_log('per class results:', logger)
|
||||
print_log('\n' + class_table_data.get_string(), logger=logger)
|
||||
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def intersect_and_union(pred_label: torch.tensor, label: torch.tensor,
|
||||
num_classes: int, ignore_index: int):
|
||||
"""Calculate Intersection and Union.
|
||||
|
||||
Args:
|
||||
pred_label (torch.tensor): Prediction segmentation map
|
||||
or predict result filename. The shape is (H, W).
|
||||
label (torch.tensor): Ground truth segmentation map
|
||||
or label filename. The shape is (H, W).
|
||||
num_classes (int): Number of categories.
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The intersection of prediction and ground truth
|
||||
histogram on all classes.
|
||||
torch.Tensor: The union of prediction and ground truth histogram on
|
||||
all classes.
|
||||
torch.Tensor: The prediction histogram on all classes.
|
||||
torch.Tensor: The ground truth histogram on all classes.
|
||||
"""
|
||||
|
||||
mask = (label != ignore_index)
|
||||
pred_label = pred_label[mask]
|
||||
label = label[mask]
|
||||
|
||||
intersect = pred_label[pred_label == label]
|
||||
area_intersect = torch.histc(
|
||||
intersect.float(), bins=(num_classes), min=0,
|
||||
max=num_classes - 1).cpu()
|
||||
area_pred_label = torch.histc(
|
||||
pred_label.float(), bins=(num_classes), min=0,
|
||||
max=num_classes - 1).cpu()
|
||||
area_label = torch.histc(
|
||||
label.float(), bins=(num_classes), min=0,
|
||||
max=num_classes - 1).cpu()
|
||||
area_union = area_pred_label + area_label - area_intersect
|
||||
return area_intersect, area_union, area_pred_label, area_label
|
||||
|
||||
@staticmethod
|
||||
def total_area_to_metrics(total_area_intersect: np.ndarray,
|
||||
total_area_union: np.ndarray,
|
||||
total_area_pred_label: np.ndarray,
|
||||
total_area_label: np.ndarray,
|
||||
metrics: List[str] = ['mIoU'],
|
||||
nan_to_num: Optional[int] = None,
|
||||
beta: int = 1):
|
||||
"""Calculate evaluation metrics
|
||||
Args:
|
||||
total_area_intersect (np.ndarray): The intersection of prediction
|
||||
and ground truth histogram on all classes.
|
||||
total_area_union (np.ndarray): The union of prediction and ground
|
||||
truth histogram on all classes.
|
||||
total_area_pred_label (np.ndarray): The prediction histogram on
|
||||
all classes.
|
||||
total_area_label (np.ndarray): The ground truth histogram on
|
||||
all classes.
|
||||
metrics (List[str] | str): Metrics to be evaluated, 'mIoU' and
|
||||
'mDice'.
|
||||
nan_to_num (int, optional): If specified, NaN values will be
|
||||
replaced by the numbers defined by the user. Default: None.
|
||||
beta (int): Determines the weight of recall in the combined score.
|
||||
Default: 1.
|
||||
Returns:
|
||||
Dict[str, np.ndarray]: per category evaluation metrics,
|
||||
shape (num_classes, ).
|
||||
"""
|
||||
|
||||
def f_score(precision, recall, beta=1):
|
||||
"""calculate the f-score value.
|
||||
|
||||
Args:
|
||||
precision (float | torch.Tensor): The precision value.
|
||||
recall (float | torch.Tensor): The recall value.
|
||||
beta (int): Determines the weight of recall in the combined
|
||||
score. Default: 1.
|
||||
|
||||
Returns:
|
||||
[torch.tensor]: The f-score value.
|
||||
"""
|
||||
score = (1 + beta**2) * (precision * recall) / (
|
||||
(beta**2 * precision) + recall)
|
||||
return score
|
||||
|
||||
if isinstance(metrics, str):
|
||||
metrics = [metrics]
|
||||
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
|
||||
if not set(metrics).issubset(set(allowed_metrics)):
|
||||
raise KeyError(f'metrics {metrics} is not supported')
|
||||
|
||||
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
||||
ret_metrics = OrderedDict({'aAcc': all_acc})
|
||||
for metric in metrics:
|
||||
if metric == 'mIoU':
|
||||
iou = total_area_intersect / total_area_union
|
||||
acc = total_area_intersect / total_area_label
|
||||
ret_metrics['IoU'] = iou
|
||||
ret_metrics['Acc'] = acc
|
||||
elif metric == 'mDice':
|
||||
dice = 2 * total_area_intersect / (
|
||||
total_area_pred_label + total_area_label)
|
||||
acc = total_area_intersect / total_area_label
|
||||
ret_metrics['Dice'] = dice
|
||||
ret_metrics['Acc'] = acc
|
||||
elif metric == 'mFscore':
|
||||
precision = total_area_intersect / total_area_pred_label
|
||||
recall = total_area_intersect / total_area_label
|
||||
f_value = torch.tensor([
|
||||
f_score(x[0], x[1], beta) for x in zip(precision, recall)
|
||||
])
|
||||
ret_metrics['Fscore'] = f_value
|
||||
ret_metrics['Precision'] = precision
|
||||
ret_metrics['Recall'] = recall
|
||||
|
||||
ret_metrics = {
|
||||
metric: value.numpy()
|
||||
for metric, value in ret_metrics.items()
|
||||
}
|
||||
if nan_to_num is not None:
|
||||
ret_metrics = OrderedDict({
|
||||
metric: np.nan_to_num(metric_value, nan=nan_to_num)
|
||||
for metric, metric_value in ret_metrics.items()
|
||||
})
|
||||
return ret_metrics
|
||||
16
finetune/mmseg/models/__init__.py
Normal file
16
finetune/mmseg/models/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .assigners import * # noqa: F401,F403
|
||||
from .backbones import * # noqa: F401,F403
|
||||
from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
|
||||
build_head, build_loss, build_segmentor)
|
||||
from .data_preprocessor import SegDataPreProcessor
|
||||
from .decode_heads import * # noqa: F401,F403
|
||||
from .losses import * # noqa: F401,F403
|
||||
from .necks import * # noqa: F401,F403
|
||||
from .segmentors import * # noqa: F401,F403
|
||||
from .text_encoder import * # noqa: F401,F403
|
||||
|
||||
__all__ = [
|
||||
'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
|
||||
'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor'
|
||||
]
|
||||
12
finetune/mmseg/models/assigners/__init__.py
Normal file
12
finetune/mmseg/models/assigners/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_assigner import BaseAssigner
|
||||
from .hungarian_assigner import HungarianAssigner
|
||||
from .match_cost import ClassificationCost, CrossEntropyLossCost, DiceCost
|
||||
|
||||
__all__ = [
|
||||
'BaseAssigner',
|
||||
'HungarianAssigner',
|
||||
'ClassificationCost',
|
||||
'CrossEntropyLossCost',
|
||||
'DiceCost',
|
||||
]
|
||||
18
finetune/mmseg/models/assigners/base_assigner.py
Normal file
18
finetune/mmseg/models/assigners/base_assigner.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from mmengine.structures import InstanceData
|
||||
|
||||
|
||||
class BaseAssigner(metaclass=ABCMeta):
|
||||
"""Base assigner that assigns masks to ground truth class labels."""
|
||||
|
||||
@abstractmethod
|
||||
def assign(self,
|
||||
pred_instances: InstanceData,
|
||||
gt_instances: InstanceData,
|
||||
gt_instances_ignore: Optional[InstanceData] = None,
|
||||
**kwargs):
|
||||
"""Assign masks to either a ground truth class label or a negative
|
||||
label."""
|
||||
86
finetune/mmseg/models/assigners/hungarian_assigner.py
Normal file
86
finetune/mmseg/models/assigners/hungarian_assigner.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from mmengine import ConfigDict
|
||||
from mmengine.structures import InstanceData
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
from .base_assigner import BaseAssigner
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class HungarianAssigner(BaseAssigner):
|
||||
"""Computes one-to-one matching between prediction masks and ground truth.
|
||||
|
||||
This class uses bipartite matching-based assignment to computes an
|
||||
assignment between the prediction masks and the ground truth. The
|
||||
assignment result is based on the weighted sum of match costs. The
|
||||
Hungarian algorithm is used to calculate the best matching with the
|
||||
minimum cost. The prediction masks that are not matched are classified
|
||||
as background.
|
||||
|
||||
Args:
|
||||
match_costs (ConfigDict|List[ConfigDict]): Match cost configs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, match_costs: Union[List[Union[dict, ConfigDict]], dict,
|
||||
ConfigDict]
|
||||
) -> None:
|
||||
|
||||
if isinstance(match_costs, dict):
|
||||
match_costs = [match_costs]
|
||||
elif isinstance(match_costs, list):
|
||||
assert len(match_costs) > 0, \
|
||||
'match_costs must not be a empty list.'
|
||||
|
||||
self.match_costs = [
|
||||
TASK_UTILS.build(match_cost) for match_cost in match_costs
|
||||
]
|
||||
|
||||
def assign(self, pred_instances: InstanceData, gt_instances: InstanceData,
|
||||
**kwargs):
|
||||
"""Computes one-to-one matching based on the weighted costs.
|
||||
|
||||
This method assign each query prediction to a ground truth or
|
||||
background. The assignment first calculates the cost for each
|
||||
category assigned to each query mask, and then uses the
|
||||
Hungarian algorithm to calculate the minimum cost as the best
|
||||
match.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): Instances of model
|
||||
predictions. It includes "masks", with shape
|
||||
(n, h, w) or (n, l), and "cls", with shape (n, num_classes+1)
|
||||
gt_instances (InstanceData): Ground truth of instance
|
||||
annotations. It includes "labels", with shape (k, ),
|
||||
and "masks", with shape (k, h, w) or (k, l).
|
||||
|
||||
Returns:
|
||||
matched_quiery_inds (Tensor): The indexes of matched quieres.
|
||||
matched_label_inds (Tensor): The indexes of matched labels.
|
||||
"""
|
||||
# compute weighted cost
|
||||
cost_list = []
|
||||
with autocast(enabled=False):
|
||||
for match_cost in self.match_costs:
|
||||
cost = match_cost(
|
||||
pred_instances=pred_instances, gt_instances=gt_instances)
|
||||
cost_list.append(cost)
|
||||
cost = torch.stack(cost_list).sum(dim=0)
|
||||
|
||||
device = cost.device
|
||||
# do Hungarian matching on CPU using linear_sum_assignment
|
||||
cost = cost.detach().cpu()
|
||||
if linear_sum_assignment is None:
|
||||
raise ImportError('Please run "pip install scipy" '
|
||||
'to install scipy first.')
|
||||
|
||||
matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost)
|
||||
matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device)
|
||||
matched_label_inds = torch.from_numpy(matched_label_inds).to(device)
|
||||
|
||||
return matched_quiery_inds, matched_label_inds
|
||||
231
finetune/mmseg/models/assigners/match_cost.py
Normal file
231
finetune/mmseg/models/assigners/match_cost.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import abstractmethod
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
|
||||
|
||||
class BaseMatchCost:
|
||||
"""Base match cost class.
|
||||
|
||||
Args:
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, weight: Union[float, int] = 1.) -> None:
|
||||
self.weight = weight
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): Instances of model predictions.
|
||||
It often includes "labels" and "scores".
|
||||
gt_instances (InstanceData): Ground truth of instance
|
||||
annotations. It usually includes "labels".
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ClassificationCost(BaseMatchCost):
|
||||
"""ClsSoftmaxCost.
|
||||
|
||||
Args:
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
|
||||
Examples:
|
||||
>>> from mmseg.models.assigners import ClassificationCost
|
||||
>>> import torch
|
||||
>>> self = ClassificationCost()
|
||||
>>> cls_pred = torch.rand(4, 3)
|
||||
>>> gt_labels = torch.tensor([0, 1, 2])
|
||||
>>> factor = torch.tensor([10, 8, 10, 8])
|
||||
>>> self(cls_pred, gt_labels)
|
||||
tensor([[-0.3430, -0.3525, -0.3045],
|
||||
[-0.3077, -0.2931, -0.3992],
|
||||
[-0.3664, -0.3455, -0.2881],
|
||||
[-0.3343, -0.2701, -0.3956]])
|
||||
"""
|
||||
|
||||
def __init__(self, weight: Union[float, int] = 1) -> None:
|
||||
super().__init__(weight=weight)
|
||||
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): "scores" inside is
|
||||
predicted classification logits, of shape
|
||||
(num_queries, num_class).
|
||||
gt_instances (InstanceData): "labels" inside should have
|
||||
shape (num_gt, ).
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
assert hasattr(pred_instances, 'scores'), \
|
||||
"pred_instances must contain 'scores'"
|
||||
assert hasattr(gt_instances, 'labels'), \
|
||||
"gt_instances must contain 'labels'"
|
||||
pred_scores = pred_instances.scores
|
||||
gt_labels = gt_instances.labels
|
||||
|
||||
pred_scores = pred_scores.softmax(-1)
|
||||
cls_cost = -pred_scores[:, gt_labels]
|
||||
|
||||
return cls_cost * self.weight
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class DiceCost(BaseMatchCost):
|
||||
"""Cost of mask assignments based on dice losses.
|
||||
|
||||
Args:
|
||||
pred_act (bool): Whether to apply sigmoid to mask_pred.
|
||||
Defaults to False.
|
||||
eps (float): Defaults to 1e-3.
|
||||
naive_dice (bool): If True, use the naive dice loss
|
||||
in which the power of the number in the denominator is
|
||||
the first power. If False, use the second power that
|
||||
is adopted by K-Net and SOLO. Defaults to True.
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pred_act: bool = False,
|
||||
eps: float = 1e-3,
|
||||
naive_dice: bool = True,
|
||||
weight: Union[float, int] = 1.) -> None:
|
||||
super().__init__(weight=weight)
|
||||
self.pred_act = pred_act
|
||||
self.eps = eps
|
||||
self.naive_dice = naive_dice
|
||||
|
||||
def _binary_mask_dice_loss(self, mask_preds: Tensor,
|
||||
gt_masks: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
mask_preds (Tensor): Mask prediction in shape (num_queries, *).
|
||||
gt_masks (Tensor): Ground truth in shape (num_gt, *)
|
||||
store 0 or 1, 0 for negative class and 1 for
|
||||
positive class.
|
||||
|
||||
Returns:
|
||||
Tensor: Dice cost matrix in shape (num_queries, num_gt).
|
||||
"""
|
||||
mask_preds = mask_preds.flatten(1)
|
||||
gt_masks = gt_masks.flatten(1).float()
|
||||
numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
|
||||
if self.naive_dice:
|
||||
denominator = mask_preds.sum(-1)[:, None] + \
|
||||
gt_masks.sum(-1)[None, :]
|
||||
else:
|
||||
denominator = mask_preds.pow(2).sum(1)[:, None] + \
|
||||
gt_masks.pow(2).sum(1)[None, :]
|
||||
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
|
||||
return loss
|
||||
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (InstanceData): Predicted instances which
|
||||
must contain "masks".
|
||||
gt_instances (InstanceData): Ground truth which must contain
|
||||
"mask".
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
assert hasattr(pred_instances, 'masks'), \
|
||||
"pred_instances must contain 'masks'"
|
||||
assert hasattr(gt_instances, 'masks'), \
|
||||
"gt_instances must contain 'masks'"
|
||||
pred_masks = pred_instances.masks
|
||||
gt_masks = gt_instances.masks
|
||||
|
||||
if self.pred_act:
|
||||
pred_masks = pred_masks.sigmoid()
|
||||
dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks)
|
||||
return dice_cost * self.weight
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class CrossEntropyLossCost(BaseMatchCost):
|
||||
"""CrossEntropyLossCost.
|
||||
|
||||
Args:
|
||||
use_sigmoid (bool): Whether the prediction uses sigmoid
|
||||
of softmax. Defaults to True.
|
||||
weight (Union[float, int]): Cost weight. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
use_sigmoid: bool = True,
|
||||
weight: Union[float, int] = 1.) -> None:
|
||||
super().__init__(weight=weight)
|
||||
self.use_sigmoid = use_sigmoid
|
||||
|
||||
def _binary_cross_entropy(self, cls_pred: Tensor,
|
||||
gt_labels: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or
|
||||
(num_queries, *).
|
||||
gt_labels (Tensor): The learning label of prediction with
|
||||
shape (num_gt, *).
|
||||
|
||||
Returns:
|
||||
Tensor: Cross entropy cost matrix in shape (num_queries, num_gt).
|
||||
"""
|
||||
cls_pred = cls_pred.flatten(1).float()
|
||||
gt_labels = gt_labels.flatten(1).float()
|
||||
n = cls_pred.shape[1]
|
||||
pos = F.binary_cross_entropy_with_logits(
|
||||
cls_pred, torch.ones_like(cls_pred), reduction='none')
|
||||
neg = F.binary_cross_entropy_with_logits(
|
||||
cls_pred, torch.zeros_like(cls_pred), reduction='none')
|
||||
cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
|
||||
torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
|
||||
cls_cost = cls_cost / n
|
||||
|
||||
return cls_cost
|
||||
|
||||
def __call__(self, pred_instances: InstanceData,
|
||||
gt_instances: InstanceData, **kwargs) -> Tensor:
|
||||
"""Compute match cost.
|
||||
|
||||
Args:
|
||||
pred_instances (:obj:`InstanceData`): Predicted instances which
|
||||
must contain ``masks``.
|
||||
gt_instances (:obj:`InstanceData`): Ground truth which must contain
|
||||
``masks``.
|
||||
|
||||
Returns:
|
||||
Tensor: Match Cost matrix of shape (num_preds, num_gts).
|
||||
"""
|
||||
assert hasattr(pred_instances, 'masks'), \
|
||||
"pred_instances must contain 'masks'"
|
||||
assert hasattr(gt_instances, 'masks'), \
|
||||
"gt_instances must contain 'masks'"
|
||||
pred_masks = pred_instances.masks
|
||||
gt_masks = gt_instances.masks
|
||||
if self.use_sigmoid:
|
||||
cls_cost = self._binary_cross_entropy(pred_masks, gt_masks)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return cls_cost * self.weight
|
||||
35
finetune/mmseg/models/backbones/__init__.py
Normal file
35
finetune/mmseg/models/backbones/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .beit import BEiT
|
||||
from .bisenetv1 import BiSeNetV1
|
||||
from .bisenetv2 import BiSeNetV2
|
||||
from .cgnet import CGNet
|
||||
from .ddrnet import DDRNet
|
||||
from .erfnet import ERFNet
|
||||
from .fast_scnn import FastSCNN
|
||||
from .hrnet import HRNet
|
||||
from .icnet import ICNet
|
||||
from .mae import MAE
|
||||
from .mit import MixVisionTransformer
|
||||
from .mobilenet_v2 import MobileNetV2
|
||||
from .mobilenet_v3 import MobileNetV3
|
||||
from .mscan import MSCAN
|
||||
from .pidnet import PIDNet
|
||||
from .resnest import ResNeSt
|
||||
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
||||
from .resnext import ResNeXt
|
||||
from .stdc import STDCContextPathNet, STDCNet
|
||||
from .swin import SwinTransformer
|
||||
from .timm_backbone import TIMMBackbone
|
||||
from .twins import PCPVT, SVT
|
||||
from .unet import UNet
|
||||
from .vit import VisionTransformer
|
||||
from .vpd import VPD
|
||||
|
||||
__all__ = [
|
||||
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
|
||||
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
|
||||
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
|
||||
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
|
||||
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN',
|
||||
'DDRNet', 'VPD'
|
||||
]
|
||||
554
finetune/mmseg/models/backbones/beit.py
Normal file
554
finetune/mmseg/models/backbones/beit.py
Normal file
@@ -0,0 +1,554 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import _load_checkpoint
|
||||
from scipy import interpolate
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import PatchEmbed
|
||||
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
|
||||
|
||||
|
||||
class BEiTAttention(BaseModule):
|
||||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||||
position bias.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
bias (bool): The option to add leanable bias for q, k, v. If bias is
|
||||
True, it will add leanable bias. If bias is 'qv_bias', it will only
|
||||
add leanable bias for q, v. If bias is False, it will not add bias
|
||||
for q, k, v. Default to 'qv_bias'.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
attn_drop_rate (float): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
proj_drop_rate (float): Dropout ratio of output. Default: 0.
|
||||
init_cfg (dict | None, optional): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size,
|
||||
bias='qv_bias',
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
head_embed_dims = embed_dims // num_heads
|
||||
self.bias = bias
|
||||
self.scale = qk_scale or head_embed_dims**-0.5
|
||||
|
||||
qkv_bias = bias
|
||||
if bias == 'qv_bias':
|
||||
self._init_qv_bias()
|
||||
qkv_bias = False
|
||||
|
||||
self.window_size = window_size
|
||||
self._init_rel_pos_embedding()
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop_rate)
|
||||
|
||||
def _init_qv_bias(self):
|
||||
self.q_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
||||
self.v_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
||||
|
||||
def _init_rel_pos_embedding(self):
|
||||
Wh, Ww = self.window_size
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3
|
||||
# relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH)
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, self.num_heads))
|
||||
|
||||
# get pair-wise relative position index for
|
||||
# each token inside the window
|
||||
coords_h = torch.arange(Wh)
|
||||
coords_w = torch.arange(Ww)
|
||||
# coords shape is (2, Wh, Ww)
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
# coords_flatten shape is (2, Wh*Ww)
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :])
|
||||
# relative_coords shape is (Wh*Ww, Wh*Ww, 2)
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
# shift to start from 0
|
||||
relative_coords[:, :, 0] += Wh - 1
|
||||
relative_coords[:, :, 1] += Ww - 1
|
||||
relative_coords[:, :, 0] *= 2 * Ww - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype)
|
||||
# relative_position_index shape is (Wh*Ww, Wh*Ww)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1)
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (tensor): input features with shape of (num_windows*B, N, C).
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
|
||||
if self.bias == 'qv_bias':
|
||||
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
|
||||
qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias))
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
else:
|
||||
qkv = self.qkv(x)
|
||||
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
if self.relative_position_bias_table is not None:
|
||||
Wh = self.window_size[0]
|
||||
Ww = self.window_size[1]
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
Wh * Ww + 1, Wh * Ww + 1, -1)
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer):
|
||||
"""Implements one encoder layer in Vision Transformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default: 0.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
bias (bool): The option to add leanable bias for q, k, v. If bias is
|
||||
True, it will add leanable bias. If bias is 'qv_bias', it will only
|
||||
add leanable bias for q, v. If bias is False, it will not add bias
|
||||
for q, k, v. Default to 'qv_bias'.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
window_size (tuple[int], optional): The height and width of the window.
|
||||
Default: None.
|
||||
init_values (float, optional): Initialize the values of BEiTAttention
|
||||
and FFN with learnable scaling. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
bias='qv_bias',
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
window_size=None,
|
||||
attn_cfg=dict(),
|
||||
ffn_cfg=dict(add_identity=False),
|
||||
init_values=None):
|
||||
attn_cfg.update(dict(window_size=window_size, qk_scale=None))
|
||||
|
||||
super().__init__(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=feedforward_channels,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=0.,
|
||||
drop_rate=0.,
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
attn_cfg=attn_cfg,
|
||||
ffn_cfg=ffn_cfg)
|
||||
|
||||
# NOTE: drop path for stochastic depth, we shall see if
|
||||
# this is better than dropout here
|
||||
dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
self.drop_path = build_dropout(
|
||||
dropout_layer) if dropout_layer else nn.Identity()
|
||||
self.gamma_1 = nn.Parameter(
|
||||
init_values * torch.ones(embed_dims), requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(
|
||||
init_values * torch.ones(embed_dims), requires_grad=True)
|
||||
|
||||
def build_attn(self, attn_cfg):
|
||||
self.attn = BEiTAttention(**attn_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BEiT(BaseModule):
|
||||
"""BERT Pre-Training of Image Transformers.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Default: 224.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): Embedding dimension. Default: 768.
|
||||
num_layers (int): Depth of transformer. Default: 12.
|
||||
num_heads (int): Number of attention heads. Default: 12.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
out_indices (list | tuple | int): Output from which stages.
|
||||
Default: -1.
|
||||
qv_bias (bool): Enable bias for qv if True. Default: True.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
pretrained (str, optional): Model pretrained path. Default: None.
|
||||
init_values (float): Initialize the values of BEiTAttention and FFN
|
||||
with learnable scaling.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
out_indices=-1,
|
||||
qv_bias=True,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
final_norm=False,
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
pretrained=None,
|
||||
init_values=0.1,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
elif isinstance(img_size, tuple):
|
||||
if len(img_size) == 1:
|
||||
img_size = to_2tuple(img_size[0])
|
||||
assert len(img_size) == 2, \
|
||||
f'The size of image should have length 1 or 2, ' \
|
||||
f'but got {len(img_size)}'
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.norm_eval = norm_eval
|
||||
self.pretrained = pretrained
|
||||
self.num_layers = num_layers
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.attn_drop_rate = attn_drop_rate
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.num_fcs = num_fcs
|
||||
self.qv_bias = qv_bias
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.patch_norm = patch_norm
|
||||
self.init_values = init_values
|
||||
self.window_size = (img_size[0] // patch_size,
|
||||
img_size[1] // patch_size)
|
||||
self.patch_shape = self.window_size
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
|
||||
self._build_patch_embedding()
|
||||
self._build_layers()
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
if out_indices == -1:
|
||||
out_indices = num_layers - 1
|
||||
self.out_indices = [out_indices]
|
||||
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
|
||||
self.out_indices = out_indices
|
||||
else:
|
||||
raise TypeError('out_indices must be type of int, list or tuple')
|
||||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
def _build_patch_embedding(self):
|
||||
"""Build patch embedding layer."""
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=self.in_channels,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
padding=0,
|
||||
norm_cfg=self.norm_cfg if self.patch_norm else None,
|
||||
init_cfg=None)
|
||||
|
||||
def _build_layers(self):
|
||||
"""Build transformer encoding layers."""
|
||||
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
|
||||
]
|
||||
self.layers = ModuleList()
|
||||
for i in range(self.num_layers):
|
||||
self.layers.append(
|
||||
BEiTTransformerEncoderLayer(
|
||||
embed_dims=self.embed_dims,
|
||||
num_heads=self.num_heads,
|
||||
feedforward_channels=self.mlp_ratio * self.embed_dims,
|
||||
attn_drop_rate=self.attn_drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=self.num_fcs,
|
||||
bias='qv_bias' if self.qv_bias else False,
|
||||
act_cfg=self.act_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
window_size=self.window_size,
|
||||
init_values=self.init_values))
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def _geometric_sequence_interpolation(self, src_size, dst_size, sequence,
|
||||
num):
|
||||
"""Get new sequence via geometric sequence interpolation.
|
||||
|
||||
Args:
|
||||
src_size (int): Pos_embedding size in pre-trained model.
|
||||
dst_size (int): Pos_embedding size in the current model.
|
||||
sequence (tensor): The relative position bias of the pretrain
|
||||
model after removing the extra tokens.
|
||||
num (int): Number of attention heads.
|
||||
Returns:
|
||||
new_sequence (tensor): Geometric sequence interpolate the
|
||||
pre-trained relative position bias to the size of
|
||||
the current model.
|
||||
"""
|
||||
|
||||
def geometric_progression(a, r, n):
|
||||
return a * (1.0 - r**n) / (1.0 - r)
|
||||
|
||||
# Here is a binary function.
|
||||
left, right = 1.01, 1.5
|
||||
while right - left > 1e-6:
|
||||
q = (left + right) / 2.0
|
||||
gp = geometric_progression(1, q, src_size // 2)
|
||||
if gp > dst_size // 2:
|
||||
right = q
|
||||
else:
|
||||
left = q
|
||||
# The position of each interpolated point is determined
|
||||
# by the ratio obtained by dichotomy.
|
||||
dis = []
|
||||
cur = 1
|
||||
for i in range(src_size // 2):
|
||||
dis.append(cur)
|
||||
cur += q**(i + 1)
|
||||
r_ids = [-_ for _ in reversed(dis)]
|
||||
x = r_ids + [0] + dis
|
||||
y = r_ids + [0] + dis
|
||||
t = dst_size // 2.0
|
||||
dx = np.arange(-t, t + 0.1, 1.0)
|
||||
dy = np.arange(-t, t + 0.1, 1.0)
|
||||
# Interpolation functions are being executed and called.
|
||||
new_sequence = []
|
||||
for i in range(num):
|
||||
z = sequence[:, i].view(src_size, src_size).float().numpy()
|
||||
f = interpolate.interp2d(x, y, z, kind='cubic')
|
||||
new_sequence.append(
|
||||
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence))
|
||||
new_sequence = torch.cat(new_sequence, dim=-1)
|
||||
return new_sequence
|
||||
|
||||
def resize_rel_pos_embed(self, checkpoint):
|
||||
"""Resize relative pos_embed weights.
|
||||
|
||||
This function is modified from
|
||||
https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501
|
||||
Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT License
|
||||
Args:
|
||||
checkpoint (dict): Key and value of the pretrain model.
|
||||
Returns:
|
||||
state_dict (dict): Interpolate the relative pos_embed weights
|
||||
in the pre-train model to the current model size.
|
||||
"""
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
for key in all_keys:
|
||||
if 'relative_position_index' in key:
|
||||
state_dict.pop(key)
|
||||
# In order to keep the center of pos_bias as consistent as
|
||||
# possible after interpolation, and vice versa in the edge
|
||||
# area, the geometric sequence interpolation method is adopted.
|
||||
if 'relative_position_bias_table' in key:
|
||||
rel_pos_bias = state_dict[key]
|
||||
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
||||
dst_num_pos, _ = self.state_dict()[key].size()
|
||||
dst_patch_shape = self.patch_shape
|
||||
if dst_patch_shape[0] != dst_patch_shape[1]:
|
||||
raise NotImplementedError()
|
||||
# Count the number of extra tokens.
|
||||
num_extra_tokens = dst_num_pos - (
|
||||
dst_patch_shape[0] * 2 - 1) * (
|
||||
dst_patch_shape[1] * 2 - 1)
|
||||
src_size = int((src_num_pos - num_extra_tokens)**0.5)
|
||||
dst_size = int((dst_num_pos - num_extra_tokens)**0.5)
|
||||
if src_size != dst_size:
|
||||
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
||||
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
||||
new_rel_pos_bias = self._geometric_sequence_interpolation(
|
||||
src_size, dst_size, rel_pos_bias, num_attn_heads)
|
||||
new_rel_pos_bias = torch.cat(
|
||||
(new_rel_pos_bias, extra_tokens), dim=0)
|
||||
state_dict[key] = new_rel_pos_bias
|
||||
|
||||
return state_dict
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
self.apply(_init_weights)
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
checkpoint = _load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
state_dict = self.resize_rel_pos_embed(checkpoint)
|
||||
self.load_state_dict(state_dict, False)
|
||||
elif self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
# Copyright 2019 Ross Wightman
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def forward(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == len(self.layers) - 1:
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
if i in self.out_indices:
|
||||
# Remove class token and reshape token for decoder head
|
||||
out = x[:, 1:]
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
332
finetune/mmseg/models/backbones/bisenetv1.py
Normal file
332
finetune/mmseg/models/backbones/bisenetv1.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class SpatialPath(BaseModule):
|
||||
"""Spatial Path to preserve the spatial size of the original input image
|
||||
and encode affluent spatial information.
|
||||
|
||||
Args:
|
||||
in_channels(int): The number of channels of input
|
||||
image. Default: 3.
|
||||
num_channels (Tuple[int]): The number of channels of
|
||||
each layers in Spatial Path.
|
||||
Default: (64, 64, 64, 128).
|
||||
Returns:
|
||||
x (torch.Tensor): Feature map for Feature Fusion Module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
num_channels=(64, 64, 64, 128),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(num_channels) == 4, 'Length of input channels \
|
||||
of Spatial Path must be 4!'
|
||||
|
||||
self.layers = []
|
||||
for i in range(len(num_channels)):
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.layers.append(layer_name)
|
||||
if i == 0:
|
||||
self.add_module(
|
||||
layer_name,
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=num_channels[i],
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
elif i == len(num_channels) - 1:
|
||||
self.add_module(
|
||||
layer_name,
|
||||
ConvModule(
|
||||
in_channels=num_channels[i - 1],
|
||||
out_channels=num_channels[i],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
else:
|
||||
self.add_module(
|
||||
layer_name,
|
||||
ConvModule(
|
||||
in_channels=num_channels[i - 1],
|
||||
out_channels=num_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer_stage = getattr(self, layer_name)
|
||||
x = layer_stage(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionRefinementModule(BaseModule):
|
||||
"""Attention Refinement Module (ARM) to refine the features of each stage.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
Returns:
|
||||
x_out (torch.Tensor): Feature map of Attention Refinement Module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channel,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv_layer = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.atten_conv_layer = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
in_channels=out_channel,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_layer(x)
|
||||
x_atten = self.atten_conv_layer(x)
|
||||
x_out = x * x_atten
|
||||
return x_out
|
||||
|
||||
|
||||
class ContextPath(BaseModule):
|
||||
"""Context Path to provide sufficient receptive field.
|
||||
|
||||
Args:
|
||||
backbone_cfg:(dict): Config of backbone of
|
||||
Context Path.
|
||||
context_channels (Tuple[int]): The number of channel numbers
|
||||
of various modules in Context Path.
|
||||
Default: (128, 256, 512).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation. Default: False.
|
||||
Returns:
|
||||
x_16_up, x_32_up (torch.Tensor, torch.Tensor): Two feature maps
|
||||
undergoing upsampling from 1/16 and 1/32 downsampling
|
||||
feature maps. These two feature maps are used for Feature
|
||||
Fusion Module and Auxiliary Head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
context_channels=(128, 256, 512),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(context_channels) == 3, 'Length of input channels \
|
||||
of Context Path must be 3!'
|
||||
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
|
||||
self.align_corners = align_corners
|
||||
self.arm16 = AttentionRefinementModule(context_channels[1],
|
||||
context_channels[0])
|
||||
self.arm32 = AttentionRefinementModule(context_channels[2],
|
||||
context_channels[0])
|
||||
self.conv_head32 = ConvModule(
|
||||
in_channels=context_channels[0],
|
||||
out_channels=context_channels[0],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.conv_head16 = ConvModule(
|
||||
in_channels=context_channels[0],
|
||||
out_channels=context_channels[0],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.gap_conv = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
in_channels=context_channels[2],
|
||||
out_channels=context_channels[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, x):
|
||||
x_4, x_8, x_16, x_32 = self.backbone(x)
|
||||
x_gap = self.gap_conv(x_32)
|
||||
|
||||
x_32_arm = self.arm32(x_32)
|
||||
x_32_sum = x_32_arm + x_gap
|
||||
x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode='nearest')
|
||||
x_32_up = self.conv_head32(x_32_up)
|
||||
|
||||
x_16_arm = self.arm16(x_16)
|
||||
x_16_sum = x_16_arm + x_32_up
|
||||
x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode='nearest')
|
||||
x_16_up = self.conv_head16(x_16_up)
|
||||
|
||||
return x_16_up, x_32_up
|
||||
|
||||
|
||||
class FeatureFusionModule(BaseModule):
|
||||
"""Feature Fusion Module to fuse low level output feature of Spatial Path
|
||||
and high level output feature of Context Path.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
Returns:
|
||||
x_out (torch.Tensor): Feature map of Feature Fusion Module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.conv_atten = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg), nn.Sigmoid())
|
||||
|
||||
def forward(self, x_sp, x_cp):
|
||||
x_concat = torch.cat([x_sp, x_cp], dim=1)
|
||||
x_fuse = self.conv1(x_concat)
|
||||
x_atten = self.gap(x_fuse)
|
||||
# Note: No BN and more 1x1 conv in paper.
|
||||
x_atten = self.conv_atten(x_atten)
|
||||
x_atten = x_fuse * x_atten
|
||||
x_out = x_atten + x_fuse
|
||||
return x_out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BiSeNetV1(BaseModule):
|
||||
"""BiSeNetV1 backbone.
|
||||
|
||||
This backbone is the implementation of `BiSeNet: Bilateral
|
||||
Segmentation Network for Real-time Semantic
|
||||
Segmentation <https://arxiv.org/abs/1808.00897>`_.
|
||||
|
||||
Args:
|
||||
backbone_cfg:(dict): Config of backbone of
|
||||
Context Path.
|
||||
in_channels (int): The number of channels of input
|
||||
image. Default: 3.
|
||||
spatial_channels (Tuple[int]): Size of channel numbers of
|
||||
various layers in Spatial Path.
|
||||
Default: (64, 64, 64, 128).
|
||||
context_channels (Tuple[int]): Size of channel numbers of
|
||||
various modules in Context Path.
|
||||
Default: (128, 256, 512).
|
||||
out_indices (Tuple[int] | int, optional): Output from which stages.
|
||||
Default: (0, 1, 2).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation in Bilateral Guided Aggregation Layer.
|
||||
Default: False.
|
||||
out_channels(int): The number of channels of output.
|
||||
It must be the same with `in_channels` of decode_head.
|
||||
Default: 256.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
in_channels=3,
|
||||
spatial_channels=(64, 64, 64, 128),
|
||||
context_channels=(128, 256, 512),
|
||||
out_indices=(0, 1, 2),
|
||||
align_corners=False,
|
||||
out_channels=256,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(spatial_channels) == 4, 'Length of input channels \
|
||||
of Spatial Path must be 4!'
|
||||
|
||||
assert len(context_channels) == 3, 'Length of input channels \
|
||||
of Context Path must be 3!'
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.align_corners = align_corners
|
||||
self.context_path = ContextPath(backbone_cfg, context_channels,
|
||||
self.align_corners)
|
||||
self.spatial_path = SpatialPath(in_channels, spatial_channels)
|
||||
self.ffm = FeatureFusionModule(context_channels[1], out_channels)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
def forward(self, x):
|
||||
# stole refactoring code from Coin Cheung, thanks
|
||||
x_context8, x_context16 = self.context_path(x)
|
||||
x_spatial = self.spatial_path(x)
|
||||
x_fuse = self.ffm(x_spatial, x_context8)
|
||||
|
||||
outs = [x_fuse, x_context8, x_context16]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
622
finetune/mmseg/models/backbones/bisenetv2.py
Normal file
622
finetune/mmseg/models/backbones/bisenetv2.py
Normal file
@@ -0,0 +1,622 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
|
||||
build_activation_layer, build_norm_layer)
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class DetailBranch(BaseModule):
|
||||
"""Detail Branch with wide channels and shallow layers to capture low-level
|
||||
details and generate high-resolution feature representation.
|
||||
|
||||
Args:
|
||||
detail_channels (Tuple[int]): Size of channel numbers of each stage
|
||||
in Detail Branch, in paper it has 3 stages.
|
||||
Default: (64, 64, 128).
|
||||
in_channels (int): Number of channels of input image. Default: 3.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Feature map of Detail Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
detail_channels=(64, 64, 128),
|
||||
in_channels=3,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
detail_branch = []
|
||||
for i in range(len(detail_channels)):
|
||||
if i == 0:
|
||||
detail_branch.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)))
|
||||
else:
|
||||
detail_branch.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i - 1],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)))
|
||||
self.detail_branch = nn.ModuleList(detail_branch)
|
||||
|
||||
def forward(self, x):
|
||||
for stage in self.detail_branch:
|
||||
x = stage(x)
|
||||
return x
|
||||
|
||||
|
||||
class StemBlock(BaseModule):
|
||||
"""Stem Block at the beginning of Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): First feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_first = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.convs = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=out_channels // 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.pool = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1, ceil_mode=False)
|
||||
self.fuse_last = ConvModule(
|
||||
in_channels=out_channels * 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_first(x)
|
||||
x_left = self.convs(x)
|
||||
x_right = self.pool(x)
|
||||
x = self.fuse_last(torch.cat([x_left, x_right], dim=1))
|
||||
return x
|
||||
|
||||
|
||||
class GELayer(BaseModule):
|
||||
"""Gather-and-Expansion Layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
stride (int): Stride of GELayer. Default: 1
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Intermediate feature map in
|
||||
Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
exp_ratio=6,
|
||||
stride=1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
mid_channel = in_channels * exp_ratio
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if stride == 1:
|
||||
self.dwconv = nn.Sequential(
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.shortcut = None
|
||||
else:
|
||||
self.dwconv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=mid_channel,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
)
|
||||
self.shortcut = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=norm_cfg,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None,
|
||||
))
|
||||
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.dwconv(x)
|
||||
x = self.conv2(x)
|
||||
if self.shortcut is not None:
|
||||
shortcut = self.shortcut(identity)
|
||||
x = x + shortcut
|
||||
else:
|
||||
x = x + identity
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class CEBlock(BaseModule):
|
||||
"""Context Embedding Block for large receptive filed in Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Last feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.gap = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
build_norm_layer(norm_cfg, self.in_channels)[1])
|
||||
self.conv_gap = ConvModule(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
# Note: in paper here is naive conv2d, no bn-relu
|
||||
self.conv_last = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.gap(x)
|
||||
x = self.conv_gap(x)
|
||||
x = identity + x
|
||||
x = self.conv_last(x)
|
||||
return x
|
||||
|
||||
|
||||
class SemanticBranch(BaseModule):
|
||||
"""Semantic Branch which is lightweight with narrow channels and deep
|
||||
layers to obtain high-level semantic context.
|
||||
|
||||
Args:
|
||||
semantic_channels(Tuple[int]): Size of channel numbers of
|
||||
various stages in Semantic Branch.
|
||||
Default: (16, 32, 64, 128).
|
||||
in_channels (int): Number of channels of input image. Default: 3.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
semantic_outs (List[torch.Tensor]): List of several feature maps
|
||||
for auxiliary heads (Booster) and Bilateral
|
||||
Guided Aggregation Layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
in_channels=3,
|
||||
exp_ratio=6,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.semantic_channels = semantic_channels
|
||||
self.semantic_stages = []
|
||||
for i in range(len(semantic_channels)):
|
||||
stage_name = f'stage{i + 1}'
|
||||
self.semantic_stages.append(stage_name)
|
||||
if i == 0:
|
||||
self.add_module(
|
||||
stage_name,
|
||||
StemBlock(self.in_channels, semantic_channels[i]))
|
||||
elif i == (len(semantic_channels) - 1):
|
||||
self.add_module(
|
||||
stage_name,
|
||||
nn.Sequential(
|
||||
GELayer(semantic_channels[i - 1], semantic_channels[i],
|
||||
exp_ratio, 2),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1)))
|
||||
else:
|
||||
self.add_module(
|
||||
stage_name,
|
||||
nn.Sequential(
|
||||
GELayer(semantic_channels[i - 1], semantic_channels[i],
|
||||
exp_ratio, 2),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1)))
|
||||
|
||||
self.add_module(f'stage{len(semantic_channels)}_CEBlock',
|
||||
CEBlock(semantic_channels[-1], semantic_channels[-1]))
|
||||
self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock')
|
||||
|
||||
def forward(self, x):
|
||||
semantic_outs = []
|
||||
for stage_name in self.semantic_stages:
|
||||
semantic_stage = getattr(self, stage_name)
|
||||
x = semantic_stage(x)
|
||||
semantic_outs.append(x)
|
||||
return semantic_outs
|
||||
|
||||
|
||||
class BGALayer(BaseModule):
|
||||
"""Bilateral Guided Aggregation Layer to fuse the complementary information
|
||||
from both Detail Branch and Semantic Branch.
|
||||
|
||||
Args:
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 128.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
output (torch.Tensor): Output feature map for Segment heads.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_channels=128,
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.out_channels = out_channels
|
||||
self.align_corners = align_corners
|
||||
self.detail_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.detail_down = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False))
|
||||
self.semantic_conv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None))
|
||||
self.semantic_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.conv = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
inplace=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
)
|
||||
|
||||
def forward(self, x_d, x_s):
|
||||
detail_dwconv = self.detail_dwconv(x_d)
|
||||
detail_down = self.detail_down(x_d)
|
||||
semantic_conv = self.semantic_conv(x_s)
|
||||
semantic_dwconv = self.semantic_dwconv(x_s)
|
||||
semantic_conv = resize(
|
||||
input=semantic_conv,
|
||||
size=detail_dwconv.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv)
|
||||
fuse_2 = detail_down * torch.sigmoid(semantic_dwconv)
|
||||
fuse_2 = resize(
|
||||
input=fuse_2,
|
||||
size=fuse_1.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
output = self.conv(fuse_1 + fuse_2)
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BiSeNetV2(BaseModule):
|
||||
"""BiSeNetV2: Bilateral Network with Guided Aggregation for
|
||||
Real-time Semantic Segmentation.
|
||||
|
||||
This backbone is the implementation of
|
||||
`BiSeNetV2 <https://arxiv.org/abs/2004.02147>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channel of input image. Default: 3.
|
||||
detail_channels (Tuple[int], optional): Channels of each stage
|
||||
in Detail Branch. Default: (64, 64, 128).
|
||||
semantic_channels (Tuple[int], optional): Channels of each stage
|
||||
in Semantic Branch. Default: (16, 32, 64, 128).
|
||||
See Table 1 and Figure 3 of paper for more details.
|
||||
semantic_expansion_ratio (int, optional): The expansion factor
|
||||
expanding channel number of middle channels in Semantic Branch.
|
||||
Default: 6.
|
||||
bga_channels (int, optional): Number of middle channels in
|
||||
Bilateral Guided Aggregation Layer. Default: 128.
|
||||
out_indices (Tuple[int] | int, optional): Output from which stages.
|
||||
Default: (0, 1, 2, 3, 4).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation in Bilateral Guided Aggregation Layer.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
detail_channels=(64, 64, 128),
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
semantic_expansion_ratio=6,
|
||||
bga_channels=128,
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
if init_cfg is None:
|
||||
init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_indices = out_indices
|
||||
self.detail_channels = detail_channels
|
||||
self.semantic_channels = semantic_channels
|
||||
self.semantic_expansion_ratio = semantic_expansion_ratio
|
||||
self.bga_channels = bga_channels
|
||||
self.align_corners = align_corners
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.detail = DetailBranch(self.detail_channels, self.in_channels)
|
||||
self.semantic = SemanticBranch(self.semantic_channels,
|
||||
self.in_channels,
|
||||
self.semantic_expansion_ratio)
|
||||
self.bga = BGALayer(self.bga_channels, self.align_corners)
|
||||
|
||||
def forward(self, x):
|
||||
# stole refactoring code from Coin Cheung, thanks
|
||||
x_detail = self.detail(x)
|
||||
x_semantic_lst = self.semantic(x)
|
||||
x_head = self.bga(x_detail, x_semantic_lst[-1])
|
||||
outs = [x_head] + x_semantic_lst[:-1]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
372
finetune/mmseg/models/backbones/cgnet.py
Normal file
372
finetune/mmseg/models/backbones/cgnet.py
Normal file
@@ -0,0 +1,372 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class GlobalContextExtractor(nn.Module):
|
||||
"""Global Context Extractor for CGNet.
|
||||
|
||||
This class is employed to refine the joint feature of both local feature
|
||||
and surrounding context.
|
||||
|
||||
Args:
|
||||
channel (int): Number of input feature channels.
|
||||
reduction (int): Reductions for global context extractor. Default: 16.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, channel, reduction=16, with_cp=False):
|
||||
super().__init__()
|
||||
self.channel = channel
|
||||
self.reduction = reduction
|
||||
assert reduction >= 1 and channel >= reduction
|
||||
self.with_cp = with_cp
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
num_batch, num_channel = x.size()[:2]
|
||||
y = self.avg_pool(x).view(num_batch, num_channel)
|
||||
y = self.fc(y).view(num_batch, num_channel, 1, 1)
|
||||
return x * y
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ContextGuidedBlock(nn.Module):
|
||||
"""Context Guided Block for CGNet.
|
||||
|
||||
This class consists of four components: local feature extractor,
|
||||
surrounding feature extractor, joint feature extractor and global
|
||||
context extractor.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input feature channels.
|
||||
out_channels (int): Number of output feature channels.
|
||||
dilation (int): Dilation rate for surrounding context extractor.
|
||||
Default: 2.
|
||||
reduction (int): Reduction for global context extractor. Default: 16.
|
||||
skip_connect (bool): Add input to output or not. Default: True.
|
||||
downsample (bool): Downsample the input to 1/2 or not. Default: False.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='PReLU').
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dilation=2,
|
||||
reduction=16,
|
||||
skip_connect=True,
|
||||
downsample=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='PReLU'),
|
||||
with_cp=False):
|
||||
super().__init__()
|
||||
self.with_cp = with_cp
|
||||
self.downsample = downsample
|
||||
|
||||
channels = out_channels if downsample else out_channels // 2
|
||||
if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
|
||||
act_cfg['num_parameters'] = channels
|
||||
kernel_size = 3 if downsample else 1
|
||||
stride = 2 if downsample else 1
|
||||
padding = (kernel_size - 1) // 2
|
||||
|
||||
self.conv1x1 = ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.f_loc = build_conv_layer(
|
||||
conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
groups=channels,
|
||||
bias=False)
|
||||
self.f_sur = build_conv_layer(
|
||||
conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=dilation,
|
||||
groups=channels,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
|
||||
self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
|
||||
self.activate = nn.PReLU(2 * channels)
|
||||
|
||||
if downsample:
|
||||
self.bottleneck = build_conv_layer(
|
||||
conv_cfg,
|
||||
2 * channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
|
||||
self.skip_connect = skip_connect and not downsample
|
||||
self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
out = self.conv1x1(x)
|
||||
loc = self.f_loc(out)
|
||||
sur = self.f_sur(out)
|
||||
|
||||
joi_feat = torch.cat([loc, sur], 1) # the joint feature
|
||||
joi_feat = self.bn(joi_feat)
|
||||
joi_feat = self.activate(joi_feat)
|
||||
if self.downsample:
|
||||
joi_feat = self.bottleneck(joi_feat) # channel = out_channels
|
||||
# f_glo is employed to refine the joint feature
|
||||
out = self.f_glo(joi_feat)
|
||||
|
||||
if self.skip_connect:
|
||||
return x + out
|
||||
else:
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class InputInjection(nn.Module):
|
||||
"""Downsampling module for CGNet."""
|
||||
|
||||
def __init__(self, num_downsampling):
|
||||
super().__init__()
|
||||
self.pool = nn.ModuleList()
|
||||
for i in range(num_downsampling):
|
||||
self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
for pool in self.pool:
|
||||
x = pool(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CGNet(BaseModule):
|
||||
"""CGNet backbone.
|
||||
|
||||
This backbone is the implementation of `A Light-weight Context Guided
|
||||
Network for Semantic Segmentation <https://arxiv.org/abs/1811.08201>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
num_channels (tuple[int]): Numbers of feature channels at each stages.
|
||||
Default: (32, 64, 128).
|
||||
num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
|
||||
Default: (3, 21).
|
||||
dilations (tuple[int]): Dilation rate for surrounding context
|
||||
extractors at stage 1 and stage 2. Default: (2, 4).
|
||||
reductions (tuple[int]): Reductions for global context extractors at
|
||||
stage 1 and stage 2. Default: (8, 16).
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='PReLU').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
num_channels=(32, 64, 128),
|
||||
num_blocks=(3, 21),
|
||||
dilations=(2, 4),
|
||||
reductions=(8, 16),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='PReLU'),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer=['Conv2d', 'Linear']),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm']),
|
||||
dict(type='Constant', val=0, layer='PReLU')
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_channels = num_channels
|
||||
assert isinstance(self.num_channels, tuple) and len(
|
||||
self.num_channels) == 3
|
||||
self.num_blocks = num_blocks
|
||||
assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
|
||||
self.dilations = dilations
|
||||
assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
|
||||
self.reductions = reductions
|
||||
assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
|
||||
self.act_cfg['num_parameters'] = num_channels[0]
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
cur_channels = in_channels
|
||||
self.stem = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.stem.append(
|
||||
ConvModule(
|
||||
cur_channels,
|
||||
num_channels[0],
|
||||
3,
|
||||
2 if i == 0 else 1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
cur_channels = num_channels[0]
|
||||
|
||||
self.inject_2x = InputInjection(1) # down-sample for Input, factor=2
|
||||
self.inject_4x = InputInjection(2) # down-sample for Input, factor=4
|
||||
|
||||
cur_channels += in_channels
|
||||
self.norm_prelu_0 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
# stage 1
|
||||
self.level1 = nn.ModuleList()
|
||||
for i in range(num_blocks[0]):
|
||||
self.level1.append(
|
||||
ContextGuidedBlock(
|
||||
cur_channels if i == 0 else num_channels[1],
|
||||
num_channels[1],
|
||||
dilations[0],
|
||||
reductions[0],
|
||||
downsample=(i == 0),
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
with_cp=with_cp)) # CG block
|
||||
|
||||
cur_channels = 2 * num_channels[1] + in_channels
|
||||
self.norm_prelu_1 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
# stage 2
|
||||
self.level2 = nn.ModuleList()
|
||||
for i in range(num_blocks[1]):
|
||||
self.level2.append(
|
||||
ContextGuidedBlock(
|
||||
cur_channels if i == 0 else num_channels[2],
|
||||
num_channels[2],
|
||||
dilations[1],
|
||||
reductions[1],
|
||||
downsample=(i == 0),
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
with_cp=with_cp)) # CG block
|
||||
|
||||
cur_channels = 2 * num_channels[2]
|
||||
self.norm_prelu_2 = nn.Sequential(
|
||||
build_norm_layer(norm_cfg, cur_channels)[1],
|
||||
nn.PReLU(cur_channels))
|
||||
|
||||
def forward(self, x):
|
||||
output = []
|
||||
|
||||
# stage 0
|
||||
inp_2x = self.inject_2x(x)
|
||||
inp_4x = self.inject_4x(x)
|
||||
for layer in self.stem:
|
||||
x = layer(x)
|
||||
x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
|
||||
output.append(x)
|
||||
|
||||
# stage 1
|
||||
for i, layer in enumerate(self.level1):
|
||||
x = layer(x)
|
||||
if i == 0:
|
||||
down1 = x
|
||||
x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
|
||||
output.append(x)
|
||||
|
||||
# stage 2
|
||||
for i, layer in enumerate(self.level2):
|
||||
x = layer(x)
|
||||
if i == 0:
|
||||
down2 = x
|
||||
x = self.norm_prelu_2(torch.cat([down2, x], 1))
|
||||
output.append(x)
|
||||
|
||||
return output
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode will keeping the normalization
|
||||
layer freezed."""
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
222
finetune/mmseg/models/backbones/ddrnet.py
Normal file
222
finetune/mmseg/models/backbones/ddrnet.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DDRNet(BaseModule):
|
||||
"""DDRNet backbone.
|
||||
|
||||
This backbone is the implementation of `Deep Dual-resolution Networks for
|
||||
Real-time and Accurate Semantic Segmentation of Road Scenes
|
||||
<http://arxiv.org/abs/2101.06085>`_.
|
||||
Modified from https://github.com/ydhongHIT/DDRNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
channels: (int): The base channels of DDRNet. Default: 32.
|
||||
ppm_channels (int): The channels of PPM module. Default: 128.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
norm_cfg (dict): Config dict to build norm layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int = 3,
|
||||
channels: int = 32,
|
||||
ppm_channels: int = 128,
|
||||
align_corners: bool = False,
|
||||
norm_cfg: OptConfigType = dict(type='BN', requires_grad=True),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.ppm_channels = ppm_channels
|
||||
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
|
||||
# stage 0-2
|
||||
self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# low resolution(context) branch
|
||||
self.context_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.context_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
inplanes=channels * 2**(i + 1),
|
||||
planes=channels * 8 if i > 0 else channels * 4,
|
||||
num_blocks=2 if i < 2 else 1,
|
||||
stride=2))
|
||||
|
||||
# bilateral fusion
|
||||
self.compression_1 = ConvModule(
|
||||
channels * 4,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
self.down_1 = ConvModule(
|
||||
channels * 2,
|
||||
channels * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
self.compression_2 = ConvModule(
|
||||
channels * 8,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
self.down_2 = nn.Sequential(
|
||||
ConvModule(
|
||||
channels * 2,
|
||||
channels * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
channels * 4,
|
||||
channels * 8,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None))
|
||||
|
||||
# high resolution(spatial) branch
|
||||
self.spatial_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.spatial_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
inplanes=channels * 2,
|
||||
planes=channels * 2,
|
||||
num_blocks=2 if i < 2 else 1,
|
||||
))
|
||||
|
||||
self.spp = DAPPM(
|
||||
channels * 16, ppm_channels, channels * 4, num_scales=5)
|
||||
|
||||
def _make_stem_layer(self, in_channels, channels, num_blocks):
|
||||
layers = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
]
|
||||
|
||||
layers.extend([
|
||||
self._make_layer(BasicBlock, channels, channels, num_blocks),
|
||||
nn.ReLU(),
|
||||
self._make_layer(
|
||||
BasicBlock, channels, channels * 2, num_blocks, stride=2),
|
||||
nn.ReLU(),
|
||||
])
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_layer(self, block, inplanes, planes, num_blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
|
||||
|
||||
layers = [
|
||||
block(
|
||||
in_channels=inplanes,
|
||||
channels=planes,
|
||||
stride=stride,
|
||||
downsample=downsample)
|
||||
]
|
||||
inplanes = planes * block.expansion
|
||||
for i in range(1, num_blocks):
|
||||
layers.append(
|
||||
block(
|
||||
in_channels=inplanes,
|
||||
channels=planes,
|
||||
stride=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
out_size = (x.shape[-2] // 8, x.shape[-1] // 8)
|
||||
|
||||
# stage 0-2
|
||||
x = self.stem(x)
|
||||
|
||||
# stage3
|
||||
x_c = self.context_branch_layers[0](x)
|
||||
x_s = self.spatial_branch_layers[0](x)
|
||||
comp_c = self.compression_1(self.relu(x_c))
|
||||
x_c += self.down_1(self.relu(x_s))
|
||||
x_s += resize(
|
||||
comp_c,
|
||||
size=out_size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.training:
|
||||
temp_context = x_s.clone()
|
||||
|
||||
# stage4
|
||||
x_c = self.context_branch_layers[1](self.relu(x_c))
|
||||
x_s = self.spatial_branch_layers[1](self.relu(x_s))
|
||||
comp_c = self.compression_2(self.relu(x_c))
|
||||
x_c += self.down_2(self.relu(x_s))
|
||||
x_s += resize(
|
||||
comp_c,
|
||||
size=out_size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
# stage5
|
||||
x_s = self.spatial_branch_layers[2](self.relu(x_s))
|
||||
x_c = self.context_branch_layers[2](self.relu(x_c))
|
||||
x_c = self.spp(x_c)
|
||||
x_c = resize(
|
||||
x_c,
|
||||
size=out_size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
return (temp_context, x_s + x_c) if self.training else x_s + x_c
|
||||
329
finetune/mmseg/models/backbones/erfnet.py
Normal file
329
finetune/mmseg/models/backbones/erfnet.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class DownsamplerBlock(BaseModule):
|
||||
"""Downsampler block of ERFNet.
|
||||
|
||||
This module is a little different from basical ConvModule.
|
||||
The features from Conv and MaxPool layers are
|
||||
concatenated before BatchNorm.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.conv = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
out_channels - in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
def forward(self, input):
|
||||
conv_out = self.conv(input)
|
||||
pool_out = self.pool(input)
|
||||
pool_out = resize(
|
||||
input=pool_out,
|
||||
size=conv_out.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
output = torch.cat([conv_out, pool_out], 1)
|
||||
output = self.bn(output)
|
||||
output = self.act(output)
|
||||
return output
|
||||
|
||||
|
||||
class NonBottleneck1d(BaseModule):
|
||||
"""Non-bottleneck block of ERFNet.
|
||||
|
||||
Args:
|
||||
channels (int): Number of channels in Non-bottleneck block.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.
|
||||
dilation (int): Dilation rate for last two conv layers.
|
||||
Default 1.
|
||||
num_conv_layer (int): Number of 3x1 and 1x3 convolution layers.
|
||||
Default 2.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
drop_rate=0,
|
||||
dilation=1,
|
||||
num_conv_layer=2,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
self.convs_layers = nn.ModuleList()
|
||||
for conv_layer in range(num_conv_layer):
|
||||
first_conv_padding = (1, 0) if conv_layer == 0 else (dilation, 0)
|
||||
first_conv_dilation = 1 if conv_layer == 0 else (dilation, 1)
|
||||
second_conv_padding = (0, 1) if conv_layer == 0 else (0, dilation)
|
||||
second_conv_dilation = 1 if conv_layer == 0 else (1, dilation)
|
||||
|
||||
self.convs_layers.append(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=(3, 1),
|
||||
stride=1,
|
||||
padding=first_conv_padding,
|
||||
bias=True,
|
||||
dilation=first_conv_dilation))
|
||||
self.convs_layers.append(self.act)
|
||||
self.convs_layers.append(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=(1, 3),
|
||||
stride=1,
|
||||
padding=second_conv_padding,
|
||||
bias=True,
|
||||
dilation=second_conv_dilation))
|
||||
self.convs_layers.append(
|
||||
build_norm_layer(self.norm_cfg, channels)[1])
|
||||
if conv_layer == 0:
|
||||
self.convs_layers.append(self.act)
|
||||
else:
|
||||
self.convs_layers.append(nn.Dropout(p=drop_rate))
|
||||
|
||||
def forward(self, input):
|
||||
output = input
|
||||
for conv in self.convs_layers:
|
||||
output = conv(output)
|
||||
output = self.act(output + input)
|
||||
return output
|
||||
|
||||
|
||||
class UpsamplerBlock(BaseModule):
|
||||
"""Upsampler block of ERFNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.conv = nn.ConvTranspose2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
bias=True)
|
||||
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv(input)
|
||||
output = self.bn(output)
|
||||
output = self.act(output)
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ERFNet(BaseModule):
|
||||
"""ERFNet backbone.
|
||||
|
||||
This backbone is the implementation of `ERFNet: Efficient Residual
|
||||
Factorized ConvNet for Real-time SemanticSegmentation
|
||||
<https://ieeexplore.ieee.org/document/8063438>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels of input
|
||||
image. Default: 3.
|
||||
enc_downsample_channels (Tuple[int]): Size of channel
|
||||
numbers of various Downsampler block in encoder.
|
||||
Default: (16, 64, 128).
|
||||
enc_stage_non_bottlenecks (Tuple[int]): Number of stages of
|
||||
Non-bottleneck block in encoder.
|
||||
Default: (5, 8).
|
||||
enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each
|
||||
stage of Non-bottleneck block of encoder.
|
||||
Default: (2, 4, 8, 16).
|
||||
enc_non_bottleneck_channels (Tuple[int]): Size of channel
|
||||
numbers of various Non-bottleneck block in encoder.
|
||||
Default: (64, 128).
|
||||
dec_upsample_channels (Tuple[int]): Size of channel numbers of
|
||||
various Deconvolution block in decoder.
|
||||
Default: (64, 16).
|
||||
dec_stages_non_bottleneck (Tuple[int]): Number of stages of
|
||||
Non-bottleneck block in decoder.
|
||||
Default: (2, 2).
|
||||
dec_non_bottleneck_channels (Tuple[int]): Size of channel
|
||||
numbers of various Non-bottleneck block in decoder.
|
||||
Default: (64, 16).
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(dec_upsample_channels)+1, 'Number of downsample\
|
||||
block of encoder does not \
|
||||
match number of upsample block of decoder!'
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(enc_stage_non_bottlenecks)+1, 'Number of \
|
||||
downsample block of encoder does not match \
|
||||
number of Non-bottleneck block of encoder!'
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(enc_non_bottleneck_channels)+1, 'Number of \
|
||||
downsample block of encoder does not match \
|
||||
number of channels of Non-bottleneck block of encoder!'
|
||||
assert enc_stage_non_bottlenecks[-1] \
|
||||
% len(enc_non_bottleneck_dilations) == 0, 'Number of \
|
||||
Non-bottleneck block of encoder does not match \
|
||||
number of Non-bottleneck block of encoder!'
|
||||
assert len(dec_upsample_channels) \
|
||||
== len(dec_stages_non_bottleneck), 'Number of \
|
||||
upsample block of decoder does not match \
|
||||
number of Non-bottleneck block of decoder!'
|
||||
assert len(dec_stages_non_bottleneck) \
|
||||
== len(dec_non_bottleneck_channels), 'Number of \
|
||||
Non-bottleneck block of decoder does not match \
|
||||
number of channels of Non-bottleneck block of decoder!'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.enc_downsample_channels = enc_downsample_channels
|
||||
self.enc_stage_non_bottlenecks = enc_stage_non_bottlenecks
|
||||
self.enc_non_bottleneck_dilations = enc_non_bottleneck_dilations
|
||||
self.enc_non_bottleneck_channels = enc_non_bottleneck_channels
|
||||
self.dec_upsample_channels = dec_upsample_channels
|
||||
self.dec_stages_non_bottleneck = dec_stages_non_bottleneck
|
||||
self.dec_non_bottleneck_channels = dec_non_bottleneck_channels
|
||||
self.dropout_ratio = dropout_ratio
|
||||
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.encoder.append(
|
||||
DownsamplerBlock(self.in_channels, enc_downsample_channels[0]))
|
||||
|
||||
for i in range(len(enc_downsample_channels) - 1):
|
||||
self.encoder.append(
|
||||
DownsamplerBlock(enc_downsample_channels[i],
|
||||
enc_downsample_channels[i + 1]))
|
||||
# Last part of encoder is some dilated NonBottleneck1d blocks.
|
||||
if i == len(enc_downsample_channels) - 2:
|
||||
iteration_times = int(enc_stage_non_bottlenecks[-1] /
|
||||
len(enc_non_bottleneck_dilations))
|
||||
for j in range(iteration_times):
|
||||
for k in range(len(enc_non_bottleneck_dilations)):
|
||||
self.encoder.append(
|
||||
NonBottleneck1d(enc_downsample_channels[-1],
|
||||
self.dropout_ratio,
|
||||
enc_non_bottleneck_dilations[k]))
|
||||
else:
|
||||
for j in range(enc_stage_non_bottlenecks[i]):
|
||||
self.encoder.append(
|
||||
NonBottleneck1d(enc_downsample_channels[i + 1],
|
||||
self.dropout_ratio))
|
||||
|
||||
for i in range(len(dec_upsample_channels)):
|
||||
if i == 0:
|
||||
self.decoder.append(
|
||||
UpsamplerBlock(enc_downsample_channels[-1],
|
||||
dec_non_bottleneck_channels[i]))
|
||||
else:
|
||||
self.decoder.append(
|
||||
UpsamplerBlock(dec_non_bottleneck_channels[i - 1],
|
||||
dec_non_bottleneck_channels[i]))
|
||||
for j in range(dec_stages_non_bottleneck[i]):
|
||||
self.decoder.append(
|
||||
NonBottleneck1d(dec_non_bottleneck_channels[i]))
|
||||
|
||||
def forward(self, x):
|
||||
for enc in self.encoder:
|
||||
x = enc(x)
|
||||
for dec in self.decoder:
|
||||
x = dec(x)
|
||||
return [x]
|
||||
408
finetune/mmseg/models/backbones/fast_scnn.py
Normal file
408
finetune/mmseg/models/backbones/fast_scnn.py
Normal file
@@ -0,0 +1,408 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.models.decode_heads.psp_head import PPM
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidual, resize
|
||||
|
||||
|
||||
class LearningToDownsample(nn.Module):
|
||||
"""Learning to downsample module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
dw_channels (tuple[int]): Number of output channels of the first and
|
||||
the second depthwise conv (dwconv) layers.
|
||||
out_channels (int): Number of output channels of the whole
|
||||
'learning to downsample' module.
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config
|
||||
of depthwise ConvModule. If it is 'default', it will be the same
|
||||
as `act_cfg`. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
dw_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
dw_act_cfg=None):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.dw_act_cfg = dw_act_cfg
|
||||
dw_channels1 = dw_channels[0]
|
||||
dw_channels2 = dw_channels[1]
|
||||
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
dw_channels1,
|
||||
3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.dsconv1 = DepthwiseSeparableConvModule(
|
||||
dw_channels1,
|
||||
dw_channels2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=self.dw_act_cfg)
|
||||
|
||||
self.dsconv2 = DepthwiseSeparableConvModule(
|
||||
dw_channels2,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dw_act_cfg=self.dw_act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.dsconv1(x)
|
||||
x = self.dsconv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class GlobalFeatureExtractor(nn.Module):
|
||||
"""Global feature extractor module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels of the GFE module.
|
||||
Default: 64
|
||||
block_channels (tuple[int]): Tuple of ints. Each int specifies the
|
||||
number of output channels of each Inverted Residual module.
|
||||
Default: (64, 96, 128)
|
||||
out_channels(int): Number of output channels of the GFE module.
|
||||
Default: 128
|
||||
expand_ratio (int): Adjusts number of channels of the hidden layer
|
||||
in InvertedResidual by this amount.
|
||||
Default: 6
|
||||
num_blocks (tuple[int]): Tuple of ints. Each int specifies the
|
||||
number of times each Inverted Residual module is repeated.
|
||||
The repeated Inverted Residual modules are called a 'group'.
|
||||
Default: (3, 3, 3)
|
||||
strides (tuple[int]): Tuple of ints. Each int specifies
|
||||
the downsampling factor of each 'group'.
|
||||
Default: (2, 2, 1)
|
||||
pool_scales (tuple[int]): Tuple of ints. Each int specifies
|
||||
the parameter required in 'global average pooling' within PPM.
|
||||
Default: (1, 2, 3, 6)
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=64,
|
||||
block_channels=(64, 96, 128),
|
||||
out_channels=128,
|
||||
expand_ratio=6,
|
||||
num_blocks=(3, 3, 3),
|
||||
strides=(2, 2, 1),
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
assert len(block_channels) == len(num_blocks) == 3
|
||||
self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
|
||||
num_blocks[0], strides[0],
|
||||
expand_ratio)
|
||||
self.bottleneck2 = self._make_layer(block_channels[0],
|
||||
block_channels[1], num_blocks[1],
|
||||
strides[1], expand_ratio)
|
||||
self.bottleneck3 = self._make_layer(block_channels[1],
|
||||
block_channels[2], num_blocks[2],
|
||||
strides[2], expand_ratio)
|
||||
self.ppm = PPM(
|
||||
pool_scales,
|
||||
block_channels[2],
|
||||
block_channels[2] // 4,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=align_corners)
|
||||
|
||||
self.out = ConvModule(
|
||||
block_channels[2] * 2,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _make_layer(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
blocks,
|
||||
stride=1,
|
||||
expand_ratio=6):
|
||||
layers = [
|
||||
InvertedResidual(
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
expand_ratio,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
]
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
InvertedResidual(
|
||||
out_channels,
|
||||
out_channels,
|
||||
1,
|
||||
expand_ratio,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bottleneck1(x)
|
||||
x = self.bottleneck2(x)
|
||||
x = self.bottleneck3(x)
|
||||
x = torch.cat([x, *self.ppm(x)], dim=1)
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
class FeatureFusionModule(nn.Module):
|
||||
"""Feature fusion module.
|
||||
|
||||
Args:
|
||||
higher_in_channels (int): Number of input channels of the
|
||||
higher-resolution branch.
|
||||
lower_in_channels (int): Number of input channels of the
|
||||
lower-resolution branch.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
dwconv_act_cfg (dict): Config of activation layers in 3x3 conv.
|
||||
Default: dict(type='ReLU').
|
||||
conv_act_cfg (dict): Config of activation layers in the two 1x1 conv.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
higher_in_channels,
|
||||
lower_in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dwconv_act_cfg=dict(type='ReLU'),
|
||||
conv_act_cfg=None,
|
||||
align_corners=False):
|
||||
super().__init__()
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.dwconv_act_cfg = dwconv_act_cfg
|
||||
self.conv_act_cfg = conv_act_cfg
|
||||
self.align_corners = align_corners
|
||||
self.dwconv = ConvModule(
|
||||
lower_in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
groups=out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.dwconv_act_cfg)
|
||||
self.conv_lower_res = ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.conv_act_cfg)
|
||||
|
||||
self.conv_higher_res = ConvModule(
|
||||
higher_in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.conv_act_cfg)
|
||||
|
||||
self.relu = nn.ReLU(True)
|
||||
|
||||
def forward(self, higher_res_feature, lower_res_feature):
|
||||
lower_res_feature = resize(
|
||||
lower_res_feature,
|
||||
size=higher_res_feature.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
lower_res_feature = self.dwconv(lower_res_feature)
|
||||
lower_res_feature = self.conv_lower_res(lower_res_feature)
|
||||
|
||||
higher_res_feature = self.conv_higher_res(higher_res_feature)
|
||||
out = higher_res_feature + lower_res_feature
|
||||
return self.relu(out)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FastSCNN(BaseModule):
|
||||
"""Fast-SCNN Backbone.
|
||||
|
||||
This backbone is the implementation of `Fast-SCNN: Fast Semantic
|
||||
Segmentation Network <https://arxiv.org/abs/1902.04502>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
downsample_dw_channels (tuple[int]): Number of output channels after
|
||||
the first conv layer & the second conv layer in
|
||||
Learning-To-Downsample (LTD) module.
|
||||
Default: (32, 48).
|
||||
global_in_channels (int): Number of input channels of
|
||||
Global Feature Extractor(GFE).
|
||||
Equal to number of output channels of LTD.
|
||||
Default: 64.
|
||||
global_block_channels (tuple[int]): Tuple of integers that describe
|
||||
the output channels for each of the MobileNet-v2 bottleneck
|
||||
residual blocks in GFE.
|
||||
Default: (64, 96, 128).
|
||||
global_block_strides (tuple[int]): Tuple of integers
|
||||
that describe the strides (downsampling factors) for each of the
|
||||
MobileNet-v2 bottleneck residual blocks in GFE.
|
||||
Default: (2, 2, 1).
|
||||
global_out_channels (int): Number of output channels of GFE.
|
||||
Default: 128.
|
||||
higher_in_channels (int): Number of input channels of the higher
|
||||
resolution branch in FFM.
|
||||
Equal to global_in_channels.
|
||||
Default: 64.
|
||||
lower_in_channels (int): Number of input channels of the lower
|
||||
resolution branch in FFM.
|
||||
Equal to global_out_channels.
|
||||
Default: 128.
|
||||
fusion_out_channels (int): Number of output channels of FFM.
|
||||
Default: 128.
|
||||
out_indices (tuple): Tuple of indices of list
|
||||
[higher_res_features, lower_res_features, fusion_output].
|
||||
Often set to (0,1,2) to enable aux. heads.
|
||||
Default: (0, 1, 2).
|
||||
conv_cfg (dict | None): Config of conv layers. Default: None
|
||||
norm_cfg (dict | None): Config of norm layers. Default:
|
||||
dict(type='BN')
|
||||
act_cfg (dict): Config of activation layers. Default:
|
||||
dict(type='ReLU')
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False
|
||||
dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config
|
||||
of depthwise ConvModule. If it is 'default', it will be the same
|
||||
as `act_cfg`. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
downsample_dw_channels=(32, 48),
|
||||
global_in_channels=64,
|
||||
global_block_channels=(64, 96, 128),
|
||||
global_block_strides=(2, 2, 1),
|
||||
global_out_channels=128,
|
||||
higher_in_channels=64,
|
||||
lower_in_channels=128,
|
||||
fusion_out_channels=128,
|
||||
out_indices=(0, 1, 2),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
dw_act_cfg=None,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg)
|
||||
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
|
||||
if global_in_channels != higher_in_channels:
|
||||
raise AssertionError('Global Input Channels must be the same \
|
||||
with Higher Input Channels!')
|
||||
elif global_out_channels != lower_in_channels:
|
||||
raise AssertionError('Global Output Channels must be the same \
|
||||
with Lower Input Channels!')
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.downsample_dw_channels1 = downsample_dw_channels[0]
|
||||
self.downsample_dw_channels2 = downsample_dw_channels[1]
|
||||
self.global_in_channels = global_in_channels
|
||||
self.global_block_channels = global_block_channels
|
||||
self.global_block_strides = global_block_strides
|
||||
self.global_out_channels = global_out_channels
|
||||
self.higher_in_channels = higher_in_channels
|
||||
self.lower_in_channels = lower_in_channels
|
||||
self.fusion_out_channels = fusion_out_channels
|
||||
self.out_indices = out_indices
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
self.learning_to_downsample = LearningToDownsample(
|
||||
in_channels,
|
||||
downsample_dw_channels,
|
||||
global_in_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
dw_act_cfg=dw_act_cfg)
|
||||
self.global_feature_extractor = GlobalFeatureExtractor(
|
||||
global_in_channels,
|
||||
global_block_channels,
|
||||
global_out_channels,
|
||||
strides=self.global_block_strides,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
self.feature_fusion = FeatureFusionModule(
|
||||
higher_in_channels,
|
||||
lower_in_channels,
|
||||
fusion_out_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dwconv_act_cfg=self.act_cfg,
|
||||
align_corners=self.align_corners)
|
||||
|
||||
def forward(self, x):
|
||||
higher_res_features = self.learning_to_downsample(x)
|
||||
lower_res_features = self.global_feature_extractor(higher_res_features)
|
||||
fusion_output = self.feature_fusion(higher_res_features,
|
||||
lower_res_features)
|
||||
|
||||
outs = [higher_res_features, lower_res_features, fusion_output]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
||||
642
finetune/mmseg/models/backbones/hrnet.py
Normal file
642
finetune/mmseg/models/backbones/hrnet.py
Normal file
@@ -0,0 +1,642 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample, resize
|
||||
from .resnet import BasicBlock, Bottleneck
|
||||
|
||||
|
||||
class HRModule(BaseModule):
|
||||
"""High-Resolution Module for HRNet.
|
||||
|
||||
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
|
||||
is in this module.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_branches,
|
||||
blocks,
|
||||
num_blocks,
|
||||
in_channels,
|
||||
num_channels,
|
||||
multiscale_output=True,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
block_init_cfg=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.block_init_cfg = block_init_cfg
|
||||
self._check_branches(num_branches, num_blocks, in_channels,
|
||||
num_channels)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.num_branches = num_branches
|
||||
|
||||
self.multiscale_output = multiscale_output
|
||||
self.norm_cfg = norm_cfg
|
||||
self.conv_cfg = conv_cfg
|
||||
self.with_cp = with_cp
|
||||
self.branches = self._make_branches(num_branches, blocks, num_blocks,
|
||||
num_channels)
|
||||
self.fuse_layers = self._make_fuse_layers()
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
|
||||
def _check_branches(self, num_branches, num_blocks, in_channels,
|
||||
num_channels):
|
||||
"""Check branches configuration."""
|
||||
if num_branches != len(num_blocks):
|
||||
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \
|
||||
f'{len(num_blocks)})'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if num_branches != len(num_channels):
|
||||
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \
|
||||
f'{len(num_channels)})'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if num_branches != len(in_channels):
|
||||
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \
|
||||
f'{len(in_channels)})'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def _make_one_branch(self,
|
||||
branch_index,
|
||||
block,
|
||||
num_blocks,
|
||||
num_channels,
|
||||
stride=1):
|
||||
"""Build one branch."""
|
||||
downsample = None
|
||||
if stride != 1 or \
|
||||
self.in_channels[branch_index] != \
|
||||
num_channels[branch_index] * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.in_channels[branch_index],
|
||||
num_channels[branch_index] * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, num_channels[branch_index] *
|
||||
block.expansion)[1])
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(
|
||||
self.in_channels[branch_index],
|
||||
num_channels[branch_index],
|
||||
stride,
|
||||
downsample=downsample,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=self.block_init_cfg))
|
||||
self.in_channels[branch_index] = \
|
||||
num_channels[branch_index] * block.expansion
|
||||
for i in range(1, num_blocks[branch_index]):
|
||||
layers.append(
|
||||
block(
|
||||
self.in_channels[branch_index],
|
||||
num_channels[branch_index],
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=self.block_init_cfg))
|
||||
|
||||
return Sequential(*layers)
|
||||
|
||||
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
||||
"""Build multiple branch."""
|
||||
branches = []
|
||||
|
||||
for i in range(num_branches):
|
||||
branches.append(
|
||||
self._make_one_branch(i, block, num_blocks, num_channels))
|
||||
|
||||
return ModuleList(branches)
|
||||
|
||||
def _make_fuse_layers(self):
|
||||
"""Build fuse layer."""
|
||||
if self.num_branches == 1:
|
||||
return None
|
||||
|
||||
num_branches = self.num_branches
|
||||
in_channels = self.in_channels
|
||||
fuse_layers = []
|
||||
num_out_branches = num_branches if self.multiscale_output else 1
|
||||
for i in range(num_out_branches):
|
||||
fuse_layer = []
|
||||
for j in range(num_branches):
|
||||
if j > i:
|
||||
fuse_layer.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels[j],
|
||||
in_channels[i],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, in_channels[i])[1],
|
||||
# we set align_corners=False for HRNet
|
||||
Upsample(
|
||||
scale_factor=2**(j - i),
|
||||
mode='bilinear',
|
||||
align_corners=False)))
|
||||
elif j == i:
|
||||
fuse_layer.append(None)
|
||||
else:
|
||||
conv_downsamples = []
|
||||
for k in range(i - j):
|
||||
if k == i - j - 1:
|
||||
conv_downsamples.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels[j],
|
||||
in_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg,
|
||||
in_channels[i])[1]))
|
||||
else:
|
||||
conv_downsamples.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels[j],
|
||||
in_channels[j],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg,
|
||||
in_channels[j])[1],
|
||||
nn.ReLU(inplace=False)))
|
||||
fuse_layer.append(nn.Sequential(*conv_downsamples))
|
||||
fuse_layers.append(nn.ModuleList(fuse_layer))
|
||||
|
||||
return nn.ModuleList(fuse_layers)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
if self.num_branches == 1:
|
||||
return [self.branches[0](x[0])]
|
||||
|
||||
for i in range(self.num_branches):
|
||||
x[i] = self.branches[i](x[i])
|
||||
|
||||
x_fuse = []
|
||||
for i in range(len(self.fuse_layers)):
|
||||
y = 0
|
||||
for j in range(self.num_branches):
|
||||
if i == j:
|
||||
y += x[j]
|
||||
elif j > i:
|
||||
y = y + resize(
|
||||
self.fuse_layers[i][j](x[j]),
|
||||
size=x[i].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
else:
|
||||
y += self.fuse_layers[i][j](x[j])
|
||||
x_fuse.append(self.relu(y))
|
||||
return x_fuse
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class HRNet(BaseModule):
|
||||
"""HRNet backbone.
|
||||
|
||||
This backbone is the implementation of `High-Resolution Representations
|
||||
for Labeling Pixels and Regions <https://arxiv.org/abs/1904.04514>`_.
|
||||
|
||||
Args:
|
||||
extra (dict): Detailed configuration for each stage of HRNet.
|
||||
There must be 4 stages, the configuration for each stage must have
|
||||
5 keys:
|
||||
|
||||
- num_modules (int): The number of HRModule in this stage.
|
||||
- num_branches (int): The number of branches in the HRModule.
|
||||
- block (str): The type of convolution block.
|
||||
- num_blocks (tuple): The number of blocks in each branch.
|
||||
The length must be equal to num_branches.
|
||||
- num_channels (tuple): The number of channels in each branch.
|
||||
The length must be equal to num_branches.
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
conv_cfg (dict): Dictionary to construct and config conv layer.
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Use `BN` by default.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Default: -1.
|
||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity. Default: False.
|
||||
multiscale_output (bool): Whether to output multi-level features
|
||||
produced by multiple branches. If False, only the first level
|
||||
feature will be output. Default: True.
|
||||
pretrained (str, optional): Model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> from mmseg.models import HRNet
|
||||
>>> import torch
|
||||
>>> extra = dict(
|
||||
>>> stage1=dict(
|
||||
>>> num_modules=1,
|
||||
>>> num_branches=1,
|
||||
>>> block='BOTTLENECK',
|
||||
>>> num_blocks=(4, ),
|
||||
>>> num_channels=(64, )),
|
||||
>>> stage2=dict(
|
||||
>>> num_modules=1,
|
||||
>>> num_branches=2,
|
||||
>>> block='BASIC',
|
||||
>>> num_blocks=(4, 4),
|
||||
>>> num_channels=(32, 64)),
|
||||
>>> stage3=dict(
|
||||
>>> num_modules=4,
|
||||
>>> num_branches=3,
|
||||
>>> block='BASIC',
|
||||
>>> num_blocks=(4, 4, 4),
|
||||
>>> num_channels=(32, 64, 128)),
|
||||
>>> stage4=dict(
|
||||
>>> num_modules=3,
|
||||
>>> num_branches=4,
|
||||
>>> block='BASIC',
|
||||
>>> num_blocks=(4, 4, 4, 4),
|
||||
>>> num_channels=(32, 64, 128, 256)))
|
||||
>>> self = HRNet(extra, in_channels=1)
|
||||
>>> self.eval()
|
||||
>>> inputs = torch.rand(1, 1, 32, 32)
|
||||
>>> level_outputs = self.forward(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
(1, 32, 8, 8)
|
||||
(1, 64, 4, 4)
|
||||
(1, 128, 2, 2)
|
||||
(1, 256, 1, 1)
|
||||
"""
|
||||
|
||||
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
|
||||
|
||||
def __init__(self,
|
||||
extra,
|
||||
in_channels=3,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
frozen_stages=-1,
|
||||
zero_init_residual=False,
|
||||
multiscale_output=True,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
self.zero_init_residual = zero_init_residual
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
# Assert configurations of 4 stages are in extra
|
||||
assert 'stage1' in extra and 'stage2' in extra \
|
||||
and 'stage3' in extra and 'stage4' in extra
|
||||
# Assert whether the length of `num_blocks` and `num_channels` are
|
||||
# equal to `num_branches`
|
||||
for i in range(4):
|
||||
cfg = extra[f'stage{i + 1}']
|
||||
assert len(cfg['num_blocks']) == cfg['num_branches'] and \
|
||||
len(cfg['num_channels']) == cfg['num_branches']
|
||||
|
||||
self.extra = extra
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
# stem net
|
||||
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
64,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.conv2 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
64,
|
||||
64,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
# stage 1
|
||||
self.stage1_cfg = self.extra['stage1']
|
||||
num_channels = self.stage1_cfg['num_channels'][0]
|
||||
block_type = self.stage1_cfg['block']
|
||||
num_blocks = self.stage1_cfg['num_blocks'][0]
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
stage1_out_channels = num_channels * block.expansion
|
||||
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
|
||||
|
||||
# stage 2
|
||||
self.stage2_cfg = self.extra['stage2']
|
||||
num_channels = self.stage2_cfg['num_channels']
|
||||
block_type = self.stage2_cfg['block']
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
num_channels = [channel * block.expansion for channel in num_channels]
|
||||
self.transition1 = self._make_transition_layer([stage1_out_channels],
|
||||
num_channels)
|
||||
self.stage2, pre_stage_channels = self._make_stage(
|
||||
self.stage2_cfg, num_channels)
|
||||
|
||||
# stage 3
|
||||
self.stage3_cfg = self.extra['stage3']
|
||||
num_channels = self.stage3_cfg['num_channels']
|
||||
block_type = self.stage3_cfg['block']
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
num_channels = [channel * block.expansion for channel in num_channels]
|
||||
self.transition2 = self._make_transition_layer(pre_stage_channels,
|
||||
num_channels)
|
||||
self.stage3, pre_stage_channels = self._make_stage(
|
||||
self.stage3_cfg, num_channels)
|
||||
|
||||
# stage 4
|
||||
self.stage4_cfg = self.extra['stage4']
|
||||
num_channels = self.stage4_cfg['num_channels']
|
||||
block_type = self.stage4_cfg['block']
|
||||
|
||||
block = self.blocks_dict[block_type]
|
||||
num_channels = [channel * block.expansion for channel in num_channels]
|
||||
self.transition3 = self._make_transition_layer(pre_stage_channels,
|
||||
num_channels)
|
||||
self.stage4, pre_stage_channels = self._make_stage(
|
||||
self.stage4_cfg, num_channels, multiscale_output=multiscale_output)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
"""nn.Module: the normalization layer named "norm2" """
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def _make_transition_layer(self, num_channels_pre_layer,
|
||||
num_channels_cur_layer):
|
||||
"""Make transition layer."""
|
||||
num_branches_cur = len(num_channels_cur_layer)
|
||||
num_branches_pre = len(num_channels_pre_layer)
|
||||
|
||||
transition_layers = []
|
||||
for i in range(num_branches_cur):
|
||||
if i < num_branches_pre:
|
||||
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
||||
transition_layers.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
num_channels_pre_layer[i],
|
||||
num_channels_cur_layer[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg,
|
||||
num_channels_cur_layer[i])[1],
|
||||
nn.ReLU(inplace=True)))
|
||||
else:
|
||||
transition_layers.append(None)
|
||||
else:
|
||||
conv_downsamples = []
|
||||
for j in range(i + 1 - num_branches_pre):
|
||||
in_channels = num_channels_pre_layer[-1]
|
||||
out_channels = num_channels_cur_layer[i] \
|
||||
if j == i - num_branches_pre else in_channels
|
||||
conv_downsamples.append(
|
||||
nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, out_channels)[1],
|
||||
nn.ReLU(inplace=True)))
|
||||
transition_layers.append(nn.Sequential(*conv_downsamples))
|
||||
|
||||
return nn.ModuleList(transition_layers)
|
||||
|
||||
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
||||
"""Make each layer."""
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
|
||||
|
||||
layers = []
|
||||
block_init_cfg = None
|
||||
if self.pretrained is None and not hasattr(
|
||||
self, 'init_cfg') and self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm3'))
|
||||
|
||||
layers.append(
|
||||
block(
|
||||
inplanes,
|
||||
planes,
|
||||
stride,
|
||||
downsample=downsample,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=block_init_cfg))
|
||||
inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
block(
|
||||
inplanes,
|
||||
planes,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
init_cfg=block_init_cfg))
|
||||
|
||||
return Sequential(*layers)
|
||||
|
||||
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
|
||||
"""Make each stage."""
|
||||
num_modules = layer_config['num_modules']
|
||||
num_branches = layer_config['num_branches']
|
||||
num_blocks = layer_config['num_blocks']
|
||||
num_channels = layer_config['num_channels']
|
||||
block = self.blocks_dict[layer_config['block']]
|
||||
|
||||
hr_modules = []
|
||||
block_init_cfg = None
|
||||
if self.pretrained is None and not hasattr(
|
||||
self, 'init_cfg') and self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant', val=0, override=dict(name='norm3'))
|
||||
|
||||
for i in range(num_modules):
|
||||
# multi_scale_output is only used for the last module
|
||||
if not multiscale_output and i == num_modules - 1:
|
||||
reset_multiscale_output = False
|
||||
else:
|
||||
reset_multiscale_output = True
|
||||
|
||||
hr_modules.append(
|
||||
HRModule(
|
||||
num_branches,
|
||||
block,
|
||||
num_blocks,
|
||||
in_channels,
|
||||
num_channels,
|
||||
reset_multiscale_output,
|
||||
with_cp=self.with_cp,
|
||||
norm_cfg=self.norm_cfg,
|
||||
conv_cfg=self.conv_cfg,
|
||||
block_init_cfg=block_init_cfg))
|
||||
|
||||
return Sequential(*hr_modules), in_channels
|
||||
|
||||
def _freeze_stages(self):
|
||||
"""Freeze stages param and norm stats."""
|
||||
if self.frozen_stages >= 0:
|
||||
|
||||
self.norm1.eval()
|
||||
self.norm2.eval()
|
||||
for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
if i == 1:
|
||||
m = getattr(self, f'layer{i}')
|
||||
t = getattr(self, f'transition{i}')
|
||||
elif i == 4:
|
||||
m = getattr(self, f'stage{i}')
|
||||
else:
|
||||
m = getattr(self, f'stage{i}')
|
||||
t = getattr(self, f'transition{i}')
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
t.eval()
|
||||
for param in t.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer1(x)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage2_cfg['num_branches']):
|
||||
if self.transition1[i] is not None:
|
||||
x_list.append(self.transition1[i](x))
|
||||
else:
|
||||
x_list.append(x)
|
||||
y_list = self.stage2(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage3_cfg['num_branches']):
|
||||
if self.transition2[i] is not None:
|
||||
x_list.append(self.transition2[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
y_list = self.stage3(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage4_cfg['num_branches']):
|
||||
if self.transition3[i] is not None:
|
||||
x_list.append(self.transition3[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
y_list = self.stage4(x_list)
|
||||
|
||||
return y_list
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode will keeping the normalization
|
||||
layer freezed."""
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
166
finetune/mmseg/models/backbones/icnet.py
Normal file
166
finetune/mmseg/models/backbones/icnet.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..decode_heads.psp_head import PPM
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ICNet(BaseModule):
|
||||
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
|
||||
|
||||
This backbone is the implementation of
|
||||
`ICNet <https://arxiv.org/abs/1704.08545>`_.
|
||||
|
||||
Args:
|
||||
backbone_cfg (dict): Config dict to build backbone. Usually it is
|
||||
ResNet but it can also be other backbones.
|
||||
in_channels (int): The number of input image channels. Default: 3.
|
||||
layer_channels (Sequence[int]): The numbers of feature channels at
|
||||
layer 2 and layer 4 in ResNet. It can also be other backbones.
|
||||
Default: (512, 2048).
|
||||
light_branch_middle_channels (int): The number of channels of the
|
||||
middle layer in light branch. Default: 32.
|
||||
psp_out_channels (int): The number of channels of the output of PSP
|
||||
module. Default: 512.
|
||||
out_channels (Sequence[int]): The numbers of output feature channels
|
||||
at each branches. Default: (64, 256, 256).
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module. Default: (1, 2, 3, 6).
|
||||
conv_cfg (dict): Dictionary to construct and config conv layer.
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Dictionary to construct and config act layer.
|
||||
Default: dict(type='ReLU').
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
in_channels=3,
|
||||
layer_channels=(512, 2048),
|
||||
light_branch_middle_channels=32,
|
||||
psp_out_channels=512,
|
||||
out_channels=(64, 256, 256),
|
||||
pool_scales=(1, 2, 3, 6),
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
align_corners=False,
|
||||
init_cfg=None):
|
||||
if backbone_cfg is None:
|
||||
raise TypeError('backbone_cfg must be passed from config file!')
|
||||
if init_cfg is None:
|
||||
init_cfg = [
|
||||
dict(type='Kaiming', mode='fan_out', layer='Conv2d'),
|
||||
dict(type='Constant', val=1, layer='_BatchNorm'),
|
||||
dict(type='Normal', mean=0.01, layer='Linear')
|
||||
]
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.align_corners = align_corners
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
|
||||
# Note: Default `ceil_mode` is false in nn.MaxPool2d, set
|
||||
# `ceil_mode=True` to keep information in the corner of feature map.
|
||||
self.backbone.maxpool = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1, ceil_mode=True)
|
||||
|
||||
self.psp_modules = PPM(
|
||||
pool_scales=pool_scales,
|
||||
in_channels=layer_channels[1],
|
||||
channels=psp_out_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
align_corners=align_corners)
|
||||
|
||||
self.psp_bottleneck = ConvModule(
|
||||
layer_channels[1] + len(pool_scales) * psp_out_channels,
|
||||
psp_out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.conv_sub1 = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=light_branch_middle_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg),
|
||||
ConvModule(
|
||||
in_channels=light_branch_middle_channels,
|
||||
out_channels=light_branch_middle_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg),
|
||||
ConvModule(
|
||||
in_channels=light_branch_middle_channels,
|
||||
out_channels=out_channels[0],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg))
|
||||
|
||||
self.conv_sub2 = ConvModule(
|
||||
layer_channels[0],
|
||||
out_channels[1],
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg)
|
||||
|
||||
self.conv_sub4 = ConvModule(
|
||||
psp_out_channels,
|
||||
out_channels[2],
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
output = []
|
||||
|
||||
# sub 1
|
||||
output.append(self.conv_sub1(x))
|
||||
|
||||
# sub 2
|
||||
x = resize(
|
||||
x,
|
||||
scale_factor=0.5,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.backbone.stem(x)
|
||||
x = self.backbone.maxpool(x)
|
||||
x = self.backbone.layer1(x)
|
||||
x = self.backbone.layer2(x)
|
||||
output.append(self.conv_sub2(x))
|
||||
|
||||
# sub 4
|
||||
x = resize(
|
||||
x,
|
||||
scale_factor=0.5,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.backbone.layer3(x)
|
||||
x = self.backbone.layer4(x)
|
||||
psp_outs = self.psp_modules(x) + [x]
|
||||
psp_outs = torch.cat(psp_outs, dim=1)
|
||||
x = self.psp_bottleneck(psp_outs)
|
||||
|
||||
output.append(self.conv_sub4(x))
|
||||
|
||||
return output
|
||||
260
finetune/mmseg/models/backbones/mae.py
Normal file
260
finetune/mmseg/models/backbones/mae.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.import math
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import _load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
|
||||
|
||||
|
||||
class MAEAttention(BEiTAttention):
|
||||
"""Multi-head self-attention with relative position bias used in MAE.
|
||||
|
||||
This module is different from ``BEiTAttention`` by initializing the
|
||||
relative bias table with zeros.
|
||||
"""
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize relative position bias with zeros."""
|
||||
|
||||
# As MAE initializes relative position bias as zeros and this class
|
||||
# inherited from BEiT which initializes relative position bias
|
||||
# with `trunc_normal`, `init_weights` here does
|
||||
# nothing and just passes directly
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer):
|
||||
"""Implements one encoder layer in Vision Transformer.
|
||||
|
||||
This module is different from ``BEiTTransformerEncoderLayer`` by replacing
|
||||
``BEiTAttention`` with ``MAEAttention``.
|
||||
"""
|
||||
|
||||
def build_attn(self, attn_cfg):
|
||||
self.attn = MAEAttention(**attn_cfg)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MAE(BEiT):
|
||||
"""VisionTransformer with support for patch.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Default: 224.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): embedding dimension. Default: 768.
|
||||
num_layers (int): depth of transformer. Default: 12.
|
||||
num_heads (int): number of attention heads. Default: 12.
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
out_indices (list | tuple | int): Output from which stages.
|
||||
Default: -1.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_values (float): Initialize the values of Attention and FFN
|
||||
with learnable scaling. Defaults to 0.1.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
out_indices=-1,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
final_norm=False,
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
pretrained=None,
|
||||
init_values=0.1,
|
||||
init_cfg=None):
|
||||
super().__init__(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
out_indices=out_indices,
|
||||
qv_bias=False,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
patch_norm=patch_norm,
|
||||
final_norm=final_norm,
|
||||
num_fcs=num_fcs,
|
||||
norm_eval=norm_eval,
|
||||
pretrained=pretrained,
|
||||
init_values=init_values,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
|
||||
self.num_patches = self.patch_shape[0] * self.patch_shape[1]
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, self.num_patches + 1, embed_dims))
|
||||
|
||||
def _build_layers(self):
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
|
||||
]
|
||||
self.layers = ModuleList()
|
||||
for i in range(self.num_layers):
|
||||
self.layers.append(
|
||||
MAETransformerEncoderLayer(
|
||||
embed_dims=self.embed_dims,
|
||||
num_heads=self.num_heads,
|
||||
feedforward_channels=self.mlp_ratio * self.embed_dims,
|
||||
attn_drop_rate=self.attn_drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=self.num_fcs,
|
||||
bias=True,
|
||||
act_cfg=self.act_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
window_size=self.patch_shape,
|
||||
init_values=self.init_values))
|
||||
|
||||
def fix_init_weight(self):
|
||||
"""Rescale the initialization according to layer id.
|
||||
|
||||
This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501
|
||||
Copyright (c) Microsoft Corporation
|
||||
Licensed under the MIT License
|
||||
"""
|
||||
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.layers):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.ffn.layers[1].weight.data, layer_id + 1)
|
||||
|
||||
def init_weights(self):
|
||||
|
||||
def _init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
self.apply(_init_weights)
|
||||
self.fix_init_weight()
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg.get('type') == 'Pretrained'):
|
||||
checkpoint = _load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
state_dict = self.resize_rel_pos_embed(checkpoint)
|
||||
state_dict = self.resize_abs_pos_embed(state_dict)
|
||||
self.load_state_dict(state_dict, False)
|
||||
elif self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
# Copyright 2019 Ross Wightman
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def resize_abs_pos_embed(self, state_dict):
|
||||
if 'pos_embed' in state_dict:
|
||||
pos_embed_checkpoint = state_dict['pos_embed']
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int(
|
||||
(pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(self.num_patches**0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
if orig_size != new_size:
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
|
||||
embedding_size).permute(
|
||||
0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens,
|
||||
size=(new_size, new_size),
|
||||
mode='bicubic',
|
||||
align_corners=False)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
state_dict['pos_embed'] = new_pos_embed
|
||||
return state_dict
|
||||
|
||||
def forward(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == len(self.layers) - 1:
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
if i in self.out_indices:
|
||||
out = x[:, 1:]
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
450
finetune/mmseg/models/backbones/mit.py
Normal file
450
finetune/mmseg/models/backbones/mit.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
from mmengine.model.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
|
||||
|
||||
|
||||
class MixFFN(BaseModule):
|
||||
"""An implementation of MixFFN of Segformer.
|
||||
|
||||
The differences between MixFFN & FFN:
|
||||
1. Use 1X1 Conv to replace Linear layer.
|
||||
2. Introduce 3X3 Conv to encode positional information.
|
||||
Args:
|
||||
embed_dims (int): The feature dimension. Same as
|
||||
`MultiheadAttention`. Defaults: 256.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 1024.
|
||||
act_cfg (dict, optional): The activation config for FFNs.
|
||||
Default: dict(type='ReLU')
|
||||
ffn_drop (float, optional): Probability of an element to be
|
||||
zeroed in FFN. Default 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
feedforward_channels,
|
||||
act_cfg=dict(type='GELU'),
|
||||
ffn_drop=0.,
|
||||
dropout_layer=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.feedforward_channels = feedforward_channels
|
||||
self.act_cfg = act_cfg
|
||||
self.activate = build_activation_layer(act_cfg)
|
||||
|
||||
in_channels = embed_dims
|
||||
fc1 = Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feedforward_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True)
|
||||
# 3x3 depth wise conv to provide positional encode information
|
||||
pe_conv = Conv2d(
|
||||
in_channels=feedforward_channels,
|
||||
out_channels=feedforward_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=(3 - 1) // 2,
|
||||
bias=True,
|
||||
groups=feedforward_channels)
|
||||
fc2 = Conv2d(
|
||||
in_channels=feedforward_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=True)
|
||||
drop = nn.Dropout(ffn_drop)
|
||||
layers = [fc1, pe_conv, self.activate, drop, fc2, drop]
|
||||
self.layers = Sequential(*layers)
|
||||
self.dropout_layer = build_dropout(
|
||||
dropout_layer) if dropout_layer else torch.nn.Identity()
|
||||
|
||||
def forward(self, x, hw_shape, identity=None):
|
||||
out = nlc_to_nchw(x, hw_shape)
|
||||
out = self.layers(out)
|
||||
out = nchw_to_nlc(out)
|
||||
if identity is None:
|
||||
identity = x
|
||||
return identity + self.dropout_layer(out)
|
||||
|
||||
|
||||
class EfficientMultiheadAttention(MultiheadAttention):
|
||||
"""An implementation of Efficient Multi-head Attention of Segformer.
|
||||
|
||||
This module is modified from MultiheadAttention which is a module from
|
||||
mmcv.cnn.bricks.transformer.
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
attn_drop (float): A Dropout layer on attn_output_weights.
|
||||
Default: 0.0.
|
||||
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
||||
Default: 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut. Default: None.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: False.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||
Attention of Segformer. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
dropout_layer=None,
|
||||
init_cfg=None,
|
||||
batch_first=True,
|
||||
qkv_bias=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=1):
|
||||
super().__init__(
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop,
|
||||
proj_drop,
|
||||
dropout_layer=dropout_layer,
|
||||
init_cfg=init_cfg,
|
||||
batch_first=batch_first,
|
||||
bias=qkv_bias)
|
||||
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1:
|
||||
self.sr = Conv2d(
|
||||
in_channels=embed_dims,
|
||||
out_channels=embed_dims,
|
||||
kernel_size=sr_ratio,
|
||||
stride=sr_ratio)
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
# handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
|
||||
from mmseg import digit_version, mmcv_version
|
||||
if mmcv_version < digit_version('1.3.17'):
|
||||
warnings.warn('The legacy version of forward function in'
|
||||
'EfficientMultiheadAttention is deprecated in'
|
||||
'mmcv>=1.3.17 and will no longer support in the'
|
||||
'future. Please upgrade your mmcv.')
|
||||
self.forward = self.legacy_forward
|
||||
|
||||
def forward(self, x, hw_shape, identity=None):
|
||||
|
||||
x_q = x
|
||||
if self.sr_ratio > 1:
|
||||
x_kv = nlc_to_nchw(x, hw_shape)
|
||||
x_kv = self.sr(x_kv)
|
||||
x_kv = nchw_to_nlc(x_kv)
|
||||
x_kv = self.norm(x_kv)
|
||||
else:
|
||||
x_kv = x
|
||||
|
||||
if identity is None:
|
||||
identity = x_q
|
||||
|
||||
# Because the dataflow('key', 'query', 'value') of
|
||||
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
|
||||
# embed_dims), We should adjust the shape of dataflow from
|
||||
# batch_first (batch, num_query, embed_dims) to num_query_first
|
||||
# (num_query ,batch, embed_dims), and recover ``attn_output``
|
||||
# from num_query_first to batch_first.
|
||||
if self.batch_first:
|
||||
x_q = x_q.transpose(0, 1)
|
||||
x_kv = x_kv.transpose(0, 1)
|
||||
|
||||
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
|
||||
|
||||
if self.batch_first:
|
||||
out = out.transpose(0, 1)
|
||||
|
||||
return identity + self.dropout_layer(self.proj_drop(out))
|
||||
|
||||
def legacy_forward(self, x, hw_shape, identity=None):
|
||||
"""multi head attention forward in mmcv version < 1.3.17."""
|
||||
|
||||
x_q = x
|
||||
if self.sr_ratio > 1:
|
||||
x_kv = nlc_to_nchw(x, hw_shape)
|
||||
x_kv = self.sr(x_kv)
|
||||
x_kv = nchw_to_nlc(x_kv)
|
||||
x_kv = self.norm(x_kv)
|
||||
else:
|
||||
x_kv = x
|
||||
|
||||
if identity is None:
|
||||
identity = x_q
|
||||
|
||||
# `need_weights=True` will let nn.MultiHeadAttention
|
||||
# `return attn_output, attn_output_weights.sum(dim=1) / num_heads`
|
||||
# The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set
|
||||
# `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`.
|
||||
# This issue - `https://github.com/pytorch/pytorch/issues/37583` report
|
||||
# the error that large scale tensor sum operation may cause cuda error.
|
||||
out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0]
|
||||
|
||||
return identity + self.dropout_layer(self.proj_drop(out))
|
||||
|
||||
|
||||
class TransformerEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in Segformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
after the feed forward layer. Default 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
qkv_bias (bool): enable bias for qkv if True.
|
||||
Default: True.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: False.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Default:None.
|
||||
sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head
|
||||
Attention of Segformer. Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
batch_first=True,
|
||||
sr_ratio=1,
|
||||
with_cp=False):
|
||||
super().__init__()
|
||||
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
self.attn = EfficientMultiheadAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
batch_first=batch_first,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg,
|
||||
sr_ratio=sr_ratio)
|
||||
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
self.ffn = MixFFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.with_cp = with_cp
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
|
||||
def _inner_forward(x):
|
||||
x = self.attn(self.norm1(x), hw_shape, identity=x)
|
||||
x = self.ffn(self.norm2(x), hw_shape, identity=x)
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MixVisionTransformer(BaseModule):
|
||||
"""The backbone of Segformer.
|
||||
|
||||
This backbone is the implementation of `SegFormer: Simple and
|
||||
Efficient Design for Semantic Segmentation with
|
||||
Transformers <https://arxiv.org/abs/2105.15203>`_.
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): Embedding dimension. Default: 768.
|
||||
num_stags (int): The num of stages. Default: 4.
|
||||
num_layers (Sequence[int]): The layer number of each transformer encode
|
||||
layer. Default: [3, 4, 6, 3].
|
||||
num_heads (Sequence[int]): The attention heads of each transformer
|
||||
encode layer. Default: [1, 2, 4, 8].
|
||||
patch_sizes (Sequence[int]): The patch_size of each overlapped patch
|
||||
embedding. Default: [7, 3, 3, 3].
|
||||
strides (Sequence[int]): The stride of each overlapped patch embedding.
|
||||
Default: [4, 2, 2, 2].
|
||||
sr_ratios (Sequence[int]): The spatial reduction rate of each
|
||||
transformer encode layer. Default: [8, 4, 2, 1].
|
||||
out_indices (Sequence[int] | int): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.0
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=64,
|
||||
num_stages=4,
|
||||
num_layers=[3, 4, 6, 3],
|
||||
num_heads=[1, 2, 4, 8],
|
||||
patch_sizes=[7, 3, 3, 3],
|
||||
strides=[4, 2, 2, 2],
|
||||
sr_ratios=[8, 4, 2, 1],
|
||||
out_indices=(0, 1, 2, 3),
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
pretrained=None,
|
||||
init_cfg=None,
|
||||
with_cp=False):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.num_stages = num_stages
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.patch_sizes = patch_sizes
|
||||
self.strides = strides
|
||||
self.sr_ratios = sr_ratios
|
||||
self.with_cp = with_cp
|
||||
assert num_stages == len(num_layers) == len(num_heads) \
|
||||
== len(patch_sizes) == len(strides) == len(sr_ratios)
|
||||
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < self.num_stages
|
||||
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
|
||||
] # stochastic num_layer decay rule
|
||||
|
||||
cur = 0
|
||||
self.layers = ModuleList()
|
||||
for i, num_layer in enumerate(num_layers):
|
||||
embed_dims_i = embed_dims * num_heads[i]
|
||||
patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims_i,
|
||||
kernel_size=patch_sizes[i],
|
||||
stride=strides[i],
|
||||
padding=patch_sizes[i] // 2,
|
||||
norm_cfg=norm_cfg)
|
||||
layer = ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
embed_dims=embed_dims_i,
|
||||
num_heads=num_heads[i],
|
||||
feedforward_channels=mlp_ratio * embed_dims_i,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[cur + idx],
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
sr_ratio=sr_ratios[i]) for idx in range(num_layer)
|
||||
])
|
||||
in_channels = embed_dims_i
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
|
||||
self.layers.append(ModuleList([patch_embed, layer, norm]))
|
||||
cur += num_layer
|
||||
|
||||
def init_weights(self):
|
||||
if self.init_cfg is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||
1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
normal_init(
|
||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x, hw_shape = layer[0](x)
|
||||
for block in layer[1]:
|
||||
x = block(x, hw_shape)
|
||||
x = layer[2](x)
|
||||
x = nlc_to_nchw(x, hw_shape)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
return outs
|
||||
197
finetune/mmseg/models/backbones/mobilenet_v2.py
Normal file
197
finetune/mmseg/models/backbones/mobilenet_v2.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidual, make_divisible
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MobileNetV2(BaseModule):
|
||||
"""MobileNetV2 backbone.
|
||||
|
||||
This backbone is the implementation of
|
||||
`MobileNetV2: Inverted Residuals and Linear Bottlenecks
|
||||
<https://arxiv.org/abs/1801.04381>`_.
|
||||
|
||||
Args:
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Default: 1.0.
|
||||
strides (Sequence[int], optional): Strides of the first block of each
|
||||
layer. If not specified, default config in ``arch_setting`` will
|
||||
be used.
|
||||
dilations (Sequence[int]): Dilation of each layer.
|
||||
out_indices (None or Sequence[int]): Output from which stages.
|
||||
Default: (7, ).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Default: -1, which means not freezing any parameters.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU6').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
# Parameters to build layers. 3 parameters are needed to construct a
|
||||
# layer, from left to right: expand_ratio, channel, num_blocks.
|
||||
arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
|
||||
[6, 96, 3], [6, 160, 3], [6, 320, 1]]
|
||||
|
||||
def __init__(self,
|
||||
widen_factor=1.,
|
||||
strides=(1, 2, 2, 2, 1, 2, 1),
|
||||
dilations=(1, 1, 1, 1, 1, 1, 1),
|
||||
out_indices=(1, 2, 4, 6),
|
||||
frozen_stages=-1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6'),
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.widen_factor = widen_factor
|
||||
self.strides = strides
|
||||
self.dilations = dilations
|
||||
assert len(strides) == len(dilations) == len(self.arch_settings)
|
||||
self.out_indices = out_indices
|
||||
for index in out_indices:
|
||||
if index not in range(0, 7):
|
||||
raise ValueError('the item in out_indices must in '
|
||||
f'range(0, 7). But received {index}')
|
||||
|
||||
if frozen_stages not in range(-1, 7):
|
||||
raise ValueError('frozen_stages must be in range(-1, 7). '
|
||||
f'But received {frozen_stages}')
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.in_channels = make_divisible(32 * widen_factor, 8)
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=self.in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.layers = []
|
||||
|
||||
for i, layer_cfg in enumerate(self.arch_settings):
|
||||
expand_ratio, channel, num_blocks = layer_cfg
|
||||
stride = self.strides[i]
|
||||
dilation = self.dilations[i]
|
||||
out_channels = make_divisible(channel * widen_factor, 8)
|
||||
inverted_res_layer = self.make_layer(
|
||||
out_channels=out_channels,
|
||||
num_blocks=num_blocks,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
expand_ratio=expand_ratio)
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, inverted_res_layer)
|
||||
self.layers.append(layer_name)
|
||||
|
||||
def make_layer(self, out_channels, num_blocks, stride, dilation,
|
||||
expand_ratio):
|
||||
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
|
||||
|
||||
Args:
|
||||
out_channels (int): out_channels of block.
|
||||
num_blocks (int): Number of blocks.
|
||||
stride (int): Stride of the first block.
|
||||
dilation (int): Dilation of the first block.
|
||||
expand_ratio (int): Expand the number of channels of the
|
||||
hidden layer in InvertedResidual by this ratio.
|
||||
"""
|
||||
layers = []
|
||||
for i in range(num_blocks):
|
||||
layers.append(
|
||||
InvertedResidual(
|
||||
self.in_channels,
|
||||
out_channels,
|
||||
stride if i == 0 else 1,
|
||||
expand_ratio=expand_ratio,
|
||||
dilation=dilation if i == 0 else 1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
with_cp=self.with_cp))
|
||||
self.in_channels = out_channels
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
if len(outs) == 1:
|
||||
return outs[0]
|
||||
else:
|
||||
return tuple(outs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
for param in self.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
267
finetune/mmseg/models/backbones/mobilenet_v3.py
Normal file
267
finetune/mmseg/models/backbones/mobilenet_v3.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.cnn.bricks import Conv2dAdaptivePadding
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils import is_tuple_of
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidualV3 as InvertedResidual
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MobileNetV3(BaseModule):
|
||||
"""MobileNetV3 backbone.
|
||||
|
||||
This backbone is the improved implementation of `Searching for MobileNetV3
|
||||
<https://ieeexplore.ieee.org/document/9008835>`_.
|
||||
|
||||
Args:
|
||||
arch (str): Architecture of mobilnetv3, from {'small', 'large'}.
|
||||
Default: 'small'.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
out_indices (tuple[int]): Output from which layer.
|
||||
Default: (0, 1, 12).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Default: -1, which means not freezing any parameters.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
# Parameters to build each block:
|
||||
# [kernel size, mid channels, out channels, with_se, act type, stride]
|
||||
arch_settings = {
|
||||
'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
|
||||
[3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
|
||||
[3, 88, 24, False, 'ReLU', 1],
|
||||
[5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
|
||||
[5, 240, 40, True, 'HSwish', 1],
|
||||
[5, 240, 40, True, 'HSwish', 1],
|
||||
[5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
|
||||
[5, 144, 48, True, 'HSwish', 1],
|
||||
[5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
|
||||
[5, 576, 96, True, 'HSwish', 1],
|
||||
[5, 576, 96, True, 'HSwish', 1]],
|
||||
'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
|
||||
[3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
|
||||
[3, 72, 24, False, 'ReLU', 1],
|
||||
[5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
|
||||
[3, 200, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
|
||||
[3, 672, 112, True, 'HSwish', 1],
|
||||
[5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
|
||||
[5, 960, 160, True, 'HSwish', 1],
|
||||
[5, 960, 160, True, 'HSwish', 1]]
|
||||
} # yapf: disable
|
||||
|
||||
def __init__(self,
|
||||
arch='small',
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
out_indices=(0, 1, 12),
|
||||
frozen_stages=-1,
|
||||
reduction_factor=1,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
assert arch in self.arch_settings
|
||||
assert isinstance(reduction_factor, int) and reduction_factor > 0
|
||||
assert is_tuple_of(out_indices, int)
|
||||
for index in out_indices:
|
||||
if index not in range(0, len(self.arch_settings[arch]) + 2):
|
||||
raise ValueError(
|
||||
'the item in out_indices must in '
|
||||
f'range(0, {len(self.arch_settings[arch])+2}). '
|
||||
f'But received {index}')
|
||||
|
||||
if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
|
||||
raise ValueError('frozen_stages must be in range(-1, '
|
||||
f'{len(self.arch_settings[arch])+2}). '
|
||||
f'But received {frozen_stages}')
|
||||
self.arch = arch
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.reduction_factor = reduction_factor
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.layers = self._make_layer()
|
||||
|
||||
def _make_layer(self):
|
||||
layers = []
|
||||
|
||||
# build the first layer (layer0)
|
||||
in_channels = 16
|
||||
layer = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=dict(type='Conv2dAdaptivePadding'),
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
self.add_module('layer0', layer)
|
||||
layers.append('layer0')
|
||||
|
||||
layer_setting = self.arch_settings[self.arch]
|
||||
for i, params in enumerate(layer_setting):
|
||||
(kernel_size, mid_channels, out_channels, with_se, act,
|
||||
stride) = params
|
||||
|
||||
if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
|
||||
i >= 8:
|
||||
mid_channels = mid_channels // self.reduction_factor
|
||||
out_channels = out_channels // self.reduction_factor
|
||||
|
||||
if with_se:
|
||||
se_cfg = dict(
|
||||
channels=mid_channels,
|
||||
ratio=4,
|
||||
act_cfg=(dict(type='ReLU'),
|
||||
dict(type='HSigmoid', bias=3.0, divisor=6.0)))
|
||||
else:
|
||||
se_cfg = None
|
||||
|
||||
layer = InvertedResidual(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
mid_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
se_cfg=se_cfg,
|
||||
with_expand_conv=(in_channels != mid_channels),
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type=act),
|
||||
with_cp=self.with_cp)
|
||||
in_channels = out_channels
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, layer)
|
||||
layers.append(layer_name)
|
||||
|
||||
# build the last layer
|
||||
# block5 layer12 os=32 for small model
|
||||
# block6 layer16 os=32 for large model
|
||||
layer = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=576 if self.arch == 'small' else 960,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
dilation=4,
|
||||
padding=0,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
layer_name = f'layer{len(layer_setting) + 1}'
|
||||
self.add_module(layer_name, layer)
|
||||
layers.append(layer_name)
|
||||
|
||||
# next, convert backbone MobileNetV3 to a semantic segmentation version
|
||||
if self.arch == 'small':
|
||||
self.layer4.depthwise_conv.conv.stride = (1, 1)
|
||||
self.layer9.depthwise_conv.conv.stride = (1, 1)
|
||||
for i in range(4, len(layers)):
|
||||
layer = getattr(self, layers[i])
|
||||
if isinstance(layer, InvertedResidual):
|
||||
modified_module = layer.depthwise_conv.conv
|
||||
else:
|
||||
modified_module = layer.conv
|
||||
|
||||
if i < 9:
|
||||
modified_module.dilation = (2, 2)
|
||||
pad = 2
|
||||
else:
|
||||
modified_module.dilation = (4, 4)
|
||||
pad = 4
|
||||
|
||||
if not isinstance(modified_module, Conv2dAdaptivePadding):
|
||||
# Adjust padding
|
||||
pad *= (modified_module.kernel_size[0] - 1) // 2
|
||||
modified_module.padding = (pad, pad)
|
||||
else:
|
||||
self.layer7.depthwise_conv.conv.stride = (1, 1)
|
||||
self.layer13.depthwise_conv.conv.stride = (1, 1)
|
||||
for i in range(7, len(layers)):
|
||||
layer = getattr(self, layers[i])
|
||||
if isinstance(layer, InvertedResidual):
|
||||
modified_module = layer.depthwise_conv.conv
|
||||
else:
|
||||
modified_module = layer.conv
|
||||
|
||||
if i < 13:
|
||||
modified_module.dilation = (2, 2)
|
||||
pad = 2
|
||||
else:
|
||||
modified_module.dilation = (4, 4)
|
||||
pad = 4
|
||||
|
||||
if not isinstance(modified_module, Conv2dAdaptivePadding):
|
||||
# Adjust padding
|
||||
pad *= (modified_module.kernel_size[0] - 1) // 2
|
||||
modified_module.padding = (pad, pad)
|
||||
|
||||
return layers
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
return outs
|
||||
|
||||
def _freeze_stages(self):
|
||||
for i in range(self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
467
finetune/mmseg/models/backbones/mscan.py
Normal file
467
finetune/mmseg/models/backbones/mscan.py
Normal file
@@ -0,0 +1,467 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Originally from https://github.com/visual-attention-network/segnext
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.model.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class Mlp(BaseModule):
|
||||
"""Multi Layer Perceptron (MLP) Module.
|
||||
|
||||
Args:
|
||||
in_features (int): The dimension of input features.
|
||||
hidden_features (int): The dimension of hidden features.
|
||||
Defaults: None.
|
||||
out_features (int): The dimension of output features.
|
||||
Defaults: None.
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
drop (float): The number of dropout rate in MLP block.
|
||||
Defaults: 0.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_cfg=dict(type='GELU'),
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
||||
self.dwconv = nn.Conv2d(
|
||||
hidden_features,
|
||||
hidden_features,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
bias=True,
|
||||
groups=hidden_features)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.fc1(x)
|
||||
|
||||
x = self.dwconv(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StemConv(BaseModule):
|
||||
"""Stem Block at the beginning of Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): The dimension of input channels.
|
||||
out_channels (int): The dimension of output channels.
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels // 2,
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1)),
|
||||
build_norm_layer(norm_cfg, out_channels // 2)[1],
|
||||
build_activation_layer(act_cfg),
|
||||
nn.Conv2d(
|
||||
out_channels // 2,
|
||||
out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1)),
|
||||
build_norm_layer(norm_cfg, out_channels)[1],
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.size()
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x, H, W
|
||||
|
||||
|
||||
class MSCAAttention(BaseModule):
|
||||
"""Attention Module in Multi-Scale Convolutional Attention Module (MSCA).
|
||||
|
||||
Args:
|
||||
channels (int): The dimension of channels.
|
||||
kernel_sizes (list): The size of attention
|
||||
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
paddings (list): The number of
|
||||
corresponding padding value in attention module.
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
paddings=[2, [0, 3], [0, 5], [0, 10]]):
|
||||
super().__init__()
|
||||
self.conv0 = nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=kernel_sizes[0],
|
||||
padding=paddings[0],
|
||||
groups=channels)
|
||||
for i, (kernel_size,
|
||||
padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])):
|
||||
kernel_size_ = [kernel_size, kernel_size[::-1]]
|
||||
padding_ = [padding, padding[::-1]]
|
||||
conv_name = [f'conv{i}_1', f'conv{i}_2']
|
||||
for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_,
|
||||
conv_name):
|
||||
self.add_module(
|
||||
i_conv,
|
||||
nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
tuple(i_kernel),
|
||||
padding=i_pad,
|
||||
groups=channels))
|
||||
self.conv3 = nn.Conv2d(channels, channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
u = x.clone()
|
||||
|
||||
attn = self.conv0(x)
|
||||
|
||||
# Multi-Scale Feature extraction
|
||||
attn_0 = self.conv0_1(attn)
|
||||
attn_0 = self.conv0_2(attn_0)
|
||||
|
||||
attn_1 = self.conv1_1(attn)
|
||||
attn_1 = self.conv1_2(attn_1)
|
||||
|
||||
attn_2 = self.conv2_1(attn)
|
||||
attn_2 = self.conv2_2(attn_2)
|
||||
|
||||
attn = attn + attn_0 + attn_1 + attn_2
|
||||
# Channel Mixing
|
||||
attn = self.conv3(attn)
|
||||
|
||||
# Convolutional Attention
|
||||
x = attn * u
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MSCASpatialAttention(BaseModule):
|
||||
"""Spatial Attention Module in Multi-Scale Convolutional Attention Module
|
||||
(MSCA).
|
||||
|
||||
Args:
|
||||
in_channels (int): The dimension of channels.
|
||||
attention_kernel_sizes (list): The size of attention
|
||||
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
attention_kernel_paddings (list): The number of
|
||||
corresponding padding value in attention module.
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
|
||||
act_cfg=dict(type='GELU')):
|
||||
super().__init__()
|
||||
self.proj_1 = nn.Conv2d(in_channels, in_channels, 1)
|
||||
self.activation = build_activation_layer(act_cfg)
|
||||
self.spatial_gating_unit = MSCAAttention(in_channels,
|
||||
attention_kernel_sizes,
|
||||
attention_kernel_paddings)
|
||||
self.proj_2 = nn.Conv2d(in_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
shorcut = x.clone()
|
||||
x = self.proj_1(x)
|
||||
x = self.activation(x)
|
||||
x = self.spatial_gating_unit(x)
|
||||
x = self.proj_2(x)
|
||||
x = x + shorcut
|
||||
return x
|
||||
|
||||
|
||||
class MSCABlock(BaseModule):
|
||||
"""Basic Multi-Scale Convolutional Attention Block. It leverage the large-
|
||||
kernel attention (LKA) mechanism to build both channel and spatial
|
||||
attention. In each branch, it uses two depth-wise strip convolutions to
|
||||
approximate standard depth-wise convolutions with large kernels. The kernel
|
||||
size for each branch is set to 7, 11, and 21, respectively.
|
||||
|
||||
Args:
|
||||
channels (int): The dimension of channels.
|
||||
attention_kernel_sizes (list): The size of attention
|
||||
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
attention_kernel_paddings (list): The number of
|
||||
corresponding padding value in attention module.
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
mlp_ratio (float): The ratio of multiple input dimension to
|
||||
calculate hidden feature in MLP layer. Defaults: 4.0.
|
||||
drop (float): The number of dropout rate in MLP block.
|
||||
Defaults: 0.0.
|
||||
drop_path (float): The ratio of drop paths.
|
||||
Defaults: 0.0.
|
||||
act_cfg (dict): Config dict for activation layer in block.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
|
||||
mlp_ratio=4.,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
self.norm1 = build_norm_layer(norm_cfg, channels)[1]
|
||||
self.attn = MSCASpatialAttention(channels, attention_kernel_sizes,
|
||||
attention_kernel_paddings, act_cfg)
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = build_norm_layer(norm_cfg, channels)[1]
|
||||
mlp_hidden_channels = int(channels * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=channels,
|
||||
hidden_features=mlp_hidden_channels,
|
||||
act_cfg=act_cfg,
|
||||
drop=drop)
|
||||
layer_scale_init_value = 1e-2
|
||||
self.layer_scale_1 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(channels), requires_grad=True)
|
||||
self.layer_scale_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(channels), requires_grad=True)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
"""Forward function."""
|
||||
|
||||
B, N, C = x.shape
|
||||
x = x.permute(0, 2, 1).view(B, C, H, W)
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
|
||||
self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
|
||||
self.mlp(self.norm2(x)))
|
||||
x = x.view(B, C, N).permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class OverlapPatchEmbed(BaseModule):
|
||||
"""Image to Patch Embedding.
|
||||
|
||||
Args:
|
||||
patch_size (int): The patch size.
|
||||
Defaults: 7.
|
||||
stride (int): Stride of the convolutional layer.
|
||||
Default: 4.
|
||||
in_channels (int): The number of input channels.
|
||||
Defaults: 3.
|
||||
embed_dims (int): The dimensions of embedding.
|
||||
Defaults: 768.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
patch_size=7,
|
||||
stride=4,
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=stride,
|
||||
padding=patch_size // 2)
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = self.norm(x)
|
||||
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
return x, H, W
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MSCAN(BaseModule):
|
||||
"""SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone.
|
||||
|
||||
This backbone is the implementation of `SegNeXt: Rethinking
|
||||
Convolutional Attention Design for Semantic
|
||||
Segmentation <https://arxiv.org/abs/2209.08575>`_.
|
||||
Inspiration from https://github.com/visual-attention-network/segnext.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels. Defaults: 3.
|
||||
embed_dims (list[int]): Embedding dimension.
|
||||
Defaults: [64, 128, 256, 512].
|
||||
mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim.
|
||||
Defaults: [4, 4, 4, 4].
|
||||
drop_rate (float): Dropout rate. Defaults: 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults: 0.
|
||||
depths (list[int]): Depths of each Swin Transformer stage.
|
||||
Default: [3, 4, 6, 3].
|
||||
num_stages (int): MSCAN stages. Default: 4.
|
||||
attention_kernel_sizes (list): Size of attention kernel in
|
||||
Attention Module (Figure 2(b) of original paper).
|
||||
Defaults: [5, [1, 7], [1, 11], [1, 21]].
|
||||
attention_kernel_paddings (list): Size of attention paddings
|
||||
in Attention Module (Figure 2(b) of original paper).
|
||||
Defaults: [2, [0, 3], [0, 5], [0, 10]].
|
||||
norm_cfg (dict): Config of norm layers.
|
||||
Defaults: dict(type='SyncBN', requires_grad=True).
|
||||
pretrained (str, optional): model pretrained path.
|
||||
Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=[64, 128, 256, 512],
|
||||
mlp_ratios=[4, 4, 4, 4],
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
depths=[3, 4, 6, 3],
|
||||
num_stages=4,
|
||||
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
|
||||
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.depths = depths
|
||||
self.num_stages = num_stages
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
cur = 0
|
||||
|
||||
for i in range(num_stages):
|
||||
if i == 0:
|
||||
patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg)
|
||||
else:
|
||||
patch_embed = OverlapPatchEmbed(
|
||||
patch_size=7 if i == 0 else 3,
|
||||
stride=4 if i == 0 else 2,
|
||||
in_channels=in_channels if i == 0 else embed_dims[i - 1],
|
||||
embed_dim=embed_dims[i],
|
||||
norm_cfg=norm_cfg)
|
||||
|
||||
block = nn.ModuleList([
|
||||
MSCABlock(
|
||||
channels=embed_dims[i],
|
||||
attention_kernel_sizes=attention_kernel_sizes,
|
||||
attention_kernel_paddings=attention_kernel_paddings,
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
drop=drop_rate,
|
||||
drop_path=dpr[cur + j],
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg) for j in range(depths[i])
|
||||
])
|
||||
norm = nn.LayerNorm(embed_dims[i])
|
||||
cur += depths[i]
|
||||
|
||||
setattr(self, f'patch_embed{i + 1}', patch_embed)
|
||||
setattr(self, f'block{i + 1}', block)
|
||||
setattr(self, f'norm{i + 1}', norm)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize modules of MSCAN."""
|
||||
|
||||
print('init cfg', self.init_cfg)
|
||||
if self.init_cfg is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||
1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
normal_init(
|
||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||
else:
|
||||
super().init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
B = x.shape[0]
|
||||
outs = []
|
||||
|
||||
for i in range(self.num_stages):
|
||||
patch_embed = getattr(self, f'patch_embed{i + 1}')
|
||||
block = getattr(self, f'block{i + 1}')
|
||||
norm = getattr(self, f'norm{i + 1}')
|
||||
x, H, W = patch_embed(x)
|
||||
for blk in block:
|
||||
x = blk(x, H, W)
|
||||
x = norm(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
return outs
|
||||
522
finetune/mmseg/models/backbones/pidnet.py
Normal file
522
finetune/mmseg/models/backbones/pidnet.py
Normal file
@@ -0,0 +1,522 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.runner import CheckpointLoader
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType
|
||||
from ..utils import DAPPM, PAPPM, BasicBlock, Bottleneck
|
||||
|
||||
|
||||
class PagFM(BaseModule):
|
||||
"""Pixel-attention-guided fusion module.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
channels (int): The number of channels.
|
||||
after_relu (bool): Whether to use ReLU before attention.
|
||||
Default: False.
|
||||
with_channel (bool): Whether to use channel attention.
|
||||
Default: False.
|
||||
upsample_mode (str): The mode of upsample. Default: 'bilinear'.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(typ='ReLU', inplace=True).
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
after_relu: bool = False,
|
||||
with_channel: bool = False,
|
||||
upsample_mode: str = 'bilinear',
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(typ='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.after_relu = after_relu
|
||||
self.with_channel = with_channel
|
||||
self.upsample_mode = upsample_mode
|
||||
self.f_i = ConvModule(
|
||||
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
self.f_p = ConvModule(
|
||||
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
if with_channel:
|
||||
self.up = ConvModule(
|
||||
channels, in_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
if after_relu:
|
||||
self.relu = MODELS.build(act_cfg)
|
||||
|
||||
def forward(self, x_p: Tensor, x_i: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x_p (Tensor): The featrue map from P branch.
|
||||
x_i (Tensor): The featrue map from I branch.
|
||||
|
||||
Returns:
|
||||
Tensor: The feature map with pixel-attention-guided fusion.
|
||||
"""
|
||||
if self.after_relu:
|
||||
x_p = self.relu(x_p)
|
||||
x_i = self.relu(x_i)
|
||||
|
||||
f_i = self.f_i(x_i)
|
||||
f_i = F.interpolate(
|
||||
f_i,
|
||||
size=x_p.shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=False)
|
||||
|
||||
f_p = self.f_p(x_p)
|
||||
|
||||
if self.with_channel:
|
||||
sigma = torch.sigmoid(self.up(f_p * f_i))
|
||||
else:
|
||||
sigma = torch.sigmoid(torch.sum(f_p * f_i, dim=1).unsqueeze(1))
|
||||
|
||||
x_i = F.interpolate(
|
||||
x_i,
|
||||
size=x_p.shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=False)
|
||||
|
||||
out = sigma * x_i + (1 - sigma) * x_p
|
||||
return out
|
||||
|
||||
|
||||
class Bag(BaseModule):
|
||||
"""Boundary-attention-guided fusion module.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
kernel_size (int): The kernel size of the convolution. Default: 3.
|
||||
padding (int): The padding of the convolution. Default: 1.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: dict(order=('norm', 'act', 'conv')).
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
padding: int = 1,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
conv_cfg: OptConfigType = dict(order=('norm', 'act', 'conv')),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=padding,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**conv_cfg)
|
||||
|
||||
def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x_p (Tensor): The featrue map from P branch.
|
||||
x_i (Tensor): The featrue map from I branch.
|
||||
x_d (Tensor): The featrue map from D branch.
|
||||
|
||||
Returns:
|
||||
Tensor: The feature map with boundary-attention-guided fusion.
|
||||
"""
|
||||
sigma = torch.sigmoid(x_d)
|
||||
return self.conv(sigma * x_p + (1 - sigma) * x_i)
|
||||
|
||||
|
||||
class LightBag(BaseModule):
|
||||
"""Light Boundary-attention-guided fusion module.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer. Default: None.
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = None,
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.f_p = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.f_i = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
x_p (Tensor): The featrue map from P branch.
|
||||
x_i (Tensor): The featrue map from I branch.
|
||||
x_d (Tensor): The featrue map from D branch.
|
||||
|
||||
Returns:
|
||||
Tensor: The feature map with light boundary-attention-guided
|
||||
fusion.
|
||||
"""
|
||||
sigma = torch.sigmoid(x_d)
|
||||
|
||||
f_p = self.f_p((1 - sigma) * x_i + x_p)
|
||||
f_i = self.f_i(x_i + sigma * x_p)
|
||||
|
||||
return f_p + f_i
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PIDNet(BaseModule):
|
||||
"""PIDNet backbone.
|
||||
|
||||
This backbone is the implementation of `PIDNet: A Real-time Semantic
|
||||
Segmentation Network Inspired from PID Controller
|
||||
<https://arxiv.org/abs/2206.02066>`_.
|
||||
Modified from https://github.com/XuJiacong/PIDNet.
|
||||
|
||||
Licensed under the MIT License.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels. Default: 3.
|
||||
channels (int): The number of channels in the stem layer. Default: 64.
|
||||
ppm_channels (int): The number of channels in the PPM layer.
|
||||
Default: 96.
|
||||
num_stem_blocks (int): The number of blocks in the stem layer.
|
||||
Default: 2.
|
||||
num_branch_blocks (int): The number of blocks in the branch layer.
|
||||
Default: 3.
|
||||
align_corners (bool): The align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
init_cfg (dict): Config dict for initialization. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int = 3,
|
||||
channels: int = 64,
|
||||
ppm_channels: int = 96,
|
||||
num_stem_blocks: int = 2,
|
||||
num_branch_blocks: int = 3,
|
||||
align_corners: bool = False,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg)
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.align_corners = align_corners
|
||||
|
||||
# stem layer
|
||||
self.stem = self._make_stem_layer(in_channels, channels,
|
||||
num_stem_blocks)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# I Branch
|
||||
self.i_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.i_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
in_channels=channels * 2**(i + 1),
|
||||
channels=channels * 8 if i > 0 else channels * 4,
|
||||
num_blocks=num_branch_blocks if i < 2 else 2,
|
||||
stride=2))
|
||||
|
||||
# P Branch
|
||||
self.p_branch_layers = nn.ModuleList()
|
||||
for i in range(3):
|
||||
self.p_branch_layers.append(
|
||||
self._make_layer(
|
||||
block=BasicBlock if i < 2 else Bottleneck,
|
||||
in_channels=channels * 2,
|
||||
channels=channels * 2,
|
||||
num_blocks=num_stem_blocks if i < 2 else 1))
|
||||
self.compression_1 = ConvModule(
|
||||
channels * 4,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.compression_2 = ConvModule(
|
||||
channels * 8,
|
||||
channels * 2,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.pag_1 = PagFM(channels * 2, channels)
|
||||
self.pag_2 = PagFM(channels * 2, channels)
|
||||
|
||||
# D Branch
|
||||
if num_stem_blocks == 2:
|
||||
self.d_branch_layers = nn.ModuleList([
|
||||
self._make_single_layer(BasicBlock, channels * 2, channels),
|
||||
self._make_layer(Bottleneck, channels, channels, 1)
|
||||
])
|
||||
channel_expand = 1
|
||||
spp_module = PAPPM
|
||||
dfm_module = LightBag
|
||||
act_cfg_dfm = None
|
||||
else:
|
||||
self.d_branch_layers = nn.ModuleList([
|
||||
self._make_single_layer(BasicBlock, channels * 2,
|
||||
channels * 2),
|
||||
self._make_single_layer(BasicBlock, channels * 2, channels * 2)
|
||||
])
|
||||
channel_expand = 2
|
||||
spp_module = DAPPM
|
||||
dfm_module = Bag
|
||||
act_cfg_dfm = act_cfg
|
||||
|
||||
self.diff_1 = ConvModule(
|
||||
channels * 4,
|
||||
channels * channel_expand,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
self.diff_2 = ConvModule(
|
||||
channels * 8,
|
||||
channels * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
self.spp = spp_module(
|
||||
channels * 16, ppm_channels, channels * 4, num_scales=5)
|
||||
self.dfm = dfm_module(
|
||||
channels * 4, channels * 4, norm_cfg=norm_cfg, act_cfg=act_cfg_dfm)
|
||||
|
||||
self.d_branch_layers.append(
|
||||
self._make_layer(Bottleneck, channels * 2, channels * 2, 1))
|
||||
|
||||
def _make_stem_layer(self, in_channels: int, channels: int,
|
||||
num_blocks: int) -> nn.Sequential:
|
||||
"""Make stem layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_blocks (int): Number of blocks.
|
||||
|
||||
Returns:
|
||||
nn.Sequential: The stem layer.
|
||||
"""
|
||||
|
||||
layers = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg),
|
||||
ConvModule(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
]
|
||||
|
||||
layers.append(
|
||||
self._make_layer(BasicBlock, channels, channels, num_blocks))
|
||||
layers.append(nn.ReLU())
|
||||
layers.append(
|
||||
self._make_layer(
|
||||
BasicBlock, channels, channels * 2, num_blocks, stride=2))
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_layer(self,
|
||||
block: BasicBlock,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
num_blocks: int,
|
||||
stride: int = 1) -> nn.Sequential:
|
||||
"""Make layer for PIDNet backbone.
|
||||
Args:
|
||||
block (BasicBlock): Basic block.
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_blocks (int): Number of blocks.
|
||||
stride (int): Stride of the first block. Default: 1.
|
||||
|
||||
Returns:
|
||||
nn.Sequential: The Branch Layer.
|
||||
"""
|
||||
downsample = None
|
||||
if stride != 1 or in_channels != channels * block.expansion:
|
||||
downsample = ConvModule(
|
||||
in_channels,
|
||||
channels * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
layers = [block(in_channels, channels, stride, downsample)]
|
||||
in_channels = channels * block.expansion
|
||||
for i in range(1, num_blocks):
|
||||
layers.append(
|
||||
block(
|
||||
in_channels,
|
||||
channels,
|
||||
stride=1,
|
||||
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_single_layer(self,
|
||||
block: Union[BasicBlock, Bottleneck],
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
stride: int = 1) -> nn.Module:
|
||||
"""Make single layer for PIDNet backbone.
|
||||
Args:
|
||||
block (BasicBlock or Bottleneck): Basic block or Bottleneck.
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
stride (int): Stride of the first block. Default: 1.
|
||||
|
||||
Returns:
|
||||
nn.Module
|
||||
"""
|
||||
|
||||
downsample = None
|
||||
if stride != 1 or in_channels != channels * block.expansion:
|
||||
downsample = ConvModule(
|
||||
in_channels,
|
||||
channels * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
return block(
|
||||
in_channels, channels, stride, downsample, act_cfg_out=None)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Since the D branch is not initialized by the pre-trained model, we
|
||||
initialize it with the same method as the ResNet.
|
||||
"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
if self.init_cfg is not None:
|
||||
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
||||
f'specify `Pretrained` in ' \
|
||||
f'`init_cfg` in ' \
|
||||
f'{self.__class__.__name__} '
|
||||
ckpt = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], map_location='cpu')
|
||||
self.load_state_dict(ckpt, strict=False)
|
||||
|
||||
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (B, C, H, W).
|
||||
|
||||
Returns:
|
||||
Tensor or tuple[Tensor]: If self.training is True, return
|
||||
tuple[Tensor], else return Tensor.
|
||||
"""
|
||||
w_out = x.shape[-1] // 8
|
||||
h_out = x.shape[-2] // 8
|
||||
|
||||
# stage 0-2
|
||||
x = self.stem(x)
|
||||
|
||||
# stage 3
|
||||
x_i = self.relu(self.i_branch_layers[0](x))
|
||||
x_p = self.p_branch_layers[0](x)
|
||||
x_d = self.d_branch_layers[0](x)
|
||||
|
||||
comp_i = self.compression_1(x_i)
|
||||
x_p = self.pag_1(x_p, comp_i)
|
||||
diff_i = self.diff_1(x_i)
|
||||
x_d += F.interpolate(
|
||||
diff_i,
|
||||
size=[h_out, w_out],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.training:
|
||||
temp_p = x_p.clone()
|
||||
|
||||
# stage 4
|
||||
x_i = self.relu(self.i_branch_layers[1](x_i))
|
||||
x_p = self.p_branch_layers[1](self.relu(x_p))
|
||||
x_d = self.d_branch_layers[1](self.relu(x_d))
|
||||
|
||||
comp_i = self.compression_2(x_i)
|
||||
x_p = self.pag_2(x_p, comp_i)
|
||||
diff_i = self.diff_2(x_i)
|
||||
x_d += F.interpolate(
|
||||
diff_i,
|
||||
size=[h_out, w_out],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.training:
|
||||
temp_d = x_d.clone()
|
||||
|
||||
# stage 5
|
||||
x_i = self.i_branch_layers[2](x_i)
|
||||
x_p = self.p_branch_layers[2](self.relu(x_p))
|
||||
x_d = self.d_branch_layers[2](self.relu(x_d))
|
||||
|
||||
x_i = self.spp(x_i)
|
||||
x_i = F.interpolate(
|
||||
x_i,
|
||||
size=[h_out, w_out],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
out = self.dfm(x_p, x_i, x_d)
|
||||
return (temp_p, out, temp_d) if self.training else out
|
||||
318
finetune/mmseg/models/backbones/resnest.py
Normal file
318
finetune/mmseg/models/backbones/resnest.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
from .resnet import Bottleneck as _Bottleneck
|
||||
from .resnet import ResNetV1d
|
||||
|
||||
|
||||
class RSoftmax(nn.Module):
|
||||
"""Radix Softmax module in ``SplitAttentionConv2d``.
|
||||
|
||||
Args:
|
||||
radix (int): Radix of input.
|
||||
groups (int): Groups of input.
|
||||
"""
|
||||
|
||||
def __init__(self, radix, groups):
|
||||
super().__init__()
|
||||
self.radix = radix
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
batch = x.size(0)
|
||||
if self.radix > 1:
|
||||
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
|
||||
x = F.softmax(x, dim=1)
|
||||
x = x.reshape(batch, -1)
|
||||
else:
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class SplitAttentionConv2d(nn.Module):
|
||||
"""Split-Attention Conv2d in ResNeSt.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as nn.Conv2d.
|
||||
out_channels (int): Same as nn.Conv2d.
|
||||
kernel_size (int | tuple[int]): Same as nn.Conv2d.
|
||||
stride (int | tuple[int]): Same as nn.Conv2d.
|
||||
padding (int | tuple[int]): Same as nn.Conv2d.
|
||||
dilation (int | tuple[int]): Same as nn.Conv2d.
|
||||
groups (int): Same as nn.Conv2d.
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels. Default: 4.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
dcn (dict): Config dict for DCN. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None):
|
||||
super().__init__()
|
||||
inter_channels = max(in_channels * radix // reduction_factor, 32)
|
||||
self.radix = radix
|
||||
self.groups = groups
|
||||
self.channels = channels
|
||||
self.with_dcn = dcn is not None
|
||||
self.dcn = dcn
|
||||
fallback_on_stride = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
|
||||
if self.with_dcn and not fallback_on_stride:
|
||||
assert conv_cfg is None, 'conv_cfg must be None for DCN'
|
||||
conv_cfg = dcn
|
||||
self.conv = build_conv_layer(
|
||||
conv_cfg,
|
||||
in_channels,
|
||||
channels * radix,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups * radix,
|
||||
bias=False)
|
||||
self.norm0_name, norm0 = build_norm_layer(
|
||||
norm_cfg, channels * radix, postfix=0)
|
||||
self.add_module(self.norm0_name, norm0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc1 = build_conv_layer(
|
||||
None, channels, inter_channels, 1, groups=self.groups)
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, inter_channels, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.fc2 = build_conv_layer(
|
||||
None, inter_channels, channels * radix, 1, groups=self.groups)
|
||||
self.rsoftmax = RSoftmax(radix, groups)
|
||||
|
||||
@property
|
||||
def norm0(self):
|
||||
"""nn.Module: the normalization layer named "norm0" """
|
||||
return getattr(self, self.norm0_name)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm0(x)
|
||||
x = self.relu(x)
|
||||
|
||||
batch, rchannel = x.shape[:2]
|
||||
batch = x.size(0)
|
||||
if self.radix > 1:
|
||||
splits = x.view(batch, self.radix, -1, *x.shape[2:])
|
||||
gap = splits.sum(dim=1)
|
||||
else:
|
||||
gap = x
|
||||
gap = F.adaptive_avg_pool2d(gap, 1)
|
||||
gap = self.fc1(gap)
|
||||
|
||||
gap = self.norm1(gap)
|
||||
gap = self.relu(gap)
|
||||
|
||||
atten = self.fc2(gap)
|
||||
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
|
||||
|
||||
if self.radix > 1:
|
||||
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
|
||||
out = torch.sum(attens * splits, dim=1)
|
||||
else:
|
||||
out = atten * x
|
||||
return out.contiguous()
|
||||
|
||||
|
||||
class Bottleneck(_Bottleneck):
|
||||
"""Bottleneck block for ResNeSt.
|
||||
|
||||
Args:
|
||||
inplane (int): Input planes of this block.
|
||||
planes (int): Middle planes of this block.
|
||||
groups (int): Groups of conv2.
|
||||
width_per_group (int): Width per group of conv2. 64x4d indicates
|
||||
``groups=64, width_per_group=4`` and 32x8d indicates
|
||||
``groups=32, width_per_group=8``.
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels in
|
||||
SplitAttentionConv2d. Default: 4.
|
||||
avg_down_stride (bool): Whether to use average pool for stride in
|
||||
Bottleneck. Default: True.
|
||||
kwargs (dict): Key word arguments for base class.
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
base_channels=64,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True,
|
||||
**kwargs):
|
||||
"""Bottleneck block for ResNeSt."""
|
||||
super().__init__(inplanes, planes, **kwargs)
|
||||
|
||||
if groups == 1:
|
||||
width = self.planes
|
||||
else:
|
||||
width = math.floor(self.planes *
|
||||
(base_width / base_channels)) * groups
|
||||
|
||||
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=1)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
self.norm_cfg, self.planes * self.expansion, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.inplanes,
|
||||
width,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.with_modulated_dcn = False
|
||||
self.conv2 = SplitAttentionConv2d(
|
||||
width,
|
||||
width,
|
||||
kernel_size=3,
|
||||
stride=1 if self.avg_down_stride else self.conv2_stride,
|
||||
padding=self.dilation,
|
||||
dilation=self.dilation,
|
||||
groups=groups,
|
||||
radix=radix,
|
||||
reduction_factor=reduction_factor,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
dcn=self.dcn)
|
||||
delattr(self, self.norm2_name)
|
||||
|
||||
if self.avg_down_stride:
|
||||
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
|
||||
|
||||
self.conv3 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
self.planes * self.expansion,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv1_plugin_names)
|
||||
|
||||
out = self.conv2(out)
|
||||
|
||||
if self.avg_down_stride:
|
||||
out = self.avd_layer(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv2_plugin_names)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.norm3(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv3_plugin_names)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNeSt(ResNetV1d):
|
||||
"""ResNeSt backbone.
|
||||
|
||||
This backbone is the implementation of `ResNeSt:
|
||||
Split-Attention Networks <https://arxiv.org/abs/2004.08955>`_.
|
||||
|
||||
Args:
|
||||
groups (int): Number of groups of Bottleneck. Default: 1
|
||||
base_width (int): Base width of Bottleneck. Default: 4
|
||||
radix (int): Radix of SpltAtConv2d. Default: 2
|
||||
reduction_factor (int): Reduction factor of inter_channels in
|
||||
SplitAttentionConv2d. Default: 4.
|
||||
avg_down_stride (bool): Whether to use average pool for stride in
|
||||
Bottleneck. Default: True.
|
||||
kwargs (dict): Keyword arguments for ResNet.
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
50: (Bottleneck, (3, 4, 6, 3)),
|
||||
101: (Bottleneck, (3, 4, 23, 3)),
|
||||
152: (Bottleneck, (3, 8, 36, 3)),
|
||||
200: (Bottleneck, (3, 24, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
radix=2,
|
||||
reduction_factor=4,
|
||||
avg_down_stride=True,
|
||||
**kwargs):
|
||||
self.groups = groups
|
||||
self.base_width = base_width
|
||||
self.radix = radix
|
||||
self.reduction_factor = reduction_factor
|
||||
self.avg_down_stride = avg_down_stride
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
"""Pack all blocks in a stage into a ``ResLayer``."""
|
||||
return ResLayer(
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
base_channels=self.base_channels,
|
||||
radix=self.radix,
|
||||
reduction_factor=self.reduction_factor,
|
||||
avg_down_stride=self.avg_down_stride,
|
||||
**kwargs)
|
||||
712
finetune/mmseg/models/backbones/resnet.py
Normal file
712
finetune/mmseg/models/backbones/resnet.py
Normal file
@@ -0,0 +1,712 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
|
||||
|
||||
class BasicBlock(BaseModule):
|
||||
"""Basic block for ResNet."""
|
||||
|
||||
expansion = 1
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
downsample=None,
|
||||
style='pytorch',
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None,
|
||||
plugins=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
conv_cfg,
|
||||
inplanes,
|
||||
planes,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.conv2 = build_conv_layer(
|
||||
conv_cfg, planes, planes, 3, padding=1, bias=False)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.with_cp = with_cp
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: normalization layer after the first convolution layer"""
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
"""nn.Module: normalization layer after the second convolution layer"""
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.norm2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(BaseModule):
|
||||
"""Bottleneck block for ResNet.
|
||||
|
||||
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
|
||||
"caffe", the stride-two layer is the first 1x1 conv layer.
|
||||
"""
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
downsample=None,
|
||||
style='pytorch',
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
dcn=None,
|
||||
plugins=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
assert style in ['pytorch', 'caffe']
|
||||
assert dcn is None or isinstance(dcn, dict)
|
||||
assert plugins is None or isinstance(plugins, list)
|
||||
if plugins is not None:
|
||||
allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
|
||||
assert all(p['position'] in allowed_position for p in plugins)
|
||||
|
||||
self.inplanes = inplanes
|
||||
self.planes = planes
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.style = style
|
||||
self.with_cp = with_cp
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.dcn = dcn
|
||||
self.with_dcn = dcn is not None
|
||||
self.plugins = plugins
|
||||
self.with_plugins = plugins is not None
|
||||
|
||||
if self.with_plugins:
|
||||
# collect plugins for conv1/conv2/conv3
|
||||
self.after_conv1_plugins = [
|
||||
plugin['cfg'] for plugin in plugins
|
||||
if plugin['position'] == 'after_conv1'
|
||||
]
|
||||
self.after_conv2_plugins = [
|
||||
plugin['cfg'] for plugin in plugins
|
||||
if plugin['position'] == 'after_conv2'
|
||||
]
|
||||
self.after_conv3_plugins = [
|
||||
plugin['cfg'] for plugin in plugins
|
||||
if plugin['position'] == 'after_conv3'
|
||||
]
|
||||
|
||||
if self.style == 'pytorch':
|
||||
self.conv1_stride = 1
|
||||
self.conv2_stride = stride
|
||||
else:
|
||||
self.conv1_stride = stride
|
||||
self.conv2_stride = 1
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
norm_cfg, planes * self.expansion, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
conv_cfg,
|
||||
inplanes,
|
||||
planes,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
fallback_on_stride = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = dcn.pop('fallback_on_stride', False)
|
||||
if not self.with_dcn or fallback_on_stride:
|
||||
self.conv2 = build_conv_layer(
|
||||
conv_cfg,
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
else:
|
||||
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
|
||||
self.conv2 = build_conv_layer(
|
||||
dcn,
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.conv3 = build_conv_layer(
|
||||
conv_cfg,
|
||||
planes,
|
||||
planes * self.expansion,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
|
||||
if self.with_plugins:
|
||||
self.after_conv1_plugin_names = self.make_block_plugins(
|
||||
planes, self.after_conv1_plugins)
|
||||
self.after_conv2_plugin_names = self.make_block_plugins(
|
||||
planes, self.after_conv2_plugins)
|
||||
self.after_conv3_plugin_names = self.make_block_plugins(
|
||||
planes * self.expansion, self.after_conv3_plugins)
|
||||
|
||||
def make_block_plugins(self, in_channels, plugins):
|
||||
"""make plugins for block.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of plugin.
|
||||
plugins (list[dict]): List of plugins cfg to build.
|
||||
|
||||
Returns:
|
||||
list[str]: List of the names of plugin.
|
||||
"""
|
||||
assert isinstance(plugins, list)
|
||||
plugin_names = []
|
||||
for plugin in plugins:
|
||||
plugin = plugin.copy()
|
||||
name, layer = build_plugin_layer(
|
||||
plugin,
|
||||
in_channels=in_channels,
|
||||
postfix=plugin.pop('postfix', ''))
|
||||
assert not hasattr(self, name), f'duplicate plugin {name}'
|
||||
self.add_module(name, layer)
|
||||
plugin_names.append(name)
|
||||
return plugin_names
|
||||
|
||||
def forward_plugin(self, x, plugin_names):
|
||||
"""Forward function for plugins."""
|
||||
out = x
|
||||
for name in plugin_names:
|
||||
out = getattr(self, name)(x)
|
||||
return out
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: normalization layer after the first convolution layer"""
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
"""nn.Module: normalization layer after the second convolution layer"""
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
@property
|
||||
def norm3(self):
|
||||
"""nn.Module: normalization layer after the third convolution layer"""
|
||||
return getattr(self, self.norm3_name)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv1_plugin_names)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.norm2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv2_plugin_names)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.norm3(out)
|
||||
|
||||
if self.with_plugins:
|
||||
out = self.forward_plugin(out, self.after_conv3_plugin_names)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNet(BaseModule):
|
||||
"""ResNet backbone.
|
||||
|
||||
This backbone is the improved implementation of `Deep Residual Learning
|
||||
for Image Recognition <https://arxiv.org/abs/1512.03385>`_.
|
||||
|
||||
Args:
|
||||
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
stem_channels (int): Number of stem channels. Default: 64.
|
||||
base_channels (int): Number of base channels of res layer. Default: 64.
|
||||
num_stages (int): Resnet stages, normally 4. Default: 4.
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
Default: (1, 2, 2, 2).
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
Default: (1, 1, 1, 1).
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer. Default: 'pytorch'.
|
||||
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
|
||||
Default: False.
|
||||
avg_down (bool): Use AvgPool instead of stride conv when
|
||||
downsampling in the bottleneck. Default: False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Default: -1.
|
||||
conv_cfg (dict | None): Dictionary to construct and config conv layer.
|
||||
When conv_cfg is None, cfg will be set to dict(type='Conv2d').
|
||||
Default: None.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
dcn (dict | None): Dictionary to construct and config DCN conv layer.
|
||||
When dcn is not None, conv_cfg must be None. Default: None.
|
||||
stage_with_dcn (Sequence[bool]): Whether to set DCN conv for each
|
||||
stage. The length of stage_with_dcn is equal to num_stages.
|
||||
Default: (False, False, False, False).
|
||||
plugins (list[dict]): List of plugins for stages, each dict contains:
|
||||
|
||||
- cfg (dict, required): Cfg dict to build plugin.
|
||||
|
||||
- position (str, required): Position inside block to insert plugin,
|
||||
options: 'after_conv1', 'after_conv2', 'after_conv3'.
|
||||
|
||||
- stages (tuple[bool], optional): Stages to apply plugin, length
|
||||
should be same as 'num_stages'.
|
||||
Default: None.
|
||||
multi_grid (Sequence[int]|None): Multi grid dilation rates of last
|
||||
stage. Default: None.
|
||||
contract_dilation (bool): Whether contract first dilation of each layer
|
||||
Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
zero_init_residual (bool): Whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity. Default: True.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> from mmseg.models import ResNet
|
||||
>>> import torch
|
||||
>>> self = ResNet(depth=18)
|
||||
>>> self.eval()
|
||||
>>> inputs = torch.rand(1, 3, 32, 32)
|
||||
>>> level_outputs = self.forward(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
(1, 64, 8, 8)
|
||||
(1, 128, 4, 4)
|
||||
(1, 256, 2, 2)
|
||||
(1, 512, 1, 1)
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
18: (BasicBlock, (2, 2, 2, 2)),
|
||||
34: (BasicBlock, (3, 4, 6, 3)),
|
||||
50: (Bottleneck, (3, 4, 6, 3)),
|
||||
101: (Bottleneck, (3, 4, 23, 3)),
|
||||
152: (Bottleneck, (3, 8, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
depth,
|
||||
in_channels=3,
|
||||
stem_channels=64,
|
||||
base_channels=64,
|
||||
num_stages=4,
|
||||
strides=(1, 2, 2, 2),
|
||||
dilations=(1, 1, 1, 1),
|
||||
out_indices=(0, 1, 2, 3),
|
||||
style='pytorch',
|
||||
deep_stem=False,
|
||||
avg_down=False,
|
||||
frozen_stages=-1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
norm_eval=False,
|
||||
dcn=None,
|
||||
stage_with_dcn=(False, False, False, False),
|
||||
plugins=None,
|
||||
multi_grid=None,
|
||||
contract_dilation=False,
|
||||
with_cp=False,
|
||||
zero_init_residual=True,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError(f'invalid depth {depth} for resnet')
|
||||
|
||||
self.pretrained = pretrained
|
||||
self.zero_init_residual = zero_init_residual
|
||||
block_init_cfg = None
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
block = self.arch_settings[depth][0]
|
||||
if self.zero_init_residual:
|
||||
if block is BasicBlock:
|
||||
block_init_cfg = dict(
|
||||
type='Constant',
|
||||
val=0,
|
||||
override=dict(name='norm2'))
|
||||
elif block is Bottleneck:
|
||||
block_init_cfg = dict(
|
||||
type='Constant',
|
||||
val=0,
|
||||
override=dict(name='norm3'))
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.depth = depth
|
||||
self.stem_channels = stem_channels
|
||||
self.base_channels = base_channels
|
||||
self.num_stages = num_stages
|
||||
assert num_stages >= 1 and num_stages <= 4
|
||||
self.strides = strides
|
||||
self.dilations = dilations
|
||||
assert len(strides) == len(dilations) == num_stages
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < num_stages
|
||||
self.style = style
|
||||
self.deep_stem = deep_stem
|
||||
self.avg_down = avg_down
|
||||
self.frozen_stages = frozen_stages
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.with_cp = with_cp
|
||||
self.norm_eval = norm_eval
|
||||
self.dcn = dcn
|
||||
self.stage_with_dcn = stage_with_dcn
|
||||
if dcn is not None:
|
||||
assert len(stage_with_dcn) == num_stages
|
||||
self.plugins = plugins
|
||||
self.multi_grid = multi_grid
|
||||
self.contract_dilation = contract_dilation
|
||||
self.block, stage_blocks = self.arch_settings[depth]
|
||||
self.stage_blocks = stage_blocks[:num_stages]
|
||||
self.inplanes = stem_channels
|
||||
|
||||
self._make_stem_layer(in_channels, stem_channels)
|
||||
|
||||
self.res_layers = []
|
||||
for i, num_blocks in enumerate(self.stage_blocks):
|
||||
stride = strides[i]
|
||||
dilation = dilations[i]
|
||||
dcn = self.dcn if self.stage_with_dcn[i] else None
|
||||
if plugins is not None:
|
||||
stage_plugins = self.make_stage_plugins(plugins, i)
|
||||
else:
|
||||
stage_plugins = None
|
||||
# multi grid is applied to last layer only
|
||||
stage_multi_grid = multi_grid if i == len(
|
||||
self.stage_blocks) - 1 else None
|
||||
planes = base_channels * 2**i
|
||||
res_layer = self.make_res_layer(
|
||||
block=self.block,
|
||||
inplanes=self.inplanes,
|
||||
planes=planes,
|
||||
num_blocks=num_blocks,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
style=self.style,
|
||||
avg_down=self.avg_down,
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
dcn=dcn,
|
||||
plugins=stage_plugins,
|
||||
multi_grid=stage_multi_grid,
|
||||
contract_dilation=contract_dilation,
|
||||
init_cfg=block_init_cfg)
|
||||
self.inplanes = planes * self.block.expansion
|
||||
layer_name = f'layer{i+1}'
|
||||
self.add_module(layer_name, res_layer)
|
||||
self.res_layers.append(layer_name)
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
self.feat_dim = self.block.expansion * base_channels * 2**(
|
||||
len(self.stage_blocks) - 1)
|
||||
|
||||
def make_stage_plugins(self, plugins, stage_idx):
|
||||
"""make plugins for ResNet 'stage_idx'th stage .
|
||||
|
||||
Currently we support to insert 'context_block',
|
||||
'empirical_attention_block', 'nonlocal_block' into the backbone like
|
||||
ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
|
||||
Bottleneck.
|
||||
|
||||
An example of plugins format could be :
|
||||
>>> plugins=[
|
||||
... dict(cfg=dict(type='xxx', arg1='xxx'),
|
||||
... stages=(False, True, True, True),
|
||||
... position='after_conv2'),
|
||||
... dict(cfg=dict(type='yyy'),
|
||||
... stages=(True, True, True, True),
|
||||
... position='after_conv3'),
|
||||
... dict(cfg=dict(type='zzz', postfix='1'),
|
||||
... stages=(True, True, True, True),
|
||||
... position='after_conv3'),
|
||||
... dict(cfg=dict(type='zzz', postfix='2'),
|
||||
... stages=(True, True, True, True),
|
||||
... position='after_conv3')
|
||||
... ]
|
||||
>>> self = ResNet(depth=18)
|
||||
>>> stage_plugins = self.make_stage_plugins(plugins, 0)
|
||||
>>> assert len(stage_plugins) == 3
|
||||
|
||||
Suppose 'stage_idx=0', the structure of blocks in the stage would be:
|
||||
conv1-> conv2->conv3->yyy->zzz1->zzz2
|
||||
Suppose 'stage_idx=1', the structure of blocks in the stage would be:
|
||||
conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
|
||||
|
||||
If stages is missing, the plugin would be applied to all stages.
|
||||
|
||||
Args:
|
||||
plugins (list[dict]): List of plugins cfg to build. The postfix is
|
||||
required if multiple same type plugins are inserted.
|
||||
stage_idx (int): Index of stage to build
|
||||
|
||||
Returns:
|
||||
list[dict]: Plugins for current stage
|
||||
"""
|
||||
stage_plugins = []
|
||||
for plugin in plugins:
|
||||
plugin = plugin.copy()
|
||||
stages = plugin.pop('stages', None)
|
||||
assert stages is None or len(stages) == self.num_stages
|
||||
# whether to insert plugin into current stage
|
||||
if stages is None or stages[stage_idx]:
|
||||
stage_plugins.append(plugin)
|
||||
|
||||
return stage_plugins
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
"""Pack all blocks in a stage into a ``ResLayer``."""
|
||||
return ResLayer(**kwargs)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
"""nn.Module: the normalization layer named "norm1" """
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def _make_stem_layer(self, in_channels, stem_channels):
|
||||
"""Make stem layer for ResNet."""
|
||||
if self.deep_stem:
|
||||
self.stem = nn.Sequential(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
stem_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
|
||||
nn.ReLU(inplace=True),
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
stem_channels // 2,
|
||||
stem_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
|
||||
nn.ReLU(inplace=True),
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
stem_channels // 2,
|
||||
stem_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False),
|
||||
build_norm_layer(self.norm_cfg, stem_channels)[1],
|
||||
nn.ReLU(inplace=True))
|
||||
else:
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
stem_channels,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
bias=False)
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, stem_channels, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
def _freeze_stages(self):
|
||||
"""Freeze stages param and norm stats."""
|
||||
if self.frozen_stages >= 0:
|
||||
if self.deep_stem:
|
||||
self.stem.eval()
|
||||
for param in self.stem.parameters():
|
||||
param.requires_grad = False
|
||||
else:
|
||||
self.norm1.eval()
|
||||
for m in [self.conv1, self.norm1]:
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
m = getattr(self, f'layer{i}')
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
if self.deep_stem:
|
||||
x = self.stem(x)
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.res_layers):
|
||||
res_layer = getattr(self, layer_name)
|
||||
x = res_layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep normalization layer
|
||||
freezed."""
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNetV1c(ResNet):
|
||||
"""ResNetV1c variant described in [1]_.
|
||||
|
||||
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in
|
||||
the input stem with three 3x3 convs. For more details please refer to `Bag
|
||||
of Tricks for Image Classification with Convolutional Neural Networks
|
||||
<https://arxiv.org/abs/1812.01187>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(deep_stem=True, avg_down=False, **kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNetV1d(ResNet):
|
||||
"""ResNetV1d variant described in [1]_.
|
||||
|
||||
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
|
||||
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
|
||||
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(deep_stem=True, avg_down=True, **kwargs)
|
||||
150
finetune/mmseg/models/backbones/resnext.py
Normal file
150
finetune/mmseg/models/backbones/resnext.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
from .resnet import Bottleneck as _Bottleneck
|
||||
from .resnet import ResNet
|
||||
|
||||
|
||||
class Bottleneck(_Bottleneck):
|
||||
"""Bottleneck block for ResNeXt.
|
||||
|
||||
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
|
||||
"caffe", the stride-two layer is the first 1x1 conv layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
groups=1,
|
||||
base_width=4,
|
||||
base_channels=64,
|
||||
**kwargs):
|
||||
super().__init__(inplanes, planes, **kwargs)
|
||||
|
||||
if groups == 1:
|
||||
width = self.planes
|
||||
else:
|
||||
width = math.floor(self.planes *
|
||||
(base_width / base_channels)) * groups
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=1)
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
self.norm_cfg, width, postfix=2)
|
||||
self.norm3_name, norm3 = build_norm_layer(
|
||||
self.norm_cfg, self.planes * self.expansion, postfix=3)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.inplanes,
|
||||
width,
|
||||
kernel_size=1,
|
||||
stride=self.conv1_stride,
|
||||
bias=False)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
fallback_on_stride = False
|
||||
self.with_modulated_dcn = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
|
||||
if not self.with_dcn or fallback_on_stride:
|
||||
self.conv2 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
width,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=self.dilation,
|
||||
dilation=self.dilation,
|
||||
groups=groups,
|
||||
bias=False)
|
||||
else:
|
||||
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
|
||||
self.conv2 = build_conv_layer(
|
||||
self.dcn,
|
||||
width,
|
||||
width,
|
||||
kernel_size=3,
|
||||
stride=self.conv2_stride,
|
||||
padding=self.dilation,
|
||||
dilation=self.dilation,
|
||||
groups=groups,
|
||||
bias=False)
|
||||
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.conv3 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
width,
|
||||
self.planes * self.expansion,
|
||||
kernel_size=1,
|
||||
bias=False)
|
||||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ResNeXt(ResNet):
|
||||
"""ResNeXt backbone.
|
||||
|
||||
This backbone is the implementation of `Aggregated
|
||||
Residual Transformations for Deep Neural
|
||||
Networks <https://arxiv.org/abs/1611.05431>`_.
|
||||
|
||||
Args:
|
||||
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
|
||||
in_channels (int): Number of input image channels. Normally 3.
|
||||
num_stages (int): Resnet stages, normally 4.
|
||||
groups (int): Group of resnext.
|
||||
base_width (int): Base width of resnext.
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
||||
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
||||
the first 1x1 conv layer.
|
||||
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
||||
not freezing any parameters.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed.
|
||||
zero_init_residual (bool): whether to use zero init for last norm layer
|
||||
in resblocks to let them behave as identity.
|
||||
|
||||
Example:
|
||||
>>> from mmseg.models import ResNeXt
|
||||
>>> import torch
|
||||
>>> self = ResNeXt(depth=50)
|
||||
>>> self.eval()
|
||||
>>> inputs = torch.rand(1, 3, 32, 32)
|
||||
>>> level_outputs = self.forward(inputs)
|
||||
>>> for level_out in level_outputs:
|
||||
... print(tuple(level_out.shape))
|
||||
(1, 256, 8, 8)
|
||||
(1, 512, 4, 4)
|
||||
(1, 1024, 2, 2)
|
||||
(1, 2048, 1, 1)
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
50: (Bottleneck, (3, 4, 6, 3)),
|
||||
101: (Bottleneck, (3, 4, 23, 3)),
|
||||
152: (Bottleneck, (3, 8, 36, 3))
|
||||
}
|
||||
|
||||
def __init__(self, groups=1, base_width=4, **kwargs):
|
||||
self.groups = groups
|
||||
self.base_width = base_width
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def make_res_layer(self, **kwargs):
|
||||
"""Pack all blocks in a stage into a ``ResLayer``"""
|
||||
return ResLayer(
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
base_channels=self.base_channels,
|
||||
**kwargs)
|
||||
422
finetune/mmseg/models/backbones/stdc.py
Normal file
422
finetune/mmseg/models/backbones/stdc.py
Normal file
@@ -0,0 +1,422 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Modified from https://github.com/MichaelFan01/STDC-Seg."""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .bisenetv1 import AttentionRefinementModule
|
||||
|
||||
|
||||
class STDCModule(BaseModule):
|
||||
"""STDCModule.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels before scaling.
|
||||
stride (int): The number of stride for the first conv layer.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
num_convs (int): Numbers of conv layers.
|
||||
fusion_type (str): Type of fusion operation. Default: 'add'.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
norm_cfg=None,
|
||||
act_cfg=None,
|
||||
num_convs=4,
|
||||
fusion_type='add',
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert num_convs > 1
|
||||
assert fusion_type in ['add', 'cat']
|
||||
self.stride = stride
|
||||
self.with_downsample = True if self.stride == 2 else False
|
||||
self.fusion_type = fusion_type
|
||||
|
||||
self.layers = ModuleList()
|
||||
conv_0 = ConvModule(
|
||||
in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)
|
||||
|
||||
if self.with_downsample:
|
||||
self.downsample = ConvModule(
|
||||
out_channels // 2,
|
||||
out_channels // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=out_channels // 2,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
if self.fusion_type == 'add':
|
||||
self.layers.append(nn.Sequential(conv_0, self.downsample))
|
||||
self.skip = Sequential(
|
||||
ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None))
|
||||
else:
|
||||
self.layers.append(conv_0)
|
||||
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
|
||||
else:
|
||||
self.layers.append(conv_0)
|
||||
|
||||
for i in range(1, num_convs):
|
||||
out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
|
||||
self.layers.append(
|
||||
ConvModule(
|
||||
out_channels // 2**i,
|
||||
out_channels // out_factor,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.fusion_type == 'add':
|
||||
out = self.forward_add(inputs)
|
||||
else:
|
||||
out = self.forward_cat(inputs)
|
||||
return out
|
||||
|
||||
def forward_add(self, inputs):
|
||||
layer_outputs = []
|
||||
x = inputs.clone()
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
layer_outputs.append(x)
|
||||
if self.with_downsample:
|
||||
inputs = self.skip(inputs)
|
||||
|
||||
return torch.cat(layer_outputs, dim=1) + inputs
|
||||
|
||||
def forward_cat(self, inputs):
|
||||
x0 = self.layers[0](inputs)
|
||||
layer_outputs = [x0]
|
||||
for i, layer in enumerate(self.layers[1:]):
|
||||
if i == 0:
|
||||
if self.with_downsample:
|
||||
x = layer(self.downsample(x0))
|
||||
else:
|
||||
x = layer(x0)
|
||||
else:
|
||||
x = layer(x)
|
||||
layer_outputs.append(x)
|
||||
if self.with_downsample:
|
||||
layer_outputs[0] = self.skip(x0)
|
||||
return torch.cat(layer_outputs, dim=1)
|
||||
|
||||
|
||||
class FeatureFusionModule(BaseModule):
|
||||
"""Feature Fusion Module. This module is different from FeatureFusionModule
|
||||
in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
|
||||
channel number is calculated by given `scale_factor`, while
|
||||
FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
|
||||
`self.conv_atten`.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
scale_factor (int): The number of channel scale factor.
|
||||
Default: 4.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
scale_factor=4,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
channels = out_channels // scale_factor
|
||||
self.conv0 = ConvModule(
|
||||
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||||
self.attention = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
ConvModule(
|
||||
out_channels,
|
||||
channels,
|
||||
1,
|
||||
norm_cfg=None,
|
||||
bias=False,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
channels,
|
||||
out_channels,
|
||||
1,
|
||||
norm_cfg=None,
|
||||
bias=False,
|
||||
act_cfg=None), nn.Sigmoid())
|
||||
|
||||
def forward(self, spatial_inputs, context_inputs):
|
||||
inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
|
||||
x = self.conv0(inputs)
|
||||
attn = self.attention(x)
|
||||
x_attn = x * attn
|
||||
return x_attn + x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class STDCNet(BaseModule):
|
||||
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time
|
||||
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
|
||||
|
||||
Args:
|
||||
stdc_type (int): The type of backbone structure,
|
||||
`STDCNet1` and`STDCNet2` denotes two main backbones in paper,
|
||||
whose FLOPs is 813M and 1446M, respectively.
|
||||
in_channels (int): The num of input_channels.
|
||||
channels (tuple[int]): The output channels for each stage.
|
||||
bottleneck_type (str): The type of STDC Module type, the value must
|
||||
be 'add' or 'cat'.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
act_cfg (dict): The activation config for conv layers.
|
||||
num_convs (int): Numbers of conv layer at each STDC Module.
|
||||
Default: 4.
|
||||
with_final_conv (bool): Whether add a conv layer at the Module output.
|
||||
Default: True.
|
||||
pretrained (str, optional): Model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> stdc_type = 'STDCNet1'
|
||||
>>> in_channels = 3
|
||||
>>> channels = (32, 64, 256, 512, 1024)
|
||||
>>> bottleneck_type = 'cat'
|
||||
>>> inputs = torch.rand(1, 3, 1024, 2048)
|
||||
>>> self = STDCNet(stdc_type, in_channels,
|
||||
... channels, bottleneck_type).eval()
|
||||
>>> outputs = self.forward(inputs)
|
||||
>>> for i in range(len(outputs)):
|
||||
... print(f'outputs[{i}].shape = {outputs[i].shape}')
|
||||
outputs[0].shape = torch.Size([1, 256, 128, 256])
|
||||
outputs[1].shape = torch.Size([1, 512, 64, 128])
|
||||
outputs[2].shape = torch.Size([1, 1024, 32, 64])
|
||||
"""
|
||||
|
||||
arch_settings = {
|
||||
'STDCNet1': [(2, 1), (2, 1), (2, 1)],
|
||||
'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)]
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
stdc_type,
|
||||
in_channels,
|
||||
channels,
|
||||
bottleneck_type,
|
||||
norm_cfg,
|
||||
act_cfg,
|
||||
num_convs=4,
|
||||
with_final_conv=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert stdc_type in self.arch_settings, \
|
||||
f'invalid structure {stdc_type} for STDCNet.'
|
||||
assert bottleneck_type in ['add', 'cat'],\
|
||||
f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}'
|
||||
|
||||
assert len(channels) == 5,\
|
||||
f'invalid channels length {len(channels)} for STDCNet.'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.stage_strides = self.arch_settings[stdc_type]
|
||||
self.prtrained = pretrained
|
||||
self.num_convs = num_convs
|
||||
self.with_final_conv = with_final_conv
|
||||
|
||||
self.stages = ModuleList([
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels[0],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
self.channels[0],
|
||||
self.channels[1],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
])
|
||||
# `self.num_shallow_features` is the number of shallow modules in
|
||||
# `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper.
|
||||
# They are both not used for following modules like Attention
|
||||
# Refinement Module and Feature Fusion Module.
|
||||
# Thus they would be cut from `outs`. Please refer to Figure 4
|
||||
# of original paper for more details.
|
||||
self.num_shallow_features = len(self.stages)
|
||||
|
||||
for strides in self.stage_strides:
|
||||
idx = len(self.stages) - 1
|
||||
self.stages.append(
|
||||
self._make_stage(self.channels[idx], self.channels[idx + 1],
|
||||
strides, norm_cfg, act_cfg, bottleneck_type))
|
||||
# After appending, `self.stages` is a ModuleList including several
|
||||
# shallow modules and STDCModules.
|
||||
# (len(self.stages) ==
|
||||
# self.num_shallow_features + len(self.stage_strides))
|
||||
if self.with_final_conv:
|
||||
self.final_conv = ConvModule(
|
||||
self.channels[-1],
|
||||
max(1024, self.channels[-1]),
|
||||
1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def _make_stage(self, in_channels, out_channels, strides, norm_cfg,
|
||||
act_cfg, bottleneck_type):
|
||||
layers = []
|
||||
for i, stride in enumerate(strides):
|
||||
layers.append(
|
||||
STDCModule(
|
||||
in_channels if i == 0 else out_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
norm_cfg,
|
||||
act_cfg,
|
||||
num_convs=self.num_convs,
|
||||
fusion_type=bottleneck_type))
|
||||
return Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
for stage in self.stages:
|
||||
x = stage(x)
|
||||
outs.append(x)
|
||||
if self.with_final_conv:
|
||||
outs[-1] = self.final_conv(outs[-1])
|
||||
outs = outs[self.num_shallow_features:]
|
||||
return tuple(outs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class STDCContextPathNet(BaseModule):
|
||||
"""STDCNet with Context Path. The `outs` below is a list of three feature
|
||||
maps from deep to shallow, whose height and width is from small to big,
|
||||
respectively. The biggest feature map of `outs` is outputted for
|
||||
`STDCHead`, where Detail Loss would be calculated by Detail Ground-truth.
|
||||
The other two feature maps are used for Attention Refinement Module,
|
||||
respectively. Besides, the biggest feature map of `outs` and the last
|
||||
output of Attention Refinement Module are concatenated for Feature Fusion
|
||||
Module. Then, this fusion feature map `feat_fuse` would be outputted for
|
||||
`decode_head`. More details please refer to Figure 4 of original paper.
|
||||
|
||||
Args:
|
||||
backbone_cfg (dict): Config dict for stdc backbone.
|
||||
last_in_channels (tuple(int)), The number of channels of last
|
||||
two feature maps from stdc backbone. Default: (1024, 512).
|
||||
out_channels (int): The channels of output feature maps.
|
||||
Default: 128.
|
||||
ffm_cfg (dict): Config dict for Feature Fusion Module. Default:
|
||||
`dict(in_channels=512, out_channels=256, scale_factor=4)`.
|
||||
upsample_mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'nearest'``.
|
||||
align_corners (str): align_corners argument of F.interpolate. It
|
||||
must be `None` if upsample_mode is ``'nearest'``. Default: None.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
|
||||
Return:
|
||||
outputs (tuple): The tuple of list of output feature map for
|
||||
auxiliary heads and decoder head.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_cfg,
|
||||
last_in_channels=(1024, 512),
|
||||
out_channels=128,
|
||||
ffm_cfg=dict(
|
||||
in_channels=512, out_channels=256, scale_factor=4),
|
||||
upsample_mode='nearest',
|
||||
align_corners=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
self.arms = ModuleList()
|
||||
self.convs = ModuleList()
|
||||
for channels in last_in_channels:
|
||||
self.arms.append(AttentionRefinementModule(channels, out_channels))
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg))
|
||||
self.conv_avg = ConvModule(
|
||||
last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg)
|
||||
|
||||
self.ffm = FeatureFusionModule(**ffm_cfg)
|
||||
|
||||
self.upsample_mode = upsample_mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
outs = list(self.backbone(x))
|
||||
avg = F.adaptive_avg_pool2d(outs[-1], 1)
|
||||
avg_feat = self.conv_avg(avg)
|
||||
|
||||
feature_up = resize(
|
||||
avg_feat,
|
||||
size=outs[-1].shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=self.align_corners)
|
||||
arms_out = []
|
||||
for i in range(len(self.arms)):
|
||||
x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
|
||||
feature_up = resize(
|
||||
x_arm,
|
||||
size=outs[len(outs) - 1 - i - 1].shape[2:],
|
||||
mode=self.upsample_mode,
|
||||
align_corners=self.align_corners)
|
||||
feature_up = self.convs[i](feature_up)
|
||||
arms_out.append(feature_up)
|
||||
|
||||
feat_fuse = self.ffm(outs[0], arms_out[1])
|
||||
|
||||
# The `outputs` has four feature maps.
|
||||
# `outs[0]` is outputted for `STDCHead` auxiliary head.
|
||||
# Two feature maps of `arms_out` are outputted for auxiliary head.
|
||||
# `feat_fuse` is outputted for decoder head.
|
||||
outputs = [outs[0]] + list(arms_out) + [feat_fuse]
|
||||
return tuple(outputs)
|
||||
757
finetune/mmseg/models/backbones/swin.py
Normal file
757
finetune/mmseg/models/backbones/swin.py
Normal file
@@ -0,0 +1,757 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, trunc_normal_,
|
||||
trunc_normal_init)
|
||||
from mmengine.runner import CheckpointLoader
|
||||
from mmengine.utils import to_2tuple
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils.embed import PatchEmbed, PatchMerging
|
||||
|
||||
|
||||
class WindowMSA(BaseModule):
|
||||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||||
position bias.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: True.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
|
||||
init_cfg (dict | None, optional): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.embed_dims = embed_dims
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_embed_dims = embed_dims // num_heads
|
||||
self.scale = qk_scale or head_embed_dims**-0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
||||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# About 2x faster than original impl
|
||||
Wh, Ww = self.window_size
|
||||
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
|
||||
rel_position_index = rel_index_coords + rel_index_coords.T
|
||||
rel_position_index = rel_position_index.flip(1).contiguous()
|
||||
self.register_buffer('relative_position_index', rel_position_index)
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop_rate)
|
||||
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
|
||||
x (tensor): input features with shape of (num_windows*B, N, C)
|
||||
mask (tensor | None, Optional): mask with shape of (num_windows,
|
||||
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
# make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B // nW, nW, self.num_heads, N,
|
||||
N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def double_step_seq(step1, len1, step2, len2):
|
||||
seq1 = torch.arange(0, step1 * len1, step1)
|
||||
seq2 = torch.arange(0, step2 * len2, step2)
|
||||
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
|
||||
|
||||
|
||||
class ShiftWindowMSA(BaseModule):
|
||||
"""Shifted Window Multihead Self-Attention Module.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window.
|
||||
shift_size (int, optional): The shift step of each window towards
|
||||
right-bottom. If zero, act as regular window-msa. Defaults to 0.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Defaults: None.
|
||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||
Defaults: 0.
|
||||
proj_drop_rate (float, optional): Dropout ratio of output.
|
||||
Defaults: 0.
|
||||
dropout_layer (dict, optional): The dropout_layer used before output.
|
||||
Defaults: dict(type='DropPath', drop_prob=0.).
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size,
|
||||
shift_size=0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0,
|
||||
proj_drop_rate=0,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=0.),
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
assert 0 <= self.shift_size < self.window_size
|
||||
|
||||
self.w_msa = WindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
window_size=to_2tuple(window_size),
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
proj_drop_rate=proj_drop_rate,
|
||||
init_cfg=None)
|
||||
|
||||
self.drop = build_dropout(dropout_layer)
|
||||
|
||||
def forward(self, query, hw_shape):
|
||||
B, L, C = query.shape
|
||||
H, W = hw_shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
query = query.view(B, H, W, C)
|
||||
|
||||
# pad feature maps to multiples of window size
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
|
||||
H_pad, W_pad = query.shape[1], query.shape[2]
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_query = torch.roll(
|
||||
query,
|
||||
shifts=(-self.shift_size, -self.shift_size),
|
||||
dims=(1, 2))
|
||||
|
||||
# calculate attention mask for SW-MSA
|
||||
img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
# nW, window_size, window_size, 1
|
||||
mask_windows = self.window_partition(img_mask)
|
||||
mask_windows = mask_windows.view(
|
||||
-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
else:
|
||||
shifted_query = query
|
||||
attn_mask = None
|
||||
|
||||
# nW*B, window_size, window_size, C
|
||||
query_windows = self.window_partition(shifted_query)
|
||||
# nW*B, window_size*window_size, C
|
||||
query_windows = query_windows.view(-1, self.window_size**2, C)
|
||||
|
||||
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
|
||||
attn_windows = self.w_msa(query_windows, mask=attn_mask)
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size,
|
||||
self.window_size, C)
|
||||
|
||||
# B H' W' C
|
||||
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(
|
||||
shifted_x,
|
||||
shifts=(self.shift_size, self.shift_size),
|
||||
dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if pad_r > 0 or pad_b:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
def window_reverse(self, windows, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
window_size = self.window_size
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
def window_partition(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
window_size = self.window_size
|
||||
x = x.view(B, H // window_size, window_size, W // window_size,
|
||||
window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
||||
windows = windows.view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
class SwinBlock(BaseModule):
|
||||
""""
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
window_size (int, optional): The local window scale. Default: 7.
|
||||
shift (bool, optional): whether to shift window or not. Default False.
|
||||
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
drop_rate (float, optional): Dropout rate. Default: 0.
|
||||
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
|
||||
act_cfg (dict, optional): The config dict of activation function.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict, optional): The config dict of normalization.
|
||||
Default: dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
init_cfg (dict | list | None, optional): The init config.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
window_size=7,
|
||||
shift=False,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.attn = ShiftWindowMSA(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=window_size // 2 if shift else 0,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
proj_drop_rate=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
init_cfg=None)
|
||||
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=2,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg,
|
||||
add_identity=True,
|
||||
init_cfg=None)
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
|
||||
def _inner_forward(x):
|
||||
identity = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x, hw_shape)
|
||||
|
||||
x = x + identity
|
||||
|
||||
identity = x
|
||||
x = self.norm2(x)
|
||||
x = self.ffn(x, identity=identity)
|
||||
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SwinBlockSequence(BaseModule):
|
||||
"""Implements one stage in Swin Transformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
depth (int): The number of blocks in this stage.
|
||||
window_size (int, optional): The local window scale. Default: 7.
|
||||
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
drop_rate (float, optional): Dropout rate. Default: 0.
|
||||
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float | list[float], optional): Stochastic depth
|
||||
rate. Default: 0.
|
||||
downsample (BaseModule | None, optional): The downsample operation
|
||||
module. Default: None.
|
||||
act_cfg (dict, optional): The config dict of activation function.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict, optional): The config dict of normalization.
|
||||
Default: dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
init_cfg (dict | list | None, optional): The init config.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
depth,
|
||||
window_size=7,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
downsample=None,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(drop_path_rate, list):
|
||||
drop_path_rates = drop_path_rate
|
||||
assert len(drop_path_rates) == depth
|
||||
else:
|
||||
drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
|
||||
|
||||
self.blocks = ModuleList()
|
||||
for i in range(depth):
|
||||
block = SwinBlock(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=feedforward_channels,
|
||||
window_size=window_size,
|
||||
shift=False if i % 2 == 0 else True,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rates[i],
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
init_cfg=None)
|
||||
self.blocks.append(block)
|
||||
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
for block in self.blocks:
|
||||
x = block(x, hw_shape)
|
||||
|
||||
if self.downsample:
|
||||
x_down, down_hw_shape = self.downsample(x, hw_shape)
|
||||
return x_down, down_hw_shape, x, hw_shape
|
||||
else:
|
||||
return x, hw_shape, x, hw_shape
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SwinTransformer(BaseModule):
|
||||
"""Swin Transformer backbone.
|
||||
|
||||
This backbone is the implementation of `Swin Transformer:
|
||||
Hierarchical Vision Transformer using Shifted
|
||||
Windows <https://arxiv.org/abs/2103.14030>`_.
|
||||
Inspiration from https://github.com/microsoft/Swin-Transformer.
|
||||
|
||||
Args:
|
||||
pretrain_img_size (int | tuple[int]): The size of input image when
|
||||
pretrain. Defaults: 224.
|
||||
in_channels (int): The num of input channels.
|
||||
Defaults: 3.
|
||||
embed_dims (int): The feature dimension. Default: 96.
|
||||
patch_size (int | tuple[int]): Patch size. Default: 4.
|
||||
window_size (int): Window size. Default: 7.
|
||||
mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
depths (tuple[int]): Depths of each Swin Transformer stage.
|
||||
Default: (2, 2, 6, 2).
|
||||
num_heads (tuple[int]): Parallel attention heads of each Swin
|
||||
Transformer stage. Default: (3, 6, 12, 24).
|
||||
strides (tuple[int]): The patch merging or patch embedding stride of
|
||||
each Swin Transformer stage. (In swin, we set kernel size equal to
|
||||
stride.) Default: (4, 2, 2, 2).
|
||||
out_indices (tuple[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key,
|
||||
value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
patch_norm (bool): If add a norm layer for patch embed and patch
|
||||
merging. Default: True.
|
||||
drop_rate (float): Dropout rate. Defaults: 0.
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
|
||||
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
||||
the patch embedding. Defaults: False.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='LN').
|
||||
norm_cfg (dict): Config dict for normalization layer at
|
||||
output of backone. Defaults: dict(type='LN').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Default: False.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrain_img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=96,
|
||||
patch_size=4,
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
depths=(2, 2, 6, 2),
|
||||
num_heads=(3, 6, 12, 24),
|
||||
strides=(4, 2, 2, 2),
|
||||
out_indices=(0, 1, 2, 3),
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
patch_norm=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
use_abs_pos_embed=False,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
pretrained=None,
|
||||
frozen_stages=-1,
|
||||
init_cfg=None):
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
if isinstance(pretrain_img_size, int):
|
||||
pretrain_img_size = to_2tuple(pretrain_img_size)
|
||||
elif isinstance(pretrain_img_size, tuple):
|
||||
if len(pretrain_img_size) == 1:
|
||||
pretrain_img_size = to_2tuple(pretrain_img_size[0])
|
||||
assert len(pretrain_img_size) == 2, \
|
||||
f'The size of image should have length 1 or 2, ' \
|
||||
f'but got {len(pretrain_img_size)}'
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be specified at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
init_cfg = init_cfg
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
num_layers = len(depths)
|
||||
self.out_indices = out_indices
|
||||
self.use_abs_pos_embed = use_abs_pos_embed
|
||||
|
||||
assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=strides[0],
|
||||
padding='corner',
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None)
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
patch_row = pretrain_img_size[0] // patch_size
|
||||
patch_col = pretrain_img_size[1] // patch_size
|
||||
num_patches = patch_row * patch_col
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros((1, num_patches, embed_dims)))
|
||||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
# set stochastic depth decay rule
|
||||
total_depth = sum(depths)
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||
]
|
||||
|
||||
self.stages = ModuleList()
|
||||
in_channels = embed_dims
|
||||
for i in range(num_layers):
|
||||
if i < num_layers - 1:
|
||||
downsample = PatchMerging(
|
||||
in_channels=in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
stride=strides[i + 1],
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None)
|
||||
else:
|
||||
downsample = None
|
||||
|
||||
stage = SwinBlockSequence(
|
||||
embed_dims=in_channels,
|
||||
num_heads=num_heads[i],
|
||||
feedforward_channels=int(mlp_ratio * in_channels),
|
||||
depth=depths[i],
|
||||
window_size=window_size,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
||||
downsample=downsample,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
init_cfg=None)
|
||||
self.stages.append(stage)
|
||||
if downsample:
|
||||
in_channels = downsample.out_channels
|
||||
|
||||
self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
|
||||
# Add a norm layer for each output
|
||||
for i in out_indices:
|
||||
layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
|
||||
layer_name = f'norm{i}'
|
||||
self.add_module(layer_name, layer)
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep layers freezed."""
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
if self.use_abs_pos_embed:
|
||||
self.absolute_pos_embed.requires_grad = False
|
||||
self.drop_after_pos.eval()
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
|
||||
if (i - 1) in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i-1}')
|
||||
norm_layer.eval()
|
||||
for param in norm_layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
m = self.stages[i - 1]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self):
|
||||
if self.init_cfg is None:
|
||||
print_log(f'No pre-trained weights for '
|
||||
f'{self.__class__.__name__}, '
|
||||
f'training start from scratch')
|
||||
if self.use_abs_pos_embed:
|
||||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
else:
|
||||
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
||||
f'specify `Pretrained` in ' \
|
||||
f'`init_cfg` in ' \
|
||||
f'{self.__class__.__name__} '
|
||||
ckpt = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
if 'state_dict' in ckpt:
|
||||
_state_dict = ckpt['state_dict']
|
||||
elif 'model' in ckpt:
|
||||
_state_dict = ckpt['model']
|
||||
else:
|
||||
_state_dict = ckpt
|
||||
|
||||
state_dict = OrderedDict()
|
||||
for k, v in _state_dict.items():
|
||||
if k.startswith('backbone.'):
|
||||
state_dict[k[9:]] = v
|
||||
else:
|
||||
state_dict[k] = v
|
||||
|
||||
# strip prefix of state_dict
|
||||
if list(state_dict.keys())[0].startswith('module.'):
|
||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||
|
||||
# reshape absolute position embedding
|
||||
if state_dict.get('absolute_pos_embed') is not None:
|
||||
absolute_pos_embed = state_dict['absolute_pos_embed']
|
||||
N1, L, C1 = absolute_pos_embed.size()
|
||||
N2, C2, H, W = self.absolute_pos_embed.size()
|
||||
if N1 != N2 or C1 != C2 or L != H * W:
|
||||
print_log('Error in loading absolute_pos_embed, pass')
|
||||
else:
|
||||
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
|
||||
N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
# interpolate position bias table if needed
|
||||
relative_position_bias_table_keys = [
|
||||
k for k in state_dict.keys()
|
||||
if 'relative_position_bias_table' in k
|
||||
]
|
||||
for table_key in relative_position_bias_table_keys:
|
||||
table_pretrained = state_dict[table_key]
|
||||
if table_key in self.state_dict():
|
||||
table_current = self.state_dict()[table_key]
|
||||
L1, nH1 = table_pretrained.size()
|
||||
L2, nH2 = table_current.size()
|
||||
if nH1 != nH2:
|
||||
print_log(f'Error in loading {table_key}, pass')
|
||||
elif L1 != L2:
|
||||
S1 = int(L1**0.5)
|
||||
S2 = int(L2**0.5)
|
||||
table_pretrained_resized = F.interpolate(
|
||||
table_pretrained.permute(1, 0).reshape(
|
||||
1, nH1, S1, S1),
|
||||
size=(S2, S2),
|
||||
mode='bicubic')
|
||||
state_dict[table_key] = table_pretrained_resized.view(
|
||||
nH2, L2).permute(1, 0).contiguous()
|
||||
|
||||
# load state_dict
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
def forward(self, x):
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
out = norm_layer(out)
|
||||
out = out.view(-1, *out_hw_shape,
|
||||
self.num_features[i]).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return outs
|
||||
63
finetune/mmseg/models/backbones/timm_backbone.py
Normal file
63
finetune/mmseg/models/backbones/timm_backbone.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
try:
|
||||
import timm
|
||||
except ImportError:
|
||||
timm = None
|
||||
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.registry import MODELS as MMENGINE_MODELS
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class TIMMBackbone(BaseModule):
|
||||
"""Wrapper to use backbones from timm library. More details can be found in
|
||||
`timm <https://github.com/rwightman/pytorch-image-models>`_ .
|
||||
|
||||
Args:
|
||||
model_name (str): Name of timm model to instantiate.
|
||||
pretrained (bool): Load pretrained weights if True.
|
||||
checkpoint_path (str): Path of checkpoint to load after
|
||||
model is initialized.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
init_cfg (dict, optional): Initialization config dict
|
||||
**kwargs: Other timm & model specific arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
features_only=True,
|
||||
pretrained=True,
|
||||
checkpoint_path='',
|
||||
in_channels=3,
|
||||
init_cfg=None,
|
||||
**kwargs,
|
||||
):
|
||||
if timm is None:
|
||||
raise RuntimeError('timm is not installed')
|
||||
super().__init__(init_cfg)
|
||||
if 'norm_layer' in kwargs:
|
||||
kwargs['norm_layer'] = MMENGINE_MODELS.get(kwargs['norm_layer'])
|
||||
self.timm_model = timm.create_model(
|
||||
model_name=model_name,
|
||||
features_only=features_only,
|
||||
pretrained=pretrained,
|
||||
in_chans=in_channels,
|
||||
checkpoint_path=checkpoint_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Make unused parameters None
|
||||
self.timm_model.global_pool = None
|
||||
self.timm_model.fc = None
|
||||
self.timm_model.classifier = None
|
||||
|
||||
# Hack to use pretrained weights from timm
|
||||
if pretrained or checkpoint_path:
|
||||
self._is_init = True
|
||||
|
||||
def forward(self, x):
|
||||
features = self.timm_model(x)
|
||||
return features
|
||||
588
finetune/mmseg/models/backbones/twins.py
Normal file
588
finetune/mmseg/models/backbones/twins.py
Normal file
@@ -0,0 +1,588 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, normal_init,
|
||||
trunc_normal_init)
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.models.backbones.mit import EfficientMultiheadAttention
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils.embed import PatchEmbed
|
||||
|
||||
|
||||
class GlobalSubsampledAttention(EfficientMultiheadAttention):
|
||||
"""Global Sub-sampled Attention (Spatial Reduction Attention)
|
||||
|
||||
This module is modified from EfficientMultiheadAttention,
|
||||
which is a module from mmseg.models.backbones.mit.py.
|
||||
Specifically, there is no difference between
|
||||
`GlobalSubsampledAttention` and `EfficientMultiheadAttention`,
|
||||
`GlobalSubsampledAttention` is built as a brand new class
|
||||
because it is renamed as `Global sub-sampled attention (GSA)`
|
||||
in paper.
|
||||
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
attn_drop (float): A Dropout layer on attn_output_weights.
|
||||
Default: 0.0.
|
||||
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
||||
Default: 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut. Default: None.
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dims)
|
||||
or (n, batch, embed_dims). Default: False.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
sr_ratio (int): The ratio of spatial reduction of GSA of PCPVT.
|
||||
Default: 1.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
dropout_layer=None,
|
||||
batch_first=True,
|
||||
qkv_bias=True,
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=1,
|
||||
init_cfg=None):
|
||||
super().__init__(
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
dropout_layer=dropout_layer,
|
||||
batch_first=batch_first,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg,
|
||||
sr_ratio=sr_ratio,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
|
||||
class GSAEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer with GSA.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Default: 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default: 0.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
sr_ratio (float): Kernel_size of conv in Attention modules. Default: 1.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=1.,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
|
||||
self.attn = GlobalSubsampledAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg,
|
||||
sr_ratio=sr_ratio)
|
||||
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg,
|
||||
add_identity=False)
|
||||
|
||||
self.drop_path = build_dropout(
|
||||
dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape, identity=0.))
|
||||
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class LocallyGroupedSelfAttention(BaseModule):
|
||||
"""Locally-grouped Self Attention (LSA) module.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads. Default: 8
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Default: False.
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
|
||||
window_size(int): Window size of LSA. Default: 1.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
window_size=1,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert embed_dims % num_heads == 0, f'dim {embed_dims} should be ' \
|
||||
f'divided by num_heads ' \
|
||||
f'{num_heads}.'
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
head_dim = embed_dims // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop_rate)
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
b, n, c = x.shape
|
||||
h, w = hw_shape
|
||||
x = x.view(b, h, w, c)
|
||||
|
||||
# pad feature maps to multiples of Local-groups
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.window_size - w % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - h % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
|
||||
# calculate attention mask for LSA
|
||||
Hp, Wp = x.shape[1:-1]
|
||||
_h, _w = Hp // self.window_size, Wp // self.window_size
|
||||
mask = torch.zeros((1, Hp, Wp), device=x.device)
|
||||
mask[:, -pad_b:, :].fill_(1)
|
||||
mask[:, :, -pad_r:].fill_(1)
|
||||
|
||||
# [B, _h, _w, window_size, window_size, C]
|
||||
x = x.reshape(b, _h, self.window_size, _w, self.window_size,
|
||||
c).transpose(2, 3)
|
||||
mask = mask.reshape(1, _h, self.window_size, _w,
|
||||
self.window_size).transpose(2, 3).reshape(
|
||||
1, _h * _w,
|
||||
self.window_size * self.window_size)
|
||||
# [1, _h*_w, window_size*window_size, window_size*window_size]
|
||||
attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-1000.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
|
||||
# [3, B, _w*_h, nhead, window_size*window_size, dim]
|
||||
qkv = self.qkv(x).reshape(b, _h * _w,
|
||||
self.window_size * self.window_size, 3,
|
||||
self.num_heads, c // self.num_heads).permute(
|
||||
3, 0, 1, 4, 2, 5)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
# [B, _h*_w, n_head, window_size*window_size, window_size*window_size]
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn + attn_mask.unsqueeze(2)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
attn = (attn @ v).transpose(2, 3).reshape(b, _h, _w, self.window_size,
|
||||
self.window_size, c)
|
||||
x = attn.transpose(2, 3).reshape(b, _h * self.window_size,
|
||||
_w * self.window_size, c)
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :h, :w, :].contiguous()
|
||||
|
||||
x = x.reshape(b, n, c)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class LSAEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in Twins-SVT.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Default: 0.0.
|
||||
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
||||
Default: 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of
|
||||
head_dim ** -0.5 if set. Default: None.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
window_size (int): Window size of LSA. Default: 1.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
window_size=1,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
|
||||
self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads,
|
||||
qkv_bias, qk_scale,
|
||||
attn_drop_rate, drop_rate,
|
||||
window_size)
|
||||
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg,
|
||||
add_identity=False)
|
||||
|
||||
self.drop_path = build_dropout(
|
||||
dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape))
|
||||
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class ConditionalPositionEncoding(BaseModule):
|
||||
"""The Conditional Position Encoding (CPE) module.
|
||||
|
||||
The CPE is the implementation of 'Conditional Positional Encodings
|
||||
for Vision Transformers <https://arxiv.org/abs/2102.10882>'_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
embed_dims (int): The feature dimension. Default: 768.
|
||||
stride (int): Stride of conv layer. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels,
|
||||
embed_dims,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=True,
|
||||
groups=embed_dims)
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
b, n, c = x.shape
|
||||
h, w = hw_shape
|
||||
feat_token = x
|
||||
cnn_feat = feat_token.transpose(1, 2).view(b, c, h, w)
|
||||
if self.stride == 1:
|
||||
x = self.proj(cnn_feat) + cnn_feat
|
||||
else:
|
||||
x = self.proj(cnn_feat)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PCPVT(BaseModule):
|
||||
"""The backbone of Twins-PCPVT.
|
||||
|
||||
This backbone is the implementation of `Twins: Revisiting the Design
|
||||
of Spatial Attention in Vision Transformers
|
||||
<https://arxiv.org/abs/1512.03385>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512].
|
||||
patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2].
|
||||
strides (list): The strides. Default: [4, 2, 2, 2].
|
||||
num_heads (int): Number of attention heads. Default: [1, 2, 4, 8].
|
||||
mlp_ratios (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: [4, 4, 4, 4].
|
||||
out_indices (tuple[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: False.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.0
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
depths (list): Depths of each stage. Default [3, 4, 6, 3]
|
||||
sr_ratios (list): Kernel_size of conv in each Attn module in
|
||||
Transformer encoder layer. Default: [8, 4, 2, 1].
|
||||
norm_after_stage(bool): Add extra norm. Default False.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=[64, 128, 256, 512],
|
||||
patch_sizes=[4, 2, 2, 2],
|
||||
strides=[4, 2, 2, 2],
|
||||
num_heads=[1, 2, 4, 8],
|
||||
mlp_ratios=[4, 4, 4, 4],
|
||||
out_indices=(0, 1, 2, 3),
|
||||
qkv_bias=False,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
depths=[3, 4, 6, 3],
|
||||
sr_ratios=[8, 4, 2, 1],
|
||||
norm_after_stage=False,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
self.depths = depths
|
||||
|
||||
# patch_embed
|
||||
self.patch_embeds = ModuleList()
|
||||
self.position_encoding_drops = ModuleList()
|
||||
self.layers = ModuleList()
|
||||
|
||||
for i in range(len(depths)):
|
||||
self.patch_embeds.append(
|
||||
PatchEmbed(
|
||||
in_channels=in_channels if i == 0 else embed_dims[i - 1],
|
||||
embed_dims=embed_dims[i],
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_sizes[i],
|
||||
stride=strides[i],
|
||||
padding='corner',
|
||||
norm_cfg=norm_cfg))
|
||||
|
||||
self.position_encoding_drops.append(nn.Dropout(p=drop_rate))
|
||||
|
||||
self.position_encodings = ModuleList([
|
||||
ConditionalPositionEncoding(embed_dim, embed_dim)
|
||||
for embed_dim in embed_dims
|
||||
])
|
||||
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
cur = 0
|
||||
|
||||
for k in range(len(depths)):
|
||||
_block = ModuleList([
|
||||
GSAEncoderLayer(
|
||||
embed_dims=embed_dims[k],
|
||||
num_heads=num_heads[k],
|
||||
feedforward_channels=mlp_ratios[k] * embed_dims[k],
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[cur + i],
|
||||
num_fcs=2,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
sr_ratio=sr_ratios[k]) for i in range(depths[k])
|
||||
])
|
||||
self.layers.append(_block)
|
||||
cur += depths[k]
|
||||
|
||||
self.norm_name, norm = build_norm_layer(
|
||||
norm_cfg, embed_dims[-1], postfix=1)
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.norm_after_stage = norm_after_stage
|
||||
if self.norm_after_stage:
|
||||
self.norm_list = ModuleList()
|
||||
for dim in embed_dims:
|
||||
self.norm_list.append(build_norm_layer(norm_cfg, dim)[1])
|
||||
|
||||
def init_weights(self):
|
||||
if self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_init(m, std=.02, bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[
|
||||
1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
normal_init(
|
||||
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
|
||||
|
||||
def forward(self, x):
|
||||
outputs = list()
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
for i in range(len(self.depths)):
|
||||
x, hw_shape = self.patch_embeds[i](x)
|
||||
h, w = hw_shape
|
||||
x = self.position_encoding_drops[i](x)
|
||||
for j, blk in enumerate(self.layers[i]):
|
||||
x = blk(x, hw_shape)
|
||||
if j == 0:
|
||||
x = self.position_encodings[i](x, hw_shape)
|
||||
if self.norm_after_stage:
|
||||
x = self.norm_list[i](x)
|
||||
x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
if i in self.out_indices:
|
||||
outputs.append(x)
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SVT(PCPVT):
|
||||
"""The backbone of Twins-SVT.
|
||||
|
||||
This backbone is the implementation of `Twins: Revisiting the Design
|
||||
of Spatial Attention in Vision Transformers
|
||||
<https://arxiv.org/abs/1512.03385>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512].
|
||||
patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2].
|
||||
strides (list): The strides. Default: [4, 2, 2, 2].
|
||||
num_heads (int): Number of attention heads. Default: [1, 2, 4].
|
||||
mlp_ratios (int): Ratio of mlp hidden dim to embedding dim.
|
||||
Default: [4, 4, 4].
|
||||
out_indices (tuple[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default: False.
|
||||
drop_rate (float): Dropout rate. Default 0.
|
||||
attn_drop_rate (float): Dropout ratio of attention weight.
|
||||
Default 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default 0.2.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
depths (list): Depths of each stage. Default [4, 4, 4].
|
||||
sr_ratios (list): Kernel_size of conv in each Attn module in
|
||||
Transformer encoder layer. Default: [4, 2, 1].
|
||||
windiow_sizes (list): Window size of LSA. Default: [7, 7, 7],
|
||||
input_features_slice(bool): Input features need slice. Default: False.
|
||||
norm_after_stage(bool): Add extra norm. Default False.
|
||||
strides (list): Strides in patch-Embedding modules. Default: (2, 2, 2)
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
embed_dims=[64, 128, 256],
|
||||
patch_sizes=[4, 2, 2, 2],
|
||||
strides=[4, 2, 2, 2],
|
||||
num_heads=[1, 2, 4],
|
||||
mlp_ratios=[4, 4, 4],
|
||||
out_indices=(0, 1, 2, 3),
|
||||
qkv_bias=False,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.2,
|
||||
norm_cfg=dict(type='LN'),
|
||||
depths=[4, 4, 4],
|
||||
sr_ratios=[4, 2, 1],
|
||||
windiow_sizes=[7, 7, 7],
|
||||
norm_after_stage=True,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(in_channels, embed_dims, patch_sizes, strides,
|
||||
num_heads, mlp_ratios, out_indices, qkv_bias,
|
||||
drop_rate, attn_drop_rate, drop_path_rate, norm_cfg,
|
||||
depths, sr_ratios, norm_after_stage, pretrained,
|
||||
init_cfg)
|
||||
# transformer encoder
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
|
||||
for k in range(len(depths)):
|
||||
for i in range(depths[k]):
|
||||
if i % 2 == 0:
|
||||
self.layers[k][i] = \
|
||||
LSAEncoderLayer(
|
||||
embed_dims=embed_dims[k],
|
||||
num_heads=num_heads[k],
|
||||
feedforward_channels=mlp_ratios[k] * embed_dims[k],
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[sum(depths[:k])+i],
|
||||
qkv_bias=qkv_bias,
|
||||
window_size=windiow_sizes[k])
|
||||
436
finetune/mmseg/models/backbones/unet.py
Normal file
436
finetune/mmseg/models/backbones/unet.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import UpConvBlock, Upsample
|
||||
|
||||
|
||||
class BasicConvBlock(nn.Module):
|
||||
"""Basic convolutional block for UNet.
|
||||
|
||||
This module consists of several plain convolutional layers.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
num_convs (int): Number of convolutional layers. Default: 2.
|
||||
stride (int): Whether use stride convolution to downsample
|
||||
the input feature map. If stride=2, it only uses stride convolution
|
||||
in the first convolutional layer to downsample the input feature
|
||||
map. Options are 1 or 2. Default: 1.
|
||||
dilation (int): Whether use dilated convolution to expand the
|
||||
receptive field. Set dilation rate of each convolutional layer and
|
||||
the dilation rate of the first convolutional layer is always 1.
|
||||
Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_convs=2,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super().__init__()
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
self.with_cp = with_cp
|
||||
convs = []
|
||||
for i in range(num_convs):
|
||||
convs.append(
|
||||
ConvModule(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride if i == 0 else 1,
|
||||
dilation=1 if i == 0 else dilation,
|
||||
padding=1 if i == 0 else dilation,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
self.convs = nn.Sequential(*convs)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.convs, x)
|
||||
else:
|
||||
out = self.convs(x)
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DeconvModule(nn.Module):
|
||||
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
||||
|
||||
This module uses deconvolution to upsample feature map in the decoder
|
||||
of UNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
with_cp=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
*,
|
||||
kernel_size=4,
|
||||
scale_factor=2):
|
||||
super().__init__()
|
||||
|
||||
assert (kernel_size - scale_factor >= 0) and\
|
||||
(kernel_size - scale_factor) % 2 == 0,\
|
||||
f'kernel_size should be greater than or equal to scale_factor '\
|
||||
f'and (kernel_size - scale_factor) should be even numbers, '\
|
||||
f'while the kernel size is {kernel_size} and scale_factor is '\
|
||||
f'{scale_factor}.'
|
||||
|
||||
stride = scale_factor
|
||||
padding = (kernel_size - scale_factor) // 2
|
||||
self.with_cp = with_cp
|
||||
deconv = nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
|
||||
norm_name, norm = build_norm_layer(norm_cfg, out_channels)
|
||||
activate = build_activation_layer(act_cfg)
|
||||
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.deconv_upsamping, x)
|
||||
else:
|
||||
out = self.deconv_upsamping(x)
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class InterpConv(nn.Module):
|
||||
"""Interpolation upsample module in decoder for UNet.
|
||||
|
||||
This module uses interpolation to upsample feature map in the decoder
|
||||
of UNet. It consists of one interpolation upsample layer and one
|
||||
convolutional layer. It can be one interpolation upsample layer followed
|
||||
by one convolutional layer (conv_first=False) or one convolutional layer
|
||||
followed by one interpolation upsample layer (conv_first=True).
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
conv_first (bool): Whether convolutional layer or interpolation
|
||||
upsample layer first. Default: False. It means interpolation
|
||||
upsample layer followed by one convolutional layer.
|
||||
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
|
||||
stride (int): Stride of the convolutional layer. Default: 1.
|
||||
padding (int): Padding of the convolutional layer. Default: 1.
|
||||
upsample_cfg (dict): Interpolation config of the upsample layer.
|
||||
Default: dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
with_cp=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
*,
|
||||
conv_cfg=None,
|
||||
conv_first=False,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
upsample_cfg=dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False)):
|
||||
super().__init__()
|
||||
|
||||
self.with_cp = with_cp
|
||||
conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
upsample = Upsample(**upsample_cfg)
|
||||
if conv_first:
|
||||
self.interp_upsample = nn.Sequential(conv, upsample)
|
||||
else:
|
||||
self.interp_upsample = nn.Sequential(upsample, conv)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.interp_upsample, x)
|
||||
else:
|
||||
out = self.interp_upsample(x)
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class UNet(BaseModule):
|
||||
"""UNet backbone.
|
||||
|
||||
This backbone is the implementation of `U-Net: Convolutional Networks
|
||||
for Biomedical Image Segmentation <https://arxiv.org/abs/1505.04597>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default" 3.
|
||||
base_channels (int): Number of base channels of each stage.
|
||||
The output channels of the first stage. Default: 64.
|
||||
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
||||
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
||||
len(strides) is equal to num_stages. Normally the stride of the
|
||||
first stage in encoder is 1. If strides[i]=2, it uses stride
|
||||
convolution to downsample in the correspondence encoder stage.
|
||||
Default: (1, 1, 1, 1, 1).
|
||||
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
||||
convolution block of the correspondence encoder stage.
|
||||
Default: (2, 2, 2, 2, 2).
|
||||
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
||||
convolution block of the correspondence decoder stage.
|
||||
Default: (2, 2, 2, 2).
|
||||
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
||||
feature map after the first stage of encoder
|
||||
(stages: [1, num_stages)). If the correspondence encoder stage use
|
||||
stride convolution (strides[i]=2), it will never use MaxPool to
|
||||
downsample, even downsamples[i-1]=True.
|
||||
Default: (True, True, True, True).
|
||||
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
||||
Default: (1, 1, 1, 1, 1).
|
||||
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
||||
Default: (1, 1, 1, 1).
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
upsample_cfg (dict): The upsample config of the upsample module in
|
||||
decoder. Default: dict(type='InterpConv').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
dcn (bool): Use deformable convolution in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
pretrained (str, optional): model pretrained path. Default: None
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
|
||||
Notice:
|
||||
The input image size should be divisible by the whole downsample rate
|
||||
of the encoder. More detail of the whole downsample rate can be found
|
||||
in UNet._check_input_divisible.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1),
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
upsample_cfg=dict(type='InterpConv'),
|
||||
norm_eval=False,
|
||||
dcn=None,
|
||||
plugins=None,
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.pretrained = pretrained
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be setting at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is None:
|
||||
if init_cfg is None:
|
||||
self.init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
assert len(strides) == num_stages, \
|
||||
'The length of strides should be equal to num_stages, '\
|
||||
f'while the strides is {strides}, the length of '\
|
||||
f'strides is {len(strides)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(enc_num_convs) == num_stages, \
|
||||
'The length of enc_num_convs should be equal to num_stages, '\
|
||||
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
||||
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(dec_num_convs) == (num_stages-1), \
|
||||
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
||||
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
||||
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(downsamples) == (num_stages-1), \
|
||||
'The length of downsamples should be equal to (num_stages-1), '\
|
||||
f'while the downsamples is {downsamples}, the length of '\
|
||||
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(enc_dilations) == num_stages, \
|
||||
'The length of enc_dilations should be equal to num_stages, '\
|
||||
f'while the enc_dilations is {enc_dilations}, the length of '\
|
||||
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(dec_dilations) == (num_stages-1), \
|
||||
'The length of dec_dilations should be equal to (num_stages-1), '\
|
||||
f'while the dec_dilations is {dec_dilations}, the length of '\
|
||||
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
self.num_stages = num_stages
|
||||
self.strides = strides
|
||||
self.downsamples = downsamples
|
||||
self.norm_eval = norm_eval
|
||||
self.base_channels = base_channels
|
||||
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
for i in range(num_stages):
|
||||
enc_conv_block = []
|
||||
if i != 0:
|
||||
if strides[i] == 1 and downsamples[i - 1]:
|
||||
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
||||
upsample = (strides[i] != 1 or downsamples[i - 1])
|
||||
self.decoder.append(
|
||||
UpConvBlock(
|
||||
conv_block=BasicConvBlock,
|
||||
in_channels=base_channels * 2**i,
|
||||
skip_channels=base_channels * 2**(i - 1),
|
||||
out_channels=base_channels * 2**(i - 1),
|
||||
num_convs=dec_num_convs[i - 1],
|
||||
stride=1,
|
||||
dilation=dec_dilations[i - 1],
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
upsample_cfg=upsample_cfg if upsample else None,
|
||||
dcn=None,
|
||||
plugins=None))
|
||||
|
||||
enc_conv_block.append(
|
||||
BasicConvBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=base_channels * 2**i,
|
||||
num_convs=enc_num_convs[i],
|
||||
stride=strides[i],
|
||||
dilation=enc_dilations[i],
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
dcn=None,
|
||||
plugins=None))
|
||||
self.encoder.append(nn.Sequential(*enc_conv_block))
|
||||
in_channels = base_channels * 2**i
|
||||
|
||||
def forward(self, x):
|
||||
self._check_input_divisible(x)
|
||||
enc_outs = []
|
||||
for enc in self.encoder:
|
||||
x = enc(x)
|
||||
enc_outs.append(x)
|
||||
dec_outs = [x]
|
||||
for i in reversed(range(len(self.decoder))):
|
||||
x = self.decoder[i](enc_outs[i], x)
|
||||
dec_outs.append(x)
|
||||
|
||||
return dec_outs
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep normalization layer
|
||||
freezed."""
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def _check_input_divisible(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
whole_downsample_rate = 1
|
||||
for i in range(1, self.num_stages):
|
||||
if self.strides[i] == 2 or self.downsamples[i - 1]:
|
||||
whole_downsample_rate *= 2
|
||||
assert (h % whole_downsample_rate == 0) \
|
||||
and (w % whole_downsample_rate == 0),\
|
||||
f'The input image size {(h, w)} should be divisible by the whole '\
|
||||
f'downsample rate {whole_downsample_rate}, when num_stages is '\
|
||||
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
|
||||
f'is {self.downsamples}.'
|
||||
501
finetune/mmseg/models/backbones/vit.py
Normal file
501
finetune/mmseg/models/backbones/vit.py
Normal file
@@ -0,0 +1,501 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import (constant_init, kaiming_init,
|
||||
trunc_normal_)
|
||||
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import PatchEmbed, resize
|
||||
|
||||
|
||||
class TransformerEncoderLayer(BaseModule):
|
||||
"""Implements one encoder layer in Vision Transformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Default: 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default: 0.0.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default: True
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default: True.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
batch_first=True,
|
||||
attn_cfg=dict(),
|
||||
ffn_cfg=dict(),
|
||||
with_cp=False):
|
||||
super().__init__()
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
attn_cfg.update(
|
||||
dict(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
batch_first=batch_first,
|
||||
bias=qkv_bias))
|
||||
|
||||
self.build_attn(attn_cfg)
|
||||
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
|
||||
ffn_cfg.update(
|
||||
dict(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
if drop_path_rate > 0 else None,
|
||||
act_cfg=act_cfg))
|
||||
self.build_ffn(ffn_cfg)
|
||||
self.with_cp = with_cp
|
||||
|
||||
def build_attn(self, attn_cfg):
|
||||
self.attn = MultiheadAttention(**attn_cfg)
|
||||
|
||||
def build_ffn(self, ffn_cfg):
|
||||
self.ffn = FFN(**ffn_cfg)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
x = self.attn(self.norm1(x), identity=x)
|
||||
x = self.ffn(self.norm2(x), identity=x)
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VisionTransformer(BaseModule):
|
||||
"""Vision Transformer.
|
||||
|
||||
This backbone is the implementation of `An Image is Worth 16x16 Words:
|
||||
Transformers for Image Recognition at
|
||||
Scale <https://arxiv.org/abs/2010.11929>`_.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Default: 224.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
patch_pad (str | int | None): The padding method in patch embedding.
|
||||
Default: 'corner'.
|
||||
in_channels (int): Number of input channels. Default: 3.
|
||||
embed_dims (int): embedding dimension. Default: 768.
|
||||
num_layers (int): depth of transformer. Default: 12.
|
||||
num_heads (int): number of attention heads. Default: 12.
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
||||
Default: 4.
|
||||
out_origin (bool): Whether to output the original input embedding.
|
||||
Default: False
|
||||
out_indices (list | tuple | int): Output from which stages.
|
||||
Default: -1.
|
||||
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.0
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.0
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.0
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Default: True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Default: False.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN')
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Default: dict(type='GELU').
|
||||
patch_bias (dict): Whether use bias in convolution of PatchEmbed Block.
|
||||
Default: True.
|
||||
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
|
||||
Default: False.
|
||||
pre_norm (bool): Whether to add a norm before Transformer Layers.
|
||||
Default: False.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Default: False.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Default: bicubic.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Default: 2.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed. Default: False.
|
||||
frozen_exclude (List): List of parameters that are not to be frozen.
|
||||
Default: ["all"], "all" means there are no frozen parameters.
|
||||
pretrained (str, optional): model pretrained path. Default: None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
patch_pad='corner',
|
||||
in_channels=3,
|
||||
embed_dims=768,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
out_origin=False,
|
||||
out_indices=-1,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
with_cls_token=True,
|
||||
output_cls_token=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='GELU'),
|
||||
patch_norm=False,
|
||||
patch_bias=False,
|
||||
pre_norm=False,
|
||||
final_norm=False,
|
||||
interpolate_mode='bicubic',
|
||||
num_fcs=2,
|
||||
norm_eval=False,
|
||||
with_cp=False,
|
||||
frozen_exclude=['all'],
|
||||
pretrained=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(img_size, int):
|
||||
img_size = to_2tuple(img_size)
|
||||
elif isinstance(img_size, tuple):
|
||||
if len(img_size) == 1:
|
||||
img_size = to_2tuple(img_size[0])
|
||||
assert len(img_size) == 2, \
|
||||
f'The size of image should have length 1 or 2, ' \
|
||||
f'but got {len(img_size)}'
|
||||
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
|
||||
assert not (init_cfg and pretrained), \
|
||||
'init_cfg and pretrained cannot be set at the same time'
|
||||
if isinstance(pretrained, str):
|
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, '
|
||||
'please use "init_cfg" instead')
|
||||
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
self.pretrained = pretrained
|
||||
self.out_origin = out_origin
|
||||
self.frozen_exclude = frozen_exclude
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding=patch_pad,
|
||||
bias=patch_bias,
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
init_cfg=None,
|
||||
)
|
||||
|
||||
num_patches = (img_size[0] // patch_size) * \
|
||||
(img_size[1] // patch_size)
|
||||
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, embed_dims))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
if self.pre_norm:
|
||||
self.pre_ln_name, pre_ln = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix='_pre')
|
||||
self.add_module(self.pre_ln_name, pre_ln)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
if out_indices == -1:
|
||||
out_indices = num_layers - 1
|
||||
self.out_indices = [out_indices]
|
||||
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
|
||||
self.out_indices = out_indices
|
||||
else:
|
||||
raise TypeError('out_indices must be type of int, list or tuple')
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, num_layers)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.layers = ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(
|
||||
TransformerEncoderLayer(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio * embed_dims,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp,
|
||||
batch_first=True))
|
||||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
self._freeze()
|
||||
|
||||
@property
|
||||
def pre_ln(self):
|
||||
return getattr(self, self.pre_ln_name)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def init_weights(self):
|
||||
if isinstance(self.init_cfg, dict) and \
|
||||
self.init_cfg.get('type') in ['Pretrained', 'Pretrained_Part']:
|
||||
checkpoint = CheckpointLoader.load_checkpoint(
|
||||
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
|
||||
|
||||
if self.init_cfg.get('type') == 'Pretrained':
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
elif self.init_cfg.get('type') == 'Pretrained_Part':
|
||||
state_dict = checkpoint.copy()
|
||||
para_prefix = 'image_encoder'
|
||||
prefix_len = len(para_prefix) + 1
|
||||
for k, v in checkpoint.items():
|
||||
state_dict.pop(k)
|
||||
if para_prefix in k:
|
||||
state_dict[k[prefix_len:]] = v
|
||||
|
||||
if 'pos_embed' in state_dict.keys():
|
||||
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
||||
print_log(msg=f'Resize the pos_embed shape from '
|
||||
f'{state_dict["pos_embed"].shape} to '
|
||||
f'{self.pos_embed.shape}')
|
||||
h, w = self.img_size
|
||||
pos_size = int(
|
||||
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
||||
state_dict['pos_embed'] = self.resize_pos_embed(
|
||||
state_dict['pos_embed'],
|
||||
(h // self.patch_size, w // self.patch_size),
|
||||
(pos_size, pos_size), self.interpolate_mode)
|
||||
|
||||
load_state_dict(self, state_dict, strict=False, logger=None)
|
||||
elif self.init_cfg is not None:
|
||||
super().init_weights()
|
||||
else:
|
||||
# We only implement the 'jax_impl' initialization implemented at
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
if 'ffn' in n:
|
||||
nn.init.normal_(m.bias, mean=0., std=1e-6)
|
||||
else:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m, mode='fan_in', bias=0.)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
||||
constant_init(m, val=1.0, bias=0.)
|
||||
|
||||
def _freeze(self):
|
||||
if 'all' in self.frozen_exclude:
|
||||
return
|
||||
for name, param in self.named_parameters():
|
||||
if not any([exclude in name for exclude in self.frozen_exclude]):
|
||||
param.requires_grad = False
|
||||
|
||||
def _pos_embeding(self, patched_img, hw_shape, pos_embed):
|
||||
"""Positioning embeding method.
|
||||
|
||||
Resize the pos_embed, if the input image size doesn't match
|
||||
the training size.
|
||||
Args:
|
||||
patched_img (torch.Tensor): The patched image, it should be
|
||||
shape of [B, L1, C].
|
||||
hw_shape (tuple): The downsampled image resolution.
|
||||
pos_embed (torch.Tensor): The pos_embed weighs, it should be
|
||||
shape of [B, L2, c].
|
||||
Return:
|
||||
torch.Tensor: The pos encoded image feature.
|
||||
"""
|
||||
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
|
||||
'the shapes of patched_img and pos_embed must be [B, L, C]'
|
||||
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
|
||||
if x_len != pos_len:
|
||||
if pos_len == (self.img_size[0] // self.patch_size) * (
|
||||
self.img_size[1] // self.patch_size) + 1:
|
||||
pos_h = self.img_size[0] // self.patch_size
|
||||
pos_w = self.img_size[1] // self.patch_size
|
||||
else:
|
||||
raise ValueError(
|
||||
'Unexpected shape of pos_embed, got {}.'.format(
|
||||
pos_embed.shape))
|
||||
pos_embed = self.resize_pos_embed(pos_embed, hw_shape,
|
||||
(pos_h, pos_w),
|
||||
self.interpolate_mode)
|
||||
return self.drop_after_pos(patched_img + pos_embed)
|
||||
|
||||
@staticmethod
|
||||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
|
||||
"""Resize pos_embed weights.
|
||||
|
||||
Resize pos_embed using bicubic interpolate method.
|
||||
Args:
|
||||
pos_embed (torch.Tensor): Position embedding weights.
|
||||
input_shpae (tuple): Tuple for (downsampled input image height,
|
||||
downsampled input image width).
|
||||
pos_shape (tuple): The resolution of downsampled origin training
|
||||
image.
|
||||
mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'nearest'``
|
||||
Return:
|
||||
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
|
||||
"""
|
||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
||||
pos_h, pos_w = pos_shape
|
||||
cls_token_weight = pos_embed[:, 0]
|
||||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
||||
pos_embed_weight = pos_embed_weight.reshape(
|
||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||
pos_embed_weight = resize(
|
||||
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
|
||||
cls_token_weight = cls_token_weight.unsqueeze(1)
|
||||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
||||
return pos_embed
|
||||
|
||||
def forward(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
|
||||
x, hw_shape = self.patch_embed(inputs)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = self._pos_embeding(x, hw_shape, self.pos_embed)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
if self.pre_norm:
|
||||
x = self.pre_ln(x)
|
||||
|
||||
outs = []
|
||||
if self.out_origin:
|
||||
if self.with_cls_token:
|
||||
# Remove class token and reshape token for decoder head
|
||||
out = x[:, 1:]
|
||||
else:
|
||||
out = x
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
if self.output_cls_token:
|
||||
out = [out, x[:, 0]]
|
||||
outs.append(out)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i == len(self.layers) - 1:
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
if i in self.out_indices:
|
||||
if self.with_cls_token:
|
||||
# Remove class token and reshape token for decoder head
|
||||
out = x[:, 1:]
|
||||
else:
|
||||
out = x
|
||||
B, _, C = out.shape
|
||||
out = out.reshape(B, hw_shape[0], hw_shape[1],
|
||||
C).permute(0, 3, 1, 2).contiguous()
|
||||
if self.output_cls_token:
|
||||
out = [out, x[:, 0]]
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.LayerNorm):
|
||||
m.eval()
|
||||
395
finetune/mmseg/models/backbones/vpd.py
Normal file
395
finetune/mmseg/models/backbones/vpd.py
Normal file
@@ -0,0 +1,395 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# ------------------------------------------------------------------------------
|
||||
# Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py
|
||||
# Original licence: MIT License
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.runner import CheckpointLoader, load_checkpoint
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import ConfigType, OptConfigType
|
||||
|
||||
try:
|
||||
from ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from ldm.util import instantiate_from_config
|
||||
has_ldm = True
|
||||
except ImportError:
|
||||
has_ldm = False
|
||||
|
||||
|
||||
def register_attention_control(model, controller):
|
||||
"""Registers a control function to manage attention within a model.
|
||||
|
||||
Args:
|
||||
model: The model to which attention is to be registered.
|
||||
controller: The control function responsible for managing attention.
|
||||
"""
|
||||
|
||||
def ca_forward(self, place_in_unet):
|
||||
"""Custom forward method for attention.
|
||||
|
||||
Args:
|
||||
self: Reference to the current object.
|
||||
place_in_unet: The location in UNet (down/mid/up).
|
||||
|
||||
Returns:
|
||||
The modified forward method.
|
||||
"""
|
||||
|
||||
def forward(x, context=None, mask=None):
|
||||
h = self.heads
|
||||
is_cross = context is not None
|
||||
context = context or x # if context is None, use x
|
||||
|
||||
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
||||
q, k, v = (
|
||||
tensor.view(tensor.shape[0] * h, tensor.shape[1],
|
||||
tensor.shape[2] // h) for tensor in [q, k, v])
|
||||
|
||||
sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
attn = sim.softmax(dim=-1)
|
||||
attn_mean = attn.view(h, attn.shape[0] // h,
|
||||
*attn.shape[1:]).mean(0)
|
||||
controller(attn_mean, is_cross, place_in_unet)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h)
|
||||
return self.to_out(out)
|
||||
|
||||
return forward
|
||||
|
||||
def register_recr(net_, count, place_in_unet):
|
||||
"""Recursive function to register the custom forward method to all
|
||||
CrossAttention layers.
|
||||
|
||||
Args:
|
||||
net_: The network layer currently being processed.
|
||||
count: The current count of layers processed.
|
||||
place_in_unet: The location in UNet (down/mid/up).
|
||||
|
||||
Returns:
|
||||
The updated count of layers processed.
|
||||
"""
|
||||
if net_.__class__.__name__ == 'CrossAttention':
|
||||
net_.forward = ca_forward(net_, place_in_unet)
|
||||
return count + 1
|
||||
if hasattr(net_, 'children'):
|
||||
return sum(
|
||||
register_recr(child, 0, place_in_unet)
|
||||
for child in net_.children())
|
||||
return count
|
||||
|
||||
cross_att_count = sum(
|
||||
register_recr(net[1], 0, place) for net, place in [
|
||||
(child, 'down') if 'input_blocks' in name else (
|
||||
child, 'up') if 'output_blocks' in name else
|
||||
(child,
|
||||
'mid') if 'middle_block' in name else (None, None) # Default case
|
||||
for name, child in model.diffusion_model.named_children()
|
||||
] if net is not None)
|
||||
|
||||
controller.num_att_layers = cross_att_count
|
||||
|
||||
|
||||
class AttentionStore:
|
||||
"""A class for storing attention information in the UNet model.
|
||||
|
||||
Attributes:
|
||||
base_size (int): Base size for storing attention information.
|
||||
max_size (int): Maximum size for storing attention information.
|
||||
"""
|
||||
|
||||
def __init__(self, base_size=64, max_size=None):
|
||||
"""Initialize AttentionStore with default or custom sizes."""
|
||||
self.reset()
|
||||
self.base_size = base_size
|
||||
self.max_size = max_size or (base_size // 2)
|
||||
self.num_att_layers = -1
|
||||
|
||||
@staticmethod
|
||||
def get_empty_store():
|
||||
"""Returns an empty store for holding attention values."""
|
||||
return {
|
||||
key: []
|
||||
for key in [
|
||||
'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self',
|
||||
'up_self'
|
||||
]
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Resets the step and attention stores to their initial states."""
|
||||
self.cur_step = 0
|
||||
self.cur_att_layer = 0
|
||||
self.step_store = self.get_empty_store()
|
||||
self.attention_store = {}
|
||||
|
||||
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
||||
"""Processes a single forward step, storing the attention.
|
||||
|
||||
Args:
|
||||
attn: The attention tensor.
|
||||
is_cross (bool): Whether it's cross attention.
|
||||
place_in_unet (str): The location in UNet (down/mid/up).
|
||||
|
||||
Returns:
|
||||
The unmodified attention tensor.
|
||||
"""
|
||||
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
||||
if attn.shape[1] <= (self.max_size)**2:
|
||||
self.step_store[key].append(attn)
|
||||
return attn
|
||||
|
||||
def between_steps(self):
|
||||
"""Processes and stores attention information between steps."""
|
||||
if not self.attention_store:
|
||||
self.attention_store = self.step_store
|
||||
else:
|
||||
for key in self.attention_store:
|
||||
self.attention_store[key] = [
|
||||
stored + step for stored, step in zip(
|
||||
self.attention_store[key], self.step_store[key])
|
||||
]
|
||||
self.step_store = self.get_empty_store()
|
||||
|
||||
def get_average_attention(self):
|
||||
"""Calculates and returns the average attention across all steps."""
|
||||
return {
|
||||
key: [item for item in self.step_store[key]]
|
||||
for key in self.step_store
|
||||
}
|
||||
|
||||
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
||||
"""Allows the class instance to be callable."""
|
||||
return self.forward(attn, is_cross, place_in_unet)
|
||||
|
||||
@property
|
||||
def num_uncond_att_layers(self):
|
||||
"""Returns the number of unconditional attention layers (default is
|
||||
0)."""
|
||||
return 0
|
||||
|
||||
def step_callback(self, x_t):
|
||||
"""A placeholder for a step callback.
|
||||
|
||||
Returns the input unchanged.
|
||||
"""
|
||||
return x_t
|
||||
|
||||
|
||||
class UNetWrapper(nn.Module):
|
||||
"""A wrapper for UNet with optional attention mechanisms.
|
||||
|
||||
Args:
|
||||
unet (nn.Module): The UNet model to wrap
|
||||
use_attn (bool): Whether to use attention. Defaults to True
|
||||
base_size (int): Base size for the attention store. Defaults to 512
|
||||
max_attn_size (int, optional): Maximum size for the attention store.
|
||||
Defaults to None
|
||||
attn_selector (str): The types of attention to use.
|
||||
Defaults to 'up_cross+down_cross'
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
unet,
|
||||
use_attn=True,
|
||||
base_size=512,
|
||||
max_attn_size=None,
|
||||
attn_selector='up_cross+down_cross'):
|
||||
super().__init__()
|
||||
|
||||
assert has_ldm, 'To use UNetWrapper, please install required ' \
|
||||
'packages via `pip install -r requirements/optional.txt`.'
|
||||
|
||||
self.unet = unet
|
||||
self.attention_store = AttentionStore(
|
||||
base_size=base_size // 8, max_size=max_attn_size)
|
||||
self.attn_selector = attn_selector.split('+')
|
||||
self.use_attn = use_attn
|
||||
self.init_sizes(base_size)
|
||||
if self.use_attn:
|
||||
register_attention_control(unet, self.attention_store)
|
||||
|
||||
def init_sizes(self, base_size):
|
||||
"""Initialize sizes based on the base size."""
|
||||
self.size16 = base_size // 32
|
||||
self.size32 = base_size // 16
|
||||
self.size64 = base_size // 8
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
"""Forward pass through the model."""
|
||||
diffusion_model = self.unet.diffusion_model
|
||||
if self.use_attn:
|
||||
self.attention_store.reset()
|
||||
hs, emb, out_list = self._unet_forward(x, timesteps, context, y,
|
||||
diffusion_model)
|
||||
if self.use_attn:
|
||||
self._append_attn_to_output(out_list)
|
||||
return out_list[::-1]
|
||||
|
||||
def _unet_forward(self, x, timesteps, context, y, diffusion_model):
|
||||
hs = []
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, diffusion_model.model_channels, repeat_only=False)
|
||||
emb = diffusion_model.time_embed(t_emb)
|
||||
h = x.type(diffusion_model.dtype)
|
||||
for module in diffusion_model.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = diffusion_model.middle_block(h, emb, context)
|
||||
out_list = []
|
||||
for i_out, module in enumerate(diffusion_model.output_blocks):
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
if i_out in [1, 4, 7]:
|
||||
out_list.append(h)
|
||||
h = h.type(x.dtype)
|
||||
out_list.append(h)
|
||||
return hs, emb, out_list
|
||||
|
||||
def _append_attn_to_output(self, out_list):
|
||||
avg_attn = self.attention_store.get_average_attention()
|
||||
attns = {self.size16: [], self.size32: [], self.size64: []}
|
||||
for k in self.attn_selector:
|
||||
for up_attn in avg_attn[k]:
|
||||
size = int(math.sqrt(up_attn.shape[1]))
|
||||
up_attn = up_attn.transpose(-1, -2).reshape(
|
||||
*up_attn.shape[:2], size, -1)
|
||||
attns[size].append(up_attn)
|
||||
attn16 = torch.stack(attns[self.size16]).mean(0)
|
||||
attn32 = torch.stack(attns[self.size32]).mean(0)
|
||||
attn64 = torch.stack(attns[self.size64]).mean(0) if len(
|
||||
attns[self.size64]) > 0 else None
|
||||
out_list[1] = torch.cat([out_list[1], attn16], dim=1)
|
||||
out_list[2] = torch.cat([out_list[2], attn32], dim=1)
|
||||
if attn64 is not None:
|
||||
out_list[3] = torch.cat([out_list[3], attn64], dim=1)
|
||||
|
||||
|
||||
class TextAdapter(nn.Module):
|
||||
"""A PyTorch Module that serves as a text adapter.
|
||||
|
||||
This module takes text embeddings and adjusts them based on a scaling
|
||||
factor gamma.
|
||||
"""
|
||||
|
||||
def __init__(self, text_dim=768):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(text_dim, text_dim), nn.GELU(),
|
||||
nn.Linear(text_dim, text_dim))
|
||||
|
||||
def forward(self, texts, gamma):
|
||||
texts_after = self.fc(texts)
|
||||
texts = texts + gamma * texts_after
|
||||
return texts
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VPD(BaseModule):
|
||||
"""VPD (Visual Perception Diffusion) model.
|
||||
|
||||
.. _`VPD`: https://arxiv.org/abs/2303.02153
|
||||
|
||||
Args:
|
||||
diffusion_cfg (dict): Configuration for diffusion model.
|
||||
class_embed_path (str): Path for class embeddings.
|
||||
unet_cfg (dict, optional): Configuration for U-Net.
|
||||
gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4.
|
||||
class_embed_select (bool, optional): If True, enables class embedding
|
||||
selection. Defaults to False.
|
||||
pad_shape (Optional[Union[int, List[int]]], optional): Padding shape.
|
||||
Defaults to None.
|
||||
pad_val (Union[int, List[int]], optional): Padding value.
|
||||
Defaults to 0.
|
||||
init_cfg (dict, optional): Configuration for network initialization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
diffusion_cfg: ConfigType,
|
||||
class_embed_path: str,
|
||||
unet_cfg: OptConfigType = dict(),
|
||||
gamma: float = 1e-4,
|
||||
class_embed_select=False,
|
||||
pad_shape: Optional[Union[int, List[int]]] = None,
|
||||
pad_val: Union[int, List[int]] = 0,
|
||||
init_cfg: OptConfigType = None):
|
||||
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
assert has_ldm, 'To use VPD model, please install required packages' \
|
||||
' via `pip install -r requirements/optional.txt`.'
|
||||
|
||||
if pad_shape is not None:
|
||||
if not isinstance(pad_shape, (list, tuple)):
|
||||
pad_shape = (pad_shape, pad_shape)
|
||||
|
||||
self.pad_shape = pad_shape
|
||||
self.pad_val = pad_val
|
||||
|
||||
# diffusion model
|
||||
diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None)
|
||||
sd_model = instantiate_from_config(diffusion_cfg)
|
||||
if diffusion_checkpoint is not None:
|
||||
load_checkpoint(sd_model, diffusion_checkpoint, strict=False)
|
||||
|
||||
self.encoder_vq = sd_model.first_stage_model
|
||||
self.unet = UNetWrapper(sd_model.model, **unet_cfg)
|
||||
|
||||
# class embeddings & text adapter
|
||||
class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path)
|
||||
text_dim = class_embeddings.size(-1)
|
||||
self.text_adapter = TextAdapter(text_dim=text_dim)
|
||||
self.class_embed_select = class_embed_select
|
||||
if class_embed_select:
|
||||
class_embeddings = torch.cat(
|
||||
(class_embeddings, class_embeddings.mean(dim=0,
|
||||
keepdims=True)),
|
||||
dim=0)
|
||||
self.register_buffer('class_embeddings', class_embeddings)
|
||||
self.gamma = nn.Parameter(torch.ones(text_dim) * gamma)
|
||||
|
||||
def forward(self, x):
|
||||
"""Extract features from images."""
|
||||
|
||||
# calculate cross-attn map
|
||||
if self.class_embed_select:
|
||||
if isinstance(x, (tuple, list)):
|
||||
x, class_ids = x[:2]
|
||||
class_ids = class_ids.tolist()
|
||||
else:
|
||||
class_ids = [-1] * x.size(0)
|
||||
class_embeddings = self.class_embeddings[class_ids]
|
||||
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
|
||||
c_crossattn = c_crossattn.unsqueeze(1)
|
||||
else:
|
||||
class_embeddings = self.class_embeddings
|
||||
c_crossattn = self.text_adapter(class_embeddings, self.gamma)
|
||||
c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1)
|
||||
|
||||
# pad to required input shape for pretrained diffusion model
|
||||
if self.pad_shape is not None:
|
||||
pad_width = max(0, self.pad_shape[1] - x.shape[-1])
|
||||
pad_height = max(0, self.pad_shape[0] - x.shape[-2])
|
||||
x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val)
|
||||
|
||||
# forward the denoising model
|
||||
with torch.no_grad():
|
||||
latents = self.encoder_vq.encode(x).mode().detach()
|
||||
t = torch.ones((x.shape[0], ), device=x.device).long()
|
||||
outs = self.unet(latents, t, context=c_crossattn)
|
||||
|
||||
return outs
|
||||
52
finetune/mmseg/models/builder.py
Normal file
52
finetune/mmseg/models/builder.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
BACKBONES = MODELS
|
||||
NECKS = MODELS
|
||||
HEADS = MODELS
|
||||
LOSSES = MODELS
|
||||
SEGMENTORS = MODELS
|
||||
|
||||
|
||||
def build_backbone(cfg):
|
||||
"""Build backbone."""
|
||||
warnings.warn('``build_backbone`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return BACKBONES.build(cfg)
|
||||
|
||||
|
||||
def build_neck(cfg):
|
||||
"""Build neck."""
|
||||
warnings.warn('``build_neck`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return NECKS.build(cfg)
|
||||
|
||||
|
||||
def build_head(cfg):
|
||||
"""Build head."""
|
||||
warnings.warn('``build_head`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return HEADS.build(cfg)
|
||||
|
||||
|
||||
def build_loss(cfg):
|
||||
"""Build loss."""
|
||||
warnings.warn('``build_loss`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return LOSSES.build(cfg)
|
||||
|
||||
|
||||
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
|
||||
"""Build segmentor."""
|
||||
if train_cfg is not None or test_cfg is not None:
|
||||
warnings.warn(
|
||||
'train_cfg and test_cfg is deprecated, '
|
||||
'please specify them in model', UserWarning)
|
||||
assert cfg.get('train_cfg') is None or train_cfg is None, \
|
||||
'train_cfg specified in both outer field and model field '
|
||||
assert cfg.get('test_cfg') is None or test_cfg is None, \
|
||||
'test_cfg specified in both outer field and model field '
|
||||
return SEGMENTORS.build(
|
||||
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
||||
151
finetune/mmseg/models/data_preprocessor.py
Normal file
151
finetune/mmseg/models/data_preprocessor.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseDataPreprocessor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import stack_batch
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SegDataPreProcessor(BaseDataPreprocessor):
|
||||
"""Image pre-processor for segmentation tasks.
|
||||
|
||||
Comparing with the :class:`mmengine.ImgDataPreprocessor`,
|
||||
|
||||
1. It won't do normalization if ``mean`` is not specified.
|
||||
2. It does normalization and color space conversion after stacking batch.
|
||||
3. It supports batch augmentations like mixup and cutmix.
|
||||
|
||||
|
||||
It provides the data pre-processing as follows
|
||||
|
||||
- Collate and move data to the target device.
|
||||
- Pad inputs to the input size with defined ``pad_val``, and pad seg map
|
||||
with defined ``seg_pad_val``.
|
||||
- Stack inputs to batch_inputs.
|
||||
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
|
||||
- Normalize image with defined std and mean.
|
||||
- Do batch augmentations like Mixup and Cutmix during training.
|
||||
|
||||
Args:
|
||||
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
|
||||
Defaults to None.
|
||||
std (Sequence[Number], optional): The pixel standard deviation of
|
||||
R, G, B channels. Defaults to None.
|
||||
size (tuple, optional): Fixed padding size.
|
||||
size_divisor (int, optional): The divisor of padded size.
|
||||
pad_val (float, optional): Padding value. Default: 0.
|
||||
seg_pad_val (float, optional): Padding value of segmentation map.
|
||||
Default: 255.
|
||||
padding_mode (str): Type of padding. Default: constant.
|
||||
- constant: pads with a constant value, this value is specified
|
||||
with pad_val.
|
||||
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
|
||||
Defaults to False.
|
||||
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
|
||||
Defaults to False.
|
||||
batch_augments (list[dict], optional): Batch-level augmentations
|
||||
test_cfg (dict, optional): The padding size config in testing, if not
|
||||
specify, will use `size` and `size_divisor` params as default.
|
||||
Defaults to None, only supports keys `size` or `size_divisor`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mean: Sequence[Number] = None,
|
||||
std: Sequence[Number] = None,
|
||||
size: Optional[tuple] = None,
|
||||
size_divisor: Optional[int] = None,
|
||||
pad_val: Number = 0,
|
||||
seg_pad_val: Number = 255,
|
||||
bgr_to_rgb: bool = False,
|
||||
rgb_to_bgr: bool = False,
|
||||
batch_augments: Optional[List[dict]] = None,
|
||||
test_cfg: dict = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.size_divisor = size_divisor
|
||||
self.pad_val = pad_val
|
||||
self.seg_pad_val = seg_pad_val
|
||||
|
||||
assert not (bgr_to_rgb and rgb_to_bgr), (
|
||||
'`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time')
|
||||
self.channel_conversion = rgb_to_bgr or bgr_to_rgb
|
||||
|
||||
if mean is not None:
|
||||
assert std is not None, 'To enable the normalization in ' \
|
||||
'preprocessing, please specify both ' \
|
||||
'`mean` and `std`.'
|
||||
# Enable the normalization in preprocessing.
|
||||
self._enable_normalize = True
|
||||
self.register_buffer('mean',
|
||||
torch.tensor(mean).view(-1, 1, 1), False)
|
||||
self.register_buffer('std',
|
||||
torch.tensor(std).view(-1, 1, 1), False)
|
||||
else:
|
||||
self._enable_normalize = False
|
||||
|
||||
# TODO: support batch augmentations.
|
||||
self.batch_augments = batch_augments
|
||||
|
||||
# Support different padding methods in testing
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
def forward(self, data: dict, training: bool = False) -> Dict[str, Any]:
|
||||
"""Perform normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (dict): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation.
|
||||
|
||||
Returns:
|
||||
Dict: Data in the same format as the model input.
|
||||
"""
|
||||
data = self.cast_data(data) # type: ignore
|
||||
inputs = data['inputs']
|
||||
data_samples = data.get('data_samples', None)
|
||||
# TODO: whether normalize should be after stack_batch
|
||||
if self.channel_conversion and inputs[0].size(0) == 3:
|
||||
inputs = [_input[[2, 1, 0], ...] for _input in inputs]
|
||||
|
||||
inputs = [_input.float() for _input in inputs]
|
||||
if self._enable_normalize:
|
||||
inputs = [(_input - self.mean) / self.std for _input in inputs]
|
||||
|
||||
if training:
|
||||
assert data_samples is not None, ('During training, ',
|
||||
'`data_samples` must be define.')
|
||||
inputs, data_samples = stack_batch(
|
||||
inputs=inputs,
|
||||
data_samples=data_samples,
|
||||
size=self.size,
|
||||
size_divisor=self.size_divisor,
|
||||
pad_val=self.pad_val,
|
||||
seg_pad_val=self.seg_pad_val)
|
||||
|
||||
if self.batch_augments is not None:
|
||||
inputs, data_samples = self.batch_augments(
|
||||
inputs, data_samples)
|
||||
else:
|
||||
img_size = inputs[0].shape[1:]
|
||||
assert all(input_.shape[1:] == img_size for input_ in inputs), \
|
||||
'The image size in a batch should be the same.'
|
||||
# pad images when testing
|
||||
if self.test_cfg:
|
||||
inputs, padded_samples = stack_batch(
|
||||
inputs=inputs,
|
||||
size=self.test_cfg.get('size', None),
|
||||
size_divisor=self.test_cfg.get('size_divisor', None),
|
||||
pad_val=self.pad_val,
|
||||
seg_pad_val=self.seg_pad_val)
|
||||
for data_sample, pad_info in zip(data_samples, padded_samples):
|
||||
data_sample.set_metainfo({**pad_info})
|
||||
else:
|
||||
inputs = torch.stack(inputs, dim=0)
|
||||
|
||||
return dict(inputs=inputs, data_samples=data_samples)
|
||||
48
finetune/mmseg/models/decode_heads/__init__.py
Normal file
48
finetune/mmseg/models/decode_heads/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ann_head import ANNHead
|
||||
from .apc_head import APCHead
|
||||
from .aspp_head import ASPPHead
|
||||
from .cc_head import CCHead
|
||||
from .da_head import DAHead
|
||||
from .ddr_head import DDRHead
|
||||
from .dm_head import DMHead
|
||||
from .dnl_head import DNLHead
|
||||
from .dpt_head import DPTHead
|
||||
from .ema_head import EMAHead
|
||||
from .enc_head import EncHead
|
||||
from .fcn_head import FCNHead
|
||||
from .fpn_head import FPNHead
|
||||
from .gc_head import GCHead
|
||||
from .ham_head import LightHamHead
|
||||
from .isa_head import ISAHead
|
||||
from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
|
||||
from .lraspp_head import LRASPPHead
|
||||
from .mask2former_head import Mask2FormerHead
|
||||
from .maskformer_head import MaskFormerHead
|
||||
from .nl_head import NLHead
|
||||
from .ocr_head import OCRHead
|
||||
from .pid_head import PIDHead
|
||||
from .point_head import PointHead
|
||||
from .psa_head import PSAHead
|
||||
from .psp_head import PSPHead
|
||||
from .san_head import SideAdapterCLIPHead
|
||||
from .segformer_head import SegformerHead
|
||||
from .segmenter_mask_head import SegmenterMaskTransformerHead
|
||||
from .sep_aspp_head import DepthwiseSeparableASPPHead
|
||||
from .sep_fcn_head import DepthwiseSeparableFCNHead
|
||||
from .setr_mla_head import SETRMLAHead
|
||||
from .setr_up_head import SETRUPHead
|
||||
from .stdc_head import STDCHead
|
||||
from .uper_head import UPerHead
|
||||
from .vpd_depth_head import VPDDepthHead
|
||||
|
||||
__all__ = [
|
||||
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
|
||||
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
|
||||
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
|
||||
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
|
||||
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
|
||||
'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead', 'SideAdapterCLIPHead'
|
||||
]
|
||||
245
finetune/mmseg/models/decode_heads/ann_head.py
Normal file
245
finetune/mmseg/models/decode_heads/ann_head.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PPMConcat(nn.ModuleList):
|
||||
"""Pyramid Pooling Module that only concat the features of each layer.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 3, 6, 8)):
|
||||
super().__init__(
|
||||
[nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(feats)
|
||||
ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
|
||||
concat_outs = torch.cat(ppm_outs, dim=2)
|
||||
return concat_outs
|
||||
|
||||
|
||||
class SelfAttentionBlock(_SelfAttentionBlock):
|
||||
"""Make a ANN used SelfAttentionBlock.
|
||||
|
||||
Args:
|
||||
low_in_channels (int): Input channels of lower level feature,
|
||||
which is the key feature for self-attention.
|
||||
high_in_channels (int): Input channels of higher level feature,
|
||||
which is the query feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
share_key_query (bool): Whether share projection weight between key
|
||||
and query projection.
|
||||
query_scale (int): The scale of query feature map.
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, low_in_channels, high_in_channels, channels,
|
||||
out_channels, share_key_query, query_scale, key_pool_scales,
|
||||
conv_cfg, norm_cfg, act_cfg):
|
||||
key_psp = PPMConcat(key_pool_scales)
|
||||
if query_scale > 1:
|
||||
query_downsample = nn.MaxPool2d(kernel_size=query_scale)
|
||||
else:
|
||||
query_downsample = None
|
||||
super().__init__(
|
||||
key_in_channels=low_in_channels,
|
||||
query_in_channels=high_in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=share_key_query,
|
||||
query_downsample=query_downsample,
|
||||
key_downsample=key_psp,
|
||||
key_query_num_convs=1,
|
||||
key_query_norm=True,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=True,
|
||||
with_out=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
|
||||
class AFNB(nn.Module):
|
||||
"""Asymmetric Fusion Non-local Block(AFNB)
|
||||
|
||||
Args:
|
||||
low_in_channels (int): Input channels of lower level feature,
|
||||
which is the key feature for self-attention.
|
||||
high_in_channels (int): Input channels of higher level feature,
|
||||
which is the query feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
and query projection.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, low_in_channels, high_in_channels, channels,
|
||||
out_channels, query_scales, key_pool_scales, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.stages = nn.ModuleList()
|
||||
for query_scale in query_scales:
|
||||
self.stages.append(
|
||||
SelfAttentionBlock(
|
||||
low_in_channels=low_in_channels,
|
||||
high_in_channels=high_in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=False,
|
||||
query_scale=query_scale,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.bottleneck = ConvModule(
|
||||
out_channels + high_in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
def forward(self, low_feats, high_feats):
|
||||
"""Forward function."""
|
||||
priors = [stage(high_feats, low_feats) for stage in self.stages]
|
||||
context = torch.stack(priors, dim=0).sum(dim=0)
|
||||
output = self.bottleneck(torch.cat([context, high_feats], 1))
|
||||
return output
|
||||
|
||||
|
||||
class APNB(nn.Module):
|
||||
"""Asymmetric Pyramid Non-local Block (APNB)
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature,
|
||||
which is the key feature for self-attention.
|
||||
channels (int): Output channels of key/query transform.
|
||||
out_channels (int): Output channels.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
||||
Module of key feature.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict|None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels, out_channels, query_scales,
|
||||
key_pool_scales, conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.stages = nn.ModuleList()
|
||||
for query_scale in query_scales:
|
||||
self.stages.append(
|
||||
SelfAttentionBlock(
|
||||
low_in_channels=in_channels,
|
||||
high_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
share_key_query=True,
|
||||
query_scale=query_scale,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.bottleneck = ConvModule(
|
||||
2 * in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
priors = [stage(feats, feats) for stage in self.stages]
|
||||
context = torch.stack(priors, dim=0).sum(dim=0)
|
||||
output = self.bottleneck(torch.cat([context, feats], 1))
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ANNHead(BaseDecodeHead):
|
||||
"""Asymmetric Non-local Neural Networks for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `ANNNet
|
||||
<https://arxiv.org/abs/1908.07678>`_.
|
||||
|
||||
Args:
|
||||
project_channels (int): Projection channels for Nonlocal.
|
||||
query_scales (tuple[int]): The scales of query feature map.
|
||||
Default: (1,)
|
||||
key_pool_scales (tuple[int]): The pooling scales of key feature map.
|
||||
Default: (1, 3, 6, 8).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
project_channels,
|
||||
query_scales=(1, ),
|
||||
key_pool_scales=(1, 3, 6, 8),
|
||||
**kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
assert len(self.in_channels) == 2
|
||||
low_in_channels, high_in_channels = self.in_channels
|
||||
self.project_channels = project_channels
|
||||
self.fusion = AFNB(
|
||||
low_in_channels=low_in_channels,
|
||||
high_in_channels=high_in_channels,
|
||||
out_channels=high_in_channels,
|
||||
channels=project_channels,
|
||||
query_scales=query_scales,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
high_in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.context = APNB(
|
||||
in_channels=self.channels,
|
||||
out_channels=self.channels,
|
||||
channels=project_channels,
|
||||
query_scales=query_scales,
|
||||
key_pool_scales=key_pool_scales,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
low_feats, high_feats = self._transform_inputs(inputs)
|
||||
output = self.fusion(low_feats, high_feats)
|
||||
output = self.dropout(output)
|
||||
output = self.bottleneck(output)
|
||||
output = self.context(output)
|
||||
output = self.cls_seg(output)
|
||||
|
||||
return output
|
||||
159
finetune/mmseg/models/decode_heads/apc_head.py
Normal file
159
finetune/mmseg/models/decode_heads/apc_head.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ACM(nn.Module):
|
||||
"""Adaptive Context Module used in APCNet.
|
||||
|
||||
Args:
|
||||
pool_scale (int): Pooling scale used in Adaptive Context
|
||||
Module to extract region features.
|
||||
fusion (bool): Add one conv to fuse residual feature.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.pool_scale = pool_scale
|
||||
self.fusion = fusion
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.pooled_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.input_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.global_info = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
|
||||
|
||||
self.residual_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
if self.fusion:
|
||||
self.fusion_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
|
||||
# [batch_size, channels, h, w]
|
||||
x = self.input_redu_conv(x)
|
||||
# [batch_size, channels, pool_scale, pool_scale]
|
||||
pooled_x = self.pooled_redu_conv(pooled_x)
|
||||
batch_size = x.size(0)
|
||||
# [batch_size, pool_scale * pool_scale, channels]
|
||||
pooled_x = pooled_x.view(batch_size, self.channels,
|
||||
-1).permute(0, 2, 1).contiguous()
|
||||
# [batch_size, h * w, pool_scale * pool_scale]
|
||||
affinity_matrix = self.gla(x + resize(
|
||||
self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
|
||||
).permute(0, 2, 3, 1).reshape(
|
||||
batch_size, -1, self.pool_scale**2)
|
||||
affinity_matrix = F.sigmoid(affinity_matrix)
|
||||
# [batch_size, h * w, channels]
|
||||
z_out = torch.matmul(affinity_matrix, pooled_x)
|
||||
# [batch_size, channels, h * w]
|
||||
z_out = z_out.permute(0, 2, 1).contiguous()
|
||||
# [batch_size, channels, h, w]
|
||||
z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
|
||||
z_out = self.residual_conv(z_out)
|
||||
z_out = F.relu(z_out + x)
|
||||
if self.fusion:
|
||||
z_out = self.fusion_conv(z_out)
|
||||
|
||||
return z_out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class APCHead(BaseDecodeHead):
|
||||
"""Adaptive Pyramid Context Network for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`APCNet <https://openaccess.thecvf.com/content_CVPR_2019/papers/\
|
||||
He_Adaptive_Pyramid_Context_Network_for_Semantic_Segmentation_\
|
||||
CVPR_2019_paper.pdf>`_.
|
||||
|
||||
Args:
|
||||
pool_scales (tuple[int]): Pooling scales used in Adaptive Context
|
||||
Module. Default: (1, 2, 3, 6).
|
||||
fusion (bool): Add one conv to fuse residual feature.
|
||||
"""
|
||||
|
||||
def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(pool_scales, (list, tuple))
|
||||
self.pool_scales = pool_scales
|
||||
self.fusion = fusion
|
||||
acm_modules = []
|
||||
for pool_scale in self.pool_scales:
|
||||
acm_modules.append(
|
||||
ACM(pool_scale,
|
||||
self.fusion,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.acm_modules = nn.ModuleList(acm_modules)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(pool_scales) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
acm_outs = [x]
|
||||
for acm_module in self.acm_modules:
|
||||
acm_outs.append(acm_module(x))
|
||||
acm_outs = torch.cat(acm_outs, dim=1)
|
||||
output = self.bottleneck(acm_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
122
finetune/mmseg/models/decode_heads/aspp_head.py
Normal file
122
finetune/mmseg/models/decode_heads/aspp_head.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ASPPModule(nn.ModuleList):
|
||||
"""Atrous Spatial Pyramid Pooling (ASPP) Module.
|
||||
|
||||
Args:
|
||||
dilations (tuple[int]): Dilation rate of each layer.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
|
||||
act_cfg):
|
||||
super().__init__()
|
||||
self.dilations = dilations
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
for dilation in dilations:
|
||||
self.append(
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1 if dilation == 1 else 3,
|
||||
dilation=dilation,
|
||||
padding=0 if dilation == 1 else dilation,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
aspp_outs = []
|
||||
for aspp_module in self:
|
||||
aspp_outs.append(aspp_module(x))
|
||||
|
||||
return aspp_outs
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ASPPHead(BaseDecodeHead):
|
||||
"""Rethinking Atrous Convolution for Semantic Image Segmentation.
|
||||
|
||||
This head is the implementation of `DeepLabV3
|
||||
<https://arxiv.org/abs/1706.05587>`_.
|
||||
|
||||
Args:
|
||||
dilations (tuple[int]): Dilation rates for ASPP module.
|
||||
Default: (1, 6, 12, 18).
|
||||
"""
|
||||
|
||||
def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(dilations, (list, tuple))
|
||||
self.dilations = dilations
|
||||
self.image_pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.aspp_modules = ASPPModule(
|
||||
dilations,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
(len(dilations) + 1) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
aspp_outs = [
|
||||
resize(
|
||||
self.image_pool(x),
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
]
|
||||
aspp_outs.extend(self.aspp_modules(x))
|
||||
aspp_outs = torch.cat(aspp_outs, dim=1)
|
||||
feats = self.bottleneck(aspp_outs)
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
62
finetune/mmseg/models/decode_heads/cascade_decode_head.py
Normal file
62
finetune/mmseg/models/decode_heads/cascade_decode_head.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.utils import ConfigType
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
|
||||
"""Base class for cascade decode head used in
|
||||
:class:`CascadeEncoderDecoder."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, inputs, prev_output):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def loss(self, inputs: List[Tensor], prev_output: Tensor,
|
||||
batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor:
|
||||
"""Forward function for training.
|
||||
|
||||
Args:
|
||||
inputs (List[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs, prev_output)
|
||||
losses = self.loss_by_feat(seg_logits, batch_data_samples)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, inputs: List[Tensor], prev_output: Tensor,
|
||||
batch_img_metas: List[dict], tese_cfg: ConfigType):
|
||||
"""Forward function for testing.
|
||||
|
||||
Args:
|
||||
inputs (List[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
batch_img_metas (dict): List Image info where each dict may also
|
||||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Output segmentation map.
|
||||
"""
|
||||
seg_logits = self.forward(inputs, prev_output)
|
||||
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
43
finetune/mmseg/models/decode_heads/cc_head.py
Normal file
43
finetune/mmseg/models/decode_heads/cc_head.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
try:
|
||||
from mmcv.ops import CrissCrossAttention
|
||||
except ModuleNotFoundError:
|
||||
CrissCrossAttention = None
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CCHead(FCNHead):
|
||||
"""CCNet: Criss-Cross Attention for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `CCNet
|
||||
<https://arxiv.org/abs/1811.11721>`_.
|
||||
|
||||
Args:
|
||||
recurrence (int): Number of recurrence of Criss Cross Attention
|
||||
module. Default: 2.
|
||||
"""
|
||||
|
||||
def __init__(self, recurrence=2, **kwargs):
|
||||
if CrissCrossAttention is None:
|
||||
raise RuntimeError('Please install mmcv-full for '
|
||||
'CrissCrossAttention ops')
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.recurrence = recurrence
|
||||
self.cca = CrissCrossAttention(self.channels)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
for _ in range(self.recurrence):
|
||||
output = self.cca(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
184
finetune/mmseg/models/decode_heads/da_head.py
Normal file
184
finetune/mmseg/models/decode_heads/da_head.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, Scale
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList, add_prefix
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class PAM(_SelfAttentionBlock):
|
||||
"""Position Attention Module (PAM)
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature.
|
||||
channels (int): Output channels of key/query transform.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels):
|
||||
super().__init__(
|
||||
key_in_channels=in_channels,
|
||||
query_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=in_channels,
|
||||
share_key_query=False,
|
||||
query_downsample=None,
|
||||
key_downsample=None,
|
||||
key_query_num_convs=1,
|
||||
key_query_norm=False,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=False,
|
||||
with_out=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)
|
||||
|
||||
self.gamma = Scale(0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
out = super().forward(x, x)
|
||||
|
||||
out = self.gamma(out) + x
|
||||
return out
|
||||
|
||||
|
||||
class CAM(nn.Module):
|
||||
"""Channel Attention Module (CAM)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gamma = Scale(0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
batch_size, channels, height, width = x.size()
|
||||
proj_query = x.view(batch_size, channels, -1)
|
||||
proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
|
||||
energy = torch.bmm(proj_query, proj_key)
|
||||
energy_new = torch.max(
|
||||
energy, -1, keepdim=True)[0].expand_as(energy) - energy
|
||||
attention = F.softmax(energy_new, dim=-1)
|
||||
proj_value = x.view(batch_size, channels, -1)
|
||||
|
||||
out = torch.bmm(attention, proj_value)
|
||||
out = out.view(batch_size, channels, height, width)
|
||||
|
||||
out = self.gamma(out) + x
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DAHead(BaseDecodeHead):
|
||||
"""Dual Attention Network for Scene Segmentation.
|
||||
|
||||
This head is the implementation of `DANet
|
||||
<https://arxiv.org/abs/1809.02983>`_.
|
||||
|
||||
Args:
|
||||
pam_channels (int): The channels of Position Attention Module(PAM).
|
||||
"""
|
||||
|
||||
def __init__(self, pam_channels, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pam_channels = pam_channels
|
||||
self.pam_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.pam = PAM(self.channels, pam_channels)
|
||||
self.pam_out_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.pam_conv_seg = nn.Conv2d(
|
||||
self.channels, self.num_classes, kernel_size=1)
|
||||
|
||||
self.cam_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.cam = CAM()
|
||||
self.cam_out_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.cam_conv_seg = nn.Conv2d(
|
||||
self.channels, self.num_classes, kernel_size=1)
|
||||
|
||||
def pam_cls_seg(self, feat):
|
||||
"""PAM feature classification."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.pam_conv_seg(feat)
|
||||
return output
|
||||
|
||||
def cam_cls_seg(self, feat):
|
||||
"""CAM feature classification."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.cam_conv_seg(feat)
|
||||
return output
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
pam_feat = self.pam_in_conv(x)
|
||||
pam_feat = self.pam(pam_feat)
|
||||
pam_feat = self.pam_out_conv(pam_feat)
|
||||
pam_out = self.pam_cls_seg(pam_feat)
|
||||
|
||||
cam_feat = self.cam_in_conv(x)
|
||||
cam_feat = self.cam(cam_feat)
|
||||
cam_feat = self.cam_out_conv(cam_feat)
|
||||
cam_out = self.cam_cls_seg(cam_feat)
|
||||
|
||||
feat_sum = pam_feat + cam_feat
|
||||
pam_cam_out = self.cls_seg(feat_sum)
|
||||
|
||||
return pam_cam_out, pam_out, cam_out
|
||||
|
||||
def predict(self, inputs, batch_img_metas: List[dict], test_cfg,
|
||||
**kwargs) -> List[Tensor]:
|
||||
"""Forward function for testing, only ``pam_cam`` is used."""
|
||||
seg_logits = self.forward(inputs)[0]
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
|
||||
|
||||
def loss_by_feat(self, seg_logit: Tuple[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
"""Compute ``pam_cam``, ``pam``, ``cam`` loss."""
|
||||
pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
|
||||
loss = dict()
|
||||
loss.update(
|
||||
add_prefix(
|
||||
super().loss_by_feat(pam_cam_seg_logit, batch_data_samples),
|
||||
'pam_cam'))
|
||||
loss.update(
|
||||
add_prefix(super().loss_by_feat(pam_seg_logit, batch_data_samples),
|
||||
'pam'))
|
||||
loss.update(
|
||||
add_prefix(super().loss_by_feat(cam_seg_logit, batch_data_samples),
|
||||
'cam'))
|
||||
return loss
|
||||
116
finetune/mmseg/models/decode_heads/ddr_head.py
Normal file
116
finetune/mmseg/models/decode_heads/ddr_head.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.models.losses import accuracy
|
||||
from mmseg.models.utils import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType, SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DDRHead(BaseDecodeHead):
|
||||
"""Decode head for DDRNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_classes (int): Number of classes.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict, optional): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
num_classes: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
channels,
|
||||
num_classes=num_classes,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**kwargs)
|
||||
|
||||
self.head = self._make_base_head(self.in_channels, self.channels)
|
||||
self.aux_head = self._make_base_head(self.in_channels // 2,
|
||||
self.channels)
|
||||
self.aux_cls_seg = nn.Conv2d(
|
||||
self.channels, self.out_channels, kernel_size=1)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: Union[Tensor,
|
||||
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
|
||||
if self.training:
|
||||
c3_feat, c5_feat = inputs
|
||||
x_c = self.head(c5_feat)
|
||||
x_c = self.cls_seg(x_c)
|
||||
x_s = self.aux_head(c3_feat)
|
||||
x_s = self.aux_cls_seg(x_s)
|
||||
|
||||
return x_c, x_s
|
||||
else:
|
||||
x_c = self.head(inputs)
|
||||
x_c = self.cls_seg(x_c)
|
||||
return x_c
|
||||
|
||||
def _make_base_head(self, in_channels: int,
|
||||
channels: int) -> nn.Sequential:
|
||||
layers = [
|
||||
ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
order=('norm', 'act', 'conv')),
|
||||
build_norm_layer(self.norm_cfg, channels)[1],
|
||||
build_activation_layer(self.act_cfg),
|
||||
]
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tuple[Tensor],
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
loss = dict()
|
||||
context_logit, spatial_logit = seg_logits
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
|
||||
context_logit = resize(
|
||||
context_logit,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
spatial_logit = resize(
|
||||
spatial_logit,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
seg_label = seg_label.squeeze(1)
|
||||
|
||||
loss['loss_context'] = self.loss_decode[0](context_logit, seg_label)
|
||||
loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label)
|
||||
loss['acc_seg'] = accuracy(
|
||||
context_logit, seg_label, ignore_index=self.ignore_index)
|
||||
|
||||
return loss
|
||||
366
finetune/mmseg/models/decode_heads/decode_head.py
Normal file
366
finetune/mmseg/models/decode_heads/decode_head.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import build_pixel_sampler
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
from ..losses import accuracy
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for BaseDecodeHead.
|
||||
|
||||
1. The ``init_weights`` method is used to initialize decode_head's
|
||||
model parameters. After segmentor initialization, ``init_weights``
|
||||
is triggered when ``segmentor.init_weights()`` is called externally.
|
||||
|
||||
2. The ``loss`` method is used to calculate the loss of decode_head,
|
||||
which includes two steps: (1) the decode_head model performs forward
|
||||
propagation to obtain the feature maps (2) The ``loss_by_feat`` method
|
||||
is called based on the feature maps to calculate the loss.
|
||||
|
||||
.. code:: text
|
||||
|
||||
loss(): forward() -> loss_by_feat()
|
||||
|
||||
3. The ``predict`` method is used to predict segmentation results,
|
||||
which includes two steps: (1) the decode_head model performs forward
|
||||
propagation to obtain the feature maps (2) The ``predict_by_feat`` method
|
||||
is called based on the feature maps to predict segmentation results
|
||||
including post-processing.
|
||||
|
||||
.. code:: text
|
||||
|
||||
predict(): forward() -> predict_by_feat()
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
num_classes (int): Number of classes.
|
||||
out_channels (int): Output channels of conv_seg. Default: None.
|
||||
threshold (float): Threshold for binary segmentation in the case of
|
||||
`num_classes==1`. Default: None.
|
||||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers. Default: None.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU')
|
||||
in_index (int|Sequence[int]): Input feature index. Default: -1
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
Default: None.
|
||||
loss_decode (dict | Sequence[dict]): Config of decode loss.
|
||||
The `loss_name` is property of corresponding loss function which
|
||||
could be shown in training log. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_ce'.
|
||||
e.g. dict(type='CrossEntropyLoss'),
|
||||
[dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='DiceLoss', loss_name='loss_dice')]
|
||||
Default: dict(type='CrossEntropyLoss').
|
||||
ignore_index (int | None): The label index to be ignored. When using
|
||||
masked BCE loss, ignore_index should be set to None. Default: 255.
|
||||
sampler (dict|None): The config of segmentation map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
channels,
|
||||
*,
|
||||
num_classes,
|
||||
out_channels=None,
|
||||
threshold=None,
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
in_index=-1,
|
||||
input_transform=None,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
loss_weight=1.0),
|
||||
ignore_index=255,
|
||||
sampler=None,
|
||||
align_corners=False,
|
||||
init_cfg=dict(
|
||||
type='Normal', std=0.01, override=dict(name='conv_seg'))):
|
||||
super().__init__(init_cfg)
|
||||
self._init_inputs(in_channels, in_index, input_transform)
|
||||
self.channels = channels
|
||||
self.dropout_ratio = dropout_ratio
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.in_index = in_index
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.align_corners = align_corners
|
||||
|
||||
if out_channels is None:
|
||||
if num_classes == 2:
|
||||
warnings.warn('For binary segmentation, we suggest using'
|
||||
'`out_channels = 1` to define the output'
|
||||
'channels of segmentor, and use `threshold`'
|
||||
'to convert `seg_logits` into a prediction'
|
||||
'applying a threshold')
|
||||
out_channels = num_classes
|
||||
|
||||
if out_channels != num_classes and out_channels != 1:
|
||||
raise ValueError(
|
||||
'out_channels should be equal to num_classes,'
|
||||
'except binary segmentation set out_channels == 1 and'
|
||||
f'num_classes == 2, but got out_channels={out_channels}'
|
||||
f'and num_classes={num_classes}')
|
||||
|
||||
if out_channels == 1 and threshold is None:
|
||||
threshold = 0.3
|
||||
warnings.warn('threshold is not defined for binary, and defaults'
|
||||
'to 0.3')
|
||||
self.num_classes = num_classes
|
||||
self.out_channels = out_channels
|
||||
self.threshold = threshold
|
||||
|
||||
if isinstance(loss_decode, dict):
|
||||
self.loss_decode = MODELS.build(loss_decode)
|
||||
elif isinstance(loss_decode, (list, tuple)):
|
||||
self.loss_decode = nn.ModuleList()
|
||||
for loss in loss_decode:
|
||||
self.loss_decode.append(MODELS.build(loss))
|
||||
else:
|
||||
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
|
||||
but got {type(loss_decode)}')
|
||||
|
||||
if sampler is not None:
|
||||
self.sampler = build_pixel_sampler(sampler, context=self)
|
||||
else:
|
||||
self.sampler = None
|
||||
|
||||
self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
||||
if dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout2d(dropout_ratio)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f'input_transform={self.input_transform}, ' \
|
||||
f'ignore_index={self.ignore_index}, ' \
|
||||
f'align_corners={self.align_corners}'
|
||||
return s
|
||||
|
||||
def _init_inputs(self, in_channels, in_index, input_transform):
|
||||
"""Check and initialize input transforms.
|
||||
|
||||
The in_channels, in_index and input_transform must match.
|
||||
Specifically, when input_transform is None, only single feature map
|
||||
will be selected. So in_channels and in_index must be of type int.
|
||||
When input_transform
|
||||
|
||||
Args:
|
||||
in_channels (int|Sequence[int]): Input channels.
|
||||
in_index (int|Sequence[int]): Input feature index.
|
||||
input_transform (str|None): Transformation type of input features.
|
||||
Options: 'resize_concat', 'multiple_select', None.
|
||||
'resize_concat': Multiple feature maps will be resize to the
|
||||
same size as first one and than concat together.
|
||||
Usually used in FCN head of HRNet.
|
||||
'multiple_select': Multiple feature maps will be bundle into
|
||||
a list and passed into decode head.
|
||||
None: Only one select feature map is allowed.
|
||||
"""
|
||||
|
||||
if input_transform is not None:
|
||||
assert input_transform in ['resize_concat', 'multiple_select']
|
||||
self.input_transform = input_transform
|
||||
self.in_index = in_index
|
||||
if input_transform is not None:
|
||||
assert isinstance(in_channels, (list, tuple))
|
||||
assert isinstance(in_index, (list, tuple))
|
||||
assert len(in_channels) == len(in_index)
|
||||
if input_transform == 'resize_concat':
|
||||
self.in_channels = sum(in_channels)
|
||||
else:
|
||||
self.in_channels = in_channels
|
||||
else:
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(in_index, int)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def _transform_inputs(self, inputs):
|
||||
"""Transform inputs for decoder.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
Tensor: The transformed inputs
|
||||
"""
|
||||
|
||||
if self.input_transform == 'resize_concat':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
upsampled_inputs = [
|
||||
resize(
|
||||
input=x,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for x in inputs
|
||||
]
|
||||
inputs = torch.cat(upsampled_inputs, dim=1)
|
||||
elif self.input_transform == 'multiple_select':
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
else:
|
||||
inputs = inputs[self.in_index]
|
||||
|
||||
return inputs
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, inputs):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.conv_seg(feat)
|
||||
return output
|
||||
|
||||
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Forward function for training.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): List of multi-level img features.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `img_metas` or `gt_semantic_seg`.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs)
|
||||
losses = self.loss_by_feat(seg_logits, batch_data_samples)
|
||||
return losses
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tensor:
|
||||
"""Forward function for prediction.
|
||||
|
||||
Args:
|
||||
inputs (Tuple[Tensor]): List of multi-level img features.
|
||||
batch_img_metas (dict): List Image info where each dict may also
|
||||
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
|
||||
'ori_shape', and 'pad_shape'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
seg_logits = self.forward(inputs)
|
||||
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
|
||||
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
|
||||
gt_semantic_segs = [
|
||||
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
|
||||
]
|
||||
return torch.stack(gt_semantic_segs, dim=0)
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Compute segmentation loss.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): The output from decode head forward function.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `metainfo` and `gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
loss = dict()
|
||||
seg_logits = resize(
|
||||
input=seg_logits,
|
||||
size=seg_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
if self.sampler is not None:
|
||||
seg_weight = self.sampler.sample(seg_logits, seg_label)
|
||||
else:
|
||||
seg_weight = None
|
||||
seg_label = seg_label.squeeze(1)
|
||||
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
for loss_decode in losses_decode:
|
||||
if loss_decode.loss_name not in loss:
|
||||
loss[loss_decode.loss_name] = loss_decode(
|
||||
seg_logits,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
else:
|
||||
loss[loss_decode.loss_name] += loss_decode(
|
||||
seg_logits,
|
||||
seg_label,
|
||||
weight=seg_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
|
||||
loss['acc_seg'] = accuracy(
|
||||
seg_logits, seg_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
def predict_by_feat(self, seg_logits: Tensor,
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Transform a batch of output seg_logits to the input shape.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): The output from decode head forward function.
|
||||
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
|
||||
Returns:
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
|
||||
if isinstance(batch_img_metas[0]['img_shape'], torch.Size):
|
||||
# slide inference
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
elif 'pad_shape' in batch_img_metas[0]:
|
||||
size = batch_img_metas[0]['pad_shape'][:2]
|
||||
else:
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
|
||||
seg_logits = resize(
|
||||
input=seg_logits,
|
||||
size=size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
return seg_logits
|
||||
141
finetune/mmseg/models/decode_heads/dm_head.py
Normal file
141
finetune/mmseg/models/decode_heads/dm_head.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class DCM(nn.Module):
|
||||
"""Dynamic Convolutional Module used in DMNet.
|
||||
|
||||
Args:
|
||||
filter_size (int): The filter size of generated convolution kernel
|
||||
used in Dynamic Convolutional Module.
|
||||
fusion (bool): Add one conv to fuse DCM output feature.
|
||||
in_channels (int): Input channels.
|
||||
channels (int): Channels after modules, before conv_seg.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
|
||||
norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.filter_size = filter_size
|
||||
self.fusion = fusion
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
|
||||
0)
|
||||
|
||||
self.input_redu_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
if self.norm_cfg is not None:
|
||||
self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
self.activate = build_activation_layer(self.act_cfg)
|
||||
|
||||
if self.fusion:
|
||||
self.fusion_conv = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
generated_filter = self.filter_gen_conv(
|
||||
F.adaptive_avg_pool2d(x, self.filter_size))
|
||||
x = self.input_redu_conv(x)
|
||||
b, c, h, w = x.shape
|
||||
# [1, b * c, h, w], c = self.channels
|
||||
x = x.view(1, b * c, h, w)
|
||||
# [b * c, 1, filter_size, filter_size]
|
||||
generated_filter = generated_filter.view(b * c, 1, self.filter_size,
|
||||
self.filter_size)
|
||||
pad = (self.filter_size - 1) // 2
|
||||
if (self.filter_size - 1) % 2 == 0:
|
||||
p2d = (pad, pad, pad, pad)
|
||||
else:
|
||||
p2d = (pad + 1, pad, pad + 1, pad)
|
||||
x = F.pad(input=x, pad=p2d, mode='constant', value=0)
|
||||
# [1, b * c, h, w]
|
||||
output = F.conv2d(input=x, weight=generated_filter, groups=b * c)
|
||||
# [b, c, h, w]
|
||||
output = output.view(b, c, h, w)
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
output = self.activate(output)
|
||||
|
||||
if self.fusion:
|
||||
output = self.fusion_conv(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DMHead(BaseDecodeHead):
|
||||
"""Dynamic Multi-scale Filters for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`DMNet <https://openaccess.thecvf.com/content_ICCV_2019/papers/\
|
||||
He_Dynamic_Multi-Scale_Filters_for_Semantic_Segmentation_\
|
||||
ICCV_2019_paper.pdf>`_.
|
||||
|
||||
Args:
|
||||
filter_sizes (tuple[int]): The size of generated convolutional filters
|
||||
used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
|
||||
fusion (bool): Add one conv to fuse DCM output feature.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(filter_sizes, (list, tuple))
|
||||
self.filter_sizes = filter_sizes
|
||||
self.fusion = fusion
|
||||
dcm_modules = []
|
||||
for filter_size in self.filter_sizes:
|
||||
dcm_modules.append(
|
||||
DCM(filter_size,
|
||||
self.fusion,
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.dcm_modules = nn.ModuleList(dcm_modules)
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels + len(filter_sizes) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
dcm_outs = [x]
|
||||
for dcm_module in self.dcm_modules:
|
||||
dcm_outs.append(dcm_module(x))
|
||||
dcm_outs = torch.cat(dcm_outs, dim=1)
|
||||
output = self.bottleneck(dcm_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
137
finetune/mmseg/models/decode_heads/dnl_head.py
Normal file
137
finetune/mmseg/models/decode_heads/dnl_head.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import NonLocal2d
|
||||
from torch import nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
class DisentangledNonLocal2d(NonLocal2d):
|
||||
"""Disentangled Non-Local Blocks.
|
||||
|
||||
Args:
|
||||
temperature (float): Temperature to adjust attention. Default: 0.05
|
||||
"""
|
||||
|
||||
def __init__(self, *arg, temperature, **kwargs):
|
||||
super().__init__(*arg, **kwargs)
|
||||
self.temperature = temperature
|
||||
self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
|
||||
|
||||
def embedded_gaussian(self, theta_x, phi_x):
|
||||
"""Embedded gaussian with temperature."""
|
||||
|
||||
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
||||
pairwise_weight = torch.matmul(theta_x, phi_x)
|
||||
if self.use_scale:
|
||||
# theta_x.shape[-1] is `self.inter_channels`
|
||||
pairwise_weight /= torch.tensor(
|
||||
theta_x.shape[-1],
|
||||
dtype=torch.float,
|
||||
device=pairwise_weight.device)**torch.tensor(
|
||||
0.5, device=pairwise_weight.device)
|
||||
pairwise_weight /= torch.tensor(
|
||||
self.temperature, device=pairwise_weight.device)
|
||||
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
||||
return pairwise_weight
|
||||
|
||||
def forward(self, x):
|
||||
# x: [N, C, H, W]
|
||||
n = x.size(0)
|
||||
|
||||
# g_x: [N, HxW, C]
|
||||
g_x = self.g(x).view(n, self.inter_channels, -1)
|
||||
g_x = g_x.permute(0, 2, 1)
|
||||
|
||||
# theta_x: [N, HxW, C], phi_x: [N, C, HxW]
|
||||
if self.mode == 'gaussian':
|
||||
theta_x = x.view(n, self.in_channels, -1)
|
||||
theta_x = theta_x.permute(0, 2, 1)
|
||||
if self.sub_sample:
|
||||
phi_x = self.phi(x).view(n, self.in_channels, -1)
|
||||
else:
|
||||
phi_x = x.view(n, self.in_channels, -1)
|
||||
elif self.mode == 'concatenation':
|
||||
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
|
||||
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
|
||||
else:
|
||||
theta_x = self.theta(x).view(n, self.inter_channels, -1)
|
||||
theta_x = theta_x.permute(0, 2, 1)
|
||||
phi_x = self.phi(x).view(n, self.inter_channels, -1)
|
||||
|
||||
# subtract mean
|
||||
theta_x -= theta_x.mean(dim=-2, keepdim=True)
|
||||
phi_x -= phi_x.mean(dim=-1, keepdim=True)
|
||||
|
||||
pairwise_func = getattr(self, self.mode)
|
||||
# pairwise_weight: [N, HxW, HxW]
|
||||
pairwise_weight = pairwise_func(theta_x, phi_x)
|
||||
|
||||
# y: [N, HxW, C]
|
||||
y = torch.matmul(pairwise_weight, g_x)
|
||||
# y: [N, C, H, W]
|
||||
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
|
||||
*x.size()[2:])
|
||||
|
||||
# unary_mask: [N, 1, HxW]
|
||||
unary_mask = self.conv_mask(x)
|
||||
unary_mask = unary_mask.view(n, 1, -1)
|
||||
unary_mask = unary_mask.softmax(dim=-1)
|
||||
# unary_x: [N, 1, C]
|
||||
unary_x = torch.matmul(unary_mask, g_x)
|
||||
# unary_x: [N, C, 1, 1]
|
||||
unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
|
||||
n, self.inter_channels, 1, 1)
|
||||
|
||||
output = x + self.conv_out(y + unary_x)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DNLHead(FCNHead):
|
||||
"""Disentangled Non-Local Neural Networks.
|
||||
|
||||
This head is the implementation of `DNLNet
|
||||
<https://arxiv.org/abs/2006.06668>`_.
|
||||
|
||||
Args:
|
||||
reduction (int): Reduction factor of projection transform. Default: 2.
|
||||
use_scale (bool): Whether to scale pairwise_weight by
|
||||
sqrt(1/inter_channels). Default: False.
|
||||
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
|
||||
'dot_product'. Default: 'embedded_gaussian.'.
|
||||
temperature (float): Temperature to adjust attention. Default: 0.05
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction=2,
|
||||
use_scale=True,
|
||||
mode='embedded_gaussian',
|
||||
temperature=0.05,
|
||||
**kwargs):
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.reduction = reduction
|
||||
self.use_scale = use_scale
|
||||
self.mode = mode
|
||||
self.temperature = temperature
|
||||
self.dnl_block = DisentangledNonLocal2d(
|
||||
in_channels=self.channels,
|
||||
reduction=self.reduction,
|
||||
use_scale=self.use_scale,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
mode=self.mode,
|
||||
temperature=self.temperature)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.dnl_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
294
finetune/mmseg/models/decode_heads/dpt_head.py
Normal file
294
finetune/mmseg/models/decode_heads/dpt_head.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, Linear, build_activation_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class ReassembleBlocks(BaseModule):
|
||||
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
||||
rearrange the feature vector to feature map.
|
||||
|
||||
Args:
|
||||
in_channels (int): ViT feature channels. Default: 768.
|
||||
out_channels (List): output channels of each stage.
|
||||
Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=768,
|
||||
out_channels=[96, 192, 384, 768],
|
||||
readout_type='ignore',
|
||||
patch_size=16,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
assert readout_type in ['ignore', 'add', 'project']
|
||||
self.readout_type = readout_type
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.projects = nn.ModuleList([
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
act_cfg=None,
|
||||
) for out_channel in out_channels
|
||||
])
|
||||
|
||||
self.resize_layers = nn.ModuleList([
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0],
|
||||
out_channels=out_channels[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1],
|
||||
out_channels=out_channels[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3],
|
||||
out_channels=out_channels[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
])
|
||||
if self.readout_type == 'project':
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(
|
||||
nn.Sequential(
|
||||
Linear(2 * in_channels, in_channels),
|
||||
build_activation_layer(dict(type='GELU'))))
|
||||
|
||||
def forward(self, inputs):
|
||||
assert isinstance(inputs, list)
|
||||
out = []
|
||||
for i, x in enumerate(inputs):
|
||||
assert len(x) == 2
|
||||
x, cls_token = x[0], x[1]
|
||||
feature_shape = x.shape
|
||||
if self.readout_type == 'project':
|
||||
x = x.flatten(2).permute((0, 2, 1))
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
x = x.permute(0, 2, 1).reshape(feature_shape)
|
||||
elif self.readout_type == 'add':
|
||||
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
||||
x = x.reshape(feature_shape)
|
||||
else:
|
||||
pass
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
class PreActResidualConvUnit(BaseModule):
|
||||
"""ResidualConvUnit, pre-activate residual unit.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of channels in the input feature map.
|
||||
act_cfg (dict): dictionary to construct and config activation layer.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
stride (int): stride of the first block. Default: 1
|
||||
dilation (int): dilation rate for convs layers. Default: 1.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
act_cfg,
|
||||
norm_cfg,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
bias=False,
|
||||
order=('act', 'conv', 'norm'))
|
||||
|
||||
self.conv2 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
bias=False,
|
||||
order=('act', 'conv', 'norm'))
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs_ = inputs.clone()
|
||||
x = self.conv1(inputs)
|
||||
x = self.conv2(x)
|
||||
return x + inputs_
|
||||
|
||||
|
||||
class FeatureFusionBlock(BaseModule):
|
||||
"""FeatureFusionBlock, merge feature map from different stages.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
act_cfg (dict): The activation config for ResidualConvUnit.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
expand (bool): Whether expand the channels in post process block.
|
||||
Default: False.
|
||||
align_corners (bool): align_corner setting for bilinear upsample.
|
||||
Default: True.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
act_cfg,
|
||||
norm_cfg,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.expand = expand
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.out_channels = in_channels
|
||||
if self.expand:
|
||||
self.out_channels = in_channels // 2
|
||||
|
||||
self.project = ConvModule(
|
||||
self.in_channels,
|
||||
self.out_channels,
|
||||
kernel_size=1,
|
||||
act_cfg=None,
|
||||
bias=True)
|
||||
|
||||
self.res_conv_unit1 = PreActResidualConvUnit(
|
||||
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||
self.res_conv_unit2 = PreActResidualConvUnit(
|
||||
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||
|
||||
def forward(self, *inputs):
|
||||
x = inputs[0]
|
||||
if len(inputs) == 2:
|
||||
if x.shape != inputs[1].shape:
|
||||
res = resize(
|
||||
inputs[1],
|
||||
size=(x.shape[2], x.shape[3]),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
else:
|
||||
res = inputs[1]
|
||||
x = x + self.res_conv_unit1(res)
|
||||
x = self.res_conv_unit2(x)
|
||||
x = resize(
|
||||
x,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.project(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DPTHead(BaseDecodeHead):
|
||||
"""Vision Transformers for Dense Prediction.
|
||||
|
||||
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embed dimension of the ViT backbone.
|
||||
Default: 768.
|
||||
post_process_channels (List): Out channels of post process conv
|
||||
layers. Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
expand_channels (bool): Whether expand the channels in post process
|
||||
block. Default: False.
|
||||
act_cfg (dict): The activation config for residual conv unit.
|
||||
Default dict(type='ReLU').
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims=768,
|
||||
post_process_channels=[96, 192, 384, 768],
|
||||
readout_type='ignore',
|
||||
patch_size=16,
|
||||
expand_channels=False,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_cfg=dict(type='BN'),
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.in_channels = self.in_channels
|
||||
self.expand_channels = expand_channels
|
||||
self.reassemble_blocks = ReassembleBlocks(embed_dims,
|
||||
post_process_channels,
|
||||
readout_type, patch_size)
|
||||
|
||||
self.post_process_channels = [
|
||||
channel * math.pow(2, i) if expand_channels else channel
|
||||
for i, channel in enumerate(post_process_channels)
|
||||
]
|
||||
self.convs = nn.ModuleList()
|
||||
for channel in self.post_process_channels:
|
||||
self.convs.append(
|
||||
ConvModule(
|
||||
channel,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
act_cfg=None,
|
||||
bias=False))
|
||||
self.fusion_blocks = nn.ModuleList()
|
||||
for _ in range(len(self.convs)):
|
||||
self.fusion_blocks.append(
|
||||
FeatureFusionBlock(self.channels, act_cfg, norm_cfg))
|
||||
self.fusion_blocks[0].res_conv_unit1 = None
|
||||
self.project = ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg)
|
||||
self.num_fusion_blocks = len(self.fusion_blocks)
|
||||
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
||||
self.num_post_process_channels = len(self.post_process_channels)
|
||||
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
||||
assert self.num_reassemble_blocks == self.num_post_process_channels
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == self.num_reassemble_blocks
|
||||
x = self._transform_inputs(inputs)
|
||||
x = self.reassemble_blocks(x)
|
||||
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
||||
out = self.fusion_blocks[0](x[-1])
|
||||
for i in range(1, len(self.fusion_blocks)):
|
||||
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
||||
out = self.project(out)
|
||||
out = self.cls_seg(out)
|
||||
return out
|
||||
169
finetune/mmseg/models/decode_heads/ema_head.py
Normal file
169
finetune/mmseg/models/decode_heads/ema_head.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
def reduce_mean(tensor):
|
||||
"""Reduce mean when distributed training."""
|
||||
if not (dist.is_available() and dist.is_initialized()):
|
||||
return tensor
|
||||
tensor = tensor.clone()
|
||||
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
|
||||
return tensor
|
||||
|
||||
|
||||
class EMAModule(nn.Module):
|
||||
"""Expectation Maximization Attention Module used in EMANet.
|
||||
|
||||
Args:
|
||||
channels (int): Channels of the whole module.
|
||||
num_bases (int): Number of bases.
|
||||
num_stages (int): Number of the EM iterations.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, num_bases, num_stages, momentum):
|
||||
super().__init__()
|
||||
assert num_stages >= 1, 'num_stages must be at least 1!'
|
||||
self.num_bases = num_bases
|
||||
self.num_stages = num_stages
|
||||
self.momentum = momentum
|
||||
|
||||
bases = torch.zeros(1, channels, self.num_bases)
|
||||
bases.normal_(0, math.sqrt(2. / self.num_bases))
|
||||
# [1, channels, num_bases]
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
self.register_buffer('bases', bases)
|
||||
|
||||
def forward(self, feats):
|
||||
"""Forward function."""
|
||||
batch_size, channels, height, width = feats.size()
|
||||
# [batch_size, channels, height*width]
|
||||
feats = feats.view(batch_size, channels, height * width)
|
||||
# [batch_size, channels, num_bases]
|
||||
bases = self.bases.repeat(batch_size, 1, 1)
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(self.num_stages):
|
||||
# [batch_size, height*width, num_bases]
|
||||
attention = torch.einsum('bcn,bck->bnk', feats, bases)
|
||||
attention = F.softmax(attention, dim=2)
|
||||
# l1 norm
|
||||
attention_normed = F.normalize(attention, dim=1, p=1)
|
||||
# [batch_size, channels, num_bases]
|
||||
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
|
||||
# l2 norm
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
|
||||
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
|
||||
feats_recon = feats_recon.view(batch_size, channels, height, width)
|
||||
|
||||
if self.training:
|
||||
bases = bases.mean(dim=0, keepdim=True)
|
||||
bases = reduce_mean(bases)
|
||||
# l2 norm
|
||||
bases = F.normalize(bases, dim=1, p=2)
|
||||
self.bases = (1 -
|
||||
self.momentum) * self.bases + self.momentum * bases
|
||||
|
||||
return feats_recon
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class EMAHead(BaseDecodeHead):
|
||||
"""Expectation Maximization Attention Networks for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `EMANet
|
||||
<https://arxiv.org/abs/1907.13426>`_.
|
||||
|
||||
Args:
|
||||
ema_channels (int): EMA module channels
|
||||
num_bases (int): Number of bases.
|
||||
num_stages (int): Number of the EM iterations.
|
||||
concat_input (bool): Whether concat the input and output of convs
|
||||
before classification layer. Default: True
|
||||
momentum (float): Momentum to update the base. Default: 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ema_channels,
|
||||
num_bases,
|
||||
num_stages,
|
||||
concat_input=True,
|
||||
momentum=0.1,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.ema_channels = ema_channels
|
||||
self.num_bases = num_bases
|
||||
self.num_stages = num_stages
|
||||
self.concat_input = concat_input
|
||||
self.momentum = momentum
|
||||
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
|
||||
self.num_stages, self.momentum)
|
||||
|
||||
self.ema_in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.ema_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
# project (0, inf) -> (-inf, inf)
|
||||
self.ema_mid_conv = ConvModule(
|
||||
self.ema_channels,
|
||||
self.ema_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=None,
|
||||
act_cfg=None)
|
||||
for param in self.ema_mid_conv.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.ema_out_conv = ConvModule(
|
||||
self.ema_channels,
|
||||
self.ema_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=None)
|
||||
self.bottleneck = ConvModule(
|
||||
self.ema_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if self.concat_input:
|
||||
self.conv_cat = ConvModule(
|
||||
self.in_channels + self.channels,
|
||||
self.channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
feats = self.ema_in_conv(x)
|
||||
identity = feats
|
||||
feats = self.ema_mid_conv(feats)
|
||||
recon = self.ema_module(feats)
|
||||
recon = F.relu(recon, inplace=True)
|
||||
recon = self.ema_out_conv(recon)
|
||||
output = F.relu(identity + recon, inplace=True)
|
||||
output = self.bottleneck(output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
196
finetune/mmseg/models/decode_heads/enc_head.py
Normal file
196
finetune/mmseg/models/decode_heads/enc_head.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
from ..utils import Encoding, resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class EncModule(nn.Module):
|
||||
"""Encoding Module used in EncNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
num_codes (int): Number of code words.
|
||||
conv_cfg (dict|None): Config of conv layers.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__()
|
||||
self.encoding_project = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
# TODO: resolve this hack
|
||||
# change to 1d
|
||||
if norm_cfg is not None:
|
||||
encoding_norm_cfg = norm_cfg.copy()
|
||||
if encoding_norm_cfg['type'] in ['BN', 'IN']:
|
||||
encoding_norm_cfg['type'] += '1d'
|
||||
else:
|
||||
encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
|
||||
'2d', '1d')
|
||||
else:
|
||||
# fallback to BN1d
|
||||
encoding_norm_cfg = dict(type='BN1d')
|
||||
self.encoding = nn.Sequential(
|
||||
Encoding(channels=in_channels, num_codes=num_codes),
|
||||
build_norm_layer(encoding_norm_cfg, num_codes)[1],
|
||||
nn.ReLU(inplace=True))
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_channels, in_channels), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
encoding_projection = self.encoding_project(x)
|
||||
encoding_feat = self.encoding(encoding_projection).mean(dim=1)
|
||||
batch_size, channels, _, _ = x.size()
|
||||
gamma = self.fc(encoding_feat)
|
||||
y = gamma.view(batch_size, channels, 1, 1)
|
||||
output = F.relu_(x + x * y)
|
||||
return encoding_feat, output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class EncHead(BaseDecodeHead):
|
||||
"""Context Encoding for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `EncNet
|
||||
<https://arxiv.org/abs/1803.08904>`_.
|
||||
|
||||
Args:
|
||||
num_codes (int): Number of code words. Default: 32.
|
||||
use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
|
||||
regularize the training. Default: True.
|
||||
add_lateral (bool): Whether use lateral connection to fuse features.
|
||||
Default: False.
|
||||
loss_se_decode (dict): Config of decode loss.
|
||||
Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_codes=32,
|
||||
use_se_loss=True,
|
||||
add_lateral=False,
|
||||
loss_se_decode=dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
loss_weight=0.2),
|
||||
**kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
self.use_se_loss = use_se_loss
|
||||
self.add_lateral = add_lateral
|
||||
self.num_codes = num_codes
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if add_lateral:
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
for in_channels in self.in_channels[:-1]: # skip the last one
|
||||
self.lateral_convs.append(
|
||||
ConvModule(
|
||||
in_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
self.fusion = ConvModule(
|
||||
len(self.in_channels) * self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.enc_module = EncModule(
|
||||
self.channels,
|
||||
num_codes=num_codes,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
if self.use_se_loss:
|
||||
self.loss_se_decode = MODELS.build(loss_se_decode)
|
||||
self.se_layer = nn.Linear(self.channels, self.num_classes)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
feat = self.bottleneck(inputs[-1])
|
||||
if self.add_lateral:
|
||||
laterals = [
|
||||
resize(
|
||||
lateral_conv(inputs[i]),
|
||||
size=feat.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
feat = self.fusion(torch.cat([feat, *laterals], 1))
|
||||
encode_feat, output = self.enc_module(feat)
|
||||
output = self.cls_seg(output)
|
||||
if self.use_se_loss:
|
||||
se_output = self.se_layer(encode_feat)
|
||||
return output, se_output
|
||||
else:
|
||||
return output
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType):
|
||||
"""Forward function for testing, ignore se_loss."""
|
||||
if self.use_se_loss:
|
||||
seg_logits = self.forward(inputs)[0]
|
||||
else:
|
||||
seg_logits = self.forward(inputs)
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_onehot_labels(seg_label, num_classes):
|
||||
"""Convert segmentation label to onehot.
|
||||
|
||||
Args:
|
||||
seg_label (Tensor): Segmentation label of shape (N, H, W).
|
||||
num_classes (int): Number of classes.
|
||||
|
||||
Returns:
|
||||
Tensor: Onehot labels of shape (N, num_classes).
|
||||
"""
|
||||
|
||||
batch_size = seg_label.size(0)
|
||||
onehot_labels = seg_label.new_zeros((batch_size, num_classes))
|
||||
for i in range(batch_size):
|
||||
hist = seg_label[i].float().histc(
|
||||
bins=num_classes, min=0, max=num_classes - 1)
|
||||
onehot_labels[i] = hist > 0
|
||||
return onehot_labels
|
||||
|
||||
def loss_by_feat(self, seg_logit: Tuple[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
"""Compute segmentation and semantic encoding loss."""
|
||||
seg_logit, se_seg_logit = seg_logit
|
||||
loss = dict()
|
||||
loss.update(super().loss_by_feat(seg_logit, batch_data_samples))
|
||||
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
se_loss = self.loss_se_decode(
|
||||
se_seg_logit,
|
||||
self._convert_to_onehot_labels(seg_label, self.num_classes))
|
||||
loss['loss_se'] = se_loss
|
||||
return loss
|
||||
96
finetune/mmseg/models/decode_heads/fcn_head.py
Normal file
96
finetune/mmseg/models/decode_heads/fcn_head.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FCNHead(BaseDecodeHead):
|
||||
"""Fully Convolution Networks for Semantic Segmentation.
|
||||
|
||||
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
|
||||
|
||||
Args:
|
||||
num_convs (int): Number of convs in the head. Default: 2.
|
||||
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
||||
concat_input (bool): Whether concat the input and output of convs
|
||||
before classification layer.
|
||||
dilation (int): The dilation rate for convs in the head. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_convs=2,
|
||||
kernel_size=3,
|
||||
concat_input=True,
|
||||
dilation=1,
|
||||
**kwargs):
|
||||
assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
|
||||
self.num_convs = num_convs
|
||||
self.concat_input = concat_input
|
||||
self.kernel_size = kernel_size
|
||||
super().__init__(**kwargs)
|
||||
if num_convs == 0:
|
||||
assert self.in_channels == self.channels
|
||||
|
||||
conv_padding = (kernel_size // 2) * dilation
|
||||
convs = []
|
||||
convs.append(
|
||||
ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=conv_padding,
|
||||
dilation=dilation,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
for i in range(num_convs - 1):
|
||||
convs.append(
|
||||
ConvModule(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=conv_padding,
|
||||
dilation=dilation,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
if num_convs == 0:
|
||||
self.convs = nn.Identity()
|
||||
else:
|
||||
self.convs = nn.Sequential(*convs)
|
||||
if self.concat_input:
|
||||
self.conv_cat = ConvModule(
|
||||
self.in_channels + self.channels,
|
||||
self.channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def _forward_feature(self, inputs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
feats = self.convs(x)
|
||||
if self.concat_input:
|
||||
feats = self.conv_cat(torch.cat([x, feats], dim=1))
|
||||
return feats
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
68
finetune/mmseg/models/decode_heads/fpn_head.py
Normal file
68
finetune/mmseg/models/decode_heads/fpn_head.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample, resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class FPNHead(BaseDecodeHead):
|
||||
"""Panoptic Feature Pyramid Networks.
|
||||
|
||||
This head is the implementation of `Semantic FPN
|
||||
<https://arxiv.org/abs/1901.02446>`_.
|
||||
|
||||
Args:
|
||||
feature_strides (tuple[int]): The strides for input feature maps.
|
||||
stack_lateral. All strides suppose to be power of 2. The first
|
||||
one is of largest resolution.
|
||||
"""
|
||||
|
||||
def __init__(self, feature_strides, **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
assert len(feature_strides) == len(self.in_channels)
|
||||
assert min(feature_strides) == feature_strides[0]
|
||||
self.feature_strides = feature_strides
|
||||
|
||||
self.scale_heads = nn.ModuleList()
|
||||
for i in range(len(feature_strides)):
|
||||
head_length = max(
|
||||
1,
|
||||
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
|
||||
scale_head = []
|
||||
for k in range(head_length):
|
||||
scale_head.append(
|
||||
ConvModule(
|
||||
self.in_channels[i] if k == 0 else self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
if feature_strides[i] != feature_strides[0]:
|
||||
scale_head.append(
|
||||
Upsample(
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners))
|
||||
self.scale_heads.append(nn.Sequential(*scale_head))
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
x = self._transform_inputs(inputs)
|
||||
|
||||
output = self.scale_heads[0](x[0])
|
||||
for i in range(1, len(self.feature_strides)):
|
||||
# non inplace
|
||||
output = output + resize(
|
||||
self.scale_heads[i](x[i]),
|
||||
size=output.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
48
finetune/mmseg/models/decode_heads/gc_head.py
Normal file
48
finetune/mmseg/models/decode_heads/gc_head.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import ContextBlock
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class GCHead(FCNHead):
|
||||
"""GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
|
||||
|
||||
This head is the implementation of `GCNet
|
||||
<https://arxiv.org/abs/1904.11492>`_.
|
||||
|
||||
Args:
|
||||
ratio (float): Multiplier of channels ratio. Default: 1/4.
|
||||
pooling_type (str): The pooling type of context aggregation.
|
||||
Options are 'att', 'avg'. Default: 'avg'.
|
||||
fusion_types (tuple[str]): The fusion type for feature fusion.
|
||||
Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ratio=1 / 4.,
|
||||
pooling_type='att',
|
||||
fusion_types=('channel_add', ),
|
||||
**kwargs):
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.ratio = ratio
|
||||
self.pooling_type = pooling_type
|
||||
self.fusion_types = fusion_types
|
||||
self.gc_block = ContextBlock(
|
||||
in_channels=self.channels,
|
||||
ratio=self.ratio,
|
||||
pooling_type=self.pooling_type,
|
||||
fusion_types=self.fusion_types)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.gc_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
255
finetune/mmseg/models/decode_heads/ham_head.py
Normal file
255
finetune/mmseg/models/decode_heads/ham_head.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Originally from https://github.com/visual-attention-network/segnext
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.device import get_device
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class Matrix_Decomposition_2D_Base(nn.Module):
|
||||
"""Base class of 2D Matrix Decomposition.
|
||||
|
||||
Args:
|
||||
MD_S (int): The number of spatial coefficient in
|
||||
Matrix Decomposition, it may be used for calculation
|
||||
of the number of latent dimension D in Matrix
|
||||
Decomposition. Defaults: 1.
|
||||
MD_R (int): The number of latent dimension R in
|
||||
Matrix Decomposition. Defaults: 64.
|
||||
train_steps (int): The number of iteration steps in
|
||||
Multiplicative Update (MU) rule to solve Non-negative
|
||||
Matrix Factorization (NMF) in training. Defaults: 6.
|
||||
eval_steps (int): The number of iteration steps in
|
||||
Multiplicative Update (MU) rule to solve Non-negative
|
||||
Matrix Factorization (NMF) in evaluation. Defaults: 7.
|
||||
inv_t (int): Inverted multiple number to make coefficient
|
||||
smaller in softmax. Defaults: 100.
|
||||
rand_init (bool): Whether to initialize randomly.
|
||||
Defaults: True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
MD_S=1,
|
||||
MD_R=64,
|
||||
train_steps=6,
|
||||
eval_steps=7,
|
||||
inv_t=100,
|
||||
rand_init=True):
|
||||
super().__init__()
|
||||
|
||||
self.S = MD_S
|
||||
self.R = MD_R
|
||||
|
||||
self.train_steps = train_steps
|
||||
self.eval_steps = eval_steps
|
||||
|
||||
self.inv_t = inv_t
|
||||
|
||||
self.rand_init = rand_init
|
||||
|
||||
def _build_bases(self, B, S, D, R, device=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def local_step(self, x, bases, coef):
|
||||
raise NotImplementedError
|
||||
|
||||
def local_inference(self, x, bases):
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
coef = torch.bmm(x.transpose(1, 2), bases)
|
||||
coef = F.softmax(self.inv_t * coef, dim=-1)
|
||||
|
||||
steps = self.train_steps if self.training else self.eval_steps
|
||||
for _ in range(steps):
|
||||
bases, coef = self.local_step(x, bases, coef)
|
||||
|
||||
return bases, coef
|
||||
|
||||
def compute_coef(self, x, bases, coef):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x, return_bases=False):
|
||||
"""Forward Function."""
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# (B, C, H, W) -> (B * S, D, N)
|
||||
D = C // self.S
|
||||
N = H * W
|
||||
x = x.view(B * self.S, D, N)
|
||||
if not self.rand_init and not hasattr(self, 'bases'):
|
||||
bases = self._build_bases(1, self.S, D, self.R, device=x.device)
|
||||
self.register_buffer('bases', bases)
|
||||
|
||||
# (S, D, R) -> (B * S, D, R)
|
||||
if self.rand_init:
|
||||
bases = self._build_bases(B, self.S, D, self.R, device=x.device)
|
||||
else:
|
||||
bases = self.bases.repeat(B, 1, 1)
|
||||
|
||||
bases, coef = self.local_inference(x, bases)
|
||||
|
||||
# (B * S, N, R)
|
||||
coef = self.compute_coef(x, bases, coef)
|
||||
|
||||
# (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
|
||||
x = torch.bmm(bases, coef.transpose(1, 2))
|
||||
|
||||
# (B * S, D, N) -> (B, C, H, W)
|
||||
x = x.view(B, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class NMF2D(Matrix_Decomposition_2D_Base):
|
||||
"""Non-negative Matrix Factorization (NMF) module.
|
||||
|
||||
It is inherited from ``Matrix_Decomposition_2D_Base`` module.
|
||||
"""
|
||||
|
||||
def __init__(self, args=dict()):
|
||||
super().__init__(**args)
|
||||
|
||||
self.inv_t = 1
|
||||
|
||||
def _build_bases(self, B, S, D, R, device=None):
|
||||
"""Build bases in initialization."""
|
||||
if device is None:
|
||||
device = get_device()
|
||||
bases = torch.rand((B * S, D, R)).to(device)
|
||||
bases = F.normalize(bases, dim=1)
|
||||
|
||||
return bases
|
||||
|
||||
def local_step(self, x, bases, coef):
|
||||
"""Local step in iteration to renew bases and coefficient."""
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
numerator = torch.bmm(x.transpose(1, 2), bases)
|
||||
# (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
|
||||
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
||||
# Multiplicative Update
|
||||
coef = coef * numerator / (denominator + 1e-6)
|
||||
|
||||
# (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
|
||||
numerator = torch.bmm(x, coef)
|
||||
# (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
|
||||
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
|
||||
# Multiplicative Update
|
||||
bases = bases * numerator / (denominator + 1e-6)
|
||||
|
||||
return bases, coef
|
||||
|
||||
def compute_coef(self, x, bases, coef):
|
||||
"""Compute coefficient."""
|
||||
# (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
numerator = torch.bmm(x.transpose(1, 2), bases)
|
||||
# (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
|
||||
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
|
||||
# multiplication update
|
||||
coef = coef * numerator / (denominator + 1e-6)
|
||||
|
||||
return coef
|
||||
|
||||
|
||||
class Hamburger(nn.Module):
|
||||
"""Hamburger Module. It consists of one slice of "ham" (matrix
|
||||
decomposition) and two slices of "bread" (linear transformation).
|
||||
|
||||
Args:
|
||||
ham_channels (int): Input and output channels of feature.
|
||||
ham_kwargs (dict): Config of matrix decomposition module.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ham_channels=512,
|
||||
ham_kwargs=dict(),
|
||||
norm_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.ham_in = ConvModule(
|
||||
ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None)
|
||||
|
||||
self.ham = NMF2D(ham_kwargs)
|
||||
|
||||
self.ham_out = ConvModule(
|
||||
ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
|
||||
|
||||
def forward(self, x):
|
||||
enjoy = self.ham_in(x)
|
||||
enjoy = F.relu(enjoy, inplace=True)
|
||||
enjoy = self.ham(enjoy)
|
||||
enjoy = self.ham_out(enjoy)
|
||||
ham = F.relu(x + enjoy, inplace=True)
|
||||
|
||||
return ham
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LightHamHead(BaseDecodeHead):
|
||||
"""SegNeXt decode head.
|
||||
|
||||
This decode head is the implementation of `SegNeXt: Rethinking
|
||||
Convolutional Attention Design for Semantic
|
||||
Segmentation <https://arxiv.org/abs/2209.08575>`_.
|
||||
Inspiration from https://github.com/visual-attention-network/segnext.
|
||||
|
||||
Specifically, LightHamHead is inspired by HamNet from
|
||||
`Is Attention Better Than Matrix Decomposition?
|
||||
<https://arxiv.org/abs/2109.04553>`.
|
||||
|
||||
Args:
|
||||
ham_channels (int): input channels for Hamburger.
|
||||
Defaults: 512.
|
||||
ham_kwargs (int): kwagrs for Ham. Defaults: dict().
|
||||
"""
|
||||
|
||||
def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs):
|
||||
super().__init__(input_transform='multiple_select', **kwargs)
|
||||
self.ham_channels = ham_channels
|
||||
|
||||
self.squeeze = ConvModule(
|
||||
sum(self.in_channels),
|
||||
self.ham_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs)
|
||||
|
||||
self.align = ConvModule(
|
||||
self.ham_channels,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
inputs = [
|
||||
resize(
|
||||
level,
|
||||
size=inputs[0].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners) for level in inputs
|
||||
]
|
||||
|
||||
inputs = torch.cat(inputs, dim=1)
|
||||
# apply a conv block to squeeze feature map
|
||||
x = self.squeeze(inputs)
|
||||
# apply hamburger module
|
||||
x = self.hamburger(x)
|
||||
|
||||
# apply a conv block to align feature map
|
||||
output = self.align(x)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
143
finetune/mmseg/models/decode_heads/isa_head.py
Normal file
143
finetune/mmseg/models/decode_heads/isa_head.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
class SelfAttentionBlock(_SelfAttentionBlock):
|
||||
"""Self-Attention Module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Input channels of key/query feature.
|
||||
channels (int): Output channels of key/query transform.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
act_cfg (dict | None): Config of activation layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg):
|
||||
super().__init__(
|
||||
key_in_channels=in_channels,
|
||||
query_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=in_channels,
|
||||
share_key_query=False,
|
||||
query_downsample=None,
|
||||
key_downsample=None,
|
||||
key_query_num_convs=2,
|
||||
key_query_norm=True,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=False,
|
||||
matmul_norm=True,
|
||||
with_out=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.output_project = self.build_project(
|
||||
in_channels,
|
||||
in_channels,
|
||||
num_convs=1,
|
||||
use_conv_module=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
context = super().forward(x, x)
|
||||
return self.output_project(context)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ISAHead(BaseDecodeHead):
|
||||
"""Interlaced Sparse Self-Attention for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `ISA
|
||||
<https://arxiv.org/abs/1907.12273>`_.
|
||||
|
||||
Args:
|
||||
isa_channels (int): The channels of ISA Module.
|
||||
down_factor (tuple[int]): The local group size of ISA.
|
||||
"""
|
||||
|
||||
def __init__(self, isa_channels, down_factor=(8, 8), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.down_factor = down_factor
|
||||
|
||||
self.in_conv = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.global_relation = SelfAttentionBlock(
|
||||
self.channels,
|
||||
isa_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.local_relation = SelfAttentionBlock(
|
||||
self.channels,
|
||||
isa_channels,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.out_conv = ConvModule(
|
||||
self.channels * 2,
|
||||
self.channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x_ = self._transform_inputs(inputs)
|
||||
x = self.in_conv(x_)
|
||||
residual = x
|
||||
|
||||
n, c, h, w = x.size()
|
||||
loc_h, loc_w = self.down_factor # size of local group in H- and W-axes
|
||||
glb_h, glb_w = math.ceil(h / loc_h), math.ceil(w / loc_w)
|
||||
pad_h, pad_w = glb_h * loc_h - h, glb_w * loc_w - w
|
||||
if pad_h > 0 or pad_w > 0: # pad if the size is not divisible
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
|
||||
pad_h - pad_h // 2)
|
||||
x = F.pad(x, padding)
|
||||
|
||||
# global relation
|
||||
x = x.view(n, c, glb_h, loc_h, glb_w, loc_w)
|
||||
# do permutation to gather global group
|
||||
x = x.permute(0, 3, 5, 1, 2, 4) # (n, loc_h, loc_w, c, glb_h, glb_w)
|
||||
x = x.reshape(-1, c, glb_h, glb_w)
|
||||
# apply attention within each global group
|
||||
x = self.global_relation(x) # (n * loc_h * loc_w, c, glb_h, glb_w)
|
||||
|
||||
# local relation
|
||||
x = x.view(n, loc_h, loc_w, c, glb_h, glb_w)
|
||||
# do permutation to gather local group
|
||||
x = x.permute(0, 4, 5, 3, 1, 2) # (n, glb_h, glb_w, c, loc_h, loc_w)
|
||||
x = x.reshape(-1, c, loc_h, loc_w)
|
||||
# apply attention within each local group
|
||||
x = self.local_relation(x) # (n * glb_h * glb_w, c, loc_h, loc_w)
|
||||
|
||||
# permute each pixel back to its original position
|
||||
x = x.view(n, glb_h, glb_w, c, loc_h, loc_w)
|
||||
x = x.permute(0, 3, 1, 4, 2, 5) # (n, c, glb_h, loc_h, glb_w, loc_w)
|
||||
x = x.reshape(n, c, glb_h * loc_h, glb_w * loc_w)
|
||||
if pad_h > 0 or pad_w > 0: # remove padding
|
||||
x = x[:, :, pad_h // 2:pad_h // 2 + h, pad_w // 2:pad_w // 2 + w]
|
||||
|
||||
x = self.out_conv(torch.cat([x, residual], dim=1))
|
||||
out = self.cls_seg(x)
|
||||
|
||||
return out
|
||||
461
finetune/mmseg/models/decode_heads/knet_head.py
Normal file
461
finetune/mmseg/models/decode_heads/knet_head.py
Normal file
@@ -0,0 +1,461 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention,
|
||||
build_transformer_layer)
|
||||
from mmengine.logging import print_log
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class KernelUpdator(nn.Module):
|
||||
"""Dynamic Kernel Updator in Kernel Update Head.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels of input feature map.
|
||||
Default: 256.
|
||||
feat_channels (int): The number of middle-stage channels in
|
||||
the kernel updator. Default: 64.
|
||||
out_channels (int): The number of output channels.
|
||||
gate_sigmoid (bool): Whether use sigmoid function in gate
|
||||
mechanism. Default: True.
|
||||
gate_norm_act (bool): Whether add normalization and activation
|
||||
layer in gate mechanism. Default: False.
|
||||
activate_out: Whether add activation after gate mechanism.
|
||||
Default: False.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='LN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=None,
|
||||
gate_sigmoid=True,
|
||||
gate_norm_act=False,
|
||||
activate_out=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.feat_channels = feat_channels
|
||||
self.out_channels_raw = out_channels
|
||||
self.gate_sigmoid = gate_sigmoid
|
||||
self.gate_norm_act = gate_norm_act
|
||||
self.activate_out = activate_out
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.out_channels = out_channels if out_channels else in_channels
|
||||
|
||||
self.num_params_in = self.feat_channels
|
||||
self.num_params_out = self.feat_channels
|
||||
self.dynamic_layer = nn.Linear(
|
||||
self.in_channels, self.num_params_in + self.num_params_out)
|
||||
self.input_layer = nn.Linear(self.in_channels,
|
||||
self.num_params_in + self.num_params_out,
|
||||
1)
|
||||
self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
|
||||
self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1)
|
||||
if self.gate_norm_act:
|
||||
self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
|
||||
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
|
||||
self.activation = build_activation_layer(act_cfg)
|
||||
|
||||
self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1)
|
||||
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
|
||||
|
||||
def forward(self, update_feature, input_feature):
|
||||
"""Forward function of KernelUpdator.
|
||||
|
||||
Args:
|
||||
update_feature (torch.Tensor): Feature map assembled from
|
||||
each group. It would be reshaped with last dimension
|
||||
shape: `self.in_channels`.
|
||||
input_feature (torch.Tensor): Intermediate feature
|
||||
with shape: (N, num_classes, conv_kernel_size**2, channels).
|
||||
Returns:
|
||||
Tensor: The output tensor of shape (N*C1/C2, K*K, C2), where N is
|
||||
the number of classes, C1 and C2 are the feature map channels of
|
||||
KernelUpdateHead and KernelUpdator, respectively.
|
||||
"""
|
||||
|
||||
update_feature = update_feature.reshape(-1, self.in_channels)
|
||||
num_proposals = update_feature.size(0)
|
||||
# dynamic_layer works for
|
||||
# phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper
|
||||
parameters = self.dynamic_layer(update_feature)
|
||||
param_in = parameters[:, :self.num_params_in].view(
|
||||
-1, self.feat_channels)
|
||||
param_out = parameters[:, -self.num_params_out:].view(
|
||||
-1, self.feat_channels)
|
||||
|
||||
# input_layer works for
|
||||
# phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper
|
||||
input_feats = self.input_layer(
|
||||
input_feature.reshape(num_proposals, -1, self.feat_channels))
|
||||
input_in = input_feats[..., :self.num_params_in]
|
||||
input_out = input_feats[..., -self.num_params_out:]
|
||||
|
||||
# `gate_feats` is F^G in K-Net paper
|
||||
gate_feats = input_in * param_in.unsqueeze(-2)
|
||||
if self.gate_norm_act:
|
||||
gate_feats = self.activation(self.gate_norm(gate_feats))
|
||||
|
||||
input_gate = self.input_norm_in(self.input_gate(gate_feats))
|
||||
update_gate = self.norm_in(self.update_gate(gate_feats))
|
||||
if self.gate_sigmoid:
|
||||
input_gate = input_gate.sigmoid()
|
||||
update_gate = update_gate.sigmoid()
|
||||
param_out = self.norm_out(param_out)
|
||||
input_out = self.input_norm_out(input_out)
|
||||
|
||||
if self.activate_out:
|
||||
param_out = self.activation(param_out)
|
||||
input_out = self.activation(input_out)
|
||||
|
||||
# Gate mechanism. Eq.(5) in original paper.
|
||||
# param_out has shape (batch_size, feat_channels, out_channels)
|
||||
features = update_gate * param_out.unsqueeze(
|
||||
-2) + input_gate * input_out
|
||||
|
||||
features = self.fc_layer(features)
|
||||
features = self.fc_norm(features)
|
||||
features = self.activation(features)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class KernelUpdateHead(nn.Module):
|
||||
"""Kernel Update Head in K-Net.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes. Default: 150.
|
||||
num_ffn_fcs (int): The number of fully-connected layers in
|
||||
FFNs. Default: 2.
|
||||
num_heads (int): The number of parallel attention heads.
|
||||
Default: 8.
|
||||
num_mask_fcs (int): The number of fully connected layers for
|
||||
mask prediction. Default: 3.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 2048.
|
||||
in_channels (int): The number of channels of input feature map.
|
||||
Default: 256.
|
||||
out_channels (int): The number of output channels.
|
||||
Default: 256.
|
||||
dropout (float): The Probability of an element to be
|
||||
zeroed in MultiheadAttention and FFN. Default 0.0.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
ffn_act_cfg (dict): Config of activation layers in FFN.
|
||||
Default: dict(type='ReLU').
|
||||
conv_kernel_size (int): The kernel size of convolution in
|
||||
Kernel Update Head for dynamic kernel updation.
|
||||
Default: 1.
|
||||
feat_transform_cfg (dict | None): Config of feature transform.
|
||||
Default: None.
|
||||
kernel_init (bool): Whether initiate mask kernel in mask head.
|
||||
Default: False.
|
||||
with_ffn (bool): Whether add FFN in kernel update head.
|
||||
Default: True.
|
||||
feat_gather_stride (int): Stride of convolution in feature transform.
|
||||
Default: 1.
|
||||
mask_transform_stride (int): Stride of mask transform.
|
||||
Default: 1.
|
||||
kernel_updator_cfg (dict): Config of kernel updator.
|
||||
Default: dict(
|
||||
type='DynamicConv',
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN')).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=3,
|
||||
feedforward_channels=2048,
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
dropout=0.0,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
conv_kernel_size=1,
|
||||
feat_transform_cfg=None,
|
||||
kernel_init=False,
|
||||
with_ffn=True,
|
||||
feat_gather_stride=1,
|
||||
mask_transform_stride=1,
|
||||
kernel_updator_cfg=dict(
|
||||
type='DynamicConv',
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=256,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN'))):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.fp16_enabled = False
|
||||
self.dropout = dropout
|
||||
self.num_heads = num_heads
|
||||
self.kernel_init = kernel_init
|
||||
self.with_ffn = with_ffn
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.feat_gather_stride = feat_gather_stride
|
||||
self.mask_transform_stride = mask_transform_stride
|
||||
|
||||
self.attention = MultiheadAttention(in_channels * conv_kernel_size**2,
|
||||
num_heads, dropout)
|
||||
self.attention_norm = build_norm_layer(
|
||||
dict(type='LN'), in_channels * conv_kernel_size**2)[1]
|
||||
self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg)
|
||||
|
||||
if feat_transform_cfg is not None:
|
||||
kernel_size = feat_transform_cfg.pop('kernel_size', 1)
|
||||
transform_channels = in_channels
|
||||
self.feat_transform = ConvModule(
|
||||
transform_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=feat_gather_stride,
|
||||
padding=int(feat_gather_stride // 2),
|
||||
**feat_transform_cfg)
|
||||
else:
|
||||
self.feat_transform = None
|
||||
|
||||
if self.with_ffn:
|
||||
self.ffn = FFN(
|
||||
in_channels,
|
||||
feedforward_channels,
|
||||
num_ffn_fcs,
|
||||
act_cfg=ffn_act_cfg,
|
||||
dropout=dropout)
|
||||
self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
|
||||
|
||||
self.mask_fcs = nn.ModuleList()
|
||||
for _ in range(num_mask_fcs):
|
||||
self.mask_fcs.append(
|
||||
nn.Linear(in_channels, in_channels, bias=False))
|
||||
self.mask_fcs.append(
|
||||
build_norm_layer(dict(type='LN'), in_channels)[1])
|
||||
self.mask_fcs.append(build_activation_layer(act_cfg))
|
||||
|
||||
self.fc_mask = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def init_weights(self):
|
||||
"""Use xavier initialization for all weight parameter and set
|
||||
classification head bias as a specific value when use focal loss."""
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
else:
|
||||
# adopt the default initialization for
|
||||
# the weight and bias of the layer norm
|
||||
pass
|
||||
if self.kernel_init:
|
||||
print_log(
|
||||
'mask kernel in mask head is normal initialized by std 0.01')
|
||||
nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01)
|
||||
|
||||
def forward(self, x, proposal_feat, mask_preds, mask_shape=None):
|
||||
"""Forward function of Dynamic Instance Interactive Head.
|
||||
|
||||
Args:
|
||||
x (Tensor): Feature map from FPN with shape
|
||||
(batch_size, feature_dimensions, H , W).
|
||||
proposal_feat (Tensor): Intermediate feature get from
|
||||
diihead in last stage, has shape
|
||||
(batch_size, num_proposals, feature_dimensions)
|
||||
mask_preds (Tensor): mask prediction from the former stage in shape
|
||||
(batch_size, num_proposals, H, W).
|
||||
|
||||
Returns:
|
||||
Tuple: The first tensor is predicted mask with shape
|
||||
(N, num_classes, H, W), the second tensor is dynamic kernel
|
||||
with shape (N, num_classes, channels, K, K).
|
||||
"""
|
||||
N, num_proposals = proposal_feat.shape[:2]
|
||||
if self.feat_transform is not None:
|
||||
x = self.feat_transform(x)
|
||||
|
||||
C, H, W = x.shape[-3:]
|
||||
|
||||
mask_h, mask_w = mask_preds.shape[-2:]
|
||||
if mask_h != H or mask_w != W:
|
||||
gather_mask = F.interpolate(
|
||||
mask_preds, (H, W), align_corners=False, mode='bilinear')
|
||||
else:
|
||||
gather_mask = mask_preds
|
||||
|
||||
sigmoid_masks = gather_mask.softmax(dim=1)
|
||||
|
||||
# Group Feature Assembling. Eq.(3) in original paper.
|
||||
# einsum is faster than bmm by 30%
|
||||
x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x)
|
||||
|
||||
# obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]
|
||||
proposal_feat = proposal_feat.reshape(N, num_proposals,
|
||||
self.in_channels,
|
||||
-1).permute(0, 1, 3, 2)
|
||||
obj_feat = self.kernel_update_conv(x_feat, proposal_feat)
|
||||
|
||||
# [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C]
|
||||
obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2)
|
||||
obj_feat = self.attention_norm(self.attention(obj_feat))
|
||||
# [N, B, K*K*C] -> [B, N, K*K*C]
|
||||
obj_feat = obj_feat.permute(1, 0, 2)
|
||||
|
||||
# obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C]
|
||||
obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels)
|
||||
|
||||
# FFN
|
||||
if self.with_ffn:
|
||||
obj_feat = self.ffn_norm(self.ffn(obj_feat))
|
||||
|
||||
mask_feat = obj_feat
|
||||
|
||||
for reg_layer in self.mask_fcs:
|
||||
mask_feat = reg_layer(mask_feat)
|
||||
|
||||
# [B, N, K*K, C] -> [B, N, C, K*K]
|
||||
mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2)
|
||||
|
||||
if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1):
|
||||
mask_x = F.interpolate(
|
||||
x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
H, W = mask_x.shape[-2:]
|
||||
else:
|
||||
mask_x = x
|
||||
# group conv is 5x faster than unfold and uses about 1/5 memory
|
||||
# Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms
|
||||
# Group conv vs. unfold vs. concat batch, 278 : 1420 : 369
|
||||
# but in real training group conv is slower than concat batch
|
||||
# so we keep using concat batch.
|
||||
# fold_x = F.unfold(
|
||||
# mask_x,
|
||||
# self.conv_kernel_size,
|
||||
# padding=int(self.conv_kernel_size // 2))
|
||||
# mask_feat = mask_feat.reshape(N, num_proposals, -1)
|
||||
# new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x)
|
||||
# [B, N, C, K*K] -> [B*N, C, K, K]
|
||||
mask_feat = mask_feat.reshape(N, num_proposals, C,
|
||||
self.conv_kernel_size,
|
||||
self.conv_kernel_size)
|
||||
# [B, C, H, W] -> [1, B*C, H, W]
|
||||
new_mask_preds = []
|
||||
for i in range(N):
|
||||
new_mask_preds.append(
|
||||
F.conv2d(
|
||||
mask_x[i:i + 1],
|
||||
mask_feat[i],
|
||||
padding=int(self.conv_kernel_size // 2)))
|
||||
|
||||
new_mask_preds = torch.cat(new_mask_preds, dim=0)
|
||||
new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W)
|
||||
if self.mask_transform_stride == 2:
|
||||
new_mask_preds = F.interpolate(
|
||||
new_mask_preds,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
if mask_shape is not None and mask_shape[0] != H:
|
||||
new_mask_preds = F.interpolate(
|
||||
new_mask_preds,
|
||||
mask_shape,
|
||||
align_corners=False,
|
||||
mode='bilinear')
|
||||
|
||||
return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape(
|
||||
N, num_proposals, self.in_channels, self.conv_kernel_size,
|
||||
self.conv_kernel_size)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class IterativeDecodeHead(BaseDecodeHead):
|
||||
"""K-Net: Towards Unified Image Segmentation.
|
||||
|
||||
This head is the implementation of
|
||||
`K-Net: <https://arxiv.org/abs/2106.14855>`_.
|
||||
|
||||
Args:
|
||||
num_stages (int): The number of stages (kernel update heads)
|
||||
in IterativeDecodeHead. Default: 3.
|
||||
kernel_generate_head:(dict): Config of kernel generate head which
|
||||
generate mask predictions, dynamic kernels and class predictions
|
||||
for next kernel update heads.
|
||||
kernel_update_head (dict): Config of kernel update head which refine
|
||||
dynamic kernels and class predictions iteratively.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_stages, kernel_generate_head, kernel_update_head,
|
||||
**kwargs):
|
||||
# ``IterativeDecodeHead`` would skip initialization of
|
||||
# ``BaseDecodeHead`` which would be called when building
|
||||
# ``self.kernel_generate_head``.
|
||||
super(BaseDecodeHead, self).__init__(**kwargs)
|
||||
assert num_stages == len(kernel_update_head)
|
||||
self.num_stages = num_stages
|
||||
self.kernel_generate_head = MODELS.build(kernel_generate_head)
|
||||
self.kernel_update_head = nn.ModuleList()
|
||||
self.align_corners = self.kernel_generate_head.align_corners
|
||||
self.num_classes = self.kernel_generate_head.num_classes
|
||||
self.input_transform = self.kernel_generate_head.input_transform
|
||||
self.ignore_index = self.kernel_generate_head.ignore_index
|
||||
self.out_channels = self.num_classes
|
||||
|
||||
for head_cfg in kernel_update_head:
|
||||
self.kernel_update_head.append(MODELS.build(head_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
feats = self.kernel_generate_head._forward_feature(inputs)
|
||||
sem_seg = self.kernel_generate_head.cls_seg(feats)
|
||||
seg_kernels = self.kernel_generate_head.conv_seg.weight.clone()
|
||||
seg_kernels = seg_kernels[None].expand(
|
||||
feats.size(0), *seg_kernels.size())
|
||||
|
||||
stage_segs = [sem_seg]
|
||||
for i in range(self.num_stages):
|
||||
sem_seg, seg_kernels = self.kernel_update_head[i](feats,
|
||||
seg_kernels,
|
||||
sem_seg)
|
||||
stage_segs.append(sem_seg)
|
||||
if self.training:
|
||||
return stage_segs
|
||||
# only return the prediction of the last stage during testing
|
||||
return stage_segs[-1]
|
||||
|
||||
def loss_by_feat(self, seg_logits: List[Tensor],
|
||||
batch_data_samples: SampleList, **kwargs) -> dict:
|
||||
losses = dict()
|
||||
for i, logit in enumerate(seg_logits):
|
||||
loss = self.kernel_generate_head.loss_by_feat(
|
||||
logit, batch_data_samples)
|
||||
for k, v in loss.items():
|
||||
losses[f'{k}.s{i}'] = v
|
||||
|
||||
return losses
|
||||
91
finetune/mmseg/models/decode_heads/lraspp_head.py
Normal file
91
finetune/mmseg/models/decode_heads/lraspp_head.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.utils import is_tuple_of
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LRASPPHead(BaseDecodeHead):
|
||||
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
|
||||
|
||||
This head is the improved implementation of `Searching for MobileNetV3
|
||||
<https://ieeexplore.ieee.org/document/9008835>`_.
|
||||
|
||||
Args:
|
||||
branch_channels (tuple[int]): The number of output channels in every
|
||||
each branch. Default: (32, 64).
|
||||
"""
|
||||
|
||||
def __init__(self, branch_channels=(32, 64), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if self.input_transform != 'multiple_select':
|
||||
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
|
||||
f'must be \'multiple_select\'. But received '
|
||||
f'\'{self.input_transform}\'')
|
||||
assert is_tuple_of(branch_channels, int)
|
||||
assert len(branch_channels) == len(self.in_channels) - 1
|
||||
self.branch_channels = branch_channels
|
||||
|
||||
self.convs = nn.Sequential()
|
||||
self.conv_ups = nn.Sequential()
|
||||
for i in range(len(branch_channels)):
|
||||
self.convs.add_module(
|
||||
f'conv{i}',
|
||||
nn.Conv2d(
|
||||
self.in_channels[i], branch_channels[i], 1, bias=False))
|
||||
self.conv_ups.add_module(
|
||||
f'conv_up{i}',
|
||||
ConvModule(
|
||||
self.channels + branch_channels[i],
|
||||
self.channels,
|
||||
1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
bias=False))
|
||||
|
||||
self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
|
||||
|
||||
self.aspp_conv = ConvModule(
|
||||
self.in_channels[-1],
|
||||
self.channels,
|
||||
1,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
bias=False)
|
||||
self.image_pool = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
|
||||
ConvModule(
|
||||
self.in_channels[2],
|
||||
self.channels,
|
||||
1,
|
||||
act_cfg=dict(type='Sigmoid'),
|
||||
bias=False))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
inputs = self._transform_inputs(inputs)
|
||||
|
||||
x = inputs[-1]
|
||||
|
||||
x = self.aspp_conv(x) * resize(
|
||||
self.image_pool(x),
|
||||
size=x.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = self.conv_up_input(x)
|
||||
|
||||
for i in range(len(self.branch_channels) - 1, -1, -1):
|
||||
x = resize(
|
||||
x,
|
||||
size=inputs[i].size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
x = torch.cat([x, self.convs[i](inputs[i])], 1)
|
||||
x = self.conv_ups[i](x)
|
||||
|
||||
return self.cls_seg(x)
|
||||
163
finetune/mmseg/models/decode_heads/mask2former_head.py
Normal file
163
finetune/mmseg/models/decode_heads/mask2former_head.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
try:
|
||||
from mmdet.models.dense_heads import \
|
||||
Mask2FormerHead as MMDET_Mask2FormerHead
|
||||
except ModuleNotFoundError:
|
||||
MMDET_Mask2FormerHead = BaseModule
|
||||
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures.seg_data_sample import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class Mask2FormerHead(MMDET_Mask2FormerHead):
|
||||
"""Implements the Mask2Former head.
|
||||
|
||||
See `Mask2Former: Masked-attention Mask Transformer for Universal Image
|
||||
Segmentation <https://arxiv.org/abs/2112.01527>`_ for details.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes. Default: 150.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
ignore_index (int): The label index to be ignored. Default: 255.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
align_corners=False,
|
||||
ignore_index=255,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.align_corners = align_corners
|
||||
self.out_channels = num_classes
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
feat_channels = kwargs['feat_channels']
|
||||
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
|
||||
|
||||
def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
|
||||
"""Perform forward propagation to convert paradigm from MMSegmentation
|
||||
to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called
|
||||
normally. Specifically, ``batch_gt_instances`` would be added.
|
||||
|
||||
Args:
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: A tuple contains two lists.
|
||||
|
||||
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
||||
gt_instance. It usually includes ``labels``, each is
|
||||
unique ground truth label id of images, with
|
||||
shape (num_gt, ) and ``masks``, each is ground truth
|
||||
masks of each instances of a image, shape (num_gt, h, w).
|
||||
- batch_img_metas (list[dict]): List of image meta information.
|
||||
"""
|
||||
batch_img_metas = []
|
||||
batch_gt_instances = []
|
||||
|
||||
for data_sample in batch_data_samples:
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
classes = torch.unique(
|
||||
gt_sem_seg,
|
||||
sorted=False,
|
||||
return_inverse=False,
|
||||
return_counts=False)
|
||||
|
||||
# remove ignored region
|
||||
gt_labels = classes[classes != self.ignore_index]
|
||||
|
||||
masks = []
|
||||
for class_id in gt_labels:
|
||||
masks.append(gt_sem_seg == class_id)
|
||||
|
||||
if len(masks) == 0:
|
||||
gt_masks = torch.zeros(
|
||||
(0, gt_sem_seg.shape[-2],
|
||||
gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
|
||||
else:
|
||||
gt_masks = torch.stack(masks).squeeze(1).long()
|
||||
|
||||
instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
|
||||
batch_gt_instances.append(instance_data)
|
||||
return batch_gt_instances, batch_img_metas
|
||||
|
||||
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the decoder head
|
||||
on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the upstream
|
||||
network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
train_cfg (ConfigType): Training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components.
|
||||
"""
|
||||
# batch SegDataSample to InstanceDataSample
|
||||
batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
|
||||
batch_data_samples)
|
||||
|
||||
# forward
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
|
||||
# loss
|
||||
losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
|
||||
batch_gt_instances, batch_img_metas)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tuple[Tensor]:
|
||||
"""Test without augmentaton.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the
|
||||
upstream network, each is a 4D-tensor.
|
||||
batch_img_metas (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
test_cfg (ConfigType): Test config.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of segmentation mask.
|
||||
"""
|
||||
batch_data_samples = [
|
||||
SegDataSample(metainfo=metainfo) for metainfo in batch_img_metas
|
||||
]
|
||||
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
mask_cls_results = all_cls_scores[-1]
|
||||
mask_pred_results = all_mask_preds[-1]
|
||||
if 'pad_shape' in batch_img_metas[0]:
|
||||
size = batch_img_metas[0]['pad_shape']
|
||||
else:
|
||||
size = batch_img_metas[0]['img_shape']
|
||||
# upsample mask
|
||||
mask_pred_results = F.interpolate(
|
||||
mask_pred_results, size=size, mode='bilinear', align_corners=False)
|
||||
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
|
||||
mask_pred = mask_pred_results.sigmoid()
|
||||
seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred)
|
||||
return seg_logits
|
||||
174
finetune/mmseg/models/decode_heads/maskformer_head.py
Normal file
174
finetune/mmseg/models/decode_heads/maskformer_head.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
try:
|
||||
from mmdet.models.dense_heads import MaskFormerHead as MMDET_MaskFormerHead
|
||||
except ModuleNotFoundError:
|
||||
MMDET_MaskFormerHead = BaseModule
|
||||
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures.seg_data_sample import SegDataSample
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MaskFormerHead(MMDET_MaskFormerHead):
|
||||
"""Implements the MaskFormer head.
|
||||
|
||||
See `Per-Pixel Classification is Not All You Need for Semantic Segmentation
|
||||
<https://arxiv.org/pdf/2107.06278>`_ for details.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes. Default: 150.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
ignore_index (int): The label index to be ignored. Default: 255.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int = 150,
|
||||
align_corners: bool = False,
|
||||
ignore_index: int = 255,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.out_channels = kwargs['out_channels']
|
||||
self.align_corners = True
|
||||
self.num_classes = num_classes
|
||||
self.align_corners = align_corners
|
||||
self.out_channels = num_classes
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
feat_channels = kwargs['feat_channels']
|
||||
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
|
||||
|
||||
def _seg_data_to_instance_data(self, batch_data_samples: SampleList):
|
||||
"""Perform forward propagation to convert paradigm from MMSegmentation
|
||||
to MMDetection to ensure ``MMDET_MaskFormerHead`` could be called
|
||||
normally. Specifically, ``batch_gt_instances`` would be added.
|
||||
|
||||
Args:
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: A tuple contains two lists.
|
||||
|
||||
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
||||
gt_instance. It usually includes ``labels``, each is
|
||||
unique ground truth label id of images, with
|
||||
shape (num_gt, ) and ``masks``, each is ground truth
|
||||
masks of each instances of a image, shape (num_gt, h, w).
|
||||
- batch_img_metas (list[dict]): List of image meta information.
|
||||
"""
|
||||
batch_img_metas = []
|
||||
batch_gt_instances = []
|
||||
for data_sample in batch_data_samples:
|
||||
# Add `batch_input_shape` in metainfo of data_sample, which would
|
||||
# be used in MaskFormerHead of MMDetection.
|
||||
metainfo = data_sample.metainfo
|
||||
metainfo['batch_input_shape'] = metainfo['img_shape']
|
||||
data_sample.set_metainfo(metainfo)
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
classes = torch.unique(
|
||||
gt_sem_seg,
|
||||
sorted=False,
|
||||
return_inverse=False,
|
||||
return_counts=False)
|
||||
|
||||
# remove ignored region
|
||||
gt_labels = classes[classes != self.ignore_index]
|
||||
|
||||
masks = []
|
||||
for class_id in gt_labels:
|
||||
masks.append(gt_sem_seg == class_id)
|
||||
|
||||
if len(masks) == 0:
|
||||
gt_masks = torch.zeros((0, gt_sem_seg.shape[-2],
|
||||
gt_sem_seg.shape[-1])).to(gt_sem_seg)
|
||||
else:
|
||||
gt_masks = torch.stack(masks).squeeze(1)
|
||||
|
||||
instance_data = InstanceData(
|
||||
labels=gt_labels, masks=gt_masks.long())
|
||||
batch_gt_instances.append(instance_data)
|
||||
return batch_gt_instances, batch_img_metas
|
||||
|
||||
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Perform forward propagation and loss calculation of the decoder head
|
||||
on the features of the upstream network.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the upstream
|
||||
network, each is a 4D-tensor.
|
||||
batch_data_samples (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
train_cfg (ConfigType): Training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components.
|
||||
"""
|
||||
# batch SegDataSample to InstanceDataSample
|
||||
batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data(
|
||||
batch_data_samples)
|
||||
|
||||
# forward
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
|
||||
# loss
|
||||
losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
|
||||
batch_gt_instances, batch_img_metas)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType) -> Tuple[Tensor]:
|
||||
"""Test without augmentaton.
|
||||
|
||||
Args:
|
||||
x (tuple[Tensor]): Multi-level features from the
|
||||
upstream network, each is a 4D-tensor.
|
||||
batch_img_metas (List[:obj:`SegDataSample`]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
test_cfg (ConfigType): Test config.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of segmentation mask.
|
||||
"""
|
||||
|
||||
batch_data_samples = []
|
||||
for metainfo in batch_img_metas:
|
||||
metainfo['batch_input_shape'] = metainfo['img_shape']
|
||||
batch_data_samples.append(SegDataSample(metainfo=metainfo))
|
||||
# Forward function of MaskFormerHead from MMDetection needs
|
||||
# 'batch_data_samples' as inputs, which is image shape actually.
|
||||
all_cls_scores, all_mask_preds = self(x, batch_data_samples)
|
||||
mask_cls_results = all_cls_scores[-1]
|
||||
mask_pred_results = all_mask_preds[-1]
|
||||
|
||||
# upsample masks
|
||||
img_shape = batch_img_metas[0]['batch_input_shape']
|
||||
mask_pred_results = F.interpolate(
|
||||
mask_pred_results,
|
||||
size=img_shape,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
# semantic inference
|
||||
cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1]
|
||||
mask_pred = mask_pred_results.sigmoid()
|
||||
seg_logits = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred)
|
||||
return seg_logits
|
||||
50
finetune/mmseg/models/decode_heads/nl_head.py
Normal file
50
finetune/mmseg/models/decode_heads/nl_head.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import NonLocal2d
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class NLHead(FCNHead):
|
||||
"""Non-local Neural Networks.
|
||||
|
||||
This head is the implementation of `NLNet
|
||||
<https://arxiv.org/abs/1711.07971>`_.
|
||||
|
||||
Args:
|
||||
reduction (int): Reduction factor of projection transform. Default: 2.
|
||||
use_scale (bool): Whether to scale pairwise_weight by
|
||||
sqrt(1/inter_channels). Default: True.
|
||||
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
|
||||
'dot_product'. Default: 'embedded_gaussian.'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction=2,
|
||||
use_scale=True,
|
||||
mode='embedded_gaussian',
|
||||
**kwargs):
|
||||
super().__init__(num_convs=2, **kwargs)
|
||||
self.reduction = reduction
|
||||
self.use_scale = use_scale
|
||||
self.mode = mode
|
||||
self.nl_block = NonLocal2d(
|
||||
in_channels=self.channels,
|
||||
reduction=self.reduction,
|
||||
use_scale=self.use_scale,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
mode=self.mode)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
output = self.convs[0](x)
|
||||
output = self.nl_block(output)
|
||||
output = self.convs[1](output)
|
||||
if self.concat_input:
|
||||
output = self.conv_cat(torch.cat([x, output], dim=1))
|
||||
output = self.cls_seg(output)
|
||||
return output
|
||||
127
finetune/mmseg/models/decode_heads/ocr_head.py
Normal file
127
finetune/mmseg/models/decode_heads/ocr_head.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from ..utils import resize
|
||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||
|
||||
|
||||
class SpatialGatherModule(nn.Module):
|
||||
"""Aggregate the context features according to the initial predicted
|
||||
probability distribution.
|
||||
|
||||
Employ the soft-weighted method to aggregate the context.
|
||||
"""
|
||||
|
||||
def __init__(self, scale):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, feats, probs):
|
||||
"""Forward function."""
|
||||
batch_size, num_classes, height, width = probs.size()
|
||||
channels = feats.size(1)
|
||||
probs = probs.view(batch_size, num_classes, -1)
|
||||
feats = feats.view(batch_size, channels, -1)
|
||||
# [batch_size, height*width, num_classes]
|
||||
feats = feats.permute(0, 2, 1)
|
||||
# [batch_size, channels, height*width]
|
||||
probs = F.softmax(self.scale * probs, dim=2)
|
||||
# [batch_size, channels, num_classes]
|
||||
ocr_context = torch.matmul(probs, feats)
|
||||
ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)
|
||||
return ocr_context
|
||||
|
||||
|
||||
class ObjectAttentionBlock(_SelfAttentionBlock):
|
||||
"""Make a OCR used SelfAttentionBlock."""
|
||||
|
||||
def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,
|
||||
act_cfg):
|
||||
if scale > 1:
|
||||
query_downsample = nn.MaxPool2d(kernel_size=scale)
|
||||
else:
|
||||
query_downsample = None
|
||||
super().__init__(
|
||||
key_in_channels=in_channels,
|
||||
query_in_channels=in_channels,
|
||||
channels=channels,
|
||||
out_channels=in_channels,
|
||||
share_key_query=False,
|
||||
query_downsample=query_downsample,
|
||||
key_downsample=None,
|
||||
key_query_num_convs=2,
|
||||
key_query_norm=True,
|
||||
value_out_num_convs=1,
|
||||
value_out_norm=True,
|
||||
matmul_norm=True,
|
||||
with_out=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.bottleneck = ConvModule(
|
||||
in_channels * 2,
|
||||
in_channels,
|
||||
1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, query_feats, key_feats):
|
||||
"""Forward function."""
|
||||
context = super().forward(query_feats, key_feats)
|
||||
output = self.bottleneck(torch.cat([context, query_feats], dim=1))
|
||||
if self.query_downsample is not None:
|
||||
output = resize(query_feats)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class OCRHead(BaseCascadeDecodeHead):
|
||||
"""Object-Contextual Representations for Semantic Segmentation.
|
||||
|
||||
This head is the implementation of `OCRNet
|
||||
<https://arxiv.org/abs/1909.11065>`_.
|
||||
|
||||
Args:
|
||||
ocr_channels (int): The intermediate channels of OCR block.
|
||||
scale (int): The scale of probability map in SpatialGatherModule in
|
||||
Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, ocr_channels, scale=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.ocr_channels = ocr_channels
|
||||
self.scale = scale
|
||||
self.object_context_block = ObjectAttentionBlock(
|
||||
self.channels,
|
||||
self.ocr_channels,
|
||||
self.scale,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
self.spatial_gather_module = SpatialGatherModule(self.scale)
|
||||
|
||||
self.bottleneck = ConvModule(
|
||||
self.in_channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg)
|
||||
|
||||
def forward(self, inputs, prev_output):
|
||||
"""Forward function."""
|
||||
x = self._transform_inputs(inputs)
|
||||
feats = self.bottleneck(x)
|
||||
context = self.spatial_gather_module(feats, prev_output)
|
||||
object_context = self.object_context_block(feats, context)
|
||||
output = self.cls_seg(object_context)
|
||||
|
||||
return output
|
||||
183
finetune/mmseg/models/decode_heads/pid_head.py
Normal file
183
finetune/mmseg/models/decode_heads/pid_head.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.models.losses import accuracy
|
||||
from mmseg.models.utils import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import OptConfigType, SampleList
|
||||
|
||||
|
||||
class BasePIDHead(BaseModule):
|
||||
"""Base class for PID head.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
init_cfg (dict or list[dict], optional): Init config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
init_cfg: OptConfigType = None):
|
||||
super().__init__(init_cfg)
|
||||
self.conv = ConvModule(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
order=('norm', 'act', 'conv'))
|
||||
_, self.norm = build_norm_layer(norm_cfg, num_features=channels)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
def forward(self, x: Tensor, cls_seg: Optional[nn.Module]) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
cls_seg (nn.Module, optional): The classification head.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor.
|
||||
"""
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
x = self.act(x)
|
||||
if cls_seg is not None:
|
||||
x = cls_seg(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PIDHead(BaseDecodeHead):
|
||||
"""Decode head for PIDNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
channels (int): Number of output channels.
|
||||
num_classes (int): Number of classes.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU', inplace=True).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
channels: int,
|
||||
num_classes: int,
|
||||
norm_cfg: OptConfigType = dict(type='BN'),
|
||||
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
channels,
|
||||
num_classes=num_classes,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
**kwargs)
|
||||
self.i_head = BasePIDHead(in_channels, channels, norm_cfg, act_cfg)
|
||||
self.p_head = BasePIDHead(in_channels // 2, channels, norm_cfg,
|
||||
act_cfg)
|
||||
self.d_head = BasePIDHead(
|
||||
in_channels // 2,
|
||||
in_channels // 4,
|
||||
norm_cfg,
|
||||
)
|
||||
self.p_cls_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
||||
self.d_cls_seg = nn.Conv2d(in_channels // 4, 1, kernel_size=1)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: Union[Tensor,
|
||||
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
|
||||
"""Forward function.
|
||||
Args:
|
||||
inputs (Tensor | tuple[Tensor]): Input tensor or tuple of
|
||||
Tensor. When training, the input is a tuple of three tensors,
|
||||
(p_feat, i_feat, d_feat), and the output is a tuple of three
|
||||
tensors, (p_seg_logit, i_seg_logit, d_seg_logit).
|
||||
When inference, only the head of integral branch is used, and
|
||||
input is a tensor of integral feature map, and the output is
|
||||
the segmentation logit.
|
||||
|
||||
Returns:
|
||||
Tensor | tuple[Tensor]: Output tensor or tuple of tensors.
|
||||
"""
|
||||
if self.training:
|
||||
x_p, x_i, x_d = inputs
|
||||
x_p = self.p_head(x_p, self.p_cls_seg)
|
||||
x_i = self.i_head(x_i, self.cls_seg)
|
||||
x_d = self.d_head(x_d, self.d_cls_seg)
|
||||
return x_p, x_i, x_d
|
||||
else:
|
||||
return self.i_head(inputs, self.cls_seg)
|
||||
|
||||
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tuple[Tensor]:
|
||||
gt_semantic_segs = [
|
||||
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
|
||||
]
|
||||
gt_edge_segs = [
|
||||
data_sample.gt_edge_map.data for data_sample in batch_data_samples
|
||||
]
|
||||
gt_sem_segs = torch.stack(gt_semantic_segs, dim=0)
|
||||
gt_edge_segs = torch.stack(gt_edge_segs, dim=0)
|
||||
return gt_sem_segs, gt_edge_segs
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tuple[Tensor],
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
loss = dict()
|
||||
p_logit, i_logit, d_logit = seg_logits
|
||||
sem_label, bd_label = self._stack_batch_gt(batch_data_samples)
|
||||
p_logit = resize(
|
||||
input=p_logit,
|
||||
size=sem_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
i_logit = resize(
|
||||
input=i_logit,
|
||||
size=sem_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
d_logit = resize(
|
||||
input=d_logit,
|
||||
size=bd_label.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
sem_label = sem_label.squeeze(1)
|
||||
bd_label = bd_label.squeeze(1)
|
||||
loss['loss_sem_p'] = self.loss_decode[0](
|
||||
p_logit, sem_label, ignore_index=self.ignore_index)
|
||||
loss['loss_sem_i'] = self.loss_decode[1](i_logit, sem_label)
|
||||
loss['loss_bd'] = self.loss_decode[2](d_logit, bd_label)
|
||||
filler = torch.ones_like(sem_label) * self.ignore_index
|
||||
sem_bd_label = torch.where(
|
||||
torch.sigmoid(d_logit[:, 0, :, :]) > 0.8, sem_label, filler)
|
||||
loss['loss_sem_bd'] = self.loss_decode[3](i_logit, sem_bd_label)
|
||||
loss['acc_seg'] = accuracy(
|
||||
i_logit, sem_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
367
finetune/mmseg/models/decode_heads/point_head.py
Normal file
367
finetune/mmseg/models/decode_heads/point_head.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
try:
|
||||
from mmcv.ops import point_sample
|
||||
except ModuleNotFoundError:
|
||||
point_sample = None
|
||||
|
||||
from typing import List
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
from ..losses import accuracy
|
||||
from ..utils import resize
|
||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||
|
||||
|
||||
def calculate_uncertainty(seg_logits):
|
||||
"""Estimate uncertainty based on seg logits.
|
||||
|
||||
For each location of the prediction ``seg_logits`` we estimate
|
||||
uncertainty as the difference between top first and top second
|
||||
predicted logits.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): Semantic segmentation logits,
|
||||
shape (batch_size, num_classes, height, width).
|
||||
|
||||
Returns:
|
||||
scores (Tensor): T uncertainty scores with the most uncertain
|
||||
locations having the highest uncertainty score, shape (
|
||||
batch_size, 1, height, width)
|
||||
"""
|
||||
top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
|
||||
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PointHead(BaseCascadeDecodeHead):
|
||||
"""A mask point head use in PointRend.
|
||||
|
||||
This head is implemented of `PointRend: Image Segmentation as
|
||||
Rendering <https://arxiv.org/abs/1912.08193>`_.
|
||||
``PointHead`` use shared multi-layer perceptron (equivalent to
|
||||
nn.Conv1d) to predict the logit of input points. The fine-grained feature
|
||||
and coarse feature will be concatenate together for predication.
|
||||
|
||||
Args:
|
||||
num_fcs (int): Number of fc layers in the head. Default: 3.
|
||||
in_channels (int): Number of input channels. Default: 256.
|
||||
fc_channels (int): Number of fc channels. Default: 256.
|
||||
num_classes (int): Number of classes for logits. Default: 80.
|
||||
class_agnostic (bool): Whether use class agnostic classification.
|
||||
If so, the output channels of logits will be 1. Default: False.
|
||||
coarse_pred_each_layer (bool): Whether concatenate coarse feature with
|
||||
the output of each fc layer. Default: True.
|
||||
conv_cfg (dict|None): Dictionary to construct and config conv layer.
|
||||
Default: dict(type='Conv1d'))
|
||||
norm_cfg (dict|None): Dictionary to construct and config norm layer.
|
||||
Default: None.
|
||||
loss_point (dict): Dictionary to construct and config loss layer of
|
||||
point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
|
||||
loss_weight=1.0).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_fcs=3,
|
||||
coarse_pred_each_layer=True,
|
||||
conv_cfg=dict(type='Conv1d'),
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU', inplace=False),
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
input_transform='multiple_select',
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
init_cfg=dict(
|
||||
type='Normal', std=0.01, override=dict(name='fc_seg')),
|
||||
**kwargs)
|
||||
if point_sample is None:
|
||||
raise RuntimeError('Please install mmcv-full for '
|
||||
'point_sample ops')
|
||||
|
||||
self.num_fcs = num_fcs
|
||||
self.coarse_pred_each_layer = coarse_pred_each_layer
|
||||
|
||||
fc_in_channels = sum(self.in_channels) + self.num_classes
|
||||
fc_channels = self.channels
|
||||
self.fcs = nn.ModuleList()
|
||||
for k in range(num_fcs):
|
||||
fc = ConvModule(
|
||||
fc_in_channels,
|
||||
fc_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.fcs.append(fc)
|
||||
fc_in_channels = fc_channels
|
||||
fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
|
||||
else 0
|
||||
self.fc_seg = nn.Conv1d(
|
||||
fc_in_channels,
|
||||
self.num_classes,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
if self.dropout_ratio > 0:
|
||||
self.dropout = nn.Dropout(self.dropout_ratio)
|
||||
delattr(self, 'conv_seg')
|
||||
|
||||
def cls_seg(self, feat):
|
||||
"""Classify each pixel with fc."""
|
||||
if self.dropout is not None:
|
||||
feat = self.dropout(feat)
|
||||
output = self.fc_seg(feat)
|
||||
return output
|
||||
|
||||
def forward(self, fine_grained_point_feats, coarse_point_feats):
|
||||
x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
|
||||
for fc in self.fcs:
|
||||
x = fc(x)
|
||||
if self.coarse_pred_each_layer:
|
||||
x = torch.cat((x, coarse_point_feats), dim=1)
|
||||
return self.cls_seg(x)
|
||||
|
||||
def _get_fine_grained_point_feats(self, x, points):
|
||||
"""Sample from fine grained features.
|
||||
|
||||
Args:
|
||||
x (list[Tensor]): Feature pyramid from by neck or backbone.
|
||||
points (Tensor): Point coordinates, shape (batch_size,
|
||||
num_points, 2).
|
||||
|
||||
Returns:
|
||||
fine_grained_feats (Tensor): Sampled fine grained feature,
|
||||
shape (batch_size, sum(channels of x), num_points).
|
||||
"""
|
||||
|
||||
fine_grained_feats_list = [
|
||||
point_sample(_, points, align_corners=self.align_corners)
|
||||
for _ in x
|
||||
]
|
||||
if len(fine_grained_feats_list) > 1:
|
||||
fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
|
||||
else:
|
||||
fine_grained_feats = fine_grained_feats_list[0]
|
||||
|
||||
return fine_grained_feats
|
||||
|
||||
def _get_coarse_point_feats(self, prev_output, points):
|
||||
"""Sample from fine grained features.
|
||||
|
||||
Args:
|
||||
prev_output (list[Tensor]): Prediction of previous decode head.
|
||||
points (Tensor): Point coordinates, shape (batch_size,
|
||||
num_points, 2).
|
||||
|
||||
Returns:
|
||||
coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
|
||||
num_classes, num_points).
|
||||
"""
|
||||
|
||||
coarse_feats = point_sample(
|
||||
prev_output, points, align_corners=self.align_corners)
|
||||
|
||||
return coarse_feats
|
||||
|
||||
def loss(self, inputs, prev_output, batch_data_samples: SampleList,
|
||||
train_cfg, **kwargs):
|
||||
"""Forward function for training.
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
batch_data_samples (list[:obj:`SegDataSample`]): The seg
|
||||
data samples. It usually includes information such
|
||||
as `img_metas` or `gt_semantic_seg`.
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
with torch.no_grad():
|
||||
points = self.get_points_train(
|
||||
prev_output, calculate_uncertainty, cfg=train_cfg)
|
||||
fine_grained_point_feats = self._get_fine_grained_point_feats(
|
||||
x, points)
|
||||
coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
|
||||
point_logits = self.forward(fine_grained_point_feats,
|
||||
coarse_point_feats)
|
||||
|
||||
losses = self.loss_by_feat(point_logits, points, batch_data_samples)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, inputs, prev_output, batch_img_metas: List[dict],
|
||||
test_cfg, **kwargs):
|
||||
"""Forward function for testing.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
prev_output (Tensor): The output of previous decode head.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Output segmentation map.
|
||||
"""
|
||||
|
||||
x = self._transform_inputs(inputs)
|
||||
refined_seg_logits = prev_output.clone()
|
||||
for _ in range(test_cfg.subdivision_steps):
|
||||
refined_seg_logits = resize(
|
||||
refined_seg_logits,
|
||||
scale_factor=test_cfg.scale_factor,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
batch_size, channels, height, width = refined_seg_logits.shape
|
||||
point_indices, points = self.get_points_test(
|
||||
refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
|
||||
fine_grained_point_feats = self._get_fine_grained_point_feats(
|
||||
x, points)
|
||||
coarse_point_feats = self._get_coarse_point_feats(
|
||||
prev_output, points)
|
||||
point_logits = self.forward(fine_grained_point_feats,
|
||||
coarse_point_feats)
|
||||
|
||||
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
|
||||
refined_seg_logits = refined_seg_logits.reshape(
|
||||
batch_size, channels, height * width)
|
||||
refined_seg_logits = refined_seg_logits.scatter_(
|
||||
2, point_indices, point_logits)
|
||||
refined_seg_logits = refined_seg_logits.view(
|
||||
batch_size, channels, height, width)
|
||||
|
||||
return self.predict_by_feat(refined_seg_logits, batch_img_metas,
|
||||
**kwargs)
|
||||
|
||||
def loss_by_feat(self, point_logits, points, batch_data_samples, **kwargs):
|
||||
"""Compute segmentation loss."""
|
||||
gt_semantic_seg = self._stack_batch_gt(batch_data_samples)
|
||||
point_label = point_sample(
|
||||
gt_semantic_seg.float(),
|
||||
points,
|
||||
mode='nearest',
|
||||
align_corners=self.align_corners)
|
||||
point_label = point_label.squeeze(1).long()
|
||||
|
||||
loss = dict()
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
for loss_module in losses_decode:
|
||||
loss['point' + loss_module.loss_name] = loss_module(
|
||||
point_logits, point_label, ignore_index=self.ignore_index)
|
||||
|
||||
loss['acc_point'] = accuracy(
|
||||
point_logits, point_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
def get_points_train(self, seg_logits, uncertainty_func, cfg):
|
||||
"""Sample points for training.
|
||||
|
||||
Sample points in [0, 1] x [0, 1] coordinate space based on their
|
||||
uncertainty. The uncertainties are calculated for each point using
|
||||
'uncertainty_func' function that takes point's logit prediction as
|
||||
input.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): Semantic segmentation logits, shape (
|
||||
batch_size, num_classes, height, width).
|
||||
uncertainty_func (func): uncertainty calculation function.
|
||||
cfg (dict): Training config of point head.
|
||||
|
||||
Returns:
|
||||
point_coords (Tensor): A tensor of shape (batch_size, num_points,
|
||||
2) that contains the coordinates of ``num_points`` sampled
|
||||
points.
|
||||
"""
|
||||
num_points = cfg.num_points
|
||||
oversample_ratio = cfg.oversample_ratio
|
||||
importance_sample_ratio = cfg.importance_sample_ratio
|
||||
assert oversample_ratio >= 1
|
||||
assert 0 <= importance_sample_ratio <= 1
|
||||
batch_size = seg_logits.shape[0]
|
||||
num_sampled = int(num_points * oversample_ratio)
|
||||
point_coords = torch.rand(
|
||||
batch_size, num_sampled, 2, device=seg_logits.device)
|
||||
point_logits = point_sample(seg_logits, point_coords)
|
||||
# It is crucial to calculate uncertainty based on the sampled
|
||||
# prediction value for the points. Calculating uncertainties of the
|
||||
# coarse predictions first and sampling them for points leads to
|
||||
# incorrect results. To illustrate this: assume uncertainty func(
|
||||
# logits)=-abs(logits), a sampled point between two coarse
|
||||
# predictions with -1 and 1 logits has 0 logits, and therefore 0
|
||||
# uncertainty value. However, if we calculate uncertainties for the
|
||||
# coarse predictions first, both will have -1 uncertainty,
|
||||
# and sampled point will get -1 uncertainty.
|
||||
point_uncertainties = uncertainty_func(point_logits)
|
||||
num_uncertain_points = int(importance_sample_ratio * num_points)
|
||||
num_random_points = num_points - num_uncertain_points
|
||||
idx = torch.topk(
|
||||
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
||||
shift = num_sampled * torch.arange(
|
||||
batch_size, dtype=torch.long, device=seg_logits.device)
|
||||
idx += shift[:, None]
|
||||
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
|
||||
batch_size, num_uncertain_points, 2)
|
||||
if num_random_points > 0:
|
||||
rand_point_coords = torch.rand(
|
||||
batch_size, num_random_points, 2, device=seg_logits.device)
|
||||
point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
|
||||
return point_coords
|
||||
|
||||
def get_points_test(self, seg_logits, uncertainty_func, cfg):
|
||||
"""Sample points for testing.
|
||||
|
||||
Find ``num_points`` most uncertain points from ``uncertainty_map``.
|
||||
|
||||
Args:
|
||||
seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
|
||||
height, width) for class-specific or class-agnostic prediction.
|
||||
uncertainty_func (func): uncertainty calculation function.
|
||||
cfg (dict): Testing config of point head.
|
||||
|
||||
Returns:
|
||||
point_indices (Tensor): A tensor of shape (batch_size, num_points)
|
||||
that contains indices from [0, height x width) of the most
|
||||
uncertain points.
|
||||
point_coords (Tensor): A tensor of shape (batch_size, num_points,
|
||||
2) that contains [0, 1] x [0, 1] normalized coordinates of the
|
||||
most uncertain points from the ``height x width`` grid .
|
||||
"""
|
||||
|
||||
num_points = cfg.subdivision_num_points
|
||||
uncertainty_map = uncertainty_func(seg_logits)
|
||||
batch_size, _, height, width = uncertainty_map.shape
|
||||
h_step = 1.0 / height
|
||||
w_step = 1.0 / width
|
||||
|
||||
uncertainty_map = uncertainty_map.view(batch_size, height * width)
|
||||
num_points = min(height * width, num_points)
|
||||
point_indices = uncertainty_map.topk(num_points, dim=1)[1]
|
||||
point_coords = torch.zeros(
|
||||
batch_size,
|
||||
num_points,
|
||||
2,
|
||||
dtype=torch.float,
|
||||
device=seg_logits.device)
|
||||
point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
|
||||
width).float() * w_step
|
||||
point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
|
||||
width).float() * h_step
|
||||
return point_indices, point_coords
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user