commit 01adcfdf60c3501693e6e6faec4a6a9a3cc1c895 Author: esenke Date: Mon Dec 8 22:16:31 2025 +0800 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2b12c4e --- /dev/null +++ b/.gitignore @@ -0,0 +1,130 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +.envrc + +# Jupyter Notebook +.ipynb_checkpoints/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre +.pyre/ + +# ruff +.ruff_cache/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ +.project +.pydevproject +.settings/ + +# OS +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Deep Learning +*.pth +*.pt +*.onnx +*.engine +*.trt +checkpoints/ +work_dirs/ +logs/ +runs/ +outputs/ +wandb/ +mlruns/ +tensorboard/ + +# Data (uncomment if needed) +# data/ +# datasets/ +# *.tif +# *.tiff +# *.zip +# *.tar +# *.tar.gz + +# Misc +*.log +*.tmp +*.temp +.cache/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..75f00b0 --- /dev/null +++ b/README.md @@ -0,0 +1,353 @@ +# SkySense++ + +This repository is the official implementation of the paper "SkySense++: A Semantic-Enhanced Multi-Modal Remote Sensing Foundation Model Beyond SkySense for Earth Observation". + +## 📢 Latest Updates + +🔥🔥🔥 Last Updated on 2025.09.15 🔥🔥🔥 +- [2025.09.15] Add a [🌍 project page](https://zqcrafts.github.io/SkySense-O/project.html). +- [2025.08.04] Our work has been published in [*Nature Machine Intelligence*](https://www.nature.com/articles/s42256-025-01078-8). +- [2025.03.23] Code for preprocessing/pretraining/application and [model weights](https://www.notion.so/SkySense-Checkpoints-a7fcff6ce29a4647a08c7fe416910509?source=copy_link) for models have been uploaded. +- [2025.03.14] updated optical images of JL-16 dataset in [Huggingface](https://huggingface.co/datasets/KKKKKKang/JL-16). +- [2025.03.12] updated sentinel-1 images and labels of JL-16 dataset in [Zenodo](https://zenodo.org/records/15010418). +- [2025.03.09] created repo in [Zenodo](https://zenodo.org/records/15010418), datasets are uploading. +- [2024.11.13] updated details of pretrain and evaluation data. + +## Pretrain Data + +### RS-Semantic Dataset + +We conduct semantic-enhanced pretraining on the RS-Semantic dataset, which consists of 13 datasets with pixel-level annotations. Below are the specifics of these datasets. (also see in [Zenodo](https://zenodo.org/records/15010418)). + +| Dataset | Modalities | GSD(m) | Size | Categories | Download Link | +| ------------------- | ---------------- | ------ | --------------------- | ---------- | --------------------------------------------------------------------------------------------- | +| Five Billion Pixels | Gaofen-2 | 4 | 6800x7200 | 24 | [Download](https://x-ytong.github.io/project/Five-Billion-Pixels.html) | +| Potsdam | Airborne | 0.05 | 6000x6000 | 5 | [Download](https://www.isprs.org/education/benchmarks/UrbanSemLab/2d-sem-label-potsdam.aspx) | +| Vaihingen | Airborne | 0.05 | 2494x2064 | 5 | [Download](https://www.isprs.org/education/benchmarks/UrbanSemLab/2d-sem-label-vaihingen.aspx) | +| Deepglobe | WorldView | 0.5 | 2448x2448 | 6 | [Download](https://www.kaggle.com/datasets/balraj98/deepglobe-land-cover-classification-dataset) | +| iSAID | Multiple Sensors | - | 800x800 to 4000x13000 | 15 | [Download](https://captain-whu.github.io/iSAID/index.html) | +| LoveDA | Spaceborne | 0.3 | 1024x1024 | 7 | [Download](https://github.com/Junjue-Wang/LoveDA) | +| DynamicEarthNet | WorldView | 0.3 | 1024x1024 | 7 | [Download](https://github.com/aysim/dynnet) | +| | Sentinel-2* | 10 | 32x32 | | | +| | Sentinel-1* | 10 | 32x33 | | | +| Pastis-MM | WorldView | 0.3 | 1024x1024 | 18 | [Download](https://github.com/VSainteuf/pastis-benchmark) | +| | Sentinel-2* | 10 | 32x32 | | | +| | Sentinel-1* | 10 | 32x33 | | | +| C2Seg-AB | Sentinel-2* | 10 | 128x128 | 13 | [Download](https://github.com/danfenghong/RSE_Cross-city) | +| | Sentinel-1* | 10 | 128x128 | | | +| FLAIR | Spot-5 | 0.2 | 512x512 | 12 | [Download](https://github.com/IGNF/FLAIR-2) | +| | Sentinel-2* | 10 | 40x40 | | | +| DFC20 | Sentinel-2 | 10 | 256x256 | 9 | [Download](https://ieee-dataport.org/competitions/2020-ieee-grss-data-fusion-contest) | +| | Sentinel-1 | 10 | 256x256 | | | +| S2-naip | NAIP | 1 | 512x512 | 32 | [Download](https://huggingface.co/datasets/allenai/s2-naip) | +| | Sentinel-2* | 10 | 64x64 | | | +| | Sentinel-1* | 10 | 64x64 | | | +| JL-16 | Jilin-1 | 0.72 | 512x512 | 16 | [Download](https://zenodo.org/records/15010418) | +| | Sentinel-1* | 10 | 40x40 | | | + +*\* for time-series data.* + +### RS-Representation Dataset + +The pretraining list is in the [Zenodo](https://zenodo.org/records/15068572)- `rep_data_list.tar`. The download and process scripts are in [tools/pretraining_data_builder](tools/pretraining_data_builder). + +## EO Benchmark + +We evaluate our SkySense++ on 12 typical Earth Observation (EO) tasks across 7 domains: *agriculture*, *forestry*, *oceanography*, *atmosphere*, *biology*, *land surveying*, and *disaster management*. The detailed information about the datasets used for evaluation is as follows. + +| Domain | Task type | Dataset | Modalities | GSD | Image size | Download Link | Notes | +| ------------------- | --------------------------- | --------------------- | ---------------------- | ---- | ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- | ----- | +| Agriculture | Crop classification | Germany | Sentinel-2* | 10 | 24x24 | [Download](https://github.com/michaeltrs/DeepSatModels/tree/main/data) | | +| Foresetry | Tree species classification | TreeSatAI-Time-Series | Airborne, | 0.2 | 304x304 | [Download](http://example.com/download/treesatai-time-series) | | +| | | | Sentinel-2* | 10 | 6x6 | | | +| | | | Sentinel-1* | 10 | 6x6 | | | +| | Deforestation segmentation | Atlantic | Sentinel-2 | 10 | 512x512 | [Download](https://github.com/davej23/attention-mechanism-unet) | | +| Oceanography | Oil spill segmentation | SOS | Sentinel-1 | 10 | 256x256 | [Download](https://grzy.cug.edu.cn/zhuqiqi/en/yjgk/32384/list/index.htm) | | +| Atmosphere | Air pollution regression | 3pollution | Sentinel-2 | 10 | 200x200 | [Download](https://github.com/CoDIS-Lab/AQNet) | | +| | | | Sentinel-5P | 2600 | 120x120 | | | +| Biology | Wildlife detection | Kenya | Airborne | - | 3068x4603 | [Download](https://data.4tu.nl/articles/_/12713903/1) | | +| Land surveying | LULC mapping | C2Seg-BW | Gaofen-6 | 10 | 256x256 | [Download](https://github.com/danfenghong/RSE_Cross-city) | | +| | | | Gaofen-3 | 10 | 256x256 | | | +| | Change detection | dsifn-cd | GoogleEarth | 0.3 | 512x512 | [Download](https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images/tree/master/dataset) | | +| Disaster management | Flood monitoring | Flood-3i | Airborne | 0.05 | 256 × 256 | [Download](https://drive.google.com/drive/folders/1FMAKf2sszoFKjq0UrUmSLnJDbwQSpfxR) | | +| | | C2SMSFloods | Sentinel-2, Sentinel-1 | 10 | 512x512 | [Download](https://beta.source.coop/c2sms/) | | +| | Wildfire monitoring | CABUAR | Sentinel-2 | 10 | 5490 × 5490 | [Download](https://github.com/DarthReca/CaBuAr) | | +| | Landslide mapping | GVLM | GoogleEarth | 0.3 | 1748x1748 ~ 10808x7424 | [Download](https://github.com/zxk688/GVLM) | | +| | Building damage assessment | xBD | WorldView | 0.3 | 1024x1024 | [Download](https://xview2.org/) | | + +*\* for time-series data.* + +## Implementation Code + +### Structure + +This project mainly contains the following parts. + +```plain +./ +├── antmmf/ # antmmf framework code +├── configs/ +│ ├── eval_skysense_pp_flood3i.yml # eval cfg on flood3i +│ └── pretrain_skysensepp.yml # pretrain cfg +├── finetune/ # finetuning code +│ ├── configs/ # finetuning configs +│ ├── mmseg/ # mmseg library +│ ├── requirements/ # mmseg install requirements folder +│ ├── requirements.txt # mmseg install requirements +│ ├── setup.py # mmseg setup file +│ └── tools/ # mmseg utils +├── lib/ # model implementation +│ ├── datasets/ # datasets for evaluation +│ ├── evaluation/ # evaluation code +│ ├── models/ # model architecture +│ ├── predictors/ # inference code +│ ├── task/ # task code +│ ├── trainer/ # trainer code +│ ├── utils/ # library code +│ └── __init__.py # packages init file +├── pretrain/ # pretrain ckpts +├── tools/ # tools ckpts +│ ├── pretraining_data_builder # pretraining dataset builder +│ ├── run_1shot_flood3i.sh # datasets for evaluation +│ ├── run_ft_atlantic.sh # run ft script +│ ├── run_pretrain.sh # run pretrain script +│ └── run.py # Program entry point +└── readme.md # project readme +``` + +### Environment + +Each machine for implementating the pretraining or fintuning are with *Alibaba Group Enterprise Linux(7.2)* and *Python 3.8.10*. The pretraining and finetuning code are implemented on severs with *Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz* and *Nvidia A100 GPUS*. + +### Pretraining + +To run our pretraining code, please install dependency packages. (Instalazation takes about 14 minutes on a node with Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz and 8 A100 GPUs.) + +```plain +torch==1.13.1 +atorch==0.1.3 +torchvision==0.14.1 +mmcv-full==1.7.1 +mmsegmentation==0.30.0 +mmcls==0.25.0 +timm==0.6.13 +gdal==3.4.0 +scikit-image==0.19.3 +``` + +Step1. Install the above packages and clone [antmmf framework](https://github.com/alipay/Ant-Multi-Modal-Framework): + +```bash +git clone https://github.com/alipay/Ant-Multi-Modal-Framework.git antmmf/ +``` + +Step2. Download the pretraining datasets in [Zenodo](https://zenodo.org/records/15010418) and orgnize them as follows: + +``` +pretrain_datasets +├── dynamic-mm # multi-modal dynamic-mm datasets +│ ├── images_hr # hr images +│ ├── images_s2 # sentinel-2 images +│ ├── images_s1 # sentinel-1 images +│ ├── labels # segmentation annotations +│ ├── dynamic-mm_train.json # train list file +│ └── dynamic-mm_val.json # val list file +├── fbp # single-modal fbp datasets +│ ├── images # input gaofen-2 images +│ ├── labels # segmentation annotations +│ ├── fbp_train.json # train list file +│ └── fbp_val.json # val list file +└── ...... +``` + +The `_.json` is used to read information for training and validation, with a unified organizational format: + +```json +[ + { + "hr_path": "dataset_name/images_hr/.png", // hr info c,h,w + "s2_path": ["dataset_name/images_s2/_20240101.npz", "dataset_name/images_s2/_20240103.npz"], // s2 c,h,w + "s1_path": ["dataset_name/images_s1/_20240104.npz", "dataset_name/images_s1/_20240108.npz"], // s1 c,h,w + "target_path": "dataset_name/labels/.png", // annotation info + "type": "dataset_name", // dataset_name + "classes": [ // Included categories + 0, + 2, + 4, + 5 + ] + }, + { + ... + } +] +``` + +Step3. Download the pretraining weights of SkySense [here](https://www.notion.so/SkySense-Checkpoints-a7fcff6ce29a4647a08c7fe416910509) and move it to `pretrain/` + +Step4. Run the pretrain code on 4 nodes (each node with 8 A100 gpus): + +``` +bash tools/run_pretrain.sh +``` + +For example, if the ip adress of master node is *192.168.112.10*, the command for node *1* is: + +``` +bash tools/run_pretrain.sh 1 192.168.112.10 +``` + +### Downstream 1-shot application + +#### Requirments + +To run our code, please install dependency packages. ( Instalazation takes about 10 minutes on a sever with Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz and 2 A100 GPUs.) + +```plain +torch==1.13.1 +atorch==0.1.3 +torchvision==0.14.1 +mmcv-full==1.7.1 +mmcls==0.25.0 +mmsegmentation==0.30.0 +timm==0.6.13 +gdal==3.4.0 +scikit-image==0.19.3 +``` + +#### Run steps + +step1. Clone [antmmf framework](https://github.com/alipay/Ant-Multi-Modal-Framework). and install the above packages: + +```plain +git clone https://github.com/alipay/Ant-Multi-Modal-Framework.git antmmf/ +``` + +step1. Download the flood-3i dataset (`Images.zip`/`Semantic_mask.zip` at [here](https://drive.google.com/drive/folders/1FMAKf2sszoFKjq0UrUmSLnJDbwQSpfxR), `val.txt` at [here](https://www.notion.so/SkySense-Checkpoints-a7fcff6ce29a4647a08c7fe416910509). Testing dataset should be organized as follows: +```plain +eval_datasets/ +└── flood3i/ + ├── Images/ + │ ├── 10165_0_2.jpg + │ ├── 10165_1_0.jpg + │ └── ... + ├── Semantic_mask/ + │ ├── 10165_lab_0_2.png + │ ├── 10165_lab_1_0.png + │ └── ... + └── val.txt +``` + +step2. Using the above pretraining wieights or download the pretrained model weights [here](https://www.notion.so/SkySense-Checkpoints-a7fcff6ce29a4647a08c7fe416910509). + +step3. Run the script for evaluating 1-shot performance on flood-3i: + +```plain +bash tools/run_1shot.sh flood-3i(dataset_name) +``` + +### Downstream finetuning application + +#### Requirments + +We build our fine-tuning application code on the [openmmlab framework](https://github.com/open-mmlab/mmcv). + +To run our code, please install dependency packages. ( Instalazation takes about 10 minutes on a sever with Intel(R) Xeon(R) Platinum 8369B CPU @ 2.90GHz and 2 A100 GPUs.) + +```plain +torch==1.13.1 +torchvision==0.14.1 +mmcv-full==2.1.0 +mmpretrain==1.2.0 +mmsegmentation==1.2.2 +mmdetection==3.3.0 +timm==0.6.13 +gdal==3.4.0 +scikit-image==0.19.3 +``` + +#### Run steps + +Step1. Install the mmsegmentation framework under the instrction in [here](https://mmsegmentation.readthedocs.io/en/latest/index.html) + +Step2. Download the evaluation datsets. We take Atlantic dataset for deforestation segmentation as an example. Download the Atlantic dataset at [here](https://github.com/davej23/attention-mechanism-unet). Spliting json files of evaluation framwork [here](https://www.notion.so/SkySense-Checkpoints-a7fcff6ce29a4647a08c7fe416910509). +``` +../rs_datasets/deforestation_atlantic/ +-- +├── deforestation_atlantic_test.json +├── deforestation_atlantic_train.json +├── deforestation_atlantic_val.json +├── Test/ +│ ├── image/ +│ └── label/ +├── Training/ +│ ├── image/ +│ └── label/ +└── Validation/ + ├── image/ + └── label/ +``` + +Step3. Use your pretrained model weights or download the model weights: [here](https://www.notion.so/SkySense-Checkpoints-a7fcff6ce29a4647a08c7fe416910509) + +Step4. Run the finetuning script. We take the Atlantic dataset as an example: + +```bash +bash tools/run_finetune.sh configs/atlantic.py +``` + +## Acknowledgments + +This projects are mainly built on the following projects: + ++ [antmmf](https://github.com/alipay/Ant-Multi-Modal-Framework) ++ [mmcv](https://github.com/open-mmlab/mmcv) ++ [mmpretrain](https://github.com/open-mmlab/mmpretrain) ++ [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) ++ [mmdetection](https://github.com/open-mmlab/mmdetection) ++ [Painter](https://github.com/baaivision/Painter) + +## License + +The pre-trained model weight and pre-training code are only available for the non-commercial research. For any commercial use or cooperation, please contact Yansheng Li at Wuhan University (e-mail: yansheng.li@whu.edu.cn). + +## Citation +If you find our repo useful, please consider giving a star and citation: + +``` +@article{wu2025semantic, + author = {Wu, Kang and Zhang, Yingying and Ru, Lixiang and Dang, Bo and Lao, Jiangwei and Yu, Lei and Luo, Junwei and Zhu, Zifan and Sun, Yue and Zhang, Jiahao and Zhu, Qi and Wang, Jian and Yang, Ming and Chen, Jingdong and Zhang, Yongjun and Li, Yansheng}, + title = {A semantic‑enhanced multi‑modal remote sensing foundation model for Earth observation}, + journal = {Nature Machine Intelligence}, + year = {2025}, + doi = {10.1038/s42256-025-01078-8}, + url = {https://doi.org/10.1038/s42256-025-01078-8} +} + +@inproceedings{guo2024skysense, + author = {Guo, Xin and Lao, Jiangwei and Dang, Bo and Zhang, Yingying and Yu, Lei and Ru, Lixiang and Zhong, Liheng and Huang, Ziyuan and Wu, Kang and Hu, Dingxiang and He, Huimei and Wang, Jian and Chen, Jingdong and Yang, Ming and Zhang, Yongjun and Li, Yansheng}, + title = {SkySense: A Multi-Modal Remote Sensing Foundation Model Towards Universal Interpretation for Earth Observation Imagery}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2024}, + pages = {27672-27683} +} + +@inproceedings{zhu2025skysenseo, + title={Skysense-o: Towards open-world remote sensing interpretation with vision-centric visual-language modeling}, + author={Zhu, Qi and Lao, Jiangwei and Ji, Deyi and Luo, Junwei and Wu, Kang and Zhang, Yingying and Ru, Lixiang and Wang, Jian and Chen, Jingdong and Yang, Ming and others}, + booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference}, + pages={14733--14744}, + year={2025} +} + +@article{luo2024skysensegpt, + title={Skysensegpt: A fine-grained instruction tuning dataset and model for remote sensing vision-language understanding}, + author={Luo, Junwei and Pang, Zhen and Zhang, Yongjun and Wang, Tingzhu and Wang, Linlin and Dang, Bo and Lao, Jiangwei and Wang, Jian and Chen, Jingdong and Tan, Yihua and others}, + journal={arXiv preprint arXiv:2406.10100}, + year={2024} +} +``` +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=kang-wu/SkySensePlusPlus&type=Date)](https://www.star-history.com/#kang-wu/SkySensePlusPlus&Date) diff --git a/configs/eval_skysense_pp_flood3i.yml b/configs/eval_skysense_pp_flood3i.yml new file mode 100644 index 0000000..eaeb1bc --- /dev/null +++ b/configs/eval_skysense_pp_flood3i.yml @@ -0,0 +1,171 @@ +task_attributes: + segmentation: + dataset_attributes: + few_shot_flood_segmentation: + data_root_dir: 'eval_datasets/flood3i' + data_txt: 'eval_datasets/flood3i/val.txt' + img_dir: 'eval_datasets/flood3i/Images' + tgt_dir: 'eval_datasets/flood3i/Semantic_mask' + num_shot: 1 + seq_len: 1 + npz_key: 'arr_0' + image_size: + hr: (512, 512) + s2: (16, 16) + s1: (16, 16) + anno: (512, 512) + mim: + input_size: (1024, 512) + patch_size: 128 + mask_ratio: 0.5 + +model_attributes: + SkySensePP: + sources: ['hr', 's2', 's1'] + use_glbank: False + use_modal_vae: True + use_ctpe: False + use_cls_token_uper_head: False + upsacle_results: True + calendar_time: 365 + vocabulary_size: 64 + backbone_hr: + type: 'SwinTransformerV2MSL' + arch: 'huge' + use_attn: True + merge_stage: 2 + vocabulary_size: 64 + img_size: 224 + patch_size: 4 + in_channels: 3 + window_size: 8 + drop_rate: 0. + drop_path_rate: 0.2 + out_indices: (0,1,2,3) + use_abs_pos_embed: False + interpolate_mode: 'bicubic' + with_cp: True + frozen_stages: -1 + norm_eval: False + pad_small_map: False + pretrained_window_sizes: [0, 0, 0, 0] + + backbone_s2: + type: 'VisionTransformerMSL' + img_size: (16, 16) + use_attn: False + merge_stage: 4 + vocabulary_size: 64 + patch_size: 4 + in_channels: 10 + embed_dims: 1024 + num_layers: 24 + num_heads: 16 + mlp_ratio: 4 + out_indices: (5,11,17,23) + qkv_bias: True + drop_rate: 0. + attn_drop_rate: 0. + drop_path_rate: 0.3 + with_cls_token: False + output_cls_token: False + act_cfg: + type: 'GELU' + norm_cfg: + type: 'LN' + eps: 1e-6 + with_cp: True + interpolate_mode: 'bicubic' + + head_s2: + type: 'UPHead' + in_dim: 1024 + out_dim: 2816 + up_scale: 4 + + backbone_s1: + type: 'VisionTransformerMSL' + img_size: (16, 16) + use_attn: False + merge_stage: 4 + vocabulary_size: 64 + patch_size: 4 + in_channels: 2 + embed_dims: 1024 + num_layers: 24 + num_heads: 16 + mlp_ratio: 4 + out_indices: (5,11,17,23) + qkv_bias: True + drop_rate: 0. + attn_drop_rate: 0. + drop_path_rate: 0.3 + with_cls_token: False + output_cls_token: False + act_cfg: + type: 'GELU' + norm_cfg: + type: 'LN' + eps: 1e-6 + with_cp: True + interpolate_mode: 'bicubic' + + head_s1: + type: 'UPHead' + in_dim: 1024 + out_dim: 2816 + up_scale: 4 + + rec_head_hr: + type: 'UPerHead' + in_channels: [704, 704, 1408, 2816, 1024] + in_index: [0, 1, 2, 3, 4] + pool_scales: (1, 2, 3, 6) + channels: 512 + dropout_ratio: 0.1 + num_classes: 65 + norm_cfg: + type: 'SyncBN' + requires_grad: true + align_corners: false + + necks: + type: 'TransformerEncoder' + input_dims: 2816 + embed_dims: 1024 + num_layers: 24 + num_heads: 16 + mlp_ratio: 4 + qkv_bias: True + drop_rate: 0. + attn_drop_rate: 0. + drop_path_rate: 0.3 + with_cls_token: True + output_cls_token: True + norm_cfg: + type: 'LN' + act_cfg: + type: 'GELU' + num_fcs: 2 + norm_eval: False + with_cp: True + + modality_vae: + type: 'ModalityCompletion' + input_shape_hr: [2816, 32, 16] + input_shape_s2: [2816, 32, 16] + input_shape_s1: [2816, 32, 16] + conv_dim: 256 + z_dim: 256 + n_codebook: 8192 + +amp_attributes: + amp_escapes: Conv2d + opt_level: O1 + init_scale: 1 + +predictor_parameters: + predictor: 'OneshotPredictor' + replace_speedup_op: True + device: cuda + local_rank: 0 diff --git a/configs/pretrain_skysensepp.yml b/configs/pretrain_skysensepp.yml new file mode 100644 index 0000000..eb1af9d --- /dev/null +++ b/configs/pretrain_skysensepp.yml @@ -0,0 +1,248 @@ +task_attributes: + segmentation: + dataset_attributes: + pretraining_loader: + data_root_dir: 'pretrain_datasets/' + train_json_path_list: ['dynamic-mm/dynamic-mm_train.json', 'pastis-mm/pastis-mm_train.json', 'vaihingen/vaihingen_train.json', 'deepglobe/deepglobe_train.json', 'potsdam/potsdam_train.json', 'fbp/fbp_train.json', 'loveda/loveda_train.json', 'isaid/isaid_train.json', 'jl16-mm/jl16-mm_train.json', 'flair-mm/flair-mm_train.json', 's2naip-mm/s2naip-mm_train.json', 'dfc20-mm/dfc20-mm_train.json', 'c2segab-mm/c2segab-mm_train.json'] + val_json_path_list: ['dynamic-mm/dynamic-mm_val.json', 'pastis-mm/pastis-mm_val.json', 'vaihingen/vaihingen_val.json', 'deepglobe/deepglobe_val.json', 'potsdam/potsdam_val.json', 'fbp/fbp_val.json', 'loveda/loveda_val.json', 'isaid/isaid_val.json', 'jl16-mm/jl16-mm_val.json', 'flair-mm/flair-mm_val.json', 's2naip-mm/s2naip-mm_val.json', 'dfc20-mm/dfc20-mm_val.json', 'c2segab-mm/c2segab-mm_val.json'] + use_multi_pairs: True + seq_len: 1 + half_mask_ratio: 0.3 + min_random_scale: 0.3 + cls_repeat_cnt: 2000 + image_size: + hr: (512, 512) + s2: (16, 16) + s1: (16, 16) + anno: (512, 512) + mim: + input_size: (1024, 512) + patch_size: 128 + mask_ratio: 0.5 + +model_attributes: + SkySensePP: + sources: ['hr', 's2', 's1'] + use_modal_vae: True + use_ctpe: False + use_cls_token_uper_head: False + upsacle_results: True + calendar_time: 365 + vocabulary_size: 64 + backbone_hr: + type: 'SwinTransformerV2MSL' + arch: 'huge' + use_attn: True + merge_stage: 2 + vocabulary_size: 64 + img_size: 224 + patch_size: 4 + in_channels: 3 + window_size: 8 + drop_rate: 0. + drop_path_rate: 0.2 + out_indices: (0,1,2,3) + use_abs_pos_embed: False + interpolate_mode: 'bicubic' + with_cp: True + frozen_stages: -1 + norm_eval: False + pad_small_map: False + pretrained_window_sizes: [0, 0, 0, 0] + init_cfg: + type: Pretrained + checkpoint: 'pretrain/skysense_model_backbone_hr.pth' + + backbone_s2: + type: 'VisionTransformerMSL' + img_size: (16, 16) + use_attn: False + merge_stage: 4 + vocabulary_size: 64 + patch_size: 4 + in_channels: 10 + embed_dims: 1024 + num_layers: 24 + num_heads: 16 + mlp_ratio: 4 + out_indices: (5,11,17,23) + qkv_bias: True + drop_rate: 0. + attn_drop_rate: 0. + drop_path_rate: 0.3 + with_cls_token: False + output_cls_token: False + act_cfg: + type: 'GELU' + norm_cfg: + type: 'LN' + eps: 1e-6 + with_cp: True + interpolate_mode: 'bicubic' + init_cfg: + type: Pretrained + checkpoint: 'pretrain/skysense_model_backbone_s2.pth' + + head_s2: + type: 'UPHead' + in_dim: 1024 + out_dim: 2816 #2816 + up_scale: 4 + init_cfg: + type: Pretrained + checkpoint: 'pretrain/skysense_model_head_s2.pth' + + backbone_s1: + type: 'VisionTransformerMSL' + img_size: (16, 16) + use_attn: False + merge_stage: 4 + vocabulary_size: 64 + patch_size: 4 + in_channels: 2 + embed_dims: 1024 + num_layers: 24 + num_heads: 16 + mlp_ratio: 4 + out_indices: (5,11,17,23) + qkv_bias: True + drop_rate: 0. + attn_drop_rate: 0. + drop_path_rate: 0.3 + with_cls_token: False + output_cls_token: False + act_cfg: + type: 'GELU' + norm_cfg: + type: 'LN' + eps: 1e-6 + with_cp: True + interpolate_mode: 'bicubic' + init_cfg: + type: Pretrained + checkpoint: 'pretrain/skysense_model_backbone_s1.pth' + + head_s1: + type: 'UPHead' + in_dim: 1024 + out_dim: 2816 #2816 + up_scale: 4 + init_cfg: + type: Pretrained + checkpoint: 'pretrain/skysense_model_head_s1.pth' + + rec_head_hr: + type: 'UPerHead' + in_channels: [704, 704, 1408, 2816, 1024] + in_index: [0, 1, 2, 3, 4] + pool_scales: (1, 2, 3, 6) + channels: 512 + dropout_ratio: 0.1 + num_classes: 65 + norm_cfg: + type: 'SyncBN' + requires_grad: true + align_corners: false + + necks: + type: 'TransformerEncoder' + input_dims: 2816 + embed_dims: 1024 + num_layers: 24 + num_heads: 16 + mlp_ratio: 4 + qkv_bias: True + drop_rate: 0. + attn_drop_rate: 0. + drop_path_rate: 0.3 + with_cls_token: True + output_cls_token: True + norm_cfg: + type: 'LN' + act_cfg: + type: 'GELU' + num_fcs: 2 + norm_eval: False + with_cp: True + init_cfg: + type: Pretrained + checkpoint: 'pretrain/skysense_model_fusion.pth' + + modality_vae: + type: 'ModalityCompletion' + input_shape_hr: [2816, 32, 16] + input_shape_s2: [2816, 32, 16] + input_shape_s1: [2816, 32, 16] + conv_dim: 256 + z_dim: 256 + n_codebook: 8192 + + metrics: + - type: 'sem_metric' + + losses: + - type: 'RecLoss' + params: + weight: 1.0 + patch_size: 4 + balance: True + use_all_patch: True + vocabulary_size: 64 + feature_merged: True + pred_key: 'logits_hr' + mask_key: 'mask_hr' + target_key: 'mapped_targets' + use_bg: True + + - type: 'ModalityVAELoss' + params: + weight: 1.0 + + +optimizer_attributes: + type: AdamW + params: + lr: 2e-04 + betas: (0.9, 0.999) + weight_decay: 0.04 + +lr_parameters: + layer_decay: 0.7 + frozen_blocks: 12 + frozen_fusion_blocks_start: 3 + +training_parameters: + trainer: 'seg_trainer' + run_type: train + seed: 24042301 + pin_memory: True + batch_size: 256 + test_batch_size: 128 + num_workers: 16 + max_iterations: 30000 + num_warmup_steps: 1000 + log_interval: 50 + snapshot_interval: 3000 + cos_lr: False + + clip_norm_mode: all + clip_gradients: true + max_grad_l2_norm: 5 + + enable_tf32: False + enable_amp: True + find_unused_parameters: True + synchronized_loss: True + + static_graph: True + replace_speedup_op: True + + ema: False + + distributed_batch_sampler: + batch_size: 8 + +amp_attributes: + amp_escapes: Conv2d + opt_level: O1 + init_scale: 1 \ No newline at end of file diff --git a/finetune/configs/atlantic.py b/finetune/configs/atlantic.py new file mode 100644 index 0000000..95ec894 --- /dev/null +++ b/finetune/configs/atlantic.py @@ -0,0 +1,303 @@ +crop_size = ( + 256, + 256, +) +data_preprocessor = dict( + bgr_to_rgb=True, + mean=[ + 537.9411629981602, + 615.7886221108977, + 343.4481583821405, + 3010.641650390625, + ], + pad_val=0, + seg_pad_val=255, + size=crop_size, + std=[ + 367.4598430230881, + 254.2473100510193, + 187.5437562223154, + 921.0792775874182, + ], + type='SegDataPreProcessor') +data_root = 'rs_datasets/' +dataset_type = 'AtlanticDataset' +default_hooks = dict( + checkpoint=dict( + by_epoch=False, interval=1000, save_best='mIoU',max_keep_ckpts=1, + type='CheckpointHook'), + logger=dict(interval=50, log_metric_by_epoch=False, type='LoggerHook'), + param_scheduler=dict(type='ParamSchedulerHook'), + sampler_seed=dict(type='DistSamplerSeedHook'), + timer=dict(type='IterTimerHook'), + visualization=dict(type='SegVisualizationHook')) +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + dist_cfg=dict(backend='nccl'), + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0)) +img_ratios = [ + 0.5, + 0.75, + 1.0, + 1.25, + 1.5, + 1.75, +] +launcher = 'pytorch' +load_from = None +log_level = 'INFO' +log_processor = dict(by_epoch=False) +mean = [ + 537.9411629981602, + 615.7886221108977, + 343.4481583821405, + 3010.641650390625, +] +model = dict( + auxiliary_head=dict( + align_corners=False, + channels=256, + concat_input=False, + dropout_ratio=0.1, + in_channels=512, + in_index=3, + loss_decode=dict( + loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=False), + norm_cfg=dict(requires_grad=True, type='SyncBN'), + num_classes=2, + num_convs=1, + type='FCNHead'), + backbone=dict( + act_cfg=dict(type='GELU'), + attn_drop_rate=0.0, + drop_path_rate=0.0, + drop_rate=0.1, + embed_dims=1024, + img_size=crop_size, + in_channels=4, + interpolate_mode='bilinear', + mlp_ratio=4, + norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'), + norm_eval=False, + num_heads=16, + num_layers=24, + out_indices=( + 5, + 11, + 17, + 23, + ), + patch_size=4, + qkv_bias=True, + type='VisionTransformer', + with_cls_token=False), + data_preprocessor=dict( + bgr_to_rgb=True, + mean=[ + 537.9411629981602, + 615.7886221108977, + 343.4481583821405, + 3010.641650390625, + ], + pad_val=0, + seg_pad_val=255, + size=crop_size, + std=[ + 367.4598430230881, + 254.2473100510193, + 187.5437562223154, + 921.0792775874182, + ], + type='SegDataPreProcessor'), + decode_head=dict( + align_corners=False, + channels=512, + dropout_ratio=0.1, + in_channels=[ + 512, + 512, + 512, + 512, + ], + in_index=[ + 0, + 1, + 2, + 3, + ], + loss_decode=dict( + loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=False), + norm_cfg=dict(requires_grad=True, type='SyncBN'), + num_classes=2, + pool_scales=( + 1, + 2, + 3, + 6, + ), + type='UPerHead'), + neck=dict( + in_channels=[ + 1024, + 1024, + 1024, + 1024, + ], + out_channels=512, + scales=[ + 1, + 1, + 1, + 1, + ], + type='MultiLevelNeck'), + pretrained="pretrain/skysensepp_release_s2.pth", + test_cfg=dict(mode='whole'), + train_cfg=dict(), + type='EncoderDecoder') +norm_cfg = dict(requires_grad=True, type='SyncBN') +optim_wrapper = dict( + optimizer=dict( + betas=( + 0.9, + 0.999, + ), lr=6e-05, type='AdamW', weight_decay=0.01), + paramwise_cfg=dict( + custom_keys=dict( + cls_token=dict(decay_mult=0.0), + norm=dict(decay_mult=0.0), + pos_embed=dict(decay_mult=0.0))), + type='OptimWrapper') +optimizer = dict(lr=0.01, momentum=0.9, type='SGD', weight_decay=0.0005) +param_scheduler = [ + dict( + begin=0, by_epoch=False, end=1000, start_factor=1e-06, + type='LinearLR'), + dict( + begin=1000, + by_epoch=False, + end=10000, + eta_min=0.0, + power=1.0, + type='PolyLR'), +] +randomness = dict(seed=20240315) +resume = False +std = [ + 367.4598430230881, + 254.2473100510193, + 187.5437562223154, + 921.0792775874182, +] +test_cfg = dict(type='TestLoop') +test_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file= + 'deforestation_atlantic/deforestation_atlantic_test.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(type='LoadSingleRSImageFromFile'), + dict(reduce_zero_label=False, type='LoadAnnotations'), + dict(type='PackSegInputs'), + ], + type='AtlanticDataset'), + num_workers=4, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +test_evaluator = dict( + iou_metrics=[ + 'mIoU', + 'mFscore', + ], type='IoUMetric') +test_pipeline = [ + dict(type='LoadSingleRSImageFromFile'), + dict(reduce_zero_label=False, type='LoadAnnotations'), + dict(type='PackSegInputs'), +] +train_cfg = dict(max_iters=10000, type='IterBasedTrainLoop', val_interval=500) +train_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file= + 'deforestation_atlantic/deforestation_atlantic_train.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(type='LoadSingleRSImageFromFile'), + dict(reduce_zero_label=False, type='LoadAnnotations'), + dict(cat_max_ratio=0.75, crop_size=crop_size, type='RandomCrop'), + dict(prob=0.5, type='RandomFlip'), + dict(type='PackSegInputs'), + ], + type='AtlanticDataset'), + num_workers=4, + persistent_workers=True, + sampler=dict(shuffle=True, type='InfiniteSampler')) +train_pipeline = [ + dict(type='LoadSingleRSImageFromFile'), + dict(reduce_zero_label=False, type='LoadAnnotations'), + dict(cat_max_ratio=0.75, crop_size=crop_size, type='RandomCrop'), + dict(prob=0.5, type='RandomFlip'), + dict(type='PackSegInputs'), +] +tta_model = dict(type='SegTTAModel') +tta_pipeline = [ + dict(backend_args=None, type='LoadImageFromFile'), + dict( + transforms=[ + [ + dict(keep_ratio=True, scale_factor=0.5, type='Resize'), + dict(keep_ratio=True, scale_factor=0.75, type='Resize'), + dict(keep_ratio=True, scale_factor=1.0, type='Resize'), + dict(keep_ratio=True, scale_factor=1.25, type='Resize'), + dict(keep_ratio=True, scale_factor=1.5, type='Resize'), + dict(keep_ratio=True, scale_factor=1.75, type='Resize'), + ], + [ + dict(direction='horizontal', prob=0.0, type='RandomFlip'), + dict(direction='horizontal', prob=1.0, type='RandomFlip'), + ], + [ + dict(type='LoadAnnotations'), + ], + [ + dict(type='PackSegInputs'), + ], + ], + type='TestTimeAug'), +] +val_cfg = dict(type='ValLoop') +val_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file= + 'deforestation_atlantic/deforestation_atlantic_val.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(type='LoadSingleRSImageFromFile'), + dict(reduce_zero_label=False, type='LoadAnnotations'), + dict(type='PackSegInputs'), + ], + type='AtlanticDataset'), + num_workers=4, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +val_evaluator = dict( + iou_metrics=[ + 'mIoU', + 'mFscore', + ], type='IoUMetric') +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + name='visualizer', + type='SegLocalVisualizer', + vis_backends=[ + dict(type='LocalVisBackend'), + ]) +work_dir = 'save/atlantic_skysensepp' diff --git a/finetune/configs/c2smsflood.py b/finetune/configs/c2smsflood.py new file mode 100644 index 0000000..01dc0ac --- /dev/null +++ b/finetune/configs/c2smsflood.py @@ -0,0 +1,284 @@ +dataset_type = 'C2SFloodDataset' +default_hooks = dict( + checkpoint=dict( + by_epoch=False, + interval=2000, + max_keep_ckpts=2, + save_best='mIoU', + type='CheckpointHook'), + logger=dict(interval=50, log_metric_by_epoch=False, type='LoggerHook'), + param_scheduler=dict(type='ParamSchedulerHook'), + sampler_seed=dict(type='DistSamplerSeedHook'), + timer=dict(type='IterTimerHook'), + visualization=dict(type='SegVisualizationHook')) +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + dist_cfg=dict(backend='nccl'), + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0)) +launcher = 'pytorch' +load_from = None +log_level = 'INFO' +log_processor = dict(by_epoch=False) +model = dict( + auxiliary_head=dict( + align_corners=False, + channels=256, + concat_input=False, + dropout_ratio=0.1, + in_channels=512, + in_index=3, + loss_decode=dict( + loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=False), + norm_cfg=dict(requires_grad=True, type='SyncBN'), + num_classes=4, + num_convs=1, + type='FCNHead'), + backbone=dict( + act_cfg=dict(type='GELU'), + attn_drop_rate=0.0, + downscale_indices=( + 5, + 11, + ), + drop_path_rate=0.0, + drop_rate=0.1, + embed_dims=1024, + img_size=( + 256, + 256, + ), + in_channels=10, + interpolate_mode='bilinear', + mlp_ratio=4, + norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'), + norm_eval=False, + num_heads=16, + num_layers=24, + out_indices=( + 5, + 11, + 17, + 23, + ), + patch_size=4, + qkv_bias=True, + type='VisionTransformer', + with_cls_token=False), + data_preprocessor=dict( + bgr_to_rgb=True, + mean=[ + 1381.26, + 1302.05, + 1179.27, + 1393.56, + 2164.76, + 2561.75, + 2377.94, + 2791.13, + 1998.09, + 1196.08, + ], + pad_val=0, + seg_pad_val=255, + size=( + 256, + 256, + ), + std=[ + 653.25, + 659.61, + 779.58, + 720.45, + 871.09, + 1035.57, + 965.36, + 1141.71, + 1019.73, + 825.01, + ], + type='SegDataPreProcessor'), + decode_head=dict( + align_corners=False, + channels=512, + dropout_ratio=0.1, + in_channels=[ + 512, + 512, + 512, + 512, + ], + in_index=[ + 0, + 1, + 2, + 3, + ], + loss_decode=dict( + loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=False), + norm_cfg=dict(requires_grad=True, type='SyncBN'), + num_classes=4, + pool_scales=( + 1, + 2, + 3, + 6, + ), + type='UPerHead'), + neck=dict( + in_channels=[ + 1024, + 1024, + 1024, + 1024, + ], + out_channels=512, + scales=[ + 1, + 1, + 1, + 1, + ], + type='MultiLevelNeck'), + pretrained= + 'pretrain/skysensepp_mmcvt_s2.pth', + test_cfg=dict(mode='whole'), + train_cfg=dict(), + type='EncoderDecoder') +norm_cfg = None +optim_wrapper = dict( + constructor='LearningRateDecayOptimizerConstructor', + optimizer=dict( + betas=( + 0.9, + 0.999, + ), lr=0.0005, type='AdamW', weight_decay=0.015), + paramwise_cfg=dict( + custom_keys=dict( + cls_token=dict(decay_mult=0.0), + norm=dict(decay_mult=0.0), + pos_embed=dict(decay_mult=0.0)), + decay_rate=0.84, + decay_type='layer_wise', + num_layers=24), + type='OptimWrapper') +param_scheduler = [ + dict( + begin=0, by_epoch=False, end=1000, start_factor=1e-06, + type='LinearLR'), + dict( + begin=1000, + by_epoch=False, + end=10000, + eta_min=0.0, + power=1.0, + type='PolyLR'), +] +randomness = dict(seed=1) +resume = False +test_cfg = dict(type='TestLoop') +test_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file= + 'rs_datasets/c2smsfloods/c2s_s2_val_mm.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(type='LoadImageFromNpz'), + dict(reduce_zero_label=False, type='LoadAnnotationsNpz'), + dict(type='PackSegInputs'), + ], + type='C2SFloodDataset'), + num_workers=4, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +test_evaluator = dict( + iou_metrics=[ + 'mIoU', + 'mFscore' + ], type='IoUMetric') +test_pipeline = [ + dict(type='LoadImageFromNpz'), + dict(reduce_zero_label=False, type='LoadAnnotationsNpz'), + dict(type='PackSegInputs'), +] +train_cfg = dict(max_iters=10000, type='IterBasedTrainLoop', val_interval=500) +train_dataloader = dict( + batch_size=2, + dataset=dict( + ann_file= + 'rs_datasets/c2smsfloods/c2s_s2_train_mm.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(type='LoadImageFromNpz'), + dict(reduce_zero_label=False, type='LoadAnnotationsNpz'), + dict( + cat_max_ratio=0.75, crop_size=( + 256, + 256, + ), type='RandomCrop'), + dict(prob=0.5, type='RandomFlip'), + dict(type='PackSegInputs'), + ], + type='C2SFloodDataset'), + num_workers=4, + persistent_workers=True, + sampler=dict(shuffle=True, type='InfiniteSampler')) +tta_model = dict(type='SegTTAModel') +tta_pipeline = [ + dict(backend_args=None, type='LoadImageFromFile'), + dict( + transforms=[ + [ + dict(keep_ratio=True, scale_factor=1.0, type='Resize'), + dict(keep_ratio=True, scale_factor=1.2, type='Resize'), + dict(keep_ratio=True, scale_factor=1.4, type='Resize'), + dict(keep_ratio=True, scale_factor=1.6, type='Resize'), + dict(keep_ratio=True, scale_factor=1.8, type='Resize'), + dict(keep_ratio=True, scale_factor=2.0, type='Resize'), + ], + [ + dict(direction='horizontal', prob=0.0, type='RandomFlip'), + dict(direction='horizontal', prob=1.0, type='RandomFlip'), + ], + [ + dict(type='LoadAnnotations'), + ], + [ + dict(type='PackSegInputs'), + ], + ], + type='TestTimeAug'), +] +val_cfg = dict(type='ValLoop') +val_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file= + 'rs_datasets/c2smsfloods/c2s_s2_val_mm.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(type='LoadImageFromNpz'), + dict(reduce_zero_label=False, type='LoadAnnotationsNpz'), + dict(type='PackSegInputs'), + ], + type='C2SFloodDataset'), + num_workers=4, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +val_evaluator = dict( + iou_metrics=[ + 'mIoU', + ], type='IoUMetric') +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + name='visualizer', + type='SegLocalVisualizer', + vis_backends=[ + dict(type='LocalVisBackend'), + ]) diff --git a/finetune/configs/cabuar.py b/finetune/configs/cabuar.py new file mode 100644 index 0000000..e2cefc3 --- /dev/null +++ b/finetune/configs/cabuar.py @@ -0,0 +1,227 @@ +default_hooks = dict( + checkpoint=dict( + by_epoch=False, + interval=1000, + max_keep_ckpts=1, + save_best='mIoU', + type='CheckpointHook'), + logger=dict(interval=50, log_metric_by_epoch=False, type='LoggerHook'), + param_scheduler=dict(type='ParamSchedulerHook'), + sampler_seed=dict(type='DistSamplerSeedHook'), + timer=dict(type='IterTimerHook'), + visualization=dict(type='SegVisualizationHook')) +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + dist_cfg=dict(backend='nccl'), + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0)) +launcher = 'pytorch' +load_from = None +log_level = 'INFO' +log_processor = dict(by_epoch=False) +model = dict( + auxiliary_head=dict( + align_corners=False, + channels=256, + concat_input=False, + dropout_ratio=0.1, + in_channels=512, + in_index=3, + loss_decode=dict( + loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=False), + norm_cfg=dict(requires_grad=True, type='SyncBN'), + num_classes=2, + num_convs=1, + type='FCNHead'), + backbone=dict( + act_cfg=dict(type='GELU'), + attn_drop_rate=0.0, + downscale_indices=(5, 11,), + drop_path_rate=0.0, + drop_rate=0.1, + embed_dims=1024, + img_size=(256, 256,), + in_channels=10, + interpolate_mode='bilinear', + mlp_ratio=4, + norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'), + norm_eval=False, + num_heads=16, + num_layers=24, + out_indices=(5, 11, 17, 23,), + patch_size=4, + qkv_bias=True, + type='VisionTransformer', + with_cls_token=False), + data_preprocessor=dict( + bgr_to_rgb=True, + mean=[4.50021132, 6.09891466, 7.50766315, 9.54643074, 12.82568112, 14.29062133, 15.24644993, 15.73945708, 16.60374872, 12.31011599,], + pad_val=0, + seg_pad_val=255, + size=(256, 256,), + std=[2.6094148, 2.49566825, 1.37103968, 2.6094148, 2.49566825, 1.37103968, 2.6094148, 2.49566825, 1.37103968, 1.37103968, ], + type='SegDataPreProcessor'), + decode_head=dict( + align_corners=False, + channels=512, + dropout_ratio=0.1, + in_channels=[512, 512, 512, 512,], + in_index=[0, 1, 2, 3,], + loss_decode=dict( + loss_weight=1.0, type='CrossEntropyLoss', balance=True, max_scale=4.0, use_sigmoid=False), + norm_cfg=dict(requires_grad=True, type='SyncBN'), + num_classes=2, + pool_scales=(1, 2, 3, 6,), + type='UPerHead'), + neck=dict( + in_channels=[1024, 1024, 1024, 1024,], + out_channels=512, + scales=[1, 1, 1, 1,], + type='MultiLevelNeck'), + pretrained='pretrain/skysensepp_mmcvt_s2.pth', + test_cfg=dict(mode='whole'), + train_cfg=dict(), + type='EncoderDecoder') +optim_wrapper = dict( + optimizer=dict(betas=(0.9, 0.999,), lr=None, type='AdamW', weight_decay=None,), + constructor='LearningRateDecayOptimizerConstructor', + paramwise_cfg=dict( + num_layers=24, + decay_rate=None, + decay_type='layer_wise', + custom_keys=dict( + cls_token=dict(decay_mult=0.0), + norm=dict(decay_mult=0.0), + pos_embed=dict(decay_mult=0.0))), + type='OptimWrapper') +param_scheduler = [ + dict( + begin=0, by_epoch=False, end=1000, start_factor=1e-06, + type='LinearLR'), + dict( + begin=1000, + by_epoch=False, + end=10000, + eta_min=0.0, + power=1.0, + type='PolyLR'), +] +randomness = dict(seed=20240315) +resume = False +test_cfg = dict(type='TestLoop') +test_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file= + 'rs_datasets/cabuar/cabura_val_fold_4.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(data_key='image', type='LoadImageFromNpz'), + dict( + data_key='image', + reduce_zero_label=False, + type='LoadAnnotationsNpz'), + dict(type='PackSegInputs'), + ], + type='CABURADataset'), + num_workers=4, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +test_evaluator = dict( + iou_metrics=[ + 'mIoU', + 'mFscore', + ], type='IoUMetric') +test_pipeline = [ + dict(data_key='image', type='LoadImageFromNpz'), + dict(data_key='image', reduce_zero_label=False, type='LoadAnnotationsNpz'), + dict(type='PackSegInputs'), +] +train_cfg = dict(max_iters=10000, type='IterBasedTrainLoop', val_interval=500) +train_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file= + 'rs_datasets/cabuar/cabura_train_fold0_3.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(data_key='image', type='LoadImageFromNpz'), + dict( + data_key='image', + reduce_zero_label=False, + type='LoadAnnotationsNpz'), + dict( + cat_max_ratio=0.75, crop_size=( + 256, + 256, + ), type='RandomCrop'), + dict(prob=0.5, type='RandomFlip'), + dict(type='PackSegInputs'), + ], + type='CABURADataset'), + num_workers=4, + persistent_workers=True, + sampler=dict(shuffle=True, type='InfiniteSampler')) +tta_model = dict(type='SegTTAModel') +tta_pipeline = [ + dict(backend_args=None, type='LoadImageFromFile'), + dict( + transforms=[ + [ + dict(keep_ratio=True, scale_factor=1.0, type='Resize'), + dict(keep_ratio=True, scale_factor=1.2, type='Resize'), + dict(keep_ratio=True, scale_factor=1.4, type='Resize'), + dict(keep_ratio=True, scale_factor=1.6, type='Resize'), + dict(keep_ratio=True, scale_factor=1.8, type='Resize'), + dict(keep_ratio=True, scale_factor=2.0, type='Resize'), + ], + [ + dict(direction='horizontal', prob=0.0, type='RandomFlip'), + dict(direction='horizontal', prob=1.0, type='RandomFlip'), + ], + [ + dict(type='LoadAnnotations'), + ], + [ + dict(type='PackSegInputs'), + ], + ], + type='TestTimeAug'), +] +val_cfg = dict(type='ValLoop') +val_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file= + 'rs_datasets/cabuar/cabura_val_fold_4.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(data_key='image', type='LoadImageFromNpz'), + dict( + data_key='image', + reduce_zero_label=False, + type='LoadAnnotationsNpz'), + dict(type='PackSegInputs'), + ], + type='CABURADataset'), + num_workers=4, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +val_evaluator = dict( + iou_metrics=[ + 'mIoU', + 'mFscore', + ], type='IoUMetric') +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + name='visualizer', + type='SegLocalVisualizer', + vis_backends=[ + dict(type='LocalVisBackend'), + ]) +work_dir = 'work_dirs/ft_cabura_test' \ No newline at end of file diff --git a/finetune/configs/germany.py b/finetune/configs/germany.py new file mode 100644 index 0000000..b0610db --- /dev/null +++ b/finetune/configs/germany.py @@ -0,0 +1,414 @@ +crop_size = ( + 24, + 24, +) +data_preprocessor = dict( + bgr_to_rgb=False, + mean=[ + 2482.0061841829206, + 2456.642580060208, + 2667.8229979675334, + 2744.9377076257624, + 3620.1499158373827, + 4063.9126981046647, + 3922.2406108776354, + 4264.908986788407, + 2453.0070206816135, + 1774.0019119673998, + ], + pad_val=0, + seg_pad_val=255, + size=( + 24, + 24, + ), + std=[ + 2392.1256366526068, + 2100.1364646122875, + 2262.6154840764625, + 2353.899770400333, + 2089.598452203458, + 2057.1247114077073, + 2013.2108514271458, + 2041.0248949410561, + 1380.4643757742374, + 1243.547946113518, + ], + ts_size=30, + type='RSTsSegDataPreProcessor') +data_root = 'rs_datasets/' +dataset_type = 'GermanyCropDataset' +default_hooks = dict( + checkpoint=dict( + by_epoch=False, + interval=2000, + max_keep_ckpts=1, + save_best='mIoU', + type='CheckpointHook'), + logger=dict(interval=20, log_metric_by_epoch=False, type='LoggerHook'), + param_scheduler=dict(type='ParamSchedulerHook'), + sampler_seed=dict(type='DistSamplerSeedHook'), + timer=dict(type='IterTimerHook'), + visualization=dict(type='SegVisualizationHook')) +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + dist_cfg=dict(backend='nccl'), + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0)) +find_unused_parameters = True +img_ratios = [ + 0.5, + 0.75, + 1.0, + 1.25, + 1.5, + 1.75, +] +launcher = 'pytorch' +load_from = None +log_level = 'INFO' +log_processor = dict(by_epoch=False) +mean = [ + 2482.0061841829206, + 2456.642580060208, + 2667.8229979675334, + 2744.9377076257624, + 3620.1499158373827, + 4063.9126981046647, + 3922.2406108776354, + 4264.908986788407, + 2453.0070206816135, + 1774.0019119673998, +] +model = dict( + auxiliary_head=dict( + align_corners=False, + channels=256, + concat_input=False, + dropout_ratio=0.1, + in_channels=1024, + in_index=3, + loss_decode=dict( + loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=False), + norm_cfg=dict(requires_grad=True, type='SyncBN'), + num_classes=18, + num_convs=1, + type='FCNHead'), + backbone=dict( + act_cfg=dict(type='GELU'), + attn_drop_rate=0.0, + downscale_indices=[ + -1, + ], + drop_path_rate=0.0, + drop_rate=0.1, + embed_dims=1024, + img_size=( + 24, + 24, + ), + in_channels=10, + init_cfg=dict( + checkpoint= + 'pretrain/skysensepp_mmcvt_s2.pth', + type='Pretrained'), + interpolate_mode='bilinear', + mlp_ratio=4, + norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'), + norm_eval=False, + num_heads=16, + num_layers=24, + out_indices=( + 5, + 11, + 17, + 23, + ), + patch_size=4, + qkv_bias=True, + type='VisionTransformer', + with_cls_token=False), + data_preprocessor=dict( + bgr_to_rgb=False, + mean=[ + 2482.0061841829206, + 2456.642580060208, + 2667.8229979675334, + 2744.9377076257624, + 3620.1499158373827, + 4063.9126981046647, + 3922.2406108776354, + 4264.908986788407, + 2453.0070206816135, + 1774.0019119673998, + ], + pad_val=0, + seg_pad_val=255, + size=( + 24, + 24, + ), + std=[ + 2392.1256366526068, + 2100.1364646122875, + 2262.6154840764625, + 2353.899770400333, + 2089.598452203458, + 2057.1247114077073, + 2013.2108514271458, + 2041.0248949410561, + 1380.4643757742374, + 1243.547946113518, + ], + ts_size=30, + type='RSTsSegDataPreProcessor'), + decode_head=dict( + align_corners=False, + channels=512, + dropout_ratio=0.1, + in_channels=[ + 1024, + 1024, + 1024, + 1024, + ], + in_index=[ + 0, + 1, + 2, + 3, + ], + loss_decode=dict( + loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=False), + norm_cfg=dict(requires_grad=True, type='SyncBN'), + num_classes=18, + pool_scales=( + 1, + 2, + 3, + 6, + ), + type='UPerHead'), + neck=dict( + attn_drop_rate=0.0, + drop_path_rate=0.3, + drop_rate=0.0, + embed_dims=1024, + in_channels=[ + 768, + 768, + 768, + 768, + ], + in_channels_ml=[ + 1024, + 1024, + 1024, + 1024, + ], + init_cfg=dict( + checkpoint= + 'pretrain/skysensepp_mmcvt_fusion.pth', + type='Pretrained'), + input_dims=1024, + mlp_ratio=4, + num_heads=16, + num_layers=24, + out_channels=768, + out_channels_ml=1024, + output_cls_token=True, + qkv_bias=True, + scales=[ + 4, + 2, + 1, + 0.5, + ], + scales_ml=[ + 1, + 1, + 1, + 1, + ], + ts_size=30, + type='FusionMultiLevelNeck', + with_cls_token=True), + test_cfg=dict(mode='whole'), + train_cfg=dict(), + type='EncoderDecoder') +norm_cfg = dict(requires_grad=True, type='SyncBN') +optim_wrapper = dict( + optimizer=dict( + betas=( + 0.9, + 0.999, + ), lr=0.0001, type='AdamW', weight_decay=0.01), + paramwise_cfg=dict( + custom_keys=dict( + cls_token=dict(decay_mult=0.0), + norm=dict(decay_mult=0.0), + pos_embed=dict(decay_mult=0.0))), + type='OptimWrapper') +optimizer = dict(lr=0.01, momentum=0.9, type='SGD', weight_decay=0.0005) +param_scheduler = [ + dict( + begin=0, by_epoch=False, end=2000, start_factor=1e-06, + type='LinearLR'), + dict( + begin=2000, + by_epoch=False, + end=20000, + eta_min=0.0, + power=1.0, + type='PolyLR'), +] +randomness = dict(seed=20240311) +resume = False +static_graph = True +std = [ + 2392.1256366526068, + 2100.1364646122875, + 2262.6154840764625, + 2353.899770400333, + 2089.598452203458, + 2057.1247114077073, + 2013.2108514271458, + 2041.0248949410561, + 1380.4643757742374, + 1243.547946113518, +] +test_cfg = dict(type='TestLoop') +test_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file= + 'rs_datasets/germany_crop/germany_crop_val.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'), + dict( + data_key='image', + reduce_zero_label=True, + type='LoadAnnotationsNpz'), + dict(type='PackSegInputs'), + ], + type='GermanyCropDataset'), + num_workers=4, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +test_evaluator = dict( + iou_metrics=[ + 'mIoU', + 'mFscore', + ], type='IoUMetric') +test_pipeline = [ + dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'), + dict(data_key='image', reduce_zero_label=True, type='LoadAnnotationsNpz'), + dict(type='PackSegInputs'), +] +train_cfg = dict( + dynamic_intervals=[ + ( + 0, + 1000, + ), + ( + 4000, + 2000, + ), + ( + 8000, + 4000, + ), + ], + max_iters=20000, + type='IterBasedTrainLoop', + val_interval=2000) +train_dataloader = dict( + batch_size=2, + dataset=dict( + ann_file= + 'rs_datasets/germany_crop/germany_crop_train.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'), + dict( + data_key='image', + reduce_zero_label=True, + type='LoadAnnotationsNpz'), + dict(prob=0.5, type='RandomFlip'), + dict(type='PackSegInputs'), + ], + type='GermanyCropDataset'), + num_workers=4, + persistent_workers=True, + sampler=dict(shuffle=True, type='InfiniteSampler')) +train_pipeline = [ + dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'), + dict(data_key='image', reduce_zero_label=True, type='LoadAnnotationsNpz'), + dict(prob=0.5, type='RandomFlip'), + dict(type='PackSegInputs'), +] +tta_model = dict(type='SegTTAModel') +tta_pipeline = [ + dict(backend_args=None, type='LoadImageFromFile'), + dict( + transforms=[ + [ + dict(keep_ratio=True, scale_factor=0.5, type='Resize'), + dict(keep_ratio=True, scale_factor=0.75, type='Resize'), + dict(keep_ratio=True, scale_factor=1.0, type='Resize'), + dict(keep_ratio=True, scale_factor=1.25, type='Resize'), + dict(keep_ratio=True, scale_factor=1.5, type='Resize'), + dict(keep_ratio=True, scale_factor=1.75, type='Resize'), + ], + [ + dict(direction='horizontal', prob=0.0, type='RandomFlip'), + dict(direction='horizontal', prob=1.0, type='RandomFlip'), + ], + [ + dict(type='LoadAnnotations'), + ], + [ + dict(type='PackSegInputs'), + ], + ], + type='TestTimeAug'), +] +val_cfg = dict(type='ValLoop') +val_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file= + 'rs_datasets/germany_crop/germany_crop_val.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(data_key='image', ts_size=30, type='LoadTsImageFromNpz'), + dict( + data_key='image', + reduce_zero_label=True, + type='LoadAnnotationsNpz'), + dict(type='PackSegInputs'), + ], + type='GermanyCropDataset'), + num_workers=4, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +val_evaluator = dict( + iou_metrics=[ + 'mIoU', + 'mFscore', + ], type='IoUMetric') +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + name='visualizer', + type='SegLocalVisualizer', + vis_backends=[ + dict(type='LocalVisBackend'), + ]) +work_dir = 'save_germany' \ No newline at end of file diff --git a/finetune/configs/sos.py b/finetune/configs/sos.py new file mode 100644 index 0000000..a760b70 --- /dev/null +++ b/finetune/configs/sos.py @@ -0,0 +1,297 @@ +crop_size = ( + 256, + 256, +) +data_preprocessor = dict( + bgr_to_rgb=True, + mean=[ + 83.37, + 83.37, + 83.37, + ], + pad_val=0, + seg_pad_val=255, + size=( + 256, + 256, + ), + std=[ + 40.45, + 40.45, + 40.45, + ], + type='SegDataPreProcessor') +data_root = 'rs_datasets/' +dataset_type = 'SOSDataset' +default_hooks = dict( + checkpoint=dict( + by_epoch=False, + interval=1000, + max_keep_ckpts=2, + save_best='mIoU', + type='CheckpointHook'), + logger=dict(interval=50, log_metric_by_epoch=False, type='LoggerHook'), + param_scheduler=dict(type='ParamSchedulerHook'), + sampler_seed=dict(type='DistSamplerSeedHook'), + timer=dict(type='IterTimerHook'), + visualization=dict(type='SegVisualizationHook')) +default_scope = 'mmseg' +env_cfg = dict( + cudnn_benchmark=True, + dist_cfg=dict(backend='nccl'), + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0)) +img_ratios = [ + 0.5, + 0.75, + 1.0, + 1.25, + 1.5, + 1.75, +] +launcher = 'pytorch' +load_from = None +log_level = 'INFO' +log_processor = dict(by_epoch=False) +mean = [ + 83.37, + 83.37, + 83.37, +] +model = dict( + backbone=dict( + act_cfg=dict(type='GELU'), + attn_drop_rate=0.0, + downscale_indices=[ + -1, + ], + drop_path_rate=0.0, + drop_rate=0.1, + embed_dims=1024, + img_size=( + 256, + 256, + ), + in_channels=2, + interpolate_mode='bilinear', + mlp_ratio=4, + norm_cfg=dict(eps=1e-06, requires_grad=True, type='LN'), + norm_eval=False, + num_heads=16, + num_layers=24, + out_indices=( + 5, + 11, + 17, + 23, + ), + patch_size=4, + qkv_bias=True, + type='VisionTransformer', + with_cls_token=False), + data_preprocessor=dict( + bgr_to_rgb=True, + mean=[ + 83.37, + 83.37, + 83.37, + ], + pad_val=0, + seg_pad_val=255, + size=( + 256, + 256, + ), + std=[ + 40.45, + 40.45, + 40.45, + ], + type='SegDataPreProcessor'), + decode_head=dict( + align_corners=False, + channels=512, + dropout_ratio=0.1, + in_channels=[ + 1024, + 1024, + 1024, + 1024, + ], + in_index=[ + 0, + 1, + 2, + 3, + ], + loss_decode=dict( + loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=True), + norm_cfg=dict(requires_grad=True, type='SyncBN'), + num_classes=2, + pool_scales=( + 1, + 2, + 3, + 6, + ), + type='UPerHead'), + neck=dict( + in_channels=[ + 1024, + 1024, + 1024, + 1024, + ], + out_channels=1024, + scales=[ + 4, + 2, + 1, + 0.5, + ], + type='MultiLevelNeck'), + pretrained= + 'pretrain/skysensepp_mmcvt_s1.pth', + test_cfg=dict(mode='whole'), + train_cfg=dict(), + type='EncoderDecoder') +norm_cfg = dict(requires_grad=True, type='SyncBN') +optim_wrapper = dict( + optimizer=dict( + betas=( + 0.9, + 0.999, + ), lr=6e-05, type='AdamW', weight_decay=0.01), + paramwise_cfg=dict( + custom_keys=dict( + cls_token=dict(decay_mult=0.0), + norm=dict(decay_mult=0.0), + pos_embed=dict(decay_mult=0.0))), + type='OptimWrapper') +param_scheduler = [ + dict( + begin=0, by_epoch=False, end=1000, start_factor=1e-06, + type='LinearLR'), + dict( + begin=1000, + by_epoch=False, + end=10000, + eta_min=0.0, + power=1.0, + type='PolyLR'), +] +randomness = dict(seed=0) +resume = False +std = [ + 40.45, + 40.45, + 40.45, +] +test_cfg = dict(type='TestLoop') +test_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file= + 'rs_datasets/sos/sos_test_sentinel.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(reduce_zero_label=False, type='LoadAnnotations'), + dict(type='PackSegInputs'), + ], + type='SOSDataset'), + num_workers=2, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +test_evaluator = dict( + iou_metrics=[ + 'mIoU', + 'mFscore', + ], type='IoUMetric') +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(reduce_zero_label=False, type='LoadAnnotations'), + dict(type='PackSegInputs'), +] +train_cfg = dict(max_iters=0, type='IterBasedTrainLoop', val_interval=500) +train_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file= + 'rs_datasets/sos/sos_train_sentinel.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(reduce_zero_label=False, type='LoadAnnotations'), + dict(prob=0.5, type='RandomFlip'), + dict(type='PackSegInputs'), + ], + type='SOSDataset'), + num_workers=2, + persistent_workers=True, + sampler=dict(shuffle=True, type='InfiniteSampler')) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(reduce_zero_label=False, type='LoadAnnotations'), + dict(prob=0.5, type='RandomFlip'), + dict(type='PackSegInputs'), +] +tta_model = dict(type='SegTTAModel') +tta_pipeline = [ + dict(backend_args=None, type='LoadImageFromFile'), + dict( + transforms=[ + [ + dict(keep_ratio=True, scale_factor=0.5, type='Resize'), + dict(keep_ratio=True, scale_factor=0.75, type='Resize'), + dict(keep_ratio=True, scale_factor=1.0, type='Resize'), + dict(keep_ratio=True, scale_factor=1.25, type='Resize'), + dict(keep_ratio=True, scale_factor=1.5, type='Resize'), + dict(keep_ratio=True, scale_factor=1.75, type='Resize'), + ], + [ + dict(direction='horizontal', prob=0.0, type='RandomFlip'), + dict(direction='horizontal', prob=1.0, type='RandomFlip'), + ], + [ + dict(type='LoadAnnotations'), + ], + [ + dict(type='PackSegInputs'), + ], + ], + type='TestTimeAug'), +] +val_cfg = dict(type='ValLoop') +val_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file= + 'rs_datasets/sos/sos_test_sentinel.json', + data_prefix=dict(img_path='images', seg_map_path='idx_labels'), + data_root='rs_datasets/', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(reduce_zero_label=False, type='LoadAnnotations'), + dict(type='PackSegInputs'), + ], + type='SOSDataset'), + num_workers=2, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +val_evaluator = dict( + iou_metrics=[ + 'mIoU', + 'mFscore', + ], type='IoUMetric') +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + name='visualizer', + type='SegLocalVisualizer', + vis_backends=[ + dict(type='LocalVisBackend'), + ]) +work_dir = 'save_sos/' \ No newline at end of file diff --git a/finetune/mmseg/__init__.py b/finetune/mmseg/__init__.py new file mode 100644 index 0000000..5fcb84e --- /dev/null +++ b/finetune/mmseg/__init__.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import mmcv +import mmengine +from packaging.version import parse + +from .version import __version__, version_info + +MMCV_MIN = '2.0.0rc4' +MMCV_MAX = '2.2.0' +MMENGINE_MIN = '0.5.0' +MMENGINE_MAX = '1.0.0' + + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int]: The version info in digits (integers). + """ + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) + else: + release.extend([0, 0]) + return tuple(release) + + +mmcv_min_version = digit_version(MMCV_MIN) +mmcv_max_version = digit_version(MMCV_MAX) +mmcv_version = digit_version(mmcv.__version__) + + +assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>=2.0.0rc4.' + +mmengine_min_version = digit_version(MMENGINE_MIN) +mmengine_max_version = digit_version(MMENGINE_MAX) +mmengine_version = digit_version(mmengine.__version__) + +assert (mmengine_min_version <= mmengine_version < mmengine_max_version), \ + f'MMEngine=={mmengine.__version__} is used but incompatible. ' \ + f'Please install mmengine>={mmengine_min_version}, '\ + f'<{mmengine_max_version}.' + +__all__ = ['__version__', 'version_info', 'digit_version'] diff --git a/finetune/mmseg/apis/__init__.py b/finetune/mmseg/apis/__init__.py new file mode 100644 index 0000000..b50a266 --- /dev/null +++ b/finetune/mmseg/apis/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .inference import inference_model, init_model, show_result_pyplot +from .mmseg_inferencer import MMSegInferencer +from .remote_sense_inferencer import RSImage, RSInferencer + +__all__ = [ + 'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer', + 'RSInferencer', 'RSImage' +] diff --git a/finetune/mmseg/apis/inference.py b/finetune/mmseg/apis/inference.py new file mode 100644 index 0000000..aab11d1 --- /dev/null +++ b/finetune/mmseg/apis/inference.py @@ -0,0 +1,189 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from pathlib import Path +from typing import Optional, Union + +import mmcv +import numpy as np +import torch +from mmengine import Config +from mmengine.registry import init_default_scope +from mmengine.runner import load_checkpoint +from mmengine.utils import mkdir_or_exist + +from mmseg.models import BaseSegmentor +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample +from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette +from mmseg.visualization import SegLocalVisualizer +from .utils import ImageType, _preprare_data + + +def init_model(config: Union[str, Path, Config], + checkpoint: Optional[str] = None, + device: str = 'cuda:0', + cfg_options: Optional[dict] = None): + """Initialize a segmentor from config file. + + Args: + config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, + :obj:`Path`, or the config object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str, optional) CPU/CUDA device option. Default 'cuda:0'. + Use 'cpu' for loading model on CPU. + cfg_options (dict, optional): Options to override some settings in + the used config. + Returns: + nn.Module: The constructed segmentor. + """ + if isinstance(config, (str, Path)): + config = Config.fromfile(config) + elif not isinstance(config, Config): + raise TypeError('config must be a filename or Config object, ' + 'but got {}'.format(type(config))) + if cfg_options is not None: + config.merge_from_dict(cfg_options) + if config.model.type == 'EncoderDecoder': + if 'init_cfg' in config.model.backbone: + config.model.backbone.init_cfg = None + elif config.model.type == 'MultimodalEncoderDecoder': + for k, v in config.model.items(): + if isinstance(v, dict) and 'init_cfg' in v: + config.model[k].init_cfg = None + config.model.pretrained = None + config.model.train_cfg = None + init_default_scope(config.get('default_scope', 'mmseg')) + + model = MODELS.build(config.model) + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + dataset_meta = checkpoint['meta'].get('dataset_meta', None) + # save the dataset_meta in the model for convenience + if 'dataset_meta' in checkpoint.get('meta', {}): + # mmseg 1.x + model.dataset_meta = dataset_meta + elif 'CLASSES' in checkpoint.get('meta', {}): + # < mmseg 1.x + classes = checkpoint['meta']['CLASSES'] + palette = checkpoint['meta']['PALETTE'] + model.dataset_meta = {'classes': classes, 'palette': palette} + else: + warnings.simplefilter('once') + warnings.warn( + 'dataset_meta or class names are not saved in the ' + 'checkpoint\'s meta data, classes and palette will be' + 'set according to num_classes ') + num_classes = model.decode_head.num_classes + dataset_name = None + for name in dataset_aliases.keys(): + if len(get_classes(name)) == num_classes: + dataset_name = name + break + if dataset_name is None: + warnings.warn( + 'No suitable dataset found, use Cityscapes by default') + dataset_name = 'cityscapes' + model.dataset_meta = { + 'classes': get_classes(dataset_name), + 'palette': get_palette(dataset_name) + } + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +def inference_model(model: BaseSegmentor, + img: ImageType) -> Union[SegDataSample, SampleList]: + """Inference image(s) with the segmentor. + + Args: + model (nn.Module): The loaded segmentor. + imgs (str/ndarray or list[str/ndarray]): Either image files or loaded + images. + + Returns: + :obj:`SegDataSample` or list[:obj:`SegDataSample`]: + If imgs is a list or tuple, the same length list type results + will be returned, otherwise return the segmentation results directly. + """ + # prepare data + data, is_batch = _preprare_data(img, model) + + # forward the model + with torch.no_grad(): + results = model.test_step(data) + + return results if is_batch else results[0] + + +def show_result_pyplot(model: BaseSegmentor, + img: Union[str, np.ndarray], + result: SegDataSample, + opacity: float = 0.5, + title: str = '', + draw_gt: bool = True, + draw_pred: bool = True, + wait_time: float = 0, + show: bool = True, + with_labels: Optional[bool] = True, + save_dir=None, + out_file=None): + """Visualize the segmentation results on the image. + + Args: + model (nn.Module): The loaded segmentor. + img (str or np.ndarray): Image filename or loaded image. + result (SegDataSample): The prediction SegDataSample result. + opacity(float): Opacity of painted segmentation map. + Default 0.5. Must be in (0, 1] range. + title (str): The title of pyplot figure. + Default is ''. + draw_gt (bool): Whether to draw GT SegDataSample. Default to True. + draw_pred (bool): Whether to draw Prediction SegDataSample. + Defaults to True. + wait_time (float): The interval of show (s). 0 is the special value + that means "forever". Defaults to 0. + show (bool): Whether to display the drawn image. + Default to True. + with_labels(bool, optional): Add semantic labels in visualization + result, Default to True. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + out_file (str, optional): Path to output file. Default to None. + + + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + if hasattr(model, 'module'): + model = model.module + if isinstance(img, str): + image = mmcv.imread(img, channel_order='rgb') + else: + image = img + if save_dir is not None: + mkdir_or_exist(save_dir) + # init visualizer + visualizer = SegLocalVisualizer( + vis_backends=[dict(type='LocalVisBackend')], + save_dir=save_dir, + alpha=opacity) + visualizer.dataset_meta = dict( + classes=model.dataset_meta['classes'], + palette=model.dataset_meta['palette']) + visualizer.add_datasample( + name=title, + image=image, + data_sample=result, + draw_gt=draw_gt, + draw_pred=draw_pred, + wait_time=wait_time, + out_file=out_file, + show=show, + with_labels=with_labels) + vis_img = visualizer.get_image() + + return vis_img diff --git a/finetune/mmseg/apis/mmseg_inferencer.py b/finetune/mmseg/apis/mmseg_inferencer.py new file mode 100644 index 0000000..02a198b --- /dev/null +++ b/finetune/mmseg/apis/mmseg_inferencer.py @@ -0,0 +1,382 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from typing import List, Optional, Sequence, Union + +import mmcv +import mmengine +import numpy as np +import torch +import torch.nn as nn +from mmcv.transforms import Compose +from mmengine.infer.infer import BaseInferencer, ModelType +from mmengine.model import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.runner.checkpoint import _load_checkpoint_to_model +from PIL import Image + +from mmseg.structures import SegDataSample +from mmseg.utils import ConfigType, SampleList, get_classes, get_palette +from mmseg.visualization import SegLocalVisualizer + +InputType = Union[str, np.ndarray] +InputsType = Union[InputType, Sequence[InputType]] +PredType = Union[SegDataSample, SampleList] + + +class MMSegInferencer(BaseInferencer): + """Semantic segmentation inferencer, provides inference and visualization + interfaces. Note: MMEngine >= 0.5.0 is required. + + Args: + model (str, optional): Path to the config file or the model name + defined in metafile. Take the `mmseg metafile `_ + as an example the `model` could be + "fcn_r50-d8_4xb2-40k_cityscapes-512x1024", and the weights of model + will be download automatically. If use config file, like + "configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py", the + `weights` should be defined. + weights (str, optional): Path to the checkpoint. If it is not specified + and model is a model name of metafile, the weights will be loaded + from metafile. Defaults to None. + classes (list, optional): Input classes for result rendering, as the + prediction of segmentation model is a segment map with label + indices, `classes` is a list which includes items responding to the + label indices. If classes is not defined, visualizer will take + `cityscapes` classes by default. Defaults to None. + palette (list, optional): Input palette for result rendering, which is + a list of color palette responding to the classes. If palette is + not defined, visualizer will take `cityscapes` palette by default. + Defaults to None. + dataset_name (str, optional): `Dataset name or alias `_ + visulizer will use the meta information of the dataset i.e. classes + and palette, but the `classes` and `palette` have higher priority. + Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + scope (str, optional): The scope of the model. Defaults to 'mmseg'. + """ # noqa + + preprocess_kwargs: set = set() + forward_kwargs: set = {'mode', 'out_dir'} + visualize_kwargs: set = { + 'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis', + 'with_labels' + } + postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'} + + def __init__(self, + model: Union[ModelType, str], + weights: Optional[str] = None, + classes: Optional[Union[str, List]] = None, + palette: Optional[Union[str, List]] = None, + dataset_name: Optional[str] = None, + device: Optional[str] = None, + scope: Optional[str] = 'mmseg') -> None: + # A global counter tracking the number of images processes, for + # naming of the output images + self.num_visualized_imgs = 0 + self.num_pred_imgs = 0 + init_default_scope(scope if scope else 'mmseg') + super().__init__( + model=model, weights=weights, device=device, scope=scope) + + if device == 'cpu' or not torch.cuda.is_available(): + self.model = revert_sync_batchnorm(self.model) + + assert isinstance(self.visualizer, SegLocalVisualizer) + self.visualizer.set_dataset_meta(classes, palette, dataset_name) + + def _load_weights_to_model(self, model: nn.Module, + checkpoint: Optional[dict], + cfg: Optional[ConfigType]) -> None: + """Loading model weights and meta information from cfg and checkpoint. + + Subclasses could override this method to load extra meta information + from ``checkpoint`` and ``cfg`` to model. + + Args: + model (nn.Module): Model to load weights and meta information. + checkpoint (dict, optional): The loaded checkpoint. + cfg (Config or ConfigDict, optional): The loaded config. + """ + + if checkpoint is not None: + _load_checkpoint_to_model(model, checkpoint) + checkpoint_meta = checkpoint.get('meta', {}) + # save the dataset_meta in the model for convenience + if 'dataset_meta' in checkpoint_meta: + # mmsegmentation 1.x + model.dataset_meta = { + 'classes': checkpoint_meta['dataset_meta'].get('classes'), + 'palette': checkpoint_meta['dataset_meta'].get('palette') + } + elif 'CLASSES' in checkpoint_meta: + # mmsegmentation 0.x + classes = checkpoint_meta['CLASSES'] + palette = checkpoint_meta.get('PALETTE', None) + model.dataset_meta = {'classes': classes, 'palette': palette} + else: + warnings.warn( + 'dataset_meta or class names are not saved in the ' + 'checkpoint\'s meta data, use classes of Cityscapes by ' + 'default.') + model.dataset_meta = { + 'classes': get_classes('cityscapes'), + 'palette': get_palette('cityscapes') + } + else: + warnings.warn('Checkpoint is not loaded, and the inference ' + 'result is calculated by the randomly initialized ' + 'model!') + warnings.warn( + 'weights is None, use cityscapes classes by default.') + model.dataset_meta = { + 'classes': get_classes('cityscapes'), + 'palette': get_palette('cityscapes') + } + + def __call__(self, + inputs: InputsType, + return_datasamples: bool = False, + batch_size: int = 1, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + out_dir: str = '', + img_out_dir: str = 'vis', + pred_out_dir: str = 'pred', + **kwargs) -> dict: + """Call the inferencer. + + Args: + inputs (Union[list, str, np.ndarray]): Inputs for the inferencer. + return_datasamples (bool): Whether to return results as + :obj:`SegDataSample`. Defaults to False. + batch_size (int): Batch size. Defaults to 1. + show (bool): Whether to display the rendering color segmentation + mask in a popup window. Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + out_dir (str): Output directory of inference results. Defaults + to ''. + img_out_dir (str): Subdirectory of `out_dir`, used to save + rendering color segmentation mask, so `out_dir` must be defined + if you would like to save predicted mask. Defaults to 'vis'. + pred_out_dir (str): Subdirectory of `out_dir`, used to save + predicted mask file, so `out_dir` must be defined if you would + like to save predicted mask. Defaults to 'pred'. + + **kwargs: Other keyword arguments passed to :meth:`preprocess`, + :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. + Each key in kwargs should be in the corresponding set of + ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` + and ``postprocess_kwargs``. + + + Returns: + dict: Inference and visualization results. + """ + + if out_dir != '': + pred_out_dir = osp.join(out_dir, pred_out_dir) + img_out_dir = osp.join(out_dir, img_out_dir) + else: + pred_out_dir = '' + img_out_dir = '' + + return super().__call__( + inputs=inputs, + return_datasamples=return_datasamples, + batch_size=batch_size, + show=show, + wait_time=wait_time, + img_out_dir=img_out_dir, + pred_out_dir=pred_out_dir, + return_vis=return_vis, + **kwargs) + + def visualize(self, + inputs: list, + preds: List[dict], + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + img_out_dir: str = '', + opacity: float = 0.8, + with_labels: Optional[bool] = True) -> List[np.ndarray]: + """Visualize predictions. + + Args: + inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. + preds (Any): Predictions of the model. + show (bool): Whether to display the image in a popup window. + Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + img_out_dir (str): Output directory of rendering prediction i.e. + color segmentation mask. Defaults: '' + opacity (int, float): The transparency of segmentation mask. + Defaults to 0.8. + + Returns: + List[np.ndarray]: Visualization results. + """ + if not show and img_out_dir == '' and not return_vis: + return None + if self.visualizer is None: + raise ValueError('Visualization needs the "visualizer" term' + 'defined in the config, but got None.') + + self.visualizer.set_dataset_meta(**self.model.dataset_meta) + self.visualizer.alpha = opacity + + results = [] + + for single_input, pred in zip(inputs, preds): + if isinstance(single_input, str): + img_bytes = mmengine.fileio.get(single_input) + img = mmcv.imfrombytes(img_bytes) + img = img[:, :, ::-1] + img_name = osp.basename(single_input) + elif isinstance(single_input, np.ndarray): + img = single_input.copy() + img_num = str(self.num_visualized_imgs).zfill(8) + '_vis' + img_name = f'{img_num}.jpg' + else: + raise ValueError('Unsupported input type:' + f'{type(single_input)}') + + out_file = osp.join(img_out_dir, img_name) if img_out_dir != ''\ + else None + + self.visualizer.add_datasample( + img_name, + img, + pred, + show=show, + wait_time=wait_time, + draw_gt=False, + draw_pred=True, + out_file=out_file, + with_labels=with_labels) + if return_vis: + results.append(self.visualizer.get_image()) + self.num_visualized_imgs += 1 + + return results if return_vis else None + + def postprocess(self, + preds: PredType, + visualization: List[np.ndarray], + return_datasample: bool = False, + pred_out_dir: str = '') -> dict: + """Process the predictions and visualization results from ``forward`` + and ``visualize``. + + This method should be responsible for the following tasks: + + 1. Pack the predictions and visualization results and return them. + 2. Save the predictions, if it needed. + + Args: + preds (List[Dict]): Predictions of the model. + visualization (List[np.ndarray]): The list of rendering color + segmentation mask. + return_datasample (bool): Whether to return results as datasamples. + Defaults to False. + pred_out_dir: File to save the inference results w/o + visualization. If left as empty, no file will be saved. + Defaults to ''. + + Returns: + dict: Inference and visualization results with key ``predictions`` + and ``visualization`` + + - ``visualization (Any)``: Returned by :meth:`visualize` + - ``predictions`` (List[np.ndarray], np.ndarray): Returned by + :meth:`forward` and processed in :meth:`postprocess`. + If ``return_datasample=False``, it will be the segmentation mask + with label indice. + """ + if return_datasample: + if len(preds) == 1: + return preds[0] + else: + return preds + + results_dict = {} + + results_dict['predictions'] = [] + results_dict['visualization'] = [] + + for i, pred in enumerate(preds): + pred_data = dict() + if 'pred_sem_seg' in pred.keys(): + pred_data['sem_seg'] = pred.pred_sem_seg.numpy().data[0] + elif 'pred_depth_map' in pred.keys(): + pred_data['depth_map'] = pred.pred_depth_map.numpy().data[0] + + if visualization is not None: + vis = visualization[i] + results_dict['visualization'].append(vis) + if pred_out_dir != '': + mmengine.mkdir_or_exist(pred_out_dir) + for key, data in pred_data.items(): + post_fix = '_pred.png' if key == 'sem_seg' else '_pred.npy' + img_name = str(self.num_pred_imgs).zfill(8) + post_fix + img_path = osp.join(pred_out_dir, img_name) + if key == 'sem_seg': + output = Image.fromarray(data.astype(np.uint8)) + output.save(img_path) + else: + np.save(img_path, data) + pred_data = next(iter(pred_data.values())) + results_dict['predictions'].append(pred_data) + self.num_pred_imgs += 1 + + if len(results_dict['predictions']) == 1: + results_dict['predictions'] = results_dict['predictions'][0] + if visualization is not None: + results_dict['visualization'] = \ + results_dict['visualization'][0] + return results_dict + + def _init_pipeline(self, cfg: ConfigType) -> Compose: + """Initialize the test pipeline. + + Return a pipeline to handle various input data, such as ``str``, + ``np.ndarray``. It is an abstract method in BaseInferencer, and should + be implemented in subclasses. + + The returned pipeline will be used to process a single data. + It will be used in :meth:`preprocess` like this: + + .. code-block:: python + def preprocess(self, inputs, batch_size, **kwargs): + ... + dataset = map(self.pipeline, dataset) + ... + """ + pipeline_cfg = cfg.test_dataloader.dataset.pipeline + # Loading annotations is also not applicable + for transform in ('LoadAnnotations', 'LoadDepthAnnotation'): + idx = self._get_transform_idx(pipeline_cfg, transform) + if idx != -1: + del pipeline_cfg[idx] + + load_img_idx = self._get_transform_idx(pipeline_cfg, + 'LoadImageFromFile') + if load_img_idx == -1: + raise ValueError( + 'LoadImageFromFile is not found in the test pipeline') + pipeline_cfg[load_img_idx]['type'] = 'InferencerLoader' + return Compose(pipeline_cfg) + + def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int: + """Returns the index of the transform in a pipeline. + + If the transform is not found, returns -1. + """ + for i, transform in enumerate(pipeline_cfg): + if transform['type'] == name: + return i + return -1 diff --git a/finetune/mmseg/apis/remote_sense_inferencer.py b/finetune/mmseg/apis/remote_sense_inferencer.py new file mode 100644 index 0000000..6726c6a --- /dev/null +++ b/finetune/mmseg/apis/remote_sense_inferencer.py @@ -0,0 +1,279 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import threading +from queue import Queue +from typing import List, Optional, Tuple + +import numpy as np +import torch +from mmengine import Config +from mmengine.model import BaseModel +from mmengine.registry import init_default_scope +from mmengine.runner import load_checkpoint + +try: + from osgeo import gdal +except ImportError: + gdal = None + +from mmseg.registry import MODELS +from .utils import _preprare_data + + +class RSImage: + """Remote sensing image class. + + Args: + img (str or gdal.Dataset): Image file path or gdal.Dataset. + """ + + def __init__(self, image): + self.dataset = gdal.Open(image, gdal.GA_ReadOnly) if isinstance( + image, str) else image + assert isinstance(self.dataset, gdal.Dataset), \ + f'{image} is not a image' + self.width = self.dataset.RasterXSize + self.height = self.dataset.RasterYSize + self.channel = self.dataset.RasterCount + self.trans = self.dataset.GetGeoTransform() + self.proj = self.dataset.GetProjection() + self.band_list = [] + self.band_list.extend( + self.dataset.GetRasterBand(c + 1) for c in range(self.channel)) + self.grids = [] + + def read(self, grid: Optional[List] = None) -> np.ndarray: + """Read image data. If grid is None, read the whole image. + + Args: + grid (Optional[List], optional): Grid to read. Defaults to None. + Returns: + np.ndarray: Image data. + """ + if grid is None: + return np.einsum('ijk->jki', self.dataset.ReadAsArray()) + assert len( + grid) >= 4, 'grid must be a list containing at least 4 elements' + data = self.dataset.ReadAsArray(*grid[:4]) + if data.ndim == 2: + data = data[np.newaxis, ...] + return np.einsum('ijk->jki', data) + + def write(self, data: Optional[np.ndarray], grid: Optional[List] = None): + """Write image data. + + Args: + grid (Optional[List], optional): Grid to write. Defaults to None. + data (Optional[np.ndarray], optional): Data to write. + Defaults to None. + + Raises: + ValueError: Either grid or data must be provided. + """ + if grid is not None: + assert len(grid) == 8, 'grid must be a list of 8 elements' + for band in self.band_list: + band.WriteArray( + data[grid[5]:grid[5] + grid[7], grid[4]:grid[4] + grid[6]], + grid[0] + grid[4], grid[1] + grid[5]) + elif data is not None: + for i in range(self.channel): + self.band_list[i].WriteArray(data[..., i]) + else: + raise ValueError('Either grid or data must be provided.') + + def create_seg_map(self, output_path: Optional[str] = None): + if output_path is None: + output_path = 'output_label.tif' + driver = gdal.GetDriverByName('GTiff') + seg_map = driver.Create(output_path, self.width, self.height, 1, + gdal.GDT_Byte) + seg_map.SetGeoTransform(self.trans) + seg_map.SetProjection(self.proj) + seg_map_img = RSImage(seg_map) + seg_map_img.path = output_path + return seg_map_img + + def create_grids(self, + window_size: Tuple[int, int], + stride: Tuple[int, int] = (0, 0)): + """Create grids for image inference. + + Args: + window_size (Tuple[int, int]): the size of the sliding window. + stride (Tuple[int, int], optional): the stride of the sliding + window. Defaults to (0, 0). + + Raises: + AssertionError: window_size must be a tuple of 2 elements. + AssertionError: stride must be a tuple of 2 elements. + """ + assert len( + window_size) == 2, 'window_size must be a tuple of 2 elements' + assert len(stride) == 2, 'stride must be a tuple of 2 elements' + win_w, win_h = window_size + stride_x, stride_y = stride + + stride_x = win_w if stride_x == 0 else stride_x + stride_y = win_h if stride_y == 0 else stride_y + + x_half_overlap = (win_w - stride_x + 1) // 2 + y_half_overlap = (win_h - stride_y + 1) // 2 + + for y in range(0, self.height, stride_y): + y_end = y + win_h >= self.height + y_offset = self.height - win_h if y_end else y + y_size = win_h + y_crop_off = 0 if y_offset == 0 else y_half_overlap + y_crop_size = y_size if y_end else win_h - y_crop_off + + for x in range(0, self.width, stride_x): + x_end = x + win_w >= self.width + x_offset = self.width - win_w if x_end else x + x_size = win_w + x_crop_off = 0 if x_offset == 0 else x_half_overlap + x_crop_size = x_size if x_end else win_w - x_crop_off + + self.grids.append([ + x_offset, y_offset, x_size, y_size, x_crop_off, y_crop_off, + x_crop_size, y_crop_size + ]) + + +class RSInferencer: + """Remote sensing inference class. + + Args: + model (BaseModel): The loaded model. + batch_size (int, optional): Batch size. Defaults to 1. + thread (int, optional): Number of threads. Defaults to 1. + """ + + def __init__(self, model: BaseModel, batch_size: int = 1, thread: int = 1): + self.model = model + self.batch_size = batch_size + self.END_FLAG = object() + self.read_buffer = Queue(self.batch_size) + self.write_buffer = Queue(self.batch_size) + self.thread = thread + + @classmethod + def from_config_path(cls, + config_path: str, + checkpoint_path: str, + batch_size: int = 1, + thread: int = 1, + device: Optional[str] = 'cpu'): + """Initialize a segmentor from config file. + + Args: + config_path (str): Config file path. + checkpoint_path (str): Checkpoint path. + batch_size (int, optional): Batch size. Defaults to 1. + """ + init_default_scope('mmseg') + cfg = Config.fromfile(config_path) + model = MODELS.build(cfg.model) + model.cfg = cfg + load_checkpoint(model, checkpoint_path, map_location='cpu') + model.to(device) + model.eval() + return cls(model, batch_size, thread) + + @classmethod + def from_model(cls, + model: BaseModel, + checkpoint_path: Optional[str] = None, + batch_size: int = 1, + thread: int = 1, + device: Optional[str] = 'cpu'): + """Initialize a segmentor from model. + + Args: + model (BaseModel): The loaded model. + checkpoint_path (Optional[str]): Checkpoint path. + batch_size (int, optional): Batch size. Defaults to 1. + """ + if checkpoint_path is not None: + load_checkpoint(model, checkpoint_path, map_location='cpu') + model.to(device) + return cls(model, batch_size, thread) + + def read(self, + image: RSImage, + window_size: Tuple[int, int], + strides: Tuple[int, int] = (0, 0)): + """Load image data to read buffer. + + Args: + image (RSImage): The image to read. + window_size (Tuple[int, int]): The size of the sliding window. + strides (Tuple[int, int], optional): The stride of the sliding + window. Defaults to (0, 0). + """ + image.create_grids(window_size, strides) + for grid in image.grids: + self.read_buffer.put([grid, image.read(grid=grid)]) + self.read_buffer.put(self.END_FLAG) + + def inference(self): + """Inference image data from read buffer and put the result to write + buffer.""" + while True: + item = self.read_buffer.get() + if item == self.END_FLAG: + self.read_buffer.put(self.END_FLAG) + self.write_buffer.put(item) + break + data, _ = _preprare_data(item[1], self.model) + with torch.no_grad(): + result = self.model.test_step(data) + item[1] = result[0].pred_sem_seg.cpu().data.numpy()[0] + self.write_buffer.put(item) + self.read_buffer.task_done() + + def write(self, image: RSImage, output_path: Optional[str] = None): + """Write image data from write buffer. + + Args: + image (RSImage): The image to write. + output_path (Optional[str], optional): The path to save the + segmentation map. Defaults to None. + """ + seg_map = image.create_seg_map(output_path) + while True: + item = self.write_buffer.get() + if item == self.END_FLAG: + break + seg_map.write(data=item[1], grid=item[0]) + self.write_buffer.task_done() + + def run(self, + image: RSImage, + window_size: Tuple[int, int], + strides: Tuple[int, int] = (0, 0), + output_path: Optional[str] = None): + """Run inference with multi-threading. + + Args: + image (RSImage): The image to inference. + window_size (Tuple[int, int]): The size of the sliding window. + strides (Tuple[int, int], optional): The stride of the sliding + window. Defaults to (0, 0). + output_path (Optional[str], optional): The path to save the + segmentation map. Defaults to None. + """ + read_thread = threading.Thread( + target=self.read, args=(image, window_size, strides)) + read_thread.start() + inference_threads = [] + for _ in range(self.thread): + inference_thread = threading.Thread(target=self.inference) + inference_thread.start() + inference_threads.append(inference_thread) + write_thread = threading.Thread( + target=self.write, args=(image, output_path)) + write_thread.start() + read_thread.join() + for inference_thread in inference_threads: + inference_thread.join() + write_thread.join() diff --git a/finetune/mmseg/apis/utils.py b/finetune/mmseg/apis/utils.py new file mode 100644 index 0000000..4cf8775 --- /dev/null +++ b/finetune/mmseg/apis/utils.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict +from typing import Sequence, Union + +import numpy as np +from mmengine.dataset import Compose +from mmengine.model import BaseModel + +ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]] + + +def _preprare_data(imgs: ImageType, model: BaseModel): + + cfg = model.cfg + for t in cfg.test_pipeline: + if t.get('type') == 'LoadAnnotations': + cfg.test_pipeline.remove(t) + + is_batch = True + if not isinstance(imgs, (list, tuple)): + imgs = [imgs] + is_batch = False + + if isinstance(imgs[0], np.ndarray): + cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray' + + # TODO: Consider using the singleton pattern to avoid building + # a pipeline for each inference + pipeline = Compose(cfg.test_pipeline) + + data = defaultdict(list) + for img in imgs: + if isinstance(img, np.ndarray): + data_ = dict(img=img) + else: + data_ = dict(img_path=img) + data_ = pipeline(data_) + data['inputs'].append(data_['inputs']) + data['data_samples'].append(data_['data_samples']) + + return data, is_batch diff --git a/finetune/mmseg/datasets/__init__.py b/finetune/mmseg/datasets/__init__.py new file mode 100644 index 0000000..dc903ec --- /dev/null +++ b/finetune/mmseg/datasets/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# yapf: disable +from .atlantic import AtlanticDataset +from .c2sfloods import C2SFloodDataset +from .cabuar import CABURADataset +from .germany import GermanyCropDataset +from .sos import SOSDataset +# yapf: disable +from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad, + BioMedical3DRandomCrop, BioMedical3DRandomFlip, + BioMedicalGaussianBlur, BioMedicalGaussianNoise, + BioMedicalRandomGamma, ConcatCDInput, GenerateEdge, + LoadAnnotations, LoadBiomedicalAnnotation, + LoadBiomedicalData, LoadBiomedicalImageFromFile, + LoadImageFromNDArray, LoadMultipleRSImageFromFile, + LoadSingleRSImageFromFile, PackSegInputs, + PhotoMetricDistortion, RandomCrop, RandomCutOut, + RandomMosaic, RandomRotate, RandomRotFlip, Rerange, + ResizeShortestEdge, ResizeToMultiple, RGB2Gray, + SegRescale) + +# yapf: enable +__all__ = [ + 'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip', + 'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', + 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', + 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', + 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', + 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', + 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', + 'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip', + 'Albu', 'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile', + 'ConcatCDInput', 'AtlanticDataset', 'C2SFloodDataset', + 'CABURADataset', 'GermanyCropDataset', 'SOSDataset' +] diff --git a/finetune/mmseg/datasets/atlantic.py b/finetune/mmseg/datasets/atlantic.py new file mode 100644 index 0000000..eb3025b --- /dev/null +++ b/finetune/mmseg/datasets/atlantic.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +from typing import List + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + +@DATASETS.register_module() +class AtlanticDataset(BaseSegDataset): + METAINFO = dict( + classes=('Background', 'Deforestation area'), + palette=[[0,0,0], [255,255,255]] + ) + def __init__(self, + img_suffix='.tif', + seg_map_suffix='.tif', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + if not osp.isdir(self.ann_file) and self.ann_file: + assert osp.isfile(self.ann_file), \ + f'Failed to load `ann_file` {self.ann_file}' + lines = json.load(open(self.ann_file)) + for line in lines: + # img_name = line.strip() + data_info = dict( + img_path=osp.join(self.data_root, line['s2_path'])) + if 'target_path' in line.keys(): + data_info['seg_map_path'] = osp.join(self.data_root, line['target_path']) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list diff --git a/finetune/mmseg/datasets/basesegdataset.py b/finetune/mmseg/datasets/basesegdataset.py new file mode 100644 index 0000000..9c4668c --- /dev/null +++ b/finetune/mmseg/datasets/basesegdataset.py @@ -0,0 +1,552 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from typing import Callable, Dict, List, Optional, Sequence, Union + +import mmengine +import mmengine.fileio as fileio +import numpy as np +from mmengine.dataset import BaseDataset, Compose + +from mmseg.registry import DATASETS + + +@DATASETS.register_module() +class BaseSegDataset(BaseDataset): + """Custom dataset for semantic segmentation. An example of file structure + is as followed. + + .. code-block:: none + + ├── data + │ ├── my_dataset + │ │ ├── img_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── ann_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{seg_map_suffix} + │ │ │ │ ├── yyy{seg_map_suffix} + │ │ │ │ ├── zzz{seg_map_suffix} + │ │ │ ├── val + + The img/gt_semantic_seg pair of BaseSegDataset should be of the same + except suffix. A valid img/gt_semantic_seg filename pair should be like + ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included + in the suffix). If split is given, then ``xxx`` is specified in txt file. + Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. + Please refer to ``docs/en/tutorials/new_dataset.md`` for more details. + + + Args: + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as + specify classes to load. Defaults to None. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to None. + data_prefix (dict, optional): Prefix for training data. Defaults to + dict(img_path=None, seg_map_path=None). + img_suffix (str): Suffix of images. Default: '.jpg' + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=True``. Defaults to False. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + ignore_index (int): The label index to be ignored. Default: 255 + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default to False. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + METAINFO: dict = dict() + + def __init__(self, + ann_file: str = '', + img_suffix='.jpg', + seg_map_suffix='.png', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img_path='', seg_map_path=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000, + ignore_index: int = 255, + reduce_zero_label: bool = False, + backend_args: Optional[dict] = None) -> None: + + self.img_suffix = img_suffix + self.seg_map_suffix = seg_map_suffix + self.ignore_index = ignore_index + self.reduce_zero_label = reduce_zero_label + self.backend_args = backend_args.copy() if backend_args else None + + self.data_root = data_root + self.data_prefix = copy.copy(data_prefix) + self.ann_file = ann_file + self.filter_cfg = copy.deepcopy(filter_cfg) + self._indices = indices + self.serialize_data = serialize_data + self.test_mode = test_mode + self.max_refetch = max_refetch + self.data_list: List[dict] = [] + self.data_bytes: np.ndarray + + # Set meta information. + self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) + + # Get label map for custom classes + new_classes = self._metainfo.get('classes', None) + self.label_map = self.get_label_map(new_classes) + self._metainfo.update( + dict( + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label)) + + # Update palette based on label map or generate palette + # if it is not defined + updated_palette = self._update_palette() + self._metainfo.update(dict(palette=updated_palette)) + + # Join paths. + if self.data_root is not None: + self._join_prefix() + + # Build pipeline. + self.pipeline = Compose(pipeline) + # Full initialize the dataset. + if not lazy_init: + self.full_init() + + if test_mode: + assert self._metainfo.get('classes') is not None, \ + 'dataset metainfo `classes` should be specified when testing' + + @classmethod + def get_label_map(cls, + new_classes: Optional[Sequence] = None + ) -> Union[Dict, None]: + """Require label mapping. + + The ``label_map`` is a dictionary, its keys are the old label ids and + its values are the new label ids, and is used for changing pixel + labels in load_annotations. If and only if old classes in cls.METAINFO + is not equal to new classes in self._metainfo and nether of them is not + None, `label_map` is not None. + + Args: + new_classes (list, tuple, optional): The new classes name from + metainfo. Default to None. + + + Returns: + dict, optional: The mapping from old classes in cls.METAINFO to + new classes in self._metainfo + """ + old_classes = cls.METAINFO.get('classes', None) + if (new_classes is not None and old_classes is not None + and list(new_classes) != list(old_classes)): + + label_map = {} + if not set(new_classes).issubset(cls.METAINFO['classes']): + raise ValueError( + f'new classes {new_classes} is not a ' + f'subset of classes {old_classes} in METAINFO.') + for i, c in enumerate(old_classes): + if c not in new_classes: + label_map[i] = 255 + else: + label_map[i] = new_classes.index(c) + return label_map + else: + return None + + def _update_palette(self) -> list: + """Update palette after loading metainfo. + + If length of palette is equal to classes, just return the palette. + If palette is not defined, it will randomly generate a palette. + If classes is updated by customer, it will return the subset of + palette. + + Returns: + Sequence: Palette for current dataset. + """ + palette = self._metainfo.get('palette', []) + classes = self._metainfo.get('classes', []) + # palette does match classes + if len(palette) == len(classes): + return palette + + if len(palette) == 0: + # Get random state before set seed, and restore + # random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + new_palette = np.random.randint( + 0, 255, size=(len(classes), 3)).tolist() + np.random.set_state(state) + elif len(palette) >= len(classes) and self.label_map is not None: + new_palette = [] + # return subset of palette + for old_id, new_id in sorted( + self.label_map.items(), key=lambda x: x[1]): + if new_id != 255: + new_palette.append(palette[old_id]) + new_palette = type(palette)(new_palette) + else: + raise ValueError('palette does not match classes ' + f'as metainfo is {self._metainfo}.') + return new_palette + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + img_dir = self.data_prefix.get('img_path', None) + ann_dir = self.data_prefix.get('seg_map_path', None) + if not osp.isdir(self.ann_file) and self.ann_file: + assert osp.isfile(self.ann_file), \ + f'Failed to load `ann_file` {self.ann_file}' + lines = mmengine.list_from_file( + self.ann_file, backend_args=self.backend_args) + for line in lines: + img_name = line.strip() + data_info = dict( + img_path=osp.join(img_dir, img_name + self.img_suffix)) + if ann_dir is not None: + seg_map = img_name + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + else: + _suffix_len = len(self.img_suffix) + for img in fileio.list_dir_or_file( + dir_path=img_dir, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args): + data_info = dict(img_path=osp.join(img_dir, img)) + if ann_dir is not None: + seg_map = img[:-_suffix_len] + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list + + +@DATASETS.register_module() +class BaseCDDataset(BaseDataset): + """Custom dataset for change detection. An example of file structure is as + followed. + + .. code-block:: none + + ├── data + │ ├── my_dataset + │ │ ├── img_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── img_dir2 + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── ann_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{seg_map_suffix} + │ │ │ │ ├── yyy{seg_map_suffix} + │ │ │ │ ├── zzz{seg_map_suffix} + │ │ │ ├── val + + The image names in img_dir and img_dir2 should be consistent. + The img/gt_semantic_seg pair of BaseSegDataset should be of the same + except suffix. A valid img/gt_semantic_seg filename pair should be like + ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included + in the suffix). If split is given, then ``xxx`` is specified in txt file. + Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. + Please refer to ``docs/en/tutorials/new_dataset.md`` for more details. + + + Args: + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as + specify classes to load. Defaults to None. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to None. + data_prefix (dict, optional): Prefix for training data. Defaults to + dict(img_path=None, img_path2=None, seg_map_path=None). + img_suffix (str): Suffix of images. Default: '.jpg' + img_suffix2 (str): Suffix of images. Default: '.jpg' + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=True``. Defaults to False. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + ignore_index (int): The label index to be ignored. Default: 255 + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default to False. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + METAINFO: dict = dict() + + def __init__(self, + ann_file: str = '', + img_suffix='.jpg', + img_suffix2='.jpg', + seg_map_suffix='.png', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict( + img_path='', img_path2='', seg_map_path=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000, + ignore_index: int = 255, + reduce_zero_label: bool = False, + backend_args: Optional[dict] = None) -> None: + + self.img_suffix = img_suffix + self.img_suffix2 = img_suffix2 + self.seg_map_suffix = seg_map_suffix + self.ignore_index = ignore_index + self.reduce_zero_label = reduce_zero_label + self.backend_args = backend_args.copy() if backend_args else None + + self.data_root = data_root + self.data_prefix = copy.copy(data_prefix) + self.ann_file = ann_file + self.filter_cfg = copy.deepcopy(filter_cfg) + self._indices = indices + self.serialize_data = serialize_data + self.test_mode = test_mode + self.max_refetch = max_refetch + self.data_list: List[dict] = [] + self.data_bytes: np.ndarray + + # Set meta information. + self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) + + # Get label map for custom classes + new_classes = self._metainfo.get('classes', None) + self.label_map = self.get_label_map(new_classes) + self._metainfo.update( + dict( + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label)) + + # Update palette based on label map or generate palette + # if it is not defined + updated_palette = self._update_palette() + self._metainfo.update(dict(palette=updated_palette)) + + # Join paths. + if self.data_root is not None: + self._join_prefix() + + # Build pipeline. + self.pipeline = Compose(pipeline) + # Full initialize the dataset. + if not lazy_init: + self.full_init() + + if test_mode: + assert self._metainfo.get('classes') is not None, \ + 'dataset metainfo `classes` should be specified when testing' + + @classmethod + def get_label_map(cls, + new_classes: Optional[Sequence] = None + ) -> Union[Dict, None]: + """Require label mapping. + + The ``label_map`` is a dictionary, its keys are the old label ids and + its values are the new label ids, and is used for changing pixel + labels in load_annotations. If and only if old classes in cls.METAINFO + is not equal to new classes in self._metainfo and nether of them is not + None, `label_map` is not None. + + Args: + new_classes (list, tuple, optional): The new classes name from + metainfo. Default to None. + + + Returns: + dict, optional: The mapping from old classes in cls.METAINFO to + new classes in self._metainfo + """ + old_classes = cls.METAINFO.get('classes', None) + if (new_classes is not None and old_classes is not None + and list(new_classes) != list(old_classes)): + + label_map = {} + if not set(new_classes).issubset(cls.METAINFO['classes']): + raise ValueError( + f'new classes {new_classes} is not a ' + f'subset of classes {old_classes} in METAINFO.') + for i, c in enumerate(old_classes): + if c not in new_classes: + label_map[i] = 255 + else: + label_map[i] = new_classes.index(c) + return label_map + else: + return None + + def _update_palette(self) -> list: + """Update palette after loading metainfo. + + If length of palette is equal to classes, just return the palette. + If palette is not defined, it will randomly generate a palette. + If classes is updated by customer, it will return the subset of + palette. + + Returns: + Sequence: Palette for current dataset. + """ + palette = self._metainfo.get('palette', []) + classes = self._metainfo.get('classes', []) + # palette does match classes + if len(palette) == len(classes): + return palette + + if len(palette) == 0: + # Get random state before set seed, and restore + # random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + new_palette = np.random.randint( + 0, 255, size=(len(classes), 3)).tolist() + np.random.set_state(state) + elif len(palette) >= len(classes) and self.label_map is not None: + new_palette = [] + # return subset of palette + for old_id, new_id in sorted( + self.label_map.items(), key=lambda x: x[1]): + if new_id != 255: + new_palette.append(palette[old_id]) + new_palette = type(palette)(new_palette) + else: + raise ValueError('palette does not match classes ' + f'as metainfo is {self._metainfo}.') + return new_palette + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + img_dir = self.data_prefix.get('img_path', None) + img_dir2 = self.data_prefix.get('img_path2', None) + ann_dir = self.data_prefix.get('seg_map_path', None) + if osp.isfile(self.ann_file): + lines = mmengine.list_from_file( + self.ann_file, backend_args=self.backend_args) + for line in lines: + img_name = line.strip() + if '.' in osp.basename(img_name): + img_name, img_ext = osp.splitext(img_name) + self.img_suffix = img_ext + self.img_suffix2 = img_ext + data_info = dict( + img_path=osp.join(img_dir, img_name + self.img_suffix), + img_path2=osp.join(img_dir2, img_name + self.img_suffix2)) + + if ann_dir is not None: + seg_map = img_name + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + else: + for img in fileio.list_dir_or_file( + dir_path=img_dir, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args): + if '.' in osp.basename(img): + img, img_ext = osp.splitext(img) + self.img_suffix = img_ext + self.img_suffix2 = img_ext + data_info = dict( + img_path=osp.join(img_dir, img + self.img_suffix), + img_path2=osp.join(img_dir2, img + self.img_suffix2)) + if ann_dir is not None: + seg_map = img + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list diff --git a/finetune/mmseg/datasets/c2sfloods.py b/finetune/mmseg/datasets/c2sfloods.py new file mode 100644 index 0000000..3e4a9b8 --- /dev/null +++ b/finetune/mmseg/datasets/c2sfloods.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +from typing import Callable, Dict, List, Optional, Sequence, Union + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + +# LEGEND = [ +# 255 255 255; % Background +# 0 0 0; % Roads +# 100 100 100; % Buildings +# 0 125 0; % Trees +# 0 255 0; % Grass +# 150 80 0; % Bare Soil +# 0 0 150; % Water +# 255 255 0; % Railways +# 150 150 255]; % Swimming Pools + +@DATASETS.register_module() +class C2SFloodDataset(BaseSegDataset): + """Zurich dataset. + + In segmentation map annotation for LoveDA, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + METAINFO = dict( + classes=('Background', 'Water', 'Cloud', 'Cloud shadow'), + palette=[[0,0,0], [255,255,255], [255,0,0], [0,255,0]] + ) + def __init__(self, + img_suffix='.npz', + seg_map_suffix='.npz', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + if not osp.isdir(self.ann_file) and self.ann_file: + assert osp.isfile(self.ann_file), \ + f'Failed to load `ann_file` {self.ann_file}' + lines = json.load(open(self.ann_file)) + for line in lines: + # img_name = line.strip() + data_info = dict( + img_path=osp.join(self.data_root, line['s2_path'])) + if 'target_path' in line.keys(): + data_info['seg_map_path'] = osp.join(self.data_root, line['target_path']) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list diff --git a/finetune/mmseg/datasets/cabuar.py b/finetune/mmseg/datasets/cabuar.py new file mode 100644 index 0000000..d662e5c --- /dev/null +++ b/finetune/mmseg/datasets/cabuar.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +from typing import Callable, Dict, List, Optional, Sequence, Union + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + +@DATASETS.register_module() +class CABURADataset(BaseSegDataset): + """Zurich dataset. + + In segmentation map annotation for LoveDA, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + METAINFO = dict( + classes=('Background', 'Burned area'), + palette=[[0,0,0], [255,255,255]] + ) + def __init__(self, + img_suffix='.npz', + seg_map_suffix='.npz', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + if not osp.isdir(self.ann_file) and self.ann_file: + assert osp.isfile(self.ann_file), \ + f'Failed to load `ann_file` {self.ann_file}' + lines = json.load(open(self.ann_file)) + for line in lines: + # img_name = line.strip() + data_info = dict( + img_path=osp.join(self.data_root, line['post_fire'])) + if 'mask' in line.keys(): + data_info['seg_map_path'] = osp.join(self.data_root, line['mask']) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list diff --git a/finetune/mmseg/datasets/germany.py b/finetune/mmseg/datasets/germany.py new file mode 100644 index 0000000..21d5218 --- /dev/null +++ b/finetune/mmseg/datasets/germany.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +from typing import Callable, Dict, List, Optional, Sequence, Union + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset +from mmengine.logging import print_log +import pandas as pd +# LEGEND = [ +# 255 255 255; % Background +# 0 0 0; % Roads +# 100 100 100; % Buildings +# 0 125 0; % Trees +# 0 255 0; % Grass +# 150 80 0; % Bare Soil +# 0 0 150; % Water +# 255 255 0; % Railways +# 150 150 255]; % Swimming Pools + +@DATASETS.register_module() +class GermanyCropDataset(BaseSegDataset): + """Zurich dataset. + + In segmentation map annotation for LoveDA, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + # {0: "unknown", 1: "sugar_beet", 2: "summer_oat", 3: "meadow", 5: "rape", 8: "hop", + # 9: "winter_spelt", 12: "winter_triticale", 13: "beans", 15: "peas", 16: "potatoes", + # 17: "soybeans", 19: "asparagus", 22: "winter_wheat", 23: "winter_barley", 24: "winter_rye", + # 25: "summer_barley", 26: "maize"} + METAINFO = dict( + classes=('sugar_beet', 'summer_oat', 'meadow', 'rape', 'hop', 'winter_spelt', 'winter_triticale', 'beans', 'peas',\ + 'potatoes', 'soybeans', 'asparagus', 'winter_wheat', 'winter_barley', 'winter_rye', 'summer_barley', 'maize'), + palette=[(255, 255, 255), (255, 255, 170), (255, 255, 85), (255, 170, 255), (255, 170, 170), (255, 170, 85), \ + (255, 85, 255), (255, 85, 170), (255, 85, 85), (170, 255, 255), (170, 255, 170), (170, 255, 85), (170, 170, 255), \ + (170, 170, 170), (170, 170, 85), (170, 85, 255), (170, 85, 170)]) + def __init__(self, + img_suffix='.pickle', + seg_map_suffix='.pickle', + reduce_zero_label=True, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + if not osp.isdir(self.ann_file) and self.ann_file: + assert osp.isfile(self.ann_file), \ + f'Failed to load `ann_file` {self.ann_file}' + lines = json.load(open(self.ann_file)) + print_log(f'dataset count: {len(lines)}') + for line in lines: + data_info = dict( + img_path=osp.join(self.data_root, line['s2_path'])) + if 'target_path' in line.keys(): + data_info['seg_map_path'] = osp.join(self.data_root, line['target_path']) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list diff --git a/finetune/mmseg/datasets/sos.py b/finetune/mmseg/datasets/sos.py new file mode 100644 index 0000000..f386c9d --- /dev/null +++ b/finetune/mmseg/datasets/sos.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os.path as osp +from typing import Callable, Dict, List, Optional, Sequence, Union + +from mmseg.registry import DATASETS +from .basesegdataset import BaseSegDataset + +# LEGEND = [ +# 255 255 255; % Background +# 0 0 0; % Roads +# 100 100 100; % Buildings +# 0 125 0; % Trees +# 0 255 0; % Grass +# 150 80 0; % Bare Soil +# 0 0 150; % Water +# 255 255 0; % Railways +# 150 150 255]; % Swimming Pools + +@DATASETS.register_module() +class SOSDataset(BaseSegDataset): + """Zurich dataset. + + In segmentation map annotation for LoveDA, 0 is the ignore index. + ``reduce_zero_label`` should be set to True. The ``img_suffix`` and + ``seg_map_suffix`` are both fixed to '.png'. + """ + METAINFO = dict( + classes=('Background', 'Oil Spill Area'), + palette=[[0,0,0], [255,255,255]] + ) + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) -> None: + super().__init__( + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + reduce_zero_label=reduce_zero_label, + **kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + if not osp.isdir(self.ann_file) and self.ann_file: + assert osp.isfile(self.ann_file), \ + f'Failed to load `ann_file` {self.ann_file}' + lines = json.load(open(self.ann_file)) + for line in lines: + # img_name = line.strip() + data_info = dict( + img_path=osp.join(self.data_root, line['s1_path'])) + if 'target_path' in line.keys(): + data_info['seg_map_path'] = osp.join(self.data_root, line['target_path']) + data_info['label_map'] = self.label_map + data_info['reduce_zero_label'] = self.reduce_zero_label + data_info['seg_fields'] = [] + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + # print(data_list) + return data_list diff --git a/finetune/mmseg/datasets/transforms/__init__.py b/finetune/mmseg/datasets/transforms/__init__.py new file mode 100644 index 0000000..08789fd --- /dev/null +++ b/finetune/mmseg/datasets/transforms/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .formatting import PackSegInputs +from .loading import (LoadAnnotations, LoadBiomedicalAnnotation, + LoadBiomedicalData, LoadBiomedicalImageFromFile, + LoadDepthAnnotation, LoadImageFromNDArray, + LoadMultipleRSImageFromFile, LoadSingleRSImageFromFile) +# yapf: disable +from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad, + BioMedical3DRandomCrop, BioMedical3DRandomFlip, + BioMedicalGaussianBlur, BioMedicalGaussianNoise, + BioMedicalRandomGamma, ConcatCDInput, GenerateEdge, + PhotoMetricDistortion, RandomCrop, RandomCutOut, + RandomDepthMix, RandomFlip, RandomMosaic, + RandomRotate, RandomRotFlip, Rerange, Resize, + ResizeShortestEdge, ResizeToMultiple, RGB2Gray, + SegRescale) +from .loading_npz import (LoadAnnotationsNpz, LoadImageFromNpz, LoadTsImageFromNpz, LoadAnnotationsOil, LoadImageOil, LoadImageSingleChannel) + +# yapf: enable +__all__ = [ + 'LoadAnnotations', 'RandomCrop', 'BioMedical3DRandomCrop', 'SegRescale', + 'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', + 'RGB2Gray', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', + 'ResizeToMultiple', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', + 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', + 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', + 'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad', + 'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput', + 'LoadMultipleRSImageFromFile', 'LoadDepthAnnotation', 'RandomDepthMix', + 'RandomFlip', 'Resize', 'LoadAnnotationsNpz', 'LoadImageFromNpz', 'LoadTsImageFromNpz', + 'LoadAnnotationsOil', 'LoadImageOil', 'LoadImageSingleChannel' +] diff --git a/finetune/mmseg/datasets/transforms/formatting.py b/finetune/mmseg/datasets/transforms/formatting.py new file mode 100644 index 0000000..bd25055 --- /dev/null +++ b/finetune/mmseg/datasets/transforms/formatting.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import numpy as np +from mmcv.transforms import to_tensor +from mmcv.transforms.base import BaseTransform +from mmengine.structures import PixelData + +from mmseg.registry import TRANSFORMS +from mmseg.structures import SegDataSample + + +@TRANSFORMS.register_module() +class PackSegInputs(BaseTransform): + """Pack the inputs data for the semantic segmentation. + + The ``img_meta`` item is always populated. The contents of the + ``img_meta`` dictionary depends on ``meta_keys``. By default this includes: + + - ``img_path``: filename of the image + + - ``ori_shape``: original shape of the image as a tuple (h, w, c) + + - ``img_shape``: shape of the image input to the network as a tuple \ + (h, w, c). Note that images may be zero padded on the \ + bottom/right if the batch tensor is larger than this shape. + + - ``pad_shape``: shape of padded images + + - ``scale_factor``: a float indicating the preprocessing scale + + - ``flip``: a boolean indicating if image flip transform was used + + - ``flip_direction``: the flipping direction + + Args: + meta_keys (Sequence[str], optional): Meta keys to be packed from + ``SegDataSample`` and collected in ``data[img_metas]``. + Default: ``('img_path', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction')`` + """ + + def __init__(self, + meta_keys=('img_path', 'seg_map_path', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction', 'reduce_zero_label')): + self.meta_keys = meta_keys + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: + + - 'inputs' (obj:`torch.Tensor`): The forward data of models. + - 'data_sample' (obj:`SegDataSample`): The annotation info of the + sample. + """ + packed_results = dict() + if 'img' in results: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + if not img.flags.c_contiguous: + img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1))) + else: + img = img.transpose(2, 0, 1) + img = to_tensor(img).contiguous() + packed_results['inputs'] = img + + data_sample = SegDataSample() + if 'gt_seg_map' in results: + if len(results['gt_seg_map'].shape) == 2: + data = to_tensor(results['gt_seg_map'][None, + ...].astype(np.int64)) + else: + warnings.warn('Please pay attention your ground truth ' + 'segmentation map, usually the segmentation ' + 'map is 2D, but got ' + f'{results["gt_seg_map"].shape}') + data = to_tensor(results['gt_seg_map'].astype(np.int64)) + gt_sem_seg_data = dict(data=data) + data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) + + if 'gt_edge_map' in results: + gt_edge_data = dict( + data=to_tensor(results['gt_edge_map'][None, + ...].astype(np.int64))) + data_sample.set_data(dict(gt_edge_map=PixelData(**gt_edge_data))) + + if 'gt_depth_map' in results: + gt_depth_data = dict( + data=to_tensor(results['gt_depth_map'][None, ...])) + data_sample.set_data(dict(gt_depth_map=PixelData(**gt_depth_data))) + + img_meta = {} + for key in self.meta_keys: + if key in results: + img_meta[key] = results[key] + data_sample.set_metainfo(img_meta) + packed_results['data_samples'] = data_sample + + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(meta_keys={self.meta_keys})' + return repr_str diff --git a/finetune/mmseg/datasets/transforms/loading.py b/finetune/mmseg/datasets/transforms/loading.py new file mode 100644 index 0000000..c28937e --- /dev/null +++ b/finetune/mmseg/datasets/transforms/loading.py @@ -0,0 +1,771 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from pathlib import Path +from typing import Dict, Optional, Union + +import mmcv +import mmengine.fileio as fileio +import numpy as np +from mmcv.transforms import BaseTransform +from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations +from mmcv.transforms import LoadImageFromFile + +from mmseg.registry import TRANSFORMS +from mmseg.utils import datafrombytes + +try: + from osgeo import gdal +except ImportError: + gdal = None + + +@TRANSFORMS.register_module() +class LoadAnnotations(MMCV_LoadAnnotations): + """Load annotations for semantic segmentation provided by dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + # Filename of semantic segmentation ground truth file. + 'seg_map_path': 'a/b/c' + } + + After this module, the annotation has been changed to the format below: + + .. code-block:: python + + { + # in str + 'seg_fields': List + # In uint8 type. + 'gt_seg_map': np.ndarray (H, W) + } + + Required Keys: + + - seg_map_path (str): Path of semantic segmentation ground truth file. + + Added Keys: + + - seg_fields (List) + - gt_seg_map (np.uint8) + + Args: + reduce_zero_label (bool, optional): Whether reduce all label value + by 1. Usually used for datasets where 0 is background label. + Defaults to None. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :fun:``mmcv.imfrombytes`` for details. + Defaults to 'pillow'. + backend_args (dict): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__( + self, + reduce_zero_label=None, + backend_args=None, + imdecode_backend='pillow', + ) -> None: + super().__init__( + with_bbox=False, + with_label=False, + with_seg=True, + with_keypoints=False, + imdecode_backend=imdecode_backend, + backend_args=backend_args) + self.reduce_zero_label = reduce_zero_label + if self.reduce_zero_label is not None: + warnings.warn('`reduce_zero_label` will be deprecated, ' + 'if you would like to ignore the zero label, please ' + 'set `reduce_zero_label=True` when dataset ' + 'initialized') + self.imdecode_backend = imdecode_backend + + def _load_seg_map(self, results: dict) -> None: + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + + img_bytes = fileio.get( + results['seg_map_path'], backend_args=self.backend_args) + gt_semantic_seg = mmcv.imfrombytes( + img_bytes, flag='unchanged', + backend=self.imdecode_backend).squeeze().astype(np.uint8) + + # reduce zero_label + if self.reduce_zero_label is None: + self.reduce_zero_label = results['reduce_zero_label'] + assert self.reduce_zero_label == results['reduce_zero_label'], \ + 'Initialize dataset with `reduce_zero_label` as ' \ + f'{results["reduce_zero_label"]} but when load annotation ' \ + f'the `reduce_zero_label` is {self.reduce_zero_label}' + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = 255 + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == 254] = 255 + # modify if custom classes + if results.get('label_map', None) is not None: + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_copy = gt_semantic_seg.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id + results['gt_seg_map'] = gt_semantic_seg + results['seg_fields'].append('gt_seg_map') + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(reduce_zero_label={self.reduce_zero_label}, ' + repr_str += f"imdecode_backend='{self.imdecode_backend}', " + repr_str += f'backend_args={self.backend_args})' + return repr_str + + +@TRANSFORMS.register_module() +class LoadImageFromNDArray(LoadImageFromFile): + """Load an image from ``results['img']``. + + Similar with :obj:`LoadImageFromFile`, but the image has been loaded as + :obj:`np.ndarray` in ``results['img']``. Can be used when loading image + from webcam. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_path + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def transform(self, results: dict) -> dict: + """Transform function to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + img = results['img'] + if self.to_float32: + img = img.astype(np.float32) + + results['img_path'] = None + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + + +@TRANSFORMS.register_module() +class LoadBiomedicalImageFromFile(BaseTransform): + """Load an biomedical mage from file. + + Required Keys: + + - img_path + + Added Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities, and data type is float32 + if set to_float32 = True, or float64 if decode_backend is 'nifti' and + to_float32 is False. + - img_shape + - ori_shape + + Args: + decode_backend (str): The data decoding backend type. Options are + 'numpy'and 'nifti', and there is a convention that when backend is + 'nifti' the axis of data loaded is XYZ, and when backend is + 'numpy', the the axis is ZYX. The data will be transposed if the + backend is 'nifti'. Defaults to 'nifti'. + to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z. + Defaults to False. + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an float64 array. + Defaults to True. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + decode_backend: str = 'nifti', + to_xyz: bool = False, + to_float32: bool = True, + backend_args: Optional[dict] = None) -> None: + self.decode_backend = decode_backend + self.to_xyz = to_xyz + self.to_float32 = to_float32 + self.backend_args = backend_args.copy() if backend_args else None + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + + data_bytes = fileio.get(filename, self.backend_args) + img = datafrombytes(data_bytes, backend=self.decode_backend) + + if self.to_float32: + img = img.astype(np.float32) + + if len(img.shape) == 3: + img = img[None, ...] + + if self.decode_backend == 'nifti': + img = img.transpose(0, 3, 2, 1) + + if self.to_xyz: + img = img.transpose(0, 3, 2, 1) + + results['img'] = img + results['img_shape'] = img.shape[1:] + results['ori_shape'] = img.shape[1:] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f"decode_backend='{self.decode_backend}', " + f'to_xyz={self.to_xyz}, ' + f'to_float32={self.to_float32}, ' + f'backend_args={self.backend_args})') + return repr_str + + +@TRANSFORMS.register_module() +class LoadBiomedicalAnnotation(BaseTransform): + """Load ``seg_map`` annotation provided by biomedical dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + 'gt_seg_map': np.ndarray (X, Y, Z) or (Z, Y, X) + } + + Required Keys: + + - seg_map_path + + Added Keys: + + - gt_seg_map (np.ndarray): Biomedical seg map with shape (Z, Y, X) by + default, and data type is float32 if set to_float32 = True, or + float64 if decode_backend is 'nifti' and to_float32 is False. + + Args: + decode_backend (str): The data decoding backend type. Options are + 'numpy'and 'nifti', and there is a convention that when backend is + 'nifti' the axis of data loaded is XYZ, and when backend is + 'numpy', the the axis is ZYX. The data will be transposed if the + backend is 'nifti'. Defaults to 'nifti'. + to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z. + Defaults to False. + to_float32 (bool): Whether to convert the loaded seg map to a float32 + numpy array. If set to False, the loaded image is an float64 array. + Defaults to True. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See :class:`mmengine.fileio` for details. + Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + decode_backend: str = 'nifti', + to_xyz: bool = False, + to_float32: bool = True, + backend_args: Optional[dict] = None) -> None: + super().__init__() + self.decode_backend = decode_backend + self.to_xyz = to_xyz + self.to_float32 = to_float32 + self.backend_args = backend_args.copy() if backend_args else None + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + data_bytes = fileio.get(results['seg_map_path'], self.backend_args) + gt_seg_map = datafrombytes(data_bytes, backend=self.decode_backend) + + if self.to_float32: + gt_seg_map = gt_seg_map.astype(np.float32) + + if self.decode_backend == 'nifti': + gt_seg_map = gt_seg_map.transpose(2, 1, 0) + + if self.to_xyz: + gt_seg_map = gt_seg_map.transpose(2, 1, 0) + + results['gt_seg_map'] = gt_seg_map + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f"decode_backend='{self.decode_backend}', " + f'to_xyz={self.to_xyz}, ' + f'to_float32={self.to_float32}, ' + f'backend_args={self.backend_args})') + return repr_str + + +@TRANSFORMS.register_module() +class LoadBiomedicalData(BaseTransform): + """Load an biomedical image and annotation from file. + + The loading data format is as the following: + + .. code-block:: python + + { + 'img': np.ndarray data[:-1, X, Y, Z] + 'seg_map': np.ndarray data[-1, X, Y, Z] + } + + + Required Keys: + + - img_path + + Added Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + - img_shape + - ori_shape + + Args: + with_seg (bool): Whether to parse and load the semantic segmentation + annotation. Defaults to False. + decode_backend (str): The data decoding backend type. Options are + 'numpy'and 'nifti', and there is a convention that when backend is + 'nifti' the axis of data loaded is XYZ, and when backend is + 'numpy', the the axis is ZYX. The data will be transposed if the + backend is 'nifti'. Defaults to 'nifti'. + to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z. + Defaults to False. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + with_seg=False, + decode_backend: str = 'numpy', + to_xyz: bool = False, + backend_args: Optional[dict] = None) -> None: # noqa + self.with_seg = with_seg + self.decode_backend = decode_backend + self.to_xyz = to_xyz + self.backend_args = backend_args.copy() if backend_args else None + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + data_bytes = fileio.get(results['img_path'], self.backend_args) + data = datafrombytes(data_bytes, backend=self.decode_backend) + # img is 4D data (N, X, Y, Z), N is the number of protocol + img = data[:-1, :] + + if self.decode_backend == 'nifti': + img = img.transpose(0, 3, 2, 1) + + if self.to_xyz: + img = img.transpose(0, 3, 2, 1) + + results['img'] = img + results['img_shape'] = img.shape[1:] + results['ori_shape'] = img.shape[1:] + + if self.with_seg: + gt_seg_map = data[-1, :] + if self.decode_backend == 'nifti': + gt_seg_map = gt_seg_map.transpose(2, 1, 0) + + if self.to_xyz: + gt_seg_map = gt_seg_map.transpose(2, 1, 0) + results['gt_seg_map'] = gt_seg_map + return results + + def __repr__(self) -> str: + repr_str = (f'{self.__class__.__name__}(' + f'with_seg={self.with_seg}, ' + f"decode_backend='{self.decode_backend}', " + f'to_xyz={self.to_xyz}, ' + f'backend_args={self.backend_args})') + return repr_str + + +@TRANSFORMS.register_module() +class InferencerLoader(BaseTransform): + """Load an image from ``results['img']``. + + Similar with :obj:`LoadImageFromFile`, but the image has been loaded as + :obj:`np.ndarray` in ``results['img']``. Can be used when loading image + from webcam. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_path + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def __init__(self, **kwargs) -> None: + super().__init__() + self.from_file = TRANSFORMS.build( + dict(type='LoadImageFromFile', **kwargs)) + self.from_ndarray = TRANSFORMS.build( + dict(type='LoadImageFromNDArray', **kwargs)) + + def transform(self, single_input: Union[str, np.ndarray, dict]) -> dict: + """Transform function to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + if isinstance(single_input, str): + inputs = dict(img_path=single_input) + elif isinstance(single_input, np.ndarray): + inputs = dict(img=single_input) + elif isinstance(single_input, dict): + inputs = single_input + else: + raise NotImplementedError + + if 'img' in inputs: + return self.from_ndarray(inputs) + return self.from_file(inputs) + + +@TRANSFORMS.register_module() +class LoadSingleRSImageFromFile(BaseTransform): + """Load a Remote Sensing mage from file. + + Required Keys: + + - img_path + + Modified Keys: + + - img + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is a float64 array. + Defaults to True. + """ + + def __init__(self, to_float32: bool = True): + self.to_float32 = to_float32 + + if gdal is None: + raise RuntimeError('gdal is not installed') + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + ds = gdal.Open(filename) + if ds is None: + raise Exception(f'Unable to open file: {filename}') + img = np.einsum('ijk->jki', ds.ReadAsArray()) + + if self.to_float32: + img = img.astype(np.float32) + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32})') + return repr_str + + +@TRANSFORMS.register_module() +class LoadMultipleRSImageFromFile(BaseTransform): + """Load two Remote Sensing mage from file. + + Required Keys: + + - img_path + - img_path2 + + Modified Keys: + + - img + - img2 + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is a float64 array. + Defaults to True. + """ + + def __init__(self, to_float32: bool = True): + if gdal is None: + raise RuntimeError('gdal is not installed') + self.to_float32 = to_float32 + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + filename2 = results['img_path2'] + + ds = gdal.Open(filename) + ds2 = gdal.Open(filename2) + + if ds is None: + raise Exception(f'Unable to open file: {filename}') + if ds2 is None: + raise Exception(f'Unable to open file: {filename2}') + + img = np.einsum('ijk->jki', ds.ReadAsArray()) + img2 = np.einsum('ijk->jki', ds2.ReadAsArray()) + + if self.to_float32: + img = img.astype(np.float32) + img2 = img2.astype(np.float32) + + if img.shape != img2.shape: + raise Exception(f'Image shapes do not match:' + f' {img.shape} vs {img2.shape}') + + results['img'] = img + results['img2'] = img2 + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32})') + return repr_str + + +@TRANSFORMS.register_module() +class LoadDepthAnnotation(BaseTransform): + """Load ``depth_map`` annotation provided by depth estimation dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + 'gt_depth_map': np.ndarray [Y, X] + } + + Required Keys: + + - seg_depth_path + + Added Keys: + + - gt_depth_map (np.ndarray): Depth map with shape (Y, X) by + default, and data type is float32 if set to_float32 = True. + - depth_rescale_factor (float): The rescale factor of depth map, which + can be used to recover the original value of depth map. + + Args: + decode_backend (str): The data decoding backend type. Options are + 'numpy', 'nifti', and 'cv2'. Defaults to 'cv2'. + to_float32 (bool): Whether to convert the loaded depth map to a float32 + numpy array. If set to False, the loaded image is an uint16 array. + Defaults to True. + depth_rescale_factor (float): Factor to rescale the depth value to + limit the range. Defaults to 1.0. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See :class:`mmengine.fileio` for details. + Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + decode_backend: str = 'cv2', + to_float32: bool = True, + depth_rescale_factor: float = 1.0, + backend_args: Optional[dict] = None) -> None: + super().__init__() + self.decode_backend = decode_backend + self.to_float32 = to_float32 + self.depth_rescale_factor = depth_rescale_factor + self.backend_args = backend_args.copy() if backend_args else None + + def transform(self, results: Dict) -> Dict: + """Functions to load depth map. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded depth map. + """ + data_bytes = fileio.get(results['depth_map_path'], self.backend_args) + gt_depth_map = datafrombytes(data_bytes, backend=self.decode_backend) + + if self.to_float32: + gt_depth_map = gt_depth_map.astype(np.float32) + + gt_depth_map *= self.depth_rescale_factor + results['gt_depth_map'] = gt_depth_map + results['seg_fields'].append('gt_depth_map') + results['depth_rescale_factor'] = self.depth_rescale_factor + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f"decode_backend='{self.decode_backend}', " + f'to_float32={self.to_float32}, ' + f'backend_args={self.backend_args})') + return repr_str + + +@TRANSFORMS.register_module() +class LoadImageFromNpyFile(LoadImageFromFile): + """Load an image from ``results['img_path']``. + + Required Keys: + + - img_path + + Modified Keys: + + - img + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def transform(self, results: dict) -> Optional[dict]: + """Functions to load image. + + Args: + results (dict): Result dict from + :class:`mmengine.dataset.BaseDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + + try: + if Path(filename).suffix in ['.npy', '.npz']: + img = np.load(filename) + else: + if self.file_client_args is not None: + file_client = fileio.FileClient.infer_client( + self.file_client_args, filename) + img_bytes = file_client.get(filename) + else: + img_bytes = fileio.get( + filename, backend_args=self.backend_args) + img = mmcv.imfrombytes( + img_bytes, + flag=self.color_type, + backend=self.imdecode_backend) + except Exception as e: + if self.ignore_empty: + return None + else: + raise e + + # in some cases, images are not read successfully, the img would be + # `None`, refer to https://github.com/open-mmlab/mmpretrain/issues/1427 + assert img is not None, f'failed to load image: {filename}' + if self.to_float32: + img = img.astype(np.float32) + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results diff --git a/finetune/mmseg/datasets/transforms/loading_npz.py b/finetune/mmseg/datasets/transforms/loading_npz.py new file mode 100644 index 0000000..7418b5c --- /dev/null +++ b/finetune/mmseg/datasets/transforms/loading_npz.py @@ -0,0 +1,493 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Dict, Optional, Union +import io + +import mmcv +import mmengine.fileio as fileio +import numpy as np +from mmcv.transforms import BaseTransform +from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations +from mmcv.transforms import LoadImageFromFile +import imageio +from mmseg.registry import TRANSFORMS +from mmseg.utils import datafrombytes + +try: + from osgeo import gdal +except ImportError: + gdal = None + + +@TRANSFORMS.register_module() +class LoadAnnotationsNpz(MMCV_LoadAnnotations): + """Load annotations for semantic segmentation provided by dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + # Filename of semantic segmentation ground truth file. + 'seg_map_path': 'a/b/c' + } + + After this module, the annotation has been changed to the format below: + + .. code-block:: python + + { + # in str + 'seg_fields': List + # In uint8 type. + 'gt_seg_map': np.ndarray (H, W) + } + + Required Keys: + + - seg_map_path (str): Path of semantic segmentation ground truth file. + + Added Keys: + + - seg_fields (List) + - gt_seg_map (np.uint8) + + Args: + reduce_zero_label (bool, optional): Whether reduce all label value + by 1. Usually used for datasets where 0 is background label. + Defaults to None. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :fun:``mmcv.imfrombytes`` for details. + Defaults to 'pillow'. + backend_args (dict): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__( + self, + reduce_zero_label=None, + backend_args=None, + data_key='arr_0', + imdecode_backend='pillow', + ) -> None: + super().__init__( + with_bbox=False, + with_label=False, + with_seg=True, + with_keypoints=False, + imdecode_backend=imdecode_backend, + backend_args=backend_args) + self.reduce_zero_label = reduce_zero_label + if self.reduce_zero_label is not None: + warnings.warn('`reduce_zero_label` will be deprecated, ' + 'if you would like to ignore the zero label, please ' + 'set `reduce_zero_label=True` when dataset ' + 'initialized') + self.imdecode_backend = imdecode_backend + self.data_key = data_key + + def _load_seg_map(self, results: dict) -> None: + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + + # img_bytes = fileio.get( + # results['seg_map_path'], backend_args=self.backend_args) + # gt_semantic_seg = mmcv.imfrombytes( + # img_bytes, flag='unchanged', + # backend=self.imdecode_backend).squeeze().astype(np.uint8) + gt_semantic_seg = np.load(results['seg_map_path'])[self.data_key].squeeze().astype(np.uint8) + # reduce zero_label + if self.reduce_zero_label is None: + self.reduce_zero_label = results['reduce_zero_label'] + assert self.reduce_zero_label == results['reduce_zero_label'], \ + 'Initialize dataset with `reduce_zero_label` as ' \ + f'{results["reduce_zero_label"]} but when load annotation ' \ + f'the `reduce_zero_label` is {self.reduce_zero_label}' + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = 255 + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == 254] = 255 + # modify if custom classes + if results.get('label_map', None) is not None: + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_copy = gt_semantic_seg.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id + results['gt_seg_map'] = gt_semantic_seg + results['seg_fields'].append('gt_seg_map') + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(reduce_zero_label={self.reduce_zero_label}, ' + repr_str += f"imdecode_backend='{self.imdecode_backend}', " + repr_str += f'backend_args={self.backend_args})' + return repr_str + +@TRANSFORMS.register_module() +class LoadImageSingleChannel(BaseTransform): + """Load a Remote Sensing mage from file. + + Required Keys: + + - img_path + + Modified Keys: + + - img + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is a float64 array. + Defaults to True. + """ + + def __init__(self, to_float32: bool = True): + self.to_float32 = to_float32 + # self.data_key = data_key + if gdal is None: + raise RuntimeError('gdal is not installed') + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + img = imageio.imread(filename) # h, w, c + img = img[:, :, 0] + # if ds is None: + # raise Exception(f'Unable to open file: {filename}') + # print(img)s + # img = np.einsum('ijk->jki', img) + + if self.to_float32: + img = img.astype(np.float32) + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32})') + return repr_str + +@TRANSFORMS.register_module() +class LoadAnnotationsOil(MMCV_LoadAnnotations): + """Load annotations for semantic segmentation provided by dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + # Filename of semantic segmentation ground truth file. + 'seg_map_path': 'a/b/c' + } + + After this module, the annotation has been changed to the format below: + + .. code-block:: python + + { + # in str + 'seg_fields': List + # In uint8 type. + 'gt_seg_map': np.ndarray (H, W) + } + + Required Keys: + + - seg_map_path (str): Path of semantic segmentation ground truth file. + + Added Keys: + + - seg_fields (List) + - gt_seg_map (np.uint8) + + Args: + reduce_zero_label (bool, optional): Whether reduce all label value + by 1. Usually used for datasets where 0 is background label. + Defaults to None. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :fun:``mmcv.imfrombytes`` for details. + Defaults to 'pillow'. + backend_args (dict): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__( + self, + reduce_zero_label=None, + backend_args=None, + data_key='arr_0', + imdecode_backend='pillow', + ) -> None: + super().__init__( + with_bbox=False, + with_label=False, + with_seg=True, + with_keypoints=False, + imdecode_backend=imdecode_backend, + backend_args=backend_args) + self.reduce_zero_label = reduce_zero_label + if self.reduce_zero_label is not None: + warnings.warn('`reduce_zero_label` will be deprecated, ' + 'if you would like to ignore the zero label, please ' + 'set `reduce_zero_label=True` when dataset ' + 'initialized') + self.imdecode_backend = imdecode_backend + self.data_key = data_key + + def _load_seg_map(self, results: dict) -> None: + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + + # img_bytes = fileio.get( + # results['seg_map_path'], backend_args=self.backend_args) + # gt_semantic_seg = mmcv.imfrombytes( + # img_bytes, flag='unchanged', + # backend=self.imdecode_backend).squeeze().astype(np.uint8) + seg_map = gdal.Open(results['seg_map_path']).ReadAsArray() + gt_semantic_seg = np.zeros_like(seg_map).astype(np.uint8) + gt_semantic_seg[seg_map==3.] = 1 + # reduce zero_label + if self.reduce_zero_label is None: + self.reduce_zero_label = results['reduce_zero_label'] + assert self.reduce_zero_label == results['reduce_zero_label'], \ + 'Initialize dataset with `reduce_zero_label` as ' \ + f'{results["reduce_zero_label"]} but when load annotation ' \ + f'the `reduce_zero_label` is {self.reduce_zero_label}' + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = 255 + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == 254] = 255 + # modify if custom classes + if results.get('label_map', None) is not None: + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_copy = gt_semantic_seg.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id + results['gt_seg_map'] = gt_semantic_seg + results['seg_fields'].append('gt_seg_map') + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(reduce_zero_label={self.reduce_zero_label}, ' + repr_str += f"imdecode_backend='{self.imdecode_backend}', " + repr_str += f'backend_args={self.backend_args})' + return repr_str + + +@TRANSFORMS.register_module() +class LoadImageFromNpz(BaseTransform): + """Load a Remote Sensing mage from file. + + Required Keys: + + - img_path + + Modified Keys: + + - img + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is a float64 array. + Defaults to True. + """ + + def __init__(self, data_key='arr_0', to_float32: bool = True): + self.to_float32 = to_float32 + self.data_key = data_key + if gdal is None: + raise RuntimeError('gdal is not installed') + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + img = np.load(filename)[self.data_key] + # if ds is None: + # raise Exception(f'Unable to open file: {filename}') + # print(img)s + img = np.einsum('ijk->jki', img) + + if self.to_float32: + img = img.astype(np.float32) + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32})') + return repr_str + +@TRANSFORMS.register_module() +class LoadImageOil(BaseTransform): + """Load a Remote Sensing mage from file. + + Required Keys: + + - img_path + + Modified Keys: + + - img + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is a float64 array. + Defaults to True. + """ + + def __init__(self, data_key='arr_0', to_float32: bool = True): + self.to_float32 = to_float32 + self.data_key = data_key + if gdal is None: + raise RuntimeError('gdal is not installed') + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + img = gdal.Open(filename).ReadAsArray() + img = img[:,:,None] + # if ds is None: + # raise Exception(f'Unable to open file: {filename}') + # print(img)s + # img = np.einsum('ijk->jki', img) + + if self.to_float32: + img = img.astype(np.float32) + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32})') + return repr_str + +@TRANSFORMS.register_module() +class LoadTsImageFromNpz(BaseTransform): + """Load a Remote Sensing mage from file. + + Required Keys: + + - img_path + + Modified Keys: + + - img + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is a float64 array. + Defaults to True. + """ + + def __init__(self, data_key='arr_0', to_float32: bool = True, ts_size: int=10): + self.to_float32 = to_float32 + self.data_key = data_key + if gdal is None: + raise RuntimeError('gdal is not installed') + self.ts_size = ts_size + + def transform(self, results: Dict) -> Dict: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['img_path'] + img = np.load(filename)[self.data_key] + ts, c, h, w = img.shape + if ts >= self.ts_size: + selected_indices = np.random.choice(ts, size=self.ts_size, replace=False) + else: + selected_indices = np.random.choice(ts, size=self.ts_size, replace=True) + selected_indices.sort() + img = img[selected_indices, :, :, :] + # print(f'after input shape: {img.shape}') + img = img.transpose(2, 3, 0, 1).reshape(h, w, self.ts_size*c) # h, w, ts, c -> h, w, ts*c + # if ds is None: + # raise Exception(f'Unable to open file: {filename}') + # print(img)s + # img = np.einsum('ijk->jki', img) + + if self.to_float32: + img = img.astype(np.float32) + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'to_float32={self.to_float32})') + return repr_str + + diff --git a/finetune/mmseg/datasets/transforms/transforms.py b/finetune/mmseg/datasets/transforms/transforms.py new file mode 100644 index 0000000..64e2323 --- /dev/null +++ b/finetune/mmseg/datasets/transforms/transforms.py @@ -0,0 +1,2537 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import inspect +import warnings +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import cv2 +import mmcv +import mmengine +import numpy as np +from mmcv.transforms import RandomFlip as MMCV_RandomFlip +from mmcv.transforms import Resize as MMCV_Resize +from mmcv.transforms.base import BaseTransform +from mmcv.transforms.utils import cache_randomness +from mmengine.utils import is_tuple_of +from numpy import random +from scipy.ndimage import gaussian_filter + +from mmseg.datasets.dataset_wrappers import MultiImageMixDataset +from mmseg.registry import TRANSFORMS + +try: + import albumentations + from albumentations import Compose + ALBU_INSTALLED = True +except ImportError: + albumentations = None + Compose = None + ALBU_INSTALLED = False + + +@TRANSFORMS.register_module() +class ResizeToMultiple(BaseTransform): + """Resize images & seg to multiple of divisor. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - img_shape + - pad_shape + + Args: + size_divisor (int): images and gt seg maps need to resize to multiple + of size_divisor. Default: 32. + interpolation (str, optional): The interpolation mode of image resize. + Default: None + """ + + def __init__(self, size_divisor=32, interpolation=None): + self.size_divisor = size_divisor + self.interpolation = interpolation + + def transform(self, results: dict) -> dict: + """Call function to resize images, semantic segmentation map to + multiple of size divisor. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape' keys are updated. + """ + # Align image to multiple of size divisor. + img = results['img'] + img = mmcv.imresize_to_multiple( + img, + self.size_divisor, + scale_factor=1, + interpolation=self.interpolation + if self.interpolation else 'bilinear') + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['pad_shape'] = img.shape[:2] + + # Align segmentation map to multiple of size divisor. + for key in results.get('seg_fields', []): + gt_seg = results[key] + gt_seg = mmcv.imresize_to_multiple( + gt_seg, + self.size_divisor, + scale_factor=1, + interpolation='nearest') + results[key] = gt_seg + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(size_divisor={self.size_divisor}, ' + f'interpolation={self.interpolation})') + return repr_str + + +@TRANSFORMS.register_module() +class Rerange(BaseTransform): + """Rerange the image pixel value. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + min_value (float or int): Minimum value of the reranged image. + Default: 0. + max_value (float or int): Maximum value of the reranged image. + Default: 255. + """ + + def __init__(self, min_value=0, max_value=255): + assert isinstance(min_value, float) or isinstance(min_value, int) + assert isinstance(max_value, float) or isinstance(max_value, int) + assert min_value < max_value + self.min_value = min_value + self.max_value = max_value + + def transform(self, results: dict) -> dict: + """Call function to rerange images. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Reranged results. + """ + + img = results['img'] + img_min_value = np.min(img) + img_max_value = np.max(img) + + assert img_min_value < img_max_value + # rerange to [0, 1] + img = (img - img_min_value) / (img_max_value - img_min_value) + # rerange to [min_value, max_value] + img = img * (self.max_value - self.min_value) + self.min_value + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_value={self.min_value}, max_value={self.max_value})' + return repr_str + + +@TRANSFORMS.register_module() +class CLAHE(BaseTransform): + """Use CLAHE method to process the image. + + See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. + Graphics Gems, 1994:474-485.` for more information. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + clip_limit (float): Threshold for contrast limiting. Default: 40.0. + tile_grid_size (tuple[int]): Size of grid for histogram equalization. + Input image will be divided into equally sized rectangular tiles. + It defines the number of tiles in row and column. Default: (8, 8). + """ + + def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)): + assert isinstance(clip_limit, (float, int)) + self.clip_limit = clip_limit + assert is_tuple_of(tile_grid_size, int) + assert len(tile_grid_size) == 2 + self.tile_grid_size = tile_grid_size + + def transform(self, results: dict) -> dict: + """Call function to Use CLAHE method process images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + for i in range(results['img'].shape[2]): + results['img'][:, :, i] = mmcv.clahe( + np.array(results['img'][:, :, i], dtype=np.uint8), + self.clip_limit, self.tile_grid_size) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(clip_limit={self.clip_limit}, ' \ + f'tile_grid_size={self.tile_grid_size})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomCrop(BaseTransform): + """Random crop the image & seg. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - img_shape + - gt_seg_map + + + Args: + crop_size (Union[int, Tuple[int, int]]): Expected size after cropping + with the format of (h, w). If set to an integer, then cropping + width and height are equal to this integer. + cat_max_ratio (float): The maximum ratio that single category could + occupy. + ignore_index (int): The label index to be ignored. Default: 255 + """ + + def __init__(self, + crop_size: Union[int, Tuple[int, int]], + cat_max_ratio: float = 1., + ignore_index: int = 255): + super().__init__() + assert isinstance(crop_size, int) or ( + isinstance(crop_size, tuple) and len(crop_size) == 2 + ), 'The expected crop_size is an integer, or a tuple containing two ' + 'intergers' + + if isinstance(crop_size, int): + crop_size = (crop_size, crop_size) + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + self.cat_max_ratio = cat_max_ratio + self.ignore_index = ignore_index + + @cache_randomness + def crop_bbox(self, results: dict) -> tuple: + """get a crop bounding box. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + tuple: Coordinates of the cropped image. + """ + + def generate_crop_bbox(img: np.ndarray) -> tuple: + """Randomly get a crop bounding box. + + Args: + img (np.ndarray): Original input image. + + Returns: + tuple: Coordinates of the cropped image. + """ + + margin_h = max(img.shape[0] - self.crop_size[0], 0) + margin_w = max(img.shape[1] - self.crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + return crop_y1, crop_y2, crop_x1, crop_x2 + + img = results['img'] + crop_bbox = generate_crop_bbox(img) + if self.cat_max_ratio < 1.: + # Repeat 10 times + for _ in range(10): + seg_temp = self.crop(results['gt_seg_map'], crop_bbox) + labels, cnt = np.unique(seg_temp, return_counts=True) + cnt = cnt[labels != self.ignore_index] + if len(cnt) > 1 and np.max(cnt) / np.sum( + cnt) < self.cat_max_ratio: + break + crop_bbox = generate_crop_bbox(img) + + return crop_bbox + + def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: + """Crop from ``img`` + + Args: + img (np.ndarray): Original input image. + crop_bbox (tuple): Coordinates of the cropped image. + + Returns: + np.ndarray: The cropped image. + """ + + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + return img + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + + img = results['img'] + crop_bbox = self.crop_bbox(results) + + # crop the image + img = self.crop(img, crop_bbox) + + # crop semantic seg + for key in results.get('seg_fields', []): + results[key] = self.crop(results[key], crop_bbox) + + results['img'] = img + results['img_shape'] = img.shape[:2] + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_size={self.crop_size})' + + +@TRANSFORMS.register_module() +class RandomRotate(BaseTransform): + """Rotate the image & seg. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Args: + prob (float): The rotation probability. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + pad_val (float, optional): Padding value of image. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If not specified, the center of the image will be + used. Default: None. + auto_bound (bool): Whether to adjust the image size to cover the whole + rotated image. Default: False + """ + + def __init__(self, + prob, + degree, + pad_val=0, + seg_pad_val=255, + center=None, + auto_bound=False): + self.prob = prob + assert prob >= 0 and prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + self.pal_val = pad_val + self.seg_pad_val = seg_pad_val + self.center = center + self.auto_bound = auto_bound + + @cache_randomness + def generate_degree(self): + return np.random.rand() < self.prob, np.random.uniform( + min(*self.degree), max(*self.degree)) + + def transform(self, results: dict) -> dict: + """Call function to rotate image, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated results. + """ + + rotate, degree = self.generate_degree() + if rotate: + # rotate image + results['img'] = mmcv.imrotate( + results['img'], + angle=degree, + border_value=self.pal_val, + center=self.center, + auto_bound=self.auto_bound) + + # rotate segs + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate( + results[key], + angle=degree, + border_value=self.seg_pad_val, + center=self.center, + auto_bound=self.auto_bound, + interpolation='nearest') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' \ + f'degree={self.degree}, ' \ + f'pad_val={self.pal_val}, ' \ + f'seg_pad_val={self.seg_pad_val}, ' \ + f'center={self.center}, ' \ + f'auto_bound={self.auto_bound})' + return repr_str + + +@TRANSFORMS.register_module() +class RGB2Gray(BaseTransform): + """Convert RGB image to grayscale image. + + Required Keys: + + - img + + Modified Keys: + + - img + - img_shape + + This transform calculate the weighted mean of input image channels with + ``weights`` and then expand the channels to ``out_channels``. When + ``out_channels`` is None, the number of output channels is the same as + input channels. + + Args: + out_channels (int): Expected number of output channels after + transforming. Default: None. + weights (tuple[float]): The weights to calculate the weighted mean. + Default: (0.299, 0.587, 0.114). + """ + + def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)): + assert out_channels is None or out_channels > 0 + self.out_channels = out_channels + assert isinstance(weights, tuple) + for item in weights: + assert isinstance(item, (float, int)) + self.weights = weights + + def transform(self, results: dict) -> dict: + """Call function to convert RGB image to grayscale image. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with grayscale image. + """ + img = results['img'] + assert len(img.shape) == 3 + assert img.shape[2] == len(self.weights) + weights = np.array(self.weights).reshape((1, 1, -1)) + img = (img * weights).sum(2, keepdims=True) + if self.out_channels is None: + img = img.repeat(weights.shape[2], axis=2) + else: + img = img.repeat(self.out_channels, axis=2) + + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(out_channels={self.out_channels}, ' \ + f'weights={self.weights})' + return repr_str + + +@TRANSFORMS.register_module() +class AdjustGamma(BaseTransform): + """Using gamma correction to process the image. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + gamma (float or int): Gamma value used in gamma correction. + Default: 1.0. + """ + + def __init__(self, gamma=1.0): + assert isinstance(gamma, float) or isinstance(gamma, int) + assert gamma > 0 + self.gamma = gamma + inv_gamma = 1.0 / gamma + self.table = np.array([(i / 255.0)**inv_gamma * 255 + for i in np.arange(256)]).astype('uint8') + + def transform(self, results: dict) -> dict: + """Call function to process the image with gamma correction. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + results['img'] = mmcv.lut_transform( + np.array(results['img'], dtype=np.uint8), self.table) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(gamma={self.gamma})' + + +@TRANSFORMS.register_module() +class SegRescale(BaseTransform): + """Rescale semantic segmentation maps. + + Required Keys: + + - gt_seg_map + + Modified Keys: + + - gt_seg_map + + Args: + scale_factor (float): The scale factor of the final output. + """ + + def __init__(self, scale_factor=1): + self.scale_factor = scale_factor + + def transform(self, results: dict) -> dict: + """Call function to scale the semantic segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with semantic segmentation map scaled. + """ + for key in results.get('seg_fields', []): + if self.scale_factor != 1: + results[key] = mmcv.imrescale( + results[key], self.scale_factor, interpolation='nearest') + return results + + def __repr__(self): + return self.__class__.__name__ + f'(scale_factor={self.scale_factor})' + + +@TRANSFORMS.register_module() +class PhotoMetricDistortion(BaseTransform): + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. The position of random contrast is in + second or second to last. + + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__(self, + brightness_delta: int = 32, + contrast_range: Sequence[float] = (0.5, 1.5), + saturation_range: Sequence[float] = (0.5, 1.5), + hue_delta: int = 18): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def convert(self, + img: np.ndarray, + alpha: int = 1, + beta: int = 0) -> np.ndarray: + """Multiple with alpha and add beat with clip. + + Args: + img (np.ndarray): The input image. + alpha (int): Image weights, change the contrast/saturation + of the image. Default: 1 + beta (int): Image bias, change the brightness of the + image. Default: 0 + + Returns: + np.ndarray: The transformed image. + """ + + img = img.astype(np.float32) * alpha + beta + img = np.clip(img, 0, 255) + return img.astype(np.uint8) + + def brightness(self, img: np.ndarray) -> np.ndarray: + """Brightness distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after brightness change. + """ + + if random.randint(2): + return self.convert( + img, + beta=random.uniform(-self.brightness_delta, + self.brightness_delta)) + return img + + def contrast(self, img: np.ndarray) -> np.ndarray: + """Contrast distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after contrast change. + """ + + if random.randint(2): + return self.convert( + img, + alpha=random.uniform(self.contrast_lower, self.contrast_upper)) + return img + + def saturation(self, img: np.ndarray) -> np.ndarray: + """Saturation distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after saturation change. + """ + + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, 1] = self.convert( + img[:, :, 1], + alpha=random.uniform(self.saturation_lower, + self.saturation_upper)) + img = mmcv.hsv2bgr(img) + return img + + def hue(self, img: np.ndarray) -> np.ndarray: + """Hue distortion. + + Args: + img (np.ndarray): The input image. + Returns: + np.ndarray: Image after hue change. + """ + + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, + 0] = (img[:, :, 0].astype(int) + + random.randint(-self.hue_delta, self.hue_delta)) % 180 + img = mmcv.hsv2bgr(img) + return img + + def transform(self, results: dict) -> dict: + """Transform function to perform photometric distortion on images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + img = results['img'] + # random brightness + img = self.brightness(img) + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = random.randint(2) + if mode == 1: + img = self.contrast(img) + + # random saturation + img = self.saturation(img) + + # random hue + img = self.hue(img) + + # random contrast + if mode == 0: + img = self.contrast(img) + + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(brightness_delta={self.brightness_delta}, ' + f'contrast_range=({self.contrast_lower}, ' + f'{self.contrast_upper}), ' + f'saturation_range=({self.saturation_lower}, ' + f'{self.saturation_upper}), ' + f'hue_delta={self.hue_delta})') + return repr_str + + +@TRANSFORMS.register_module() +class RandomCutOut(BaseTransform): + """CutOut operation. + + Randomly drop some regions of image used in + `Cutout `_. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Args: + prob (float): cutout probability. + n_holes (int | tuple[int, int]): Number of regions to be dropped. + If it is given as a list, number of holes will be randomly + selected from the closed interval [`n_holes[0]`, `n_holes[1]`]. + cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate + shape of dropped regions. It can be `tuple[int, int]` to use a + fixed cutout shape, or `list[tuple[int, int]]` to randomly choose + shape from the list. + cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The + candidate ratio of dropped regions. It can be `tuple[float, float]` + to use a fixed ratio or `list[tuple[float, float]]` to randomly + choose ratio from the list. Please note that `cutout_shape` + and `cutout_ratio` cannot be both given at the same time. + fill_in (tuple[float, float, float] | tuple[int, int, int]): The value + of pixel to fill in the dropped regions. Default: (0, 0, 0). + seg_fill_in (int): The labels of pixel to fill in the dropped regions. + If seg_fill_in is None, skip. Default: None. + """ + + def __init__(self, + prob, + n_holes, + cutout_shape=None, + cutout_ratio=None, + fill_in=(0, 0, 0), + seg_fill_in=None): + + assert 0 <= prob and prob <= 1 + assert (cutout_shape is None) ^ (cutout_ratio is None), \ + 'Either cutout_shape or cutout_ratio should be specified.' + assert (isinstance(cutout_shape, (list, tuple)) + or isinstance(cutout_ratio, (list, tuple))) + if isinstance(n_holes, tuple): + assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1] + else: + n_holes = (n_holes, n_holes) + if seg_fill_in is not None: + assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in + and seg_fill_in <= 255) + self.prob = prob + self.n_holes = n_holes + self.fill_in = fill_in + self.seg_fill_in = seg_fill_in + self.with_ratio = cutout_ratio is not None + self.candidates = cutout_ratio if self.with_ratio else cutout_shape + if not isinstance(self.candidates, list): + self.candidates = [self.candidates] + + @cache_randomness + def do_cutout(self): + return np.random.rand() < self.prob + + @cache_randomness + def generate_patches(self, results): + cutout = self.do_cutout() + + h, w, _ = results['img'].shape + if cutout: + n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1) + else: + n_holes = 0 + x1_lst = [] + y1_lst = [] + index_lst = [] + for _ in range(n_holes): + x1_lst.append(np.random.randint(0, w)) + y1_lst.append(np.random.randint(0, h)) + index_lst.append(np.random.randint(0, len(self.candidates))) + return cutout, n_holes, x1_lst, y1_lst, index_lst + + def transform(self, results: dict) -> dict: + """Call function to drop some regions of image.""" + cutout, n_holes, x1_lst, y1_lst, index_lst = self.generate_patches( + results) + if cutout: + h, w, c = results['img'].shape + for i in range(n_holes): + x1 = x1_lst[i] + y1 = y1_lst[i] + index = index_lst[i] + if not self.with_ratio: + cutout_w, cutout_h = self.candidates[index] + else: + cutout_w = int(self.candidates[index][0] * w) + cutout_h = int(self.candidates[index][1] * h) + + x2 = np.clip(x1 + cutout_w, 0, w) + y2 = np.clip(y1 + cutout_h, 0, h) + results['img'][y1:y2, x1:x2, :] = self.fill_in + + if self.seg_fill_in is not None: + for key in results.get('seg_fields', []): + results[key][y1:y2, x1:x2] = self.seg_fill_in + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'n_holes={self.n_holes}, ' + repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio + else f'cutout_shape={self.candidates}, ') + repr_str += f'fill_in={self.fill_in}, ' + repr_str += f'seg_fill_in={self.seg_fill_in})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomRotFlip(BaseTransform): + """Rotate and flip the image & seg or just rotate the image & seg. + + Required Keys: + + - img + - gt_seg_map + + Modified Keys: + + - img + - gt_seg_map + + Args: + rotate_prob (float): The probability of rotate image. + flip_prob (float): The probability of rotate&flip image. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + """ + + def __init__(self, rotate_prob=0.5, flip_prob=0.5, degree=(-20, 20)): + self.rotate_prob = rotate_prob + self.flip_prob = flip_prob + assert 0 <= rotate_prob <= 1 and 0 <= flip_prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + + def random_rot_flip(self, results: dict) -> dict: + k = np.random.randint(0, 4) + results['img'] = np.rot90(results['img'], k) + for key in results.get('seg_fields', []): + results[key] = np.rot90(results[key], k) + axis = np.random.randint(0, 2) + results['img'] = np.flip(results['img'], axis=axis).copy() + for key in results.get('seg_fields', []): + results[key] = np.flip(results[key], axis=axis).copy() + return results + + def random_rotate(self, results: dict) -> dict: + angle = np.random.uniform(min(*self.degree), max(*self.degree)) + results['img'] = mmcv.imrotate(results['img'], angle=angle) + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate(results[key], angle=angle) + return results + + def transform(self, results: dict) -> dict: + """Call function to rotate or rotate & flip image, semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated or rotated & flipped results. + """ + rotate_flag = 0 + if random.random() < self.rotate_prob: + results = self.random_rotate(results) + rotate_flag = 1 + if random.random() < self.flip_prob and rotate_flag == 0: + results = self.random_rot_flip(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(rotate_prob={self.rotate_prob}, ' \ + f'flip_prob={self.flip_prob}, ' \ + f'degree={self.degree})' + return repr_str + + +@TRANSFORMS.register_module() +class RandomFlip(MMCV_RandomFlip): + """Flip the image & bbox & segmentation map. Added or Updated + keys: flip, flip_direction, img, gt_bboxes, gt_seg_map, and gt_depth_map. + There are 3 flip modes: + + - ``prob`` is float, ``direction`` is string: the image will be + ``direction``ly flipped with probability of ``prob`` . + E.g., ``prob=0.5``, ``direction='horizontal'``, + then image will be horizontally flipped with probability of 0.5. + + - ``prob`` is float, ``direction`` is list of string: the image will + be ``direction[i]``ly flipped with probability of + ``prob/len(direction)``. + E.g., ``prob=0.5``, ``direction=['horizontal', 'vertical']``, + then image will be horizontally flipped with probability of 0.25, + vertically with probability of 0.25. + + - ``prob`` is list of float, ``direction`` is list of string: + given ``len(prob) == len(direction)``, the image will + be ``direction[i]``ly flipped with probability of ``prob[i]``. + E.g., ``prob=[0.3, 0.5]``, ``direction=['horizontal', + 'vertical']``, then image will be horizontally flipped with + probability of 0.3, vertically with probability of 0.5. + + Required Keys: + + - img + - gt_bboxes (optional) + - gt_seg_map (optional) + - gt_depth_map (optional) + + Modified Keys: + + - img + - gt_bboxes (optional) + - gt_seg_map (optional) + - gt_depth_map (optional) + + Added Keys: + + - flip + - flip_direction + - swap_seg_labels (optional) + + Args: + prob (float | list[float], optional): The flipping probability. + Defaults to None. + direction(str | list[str]): The flipping direction. Options + If input is a list, the length must equal ``prob``. Each + element in ``prob`` indicates the flip probability of + corresponding direction. Defaults to 'horizontal'. + swap_seg_labels (list, optional): The label pair need to be swapped + for ground truth, like 'left arm' and 'right arm' need to be + swapped after horizontal flipping. For example, ``[(1, 5)]``, + where 1/5 is the label of the left/right arm. Defaults to None. + """ + + def _flip(self, results: dict) -> None: + """Flip images, bounding boxes and semantic segmentation map.""" + # flip image + results['img'] = mmcv.imflip( + results['img'], direction=results['flip_direction']) + + img_shape = results['img'].shape[:2] + + # flip bboxes + if results.get('gt_bboxes', None) is not None: + results['gt_bboxes'] = self._flip_bbox(results['gt_bboxes'], + img_shape, + results['flip_direction']) + + # flip seg map + for key in results.get('seg_fields', []): + if results.get(key, None) is not None: + results[key] = self._flip_seg_map( + results[key], direction=results['flip_direction']).copy() + results['swap_seg_labels'] = self.swap_seg_labels + + +@TRANSFORMS.register_module() +class Resize(MMCV_Resize): + """Resize images & seg & depth map. + + This transform resizes the input image according to ``scale`` or + ``scale_factor``. Seg map, depth map and other relative annotations are + then resized with the same scale factor. + if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to + resize. + + Required Keys: + + - img + - gt_seg_map (optional) + - gt_depth_map (optional) + + Modified Keys: + + - img + - gt_seg_map + - gt_depth_map + + Added Keys: + + - scale + - scale_factor + - keep_ratio + + Args: + scale (int or tuple): Images scales for resizing. Defaults to None + scale_factor (float or tuple[float]): Scale factors for resizing. + Defaults to None. + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. Defaults to False. + clip_object_border (bool): Whether to clip the objects + outside the border of the image. In some dataset like MOT17, the gt + bboxes are allowed to cross the border of images. Therefore, we + don't need to clip the gt bboxes in these cases. Defaults to True. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. Defaults + to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. Defaults + to 'bilinear'. + """ + + def _resize_seg(self, results: dict) -> None: + """Resize semantic segmentation map with ``results['scale']``.""" + for seg_key in results.get('seg_fields', []): + if results.get(seg_key, None) is not None: + if self.keep_ratio: + gt_seg = mmcv.imrescale( + results[seg_key], + results['scale'], + interpolation='nearest', + backend=self.backend) + else: + gt_seg = mmcv.imresize( + results[seg_key], + results['scale'], + interpolation='nearest', + backend=self.backend) + results[seg_key] = gt_seg + + +@TRANSFORMS.register_module() +class RandomMosaic(BaseTransform): + """Mosaic augmentation. Given 4 images, mosaic transform combines them into + one output image. The output image is composed of the parts from each sub- + image. + + .. code:: text + + mosaic transform + center_x + +------------------------------+ + | pad | pad | + | +-----------+ | + | | | | + | | image1 |--------+ | + | | | | | + | | | image2 | | + center_y |----+-------------+-----------| + | | cropped | | + |pad | image3 | image4 | + | | | | + +----|-------------+-----------+ + | | + +-------------+ + + The mosaic transform steps are as follows: + 1. Choose the mosaic center as the intersections of 4 images + 2. Get the left top image according to the index, and randomly + sample another 3 images from the custom dataset. + 3. Sub image will be cropped if image is larger than mosaic patch + + Required Keys: + + - img + - gt_seg_map + - mix_results + + Modified Keys: + + - img + - img_shape + - ori_shape + - gt_seg_map + + Args: + prob (float): mosaic probability. + img_scale (Sequence[int]): Image size after mosaic pipeline of + a single image. The size of the output image is four times + that of a single image. The output image comprises 4 single images. + Default: (640, 640). + center_ratio_range (Sequence[float]): Center ratio range of mosaic + output. Default: (0.5, 1.5). + pad_val (int): Pad value. Default: 0. + seg_pad_val (int): Pad value of segmentation map. Default: 255. + """ + + def __init__(self, + prob, + img_scale=(640, 640), + center_ratio_range=(0.5, 1.5), + pad_val=0, + seg_pad_val=255): + assert 0 <= prob and prob <= 1 + assert isinstance(img_scale, tuple) + self.prob = prob + self.img_scale = img_scale + self.center_ratio_range = center_ratio_range + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + @cache_randomness + def do_mosaic(self): + return np.random.rand() < self.prob + + def transform(self, results: dict) -> dict: + """Call function to make a mosaic of image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with mosaic transformed. + """ + mosaic = self.do_mosaic() + if mosaic: + results = self._mosaic_transform_img(results) + results = self._mosaic_transform_seg(results) + return results + + def get_indices(self, dataset: MultiImageMixDataset) -> list: + """Call function to collect indices. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + + Returns: + list: indices. + """ + + indices = [random.randint(0, len(dataset)) for _ in range(3)] + return indices + + @cache_randomness + def generate_mosaic_center(self): + # mosaic center x, y + center_x = int( + random.uniform(*self.center_ratio_range) * self.img_scale[1]) + center_y = int( + random.uniform(*self.center_ratio_range) * self.img_scale[0]) + return center_x, center_y + + def _mosaic_transform_img(self, results: dict) -> dict: + """Mosaic transform function. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert 'mix_results' in results + if len(results['img'].shape) == 3: + c = results['img'].shape[2] + mosaic_img = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), c), + self.pad_val, + dtype=results['img'].dtype) + else: + mosaic_img = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), + self.pad_val, + dtype=results['img'].dtype) + + # mosaic center x, y + self.center_x, self.center_y = self.generate_mosaic_center() + center_position = (self.center_x, self.center_y) + + loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + for i, loc in enumerate(loc_strs): + if loc == 'top_left': + result_patch = copy.deepcopy(results) + else: + result_patch = copy.deepcopy(results['mix_results'][i - 1]) + + img_i = result_patch['img'] + h_i, w_i = img_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[0] / h_i, + self.img_scale[1] / w_i) + img_i = mmcv.imresize( + img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i))) + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, img_i.shape[:2][::-1]) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c] + + results['img'] = mosaic_img + results['img_shape'] = mosaic_img.shape + results['ori_shape'] = mosaic_img.shape + + return results + + def _mosaic_transform_seg(self, results: dict) -> dict: + """Mosaic transform function for label annotations. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert 'mix_results' in results + for key in results.get('seg_fields', []): + mosaic_seg = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), + self.seg_pad_val, + dtype=results[key].dtype) + + # mosaic center x, y + center_position = (self.center_x, self.center_y) + + loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + for i, loc in enumerate(loc_strs): + if loc == 'top_left': + result_patch = copy.deepcopy(results) + else: + result_patch = copy.deepcopy(results['mix_results'][i - 1]) + + gt_seg_i = result_patch[key] + h_i, w_i = gt_seg_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[0] / h_i, + self.img_scale[1] / w_i) + gt_seg_i = mmcv.imresize( + gt_seg_i, + (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)), + interpolation='nearest') + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, gt_seg_i.shape[:2][::-1]) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_seg[y1_p:y2_p, x1_p:x2_p] = \ + gt_seg_i[y1_c:y2_c, x1_c:x2_c] + + results[key] = mosaic_seg + + return results + + def _mosaic_combine(self, loc: str, center_position_xy: Sequence[float], + img_shape_wh: Sequence[int]) -> tuple: + """Calculate global coordinate of mosaic image and local coordinate of + cropped sub-image. + + Args: + loc (str): Index for the sub-image, loc in ('top_left', + 'top_right', 'bottom_left', 'bottom_right'). + center_position_xy (Sequence[float]): Mixing center for 4 images, + (x, y). + img_shape_wh (Sequence[int]): Width and height of sub-image + + Returns: + tuple[tuple[float]]: Corresponding coordinate of pasting and + cropping + - paste_coord (tuple): paste corner coordinate in mosaic image. + - crop_coord (tuple): crop corner coordinate in mosaic image. + """ + + assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right') + if loc == 'top_left': + # index0 to top left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + center_position_xy[0], \ + center_position_xy[1] + crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - ( + y2 - y1), img_shape_wh[0], img_shape_wh[1] + + elif loc == 'top_right': + # index1 to top right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[1] * 2), \ + center_position_xy[1] + crop_coord = 0, img_shape_wh[1] - (y2 - y1), min( + img_shape_wh[0], x2 - x1), img_shape_wh[1] + + elif loc == 'bottom_left': + # index2 to bottom left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + center_position_xy[1], \ + center_position_xy[0], \ + min(self.img_scale[0] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min( + y2 - y1, img_shape_wh[1]) + + else: + # index3 to bottom right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + center_position_xy[1], \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[1] * 2), \ + min(self.img_scale[0] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = 0, 0, min(img_shape_wh[0], + x2 - x1), min(y2 - y1, img_shape_wh[1]) + + paste_coord = x1, y1, x2, y2 + return paste_coord, crop_coord + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'img_scale={self.img_scale}, ' + repr_str += f'center_ratio_range={self.center_ratio_range}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'seg_pad_val={self.pad_val})' + return repr_str + + +@TRANSFORMS.register_module() +class GenerateEdge(BaseTransform): + """Generate Edge for CE2P approach. + + Edge will be used to calculate loss of + `CE2P `_. + + Modified from https://github.com/liutinglt/CE2P/blob/master/dataset/target_generation.py # noqa:E501 + + Required Keys: + + - img_shape + - gt_seg_map + + Added Keys: + - gt_edge_map (np.ndarray, uint8): The edge annotation generated from the + seg map by extracting border between different semantics. + + Args: + edge_width (int): The width of edge. Default to 3. + ignore_index (int): Index that will be ignored. Default to 255. + """ + + def __init__(self, edge_width: int = 3, ignore_index: int = 255) -> None: + super().__init__() + self.edge_width = edge_width + self.ignore_index = ignore_index + + def transform(self, results: Dict) -> Dict: + """Call function to generate edge from segmentation map. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with edge mask. + """ + h, w = results['img_shape'] + edge = np.zeros((h, w), dtype=np.uint8) + seg_map = results['gt_seg_map'] + + # down + edge_down = edge[1:h, :] + edge_down[(seg_map[1:h, :] != seg_map[:h - 1, :]) + & (seg_map[1:h, :] != self.ignore_index) & + (seg_map[:h - 1, :] != self.ignore_index)] = 1 + # left + edge_left = edge[:, :w - 1] + edge_left[(seg_map[:, :w - 1] != seg_map[:, 1:w]) + & (seg_map[:, :w - 1] != self.ignore_index) & + (seg_map[:, 1:w] != self.ignore_index)] = 1 + # up_left + edge_upleft = edge[:h - 1, :w - 1] + edge_upleft[(seg_map[:h - 1, :w - 1] != seg_map[1:h, 1:w]) + & (seg_map[:h - 1, :w - 1] != self.ignore_index) & + (seg_map[1:h, 1:w] != self.ignore_index)] = 1 + # up_right + edge_upright = edge[:h - 1, 1:w] + edge_upright[(seg_map[:h - 1, 1:w] != seg_map[1:h, :w - 1]) + & (seg_map[:h - 1, 1:w] != self.ignore_index) & + (seg_map[1:h, :w - 1] != self.ignore_index)] = 1 + + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, + (self.edge_width, self.edge_width)) + edge = cv2.dilate(edge, kernel) + + results['gt_edge_map'] = edge + results['edge_width'] = self.edge_width + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'edge_width={self.edge_width}, ' + repr_str += f'ignore_index={self.ignore_index})' + return repr_str + + +@TRANSFORMS.register_module() +class ResizeShortestEdge(BaseTransform): + """Resize the image and mask while keeping the aspect ratio unchanged. + + Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501 + Copyright (c) Facebook, Inc. and its affiliates. + Licensed under the Apache-2.0 License + + This transform attempts to scale the shorter edge to the given + `scale`, as long as the longer edge does not exceed `max_size`. + If `max_size` is reached, then downscale so that the longer + edge does not exceed `max_size`. + + Required Keys: + + - img + - gt_seg_map (optional) + + Modified Keys: + + - img + - img_shape + - gt_seg_map (optional)) + + Added Keys: + + - scale + - scale_factor + - keep_ratio + + + Args: + scale (Union[int, Tuple[int, int]]): The target short edge length. + If it's tuple, will select the min value as the short edge length. + max_size (int): The maximum allowed longest edge length. + """ + + def __init__(self, scale: Union[int, Tuple[int, int]], + max_size: int) -> None: + super().__init__() + self.scale = scale + self.max_size = max_size + + # Create a empty Resize object + self.resize = TRANSFORMS.build({ + 'type': 'Resize', + 'scale': 0, + 'keep_ratio': True + }) + + def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]: + """Compute the target image shape with the given `short_edge_length`. + + Args: + img (np.ndarray): The input image. + short_edge_length (Union[int, Tuple[int, int]]): The target short + edge length. If it's tuple, will select the min value as the + short edge length. + """ + h, w = img.shape[:2] + if isinstance(short_edge_length, int): + size = short_edge_length * 1.0 + elif isinstance(short_edge_length, tuple): + size = min(short_edge_length) * 1.0 + scale = size / min(h, w) + if h < w: + new_h, new_w = size, scale * w + else: + new_h, new_w = scale * h, size + + if max(new_h, new_w) > self.max_size: + scale = self.max_size * 1.0 / max(new_h, new_w) + new_h *= scale + new_w *= scale + + new_h = int(new_h + 0.5) + new_w = int(new_w + 0.5) + return (new_w, new_h) + + def transform(self, results: Dict) -> Dict: + self.resize.scale = self._get_output_shape(results['img'], self.scale) + return self.resize(results) + + +@TRANSFORMS.register_module() +class BioMedical3DRandomCrop(BaseTransform): + """Crop the input patch for medical image & segmentation mask. + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + - gt_seg_map (np.ndarray, optional): Biomedical semantic segmentation mask + with shape (Z, Y, X). + + Modified Keys: + + - img + - img_shape + - gt_seg_map (optional) + + Args: + crop_shape (Union[int, Tuple[int, int, int]]): Expected size after + cropping with the format of (z, y, x). If set to an integer, + then cropping width and height are equal to this integer. + keep_foreground (bool): If keep_foreground is True, it will sample a + voxel of foreground classes randomly, and will take it as the + center of the crop bounding-box. Default to True. + """ + + def __init__(self, + crop_shape: Union[int, Tuple[int, int, int]], + keep_foreground: bool = True): + super().__init__() + assert isinstance(crop_shape, int) or ( + isinstance(crop_shape, tuple) and len(crop_shape) == 3 + ), 'The expected crop_shape is an integer, or a tuple containing ' + 'three integers' + + if isinstance(crop_shape, int): + crop_shape = (crop_shape, crop_shape, crop_shape) + assert crop_shape[0] > 0 and crop_shape[1] > 0 and crop_shape[2] > 0 + self.crop_shape = crop_shape + self.keep_foreground = keep_foreground + + def random_sample_location(self, seg_map: np.ndarray) -> dict: + """sample foreground voxel when keep_foreground is True. + + Args: + seg_map (np.ndarray): gt seg map. + + Returns: + dict: Coordinates of selected foreground voxel. + """ + num_samples = 10000 + # at least 1% of the class voxels need to be selected, + # otherwise it may be too sparse + min_percent_coverage = 0.01 + class_locs = {} + foreground_classes = [] + all_classes = np.unique(seg_map) + for c in all_classes: + if c == 0: + # to avoid the segmentation mask full of background 0 + # and the class_locs is just void dictionary {} when it return + # there add a void list for background 0. + class_locs[c] = [] + else: + all_locs = np.argwhere(seg_map == c) + target_num_samples = min(num_samples, len(all_locs)) + target_num_samples = max( + target_num_samples, + int(np.ceil(len(all_locs) * min_percent_coverage))) + + selected = all_locs[np.random.choice( + len(all_locs), target_num_samples, replace=False)] + class_locs[c] = selected + foreground_classes.append(c) + + selected_voxel = None + if len(foreground_classes) > 0: + selected_class = np.random.choice(foreground_classes) + voxels_of_that_class = class_locs[selected_class] + selected_voxel = voxels_of_that_class[np.random.choice( + len(voxels_of_that_class))] + + return selected_voxel + + def random_generate_crop_bbox(self, margin_z: int, margin_y: int, + margin_x: int) -> tuple: + """Randomly get a crop bounding box. + + Args: + seg_map (np.ndarray): Ground truth segmentation map. + + Returns: + tuple: Coordinates of the cropped image. + """ + offset_z = np.random.randint(0, margin_z + 1) + offset_y = np.random.randint(0, margin_y + 1) + offset_x = np.random.randint(0, margin_x + 1) + crop_z1, crop_z2 = offset_z, offset_z + self.crop_shape[0] + crop_y1, crop_y2 = offset_y, offset_y + self.crop_shape[1] + crop_x1, crop_x2 = offset_x, offset_x + self.crop_shape[2] + + return crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 + + def generate_margin(self, results: dict) -> tuple: + """Generate margin of crop bounding-box. + + If keep_foreground is True, it will sample a voxel of foreground + classes randomly, and will take it as the center of the bounding-box, + and return the margin between of the bounding-box and image. + If keep_foreground is False, it will return the difference from crop + shape and image shape. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + tuple: The margin for 3 dimensions of crop bounding-box and image. + """ + + seg_map = results['gt_seg_map'] + if self.keep_foreground: + selected_voxel = self.random_sample_location(seg_map) + if selected_voxel is None: + # this only happens if some image does not contain + # foreground voxels at all + warnings.warn(f'case does not contain any foreground classes' + f': {results["img_path"]}') + margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0) + margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0) + margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0) + else: + margin_z = max(0, selected_voxel[0] - self.crop_shape[0] // 2) + margin_y = max(0, selected_voxel[1] - self.crop_shape[1] // 2) + margin_x = max(0, selected_voxel[2] - self.crop_shape[2] // 2) + margin_z = max( + 0, min(seg_map.shape[0] - self.crop_shape[0], margin_z)) + margin_y = max( + 0, min(seg_map.shape[1] - self.crop_shape[1], margin_y)) + margin_x = max( + 0, min(seg_map.shape[2] - self.crop_shape[2], margin_x)) + else: + margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0) + margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0) + margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0) + + return margin_z, margin_y, margin_x + + def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: + """Crop from ``img`` + + Args: + img (np.ndarray): Original input image. + crop_bbox (tuple): Coordinates of the cropped image. + + Returns: + np.ndarray: The cropped image. + """ + crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + if len(img.shape) == 3: + # crop seg map + img = img[crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2] + else: + # crop image + assert len(img.shape) == 4 + img = img[:, crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2] + return img + + def transform(self, results: dict) -> dict: + """Transform function to randomly crop images, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + margin = self.generate_margin(results) + crop_bbox = self.random_generate_crop_bbox(*margin) + + # crop the image + img = results['img'] + results['img'] = self.crop(img, crop_bbox) + results['img_shape'] = results['img'].shape[1:] + + # crop semantic seg + seg_map = results['gt_seg_map'] + results['gt_seg_map'] = self.crop(seg_map, crop_bbox) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_shape={self.crop_shape})' + + +@TRANSFORMS.register_module() +class BioMedicalGaussianNoise(BaseTransform): + """Add random Gaussian noise to image. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L53 # noqa:E501 + + Copyright (c) German Cancer Research Center (DKFZ) + Licensed under the Apache License, Version 2.0 + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + + Modified Keys: + + - img + + Args: + prob (float): Probability to add Gaussian noise for + each sample. Default to 0.1. + mean (float): Mean or “centre” of the distribution. Default to 0.0. + std (float): Standard deviation of distribution. Default to 0.1. + """ + + def __init__(self, + prob: float = 0.1, + mean: float = 0.0, + std: float = 0.1) -> None: + super().__init__() + assert 0.0 <= prob <= 1.0 and std >= 0.0 + self.prob = prob + self.mean = mean + self.std = std + + def transform(self, results: Dict) -> Dict: + """Call function to add random Gaussian noise to image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with random Gaussian noise. + """ + if np.random.rand() < self.prob: + rand_std = np.random.uniform(0, self.std) + noise = np.random.normal( + self.mean, rand_std, size=results['img'].shape) + # noise is float64 array, convert to the results['img'].dtype + noise = noise.astype(results['img'].dtype) + results['img'] = results['img'] + noise + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'mean={self.mean}, ' + repr_str += f'std={self.std})' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedicalGaussianBlur(BaseTransform): + """Add Gaussian blur with random sigma to image. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L81 # noqa:E501 + + Copyright (c) German Cancer Research Center (DKFZ) + Licensed under the Apache License, Version 2.0 + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + + Modified Keys: + + - img + + Args: + sigma_range (Tuple[float, float]|float): range to randomly + select sigma value. Default to (0.5, 1.0). + prob (float): Probability to apply Gaussian blur + for each sample. Default to 0.2. + prob_per_channel (float): Probability to apply Gaussian blur + for each channel (axis N of the image). Default to 0.5. + different_sigma_per_channel (bool): whether to use different + sigma for each channel (axis N of the image). Default to True. + different_sigma_per_axis (bool): whether to use different + sigma for axis Z, X and Y of the image. Default to True. + """ + + def __init__(self, + sigma_range: Tuple[float, float] = (0.5, 1.0), + prob: float = 0.2, + prob_per_channel: float = 0.5, + different_sigma_per_channel: bool = True, + different_sigma_per_axis: bool = True) -> None: + super().__init__() + assert 0.0 <= prob <= 1.0 + assert 0.0 <= prob_per_channel <= 1.0 + assert isinstance(sigma_range, Sequence) and len(sigma_range) == 2 + self.sigma_range = sigma_range + self.prob = prob + self.prob_per_channel = prob_per_channel + self.different_sigma_per_channel = different_sigma_per_channel + self.different_sigma_per_axis = different_sigma_per_axis + + def _get_valid_sigma(self, value_range) -> Tuple[float, ...]: + """Ensure the `value_range` to be either a single value or a sequence + of two values. If the `value_range` is a sequence, generate a random + value with `[value_range[0], value_range[1]]` based on uniform + sampling. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/augmentations/utils.py#L625 # noqa:E501 + + Args: + value_range (tuple|list|float|int): the input value range + """ + if (isinstance(value_range, (list, tuple))): + if (value_range[0] == value_range[1]): + value = value_range[0] + else: + orig_type = type(value_range[0]) + value = np.random.uniform(value_range[0], value_range[1]) + value = orig_type(value) + return value + + def _gaussian_blur(self, data_sample: np.ndarray) -> np.ndarray: + """Random generate sigma and apply Gaussian Blur to the data + Args: + data_sample (np.ndarray): data sample with multiple modalities, + the data shape is (N, Z, Y, X) + """ + sigma = None + for c in range(data_sample.shape[0]): + if np.random.rand() < self.prob_per_channel: + # if no `sigma` is generated, generate one + # if `self.different_sigma_per_channel` is True, + # re-generate random sigma for each channel + if (sigma is None or self.different_sigma_per_channel): + if (not self.different_sigma_per_axis): + sigma = self._get_valid_sigma(self.sigma_range) + else: + sigma = [ + self._get_valid_sigma(self.sigma_range) + for _ in data_sample.shape[1:] + ] + # apply gaussian filter with `sigma` + data_sample[c] = gaussian_filter( + data_sample[c], sigma, order=0) + return data_sample + + def transform(self, results: Dict) -> Dict: + """Call function to add random Gaussian blur to image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with random Gaussian noise. + """ + if np.random.rand() < self.prob: + results['img'] = self._gaussian_blur(results['img']) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'prob_per_channel={self.prob_per_channel}, ' + repr_str += f'sigma_range={self.sigma_range}, ' + repr_str += 'different_sigma_per_channel=' \ + f'{self.different_sigma_per_channel}, ' + repr_str += 'different_sigma_per_axis=' \ + f'{self.different_sigma_per_axis})' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedicalRandomGamma(BaseTransform): + """Using random gamma correction to process the biomedical image. + + Modified from + https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/color_transforms.py#L132 # noqa:E501 + With licence: Apache 2.0 + + Required Keys: + + - img (np.ndarray): Biomedical image with shape (N, Z, Y, X), + N is the number of modalities, and data type is float32. + + Modified Keys: + - img + + Args: + prob (float): The probability to perform this transform. Default: 0.5. + gamma_range (Tuple[float]): Range of gamma values. Default: (0.5, 2). + invert_image (bool): Whether invert the image before applying gamma + augmentation. Default: False. + per_channel (bool): Whether perform the transform each channel + individually. Default: False + retain_stats (bool): Gamma transformation will alter the mean and std + of the data in the patch. If retain_stats=True, the data will be + transformed to match the mean and standard deviation before gamma + augmentation. Default: False. + """ + + def __init__(self, + prob: float = 0.5, + gamma_range: Tuple[float] = (0.5, 2), + invert_image: bool = False, + per_channel: bool = False, + retain_stats: bool = False): + assert 0 <= prob and prob <= 1 + assert isinstance(gamma_range, tuple) and len(gamma_range) == 2 + assert isinstance(invert_image, bool) + assert isinstance(per_channel, bool) + assert isinstance(retain_stats, bool) + self.prob = prob + self.gamma_range = gamma_range + self.invert_image = invert_image + self.per_channel = per_channel + self.retain_stats = retain_stats + + @cache_randomness + def _do_gamma(self): + """Whether do adjust gamma for image.""" + return np.random.rand() < self.prob + + def _adjust_gamma(self, img: np.array): + """Gamma adjustment for image. + + Args: + img (np.array): Input image before gamma adjust. + + Returns: + np.arrays: Image after gamma adjust. + """ + + if self.invert_image: + img = -img + + def _do_adjust(img): + if retain_stats_here: + img_mean = img.mean() + img_std = img.std() + if np.random.random() < 0.5 and self.gamma_range[0] < 1: + gamma = np.random.uniform(self.gamma_range[0], 1) + else: + gamma = np.random.uniform( + max(self.gamma_range[0], 1), self.gamma_range[1]) + img_min = img.min() + img_range = img.max() - img_min # range + img = np.power(((img - img_min) / float(img_range + 1e-7)), + gamma) * img_range + img_min + if retain_stats_here: + img = img - img.mean() + img = img / (img.std() + 1e-8) * img_std + img = img + img_mean + return img + + if not self.per_channel: + retain_stats_here = self.retain_stats + img = _do_adjust(img) + else: + for c in range(img.shape[0]): + img[c] = _do_adjust(img[c]) + if self.invert_image: + img = -img + return img + + def transform(self, results: dict) -> dict: + """Call function to perform random gamma correction + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with random gamma correction performed. + """ + do_gamma = self._do_gamma() + + if do_gamma: + results['img'] = self._adjust_gamma(results['img']) + else: + pass + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'gamma_range={self.gamma_range},' + repr_str += f'invert_image={self.invert_image},' + repr_str += f'per_channel={self.per_channel},' + repr_str += f'retain_stats={self.retain_stats}' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedical3DPad(BaseTransform): + """Pad the biomedical 3d image & biomedical 3d semantic segmentation maps. + + Required Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Modified Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Added Keys: + + - pad_shape (Tuple[int, int, int]): The padded shape. + + Args: + pad_shape (Tuple[int, int, int]): Fixed padding size. + Expected padding shape (Z, Y, X). + pad_val (float): Padding value for biomedical image. + The padding mode is set to "constant". The value + to be filled in padding area. Default: 0. + seg_pad_val (int): Padding value for biomedical 3d semantic + segmentation maps. The padding mode is set to "constant". + The value to be filled in padding area. Default: 0. + """ + + def __init__(self, + pad_shape: Tuple[int, int, int], + pad_val: float = 0., + seg_pad_val: int = 0) -> None: + + # check pad_shape + assert pad_shape is not None + if not isinstance(pad_shape, tuple): + assert len(pad_shape) == 3 + + self.pad_shape = pad_shape + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + def _pad_img(self, results: dict) -> None: + """Pad images according to ``self.pad_shape`` + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: The dict contains the padded image and shape + information. + """ + padded_img = self._to_pad( + results['img'], pad_shape=self.pad_shape, pad_val=self.pad_val) + + results['img'] = padded_img + results['pad_shape'] = padded_img.shape[1:] + + def _pad_seg(self, results: dict) -> None: + """Pad semantic segmentation map according to ``self.pad_shape`` if + ``gt_seg_map`` is not None in results dict. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Update the padded gt seg map in dict. + """ + if results.get('gt_seg_map', None) is not None: + pad_gt_seg = self._to_pad( + results['gt_seg_map'][None, ...], + pad_shape=results['pad_shape'], + pad_val=self.seg_pad_val) + results['gt_seg_map'] = pad_gt_seg[1:] + + @staticmethod + def _to_pad(img: np.ndarray, + pad_shape: Tuple[int, int, int], + pad_val: Union[int, float] = 0) -> np.ndarray: + """Pad the given 3d image to a certain shape with specified padding + value. + + Args: + img (ndarray): Biomedical image with shape (N, Z, Y, X) + to be padded. N is the number of modalities. + pad_shape (Tuple[int,int,int]): Expected padding shape (Z, Y, X). + pad_val (float, int): Values to be filled in padding areas + and the padding_mode is set to 'constant'. Default: 0. + + Returns: + ndarray: The padded image. + """ + # compute pad width + d = max(pad_shape[0] - img.shape[1], 0) + pad_d = (d // 2, d - d // 2) + h = max(pad_shape[1] - img.shape[2], 0) + pad_h = (h // 2, h - h // 2) + w = max(pad_shape[2] - img.shape[2], 0) + pad_w = (w // 2, w - w // 2) + + pad_list = [(0, 0), pad_d, pad_h, pad_w] + + img = np.pad(img, pad_list, mode='constant', constant_values=pad_val) + return img + + def transform(self, results: dict) -> dict: + """Call function to pad images, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + self._pad_img(results) + self._pad_seg(results) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'pad_shape={self.pad_shape}, ' + repr_str += f'pad_val={self.pad_val}), ' + repr_str += f'seg_pad_val={self.seg_pad_val})' + return repr_str + + +@TRANSFORMS.register_module() +class BioMedical3DRandomFlip(BaseTransform): + """Flip biomedical 3D images and segmentations. + + Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/spatial_transforms.py # noqa:E501 + + Copyright 2021 Division of + Medical Image Computing, German Cancer Research Center (DKFZ) and Applied + Computer Vision Lab, Helmholtz Imaging Platform. + Licensed under the Apache-2.0 License. + + Required Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Modified Keys: + + - img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default, + N is the number of modalities. + - gt_seg_map (np.ndarray, optional): Biomedical seg map with shape + (Z, Y, X) by default. + + Added Keys: + + - do_flip + - flip_axes + + Args: + prob (float): Flipping probability. + axes (Tuple[int, ...]): Flipping axes with order 'ZXY'. + swap_label_pairs (Optional[List[Tuple[int, int]]]): + The segmentation label pairs that are swapped when flipping. + """ + + def __init__(self, + prob: float, + axes: Tuple[int, ...], + swap_label_pairs: Optional[List[Tuple[int, int]]] = None): + self.prob = prob + self.axes = axes + self.swap_label_pairs = swap_label_pairs + assert prob >= 0 and prob <= 1 + if axes is not None: + assert max(axes) <= 2 + + @staticmethod + def _flip(img, direction: Tuple[bool, bool, bool]) -> np.ndarray: + if direction[0]: + img[:, :] = img[:, ::-1] + if direction[1]: + img[:, :, :] = img[:, :, ::-1] + if direction[2]: + img[:, :, :, :] = img[:, :, :, ::-1] + return img + + def _do_flip(self, img: np.ndarray) -> Tuple[bool, bool, bool]: + """Call function to determine which axis to flip. + + Args: + img (np.ndarry): Image or segmentation map array. + Returns: + tuple: Flip action, whether to flip on the z, x, and y axes. + """ + flip_c, flip_x, flip_y = False, False, False + if self.axes is not None: + flip_c = 0 in self.axes and np.random.rand() < self.prob + flip_x = 1 in self.axes and np.random.rand() < self.prob + if len(img.shape) == 4: + flip_y = 2 in self.axes and np.random.rand() < self.prob + return flip_c, flip_x, flip_y + + def _swap_label(self, seg: np.ndarray) -> np.ndarray: + out = seg.copy() + for first, second in self.swap_label_pairs: + first_area = (seg == first) + second_area = (seg == second) + out[first_area] = second + out[second_area] = first + return out + + def transform(self, results: Dict) -> Dict: + """Call function to flip and swap pair labels. + + Args: + results (dict): Result dict. + Returns: + dict: Flipped results, 'do_flip', 'flip_axes' keys are added into + result dict. + """ + # get actual flipped axis + if 'do_flip' not in results: + results['do_flip'] = self._do_flip(results['img']) + if 'flip_axes' not in results: + results['flip_axes'] = self.axes + # flip image + results['img'] = self._flip( + results['img'], direction=results['do_flip']) + # flip seg + if results['gt_seg_map'] is not None: + if results['gt_seg_map'].shape != results['img'].shape: + results['gt_seg_map'] = results['gt_seg_map'][None, :] + results['gt_seg_map'] = self._flip( + results['gt_seg_map'], direction=results['do_flip']) + results['gt_seg_map'] = results['gt_seg_map'].squeeze() + # swap label pairs + if self.swap_label_pairs is not None: + results['gt_seg_map'] = self._swap_label(results['gt_seg_map']) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, axes={self.axes}, ' \ + f'swap_label_pairs={self.swap_label_pairs})' + return repr_str + + +@TRANSFORMS.register_module() +class Albu(BaseTransform): + """Albumentation augmentation. Adds custom transformations from + Albumentations library. Please, visit + `https://albumentations.readthedocs.io` to get more information. An example + of ``transforms`` is as followed: + + .. code-block:: + [ + dict( + type='ShiftScaleRotate', + shift_limit=0.0625, + scale_limit=0.0, + rotate_limit=0, + interpolation=1, + p=0.5), + dict( + type='RandomBrightnessContrast', + brightness_limit=[0.1, 0.3], + contrast_limit=[0.1, 0.3], + p=0.2), + dict(type='ChannelShuffle', p=0.1), + dict( + type='OneOf', + transforms=[ + dict(type='Blur', blur_limit=3, p=1.0), + dict(type='MedianBlur', blur_limit=3, p=1.0) + ], + p=0.1), + ] + Args: + transforms (list[dict]): A list of albu transformations + keymap (dict): Contains {'input key':'albumentation-style key'} + additional_targets(dict): Allows applying same augmentations to \ + multiple objects of same type. + update_pad_shape (bool): Whether to update padding shape according to \ + the output shape of the last transform + bgr_to_rgb (bool): Whether to convert the band order to RGB + """ + + def __init__(self, + transforms: List[dict], + keymap: Optional[dict] = None, + additional_targets: Optional[dict] = None, + update_pad_shape: bool = False, + bgr_to_rgb: bool = True): + if not ALBU_INSTALLED: + raise ImportError( + 'albumentations is not installed, ' + 'we suggest install albumentation by ' + '"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa + ) + + # Args will be modified later, copying it will be safer + transforms = copy.deepcopy(transforms) + + self.transforms = transforms + self.keymap = keymap + self.additional_targets = additional_targets + self.update_pad_shape = update_pad_shape + self.bgr_to_rgb = bgr_to_rgb + + self.aug = Compose([self.albu_builder(t) for t in self.transforms], + additional_targets=self.additional_targets) + + if not keymap: + self.keymap_to_albu = {'img': 'image', 'gt_seg_map': 'mask'} + else: + self.keymap_to_albu = copy.deepcopy(keymap) + self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()} + + def albu_builder(self, cfg: dict) -> object: + """Build a callable object from a dict containing albu arguments. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + + Returns: + Callable: A callable object. + """ + + assert isinstance(cfg, dict) and 'type' in cfg + args = cfg.copy() + + obj_type = args.pop('type') + if mmengine.is_str(obj_type): + if not ALBU_INSTALLED: + raise ImportError( + 'albumentations is not installed, ' + 'we suggest install albumentation by ' + '"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa + ) + obj_cls = getattr(albumentations, obj_type) + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a valid type or str, but got {type(obj_type)}') + + if 'transforms' in args: + args['transforms'] = [ + self.albu_builder(t) for t in args['transforms'] + ] + + return obj_cls(**args) + + @staticmethod + def mapper(d: dict, keymap: dict): + """Dictionary mapper. + + Renames keys according to keymap provided. + Args: + d (dict): old dict + keymap (dict): {'old_key':'new_key'} + Returns: + dict: new dict. + """ + + updated_dict = {} + for k, _ in zip(d.keys(), d.values()): + new_k = keymap.get(k, k) + updated_dict[new_k] = d[k] + return updated_dict + + def transform(self, results): + # dict to albumentations format + results = self.mapper(results, self.keymap_to_albu) + + # Convert to RGB since Albumentations works with RGB images + if self.bgr_to_rgb: + results['image'] = cv2.cvtColor(results['image'], + cv2.COLOR_BGR2RGB) + if self.additional_targets: + for key, value in self.additional_targets.items(): + if value == 'image': + results[key] = cv2.cvtColor(results[key], + cv2.COLOR_BGR2RGB) + + # Apply Transform + results = self.aug(**results) + + # Convert back to BGR + if self.bgr_to_rgb: + results['image'] = cv2.cvtColor(results['image'], + cv2.COLOR_RGB2BGR) + if self.additional_targets: + for key, value in self.additional_targets.items(): + if value == 'image': + results[key] = cv2.cvtColor(results['image2'], + cv2.COLOR_RGB2BGR) + + # back to the original format + results = self.mapper(results, self.keymap_back) + + # update final shape + if self.update_pad_shape: + results['pad_shape'] = results['img'].shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(transforms={self.transforms})' + return repr_str + + +@TRANSFORMS.register_module() +class ConcatCDInput(BaseTransform): + """Concat images for change detection. + + Required Keys: + + - img + - img2 + + Args: + input_keys (tuple): Input image keys for change detection. + Default: ('img', 'img2'). + """ + + def __init__(self, input_keys=('img', 'img2')): + self.input_keys = input_keys + + def transform(self, results: dict) -> dict: + img = [] + for input_key in self.input_keys: + img.append(results.pop(input_key)) + results['img'] = np.concatenate(img, axis=2) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(input_keys={self.input_keys}, ' + return repr_str + + +@TRANSFORMS.register_module() +class RandomDepthMix(BaseTransform): + """This class implements the RandomDepthMix transform. + + Args: + prob (float): Probability of applying the transformation. + Defaults to 0.25. + mix_scale_ratio (float): Ratio to scale the mix width. + Defaults to 0.75. + """ + + def __init__( + self, + prob: float = 0.25, + mix_scale_ratio: float = 0.75, + ): + super().__init__() + + self.prob = prob + self.mix_scale_ratio = mix_scale_ratio + + def transform(self, results: dict) -> dict: + if random.random() > self.prob: + return results + + h, w = results['img_shape'][:2] + left = int(w * random.random()) + width_ratio = self.mix_scale_ratio * random.random() + width = int(max(1, (w - left) * width_ratio)) + + img = results['img'] + depth_rescale_factor = results.get('depth_rescale_factor', 1) + depth_map = results['gt_depth_map'] / depth_rescale_factor + + if img.ndim == 3: + for c in range(img.shape[-1]): + img[:, left:left + width, c] = depth_map[:, left:left + width] + elif img.ndim == 2: + img[:, left:left + width] = depth_map[:, left:left + width] + else: + raise ValueError(f'Invalid image shape ({img.shape})') + + results['img'] = img + return results diff --git a/finetune/mmseg/engine/__init__.py b/finetune/mmseg/engine/__init__.py new file mode 100644 index 0000000..98139a0 --- /dev/null +++ b/finetune/mmseg/engine/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hooks import SegVisualizationHook +from .optimizers import (ForceDefaultOptimWrapperConstructor, + LayerDecayOptimizerConstructor, + LearningRateDecayOptimizerConstructor) +from .schedulers import PolyLRRatio + +__all__ = [ + 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor', + 'SegVisualizationHook', 'PolyLRRatio', + 'ForceDefaultOptimWrapperConstructor' +] diff --git a/finetune/mmseg/engine/hooks/__init__.py b/finetune/mmseg/engine/hooks/__init__.py new file mode 100644 index 0000000..c604808 --- /dev/null +++ b/finetune/mmseg/engine/hooks/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .visualization_hook import SegVisualizationHook + +__all__ = ['SegVisualizationHook'] diff --git a/finetune/mmseg/engine/hooks/visualization_hook.py b/finetune/mmseg/engine/hooks/visualization_hook.py new file mode 100644 index 0000000..21cddde --- /dev/null +++ b/finetune/mmseg/engine/hooks/visualization_hook.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from typing import Optional, Sequence + +import mmcv +from mmengine.fileio import get +from mmengine.hooks import Hook +from mmengine.runner import Runner +from mmengine.visualization import Visualizer + +from mmseg.registry import HOOKS +from mmseg.structures import SegDataSample + + +@HOOKS.register_module() +class SegVisualizationHook(Hook): + """Segmentation Visualization Hook. Used to visualize validation and + testing process prediction results. + + In the testing phase: + + 1. If ``show`` is True, it means that only the prediction results are + visualized without storing data, so ``vis_backends`` needs to + be excluded. + + Args: + draw (bool): whether to draw prediction results. If it is False, + it means that no drawing will be done. Defaults to False. + interval (int): The interval of visualization. Defaults to 50. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + + def __init__(self, + draw: bool = False, + interval: int = 50, + show: bool = False, + wait_time: float = 0., + backend_args: Optional[dict] = None): + self._visualizer: Visualizer = Visualizer.get_current_instance() + self.interval = interval + self.show = show + if self.show: + # No need to think about vis backends. + self._visualizer._vis_backends = {} + warnings.warn('The show is True, it means that only ' + 'the prediction results are visualized ' + 'without storing data, so vis_backends ' + 'needs to be excluded.') + + self.wait_time = wait_time + self.backend_args = backend_args.copy() if backend_args else None + self.draw = draw + if not self.draw: + warnings.warn('The draw is False, it means that the ' + 'hook for visualization will not take ' + 'effect. The results will NOT be ' + 'visualized or stored.') + self._test_index = 0 + + def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[SegDataSample]) -> None: + """Run after every ``self.interval`` validation iterations. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`SegDataSample`]]): A batch of data samples + that contain annotations and predictions. + """ + if self.draw is False: + return + + # There is no guarantee that the same batch of images + # is visualized for each evaluation. + total_curr_iter = runner.iter + batch_idx + + # Visualize only the first data + img_path = outputs[0].img_path + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + window_name = f'val_{osp.basename(img_path)}' + + if total_curr_iter % self.interval == 0: + self._visualizer.add_datasample( + window_name, + img, + data_sample=outputs[0], + show=self.show, + wait_time=self.wait_time, + step=total_curr_iter) + + def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[SegDataSample]) -> None: + """Run after every testing iterations. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`SegDataSample`]): A batch of data samples + that contain annotations and predictions. + """ + if self.draw is False: + return + + for data_sample in outputs: + self._test_index += 1 + + img_path = data_sample.img_path + window_name = f'test_{osp.basename(img_path)}' + + img_path = data_sample.img_path + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + + self._visualizer.add_datasample( + window_name, + img, + data_sample=data_sample, + show=self.show, + wait_time=self.wait_time, + step=self._test_index) diff --git a/finetune/mmseg/engine/optimizers/__init__.py b/finetune/mmseg/engine/optimizers/__init__.py new file mode 100644 index 0000000..e4cf587 --- /dev/null +++ b/finetune/mmseg/engine/optimizers/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .force_default_constructor import ForceDefaultOptimWrapperConstructor +from .layer_decay_optimizer_constructor import ( + LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) + +__all__ = [ + 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor', + 'ForceDefaultOptimWrapperConstructor' +] diff --git a/finetune/mmseg/engine/optimizers/force_default_constructor.py b/finetune/mmseg/engine/optimizers/force_default_constructor.py new file mode 100644 index 0000000..12c642a --- /dev/null +++ b/finetune/mmseg/engine/optimizers/force_default_constructor.py @@ -0,0 +1,255 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.logging import print_log +from mmengine.optim import DefaultOptimWrapperConstructor +from mmengine.utils.dl_utils import mmcv_full_available +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm +from torch.nn import GroupNorm, LayerNorm + +from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class ForceDefaultOptimWrapperConstructor(DefaultOptimWrapperConstructor): + """Default constructor with forced optimizer settings. + + This constructor extends the default constructor to add an option for + forcing default optimizer settings. This is useful for ensuring that + certain parameters or layers strictly adhere to pre-defined default + settings, regardless of any custom settings specified. + + By default, each parameter share the same optimizer settings, and we + provide an argument ``paramwise_cfg`` to specify parameter-wise settings. + It is a dict and may contain various fields like 'custom_keys', + 'bias_lr_mult', etc., as well as the additional field + `force_default_settings` which allows for enforcing default settings on + optimizer parameters. + + - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If + one of the keys in ``custom_keys`` is a substring of the name of one + parameter, then the setting of the parameter will be specified by + ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will + be ignored. It should be noted that the aforementioned ``key`` is the + longest key that is a substring of the name of the parameter. If there + are multiple matched keys with the same length, then the key with lower + alphabet order will be chosen. + ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult`` + and ``decay_mult``. See Example 2 below. + - ``bias_lr_mult`` (float): It will be multiplied to the learning + rate for all bias parameters (except for those in normalization + layers and offset layers of DCN). + - ``bias_decay_mult`` (float): It will be multiplied to the weight + decay for all bias parameters (except for those in + normalization layers, depthwise conv layers, offset layers of DCN). + - ``norm_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of normalization + layers. + - ``flat_decay_mult`` (float): It will be multiplied to the weight + decay for all one-dimensional parameters + - ``dwconv_decay_mult`` (float): It will be multiplied to the weight + decay for all weight and bias parameters of depthwise conv + layers. + - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning + rate for parameters of offset layer in the deformable convs + of a model. + - ``bypass_duplicate`` (bool): If true, the duplicate parameters + would not be added into optimizer. Defaults to False. + - ``force_default_settings`` (bool): If true, this will override any + custom settings defined by ``custom_keys`` and enforce the use of + default settings for optimizer parameters like ``bias_lr_mult``. + This is particularly useful when you want to ensure that certain layers + or parameters adhere strictly to the pre-defined default settings. + + Note: + + 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will + override the effect of ``bias_lr_mult`` in the bias of offset layer. + So be careful when using both ``bias_lr_mult`` and + ``dcn_offset_lr_mult``. If you wish to apply both of them to the offset + layer in deformable convs, set ``dcn_offset_lr_mult`` to the original + ``dcn_offset_lr_mult`` * ``bias_lr_mult``. + + 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will + apply it to all the DCN layers in the model. So be careful when the + model contains multiple DCN layers in places other than backbone. + + 3. When the option ``force_default_settings`` is true, it will override + any custom settings provided in ``custom_keys``. This ensures that the + default settings for the optimizer parameters are used. + + Args: + optim_wrapper_cfg (dict): The config dict of the optimizer wrapper. + + Required fields of ``optim_wrapper_cfg`` are + + - ``type``: class name of the OptimizerWrapper + - ``optimizer``: The configuration of optimizer. + + Optional fields of ``optim_wrapper_cfg`` are + + - any arguments of the corresponding optimizer wrapper type, + e.g., accumulative_counts, clip_grad, etc. + + Required fields of ``optimizer`` are + + - `type`: class name of the optimizer. + + Optional fields of ``optimizer`` are + + - any arguments of the corresponding optimizer type, e.g., + lr, weight_decay, momentum, etc. + + paramwise_cfg (dict, optional): Parameter-wise options. + + Example 1: + >>> model = torch.nn.modules.Conv1d(1, 1, 1) + >>> optim_wrapper_cfg = dict( + >>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01, + >>> momentum=0.9, weight_decay=0.0001)) + >>> paramwise_cfg = dict(norm_decay_mult=0.) + >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( + >>> optim_wrapper_cfg, paramwise_cfg) + >>> optim_wrapper = optim_wrapper_builder(model) + + Example 2: + >>> # assume model have attribute model.backbone and model.cls_head + >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( + >>> type='SGD', lr=0.01, weight_decay=0.95)) + >>> paramwise_cfg = dict(custom_keys={ + >>> 'backbone': dict(lr_mult=0.1, decay_mult=0.9)}) + >>> optim_wrapper_builder = DefaultOptimWrapperConstructor( + >>> optim_wrapper_cfg, paramwise_cfg) + >>> optim_wrapper = optim_wrapper_builder(model) + >>> # Then the `lr` and `weight_decay` for model.backbone is + >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for + >>> # model.cls_head is (0.01, 0.95). + """ + + def add_params(self, + params: List[dict], + module: nn.Module, + prefix: str = '', + is_dcn_module: Optional[Union[int, float]] = None) -> None: + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + prefix (str): The prefix of the module + is_dcn_module (int|float|None): If the current module is a + submodule of DCN, `is_dcn_module` will be passed to + control conv_offset layer's learning rate. Defaults to None. + """ + # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + + bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', None) + bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', None) + norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', None) + dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', None) + flat_decay_mult = self.paramwise_cfg.get('flat_decay_mult', None) + bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) + dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', None) + force_default_settings = self.paramwise_cfg.get( + 'force_default_settings', False) + + # special rules for norm layers and depth-wise conv layers + is_norm = isinstance(module, + (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) + is_dwconv = ( + isinstance(module, torch.nn.Conv2d) + and module.in_channels == module.groups) + + for name, param in module.named_parameters(recurse=False): + param_group = {'params': [param]} + if bypass_duplicate and self._is_in(param_group, params): + print_log( + f'{prefix} is duplicate. It is skipped since ' + f'bypass_duplicate={bypass_duplicate}', + logger='current', + level=logging.WARNING) + continue + if not param.requires_grad: + params.append(param_group) + continue + + # if the parameter match one of the custom keys, ignore other rules + is_custom = False + for key in sorted_keys: + if key in f'{prefix}.{name}': + is_custom = True + lr_mult = custom_keys[key].get('lr_mult', 1.) + param_group['lr'] = self.base_lr * lr_mult + if self.base_wd is not None: + decay_mult = custom_keys[key].get('decay_mult', 1.) + param_group['weight_decay'] = self.base_wd * decay_mult + # add custom settings to param_group + for k, v in custom_keys[key].items(): + param_group[k] = v + break + + if not is_custom or force_default_settings: + # bias_lr_mult affects all bias parameters + # except for norm.bias dcn.conv_offset.bias + if name == 'bias' and not ( + is_norm or is_dcn_module) and bias_lr_mult is not None: + param_group['lr'] = self.base_lr * bias_lr_mult + + if (prefix.find('conv_offset') != -1 and is_dcn_module + and dcn_offset_lr_mult is not None + and isinstance(module, torch.nn.Conv2d)): + # deal with both dcn_offset's bias & weight + param_group['lr'] = self.base_lr * dcn_offset_lr_mult + + # apply weight decay policies + if self.base_wd is not None: + # norm decay + if is_norm and norm_decay_mult is not None: + param_group[ + 'weight_decay'] = self.base_wd * norm_decay_mult + # bias lr and decay + elif (name == 'bias' and not is_dcn_module + and bias_decay_mult is not None): + param_group[ + 'weight_decay'] = self.base_wd * bias_decay_mult + # depth-wise conv + elif is_dwconv and dwconv_decay_mult is not None: + param_group[ + 'weight_decay'] = self.base_wd * dwconv_decay_mult + # flatten parameters except dcn offset + elif (param.ndim == 1 and not is_dcn_module + and flat_decay_mult is not None): + param_group[ + 'weight_decay'] = self.base_wd * flat_decay_mult + params.append(param_group) + for key, value in param_group.items(): + if key == 'params': + continue + full_name = f'{prefix}.{name}' if prefix else name + print_log( + f'paramwise_options -- {full_name}:{key}={value}', + logger='current') + + if mmcv_full_available(): + from mmcv.ops import DeformConv2d, ModulatedDeformConv2d + is_dcn_module = isinstance(module, + (DeformConv2d, ModulatedDeformConv2d)) + else: + is_dcn_module = False + for child_name, child_mod in module.named_children(): + child_prefix = f'{prefix}.{child_name}' if prefix else child_name + self.add_params( + params, + child_mod, + prefix=child_prefix, + is_dcn_module=is_dcn_module) diff --git a/finetune/mmseg/engine/optimizers/layer_decay_optimizer_constructor.py b/finetune/mmseg/engine/optimizers/layer_decay_optimizer_constructor.py new file mode 100644 index 0000000..fdae3ca --- /dev/null +++ b/finetune/mmseg/engine/optimizers/layer_decay_optimizer_constructor.py @@ -0,0 +1,207 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import warnings + +from mmengine.dist import get_dist_info +from mmengine.logging import print_log +from mmengine.optim import DefaultOptimWrapperConstructor + +from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS + + +def get_layer_id_for_convnext(var_name, max_layer_id): + """Get the layer id to set the different learning rates in ``layer_wise`` + decay_type. + + Args: + var_name (str): The key of the model. + max_layer_id (int): Maximum number of backbone layers. + + Returns: + int: The id number corresponding to different learning rate in + ``LearningRateDecayOptimizerConstructor``. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.downsample_layers'): + stage_id = int(var_name.split('.')[2]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1: + layer_id = 2 + elif stage_id == 2: + layer_id = 3 + elif stage_id == 3: + layer_id = max_layer_id + return layer_id + elif var_name.startswith('backbone.stages'): + stage_id = int(var_name.split('.')[2]) + block_id = int(var_name.split('.')[3]) + if stage_id == 0: + layer_id = 1 + elif stage_id == 1: + layer_id = 2 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + elif stage_id == 3: + layer_id = max_layer_id + return layer_id + else: + return max_layer_id + 1 + + +def get_stage_id_for_convnext(var_name, max_stage_id): + """Get the stage id to set the different learning rates in ``stage_wise`` + decay_type. + + Args: + var_name (str): The key of the model. + max_stage_id (int): Maximum number of backbone layers. + + Returns: + int: The id number corresponding to different learning rate in + ``LearningRateDecayOptimizerConstructor``. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.downsample_layers'): + return 0 + elif var_name.startswith('backbone.stages'): + stage_id = int(var_name.split('.')[2]) + return stage_id + 1 + else: + return max_stage_id - 1 + + +def get_layer_id_for_vit(var_name, max_layer_id): + """Get the layer id to set the different learning rates. + + Args: + var_name (str): The key of the model. + num_max_layer (int): Maximum number of backbone layers. + + Returns: + int: Returns the layer id of the key. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.patch_embed'): + return 0 + elif var_name.startswith('backbone.layers'): + layer_id = int(var_name.split('.')[2]) + return layer_id + 1 + else: + return max_layer_id - 1 + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor): + """Different learning rates are set for different layers of backbone. + + Note: Currently, this optimizer constructor is built for ConvNeXt, + BEiT and MAE. + """ + + def add_params(self, params, module, **kwargs): + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + """ + + parameter_groups = {} + print_log(f'self.paramwise_cfg is {self.paramwise_cfg}') + num_layers = self.paramwise_cfg.get('num_layers') + 2 + decay_rate = self.paramwise_cfg.get('decay_rate') + decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise') + print_log('Build LearningRateDecayOptimizerConstructor ' + f'{decay_type} {decay_rate} - {num_layers}') + weight_decay = self.base_wd + for name, param in module.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith('.bias') or name in ( + 'pos_embed', 'cls_token'): + group_name = 'no_decay' + this_weight_decay = 0. + else: + group_name = 'decay' + this_weight_decay = weight_decay + if 'layer_wise' in decay_type: + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_layer_id_for_convnext( + name, self.paramwise_cfg.get('num_layers')) + print_log(f'set param {name} as id {layer_id}') + elif 'BEiT' in module.backbone.__class__.__name__ or \ + 'MAE' in module.backbone.__class__.__name__: + layer_id = get_layer_id_for_vit(name, num_layers) + print_log(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() + elif decay_type == 'stage_wise': + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_stage_id_for_convnext(name, num_layers) + print_log(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() + group_name = f'layer_{layer_id}_{group_name}' + + if group_name not in parameter_groups: + scale = decay_rate**(num_layers - layer_id - 1) + + parameter_groups[group_name] = { + 'weight_decay': this_weight_decay, + 'params': [], + 'param_names': [], + 'lr_scale': scale, + 'group_name': group_name, + 'lr': scale * self.base_lr, + } + + parameter_groups[group_name]['params'].append(param) + parameter_groups[group_name]['param_names'].append(name) + rank, _ = get_dist_info() + if rank == 0: + to_display = {} + for key in parameter_groups: + to_display[key] = { + 'param_names': parameter_groups[key]['param_names'], + 'lr_scale': parameter_groups[key]['lr_scale'], + 'lr': parameter_groups[key]['lr'], + 'weight_decay': parameter_groups[key]['weight_decay'], + } + print_log(f'Param groups = {json.dumps(to_display, indent=2)}') + params.extend(parameter_groups.values()) + + +@OPTIM_WRAPPER_CONSTRUCTORS.register_module() +class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor): + """Different learning rates are set for different layers of backbone. + + Note: Currently, this optimizer constructor is built for BEiT, + and it will be deprecated. + Please use ``LearningRateDecayOptimizerConstructor`` instead. + """ + + def __init__(self, optim_wrapper_cfg, paramwise_cfg): + warnings.warn('DeprecationWarning: Original ' + 'LayerDecayOptimizerConstructor of BEiT ' + 'will be deprecated. Please use ' + 'LearningRateDecayOptimizerConstructor instead, ' + 'and set decay_type = layer_wise_vit in paramwise_cfg.') + paramwise_cfg.update({'decay_type': 'layer_wise_vit'}) + warnings.warn('DeprecationWarning: Layer_decay_rate will ' + 'be deleted, please use decay_rate instead.') + paramwise_cfg['decay_rate'] = paramwise_cfg.pop('layer_decay_rate') + super().__init__(optim_wrapper_cfg, paramwise_cfg) diff --git a/finetune/mmseg/engine/schedulers/__init__.py b/finetune/mmseg/engine/schedulers/__init__.py new file mode 100644 index 0000000..3cd3f62 --- /dev/null +++ b/finetune/mmseg/engine/schedulers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .poly_ratio_scheduler import PolyLRRatio + +__all__ = ['PolyLRRatio'] diff --git a/finetune/mmseg/engine/schedulers/poly_ratio_scheduler.py b/finetune/mmseg/engine/schedulers/poly_ratio_scheduler.py new file mode 100644 index 0000000..057203a --- /dev/null +++ b/finetune/mmseg/engine/schedulers/poly_ratio_scheduler.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +from mmengine.optim.scheduler import PolyLR + +from mmseg.registry import PARAM_SCHEDULERS + + +@PARAM_SCHEDULERS.register_module() +class PolyLRRatio(PolyLR): + """Implements polynomial learning rate decay with ratio. + + This scheduler adjusts the learning rate of each parameter group + following a polynomial decay equation. The decay can occur in + conjunction with external parameter adjustments made outside this + scheduler. + + Args: + optimizer (Optimizer or OptimWrapper): Wrapped optimizer. + eta_min (float): Minimum learning rate at the end of scheduling. + Defaults to 0. + eta_min_ratio (float, optional): The ratio of the minimum parameter + value to the base parameter value. Either `eta_min` or + `eta_min_ratio` should be specified. Defaults to None. + power (float): The power of the polynomial. Defaults to 1.0. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def __init__(self, eta_min_ratio: Optional[int] = None, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.eta_min_ratio = eta_min_ratio + + def _get_value(self): + """Compute value using chainable form of the scheduler.""" + + if self.last_step == 0: + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + + param_groups_value = [] + for base_value, param_group in zip(self.base_values, + self.optimizer.param_groups): + eta_min = self.eta_min if self.eta_min_ratio is None else \ + base_value * self.eta_min_ratio + step_ratio = (1 - 1 / + (self.total_iters - self.last_step + 1))**self.power + step_value = (param_group[self.param_name] - + eta_min) * step_ratio + eta_min + param_groups_value.append(step_value) + + return param_groups_value diff --git a/finetune/mmseg/evaluation/__init__.py b/finetune/mmseg/evaluation/__init__.py new file mode 100644 index 0000000..82b3a8d --- /dev/null +++ b/finetune/mmseg/evaluation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .metrics import CityscapesMetric, DepthMetric, IoUMetric + +__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric'] diff --git a/finetune/mmseg/evaluation/metrics/__init__.py b/finetune/mmseg/evaluation/metrics/__init__.py new file mode 100644 index 0000000..848d471 --- /dev/null +++ b/finetune/mmseg/evaluation/metrics/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .citys_metric import CityscapesMetric +from .depth_metric import DepthMetric +from .iou_metric import IoUMetric + +__all__ = ['IoUMetric', 'CityscapesMetric', 'DepthMetric'] diff --git a/finetune/mmseg/evaluation/metrics/citys_metric.py b/finetune/mmseg/evaluation/metrics/citys_metric.py new file mode 100644 index 0000000..3298465 --- /dev/null +++ b/finetune/mmseg/evaluation/metrics/citys_metric.py @@ -0,0 +1,158 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import shutil +from collections import OrderedDict +from typing import Dict, Optional, Sequence + +try: + + import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa + import cityscapesscripts.helpers.labels as CSLabels +except ImportError: + CSLabels = None + CSEval = None + +import numpy as np +from mmengine.dist import is_main_process, master_only +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger, print_log +from mmengine.utils import mkdir_or_exist +from PIL import Image + +from mmseg.registry import METRICS + + +@METRICS.register_module() +class CityscapesMetric(BaseMetric): + """Cityscapes evaluation metric. + + Args: + output_dir (str): The directory for output prediction + ignore_index (int): Index that will be ignored in evaluation. + Default: 255. + format_only (bool): Only format result for results commit without + perform evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + keep_results (bool): Whether to keep the results. When ``format_only`` + is True, ``keep_results`` must be True. Defaults to False. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + def __init__(self, + output_dir: str, + ignore_index: int = 255, + format_only: bool = False, + keep_results: bool = False, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + **kwargs) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + if CSEval is None: + raise ImportError('Please run "pip install cityscapesscripts" to ' + 'install cityscapesscripts first.') + self.output_dir = output_dir + self.ignore_index = ignore_index + + self.format_only = format_only + if format_only: + assert keep_results, ( + 'When format_only is True, the results must be keep, please ' + f'set keep_results as True, but got {keep_results}') + self.keep_results = keep_results + self.prefix = prefix + if is_main_process(): + mkdir_or_exist(self.output_dir) + + @master_only + def __del__(self) -> None: + """Clean up.""" + if not self.keep_results: + shutil.rmtree(self.output_dir) + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + mkdir_or_exist(self.output_dir) + + for data_sample in data_samples: + pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy() + # when evaluating with official cityscapesscripts, + # labelIds should be used + pred_label = self._convert_to_label_id(pred_label) + basename = osp.splitext(osp.basename(data_sample['img_path']))[0] + png_filename = osp.abspath( + osp.join(self.output_dir, f'{basename}.png')) + output = Image.fromarray(pred_label.astype(np.uint8)).convert('P') + output.save(png_filename) + if self.format_only: + # format_only always for test dataset without ground truth + gt_filename = '' + else: + # when evaluating with official cityscapesscripts, + # **_gtFine_labelIds.png is used + gt_filename = data_sample['seg_map_path'].replace( + 'labelTrainIds.png', 'labelIds.png') + self.results.append((png_filename, gt_filename)) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): Testing results of the dataset. + + Returns: + dict[str: float]: Cityscapes evaluation results. + """ + logger: MMLogger = MMLogger.get_current_instance() + if self.format_only: + logger.info(f'results are saved to {osp.dirname(self.output_dir)}') + return OrderedDict() + + msg = 'Evaluating in Cityscapes style' + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + + eval_results = dict() + print_log( + f'Evaluating results under {self.output_dir} ...', logger=logger) + + CSEval.args.evalInstLevelScore = True + CSEval.args.predictionPath = osp.abspath(self.output_dir) + CSEval.args.evalPixelAccuracy = True + CSEval.args.JSONOutput = False + + pred_list, gt_list = zip(*results) + metric = dict() + eval_results.update( + CSEval.evaluateImgLists(pred_list, gt_list, CSEval.args)) + metric['averageScoreCategories'] = eval_results[ + 'averageScoreCategories'] + metric['averageScoreInstCategories'] = eval_results[ + 'averageScoreInstCategories'] + return metric + + @staticmethod + def _convert_to_label_id(result): + """Convert trainId to id for cityscapes.""" + if isinstance(result, str): + result = np.load(result) + result_copy = result.copy() + for trainId, label in CSLabels.trainId2label.items(): + result_copy[result == trainId] = label.id + + return result_copy diff --git a/finetune/mmseg/evaluation/metrics/depth_metric.py b/finetune/mmseg/evaluation/metrics/depth_metric.py new file mode 100644 index 0000000..621d4a3 --- /dev/null +++ b/finetune/mmseg/evaluation/metrics/depth_metric.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from collections import OrderedDict, defaultdict +from typing import Dict, List, Optional, Sequence + +import cv2 +import numpy as np +import torch +from mmengine.dist import is_main_process +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger, print_log +from mmengine.utils import mkdir_or_exist +from prettytable import PrettyTable +from torch import Tensor + +from mmseg.registry import METRICS + + +@METRICS.register_module() +class DepthMetric(BaseMetric): + """Depth estimation evaluation metric. + + Args: + depth_metrics (List[str], optional): List of metrics to compute. If + not specified, defaults to all metrics in self.METRICS. + min_depth_eval (float): Minimum depth value for evaluation. + Defaults to 0.0. + max_depth_eval (float): Maximum depth value for evaluation. + Defaults to infinity. + crop_type (str, optional): Specifies the type of cropping to be used + during evaluation. This option can affect how the evaluation mask + is generated. Currently, 'nyu_crop' is supported, but other + types can be added in future. Defaults to None if no cropping + should be applied. + depth_scale_factor (float): Factor to scale the depth values. + Defaults to 1.0. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + output_dir (str): The directory for output prediction. Defaults to + None. + format_only (bool): Only format result for results commit without + perform evaluation. It is useful when you want to save the result + to a specific format and submit it to the test server. + Defaults to False. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + METRICS = ('d1', 'd2', 'd3', 'abs_rel', 'sq_rel', 'rmse', 'rmse_log', + 'log10', 'silog') + + def __init__(self, + depth_metrics: Optional[List[str]] = None, + min_depth_eval: float = 0.0, + max_depth_eval: float = float('inf'), + crop_type: Optional[str] = None, + depth_scale_factor: float = 1.0, + collect_device: str = 'cpu', + output_dir: Optional[str] = None, + format_only: bool = False, + prefix: Optional[str] = None, + **kwargs) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + if depth_metrics is None: + self.metrics = self.METRICS + elif isinstance(depth_metrics, [tuple, list]): + for metric in depth_metrics: + assert metric in self.METRICS, f'the metric {metric} is not ' \ + f'supported. Please use metrics in {self.METRICS}' + self.metrics = depth_metrics + + # Validate crop_type, if provided + assert crop_type in [ + None, 'nyu_crop' + ], (f'Invalid value for crop_type: {crop_type}. Supported values are ' + 'None or \'nyu_crop\'.') + self.crop_type = crop_type + self.min_depth_eval = min_depth_eval + self.max_depth_eval = max_depth_eval + self.output_dir = output_dir + if self.output_dir and is_main_process(): + mkdir_or_exist(self.output_dir) + self.format_only = format_only + self.depth_scale_factor = depth_scale_factor + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_label = data_sample['pred_depth_map']['data'].squeeze() + # format_only always for test dataset without ground truth + if not self.format_only: + gt_depth = data_sample['gt_depth_map']['data'].squeeze().to( + pred_label) + + eval_mask = self._get_eval_mask(gt_depth) + self.results.append( + (gt_depth[eval_mask], pred_label[eval_mask])) + # format_result + if self.output_dir is not None: + basename = osp.splitext(osp.basename( + data_sample['img_path']))[0] + png_filename = osp.abspath( + osp.join(self.output_dir, f'{basename}.png')) + output_mask = pred_label.cpu().numpy( + ) * self.depth_scale_factor + + cv2.imwrite(png_filename, output_mask.astype(np.uint16), + [cv2.IMWRITE_PNG_COMPRESSION, 0]) + + def _get_eval_mask(self, gt_depth: Tensor): + """Generates an evaluation mask based on ground truth depth and + cropping. + + Args: + gt_depth (Tensor): Ground truth depth map. + + Returns: + Tensor: Boolean mask where evaluation should be performed. + """ + valid_mask = torch.logical_and(gt_depth > self.min_depth_eval, + gt_depth < self.max_depth_eval) + + if self.crop_type == 'nyu_crop': + # this implementation is adapted from + # https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blob/main/depth/datasets/nyu.py # noqa + crop_mask = torch.zeros_like(valid_mask) + crop_mask[45:471, 41:601] = 1 + else: + crop_mask = torch.ones_like(valid_mask) + + eval_mask = torch.logical_and(valid_mask, crop_mask) + return eval_mask + + @staticmethod + def _calc_all_metrics(gt_depth, pred_depth): + """Computes final evaluation metrics based on accumulated results.""" + assert gt_depth.shape == pred_depth.shape + + thresh = torch.max((gt_depth / pred_depth), (pred_depth / gt_depth)) + diff = pred_depth - gt_depth + diff_log = torch.log(pred_depth) - torch.log(gt_depth) + + d1 = torch.sum(thresh < 1.25).float() / len(thresh) + d2 = torch.sum(thresh < 1.25**2).float() / len(thresh) + d3 = torch.sum(thresh < 1.25**3).float() / len(thresh) + + abs_rel = torch.mean(torch.abs(diff) / gt_depth) + sq_rel = torch.mean(torch.pow(diff, 2) / gt_depth) + + rmse = torch.sqrt(torch.mean(torch.pow(diff, 2))) + rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log, 2))) + + log10 = torch.mean( + torch.abs(torch.log10(pred_depth) - torch.log10(gt_depth))) + silog = torch.sqrt( + torch.pow(diff_log, 2).mean() - + 0.5 * torch.pow(diff_log.mean(), 2)) + + return { + 'd1': d1.item(), + 'd2': d2.item(), + 'd3': d3.item(), + 'abs_rel': abs_rel.item(), + 'sq_rel': sq_rel.item(), + 'rmse': rmse.item(), + 'rmse_log': rmse_log.item(), + 'log10': log10.item(), + 'silog': silog.item() + } + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. The keys + are identical with self.metrics. + """ + logger: MMLogger = MMLogger.get_current_instance() + if self.format_only: + logger.info(f'results are saved to {osp.dirname(self.output_dir)}') + return OrderedDict() + + metrics = defaultdict(list) + for gt_depth, pred_depth in results: + for key, value in self._calc_all_metrics(gt_depth, + pred_depth).items(): + metrics[key].append(value) + metrics = {k: sum(metrics[k]) / len(metrics[k]) for k in self.metrics} + + table_data = PrettyTable() + for key, val in metrics.items(): + table_data.add_column(key, [round(val, 5)]) + + print_log('results:', logger) + print_log('\n' + table_data.get_string(), logger=logger) + + return metrics diff --git a/finetune/mmseg/evaluation/metrics/iou_metric.py b/finetune/mmseg/evaluation/metrics/iou_metric.py new file mode 100644 index 0000000..16014c7 --- /dev/null +++ b/finetune/mmseg/evaluation/metrics/iou_metric.py @@ -0,0 +1,286 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from collections import OrderedDict +from typing import Dict, List, Optional, Sequence + +import numpy as np +import torch +from mmengine.dist import is_main_process +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger, print_log +from mmengine.utils import mkdir_or_exist +from PIL import Image +from prettytable import PrettyTable + +from mmseg.registry import METRICS + + +@METRICS.register_module() +class IoUMetric(BaseMetric): + """IoU evaluation metric. + + Args: + ignore_index (int): Index that will be ignored in evaluation. + Default: 255. + iou_metrics (list[str] | str): Metrics to be calculated, the options + includes 'mIoU', 'mDice' and 'mFscore'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + beta (int): Determines the weight of recall in the combined score. + Default: 1. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + output_dir (str): The directory for output prediction. Defaults to + None. + format_only (bool): Only format result for results commit without + perform evaluation. It is useful when you want to save the result + to a specific format and submit it to the test server. + Defaults to False. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + def __init__(self, + ignore_index: int = 255, + iou_metrics: List[str] = ['mIoU'], + nan_to_num: Optional[int] = None, + beta: int = 1, + collect_device: str = 'cpu', + output_dir: Optional[str] = None, + format_only: bool = False, + prefix: Optional[str] = None, + **kwargs) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + self.ignore_index = ignore_index + self.metrics = iou_metrics + self.nan_to_num = nan_to_num + self.beta = beta + self.output_dir = output_dir + if self.output_dir and is_main_process(): + mkdir_or_exist(self.output_dir) + self.format_only = format_only + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + num_classes = len(self.dataset_meta['classes']) + for data_sample in data_samples: + pred_label = data_sample['pred_sem_seg']['data'].squeeze() + # format_only always for test dataset without ground truth + if not self.format_only: + label = data_sample['gt_sem_seg']['data'].squeeze().to( + pred_label) + self.results.append( + self.intersect_and_union(pred_label, label, num_classes, + self.ignore_index)) + # format_result + if self.output_dir is not None: + basename = osp.splitext(osp.basename( + data_sample['img_path']))[0] + png_filename = osp.abspath( + osp.join(self.output_dir, f'{basename}.png')) + output_mask = pred_label.cpu().numpy() + # The index range of official ADE20k dataset is from 0 to 150. + # But the index range of output is from 0 to 149. + # That is because we set reduce_zero_label=True. + if data_sample.get('reduce_zero_label', False): + output_mask = output_mask + 1 + output = Image.fromarray(output_mask.astype(np.uint8)) + output.save(png_filename) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. The key + mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision, + mRecall. + """ + logger: MMLogger = MMLogger.get_current_instance() + if self.format_only: + logger.info(f'results are saved to {osp.dirname(self.output_dir)}') + return OrderedDict() + # convert list of tuples to tuple of lists, e.g. + # [(A_1, B_1, C_1, D_1), ..., (A_n, B_n, C_n, D_n)] to + # ([A_1, ..., A_n], ..., [D_1, ..., D_n]) + results = tuple(zip(*results)) + assert len(results) == 4 + + total_area_intersect = sum(results[0]) + total_area_union = sum(results[1]) + total_area_pred_label = sum(results[2]) + total_area_label = sum(results[3]) + ret_metrics = self.total_area_to_metrics( + total_area_intersect, total_area_union, total_area_pred_label, + total_area_label, self.metrics, self.nan_to_num, self.beta) + + class_names = self.dataset_meta['classes'] + + # summary table + ret_metrics_summary = OrderedDict({ + ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + metrics = dict() + for key, val in ret_metrics_summary.items(): + if key == 'aAcc': + metrics[key] = val + else: + metrics['m' + key] = val + + # each class table + ret_metrics.pop('aAcc', None) + ret_metrics_class = OrderedDict({ + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + ret_metrics_class.update({'Class': class_names}) + ret_metrics_class.move_to_end('Class', last=False) + class_table_data = PrettyTable() + for key, val in ret_metrics_class.items(): + class_table_data.add_column(key, val) + + print_log('per class results:', logger) + print_log('\n' + class_table_data.get_string(), logger=logger) + + return metrics + + @staticmethod + def intersect_and_union(pred_label: torch.tensor, label: torch.tensor, + num_classes: int, ignore_index: int): + """Calculate Intersection and Union. + + Args: + pred_label (torch.tensor): Prediction segmentation map + or predict result filename. The shape is (H, W). + label (torch.tensor): Ground truth segmentation map + or label filename. The shape is (H, W). + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + + Returns: + torch.Tensor: The intersection of prediction and ground truth + histogram on all classes. + torch.Tensor: The union of prediction and ground truth histogram on + all classes. + torch.Tensor: The prediction histogram on all classes. + torch.Tensor: The ground truth histogram on all classes. + """ + + mask = (label != ignore_index) + pred_label = pred_label[mask] + label = label[mask] + + intersect = pred_label[pred_label == label] + area_intersect = torch.histc( + intersect.float(), bins=(num_classes), min=0, + max=num_classes - 1).cpu() + area_pred_label = torch.histc( + pred_label.float(), bins=(num_classes), min=0, + max=num_classes - 1).cpu() + area_label = torch.histc( + label.float(), bins=(num_classes), min=0, + max=num_classes - 1).cpu() + area_union = area_pred_label + area_label - area_intersect + return area_intersect, area_union, area_pred_label, area_label + + @staticmethod + def total_area_to_metrics(total_area_intersect: np.ndarray, + total_area_union: np.ndarray, + total_area_pred_label: np.ndarray, + total_area_label: np.ndarray, + metrics: List[str] = ['mIoU'], + nan_to_num: Optional[int] = None, + beta: int = 1): + """Calculate evaluation metrics + Args: + total_area_intersect (np.ndarray): The intersection of prediction + and ground truth histogram on all classes. + total_area_union (np.ndarray): The union of prediction and ground + truth histogram on all classes. + total_area_pred_label (np.ndarray): The prediction histogram on + all classes. + total_area_label (np.ndarray): The ground truth histogram on + all classes. + metrics (List[str] | str): Metrics to be evaluated, 'mIoU' and + 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be + replaced by the numbers defined by the user. Default: None. + beta (int): Determines the weight of recall in the combined score. + Default: 1. + Returns: + Dict[str, np.ndarray]: per category evaluation metrics, + shape (num_classes, ). + """ + + def f_score(precision, recall, beta=1): + """calculate the f-score value. + + Args: + precision (float | torch.Tensor): The precision value. + recall (float | torch.Tensor): The recall value. + beta (int): Determines the weight of recall in the combined + score. Default: 1. + + Returns: + [torch.tensor]: The f-score value. + """ + score = (1 + beta**2) * (precision * recall) / ( + (beta**2 * precision) + recall) + return score + + if isinstance(metrics, str): + metrics = [metrics] + allowed_metrics = ['mIoU', 'mDice', 'mFscore'] + if not set(metrics).issubset(set(allowed_metrics)): + raise KeyError(f'metrics {metrics} is not supported') + + all_acc = total_area_intersect.sum() / total_area_label.sum() + ret_metrics = OrderedDict({'aAcc': all_acc}) + for metric in metrics: + if metric == 'mIoU': + iou = total_area_intersect / total_area_union + acc = total_area_intersect / total_area_label + ret_metrics['IoU'] = iou + ret_metrics['Acc'] = acc + elif metric == 'mDice': + dice = 2 * total_area_intersect / ( + total_area_pred_label + total_area_label) + acc = total_area_intersect / total_area_label + ret_metrics['Dice'] = dice + ret_metrics['Acc'] = acc + elif metric == 'mFscore': + precision = total_area_intersect / total_area_pred_label + recall = total_area_intersect / total_area_label + f_value = torch.tensor([ + f_score(x[0], x[1], beta) for x in zip(precision, recall) + ]) + ret_metrics['Fscore'] = f_value + ret_metrics['Precision'] = precision + ret_metrics['Recall'] = recall + + ret_metrics = { + metric: value.numpy() + for metric, value in ret_metrics.items() + } + if nan_to_num is not None: + ret_metrics = OrderedDict({ + metric: np.nan_to_num(metric_value, nan=nan_to_num) + for metric, metric_value in ret_metrics.items() + }) + return ret_metrics diff --git a/finetune/mmseg/models/__init__.py b/finetune/mmseg/models/__init__.py new file mode 100644 index 0000000..a989512 --- /dev/null +++ b/finetune/mmseg/models/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .assigners import * # noqa: F401,F403 +from .backbones import * # noqa: F401,F403 +from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, + build_head, build_loss, build_segmentor) +from .data_preprocessor import SegDataPreProcessor +from .decode_heads import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .necks import * # noqa: F401,F403 +from .segmentors import * # noqa: F401,F403 +from .text_encoder import * # noqa: F401,F403 + +__all__ = [ + 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone', + 'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor' +] diff --git a/finetune/mmseg/models/assigners/__init__.py b/finetune/mmseg/models/assigners/__init__.py new file mode 100644 index 0000000..d49b1b1 --- /dev/null +++ b/finetune/mmseg/models/assigners/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_assigner import BaseAssigner +from .hungarian_assigner import HungarianAssigner +from .match_cost import ClassificationCost, CrossEntropyLossCost, DiceCost + +__all__ = [ + 'BaseAssigner', + 'HungarianAssigner', + 'ClassificationCost', + 'CrossEntropyLossCost', + 'DiceCost', +] diff --git a/finetune/mmseg/models/assigners/base_assigner.py b/finetune/mmseg/models/assigners/base_assigner.py new file mode 100644 index 0000000..97895cd --- /dev/null +++ b/finetune/mmseg/models/assigners/base_assigner.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Optional + +from mmengine.structures import InstanceData + + +class BaseAssigner(metaclass=ABCMeta): + """Base assigner that assigns masks to ground truth class labels.""" + + @abstractmethod + def assign(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + gt_instances_ignore: Optional[InstanceData] = None, + **kwargs): + """Assign masks to either a ground truth class label or a negative + label.""" diff --git a/finetune/mmseg/models/assigners/hungarian_assigner.py b/finetune/mmseg/models/assigners/hungarian_assigner.py new file mode 100644 index 0000000..28868f0 --- /dev/null +++ b/finetune/mmseg/models/assigners/hungarian_assigner.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Union + +import torch +from mmengine import ConfigDict +from mmengine.structures import InstanceData +from scipy.optimize import linear_sum_assignment +from torch.cuda.amp import autocast + +from mmseg.registry import TASK_UTILS +from .base_assigner import BaseAssigner + + +@TASK_UTILS.register_module() +class HungarianAssigner(BaseAssigner): + """Computes one-to-one matching between prediction masks and ground truth. + + This class uses bipartite matching-based assignment to computes an + assignment between the prediction masks and the ground truth. The + assignment result is based on the weighted sum of match costs. The + Hungarian algorithm is used to calculate the best matching with the + minimum cost. The prediction masks that are not matched are classified + as background. + + Args: + match_costs (ConfigDict|List[ConfigDict]): Match cost configs. + """ + + def __init__( + self, match_costs: Union[List[Union[dict, ConfigDict]], dict, + ConfigDict] + ) -> None: + + if isinstance(match_costs, dict): + match_costs = [match_costs] + elif isinstance(match_costs, list): + assert len(match_costs) > 0, \ + 'match_costs must not be a empty list.' + + self.match_costs = [ + TASK_UTILS.build(match_cost) for match_cost in match_costs + ] + + def assign(self, pred_instances: InstanceData, gt_instances: InstanceData, + **kwargs): + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The assignment first calculates the cost for each + category assigned to each query mask, and then uses the + Hungarian algorithm to calculate the minimum cost as the best + match. + + Args: + pred_instances (InstanceData): Instances of model + predictions. It includes "masks", with shape + (n, h, w) or (n, l), and "cls", with shape (n, num_classes+1) + gt_instances (InstanceData): Ground truth of instance + annotations. It includes "labels", with shape (k, ), + and "masks", with shape (k, h, w) or (k, l). + + Returns: + matched_quiery_inds (Tensor): The indexes of matched quieres. + matched_label_inds (Tensor): The indexes of matched labels. + """ + # compute weighted cost + cost_list = [] + with autocast(enabled=False): + for match_cost in self.match_costs: + cost = match_cost( + pred_instances=pred_instances, gt_instances=gt_instances) + cost_list.append(cost) + cost = torch.stack(cost_list).sum(dim=0) + + device = cost.device + # do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' + 'to install scipy first.') + + matched_quiery_inds, matched_label_inds = linear_sum_assignment(cost) + matched_quiery_inds = torch.from_numpy(matched_quiery_inds).to(device) + matched_label_inds = torch.from_numpy(matched_label_inds).to(device) + + return matched_quiery_inds, matched_label_inds diff --git a/finetune/mmseg/models/assigners/match_cost.py b/finetune/mmseg/models/assigners/match_cost.py new file mode 100644 index 0000000..560df85 --- /dev/null +++ b/finetune/mmseg/models/assigners/match_cost.py @@ -0,0 +1,231 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import Union + +import torch +import torch.nn.functional as F +from mmengine.structures import InstanceData +from torch import Tensor + +from mmseg.registry import TASK_UTILS + + +class BaseMatchCost: + """Base match cost class. + + Args: + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__(self, weight: Union[float, int] = 1.) -> None: + self.weight = weight + + @abstractmethod + def __call__(self, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (InstanceData): Instances of model predictions. + It often includes "labels" and "scores". + gt_instances (InstanceData): Ground truth of instance + annotations. It usually includes "labels". + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + pass + + +@TASK_UTILS.register_module() +class ClassificationCost(BaseMatchCost): + """ClsSoftmaxCost. + + Args: + weight (Union[float, int]): Cost weight. Defaults to 1. + + Examples: + >>> from mmseg.models.assigners import ClassificationCost + >>> import torch + >>> self = ClassificationCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3430, -0.3525, -0.3045], + [-0.3077, -0.2931, -0.3992], + [-0.3664, -0.3455, -0.2881], + [-0.3343, -0.2701, -0.3956]]) + """ + + def __init__(self, weight: Union[float, int] = 1) -> None: + super().__init__(weight=weight) + + def __call__(self, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (InstanceData): "scores" inside is + predicted classification logits, of shape + (num_queries, num_class). + gt_instances (InstanceData): "labels" inside should have + shape (num_gt, ). + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + assert hasattr(pred_instances, 'scores'), \ + "pred_instances must contain 'scores'" + assert hasattr(gt_instances, 'labels'), \ + "gt_instances must contain 'labels'" + pred_scores = pred_instances.scores + gt_labels = gt_instances.labels + + pred_scores = pred_scores.softmax(-1) + cls_cost = -pred_scores[:, gt_labels] + + return cls_cost * self.weight + + +@TASK_UTILS.register_module() +class DiceCost(BaseMatchCost): + """Cost of mask assignments based on dice losses. + + Args: + pred_act (bool): Whether to apply sigmoid to mask_pred. + Defaults to False. + eps (float): Defaults to 1e-3. + naive_dice (bool): If True, use the naive dice loss + in which the power of the number in the denominator is + the first power. If False, use the second power that + is adopted by K-Net and SOLO. Defaults to True. + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__(self, + pred_act: bool = False, + eps: float = 1e-3, + naive_dice: bool = True, + weight: Union[float, int] = 1.) -> None: + super().__init__(weight=weight) + self.pred_act = pred_act + self.eps = eps + self.naive_dice = naive_dice + + def _binary_mask_dice_loss(self, mask_preds: Tensor, + gt_masks: Tensor) -> Tensor: + """ + Args: + mask_preds (Tensor): Mask prediction in shape (num_queries, *). + gt_masks (Tensor): Ground truth in shape (num_gt, *) + store 0 or 1, 0 for negative class and 1 for + positive class. + + Returns: + Tensor: Dice cost matrix in shape (num_queries, num_gt). + """ + mask_preds = mask_preds.flatten(1) + gt_masks = gt_masks.flatten(1).float() + numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks) + if self.naive_dice: + denominator = mask_preds.sum(-1)[:, None] + \ + gt_masks.sum(-1)[None, :] + else: + denominator = mask_preds.pow(2).sum(1)[:, None] + \ + gt_masks.pow(2).sum(1)[None, :] + loss = 1 - (numerator + self.eps) / (denominator + self.eps) + return loss + + def __call__(self, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (InstanceData): Predicted instances which + must contain "masks". + gt_instances (InstanceData): Ground truth which must contain + "mask". + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + assert hasattr(pred_instances, 'masks'), \ + "pred_instances must contain 'masks'" + assert hasattr(gt_instances, 'masks'), \ + "gt_instances must contain 'masks'" + pred_masks = pred_instances.masks + gt_masks = gt_instances.masks + + if self.pred_act: + pred_masks = pred_masks.sigmoid() + dice_cost = self._binary_mask_dice_loss(pred_masks, gt_masks) + return dice_cost * self.weight + + +@TASK_UTILS.register_module() +class CrossEntropyLossCost(BaseMatchCost): + """CrossEntropyLossCost. + + Args: + use_sigmoid (bool): Whether the prediction uses sigmoid + of softmax. Defaults to True. + weight (Union[float, int]): Cost weight. Defaults to 1. + """ + + def __init__(self, + use_sigmoid: bool = True, + weight: Union[float, int] = 1.) -> None: + super().__init__(weight=weight) + self.use_sigmoid = use_sigmoid + + def _binary_cross_entropy(self, cls_pred: Tensor, + gt_labels: Tensor) -> Tensor: + """ + Args: + cls_pred (Tensor): The prediction with shape (num_queries, 1, *) or + (num_queries, *). + gt_labels (Tensor): The learning label of prediction with + shape (num_gt, *). + + Returns: + Tensor: Cross entropy cost matrix in shape (num_queries, num_gt). + """ + cls_pred = cls_pred.flatten(1).float() + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + pos = F.binary_cross_entropy_with_logits( + cls_pred, torch.ones_like(cls_pred), reduction='none') + neg = F.binary_cross_entropy_with_logits( + cls_pred, torch.zeros_like(cls_pred), reduction='none') + cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \ + torch.einsum('nc,mc->nm', neg, 1 - gt_labels) + cls_cost = cls_cost / n + + return cls_cost + + def __call__(self, pred_instances: InstanceData, + gt_instances: InstanceData, **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): Predicted instances which + must contain ``masks``. + gt_instances (:obj:`InstanceData`): Ground truth which must contain + ``masks``. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + assert hasattr(pred_instances, 'masks'), \ + "pred_instances must contain 'masks'" + assert hasattr(gt_instances, 'masks'), \ + "gt_instances must contain 'masks'" + pred_masks = pred_instances.masks + gt_masks = gt_instances.masks + if self.use_sigmoid: + cls_cost = self._binary_cross_entropy(pred_masks, gt_masks) + else: + raise NotImplementedError + + return cls_cost * self.weight diff --git a/finetune/mmseg/models/backbones/__init__.py b/finetune/mmseg/models/backbones/__init__.py new file mode 100644 index 0000000..784d3df --- /dev/null +++ b/finetune/mmseg/models/backbones/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .beit import BEiT +from .bisenetv1 import BiSeNetV1 +from .bisenetv2 import BiSeNetV2 +from .cgnet import CGNet +from .ddrnet import DDRNet +from .erfnet import ERFNet +from .fast_scnn import FastSCNN +from .hrnet import HRNet +from .icnet import ICNet +from .mae import MAE +from .mit import MixVisionTransformer +from .mobilenet_v2 import MobileNetV2 +from .mobilenet_v3 import MobileNetV3 +from .mscan import MSCAN +from .pidnet import PIDNet +from .resnest import ResNeSt +from .resnet import ResNet, ResNetV1c, ResNetV1d +from .resnext import ResNeXt +from .stdc import STDCContextPathNet, STDCNet +from .swin import SwinTransformer +from .timm_backbone import TIMMBackbone +from .twins import PCPVT, SVT +from .unet import UNet +from .vit import VisionTransformer +from .vpd import VPD + +__all__ = [ + 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', + 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', + 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', + 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', + 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN', + 'DDRNet', 'VPD' +] diff --git a/finetune/mmseg/models/backbones/beit.py b/finetune/mmseg/models/backbones/beit.py new file mode 100644 index 0000000..e5da71e --- /dev/null +++ b/finetune/mmseg/models/backbones/beit.py @@ -0,0 +1,554 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmengine.runner.checkpoint import _load_checkpoint +from scipy import interpolate +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.utils import _pair as to_2tuple + +from mmseg.registry import MODELS +from ..utils import PatchEmbed +from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer + + +class BEiTAttention(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + bias (bool): The option to add leanable bias for q, k, v. If bias is + True, it will add leanable bias. If bias is 'qv_bias', it will only + add leanable bias for q, v. If bias is False, it will not add bias + for q, k, v. Default to 'qv_bias'. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + bias='qv_bias', + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None, + **kwargs): + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.bias = bias + self.scale = qk_scale or head_embed_dims**-0.5 + + qkv_bias = bias + if bias == 'qv_bias': + self._init_qv_bias() + qkv_bias = False + + self.window_size = window_size + self._init_rel_pos_embedding() + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + def _init_qv_bias(self): + self.q_bias = nn.Parameter(torch.zeros(self.embed_dims)) + self.v_bias = nn.Parameter(torch.zeros(self.embed_dims)) + + def _init_rel_pos_embedding(self): + Wh, Ww = self.window_size + # cls to token & token 2 cls & cls to cls + self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3 + # relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH) + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, self.num_heads)) + + # get pair-wise relative position index for + # each token inside the window + coords_h = torch.arange(Wh) + coords_w = torch.arange(Ww) + # coords shape is (2, Wh, Ww) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + # coords_flatten shape is (2, Wh*Ww) + coords_flatten = torch.flatten(coords, 1) + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :]) + # relative_coords shape is (Wh*Ww, Wh*Ww, 2) + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + # shift to start from 0 + relative_coords[:, :, 0] += Wh - 1 + relative_coords[:, :, 1] += Ww - 1 + relative_coords[:, :, 0] *= 2 * Ww - 1 + relative_position_index = torch.zeros( + size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype) + # relative_position_index shape is (Wh*Ww, Wh*Ww) + relative_position_index[1:, 1:] = relative_coords.sum(-1) + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer('relative_position_index', + relative_position_index) + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x): + """ + Args: + x (tensor): input features with shape of (num_windows*B, N, C). + """ + B, N, C = x.shape + + if self.bias == 'qv_bias': + k_bias = torch.zeros_like(self.v_bias, requires_grad=False) + qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + else: + qkv = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + if self.relative_position_bias_table is not None: + Wh = self.window_size[0] + Ww = self.window_size[1] + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + Wh * Ww + 1, Wh * Ww + 1, -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer): + """Implements one encoder layer in Vision Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + bias (bool): The option to add leanable bias for q, k, v. If bias is + True, it will add leanable bias. If bias is 'qv_bias', it will only + add leanable bias for q, v. If bias is False, it will not add bias + for q, k, v. Default to 'qv_bias'. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + window_size (tuple[int], optional): The height and width of the window. + Default: None. + init_values (float, optional): Initialize the values of BEiTAttention + and FFN with learnable scaling. Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + bias='qv_bias', + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + window_size=None, + attn_cfg=dict(), + ffn_cfg=dict(add_identity=False), + init_values=None): + attn_cfg.update(dict(window_size=window_size, qk_scale=None)) + + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + attn_drop_rate=attn_drop_rate, + drop_path_rate=0., + drop_rate=0., + num_fcs=num_fcs, + qkv_bias=bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + attn_cfg=attn_cfg, + ffn_cfg=ffn_cfg) + + # NOTE: drop path for stochastic depth, we shall see if + # this is better than dropout here + dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate) + self.drop_path = build_dropout( + dropout_layer) if dropout_layer else nn.Identity() + self.gamma_1 = nn.Parameter( + init_values * torch.ones(embed_dims), requires_grad=True) + self.gamma_2 = nn.Parameter( + init_values * torch.ones(embed_dims), requires_grad=True) + + def build_attn(self, attn_cfg): + self.attn = BEiTAttention(**attn_cfg) + + def forward(self, x): + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x))) + return x + + +@MODELS.register_module() +class BEiT(BaseModule): + """BERT Pre-Training of Image Transformers. + + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): Embedding dimension. Default: 768. + num_layers (int): Depth of transformer. Default: 12. + num_heads (int): Number of attention heads. Default: 12. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. + qv_bias (bool): Enable bias for qv if True. Default: True. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + pretrained (str, optional): Model pretrained path. Default: None. + init_values (float): Initialize the values of BEiTAttention and FFN + with learnable scaling. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + qv_bias=True, + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + final_norm=False, + num_fcs=2, + norm_eval=False, + pretrained=None, + init_values=0.1, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.in_channels = in_channels + self.img_size = img_size + self.patch_size = patch_size + self.norm_eval = norm_eval + self.pretrained = pretrained + self.num_layers = num_layers + self.embed_dims = embed_dims + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + self.num_fcs = num_fcs + self.qv_bias = qv_bias + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.patch_norm = patch_norm + self.init_values = init_values + self.window_size = (img_size[0] // patch_size, + img_size[1] // patch_size) + self.patch_shape = self.window_size + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + self._build_patch_embedding() + self._build_layers() + + if isinstance(out_indices, int): + if out_indices == -1: + out_indices = num_layers - 1 + self.out_indices = [out_indices] + elif isinstance(out_indices, list) or isinstance(out_indices, tuple): + self.out_indices = out_indices + else: + raise TypeError('out_indices must be type of int, list or tuple') + + self.final_norm = final_norm + if final_norm: + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + def _build_patch_embedding(self): + """Build patch embedding layer.""" + self.patch_embed = PatchEmbed( + in_channels=self.in_channels, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=self.patch_size, + stride=self.patch_size, + padding=0, + norm_cfg=self.norm_cfg if self.patch_norm else None, + init_cfg=None) + + def _build_layers(self): + """Build transformer encoding layers.""" + + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, self.num_layers) + ] + self.layers = ModuleList() + for i in range(self.num_layers): + self.layers.append( + BEiTTransformerEncoderLayer( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self.mlp_ratio * self.embed_dims, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=dpr[i], + num_fcs=self.num_fcs, + bias='qv_bias' if self.qv_bias else False, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg, + window_size=self.window_size, + init_values=self.init_values)) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def _geometric_sequence_interpolation(self, src_size, dst_size, sequence, + num): + """Get new sequence via geometric sequence interpolation. + + Args: + src_size (int): Pos_embedding size in pre-trained model. + dst_size (int): Pos_embedding size in the current model. + sequence (tensor): The relative position bias of the pretrain + model after removing the extra tokens. + num (int): Number of attention heads. + Returns: + new_sequence (tensor): Geometric sequence interpolate the + pre-trained relative position bias to the size of + the current model. + """ + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + # Here is a binary function. + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + # The position of each interpolated point is determined + # by the ratio obtained by dichotomy. + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q**(i + 1) + r_ids = [-_ for _ in reversed(dis)] + x = r_ids + [0] + dis + y = r_ids + [0] + dis + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + # Interpolation functions are being executed and called. + new_sequence = [] + for i in range(num): + z = sequence[:, i].view(src_size, src_size).float().numpy() + f = interpolate.interp2d(x, y, z, kind='cubic') + new_sequence.append( + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence)) + new_sequence = torch.cat(new_sequence, dim=-1) + return new_sequence + + def resize_rel_pos_embed(self, checkpoint): + """Resize relative pos_embed weights. + + This function is modified from + https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501 + Copyright (c) Microsoft Corporation + Licensed under the MIT License + Args: + checkpoint (dict): Key and value of the pretrain model. + Returns: + state_dict (dict): Interpolate the relative pos_embed weights + in the pre-train model to the current model size. + """ + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + all_keys = list(state_dict.keys()) + for key in all_keys: + if 'relative_position_index' in key: + state_dict.pop(key) + # In order to keep the center of pos_bias as consistent as + # possible after interpolation, and vice versa in the edge + # area, the geometric sequence interpolation method is adopted. + if 'relative_position_bias_table' in key: + rel_pos_bias = state_dict[key] + src_num_pos, num_attn_heads = rel_pos_bias.size() + dst_num_pos, _ = self.state_dict()[key].size() + dst_patch_shape = self.patch_shape + if dst_patch_shape[0] != dst_patch_shape[1]: + raise NotImplementedError() + # Count the number of extra tokens. + num_extra_tokens = dst_num_pos - ( + dst_patch_shape[0] * 2 - 1) * ( + dst_patch_shape[1] * 2 - 1) + src_size = int((src_num_pos - num_extra_tokens)**0.5) + dst_size = int((dst_num_pos - num_extra_tokens)**0.5) + if src_size != dst_size: + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + new_rel_pos_bias = self._geometric_sequence_interpolation( + src_size, dst_size, rel_pos_bias, num_attn_heads) + new_rel_pos_bias = torch.cat( + (new_rel_pos_bias, extra_tokens), dim=0) + state_dict[key] = new_rel_pos_bias + + return state_dict + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + checkpoint = _load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + state_dict = self.resize_rel_pos_embed(checkpoint) + self.load_state_dict(state_dict, False) + elif self.init_cfg is not None: + super().init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + # Copyright 2019 Ross Wightman + # Licensed under the Apache License, Version 2.0 (the "License") + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def forward(self, inputs): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + # Remove class token and reshape token for decoder head + out = x[:, 1:] + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.LayerNorm): + m.eval() diff --git a/finetune/mmseg/models/backbones/bisenetv1.py b/finetune/mmseg/models/backbones/bisenetv1.py new file mode 100644 index 0000000..ca58bf9 --- /dev/null +++ b/finetune/mmseg/models/backbones/bisenetv1.py @@ -0,0 +1,332 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +class SpatialPath(BaseModule): + """Spatial Path to preserve the spatial size of the original input image + and encode affluent spatial information. + + Args: + in_channels(int): The number of channels of input + image. Default: 3. + num_channels (Tuple[int]): The number of channels of + each layers in Spatial Path. + Default: (64, 64, 64, 128). + Returns: + x (torch.Tensor): Feature map for Feature Fusion Module. + """ + + def __init__(self, + in_channels=3, + num_channels=(64, 64, 64, 128), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert len(num_channels) == 4, 'Length of input channels \ + of Spatial Path must be 4!' + + self.layers = [] + for i in range(len(num_channels)): + layer_name = f'layer{i + 1}' + self.layers.append(layer_name) + if i == 0: + self.add_module( + layer_name, + ConvModule( + in_channels=in_channels, + out_channels=num_channels[i], + kernel_size=7, + stride=2, + padding=3, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + elif i == len(num_channels) - 1: + self.add_module( + layer_name, + ConvModule( + in_channels=num_channels[i - 1], + out_channels=num_channels[i], + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + else: + self.add_module( + layer_name, + ConvModule( + in_channels=num_channels[i - 1], + out_channels=num_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + for i, layer_name in enumerate(self.layers): + layer_stage = getattr(self, layer_name) + x = layer_stage(x) + return x + + +class AttentionRefinementModule(BaseModule): + """Attention Refinement Module (ARM) to refine the features of each stage. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + Returns: + x_out (torch.Tensor): Feature map of Attention Refinement Module. + """ + + def __init__(self, + in_channels, + out_channel, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.conv_layer = ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.atten_conv_layer = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + in_channels=out_channel, + out_channels=out_channel, + kernel_size=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), nn.Sigmoid()) + + def forward(self, x): + x = self.conv_layer(x) + x_atten = self.atten_conv_layer(x) + x_out = x * x_atten + return x_out + + +class ContextPath(BaseModule): + """Context Path to provide sufficient receptive field. + + Args: + backbone_cfg:(dict): Config of backbone of + Context Path. + context_channels (Tuple[int]): The number of channel numbers + of various modules in Context Path. + Default: (128, 256, 512). + align_corners (bool, optional): The align_corners argument of + resize operation. Default: False. + Returns: + x_16_up, x_32_up (torch.Tensor, torch.Tensor): Two feature maps + undergoing upsampling from 1/16 and 1/32 downsampling + feature maps. These two feature maps are used for Feature + Fusion Module and Auxiliary Head. + """ + + def __init__(self, + backbone_cfg, + context_channels=(128, 256, 512), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert len(context_channels) == 3, 'Length of input channels \ + of Context Path must be 3!' + + self.backbone = MODELS.build(backbone_cfg) + + self.align_corners = align_corners + self.arm16 = AttentionRefinementModule(context_channels[1], + context_channels[0]) + self.arm32 = AttentionRefinementModule(context_channels[2], + context_channels[0]) + self.conv_head32 = ConvModule( + in_channels=context_channels[0], + out_channels=context_channels[0], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv_head16 = ConvModule( + in_channels=context_channels[0], + out_channels=context_channels[0], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.gap_conv = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + in_channels=context_channels[2], + out_channels=context_channels[0], + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + x_4, x_8, x_16, x_32 = self.backbone(x) + x_gap = self.gap_conv(x_32) + + x_32_arm = self.arm32(x_32) + x_32_sum = x_32_arm + x_gap + x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode='nearest') + x_32_up = self.conv_head32(x_32_up) + + x_16_arm = self.arm16(x_16) + x_16_sum = x_16_arm + x_32_up + x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode='nearest') + x_16_up = self.conv_head16(x_16_up) + + return x_16_up, x_32_up + + +class FeatureFusionModule(BaseModule): + """Feature Fusion Module to fuse low level output feature of Spatial Path + and high level output feature of Context Path. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + Returns: + x_out (torch.Tensor): Feature map of Feature Fusion Module. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.conv_atten = nn.Sequential( + ConvModule( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), nn.Sigmoid()) + + def forward(self, x_sp, x_cp): + x_concat = torch.cat([x_sp, x_cp], dim=1) + x_fuse = self.conv1(x_concat) + x_atten = self.gap(x_fuse) + # Note: No BN and more 1x1 conv in paper. + x_atten = self.conv_atten(x_atten) + x_atten = x_fuse * x_atten + x_out = x_atten + x_fuse + return x_out + + +@MODELS.register_module() +class BiSeNetV1(BaseModule): + """BiSeNetV1 backbone. + + This backbone is the implementation of `BiSeNet: Bilateral + Segmentation Network for Real-time Semantic + Segmentation `_. + + Args: + backbone_cfg:(dict): Config of backbone of + Context Path. + in_channels (int): The number of channels of input + image. Default: 3. + spatial_channels (Tuple[int]): Size of channel numbers of + various layers in Spatial Path. + Default: (64, 64, 64, 128). + context_channels (Tuple[int]): Size of channel numbers of + various modules in Context Path. + Default: (128, 256, 512). + out_indices (Tuple[int] | int, optional): Output from which stages. + Default: (0, 1, 2). + align_corners (bool, optional): The align_corners argument of + resize operation in Bilateral Guided Aggregation Layer. + Default: False. + out_channels(int): The number of channels of output. + It must be the same with `in_channels` of decode_head. + Default: 256. + """ + + def __init__(self, + backbone_cfg, + in_channels=3, + spatial_channels=(64, 64, 64, 128), + context_channels=(128, 256, 512), + out_indices=(0, 1, 2), + align_corners=False, + out_channels=256, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU'), + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + assert len(spatial_channels) == 4, 'Length of input channels \ + of Spatial Path must be 4!' + + assert len(context_channels) == 3, 'Length of input channels \ + of Context Path must be 3!' + + self.out_indices = out_indices + self.align_corners = align_corners + self.context_path = ContextPath(backbone_cfg, context_channels, + self.align_corners) + self.spatial_path = SpatialPath(in_channels, spatial_channels) + self.ffm = FeatureFusionModule(context_channels[1], out_channels) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + def forward(self, x): + # stole refactoring code from Coin Cheung, thanks + x_context8, x_context16 = self.context_path(x) + x_spatial = self.spatial_path(x) + x_fuse = self.ffm(x_spatial, x_context8) + + outs = [x_fuse, x_context8, x_context16] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/finetune/mmseg/models/backbones/bisenetv2.py b/finetune/mmseg/models/backbones/bisenetv2.py new file mode 100644 index 0000000..32aa498 --- /dev/null +++ b/finetune/mmseg/models/backbones/bisenetv2.py @@ -0,0 +1,622 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, + build_activation_layer, build_norm_layer) +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +class DetailBranch(BaseModule): + """Detail Branch with wide channels and shallow layers to capture low-level + details and generate high-resolution feature representation. + + Args: + detail_channels (Tuple[int]): Size of channel numbers of each stage + in Detail Branch, in paper it has 3 stages. + Default: (64, 64, 128). + in_channels (int): Number of channels of input image. Default: 3. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Feature map of Detail Branch. + """ + + def __init__(self, + detail_channels=(64, 64, 128), + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + detail_branch = [] + for i in range(len(detail_channels)): + if i == 0: + detail_branch.append( + nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=detail_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg))) + else: + detail_branch.append( + nn.Sequential( + ConvModule( + in_channels=detail_channels[i - 1], + out_channels=detail_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg))) + self.detail_branch = nn.ModuleList(detail_branch) + + def forward(self, x): + for stage in self.detail_branch: + x = stage(x) + return x + + +class StemBlock(BaseModule): + """Stem Block at the beginning of Semantic Branch. + + Args: + in_channels (int): Number of input channels. + Default: 3. + out_channels (int): Number of output channels. + Default: 16. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): First feature map in Semantic Branch. + """ + + def __init__(self, + in_channels=3, + out_channels=16, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.conv_first = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.convs = nn.Sequential( + ConvModule( + in_channels=out_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=out_channels // 2, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.pool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, ceil_mode=False) + self.fuse_last = ConvModule( + in_channels=out_channels * 2, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + x = self.conv_first(x) + x_left = self.convs(x) + x_right = self.pool(x) + x = self.fuse_last(torch.cat([x_left, x_right], dim=1)) + return x + + +class GELayer(BaseModule): + """Gather-and-Expansion Layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + exp_ratio (int): Expansion ratio for middle channels. + Default: 6. + stride (int): Stride of GELayer. Default: 1 + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Intermediate feature map in + Semantic Branch. + """ + + def __init__(self, + in_channels, + out_channels, + exp_ratio=6, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + mid_channel = in_channels * exp_ratio + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if stride == 1: + self.dwconv = nn.Sequential( + # ReLU in ConvModule not shown in paper + ConvModule( + in_channels=in_channels, + out_channels=mid_channel, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.shortcut = None + else: + self.dwconv = nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=mid_channel, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + # ReLU in ConvModule not shown in paper + ConvModule( + in_channels=mid_channel, + out_channels=mid_channel, + kernel_size=3, + stride=1, + padding=1, + groups=mid_channel, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ) + self.shortcut = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=norm_cfg, + pw_act_cfg=None, + )) + + self.conv2 = nn.Sequential( + ConvModule( + in_channels=mid_channel, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + )) + + self.act = build_activation_layer(act_cfg) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.dwconv(x) + x = self.conv2(x) + if self.shortcut is not None: + shortcut = self.shortcut(identity) + x = x + shortcut + else: + x = x + identity + x = self.act(x) + return x + + +class CEBlock(BaseModule): + """Context Embedding Block for large receptive filed in Semantic Branch. + + Args: + in_channels (int): Number of input channels. + Default: 3. + out_channels (int): Number of output channels. + Default: 16. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Last feature map in Semantic Branch. + """ + + def __init__(self, + in_channels=3, + out_channels=16, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.gap = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + build_norm_layer(norm_cfg, self.in_channels)[1]) + self.conv_gap = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + # Note: in paper here is naive conv2d, no bn-relu + self.conv_last = ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + identity = x + x = self.gap(x) + x = self.conv_gap(x) + x = identity + x + x = self.conv_last(x) + return x + + +class SemanticBranch(BaseModule): + """Semantic Branch which is lightweight with narrow channels and deep + layers to obtain high-level semantic context. + + Args: + semantic_channels(Tuple[int]): Size of channel numbers of + various stages in Semantic Branch. + Default: (16, 32, 64, 128). + in_channels (int): Number of channels of input image. Default: 3. + exp_ratio (int): Expansion ratio for middle channels. + Default: 6. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + semantic_outs (List[torch.Tensor]): List of several feature maps + for auxiliary heads (Booster) and Bilateral + Guided Aggregation Layer. + """ + + def __init__(self, + semantic_channels=(16, 32, 64, 128), + in_channels=3, + exp_ratio=6, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.semantic_channels = semantic_channels + self.semantic_stages = [] + for i in range(len(semantic_channels)): + stage_name = f'stage{i + 1}' + self.semantic_stages.append(stage_name) + if i == 0: + self.add_module( + stage_name, + StemBlock(self.in_channels, semantic_channels[i])) + elif i == (len(semantic_channels) - 1): + self.add_module( + stage_name, + nn.Sequential( + GELayer(semantic_channels[i - 1], semantic_channels[i], + exp_ratio, 2), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1))) + else: + self.add_module( + stage_name, + nn.Sequential( + GELayer(semantic_channels[i - 1], semantic_channels[i], + exp_ratio, 2), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1))) + + self.add_module(f'stage{len(semantic_channels)}_CEBlock', + CEBlock(semantic_channels[-1], semantic_channels[-1])) + self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock') + + def forward(self, x): + semantic_outs = [] + for stage_name in self.semantic_stages: + semantic_stage = getattr(self, stage_name) + x = semantic_stage(x) + semantic_outs.append(x) + return semantic_outs + + +class BGALayer(BaseModule): + """Bilateral Guided Aggregation Layer to fuse the complementary information + from both Detail Branch and Semantic Branch. + + Args: + out_channels (int): Number of output channels. + Default: 128. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + output (torch.Tensor): Output feature map for Segment heads. + """ + + def __init__(self, + out_channels=128, + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.out_channels = out_channels + self.align_corners = align_corners + self.detail_dwconv = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=None, + pw_act_cfg=None, + )) + self.detail_down = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)) + self.semantic_conv = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None)) + self.semantic_dwconv = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=None, + pw_act_cfg=None, + )) + self.conv = ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + inplace=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + + def forward(self, x_d, x_s): + detail_dwconv = self.detail_dwconv(x_d) + detail_down = self.detail_down(x_d) + semantic_conv = self.semantic_conv(x_s) + semantic_dwconv = self.semantic_dwconv(x_s) + semantic_conv = resize( + input=semantic_conv, + size=detail_dwconv.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv) + fuse_2 = detail_down * torch.sigmoid(semantic_dwconv) + fuse_2 = resize( + input=fuse_2, + size=fuse_1.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + output = self.conv(fuse_1 + fuse_2) + return output + + +@MODELS.register_module() +class BiSeNetV2(BaseModule): + """BiSeNetV2: Bilateral Network with Guided Aggregation for + Real-time Semantic Segmentation. + + This backbone is the implementation of + `BiSeNetV2 `_. + + Args: + in_channels (int): Number of channel of input image. Default: 3. + detail_channels (Tuple[int], optional): Channels of each stage + in Detail Branch. Default: (64, 64, 128). + semantic_channels (Tuple[int], optional): Channels of each stage + in Semantic Branch. Default: (16, 32, 64, 128). + See Table 1 and Figure 3 of paper for more details. + semantic_expansion_ratio (int, optional): The expansion factor + expanding channel number of middle channels in Semantic Branch. + Default: 6. + bga_channels (int, optional): Number of middle channels in + Bilateral Guided Aggregation Layer. Default: 128. + out_indices (Tuple[int] | int, optional): Output from which stages. + Default: (0, 1, 2, 3, 4). + align_corners (bool, optional): The align_corners argument of + resize operation in Bilateral Guided Aggregation Layer. + Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=3, + detail_channels=(64, 64, 128), + semantic_channels=(16, 32, 64, 128), + semantic_expansion_ratio=6, + bga_channels=128, + out_indices=(0, 1, 2, 3, 4), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + if init_cfg is None: + init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_indices = out_indices + self.detail_channels = detail_channels + self.semantic_channels = semantic_channels + self.semantic_expansion_ratio = semantic_expansion_ratio + self.bga_channels = bga_channels + self.align_corners = align_corners + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.detail = DetailBranch(self.detail_channels, self.in_channels) + self.semantic = SemanticBranch(self.semantic_channels, + self.in_channels, + self.semantic_expansion_ratio) + self.bga = BGALayer(self.bga_channels, self.align_corners) + + def forward(self, x): + # stole refactoring code from Coin Cheung, thanks + x_detail = self.detail(x) + x_semantic_lst = self.semantic(x) + x_head = self.bga(x_detail, x_semantic_lst[-1]) + outs = [x_head] + x_semantic_lst[:-1] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/finetune/mmseg/models/backbones/cgnet.py b/finetune/mmseg/models/backbones/cgnet.py new file mode 100644 index 0000000..b74b494 --- /dev/null +++ b/finetune/mmseg/models/backbones/cgnet.py @@ -0,0 +1,372 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer +from mmengine.model import BaseModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmseg.registry import MODELS + + +class GlobalContextExtractor(nn.Module): + """Global Context Extractor for CGNet. + + This class is employed to refine the joint feature of both local feature + and surrounding context. + + Args: + channel (int): Number of input feature channels. + reduction (int): Reductions for global context extractor. Default: 16. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, channel, reduction=16, with_cp=False): + super().__init__() + self.channel = channel + self.reduction = reduction + assert reduction >= 1 and channel >= reduction + self.with_cp = with_cp + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), nn.Sigmoid()) + + def forward(self, x): + + def _inner_forward(x): + num_batch, num_channel = x.size()[:2] + y = self.avg_pool(x).view(num_batch, num_channel) + y = self.fc(y).view(num_batch, num_channel, 1, 1) + return x * y + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class ContextGuidedBlock(nn.Module): + """Context Guided Block for CGNet. + + This class consists of four components: local feature extractor, + surrounding feature extractor, joint feature extractor and global + context extractor. + + Args: + in_channels (int): Number of input feature channels. + out_channels (int): Number of output feature channels. + dilation (int): Dilation rate for surrounding context extractor. + Default: 2. + reduction (int): Reduction for global context extractor. Default: 16. + skip_connect (bool): Add input to output or not. Default: True. + downsample (bool): Downsample the input to 1/2 or not. Default: False. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels, + out_channels, + dilation=2, + reduction=16, + skip_connect=True, + downsample=False, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='PReLU'), + with_cp=False): + super().__init__() + self.with_cp = with_cp + self.downsample = downsample + + channels = out_channels if downsample else out_channels // 2 + if 'type' in act_cfg and act_cfg['type'] == 'PReLU': + act_cfg['num_parameters'] = channels + kernel_size = 3 if downsample else 1 + stride = 2 if downsample else 1 + padding = (kernel_size - 1) // 2 + + self.conv1x1 = ConvModule( + in_channels, + channels, + kernel_size, + stride, + padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.f_loc = build_conv_layer( + conv_cfg, + channels, + channels, + kernel_size=3, + padding=1, + groups=channels, + bias=False) + self.f_sur = build_conv_layer( + conv_cfg, + channels, + channels, + kernel_size=3, + padding=dilation, + groups=channels, + dilation=dilation, + bias=False) + + self.bn = build_norm_layer(norm_cfg, 2 * channels)[1] + self.activate = nn.PReLU(2 * channels) + + if downsample: + self.bottleneck = build_conv_layer( + conv_cfg, + 2 * channels, + out_channels, + kernel_size=1, + bias=False) + + self.skip_connect = skip_connect and not downsample + self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp) + + def forward(self, x): + + def _inner_forward(x): + out = self.conv1x1(x) + loc = self.f_loc(out) + sur = self.f_sur(out) + + joi_feat = torch.cat([loc, sur], 1) # the joint feature + joi_feat = self.bn(joi_feat) + joi_feat = self.activate(joi_feat) + if self.downsample: + joi_feat = self.bottleneck(joi_feat) # channel = out_channels + # f_glo is employed to refine the joint feature + out = self.f_glo(joi_feat) + + if self.skip_connect: + return x + out + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class InputInjection(nn.Module): + """Downsampling module for CGNet.""" + + def __init__(self, num_downsampling): + super().__init__() + self.pool = nn.ModuleList() + for i in range(num_downsampling): + self.pool.append(nn.AvgPool2d(3, stride=2, padding=1)) + + def forward(self, x): + for pool in self.pool: + x = pool(x) + return x + + +@MODELS.register_module() +class CGNet(BaseModule): + """CGNet backbone. + + This backbone is the implementation of `A Light-weight Context Guided + Network for Semantic Segmentation `_. + + Args: + in_channels (int): Number of input image channels. Normally 3. + num_channels (tuple[int]): Numbers of feature channels at each stages. + Default: (32, 64, 128). + num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2. + Default: (3, 21). + dilations (tuple[int]): Dilation rate for surrounding context + extractors at stage 1 and stage 2. Default: (2, 4). + reductions (tuple[int]): Reductions for global context extractors at + stage 1 and stage 2. Default: (8, 16). + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='PReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels=3, + num_channels=(32, 64, 128), + num_blocks=(3, 21), + dilations=(2, 4), + reductions=(8, 16), + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='PReLU'), + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + + super().__init__(init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer=['Conv2d', 'Linear']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']), + dict(type='Constant', val=0, layer='PReLU') + ] + else: + raise TypeError('pretrained must be a str or None') + + self.in_channels = in_channels + self.num_channels = num_channels + assert isinstance(self.num_channels, tuple) and len( + self.num_channels) == 3 + self.num_blocks = num_blocks + assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2 + self.dilations = dilations + assert isinstance(self.dilations, tuple) and len(self.dilations) == 2 + self.reductions = reductions + assert isinstance(self.reductions, tuple) and len(self.reductions) == 2 + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU': + self.act_cfg['num_parameters'] = num_channels[0] + self.norm_eval = norm_eval + self.with_cp = with_cp + + cur_channels = in_channels + self.stem = nn.ModuleList() + for i in range(3): + self.stem.append( + ConvModule( + cur_channels, + num_channels[0], + 3, + 2 if i == 0 else 1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + cur_channels = num_channels[0] + + self.inject_2x = InputInjection(1) # down-sample for Input, factor=2 + self.inject_4x = InputInjection(2) # down-sample for Input, factor=4 + + cur_channels += in_channels + self.norm_prelu_0 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + # stage 1 + self.level1 = nn.ModuleList() + for i in range(num_blocks[0]): + self.level1.append( + ContextGuidedBlock( + cur_channels if i == 0 else num_channels[1], + num_channels[1], + dilations[0], + reductions[0], + downsample=(i == 0), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + with_cp=with_cp)) # CG block + + cur_channels = 2 * num_channels[1] + in_channels + self.norm_prelu_1 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + # stage 2 + self.level2 = nn.ModuleList() + for i in range(num_blocks[1]): + self.level2.append( + ContextGuidedBlock( + cur_channels if i == 0 else num_channels[2], + num_channels[2], + dilations[1], + reductions[1], + downsample=(i == 0), + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + with_cp=with_cp)) # CG block + + cur_channels = 2 * num_channels[2] + self.norm_prelu_2 = nn.Sequential( + build_norm_layer(norm_cfg, cur_channels)[1], + nn.PReLU(cur_channels)) + + def forward(self, x): + output = [] + + # stage 0 + inp_2x = self.inject_2x(x) + inp_4x = self.inject_4x(x) + for layer in self.stem: + x = layer(x) + x = self.norm_prelu_0(torch.cat([x, inp_2x], 1)) + output.append(x) + + # stage 1 + for i, layer in enumerate(self.level1): + x = layer(x) + if i == 0: + down1 = x + x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1)) + output.append(x) + + # stage 2 + for i, layer in enumerate(self.level2): + x = layer(x) + if i == 0: + down2 = x + x = self.norm_prelu_2(torch.cat([down2, x], 1)) + output.append(x) + + return output + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/finetune/mmseg/models/backbones/ddrnet.py b/finetune/mmseg/models/backbones/ddrnet.py new file mode 100644 index 0000000..4508aad --- /dev/null +++ b/finetune/mmseg/models/backbones/ddrnet.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, build_norm_layer +from mmengine.model import BaseModule + +from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize +from mmseg.registry import MODELS +from mmseg.utils import OptConfigType + + +@MODELS.register_module() +class DDRNet(BaseModule): + """DDRNet backbone. + + This backbone is the implementation of `Deep Dual-resolution Networks for + Real-time and Accurate Semantic Segmentation of Road Scenes + `_. + Modified from https://github.com/ydhongHIT/DDRNet. + + Args: + in_channels (int): Number of input image channels. Default: 3. + channels: (int): The base channels of DDRNet. Default: 32. + ppm_channels (int): The channels of PPM module. Default: 128. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + norm_cfg (dict): Config dict to build norm layer. + Default: dict(type='BN', requires_grad=True). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + init_cfg (dict, optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels: int = 3, + channels: int = 32, + ppm_channels: int = 128, + align_corners: bool = False, + norm_cfg: OptConfigType = dict(type='BN', requires_grad=True), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + + self.in_channels = in_channels + self.ppm_channels = ppm_channels + + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + + # stage 0-2 + self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2) + self.relu = nn.ReLU() + + # low resolution(context) branch + self.context_branch_layers = nn.ModuleList() + for i in range(3): + self.context_branch_layers.append( + self._make_layer( + block=BasicBlock if i < 2 else Bottleneck, + inplanes=channels * 2**(i + 1), + planes=channels * 8 if i > 0 else channels * 4, + num_blocks=2 if i < 2 else 1, + stride=2)) + + # bilateral fusion + self.compression_1 = ConvModule( + channels * 4, + channels * 2, + kernel_size=1, + norm_cfg=self.norm_cfg, + act_cfg=None) + self.down_1 = ConvModule( + channels * 2, + channels * 4, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=None) + + self.compression_2 = ConvModule( + channels * 8, + channels * 2, + kernel_size=1, + norm_cfg=self.norm_cfg, + act_cfg=None) + self.down_2 = nn.Sequential( + ConvModule( + channels * 2, + channels * 4, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + channels * 4, + channels * 8, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=None)) + + # high resolution(spatial) branch + self.spatial_branch_layers = nn.ModuleList() + for i in range(3): + self.spatial_branch_layers.append( + self._make_layer( + block=BasicBlock if i < 2 else Bottleneck, + inplanes=channels * 2, + planes=channels * 2, + num_blocks=2 if i < 2 else 1, + )) + + self.spp = DAPPM( + channels * 16, ppm_channels, channels * 4, num_scales=5) + + def _make_stem_layer(self, in_channels, channels, num_blocks): + layers = [ + ConvModule( + in_channels, + channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + channels, + channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + ] + + layers.extend([ + self._make_layer(BasicBlock, channels, channels, num_blocks), + nn.ReLU(), + self._make_layer( + BasicBlock, channels, channels * 2, num_blocks, stride=2), + nn.ReLU(), + ]) + + return nn.Sequential(*layers) + + def _make_layer(self, block, inplanes, planes, num_blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) + + layers = [ + block( + in_channels=inplanes, + channels=planes, + stride=stride, + downsample=downsample) + ] + inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + in_channels=inplanes, + channels=planes, + stride=1, + norm_cfg=self.norm_cfg, + act_cfg_out=None if i == num_blocks - 1 else self.act_cfg)) + + return nn.Sequential(*layers) + + def forward(self, x): + """Forward function.""" + out_size = (x.shape[-2] // 8, x.shape[-1] // 8) + + # stage 0-2 + x = self.stem(x) + + # stage3 + x_c = self.context_branch_layers[0](x) + x_s = self.spatial_branch_layers[0](x) + comp_c = self.compression_1(self.relu(x_c)) + x_c += self.down_1(self.relu(x_s)) + x_s += resize( + comp_c, + size=out_size, + mode='bilinear', + align_corners=self.align_corners) + if self.training: + temp_context = x_s.clone() + + # stage4 + x_c = self.context_branch_layers[1](self.relu(x_c)) + x_s = self.spatial_branch_layers[1](self.relu(x_s)) + comp_c = self.compression_2(self.relu(x_c)) + x_c += self.down_2(self.relu(x_s)) + x_s += resize( + comp_c, + size=out_size, + mode='bilinear', + align_corners=self.align_corners) + + # stage5 + x_s = self.spatial_branch_layers[2](self.relu(x_s)) + x_c = self.context_branch_layers[2](self.relu(x_c)) + x_c = self.spp(x_c) + x_c = resize( + x_c, + size=out_size, + mode='bilinear', + align_corners=self.align_corners) + + return (temp_context, x_s + x_c) if self.training else x_s + x_c diff --git a/finetune/mmseg/models/backbones/erfnet.py b/finetune/mmseg/models/backbones/erfnet.py new file mode 100644 index 0000000..2c5ec67 --- /dev/null +++ b/finetune/mmseg/models/backbones/erfnet.py @@ -0,0 +1,329 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +class DownsamplerBlock(BaseModule): + """Downsampler block of ERFNet. + + This module is a little different from basical ConvModule. + The features from Conv and MaxPool layers are + concatenated before BatchNorm. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.conv = build_conv_layer( + self.conv_cfg, + in_channels, + out_channels - in_channels, + kernel_size=3, + stride=2, + padding=1) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] + self.act = build_activation_layer(self.act_cfg) + + def forward(self, input): + conv_out = self.conv(input) + pool_out = self.pool(input) + pool_out = resize( + input=pool_out, + size=conv_out.size()[2:], + mode='bilinear', + align_corners=False) + output = torch.cat([conv_out, pool_out], 1) + output = self.bn(output) + output = self.act(output) + return output + + +class NonBottleneck1d(BaseModule): + """Non-bottleneck block of ERFNet. + + Args: + channels (int): Number of channels in Non-bottleneck block. + drop_rate (float): Probability of an element to be zeroed. + Default 0. + dilation (int): Dilation rate for last two conv layers. + Default 1. + num_conv_layer (int): Number of 3x1 and 1x3 convolution layers. + Default 2. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + channels, + drop_rate=0, + dilation=1, + num_conv_layer=2, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.act = build_activation_layer(self.act_cfg) + + self.convs_layers = nn.ModuleList() + for conv_layer in range(num_conv_layer): + first_conv_padding = (1, 0) if conv_layer == 0 else (dilation, 0) + first_conv_dilation = 1 if conv_layer == 0 else (dilation, 1) + second_conv_padding = (0, 1) if conv_layer == 0 else (0, dilation) + second_conv_dilation = 1 if conv_layer == 0 else (1, dilation) + + self.convs_layers.append( + build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(3, 1), + stride=1, + padding=first_conv_padding, + bias=True, + dilation=first_conv_dilation)) + self.convs_layers.append(self.act) + self.convs_layers.append( + build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(1, 3), + stride=1, + padding=second_conv_padding, + bias=True, + dilation=second_conv_dilation)) + self.convs_layers.append( + build_norm_layer(self.norm_cfg, channels)[1]) + if conv_layer == 0: + self.convs_layers.append(self.act) + else: + self.convs_layers.append(nn.Dropout(p=drop_rate)) + + def forward(self, input): + output = input + for conv in self.convs_layers: + output = conv(output) + output = self.act(output + input) + return output + + +class UpsamplerBlock(BaseModule): + """Upsampler block of ERFNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.conv = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + bias=True) + self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] + self.act = build_activation_layer(self.act_cfg) + + def forward(self, input): + output = self.conv(input) + output = self.bn(output) + output = self.act(output) + return output + + +@MODELS.register_module() +class ERFNet(BaseModule): + """ERFNet backbone. + + This backbone is the implementation of `ERFNet: Efficient Residual + Factorized ConvNet for Real-time SemanticSegmentation + `_. + + Args: + in_channels (int): The number of channels of input + image. Default: 3. + enc_downsample_channels (Tuple[int]): Size of channel + numbers of various Downsampler block in encoder. + Default: (16, 64, 128). + enc_stage_non_bottlenecks (Tuple[int]): Number of stages of + Non-bottleneck block in encoder. + Default: (5, 8). + enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each + stage of Non-bottleneck block of encoder. + Default: (2, 4, 8, 16). + enc_non_bottleneck_channels (Tuple[int]): Size of channel + numbers of various Non-bottleneck block in encoder. + Default: (64, 128). + dec_upsample_channels (Tuple[int]): Size of channel numbers of + various Deconvolution block in decoder. + Default: (64, 16). + dec_stages_non_bottleneck (Tuple[int]): Number of stages of + Non-bottleneck block in decoder. + Default: (2, 2). + dec_non_bottleneck_channels (Tuple[int]): Size of channel + numbers of various Non-bottleneck block in decoder. + Default: (64, 16). + drop_rate (float): Probability of an element to be zeroed. + Default 0.1. + """ + + def __init__(self, + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_stage_non_bottlenecks=(5, 8), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(64, 16), + dec_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU'), + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + assert len(enc_downsample_channels) \ + == len(dec_upsample_channels)+1, 'Number of downsample\ + block of encoder does not \ + match number of upsample block of decoder!' + assert len(enc_downsample_channels) \ + == len(enc_stage_non_bottlenecks)+1, 'Number of \ + downsample block of encoder does not match \ + number of Non-bottleneck block of encoder!' + assert len(enc_downsample_channels) \ + == len(enc_non_bottleneck_channels)+1, 'Number of \ + downsample block of encoder does not match \ + number of channels of Non-bottleneck block of encoder!' + assert enc_stage_non_bottlenecks[-1] \ + % len(enc_non_bottleneck_dilations) == 0, 'Number of \ + Non-bottleneck block of encoder does not match \ + number of Non-bottleneck block of encoder!' + assert len(dec_upsample_channels) \ + == len(dec_stages_non_bottleneck), 'Number of \ + upsample block of decoder does not match \ + number of Non-bottleneck block of decoder!' + assert len(dec_stages_non_bottleneck) \ + == len(dec_non_bottleneck_channels), 'Number of \ + Non-bottleneck block of decoder does not match \ + number of channels of Non-bottleneck block of decoder!' + + self.in_channels = in_channels + self.enc_downsample_channels = enc_downsample_channels + self.enc_stage_non_bottlenecks = enc_stage_non_bottlenecks + self.enc_non_bottleneck_dilations = enc_non_bottleneck_dilations + self.enc_non_bottleneck_channels = enc_non_bottleneck_channels + self.dec_upsample_channels = dec_upsample_channels + self.dec_stages_non_bottleneck = dec_stages_non_bottleneck + self.dec_non_bottleneck_channels = dec_non_bottleneck_channels + self.dropout_ratio = dropout_ratio + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.encoder.append( + DownsamplerBlock(self.in_channels, enc_downsample_channels[0])) + + for i in range(len(enc_downsample_channels) - 1): + self.encoder.append( + DownsamplerBlock(enc_downsample_channels[i], + enc_downsample_channels[i + 1])) + # Last part of encoder is some dilated NonBottleneck1d blocks. + if i == len(enc_downsample_channels) - 2: + iteration_times = int(enc_stage_non_bottlenecks[-1] / + len(enc_non_bottleneck_dilations)) + for j in range(iteration_times): + for k in range(len(enc_non_bottleneck_dilations)): + self.encoder.append( + NonBottleneck1d(enc_downsample_channels[-1], + self.dropout_ratio, + enc_non_bottleneck_dilations[k])) + else: + for j in range(enc_stage_non_bottlenecks[i]): + self.encoder.append( + NonBottleneck1d(enc_downsample_channels[i + 1], + self.dropout_ratio)) + + for i in range(len(dec_upsample_channels)): + if i == 0: + self.decoder.append( + UpsamplerBlock(enc_downsample_channels[-1], + dec_non_bottleneck_channels[i])) + else: + self.decoder.append( + UpsamplerBlock(dec_non_bottleneck_channels[i - 1], + dec_non_bottleneck_channels[i])) + for j in range(dec_stages_non_bottleneck[i]): + self.decoder.append( + NonBottleneck1d(dec_non_bottleneck_channels[i])) + + def forward(self, x): + for enc in self.encoder: + x = enc(x) + for dec in self.decoder: + x = dec(x) + return [x] diff --git a/finetune/mmseg/models/backbones/fast_scnn.py b/finetune/mmseg/models/backbones/fast_scnn.py new file mode 100644 index 0000000..6ff7a31 --- /dev/null +++ b/finetune/mmseg/models/backbones/fast_scnn.py @@ -0,0 +1,408 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.model import BaseModule + +from mmseg.models.decode_heads.psp_head import PPM +from mmseg.registry import MODELS +from ..utils import InvertedResidual, resize + + +class LearningToDownsample(nn.Module): + """Learning to downsample module. + + Args: + in_channels (int): Number of input channels. + dw_channels (tuple[int]): Number of output channels of the first and + the second depthwise conv (dwconv) layers. + out_channels (int): Number of output channels of the whole + 'learning to downsample' module. + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config + of depthwise ConvModule. If it is 'default', it will be the same + as `act_cfg`. Default: None. + """ + + def __init__(self, + in_channels, + dw_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + dw_act_cfg=None): + super().__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.dw_act_cfg = dw_act_cfg + dw_channels1 = dw_channels[0] + dw_channels2 = dw_channels[1] + + self.conv = ConvModule( + in_channels, + dw_channels1, + 3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.dsconv1 = DepthwiseSeparableConvModule( + dw_channels1, + dw_channels2, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + dw_act_cfg=self.dw_act_cfg) + + self.dsconv2 = DepthwiseSeparableConvModule( + dw_channels2, + out_channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + dw_act_cfg=self.dw_act_cfg) + + def forward(self, x): + x = self.conv(x) + x = self.dsconv1(x) + x = self.dsconv2(x) + return x + + +class GlobalFeatureExtractor(nn.Module): + """Global feature extractor module. + + Args: + in_channels (int): Number of input channels of the GFE module. + Default: 64 + block_channels (tuple[int]): Tuple of ints. Each int specifies the + number of output channels of each Inverted Residual module. + Default: (64, 96, 128) + out_channels(int): Number of output channels of the GFE module. + Default: 128 + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + Default: 6 + num_blocks (tuple[int]): Tuple of ints. Each int specifies the + number of times each Inverted Residual module is repeated. + The repeated Inverted Residual modules are called a 'group'. + Default: (3, 3, 3) + strides (tuple[int]): Tuple of ints. Each int specifies + the downsampling factor of each 'group'. + Default: (2, 2, 1) + pool_scales (tuple[int]): Tuple of ints. Each int specifies + the parameter required in 'global average pooling' within PPM. + Default: (1, 2, 3, 6) + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + align_corners (bool): align_corners argument of F.interpolate. + Default: False + """ + + def __init__(self, + in_channels=64, + block_channels=(64, 96, 128), + out_channels=128, + expand_ratio=6, + num_blocks=(3, 3, 3), + strides=(2, 2, 1), + pool_scales=(1, 2, 3, 6), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False): + super().__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + assert len(block_channels) == len(num_blocks) == 3 + self.bottleneck1 = self._make_layer(in_channels, block_channels[0], + num_blocks[0], strides[0], + expand_ratio) + self.bottleneck2 = self._make_layer(block_channels[0], + block_channels[1], num_blocks[1], + strides[1], expand_ratio) + self.bottleneck3 = self._make_layer(block_channels[1], + block_channels[2], num_blocks[2], + strides[2], expand_ratio) + self.ppm = PPM( + pool_scales, + block_channels[2], + block_channels[2] // 4, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=align_corners) + + self.out = ConvModule( + block_channels[2] * 2, + out_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _make_layer(self, + in_channels, + out_channels, + blocks, + stride=1, + expand_ratio=6): + layers = [ + InvertedResidual( + in_channels, + out_channels, + stride, + expand_ratio, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + ] + for i in range(1, blocks): + layers.append( + InvertedResidual( + out_channels, + out_channels, + 1, + expand_ratio, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.bottleneck1(x) + x = self.bottleneck2(x) + x = self.bottleneck3(x) + x = torch.cat([x, *self.ppm(x)], dim=1) + x = self.out(x) + return x + + +class FeatureFusionModule(nn.Module): + """Feature fusion module. + + Args: + higher_in_channels (int): Number of input channels of the + higher-resolution branch. + lower_in_channels (int): Number of input channels of the + lower-resolution branch. + out_channels (int): Number of output channels. + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + dwconv_act_cfg (dict): Config of activation layers in 3x3 conv. + Default: dict(type='ReLU'). + conv_act_cfg (dict): Config of activation layers in the two 1x1 conv. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + """ + + def __init__(self, + higher_in_channels, + lower_in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dwconv_act_cfg=dict(type='ReLU'), + conv_act_cfg=None, + align_corners=False): + super().__init__() + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dwconv_act_cfg = dwconv_act_cfg + self.conv_act_cfg = conv_act_cfg + self.align_corners = align_corners + self.dwconv = ConvModule( + lower_in_channels, + out_channels, + 3, + padding=1, + groups=out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.dwconv_act_cfg) + self.conv_lower_res = ConvModule( + out_channels, + out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.conv_act_cfg) + + self.conv_higher_res = ConvModule( + higher_in_channels, + out_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.conv_act_cfg) + + self.relu = nn.ReLU(True) + + def forward(self, higher_res_feature, lower_res_feature): + lower_res_feature = resize( + lower_res_feature, + size=higher_res_feature.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + lower_res_feature = self.dwconv(lower_res_feature) + lower_res_feature = self.conv_lower_res(lower_res_feature) + + higher_res_feature = self.conv_higher_res(higher_res_feature) + out = higher_res_feature + lower_res_feature + return self.relu(out) + + +@MODELS.register_module() +class FastSCNN(BaseModule): + """Fast-SCNN Backbone. + + This backbone is the implementation of `Fast-SCNN: Fast Semantic + Segmentation Network `_. + + Args: + in_channels (int): Number of input image channels. Default: 3. + downsample_dw_channels (tuple[int]): Number of output channels after + the first conv layer & the second conv layer in + Learning-To-Downsample (LTD) module. + Default: (32, 48). + global_in_channels (int): Number of input channels of + Global Feature Extractor(GFE). + Equal to number of output channels of LTD. + Default: 64. + global_block_channels (tuple[int]): Tuple of integers that describe + the output channels for each of the MobileNet-v2 bottleneck + residual blocks in GFE. + Default: (64, 96, 128). + global_block_strides (tuple[int]): Tuple of integers + that describe the strides (downsampling factors) for each of the + MobileNet-v2 bottleneck residual blocks in GFE. + Default: (2, 2, 1). + global_out_channels (int): Number of output channels of GFE. + Default: 128. + higher_in_channels (int): Number of input channels of the higher + resolution branch in FFM. + Equal to global_in_channels. + Default: 64. + lower_in_channels (int): Number of input channels of the lower + resolution branch in FFM. + Equal to global_out_channels. + Default: 128. + fusion_out_channels (int): Number of output channels of FFM. + Default: 128. + out_indices (tuple): Tuple of indices of list + [higher_res_features, lower_res_features, fusion_output]. + Often set to (0,1,2) to enable aux. heads. + Default: (0, 1, 2). + conv_cfg (dict | None): Config of conv layers. Default: None + norm_cfg (dict | None): Config of norm layers. Default: + dict(type='BN') + act_cfg (dict): Config of activation layers. Default: + dict(type='ReLU') + align_corners (bool): align_corners argument of F.interpolate. + Default: False + dw_act_cfg (dict): In DepthwiseSeparableConvModule, activation config + of depthwise ConvModule. If it is 'default', it will be the same + as `act_cfg`. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + in_channels=3, + downsample_dw_channels=(32, 48), + global_in_channels=64, + global_block_channels=(64, 96, 128), + global_block_strides=(2, 2, 1), + global_out_channels=128, + higher_in_channels=64, + lower_in_channels=128, + fusion_out_channels=128, + out_indices=(0, 1, 2), + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + dw_act_cfg=None, + init_cfg=None): + + super().__init__(init_cfg) + + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + + if global_in_channels != higher_in_channels: + raise AssertionError('Global Input Channels must be the same \ + with Higher Input Channels!') + elif global_out_channels != lower_in_channels: + raise AssertionError('Global Output Channels must be the same \ + with Lower Input Channels!') + + self.in_channels = in_channels + self.downsample_dw_channels1 = downsample_dw_channels[0] + self.downsample_dw_channels2 = downsample_dw_channels[1] + self.global_in_channels = global_in_channels + self.global_block_channels = global_block_channels + self.global_block_strides = global_block_strides + self.global_out_channels = global_out_channels + self.higher_in_channels = higher_in_channels + self.lower_in_channels = lower_in_channels + self.fusion_out_channels = fusion_out_channels + self.out_indices = out_indices + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + self.learning_to_downsample = LearningToDownsample( + in_channels, + downsample_dw_channels, + global_in_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + dw_act_cfg=dw_act_cfg) + self.global_feature_extractor = GlobalFeatureExtractor( + global_in_channels, + global_block_channels, + global_out_channels, + strides=self.global_block_strides, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.feature_fusion = FeatureFusionModule( + higher_in_channels, + lower_in_channels, + fusion_out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dwconv_act_cfg=self.act_cfg, + align_corners=self.align_corners) + + def forward(self, x): + higher_res_features = self.learning_to_downsample(x) + lower_res_features = self.global_feature_extractor(higher_res_features) + fusion_output = self.feature_fusion(higher_res_features, + lower_res_features) + + outs = [higher_res_features, lower_res_features, fusion_output] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/finetune/mmseg/models/backbones/hrnet.py b/finetune/mmseg/models/backbones/hrnet.py new file mode 100644 index 0000000..2da755e --- /dev/null +++ b/finetune/mmseg/models/backbones/hrnet.py @@ -0,0 +1,642 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule, ModuleList, Sequential +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import Upsample, resize +from .resnet import BasicBlock, Bottleneck + + +class HRModule(BaseModule): + """High-Resolution Module for HRNet. + + In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange + is in this module. + """ + + def __init__(self, + num_branches, + blocks, + num_blocks, + in_channels, + num_channels, + multiscale_output=True, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + block_init_cfg=None, + init_cfg=None): + super().__init__(init_cfg) + self.block_init_cfg = block_init_cfg + self._check_branches(num_branches, num_blocks, in_channels, + num_channels) + + self.in_channels = in_channels + self.num_branches = num_branches + + self.multiscale_output = multiscale_output + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + self.with_cp = with_cp + self.branches = self._make_branches(num_branches, blocks, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=False) + + def _check_branches(self, num_branches, num_blocks, in_channels, + num_channels): + """Check branches configuration.""" + if num_branches != len(num_blocks): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \ + f'{len(num_blocks)})' + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \ + f'{len(num_channels)})' + raise ValueError(error_msg) + + if num_branches != len(in_channels): + error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \ + f'{len(in_channels)})' + raise ValueError(error_msg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + """Build one branch.""" + downsample = None + if stride != 1 or \ + self.in_channels[branch_index] != \ + num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + self.in_channels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, num_channels[branch_index] * + block.expansion)[1]) + + layers = [] + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=self.block_init_cfg)) + self.in_channels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block( + self.in_channels[branch_index], + num_channels[branch_index], + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=self.block_init_cfg)) + + return Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + """Build multiple branch.""" + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return ModuleList(branches) + + def _make_fuse_layers(self): + """Build fuse layer.""" + if self.num_branches == 1: + return None + + num_branches = self.num_branches + in_channels = self.in_channels + fuse_layers = [] + num_out_branches = num_branches if self.multiscale_output else 1 + for i in range(num_out_branches): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=1, + stride=1, + padding=0, + bias=False), + build_norm_layer(self.norm_cfg, in_channels[i])[1], + # we set align_corners=False for HRNet + Upsample( + scale_factor=2**(j - i), + mode='bilinear', + align_corners=False))) + elif j == i: + fuse_layer.append(None) + else: + conv_downsamples = [] + for k in range(i - j): + if k == i - j - 1: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[i], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[i])[1])) + else: + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels[j], + in_channels[j], + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + in_channels[j])[1], + nn.ReLU(inplace=False))) + fuse_layer.append(nn.Sequential(*conv_downsamples)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def forward(self, x): + """Forward function.""" + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = 0 + for j in range(self.num_branches): + if i == j: + y += x[j] + elif j > i: + y = y + resize( + self.fuse_layers[i][j](x[j]), + size=x[i].shape[2:], + mode='bilinear', + align_corners=False) + else: + y += self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + return x_fuse + + +@MODELS.register_module() +class HRNet(BaseModule): + """HRNet backbone. + + This backbone is the implementation of `High-Resolution Representations + for Labeling Pixels and Regions `_. + + Args: + extra (dict): Detailed configuration for each stage of HRNet. + There must be 4 stages, the configuration for each stage must have + 5 keys: + + - num_modules (int): The number of HRModule in this stage. + - num_branches (int): The number of branches in the HRModule. + - block (str): The type of convolution block. + - num_blocks (tuple): The number of blocks in each branch. + The length must be equal to num_branches. + - num_channels (tuple): The number of channels in each branch. + The length must be equal to num_branches. + in_channels (int): Number of input image channels. Normally 3. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Use `BN` by default. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: False. + multiscale_output (bool): Whether to output multi-level features + produced by multiple branches. If False, only the first level + feature will be output. Default: True. + pretrained (str, optional): Model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> from mmseg.models import HRNet + >>> import torch + >>> extra = dict( + >>> stage1=dict( + >>> num_modules=1, + >>> num_branches=1, + >>> block='BOTTLENECK', + >>> num_blocks=(4, ), + >>> num_channels=(64, )), + >>> stage2=dict( + >>> num_modules=1, + >>> num_branches=2, + >>> block='BASIC', + >>> num_blocks=(4, 4), + >>> num_channels=(32, 64)), + >>> stage3=dict( + >>> num_modules=4, + >>> num_branches=3, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4), + >>> num_channels=(32, 64, 128)), + >>> stage4=dict( + >>> num_modules=3, + >>> num_branches=4, + >>> block='BASIC', + >>> num_blocks=(4, 4, 4, 4), + >>> num_channels=(32, 64, 128, 256))) + >>> self = HRNet(extra, in_channels=1) + >>> self.eval() + >>> inputs = torch.rand(1, 1, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 32, 8, 8) + (1, 64, 4, 4) + (1, 128, 2, 2) + (1, 256, 1, 1) + """ + + blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + + def __init__(self, + extra, + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + with_cp=False, + frozen_stages=-1, + zero_init_residual=False, + multiscale_output=True, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + + self.pretrained = pretrained + self.zero_init_residual = zero_init_residual + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + # Assert configurations of 4 stages are in extra + assert 'stage1' in extra and 'stage2' in extra \ + and 'stage3' in extra and 'stage4' in extra + # Assert whether the length of `num_blocks` and `num_channels` are + # equal to `num_branches` + for i in range(4): + cfg = extra[f'stage{i + 1}'] + assert len(cfg['num_blocks']) == cfg['num_branches'] and \ + len(cfg['num_channels']) == cfg['num_branches'] + + self.extra = extra + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.frozen_stages = frozen_stages + + # stem net + self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1) + self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2) + + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + self.conv_cfg, + 64, + 64, + kernel_size=3, + stride=2, + padding=1, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.relu = nn.ReLU(inplace=True) + + # stage 1 + self.stage1_cfg = self.extra['stage1'] + num_channels = self.stage1_cfg['num_channels'][0] + block_type = self.stage1_cfg['block'] + num_blocks = self.stage1_cfg['num_blocks'][0] + + block = self.blocks_dict[block_type] + stage1_out_channels = num_channels * block.expansion + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + + # stage 2 + self.stage2_cfg = self.extra['stage2'] + num_channels = self.stage2_cfg['num_channels'] + block_type = self.stage2_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition1 = self._make_transition_layer([stage1_out_channels], + num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + # stage 3 + self.stage3_cfg = self.extra['stage3'] + num_channels = self.stage3_cfg['num_channels'] + block_type = self.stage3_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + # stage 4 + self.stage4_cfg = self.extra['stage4'] + num_channels = self.stage4_cfg['num_channels'] + block_type = self.stage4_cfg['block'] + + block = self.blocks_dict[block_type] + num_channels = [channel * block.expansion for channel in num_channels] + self.transition3 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multiscale_output=multiscale_output) + + self._freeze_stages() + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: the normalization layer named "norm2" """ + return getattr(self, self.norm2_name) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + """Make transition layer.""" + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, + num_channels_cur_layer[i])[1], + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv_downsamples = [] + for j in range(i + 1 - num_branches_pre): + in_channels = num_channels_pre_layer[-1] + out_channels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else in_channels + conv_downsamples.append( + nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, out_channels)[1], + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv_downsamples)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + """Make each layer.""" + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) + + layers = [] + block_init_cfg = None + if self.pretrained is None and not hasattr( + self, 'init_cfg') and self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + + layers.append( + block( + inplanes, + planes, + stride, + downsample=downsample, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=block_init_cfg)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block( + inplanes, + planes, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + init_cfg=block_init_cfg)) + + return Sequential(*layers) + + def _make_stage(self, layer_config, in_channels, multiscale_output=True): + """Make each stage.""" + num_modules = layer_config['num_modules'] + num_branches = layer_config['num_branches'] + num_blocks = layer_config['num_blocks'] + num_channels = layer_config['num_channels'] + block = self.blocks_dict[layer_config['block']] + + hr_modules = [] + block_init_cfg = None + if self.pretrained is None and not hasattr( + self, 'init_cfg') and self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', val=0, override=dict(name='norm3')) + + for i in range(num_modules): + # multi_scale_output is only used for the last module + if not multiscale_output and i == num_modules - 1: + reset_multiscale_output = False + else: + reset_multiscale_output = True + + hr_modules.append( + HRModule( + num_branches, + block, + num_blocks, + in_channels, + num_channels, + reset_multiscale_output, + with_cp=self.with_cp, + norm_cfg=self.norm_cfg, + conv_cfg=self.conv_cfg, + block_init_cfg=block_init_cfg)) + + return Sequential(*hr_modules), in_channels + + def _freeze_stages(self): + """Freeze stages param and norm stats.""" + if self.frozen_stages >= 0: + + self.norm1.eval() + self.norm2.eval() + for m in [self.conv1, self.norm1, self.conv2, self.norm2]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + if i == 1: + m = getattr(self, f'layer{i}') + t = getattr(self, f'transition{i}') + elif i == 4: + m = getattr(self, f'stage{i}') + else: + m = getattr(self, f'stage{i}') + t = getattr(self, f'transition{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + t.eval() + for param in t.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['num_branches']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['num_branches']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['num_branches']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + return y_list + + def train(self, mode=True): + """Convert the model into training mode will keeping the normalization + layer freezed.""" + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/finetune/mmseg/models/backbones/icnet.py b/finetune/mmseg/models/backbones/icnet.py new file mode 100644 index 0000000..8ff3448 --- /dev/null +++ b/finetune/mmseg/models/backbones/icnet.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..decode_heads.psp_head import PPM +from ..utils import resize + + +@MODELS.register_module() +class ICNet(BaseModule): + """ICNet for Real-Time Semantic Segmentation on High-Resolution Images. + + This backbone is the implementation of + `ICNet `_. + + Args: + backbone_cfg (dict): Config dict to build backbone. Usually it is + ResNet but it can also be other backbones. + in_channels (int): The number of input image channels. Default: 3. + layer_channels (Sequence[int]): The numbers of feature channels at + layer 2 and layer 4 in ResNet. It can also be other backbones. + Default: (512, 2048). + light_branch_middle_channels (int): The number of channels of the + middle layer in light branch. Default: 32. + psp_out_channels (int): The number of channels of the output of PSP + module. Default: 512. + out_channels (Sequence[int]): The numbers of output feature channels + at each branches. Default: (64, 256, 256). + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. Default: (1, 2, 3, 6). + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Dictionary to construct and config act layer. + Default: dict(type='ReLU'). + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + backbone_cfg, + in_channels=3, + layer_channels=(512, 2048), + light_branch_middle_channels=32, + psp_out_channels=512, + out_channels=(64, 256, 256), + pool_scales=(1, 2, 3, 6), + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU'), + align_corners=False, + init_cfg=None): + if backbone_cfg is None: + raise TypeError('backbone_cfg must be passed from config file!') + if init_cfg is None: + init_cfg = [ + dict(type='Kaiming', mode='fan_out', layer='Conv2d'), + dict(type='Constant', val=1, layer='_BatchNorm'), + dict(type='Normal', mean=0.01, layer='Linear') + ] + super().__init__(init_cfg=init_cfg) + self.align_corners = align_corners + self.backbone = MODELS.build(backbone_cfg) + + # Note: Default `ceil_mode` is false in nn.MaxPool2d, set + # `ceil_mode=True` to keep information in the corner of feature map. + self.backbone.maxpool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, ceil_mode=True) + + self.psp_modules = PPM( + pool_scales=pool_scales, + in_channels=layer_channels[1], + channels=psp_out_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + align_corners=align_corners) + + self.psp_bottleneck = ConvModule( + layer_channels[1] + len(pool_scales) * psp_out_channels, + psp_out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.conv_sub1 = nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=light_branch_middle_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg), + ConvModule( + in_channels=light_branch_middle_channels, + out_channels=light_branch_middle_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg), + ConvModule( + in_channels=light_branch_middle_channels, + out_channels=out_channels[0], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg)) + + self.conv_sub2 = ConvModule( + layer_channels[0], + out_channels[1], + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + self.conv_sub4 = ConvModule( + psp_out_channels, + out_channels[2], + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg) + + def forward(self, x): + output = [] + + # sub 1 + output.append(self.conv_sub1(x)) + + # sub 2 + x = resize( + x, + scale_factor=0.5, + mode='bilinear', + align_corners=self.align_corners) + x = self.backbone.stem(x) + x = self.backbone.maxpool(x) + x = self.backbone.layer1(x) + x = self.backbone.layer2(x) + output.append(self.conv_sub2(x)) + + # sub 4 + x = resize( + x, + scale_factor=0.5, + mode='bilinear', + align_corners=self.align_corners) + x = self.backbone.layer3(x) + x = self.backbone.layer4(x) + psp_outs = self.psp_modules(x) + [x] + psp_outs = torch.cat(psp_outs, dim=1) + x = self.psp_bottleneck(psp_outs) + + output.append(self.conv_sub4(x)) + + return output diff --git a/finetune/mmseg/models/backbones/mae.py b/finetune/mmseg/models/backbones/mae.py new file mode 100644 index 0000000..a1f243f --- /dev/null +++ b/finetune/mmseg/models/backbones/mae.py @@ -0,0 +1,260 @@ +# Copyright (c) OpenMMLab. All rights reserved.import math +import math + +import torch +import torch.nn as nn +from mmengine.model import ModuleList +from mmengine.model.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmengine.runner.checkpoint import _load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.registry import MODELS +from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer + + +class MAEAttention(BEiTAttention): + """Multi-head self-attention with relative position bias used in MAE. + + This module is different from ``BEiTAttention`` by initializing the + relative bias table with zeros. + """ + + def init_weights(self): + """Initialize relative position bias with zeros.""" + + # As MAE initializes relative position bias as zeros and this class + # inherited from BEiT which initializes relative position bias + # with `trunc_normal`, `init_weights` here does + # nothing and just passes directly + + pass + + +class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer): + """Implements one encoder layer in Vision Transformer. + + This module is different from ``BEiTTransformerEncoderLayer`` by replacing + ``BEiTAttention`` with ``MAEAttention``. + """ + + def build_attn(self, attn_cfg): + self.attn = MAEAttention(**attn_cfg) + + +@MODELS.register_module() +class MAE(BEiT): + """VisionTransformer with support for patch. + + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): embedding dimension. Default: 768. + num_layers (int): depth of transformer. Default: 12. + num_heads (int): number of attention heads. Default: 12. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + pretrained (str, optional): model pretrained path. Default: None. + init_values (float): Initialize the values of Attention and FFN + with learnable scaling. Defaults to 0.1. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + final_norm=False, + num_fcs=2, + norm_eval=False, + pretrained=None, + init_values=0.1, + init_cfg=None): + super().__init__( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims, + num_layers=num_layers, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + out_indices=out_indices, + qv_bias=False, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + patch_norm=patch_norm, + final_norm=final_norm, + num_fcs=num_fcs, + norm_eval=norm_eval, + pretrained=pretrained, + init_values=init_values, + init_cfg=init_cfg) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + self.num_patches = self.patch_shape[0] * self.patch_shape[1] + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches + 1, embed_dims)) + + def _build_layers(self): + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, self.num_layers) + ] + self.layers = ModuleList() + for i in range(self.num_layers): + self.layers.append( + MAETransformerEncoderLayer( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self.mlp_ratio * self.embed_dims, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=dpr[i], + num_fcs=self.num_fcs, + bias=True, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg, + window_size=self.patch_shape, + init_values=self.init_values)) + + def fix_init_weight(self): + """Rescale the initialization according to layer id. + + This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501 + Copyright (c) Microsoft Corporation + Licensed under the MIT License + """ + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.layers): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, layer_id + 1) + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + self.fix_init_weight() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + checkpoint = _load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + state_dict = self.resize_rel_pos_embed(checkpoint) + state_dict = self.resize_abs_pos_embed(state_dict) + self.load_state_dict(state_dict, False) + elif self.init_cfg is not None: + super().init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + # Copyright 2019 Ross Wightman + # Licensed under the Apache License, Version 2.0 (the "License") + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def resize_abs_pos_embed(self, state_dict): + if 'pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches + # height (== width) for the checkpoint position embedding + orig_size = int( + (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) + # height (== width) for the new position embedding + new_size = int(self.num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, + embedding_size).permute( + 0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode='bicubic', + align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['pos_embed'] = new_pos_embed + return state_dict + + def forward(self, inputs): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + out = x[:, 1:] + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) diff --git a/finetune/mmseg/models/backbones/mit.py b/finetune/mmseg/models/backbones/mit.py new file mode 100644 index 0000000..66556bd --- /dev/null +++ b/finetune/mmseg/models/backbones/mit.py @@ -0,0 +1,450 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import MultiheadAttention +from mmengine.model import BaseModule, ModuleList, Sequential +from mmengine.model.weight_init import (constant_init, normal_init, + trunc_normal_init) + +from mmseg.registry import MODELS +from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw + + +class MixFFN(BaseModule): + """An implementation of MixFFN of Segformer. + + The differences between MixFFN & FFN: + 1. Use 1X1 Conv to replace Linear layer. + 2. Introduce 3X3 Conv to encode positional information. + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + act_cfg=dict(type='GELU'), + ffn_drop=0., + dropout_layer=None, + init_cfg=None): + super().__init__(init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + in_channels = embed_dims + fc1 = Conv2d( + in_channels=in_channels, + out_channels=feedforward_channels, + kernel_size=1, + stride=1, + bias=True) + # 3x3 depth wise conv to provide positional encode information + pe_conv = Conv2d( + in_channels=feedforward_channels, + out_channels=feedforward_channels, + kernel_size=3, + stride=1, + padding=(3 - 1) // 2, + bias=True, + groups=feedforward_channels) + fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=in_channels, + kernel_size=1, + stride=1, + bias=True) + drop = nn.Dropout(ffn_drop) + layers = [fc1, pe_conv, self.activate, drop, fc2, drop] + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + + def forward(self, x, hw_shape, identity=None): + out = nlc_to_nchw(x, hw_shape) + out = self.layers(out) + out = nchw_to_nlc(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +class EfficientMultiheadAttention(MultiheadAttention): + """An implementation of Efficient Multi-head Attention of Segformer. + + This module is modified from MultiheadAttention which is a module from + mmcv.cnn.bricks.transformer. + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. Default: None. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + qkv_bias (bool): enable bias for qkv if True. Default True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head + Attention of Segformer. Default: 1. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=None, + init_cfg=None, + batch_first=True, + qkv_bias=False, + norm_cfg=dict(type='LN'), + sr_ratio=1): + super().__init__( + embed_dims, + num_heads, + attn_drop, + proj_drop, + dropout_layer=dropout_layer, + init_cfg=init_cfg, + batch_first=batch_first, + bias=qkv_bias) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=sr_ratio, + stride=sr_ratio) + # The ret[0] of build_norm_layer is norm name. + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + + # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa + from mmseg import digit_version, mmcv_version + if mmcv_version < digit_version('1.3.17'): + warnings.warn('The legacy version of forward function in' + 'EfficientMultiheadAttention is deprecated in' + 'mmcv>=1.3.17 and will no longer support in the' + 'future. Please upgrade your mmcv.') + self.forward = self.legacy_forward + + def forward(self, x, hw_shape, identity=None): + + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + # Because the dataflow('key', 'query', 'value') of + # ``torch.nn.MultiheadAttention`` is (num_query, batch, + # embed_dims), We should adjust the shape of dataflow from + # batch_first (batch, num_query, embed_dims) to num_query_first + # (num_query ,batch, embed_dims), and recover ``attn_output`` + # from num_query_first to batch_first. + if self.batch_first: + x_q = x_q.transpose(0, 1) + x_kv = x_kv.transpose(0, 1) + + out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] + + if self.batch_first: + out = out.transpose(0, 1) + + return identity + self.dropout_layer(self.proj_drop(out)) + + def legacy_forward(self, x, hw_shape, identity=None): + """multi head attention forward in mmcv version < 1.3.17.""" + + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + # `need_weights=True` will let nn.MultiHeadAttention + # `return attn_output, attn_output_weights.sum(dim=1) / num_heads` + # The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set + # `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`. + # This issue - `https://github.com/pytorch/pytorch/issues/37583` report + # the error that large scale tensor sum operation may cause cuda error. + out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0] + + return identity + self.dropout_layer(self.proj_drop(out)) + + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Segformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed. + after the feed forward layer. Default 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.0. + qkv_bias (bool): enable bias for qkv if True. + Default: True. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + init_cfg (dict, optional): Initialization config dict. + Default:None. + sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head + Attention of Segformer. Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + sr_ratio=1, + with_cp=False): + super().__init__() + + # The ret[0] of build_norm_layer is norm name. + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.attn = EfficientMultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + batch_first=batch_first, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + # The ret[0] of build_norm_layer is norm name. + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.ffn = MixFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + self.with_cp = with_cp + + def forward(self, x, hw_shape): + + def _inner_forward(x): + x = self.attn(self.norm1(x), hw_shape, identity=x) + x = self.ffn(self.norm2(x), hw_shape, identity=x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@MODELS.register_module() +class MixVisionTransformer(BaseModule): + """The backbone of Segformer. + + This backbone is the implementation of `SegFormer: Simple and + Efficient Design for Semantic Segmentation with + Transformers `_. + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): Embedding dimension. Default: 768. + num_stags (int): The num of stages. Default: 4. + num_layers (Sequence[int]): The layer number of each transformer encode + layer. Default: [3, 4, 6, 3]. + num_heads (Sequence[int]): The attention heads of each transformer + encode layer. Default: [1, 2, 4, 8]. + patch_sizes (Sequence[int]): The patch_size of each overlapped patch + embedding. Default: [7, 3, 3, 3]. + strides (Sequence[int]): The stride of each overlapped patch embedding. + Default: [4, 2, 2, 2]. + sr_ratios (Sequence[int]): The spatial reduction rate of each + transformer encode layer. Default: [8, 4, 2, 1]. + out_indices (Sequence[int] | int): Output from which stages. + Default: (0, 1, 2, 3). + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0 + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + in_channels=3, + embed_dims=64, + num_stages=4, + num_layers=[3, 4, 6, 3], + num_heads=[1, 2, 4, 8], + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratio=4, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN', eps=1e-6), + pretrained=None, + init_cfg=None, + with_cp=False): + super().__init__(init_cfg=init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.embed_dims = embed_dims + self.num_stages = num_stages + self.num_layers = num_layers + self.num_heads = num_heads + self.patch_sizes = patch_sizes + self.strides = strides + self.sr_ratios = sr_ratios + self.with_cp = with_cp + assert num_stages == len(num_layers) == len(num_heads) \ + == len(patch_sizes) == len(strides) == len(sr_ratios) + + self.out_indices = out_indices + assert max(out_indices) < self.num_stages + + # transformer encoder + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(num_layers)) + ] # stochastic num_layer decay rule + + cur = 0 + self.layers = ModuleList() + for i, num_layer in enumerate(num_layers): + embed_dims_i = embed_dims * num_heads[i] + patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims_i, + kernel_size=patch_sizes[i], + stride=strides[i], + padding=patch_sizes[i] // 2, + norm_cfg=norm_cfg) + layer = ModuleList([ + TransformerEncoderLayer( + embed_dims=embed_dims_i, + num_heads=num_heads[i], + feedforward_channels=mlp_ratio * embed_dims_i, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[cur + idx], + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + sr_ratio=sr_ratios[i]) for idx in range(num_layer) + ]) + in_channels = embed_dims_i + # The ret[0] of build_norm_layer is norm name. + norm = build_norm_layer(norm_cfg, embed_dims_i)[1] + self.layers.append(ModuleList([patch_embed, layer, norm])) + cur += num_layer + + def init_weights(self): + if self.init_cfg is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + else: + super().init_weights() + + def forward(self, x): + outs = [] + + for i, layer in enumerate(self.layers): + x, hw_shape = layer[0](x) + for block in layer[1]: + x = block(x, hw_shape) + x = layer[2](x) + x = nlc_to_nchw(x, hw_shape) + if i in self.out_indices: + outs.append(x) + + return outs diff --git a/finetune/mmseg/models/backbones/mobilenet_v2.py b/finetune/mmseg/models/backbones/mobilenet_v2.py new file mode 100644 index 0000000..1c21b5d --- /dev/null +++ b/finetune/mmseg/models/backbones/mobilenet_v2.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import InvertedResidual, make_divisible + + +@MODELS.register_module() +class MobileNetV2(BaseModule): + """MobileNetV2 backbone. + + This backbone is the implementation of + `MobileNetV2: Inverted Residuals and Linear Bottlenecks + `_. + + Args: + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + strides (Sequence[int], optional): Strides of the first block of each + layer. If not specified, default config in ``arch_setting`` will + be used. + dilations (Sequence[int]): Dilation of each layer. + out_indices (None or Sequence[int]): Output from which stages. + Default: (7, ). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + # Parameters to build layers. 3 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks. + arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4], + [6, 96, 3], [6, 160, 3], [6, 320, 1]] + + def __init__(self, + widen_factor=1., + strides=(1, 2, 2, 2, 1, 2, 1), + dilations=(1, 1, 1, 1, 1, 1, 1), + out_indices=(1, 2, 4, 6), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + self.widen_factor = widen_factor + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == len(self.arch_settings) + self.out_indices = out_indices + for index in out_indices: + if index not in range(0, 7): + raise ValueError('the item in out_indices must in ' + f'range(0, 7). But received {index}') + + if frozen_stages not in range(-1, 7): + raise ValueError('frozen_stages must be in range(-1, 7). ' + f'But received {frozen_stages}') + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.in_channels = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks = layer_cfg + stride = self.strides[i] + dilation = self.dilations[i] + out_channels = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + out_channels=out_channels, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.layers.append(layer_name) + + def make_layer(self, out_channels, num_blocks, stride, dilation, + expand_ratio): + """Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + out_channels (int): out_channels of block. + num_blocks (int): Number of blocks. + stride (int): Stride of the first block. + dilation (int): Dilation of the first block. + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. + """ + layers = [] + for i in range(num_blocks): + layers.append( + InvertedResidual( + self.in_channels, + out_channels, + stride if i == 0 else 1, + expand_ratio=expand_ratio, + dilation=dilation if i == 0 else 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.in_channels = out_channels + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/finetune/mmseg/models/backbones/mobilenet_v3.py b/finetune/mmseg/models/backbones/mobilenet_v3.py new file mode 100644 index 0000000..1efb6e0 --- /dev/null +++ b/finetune/mmseg/models/backbones/mobilenet_v3.py @@ -0,0 +1,267 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmcv.cnn import ConvModule +from mmcv.cnn.bricks import Conv2dAdaptivePadding +from mmengine.model import BaseModule +from mmengine.utils import is_tuple_of +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import InvertedResidualV3 as InvertedResidual + + +@MODELS.register_module() +class MobileNetV3(BaseModule): + """MobileNetV3 backbone. + + This backbone is the improved implementation of `Searching for MobileNetV3 + `_. + + Args: + arch (str): Architecture of mobilnetv3, from {'small', 'large'}. + Default: 'small'. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + out_indices (tuple[int]): Output from which layer. + Default: (0, 1, 12). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. + Default: False. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + # Parameters to build each block: + # [kernel size, mid channels, out channels, with_se, act type, stride] + arch_settings = { + 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4 + [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8 + [3, 88, 24, False, 'ReLU', 1], + [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16 + [5, 240, 40, True, 'HSwish', 1], + [5, 240, 40, True, 'HSwish', 1], + [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16 + [5, 144, 48, True, 'HSwish', 1], + [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32 + [5, 576, 96, True, 'HSwish', 1], + [5, 576, 96, True, 'HSwish', 1]], + 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2 + [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4 + [3, 72, 24, False, 'ReLU', 1], + [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8 + [5, 120, 40, True, 'ReLU', 1], + [5, 120, 40, True, 'ReLU', 1], + [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16 + [3, 200, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 184, 80, False, 'HSwish', 1], + [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16 + [3, 672, 112, True, 'HSwish', 1], + [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32 + [5, 960, 160, True, 'HSwish', 1], + [5, 960, 160, True, 'HSwish', 1]] + } # yapf: disable + + def __init__(self, + arch='small', + conv_cfg=None, + norm_cfg=dict(type='BN'), + out_indices=(0, 1, 12), + frozen_stages=-1, + reduction_factor=1, + norm_eval=False, + with_cp=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + assert arch in self.arch_settings + assert isinstance(reduction_factor, int) and reduction_factor > 0 + assert is_tuple_of(out_indices, int) + for index in out_indices: + if index not in range(0, len(self.arch_settings[arch]) + 2): + raise ValueError( + 'the item in out_indices must in ' + f'range(0, {len(self.arch_settings[arch])+2}). ' + f'But received {index}') + + if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2): + raise ValueError('frozen_stages must be in range(-1, ' + f'{len(self.arch_settings[arch])+2}). ' + f'But received {frozen_stages}') + self.arch = arch + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.reduction_factor = reduction_factor + self.norm_eval = norm_eval + self.with_cp = with_cp + self.layers = self._make_layer() + + def _make_layer(self): + layers = [] + + # build the first layer (layer0) + in_channels = 16 + layer = ConvModule( + in_channels=3, + out_channels=in_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=dict(type='Conv2dAdaptivePadding'), + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + self.add_module('layer0', layer) + layers.append('layer0') + + layer_setting = self.arch_settings[self.arch] + for i, params in enumerate(layer_setting): + (kernel_size, mid_channels, out_channels, with_se, act, + stride) = params + + if self.arch == 'large' and i >= 12 or self.arch == 'small' and \ + i >= 8: + mid_channels = mid_channels // self.reduction_factor + out_channels = out_channels // self.reduction_factor + + if with_se: + se_cfg = dict( + channels=mid_channels, + ratio=4, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=3.0, divisor=6.0))) + else: + se_cfg = None + + layer = InvertedResidual( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + se_cfg=se_cfg, + with_expand_conv=(in_channels != mid_channels), + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type=act), + with_cp=self.with_cp) + in_channels = out_channels + layer_name = f'layer{i + 1}' + self.add_module(layer_name, layer) + layers.append(layer_name) + + # build the last layer + # block5 layer12 os=32 for small model + # block6 layer16 os=32 for large model + layer = ConvModule( + in_channels=in_channels, + out_channels=576 if self.arch == 'small' else 960, + kernel_size=1, + stride=1, + dilation=4, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=dict(type='HSwish')) + layer_name = f'layer{len(layer_setting) + 1}' + self.add_module(layer_name, layer) + layers.append(layer_name) + + # next, convert backbone MobileNetV3 to a semantic segmentation version + if self.arch == 'small': + self.layer4.depthwise_conv.conv.stride = (1, 1) + self.layer9.depthwise_conv.conv.stride = (1, 1) + for i in range(4, len(layers)): + layer = getattr(self, layers[i]) + if isinstance(layer, InvertedResidual): + modified_module = layer.depthwise_conv.conv + else: + modified_module = layer.conv + + if i < 9: + modified_module.dilation = (2, 2) + pad = 2 + else: + modified_module.dilation = (4, 4) + pad = 4 + + if not isinstance(modified_module, Conv2dAdaptivePadding): + # Adjust padding + pad *= (modified_module.kernel_size[0] - 1) // 2 + modified_module.padding = (pad, pad) + else: + self.layer7.depthwise_conv.conv.stride = (1, 1) + self.layer13.depthwise_conv.conv.stride = (1, 1) + for i in range(7, len(layers)): + layer = getattr(self, layers[i]) + if isinstance(layer, InvertedResidual): + modified_module = layer.depthwise_conv.conv + else: + modified_module = layer.conv + + if i < 13: + modified_module.dilation = (2, 2) + pad = 2 + else: + modified_module.dilation = (4, 4) + pad = 4 + + if not isinstance(modified_module, Conv2dAdaptivePadding): + # Adjust padding + pad *= (modified_module.kernel_size[0] - 1) // 2 + modified_module.padding = (pad, pad) + + return layers + + def forward(self, x): + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + return outs + + def _freeze_stages(self): + for i in range(self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/finetune/mmseg/models/backbones/mscan.py b/finetune/mmseg/models/backbones/mscan.py new file mode 100644 index 0000000..7150cb7 --- /dev/null +++ b/finetune/mmseg/models/backbones/mscan.py @@ -0,0 +1,467 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Originally from https://github.com/visual-attention-network/segnext +# Licensed under the Apache License, Version 2.0 (the "License") +import math +import warnings + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.model.weight_init import (constant_init, normal_init, + trunc_normal_init) + +from mmseg.registry import MODELS + + +class Mlp(BaseModule): + """Multi Layer Perceptron (MLP) Module. + + Args: + in_features (int): The dimension of input features. + hidden_features (int): The dimension of hidden features. + Defaults: None. + out_features (int): The dimension of output features. + Defaults: None. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + drop (float): The number of dropout rate in MLP block. + Defaults: 0.0. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.dwconv = nn.Conv2d( + hidden_features, + hidden_features, + 3, + 1, + 1, + bias=True, + groups=hidden_features) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + """Forward function.""" + + x = self.fc1(x) + + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + + return x + + +class StemConv(BaseModule): + """Stem Block at the beginning of Semantic Branch. + + Args: + in_channels (int): The dimension of input channels. + out_channels (int): The dimension of output channels. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + in_channels, + out_channels, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + + self.proj = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels // 2, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)), + build_norm_layer(norm_cfg, out_channels // 2)[1], + build_activation_layer(act_cfg), + nn.Conv2d( + out_channels // 2, + out_channels, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)), + build_norm_layer(norm_cfg, out_channels)[1], + ) + + def forward(self, x): + """Forward function.""" + + x = self.proj(x) + _, _, H, W = x.size() + x = x.flatten(2).transpose(1, 2) + return x, H, W + + +class MSCAAttention(BaseModule): + """Attention Module in Multi-Scale Convolutional Attention Module (MSCA). + + Args: + channels (int): The dimension of channels. + kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + """ + + def __init__(self, + channels, + kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + paddings=[2, [0, 3], [0, 5], [0, 10]]): + super().__init__() + self.conv0 = nn.Conv2d( + channels, + channels, + kernel_size=kernel_sizes[0], + padding=paddings[0], + groups=channels) + for i, (kernel_size, + padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])): + kernel_size_ = [kernel_size, kernel_size[::-1]] + padding_ = [padding, padding[::-1]] + conv_name = [f'conv{i}_1', f'conv{i}_2'] + for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_, + conv_name): + self.add_module( + i_conv, + nn.Conv2d( + channels, + channels, + tuple(i_kernel), + padding=i_pad, + groups=channels)) + self.conv3 = nn.Conv2d(channels, channels, 1) + + def forward(self, x): + """Forward function.""" + + u = x.clone() + + attn = self.conv0(x) + + # Multi-Scale Feature extraction + attn_0 = self.conv0_1(attn) + attn_0 = self.conv0_2(attn_0) + + attn_1 = self.conv1_1(attn) + attn_1 = self.conv1_2(attn_1) + + attn_2 = self.conv2_1(attn) + attn_2 = self.conv2_2(attn_2) + + attn = attn + attn_0 + attn_1 + attn_2 + # Channel Mixing + attn = self.conv3(attn) + + # Convolutional Attention + x = attn * u + + return x + + +class MSCASpatialAttention(BaseModule): + """Spatial Attention Module in Multi-Scale Convolutional Attention Module + (MSCA). + + Args: + in_channels (int): The dimension of channels. + attention_kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + act_cfg=dict(type='GELU')): + super().__init__() + self.proj_1 = nn.Conv2d(in_channels, in_channels, 1) + self.activation = build_activation_layer(act_cfg) + self.spatial_gating_unit = MSCAAttention(in_channels, + attention_kernel_sizes, + attention_kernel_paddings) + self.proj_2 = nn.Conv2d(in_channels, in_channels, 1) + + def forward(self, x): + """Forward function.""" + + shorcut = x.clone() + x = self.proj_1(x) + x = self.activation(x) + x = self.spatial_gating_unit(x) + x = self.proj_2(x) + x = x + shorcut + return x + + +class MSCABlock(BaseModule): + """Basic Multi-Scale Convolutional Attention Block. It leverage the large- + kernel attention (LKA) mechanism to build both channel and spatial + attention. In each branch, it uses two depth-wise strip convolutions to + approximate standard depth-wise convolutions with large kernels. The kernel + size for each branch is set to 7, 11, and 21, respectively. + + Args: + channels (int): The dimension of channels. + attention_kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + mlp_ratio (float): The ratio of multiple input dimension to + calculate hidden feature in MLP layer. Defaults: 4.0. + drop (float): The number of dropout rate in MLP block. + Defaults: 0.0. + drop_path (float): The ratio of drop paths. + Defaults: 0.0. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + channels, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + mlp_ratio=4., + drop=0., + drop_path=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + self.norm1 = build_norm_layer(norm_cfg, channels)[1] + self.attn = MSCASpatialAttention(channels, attention_kernel_sizes, + attention_kernel_paddings, act_cfg) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = build_norm_layer(norm_cfg, channels)[1] + mlp_hidden_channels = int(channels * mlp_ratio) + self.mlp = Mlp( + in_features=channels, + hidden_features=mlp_hidden_channels, + act_cfg=act_cfg, + drop=drop) + layer_scale_init_value = 1e-2 + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones(channels), requires_grad=True) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones(channels), requires_grad=True) + + def forward(self, x, H, W): + """Forward function.""" + + B, N, C = x.shape + x = x.permute(0, 2, 1).view(B, C, H, W) + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.attn(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + x = x.view(B, C, N).permute(0, 2, 1) + return x + + +class OverlapPatchEmbed(BaseModule): + """Image to Patch Embedding. + + Args: + patch_size (int): The patch size. + Defaults: 7. + stride (int): Stride of the convolutional layer. + Default: 4. + in_channels (int): The number of input channels. + Defaults: 3. + embed_dims (int): The dimensions of embedding. + Defaults: 768. + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + patch_size=7, + stride=4, + in_channels=3, + embed_dim=768, + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + + self.proj = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=patch_size // 2) + self.norm = build_norm_layer(norm_cfg, embed_dim)[1] + + def forward(self, x): + """Forward function.""" + + x = self.proj(x) + _, _, H, W = x.shape + x = self.norm(x) + + x = x.flatten(2).transpose(1, 2) + + return x, H, W + + +@MODELS.register_module() +class MSCAN(BaseModule): + """SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone. + + This backbone is the implementation of `SegNeXt: Rethinking + Convolutional Attention Design for Semantic + Segmentation `_. + Inspiration from https://github.com/visual-attention-network/segnext. + + Args: + in_channels (int): The number of input channels. Defaults: 3. + embed_dims (list[int]): Embedding dimension. + Defaults: [64, 128, 256, 512]. + mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim. + Defaults: [4, 4, 4, 4]. + drop_rate (float): Dropout rate. Defaults: 0. + drop_path_rate (float): Stochastic depth rate. Defaults: 0. + depths (list[int]): Depths of each Swin Transformer stage. + Default: [3, 4, 6, 3]. + num_stages (int): MSCAN stages. Default: 4. + attention_kernel_sizes (list): Size of attention kernel in + Attention Module (Figure 2(b) of original paper). + Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): Size of attention paddings + in Attention Module (Figure 2(b) of original paper). + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + norm_cfg (dict): Config of norm layers. + Defaults: dict(type='SyncBN', requires_grad=True). + pretrained (str, optional): model pretrained path. + Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256, 512], + mlp_ratios=[4, 4, 4, 4], + drop_rate=0., + drop_path_rate=0., + depths=[3, 4, 6, 3], + num_stages=4, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True), + pretrained=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.depths = depths + self.num_stages = num_stages + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + + for i in range(num_stages): + if i == 0: + patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg) + else: + patch_embed = OverlapPatchEmbed( + patch_size=7 if i == 0 else 3, + stride=4 if i == 0 else 2, + in_channels=in_channels if i == 0 else embed_dims[i - 1], + embed_dim=embed_dims[i], + norm_cfg=norm_cfg) + + block = nn.ModuleList([ + MSCABlock( + channels=embed_dims[i], + attention_kernel_sizes=attention_kernel_sizes, + attention_kernel_paddings=attention_kernel_paddings, + mlp_ratio=mlp_ratios[i], + drop=drop_rate, + drop_path=dpr[cur + j], + act_cfg=act_cfg, + norm_cfg=norm_cfg) for j in range(depths[i]) + ]) + norm = nn.LayerNorm(embed_dims[i]) + cur += depths[i] + + setattr(self, f'patch_embed{i + 1}', patch_embed) + setattr(self, f'block{i + 1}', block) + setattr(self, f'norm{i + 1}', norm) + + def init_weights(self): + """Initialize modules of MSCAN.""" + + print('init cfg', self.init_cfg) + if self.init_cfg is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + else: + super().init_weights() + + def forward(self, x): + """Forward function.""" + + B = x.shape[0] + outs = [] + + for i in range(self.num_stages): + patch_embed = getattr(self, f'patch_embed{i + 1}') + block = getattr(self, f'block{i + 1}') + norm = getattr(self, f'norm{i + 1}') + x, H, W = patch_embed(x) + for blk in block: + x = blk(x, H, W) + x = norm(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return outs diff --git a/finetune/mmseg/models/backbones/pidnet.py b/finetune/mmseg/models/backbones/pidnet.py new file mode 100644 index 0000000..0b711a3 --- /dev/null +++ b/finetune/mmseg/models/backbones/pidnet.py @@ -0,0 +1,522 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from mmengine.runner import CheckpointLoader +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import OptConfigType +from ..utils import DAPPM, PAPPM, BasicBlock, Bottleneck + + +class PagFM(BaseModule): + """Pixel-attention-guided fusion module. + + Args: + in_channels (int): The number of input channels. + channels (int): The number of channels. + after_relu (bool): Whether to use ReLU before attention. + Default: False. + with_channel (bool): Whether to use channel attention. + Default: False. + upsample_mode (str): The mode of upsample. Default: 'bilinear'. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(typ='ReLU', inplace=True). + init_cfg (dict): Config dict for initialization. Default: None. + """ + + def __init__(self, + in_channels: int, + channels: int, + after_relu: bool = False, + with_channel: bool = False, + upsample_mode: str = 'bilinear', + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(typ='ReLU', inplace=True), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.after_relu = after_relu + self.with_channel = with_channel + self.upsample_mode = upsample_mode + self.f_i = ConvModule( + in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None) + self.f_p = ConvModule( + in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None) + if with_channel: + self.up = ConvModule( + channels, in_channels, 1, norm_cfg=norm_cfg, act_cfg=None) + if after_relu: + self.relu = MODELS.build(act_cfg) + + def forward(self, x_p: Tensor, x_i: Tensor) -> Tensor: + """Forward function. + + Args: + x_p (Tensor): The featrue map from P branch. + x_i (Tensor): The featrue map from I branch. + + Returns: + Tensor: The feature map with pixel-attention-guided fusion. + """ + if self.after_relu: + x_p = self.relu(x_p) + x_i = self.relu(x_i) + + f_i = self.f_i(x_i) + f_i = F.interpolate( + f_i, + size=x_p.shape[2:], + mode=self.upsample_mode, + align_corners=False) + + f_p = self.f_p(x_p) + + if self.with_channel: + sigma = torch.sigmoid(self.up(f_p * f_i)) + else: + sigma = torch.sigmoid(torch.sum(f_p * f_i, dim=1).unsqueeze(1)) + + x_i = F.interpolate( + x_i, + size=x_p.shape[2:], + mode=self.upsample_mode, + align_corners=False) + + out = sigma * x_i + (1 - sigma) * x_p + return out + + +class Bag(BaseModule): + """Boundary-attention-guided fusion module. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + kernel_size (int): The kernel size of the convolution. Default: 3. + padding (int): The padding of the convolution. Default: 1. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + conv_cfg (dict): Config dict for convolution layer. + Default: dict(order=('norm', 'act', 'conv')). + init_cfg (dict): Config dict for initialization. Default: None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + padding: int = 1, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + conv_cfg: OptConfigType = dict(order=('norm', 'act', 'conv')), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + + self.conv = ConvModule( + in_channels, + out_channels, + kernel_size, + padding=padding, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + + def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor: + """Forward function. + + Args: + x_p (Tensor): The featrue map from P branch. + x_i (Tensor): The featrue map from I branch. + x_d (Tensor): The featrue map from D branch. + + Returns: + Tensor: The feature map with boundary-attention-guided fusion. + """ + sigma = torch.sigmoid(x_d) + return self.conv(sigma * x_p + (1 - sigma) * x_i) + + +class LightBag(BaseModule): + """Light Boundary-attention-guided fusion module. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. Default: None. + init_cfg (dict): Config dict for initialization. Default: None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = None, + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.f_p = ConvModule( + in_channels, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.f_i = ConvModule( + in_channels, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor: + """Forward function. + Args: + x_p (Tensor): The featrue map from P branch. + x_i (Tensor): The featrue map from I branch. + x_d (Tensor): The featrue map from D branch. + + Returns: + Tensor: The feature map with light boundary-attention-guided + fusion. + """ + sigma = torch.sigmoid(x_d) + + f_p = self.f_p((1 - sigma) * x_i + x_p) + f_i = self.f_i(x_i + sigma * x_p) + + return f_p + f_i + + +@MODELS.register_module() +class PIDNet(BaseModule): + """PIDNet backbone. + + This backbone is the implementation of `PIDNet: A Real-time Semantic + Segmentation Network Inspired from PID Controller + `_. + Modified from https://github.com/XuJiacong/PIDNet. + + Licensed under the MIT License. + + Args: + in_channels (int): The number of input channels. Default: 3. + channels (int): The number of channels in the stem layer. Default: 64. + ppm_channels (int): The number of channels in the PPM layer. + Default: 96. + num_stem_blocks (int): The number of blocks in the stem layer. + Default: 2. + num_branch_blocks (int): The number of blocks in the branch layer. + Default: 3. + align_corners (bool): The align_corners argument of F.interpolate. + Default: False. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + init_cfg (dict): Config dict for initialization. Default: None. + """ + + def __init__(self, + in_channels: int = 3, + channels: int = 64, + ppm_channels: int = 96, + num_stem_blocks: int = 2, + num_branch_blocks: int = 3, + align_corners: bool = False, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + init_cfg: OptConfigType = None, + **kwargs): + super().__init__(init_cfg) + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + + # stem layer + self.stem = self._make_stem_layer(in_channels, channels, + num_stem_blocks) + self.relu = nn.ReLU() + + # I Branch + self.i_branch_layers = nn.ModuleList() + for i in range(3): + self.i_branch_layers.append( + self._make_layer( + block=BasicBlock if i < 2 else Bottleneck, + in_channels=channels * 2**(i + 1), + channels=channels * 8 if i > 0 else channels * 4, + num_blocks=num_branch_blocks if i < 2 else 2, + stride=2)) + + # P Branch + self.p_branch_layers = nn.ModuleList() + for i in range(3): + self.p_branch_layers.append( + self._make_layer( + block=BasicBlock if i < 2 else Bottleneck, + in_channels=channels * 2, + channels=channels * 2, + num_blocks=num_stem_blocks if i < 2 else 1)) + self.compression_1 = ConvModule( + channels * 4, + channels * 2, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None) + self.compression_2 = ConvModule( + channels * 8, + channels * 2, + kernel_size=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None) + self.pag_1 = PagFM(channels * 2, channels) + self.pag_2 = PagFM(channels * 2, channels) + + # D Branch + if num_stem_blocks == 2: + self.d_branch_layers = nn.ModuleList([ + self._make_single_layer(BasicBlock, channels * 2, channels), + self._make_layer(Bottleneck, channels, channels, 1) + ]) + channel_expand = 1 + spp_module = PAPPM + dfm_module = LightBag + act_cfg_dfm = None + else: + self.d_branch_layers = nn.ModuleList([ + self._make_single_layer(BasicBlock, channels * 2, + channels * 2), + self._make_single_layer(BasicBlock, channels * 2, channels * 2) + ]) + channel_expand = 2 + spp_module = DAPPM + dfm_module = Bag + act_cfg_dfm = act_cfg + + self.diff_1 = ConvModule( + channels * 4, + channels * channel_expand, + kernel_size=3, + padding=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None) + self.diff_2 = ConvModule( + channels * 8, + channels * 2, + kernel_size=3, + padding=1, + bias=False, + norm_cfg=norm_cfg, + act_cfg=None) + + self.spp = spp_module( + channels * 16, ppm_channels, channels * 4, num_scales=5) + self.dfm = dfm_module( + channels * 4, channels * 4, norm_cfg=norm_cfg, act_cfg=act_cfg_dfm) + + self.d_branch_layers.append( + self._make_layer(Bottleneck, channels * 2, channels * 2, 1)) + + def _make_stem_layer(self, in_channels: int, channels: int, + num_blocks: int) -> nn.Sequential: + """Make stem layer. + + Args: + in_channels (int): Number of input channels. + channels (int): Number of output channels. + num_blocks (int): Number of blocks. + + Returns: + nn.Sequential: The stem layer. + """ + + layers = [ + ConvModule( + in_channels, + channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + channels, + channels, + kernel_size=3, + stride=2, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + ] + + layers.append( + self._make_layer(BasicBlock, channels, channels, num_blocks)) + layers.append(nn.ReLU()) + layers.append( + self._make_layer( + BasicBlock, channels, channels * 2, num_blocks, stride=2)) + layers.append(nn.ReLU()) + + return nn.Sequential(*layers) + + def _make_layer(self, + block: BasicBlock, + in_channels: int, + channels: int, + num_blocks: int, + stride: int = 1) -> nn.Sequential: + """Make layer for PIDNet backbone. + Args: + block (BasicBlock): Basic block. + in_channels (int): Number of input channels. + channels (int): Number of output channels. + num_blocks (int): Number of blocks. + stride (int): Stride of the first block. Default: 1. + + Returns: + nn.Sequential: The Branch Layer. + """ + downsample = None + if stride != 1 or in_channels != channels * block.expansion: + downsample = ConvModule( + in_channels, + channels * block.expansion, + kernel_size=1, + stride=stride, + norm_cfg=self.norm_cfg, + act_cfg=None) + + layers = [block(in_channels, channels, stride, downsample)] + in_channels = channels * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + in_channels, + channels, + stride=1, + act_cfg_out=None if i == num_blocks - 1 else self.act_cfg)) + return nn.Sequential(*layers) + + def _make_single_layer(self, + block: Union[BasicBlock, Bottleneck], + in_channels: int, + channels: int, + stride: int = 1) -> nn.Module: + """Make single layer for PIDNet backbone. + Args: + block (BasicBlock or Bottleneck): Basic block or Bottleneck. + in_channels (int): Number of input channels. + channels (int): Number of output channels. + stride (int): Stride of the first block. Default: 1. + + Returns: + nn.Module + """ + + downsample = None + if stride != 1 or in_channels != channels * block.expansion: + downsample = ConvModule( + in_channels, + channels * block.expansion, + kernel_size=1, + stride=stride, + norm_cfg=self.norm_cfg, + act_cfg=None) + return block( + in_channels, channels, stride, downsample, act_cfg_out=None) + + def init_weights(self): + """Initialize the weights in backbone. + + Since the D branch is not initialized by the pre-trained model, we + initialize it with the same method as the ResNet. + """ + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if self.init_cfg is not None: + assert 'checkpoint' in self.init_cfg, f'Only support ' \ + f'specify `Pretrained` in ' \ + f'`init_cfg` in ' \ + f'{self.__class__.__name__} ' + ckpt = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], map_location='cpu') + self.load_state_dict(ckpt, strict=False) + + def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]: + """Forward function. + + Args: + x (Tensor): Input tensor with shape (B, C, H, W). + + Returns: + Tensor or tuple[Tensor]: If self.training is True, return + tuple[Tensor], else return Tensor. + """ + w_out = x.shape[-1] // 8 + h_out = x.shape[-2] // 8 + + # stage 0-2 + x = self.stem(x) + + # stage 3 + x_i = self.relu(self.i_branch_layers[0](x)) + x_p = self.p_branch_layers[0](x) + x_d = self.d_branch_layers[0](x) + + comp_i = self.compression_1(x_i) + x_p = self.pag_1(x_p, comp_i) + diff_i = self.diff_1(x_i) + x_d += F.interpolate( + diff_i, + size=[h_out, w_out], + mode='bilinear', + align_corners=self.align_corners) + if self.training: + temp_p = x_p.clone() + + # stage 4 + x_i = self.relu(self.i_branch_layers[1](x_i)) + x_p = self.p_branch_layers[1](self.relu(x_p)) + x_d = self.d_branch_layers[1](self.relu(x_d)) + + comp_i = self.compression_2(x_i) + x_p = self.pag_2(x_p, comp_i) + diff_i = self.diff_2(x_i) + x_d += F.interpolate( + diff_i, + size=[h_out, w_out], + mode='bilinear', + align_corners=self.align_corners) + if self.training: + temp_d = x_d.clone() + + # stage 5 + x_i = self.i_branch_layers[2](x_i) + x_p = self.p_branch_layers[2](self.relu(x_p)) + x_d = self.d_branch_layers[2](self.relu(x_d)) + + x_i = self.spp(x_i) + x_i = F.interpolate( + x_i, + size=[h_out, w_out], + mode='bilinear', + align_corners=self.align_corners) + out = self.dfm(x_p, x_i, x_d) + return (temp_p, out, temp_d) if self.training else out diff --git a/finetune/mmseg/models/backbones/resnest.py b/finetune/mmseg/models/backbones/resnest.py new file mode 100644 index 0000000..3cc380b --- /dev/null +++ b/finetune/mmseg/models/backbones/resnest.py @@ -0,0 +1,318 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmseg.registry import MODELS +from ..utils import ResLayer +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNetV1d + + +class RSoftmax(nn.Module): + """Radix Softmax module in ``SplitAttentionConv2d``. + + Args: + radix (int): Radix of input. + groups (int): Groups of input. + """ + + def __init__(self, radix, groups): + super().__init__() + self.radix = radix + self.groups = groups + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttentionConv2d(nn.Module): + """Split-Attention Conv2d in ResNeSt. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int | tuple[int]): Same as nn.Conv2d. + stride (int | tuple[int]): Same as nn.Conv2d. + padding (int | tuple[int]): Same as nn.Conv2d. + dilation (int | tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels. Default: 4. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + dcn (dict): Config dict for DCN. Default: None. + """ + + def __init__(self, + in_channels, + channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + radix=2, + reduction_factor=4, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None): + super().__init__() + inter_channels = max(in_channels * radix // reduction_factor, 32) + self.radix = radix + self.groups = groups + self.channels = channels + self.with_dcn = dcn is not None + self.dcn = dcn + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if self.with_dcn and not fallback_on_stride: + assert conv_cfg is None, 'conv_cfg must be None for DCN' + conv_cfg = dcn + self.conv = build_conv_layer( + conv_cfg, + in_channels, + channels * radix, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups * radix, + bias=False) + self.norm0_name, norm0 = build_norm_layer( + norm_cfg, channels * radix, postfix=0) + self.add_module(self.norm0_name, norm0) + self.relu = nn.ReLU(inplace=True) + self.fc1 = build_conv_layer( + None, channels, inter_channels, 1, groups=self.groups) + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, inter_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.fc2 = build_conv_layer( + None, inter_channels, channels * radix, 1, groups=self.groups) + self.rsoftmax = RSoftmax(radix, groups) + + @property + def norm0(self): + """nn.Module: the normalization layer named "norm0" """ + return getattr(self, self.norm0_name) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def forward(self, x): + x = self.conv(x) + x = self.norm0(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + batch = x.size(0) + if self.radix > 1: + splits = x.view(batch, self.radix, -1, *x.shape[2:]) + gap = splits.sum(dim=1) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + gap = self.norm1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) + out = torch.sum(attens * splits, dim=1) + else: + out = atten * x + return out.contiguous() + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeSt. + + Args: + inplane (int): Input planes of this block. + planes (int): Middle planes of this block. + groups (int): Groups of conv2. + width_per_group (int): Width per group of conv2. 64x4d indicates + ``groups=64, width_per_group=4`` and 32x8d indicates + ``groups=32, width_per_group=8``. + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Key word arguments for base class. + """ + expansion = 4 + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + """Bottleneck block for ResNeSt.""" + super().__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + self.with_modulated_dcn = False + self.conv2 = SplitAttentionConv2d( + width, + width, + kernel_size=3, + stride=1 if self.avg_down_stride else self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + radix=radix, + reduction_factor=reduction_factor, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + dcn=self.dcn) + delattr(self, self.norm2_name) + + if self.avg_down_stride: + self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) + + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + + if self.avg_down_stride: + out = self.avd_layer(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ResNeSt(ResNetV1d): + """ResNeSt backbone. + + This backbone is the implementation of `ResNeSt: + Split-Attention Networks `_. + + Args: + groups (int): Number of groups of Bottleneck. Default: 1 + base_width (int): Base width of Bottleneck. Default: 4 + radix (int): Radix of SpltAtConv2d. Default: 2 + reduction_factor (int): Reduction factor of inter_channels in + SplitAttentionConv2d. Default: 4. + avg_down_stride (bool): Whether to use average pool for stride in + Bottleneck. Default: True. + kwargs (dict): Keyword arguments for ResNet. + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)), + 200: (Bottleneck, (3, 24, 36, 3)) + } + + def __init__(self, + groups=1, + base_width=4, + radix=2, + reduction_factor=4, + avg_down_stride=True, + **kwargs): + self.groups = groups + self.base_width = base_width + self.radix = radix + self.reduction_factor = reduction_factor + self.avg_down_stride = avg_down_stride + super().__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + radix=self.radix, + reduction_factor=self.reduction_factor, + avg_down_stride=self.avg_down_stride, + **kwargs) diff --git a/finetune/mmseg/models/backbones/resnet.py b/finetune/mmseg/models/backbones/resnet.py new file mode 100644 index 0000000..9226c90 --- /dev/null +++ b/finetune/mmseg/models/backbones/resnet.py @@ -0,0 +1,712 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer +from mmengine.model import BaseModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import ResLayer + + +class BasicBlock(BaseModule): + """Basic block for ResNet.""" + + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + super().__init__(init_cfg) + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(BaseModule): + """Bottleneck block for ResNet. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is + "caffe", the stride-two layer is the first 1x1 conv layer. + """ + + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + super().__init__(init_cfg) + assert style in ['pytorch', 'caffe'] + assert dcn is None or isinstance(dcn, dict) + assert plugins is None or isinstance(plugins, list) + if plugins is not None: + allowed_position = ['after_conv1', 'after_conv2', 'after_conv3'] + assert all(p['position'] in allowed_position for p in plugins) + + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dcn = dcn + self.with_dcn = dcn is not None + self.plugins = plugins + self.with_plugins = plugins is not None + + if self.with_plugins: + # collect plugins for conv1/conv2/conv3 + self.after_conv1_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv1' + ] + self.after_conv2_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv2' + ] + self.after_conv3_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv3' + ] + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + conv_cfg, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + dcn, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + planes, + planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + if self.with_plugins: + self.after_conv1_plugin_names = self.make_block_plugins( + planes, self.after_conv1_plugins) + self.after_conv2_plugin_names = self.make_block_plugins( + planes, self.after_conv2_plugins) + self.after_conv3_plugin_names = self.make_block_plugins( + planes * self.expansion, self.after_conv3_plugins) + + def make_block_plugins(self, in_channels, plugins): + """make plugins for block. + + Args: + in_channels (int): Input channels of plugin. + plugins (list[dict]): List of plugins cfg to build. + + Returns: + list[str]: List of the names of plugin. + """ + assert isinstance(plugins, list) + plugin_names = [] + for plugin in plugins: + plugin = plugin.copy() + name, layer = build_plugin_layer( + plugin, + in_channels=in_channels, + postfix=plugin.pop('postfix', '')) + assert not hasattr(self, name), f'duplicate plugin {name}' + self.add_module(name, layer) + plugin_names.append(name) + return plugin_names + + def forward_plugin(self, x, plugin_names): + """Forward function for plugins.""" + out = x + for name in plugin_names: + out = getattr(self, name)(x) + return out + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: normalization layer after the third convolution layer""" + return getattr(self, self.norm3_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ResNet(BaseModule): + """ResNet backbone. + + This backbone is the improved implementation of `Deep Residual Learning + for Image Recognition `_. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default: 3. + stem_channels (int): Number of stem channels. Default: 64. + base_channels (int): Number of base channels of res layer. Default: 64. + num_stages (int): Resnet stages, normally 4. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + Default: (1, 2, 2, 2). + dilations (Sequence[int]): Dilation of each stage. + Default: (1, 1, 1, 1). + out_indices (Sequence[int]): Output from which stages. + Default: (0, 1, 2, 3). + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. Default: 'pytorch'. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Default: False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Default: -1. + conv_cfg (dict | None): Dictionary to construct and config conv layer. + When conv_cfg is None, cfg will be set to dict(type='Conv2d'). + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN', requires_grad=True). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (dict | None): Dictionary to construct and config DCN conv layer. + When dcn is not None, conv_cfg must be None. Default: None. + stage_with_dcn (Sequence[bool]): Whether to set DCN conv for each + stage. The length of stage_with_dcn is equal to num_stages. + Default: (False, False, False, False). + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + + - position (str, required): Position inside block to insert plugin, + options: 'after_conv1', 'after_conv2', 'after_conv3'. + + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + Default: None. + multi_grid (Sequence[int]|None): Multi grid dilation rates of last + stage. Default: None. + contract_dilation (bool): Whether contract first dilation of each layer + Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Default: True. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> from mmseg.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=64, + base_channels=64, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=False, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + multi_grid=None, + contract_dilation=False, + with_cp=False, + zero_init_residual=True, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + + self.pretrained = pretrained + self.zero_init_residual = zero_init_residual + block_init_cfg = None + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + block = self.arch_settings[depth][0] + if self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm3')) + else: + raise TypeError('pretrained must be a str or None') + + self.depth = depth + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + if dcn is not None: + assert len(stage_with_dcn) == num_stages + self.plugins = plugins + self.multi_grid = multi_grid + self.contract_dilation = contract_dilation + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + if plugins is not None: + stage_plugins = self.make_stage_plugins(plugins, i) + else: + stage_plugins = None + # multi grid is applied to last layer only + stage_multi_grid = multi_grid if i == len( + self.stage_blocks) - 1 else None + planes = base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + dcn=dcn, + plugins=stage_plugins, + multi_grid=stage_multi_grid, + contract_dilation=contract_dilation, + init_cfg=block_init_cfg) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i+1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2**( + len(self.stage_blocks) - 1) + + def make_stage_plugins(self, plugins, stage_idx): + """make plugins for ResNet 'stage_idx'th stage . + + Currently we support to insert 'context_block', + 'empirical_attention_block', 'nonlocal_block' into the backbone like + ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of + Bottleneck. + + An example of plugins format could be : + >>> plugins=[ + ... dict(cfg=dict(type='xxx', arg1='xxx'), + ... stages=(False, True, True, True), + ... position='after_conv2'), + ... dict(cfg=dict(type='yyy'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='1'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='2'), + ... stages=(True, True, True, True), + ... position='after_conv3') + ... ] + >>> self = ResNet(depth=18) + >>> stage_plugins = self.make_stage_plugins(plugins, 0) + >>> assert len(stage_plugins) == 3 + + Suppose 'stage_idx=0', the structure of blocks in the stage would be: + conv1-> conv2->conv3->yyy->zzz1->zzz2 + Suppose 'stage_idx=1', the structure of blocks in the stage would be: + conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 + + If stages is missing, the plugin would be applied to all stages. + + Args: + plugins (list[dict]): List of plugins cfg to build. The postfix is + required if multiple same type plugins are inserted. + stage_idx (int): Index of stage to build + + Returns: + list[dict]: Plugins for current stage + """ + stage_plugins = [] + for plugin in plugins: + plugin = plugin.copy() + stages = plugin.pop('stages', None) + assert stages is None or len(stages) == self.num_stages + # whether to insert plugin into current stage + if stages is None or stages[stage_idx]: + stage_plugins.append(plugin) + + return stage_plugins + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + """Make stem layer for ResNet.""" + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels)[1], + nn.ReLU(inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def _freeze_stages(self): + """Freeze stages param and norm stats.""" + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = self.maxpool(x) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super().train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@MODELS.register_module() +class ResNetV1c(ResNet): + """ResNetV1c variant described in [1]_. + + Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in + the input stem with three 3x3 convs. For more details please refer to `Bag + of Tricks for Image Classification with Convolutional Neural Networks + `_. + """ + + def __init__(self, **kwargs): + super().__init__(deep_stem=True, avg_down=False, **kwargs) + + +@MODELS.register_module() +class ResNetV1d(ResNet): + """ResNetV1d variant described in [1]_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super().__init__(deep_stem=True, avg_down=True, **kwargs) diff --git a/finetune/mmseg/models/backbones/resnext.py b/finetune/mmseg/models/backbones/resnext.py new file mode 100644 index 0000000..67a244a --- /dev/null +++ b/finetune/mmseg/models/backbones/resnext.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from mmcv.cnn import build_conv_layer, build_norm_layer + +from mmseg.registry import MODELS +from ..utils import ResLayer +from .resnet import Bottleneck as _Bottleneck +from .resnet import ResNet + + +class Bottleneck(_Bottleneck): + """Bottleneck block for ResNeXt. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is + "caffe", the stride-two layer is the first 1x1 conv layer. + """ + + def __init__(self, + inplanes, + planes, + groups=1, + base_width=4, + base_channels=64, + **kwargs): + super().__init__(inplanes, planes, **kwargs) + + if groups == 1: + width = self.planes + else: + width = math.floor(self.planes * + (base_width / base_channels)) * groups + + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, width, postfix=1) + self.norm2_name, norm2 = build_norm_layer( + self.norm_cfg, width, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + self.norm_cfg, self.planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + self.conv_cfg, + self.inplanes, + width, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + self.with_modulated_dcn = False + if self.with_dcn: + fallback_on_stride = self.dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + self.conv_cfg, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + self.dcn, + width, + width, + kernel_size=3, + stride=self.conv2_stride, + padding=self.dilation, + dilation=self.dilation, + groups=groups, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + self.conv_cfg, + width, + self.planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + +@MODELS.register_module() +class ResNeXt(ResNet): + """ResNeXt backbone. + + This backbone is the implementation of `Aggregated + Residual Transformations for Deep Neural + Networks `_. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Normally 3. + num_stages (int): Resnet stages, normally 4. + groups (int): Group of resnext. + base_width (int): Base width of resnext. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_cfg (dict): dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): whether to use zero init for last norm layer + in resblocks to let them behave as identity. + + Example: + >>> from mmseg.models import ResNeXt + >>> import torch + >>> self = ResNeXt(depth=50) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 256, 8, 8) + (1, 512, 4, 4) + (1, 1024, 2, 2) + (1, 2048, 1, 1) + """ + + arch_settings = { + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, groups=1, base_width=4, **kwargs): + self.groups = groups + self.base_width = base_width + super().__init__(**kwargs) + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``""" + return ResLayer( + groups=self.groups, + base_width=self.base_width, + base_channels=self.base_channels, + **kwargs) diff --git a/finetune/mmseg/models/backbones/stdc.py b/finetune/mmseg/models/backbones/stdc.py new file mode 100644 index 0000000..758a3c9 --- /dev/null +++ b/finetune/mmseg/models/backbones/stdc.py @@ -0,0 +1,422 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/MichaelFan01/STDC-Seg.""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule, ModuleList, Sequential + +from mmseg.registry import MODELS +from ..utils import resize +from .bisenetv1 import AttentionRefinementModule + + +class STDCModule(BaseModule): + """STDCModule. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels before scaling. + stride (int): The number of stride for the first conv layer. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): The activation config for conv layers. + num_convs (int): Numbers of conv layers. + fusion_type (str): Type of fusion operation. Default: 'add'. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + norm_cfg=None, + act_cfg=None, + num_convs=4, + fusion_type='add', + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert num_convs > 1 + assert fusion_type in ['add', 'cat'] + self.stride = stride + self.with_downsample = True if self.stride == 2 else False + self.fusion_type = fusion_type + + self.layers = ModuleList() + conv_0 = ConvModule( + in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg) + + if self.with_downsample: + self.downsample = ConvModule( + out_channels // 2, + out_channels // 2, + kernel_size=3, + stride=2, + padding=1, + groups=out_channels // 2, + norm_cfg=norm_cfg, + act_cfg=None) + + if self.fusion_type == 'add': + self.layers.append(nn.Sequential(conv_0, self.downsample)) + self.skip = Sequential( + ConvModule( + in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=1, + groups=in_channels, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + in_channels, + out_channels, + 1, + norm_cfg=norm_cfg, + act_cfg=None)) + else: + self.layers.append(conv_0) + self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + else: + self.layers.append(conv_0) + + for i in range(1, num_convs): + out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i + self.layers.append( + ConvModule( + out_channels // 2**i, + out_channels // out_factor, + kernel_size=3, + stride=1, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs): + if self.fusion_type == 'add': + out = self.forward_add(inputs) + else: + out = self.forward_cat(inputs) + return out + + def forward_add(self, inputs): + layer_outputs = [] + x = inputs.clone() + for layer in self.layers: + x = layer(x) + layer_outputs.append(x) + if self.with_downsample: + inputs = self.skip(inputs) + + return torch.cat(layer_outputs, dim=1) + inputs + + def forward_cat(self, inputs): + x0 = self.layers[0](inputs) + layer_outputs = [x0] + for i, layer in enumerate(self.layers[1:]): + if i == 0: + if self.with_downsample: + x = layer(self.downsample(x0)) + else: + x = layer(x0) + else: + x = layer(x) + layer_outputs.append(x) + if self.with_downsample: + layer_outputs[0] = self.skip(x0) + return torch.cat(layer_outputs, dim=1) + + +class FeatureFusionModule(BaseModule): + """Feature Fusion Module. This module is different from FeatureFusionModule + in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter + channel number is calculated by given `scale_factor`, while + FeatureFusionModule in BiSeNetV1 only uses one ConvModule in + `self.conv_atten`. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + scale_factor (int): The number of channel scale factor. + Default: 4. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): The activation config for conv layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + scale_factor=4, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + channels = out_channels // scale_factor + self.conv0 = ConvModule( + in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) + self.attention = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + out_channels, + channels, + 1, + norm_cfg=None, + bias=False, + act_cfg=act_cfg), + ConvModule( + channels, + out_channels, + 1, + norm_cfg=None, + bias=False, + act_cfg=None), nn.Sigmoid()) + + def forward(self, spatial_inputs, context_inputs): + inputs = torch.cat([spatial_inputs, context_inputs], dim=1) + x = self.conv0(inputs) + attn = self.attention(x) + x_attn = x * attn + return x_attn + x + + +@MODELS.register_module() +class STDCNet(BaseModule): + """This backbone is the implementation of `Rethinking BiSeNet For Real-time + Semantic Segmentation `_. + + Args: + stdc_type (int): The type of backbone structure, + `STDCNet1` and`STDCNet2` denotes two main backbones in paper, + whose FLOPs is 813M and 1446M, respectively. + in_channels (int): The num of input_channels. + channels (tuple[int]): The output channels for each stage. + bottleneck_type (str): The type of STDC Module type, the value must + be 'add' or 'cat'. + norm_cfg (dict): Config dict for normalization layer. + act_cfg (dict): The activation config for conv layers. + num_convs (int): Numbers of conv layer at each STDC Module. + Default: 4. + with_final_conv (bool): Whether add a conv layer at the Module output. + Default: True. + pretrained (str, optional): Model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Example: + >>> import torch + >>> stdc_type = 'STDCNet1' + >>> in_channels = 3 + >>> channels = (32, 64, 256, 512, 1024) + >>> bottleneck_type = 'cat' + >>> inputs = torch.rand(1, 3, 1024, 2048) + >>> self = STDCNet(stdc_type, in_channels, + ... channels, bottleneck_type).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 256, 128, 256]) + outputs[1].shape = torch.Size([1, 512, 64, 128]) + outputs[2].shape = torch.Size([1, 1024, 32, 64]) + """ + + arch_settings = { + 'STDCNet1': [(2, 1), (2, 1), (2, 1)], + 'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)] + } + + def __init__(self, + stdc_type, + in_channels, + channels, + bottleneck_type, + norm_cfg, + act_cfg, + num_convs=4, + with_final_conv=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert stdc_type in self.arch_settings, \ + f'invalid structure {stdc_type} for STDCNet.' + assert bottleneck_type in ['add', 'cat'],\ + f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}' + + assert len(channels) == 5,\ + f'invalid channels length {len(channels)} for STDCNet.' + + self.in_channels = in_channels + self.channels = channels + self.stage_strides = self.arch_settings[stdc_type] + self.prtrained = pretrained + self.num_convs = num_convs + self.with_final_conv = with_final_conv + + self.stages = ModuleList([ + ConvModule( + self.in_channels, + self.channels[0], + kernel_size=3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + self.channels[0], + self.channels[1], + kernel_size=3, + stride=2, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + ]) + # `self.num_shallow_features` is the number of shallow modules in + # `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper. + # They are both not used for following modules like Attention + # Refinement Module and Feature Fusion Module. + # Thus they would be cut from `outs`. Please refer to Figure 4 + # of original paper for more details. + self.num_shallow_features = len(self.stages) + + for strides in self.stage_strides: + idx = len(self.stages) - 1 + self.stages.append( + self._make_stage(self.channels[idx], self.channels[idx + 1], + strides, norm_cfg, act_cfg, bottleneck_type)) + # After appending, `self.stages` is a ModuleList including several + # shallow modules and STDCModules. + # (len(self.stages) == + # self.num_shallow_features + len(self.stage_strides)) + if self.with_final_conv: + self.final_conv = ConvModule( + self.channels[-1], + max(1024, self.channels[-1]), + 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def _make_stage(self, in_channels, out_channels, strides, norm_cfg, + act_cfg, bottleneck_type): + layers = [] + for i, stride in enumerate(strides): + layers.append( + STDCModule( + in_channels if i == 0 else out_channels, + out_channels, + stride, + norm_cfg, + act_cfg, + num_convs=self.num_convs, + fusion_type=bottleneck_type)) + return Sequential(*layers) + + def forward(self, x): + outs = [] + for stage in self.stages: + x = stage(x) + outs.append(x) + if self.with_final_conv: + outs[-1] = self.final_conv(outs[-1]) + outs = outs[self.num_shallow_features:] + return tuple(outs) + + +@MODELS.register_module() +class STDCContextPathNet(BaseModule): + """STDCNet with Context Path. The `outs` below is a list of three feature + maps from deep to shallow, whose height and width is from small to big, + respectively. The biggest feature map of `outs` is outputted for + `STDCHead`, where Detail Loss would be calculated by Detail Ground-truth. + The other two feature maps are used for Attention Refinement Module, + respectively. Besides, the biggest feature map of `outs` and the last + output of Attention Refinement Module are concatenated for Feature Fusion + Module. Then, this fusion feature map `feat_fuse` would be outputted for + `decode_head`. More details please refer to Figure 4 of original paper. + + Args: + backbone_cfg (dict): Config dict for stdc backbone. + last_in_channels (tuple(int)), The number of channels of last + two feature maps from stdc backbone. Default: (1024, 512). + out_channels (int): The channels of output feature maps. + Default: 128. + ffm_cfg (dict): Config dict for Feature Fusion Module. Default: + `dict(in_channels=512, out_channels=256, scale_factor=4)`. + upsample_mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'``. + align_corners (str): align_corners argument of F.interpolate. It + must be `None` if upsample_mode is ``'nearest'``. Default: None. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Return: + outputs (tuple): The tuple of list of output feature map for + auxiliary heads and decoder head. + """ + + def __init__(self, + backbone_cfg, + last_in_channels=(1024, 512), + out_channels=128, + ffm_cfg=dict( + in_channels=512, out_channels=256, scale_factor=4), + upsample_mode='nearest', + align_corners=None, + norm_cfg=dict(type='BN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.backbone = MODELS.build(backbone_cfg) + self.arms = ModuleList() + self.convs = ModuleList() + for channels in last_in_channels: + self.arms.append(AttentionRefinementModule(channels, out_channels)) + self.convs.append( + ConvModule( + out_channels, + out_channels, + 3, + padding=1, + norm_cfg=norm_cfg)) + self.conv_avg = ConvModule( + last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg) + + self.ffm = FeatureFusionModule(**ffm_cfg) + + self.upsample_mode = upsample_mode + self.align_corners = align_corners + + def forward(self, x): + outs = list(self.backbone(x)) + avg = F.adaptive_avg_pool2d(outs[-1], 1) + avg_feat = self.conv_avg(avg) + + feature_up = resize( + avg_feat, + size=outs[-1].shape[2:], + mode=self.upsample_mode, + align_corners=self.align_corners) + arms_out = [] + for i in range(len(self.arms)): + x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up + feature_up = resize( + x_arm, + size=outs[len(outs) - 1 - i - 1].shape[2:], + mode=self.upsample_mode, + align_corners=self.align_corners) + feature_up = self.convs[i](feature_up) + arms_out.append(feature_up) + + feat_fuse = self.ffm(outs[0], arms_out[1]) + + # The `outputs` has four feature maps. + # `outs[0]` is outputted for `STDCHead` auxiliary head. + # Two feature maps of `arms_out` are outputted for auxiliary head. + # `feat_fuse` is outputted for decoder head. + outputs = [outs[0]] + list(arms_out) + [feat_fuse] + return tuple(outputs) diff --git a/finetune/mmseg/models/backbones/swin.py b/finetune/mmseg/models/backbones/swin.py new file mode 100644 index 0000000..67b28a9 --- /dev/null +++ b/finetune/mmseg/models/backbones/swin.py @@ -0,0 +1,757 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, build_dropout +from mmengine.logging import print_log +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, trunc_normal_, + trunc_normal_init) +from mmengine.runner import CheckpointLoader +from mmengine.utils import to_2tuple + +from mmseg.registry import MODELS +from ..utils.embed import PatchEmbed, PatchMerging + + +class WindowMSA(BaseModule): + """Window based multi-head self-attention (W-MSA) module with relative + position bias. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): The height and width of the window. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + init_cfg (dict | None, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_embed_dims = embed_dims // num_heads + self.scale = qk_scale or head_embed_dims**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # About 2x faster than original impl + Wh, Ww = self.window_size + rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) + rel_position_index = rel_index_coords + rel_index_coords.T + rel_position_index = rel_position_index.flip(1).contiguous() + self.register_buffer('relative_position_index', rel_position_index) + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + + self.softmax = nn.Softmax(dim=-1) + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x, mask=None): + """ + Args: + + x (tensor): input features with shape of (num_windows*B, N, C) + mask (tensor | None, Optional): mask with shape of (num_windows, + Wh*Ww, Wh*Ww), value should be between (-inf, 0]. + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @staticmethod + def double_step_seq(step1, len1, step2, len2): + seq1 = torch.arange(0, step1 * len1, step1) + seq2 = torch.arange(0, step2 * len2, step2) + return (seq1[:, None] + seq2[None, :]).reshape(1, -1) + + +class ShiftWindowMSA(BaseModule): + """Shifted Window Multihead Self-Attention Module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. + shift_size (int, optional): The shift step of each window towards + right-bottom. If zero, act as regular window-msa. Defaults to 0. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Defaults: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Defaults: 0. + proj_drop_rate (float, optional): Dropout ratio of output. + Defaults: 0. + dropout_layer (dict, optional): The dropout_layer used before output. + Defaults: dict(type='DropPath', drop_prob=0.). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size, + shift_size=0, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0, + proj_drop_rate=0, + dropout_layer=dict(type='DropPath', drop_prob=0.), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.window_size = window_size + self.shift_size = shift_size + assert 0 <= self.shift_size < self.window_size + + self.w_msa = WindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=to_2tuple(window_size), + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate, + init_cfg=None) + + self.drop = build_dropout(dropout_layer) + + def forward(self, query, hw_shape): + B, L, C = query.shape + H, W = hw_shape + assert L == H * W, 'input feature has wrong size' + query = query.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b)) + H_pad, W_pad = query.shape[1], query.shape[2] + + # cyclic shift + if self.shift_size > 0: + shifted_query = torch.roll( + query, + shifts=(-self.shift_size, -self.shift_size), + dims=(1, 2)) + + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device) + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = self.window_partition(img_mask) + mask_windows = mask_windows.view( + -1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + else: + shifted_query = query + attn_mask = None + + # nW*B, window_size, window_size, C + query_windows = self.window_partition(shifted_query) + # nW*B, window_size*window_size, C + query_windows = query_windows.view(-1, self.window_size**2, C) + + # W-MSA/SW-MSA (nW*B, window_size*window_size, C) + attn_windows = self.w_msa(query_windows, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + + # B H' W' C + shifted_x = self.window_reverse(attn_windows, H_pad, W_pad) + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = self.drop(x) + return x + + def window_reverse(self, windows, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + window_size = self.window_size + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + def window_partition(self, x): + """ + Args: + x: (B, H, W, C) + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + window_size = self.window_size + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + +class SwinBlock(BaseModule): + """" + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + window_size (int, optional): The local window scale. Default: 7. + shift (bool, optional): whether to shift window or not. Default False. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float, optional): Stochastic depth rate. Default: 0. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + window_size=7, + shift=False, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + self.with_cp = with_cp + + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = ShiftWindowMSA( + embed_dims=embed_dims, + num_heads=num_heads, + window_size=window_size, + shift_size=window_size // 2 if shift else 0, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + init_cfg=None) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=2, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=True, + init_cfg=None) + + def forward(self, x, hw_shape): + + def _inner_forward(x): + identity = x + x = self.norm1(x) + x = self.attn(x, hw_shape) + + x = x + identity + + identity = x + x = self.norm2(x) + x = self.ffn(x, identity=identity) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SwinBlockSequence(BaseModule): + """Implements one stage in Swin Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + depth (int): The number of blocks in this stage. + window_size (int, optional): The local window scale. Default: 7. + qkv_bias (bool, optional): enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout rate. Default: 0. + attn_drop_rate (float, optional): Attention dropout rate. Default: 0. + drop_path_rate (float | list[float], optional): Stochastic depth + rate. Default: 0. + downsample (BaseModule | None, optional): The downsample operation + module. Default: None. + act_cfg (dict, optional): The config dict of activation function. + Default: dict(type='GELU'). + norm_cfg (dict, optional): The config dict of normalization. + Default: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + init_cfg (dict | list | None, optional): The init config. + Default: None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + depth, + window_size=7, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + downsample=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(drop_path_rate, list): + drop_path_rates = drop_path_rate + assert len(drop_path_rates) == depth + else: + drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)] + + self.blocks = ModuleList() + for i in range(depth): + block = SwinBlock( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=feedforward_channels, + window_size=window_size, + shift=False if i % 2 == 0 else True, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rates[i], + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + init_cfg=None) + self.blocks.append(block) + + self.downsample = downsample + + def forward(self, x, hw_shape): + for block in self.blocks: + x = block(x, hw_shape) + + if self.downsample: + x_down, down_hw_shape = self.downsample(x, hw_shape) + return x_down, down_hw_shape, x, hw_shape + else: + return x, hw_shape, x, hw_shape + + +@MODELS.register_module() +class SwinTransformer(BaseModule): + """Swin Transformer backbone. + + This backbone is the implementation of `Swin Transformer: + Hierarchical Vision Transformer using Shifted + Windows `_. + Inspiration from https://github.com/microsoft/Swin-Transformer. + + Args: + pretrain_img_size (int | tuple[int]): The size of input image when + pretrain. Defaults: 224. + in_channels (int): The num of input channels. + Defaults: 3. + embed_dims (int): The feature dimension. Default: 96. + patch_size (int | tuple[int]): Patch size. Default: 4. + window_size (int): Window size. Default: 7. + mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim. + Default: 4. + depths (tuple[int]): Depths of each Swin Transformer stage. + Default: (2, 2, 6, 2). + num_heads (tuple[int]): Parallel attention heads of each Swin + Transformer stage. Default: (3, 6, 12, 24). + strides (tuple[int]): The patch merging or patch embedding stride of + each Swin Transformer stage. (In swin, we set kernel size equal to + stride.) Default: (4, 2, 2, 2). + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool, optional): If True, add a learnable bias to query, key, + value. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + patch_norm (bool): If add a norm layer for patch embed and patch + merging. Default: True. + drop_rate (float): Dropout rate. Defaults: 0. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Defaults: 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults: False. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='LN'). + norm_cfg (dict): Config dict for normalization layer at + output of backone. Defaults: dict(type='LN'). + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + Default: False. + pretrained (str, optional): model pretrained path. Default: None. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + pretrain_img_size=224, + in_channels=3, + embed_dims=96, + patch_size=4, + window_size=7, + mlp_ratio=4, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + strides=(4, 2, 2, 2), + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + patch_norm=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + use_abs_pos_embed=False, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + with_cp=False, + pretrained=None, + frozen_stages=-1, + init_cfg=None): + self.frozen_stages = frozen_stages + + if isinstance(pretrain_img_size, int): + pretrain_img_size = to_2tuple(pretrain_img_size) + elif isinstance(pretrain_img_size, tuple): + if len(pretrain_img_size) == 1: + pretrain_img_size = to_2tuple(pretrain_img_size[0]) + assert len(pretrain_img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(pretrain_img_size)}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + init_cfg = init_cfg + else: + raise TypeError('pretrained must be a str or None') + + super().__init__(init_cfg=init_cfg) + + num_layers = len(depths) + self.out_indices = out_indices + self.use_abs_pos_embed = use_abs_pos_embed + + assert strides[0] == patch_size, 'Use non-overlapping patch embed.' + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=strides[0], + padding='corner', + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + + if self.use_abs_pos_embed: + patch_row = pretrain_img_size[0] // patch_size + patch_col = pretrain_img_size[1] // patch_size + num_patches = patch_row * patch_col + self.absolute_pos_embed = nn.Parameter( + torch.zeros((1, num_patches, embed_dims))) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + + # set stochastic depth decay rule + total_depth = sum(depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] + + self.stages = ModuleList() + in_channels = embed_dims + for i in range(num_layers): + if i < num_layers - 1: + downsample = PatchMerging( + in_channels=in_channels, + out_channels=2 * in_channels, + stride=strides[i + 1], + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None) + else: + downsample = None + + stage = SwinBlockSequence( + embed_dims=in_channels, + num_heads=num_heads[i], + feedforward_channels=int(mlp_ratio * in_channels), + depth=depths[i], + window_size=window_size, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])], + downsample=downsample, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + init_cfg=None) + self.stages.append(stage) + if downsample: + in_channels = downsample.out_channels + + self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)] + # Add a norm layer for each output + for i in out_indices: + layer = build_norm_layer(norm_cfg, self.num_features[i])[1] + layer_name = f'norm{i}' + self.add_module(layer_name, layer) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super().train(mode) + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + if self.use_abs_pos_embed: + self.absolute_pos_embed.requires_grad = False + self.drop_after_pos.eval() + + for i in range(1, self.frozen_stages + 1): + + if (i - 1) in self.out_indices: + norm_layer = getattr(self, f'norm{i-1}') + norm_layer.eval() + for param in norm_layer.parameters(): + param.requires_grad = False + + m = self.stages[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self): + if self.init_cfg is None: + print_log(f'No pre-trained weights for ' + f'{self.__class__.__name__}, ' + f'training start from scratch') + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + else: + assert 'checkpoint' in self.init_cfg, f'Only support ' \ + f'specify `Pretrained` in ' \ + f'`init_cfg` in ' \ + f'{self.__class__.__name__} ' + ckpt = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + if 'state_dict' in ckpt: + _state_dict = ckpt['state_dict'] + elif 'model' in ckpt: + _state_dict = ckpt['model'] + else: + _state_dict = ckpt + + state_dict = OrderedDict() + for k, v in _state_dict.items(): + if k.startswith('backbone.'): + state_dict[k[9:]] = v + else: + state_dict[k] = v + + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = self.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H * W: + print_log('Error in loading absolute_pos_embed, pass') + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view( + N2, H, W, C2).permute(0, 3, 1, 2).contiguous() + + # interpolate position bias table if needed + relative_position_bias_table_keys = [ + k for k in state_dict.keys() + if 'relative_position_bias_table' in k + ] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + if table_key in self.state_dict(): + table_current = self.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + print_log(f'Error in loading {table_key}, pass') + elif L1 != L2: + S1 = int(L1**0.5) + S2 = int(L2**0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).reshape( + 1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view( + nH2, L2).permute(1, 0).contiguous() + + # load state_dict + self.load_state_dict(state_dict, strict=False) + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + + if self.use_abs_pos_embed: + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape, out, out_hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(out) + out = out.view(-1, *out_hw_shape, + self.num_features[i]).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return outs diff --git a/finetune/mmseg/models/backbones/timm_backbone.py b/finetune/mmseg/models/backbones/timm_backbone.py new file mode 100644 index 0000000..1eef302 --- /dev/null +++ b/finetune/mmseg/models/backbones/timm_backbone.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + import timm +except ImportError: + timm = None + +from mmengine.model import BaseModule +from mmengine.registry import MODELS as MMENGINE_MODELS + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class TIMMBackbone(BaseModule): + """Wrapper to use backbones from timm library. More details can be found in + `timm `_ . + + Args: + model_name (str): Name of timm model to instantiate. + pretrained (bool): Load pretrained weights if True. + checkpoint_path (str): Path of checkpoint to load after + model is initialized. + in_channels (int): Number of input image channels. Default: 3. + init_cfg (dict, optional): Initialization config dict + **kwargs: Other timm & model specific arguments. + """ + + def __init__( + self, + model_name, + features_only=True, + pretrained=True, + checkpoint_path='', + in_channels=3, + init_cfg=None, + **kwargs, + ): + if timm is None: + raise RuntimeError('timm is not installed') + super().__init__(init_cfg) + if 'norm_layer' in kwargs: + kwargs['norm_layer'] = MMENGINE_MODELS.get(kwargs['norm_layer']) + self.timm_model = timm.create_model( + model_name=model_name, + features_only=features_only, + pretrained=pretrained, + in_chans=in_channels, + checkpoint_path=checkpoint_path, + **kwargs, + ) + + # Make unused parameters None + self.timm_model.global_pool = None + self.timm_model.fc = None + self.timm_model.classifier = None + + # Hack to use pretrained weights from timm + if pretrained or checkpoint_path: + self._is_init = True + + def forward(self, x): + features = self.timm_model(x) + return features diff --git a/finetune/mmseg/models/backbones/twins.py b/finetune/mmseg/models/backbones/twins.py new file mode 100644 index 0000000..b6a6eea --- /dev/null +++ b/finetune/mmseg/models/backbones/twins.py @@ -0,0 +1,588 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import FFN +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, normal_init, + trunc_normal_init) +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.models.backbones.mit import EfficientMultiheadAttention +from mmseg.registry import MODELS +from ..utils.embed import PatchEmbed + + +class GlobalSubsampledAttention(EfficientMultiheadAttention): + """Global Sub-sampled Attention (Spatial Reduction Attention) + + This module is modified from EfficientMultiheadAttention, + which is a module from mmseg.models.backbones.mit.py. + Specifically, there is no difference between + `GlobalSubsampledAttention` and `EfficientMultiheadAttention`, + `GlobalSubsampledAttention` is built as a brand new class + because it is renamed as `Global sub-sampled attention (GSA)` + in paper. + + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. Default: None. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dims) + or (n, batch, embed_dims). Default: False. + qkv_bias (bool): enable bias for qkv if True. Default: True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of GSA of PCPVT. + Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=None, + batch_first=True, + qkv_bias=True, + norm_cfg=dict(type='LN'), + sr_ratio=1, + init_cfg=None): + super().__init__( + embed_dims, + num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + dropout_layer=dropout_layer, + batch_first=batch_first, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio, + init_cfg=init_cfg) + + +class GSAEncoderLayer(BaseModule): + """Implements one encoder layer with GSA. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (float): Kernel_size of conv in Attention modules. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + sr_ratio=1., + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = GlobalSubsampledAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape, identity=0.)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +class LocallyGroupedSelfAttention(BaseModule): + """Locally-grouped Self Attention (LSA) module. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8 + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. + window_size(int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + window_size=1, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + assert embed_dims % num_heads == 0, f'dim {embed_dims} should be ' \ + f'divided by num_heads ' \ + f'{num_heads}.' + self.embed_dims = embed_dims + self.num_heads = num_heads + head_dim = embed_dims // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop_rate) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop_rate) + self.window_size = window_size + + def forward(self, x, hw_shape): + b, n, c = x.shape + h, w = hw_shape + x = x.view(b, h, w, c) + + # pad feature maps to multiples of Local-groups + pad_l = pad_t = 0 + pad_r = (self.window_size - w % self.window_size) % self.window_size + pad_b = (self.window_size - h % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + # calculate attention mask for LSA + Hp, Wp = x.shape[1:-1] + _h, _w = Hp // self.window_size, Wp // self.window_size + mask = torch.zeros((1, Hp, Wp), device=x.device) + mask[:, -pad_b:, :].fill_(1) + mask[:, :, -pad_r:].fill_(1) + + # [B, _h, _w, window_size, window_size, C] + x = x.reshape(b, _h, self.window_size, _w, self.window_size, + c).transpose(2, 3) + mask = mask.reshape(1, _h, self.window_size, _w, + self.window_size).transpose(2, 3).reshape( + 1, _h * _w, + self.window_size * self.window_size) + # [1, _h*_w, window_size*window_size, window_size*window_size] + attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-1000.0)).masked_fill( + attn_mask == 0, float(0.0)) + + # [3, B, _w*_h, nhead, window_size*window_size, dim] + qkv = self.qkv(x).reshape(b, _h * _w, + self.window_size * self.window_size, 3, + self.num_heads, c // self.num_heads).permute( + 3, 0, 1, 4, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] + # [B, _h*_w, n_head, window_size*window_size, window_size*window_size] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn + attn_mask.unsqueeze(2) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(b, _h, _w, self.window_size, + self.window_size, c) + x = attn.transpose(2, 3).reshape(b, _h * self.window_size, + _w * self.window_size, c) + if pad_r > 0 or pad_b > 0: + x = x[:, :h, :w, :].contiguous() + + x = x.reshape(b, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LSAEncoderLayer(BaseModule): + """Implements one encoder layer in Twins-SVT. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + window_size (int): Window size of LSA. Default: 1. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + qk_scale=None, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + window_size=1, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1] + self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads, + qkv_bias, qk_scale, + attn_drop_rate, drop_rate, + window_size) + + self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1] + self.ffn = FFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg, + add_identity=False) + + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path_rate) + ) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, hw_shape): + x = x + self.drop_path(self.attn(self.norm1(x), hw_shape)) + x = x + self.drop_path(self.ffn(self.norm2(x))) + return x + + +class ConditionalPositionEncoding(BaseModule): + """The Conditional Position Encoding (CPE) module. + + The CPE is the implementation of 'Conditional Positional Encodings + for Vision Transformers '_. + + Args: + in_channels (int): Number of input channels. + embed_dims (int): The feature dimension. Default: 768. + stride (int): Stride of conv layer. Default: 1. + """ + + def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.proj = nn.Conv2d( + in_channels, + embed_dims, + kernel_size=3, + stride=stride, + padding=1, + bias=True, + groups=embed_dims) + self.stride = stride + + def forward(self, x, hw_shape): + b, n, c = x.shape + h, w = hw_shape + feat_token = x + cnn_feat = feat_token.transpose(1, 2).view(b, c, h, w) + if self.stride == 1: + x = self.proj(cnn_feat) + cnn_feat + else: + x = self.proj(cnn_feat) + x = x.flatten(2).transpose(1, 2) + return x + + +@MODELS.register_module() +class PCPVT(BaseModule): + """The backbone of Twins-PCPVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512]. + patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2]. + strides (list): The strides. Default: [4, 2, 2, 2]. + num_heads (int): Number of attention heads. Default: [1, 2, 4, 8]. + mlp_ratios (int): Ratio of mlp hidden dim to embedding dim. + Default: [4, 4, 4, 4]. + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool): Enable bias for qkv if True. Default: False. + drop_rate (float): Probability of an element to be zeroed. + Default 0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.0 + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + depths (list): Depths of each stage. Default [3, 4, 6, 3] + sr_ratios (list): Kernel_size of conv in each Attn module in + Transformer encoder layer. Default: [8, 4, 2, 1]. + norm_after_stage(bool): Add extra norm. Default False. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256, 512], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + num_heads=[1, 2, 4, 8], + mlp_ratios=[4, 4, 4, 4], + out_indices=(0, 1, 2, 3), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + norm_after_stage=False, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + self.depths = depths + + # patch_embed + self.patch_embeds = ModuleList() + self.position_encoding_drops = ModuleList() + self.layers = ModuleList() + + for i in range(len(depths)): + self.patch_embeds.append( + PatchEmbed( + in_channels=in_channels if i == 0 else embed_dims[i - 1], + embed_dims=embed_dims[i], + conv_type='Conv2d', + kernel_size=patch_sizes[i], + stride=strides[i], + padding='corner', + norm_cfg=norm_cfg)) + + self.position_encoding_drops.append(nn.Dropout(p=drop_rate)) + + self.position_encodings = ModuleList([ + ConditionalPositionEncoding(embed_dim, embed_dim) + for embed_dim in embed_dims + ]) + + # transformer encoder + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + + for k in range(len(depths)): + _block = ModuleList([ + GSAEncoderLayer( + embed_dims=embed_dims[k], + num_heads=num_heads[k], + feedforward_channels=mlp_ratios[k] * embed_dims[k], + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[cur + i], + num_fcs=2, + qkv_bias=qkv_bias, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + sr_ratio=sr_ratios[k]) for i in range(depths[k]) + ]) + self.layers.append(_block) + cur += depths[k] + + self.norm_name, norm = build_norm_layer( + norm_cfg, embed_dims[-1], postfix=1) + + self.out_indices = out_indices + self.norm_after_stage = norm_after_stage + if self.norm_after_stage: + self.norm_list = ModuleList() + for dim in embed_dims: + self.norm_list.append(build_norm_layer(norm_cfg, dim)[1]) + + def init_weights(self): + if self.init_cfg is not None: + super().init_weights() + else: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + + def forward(self, x): + outputs = list() + + b = x.shape[0] + + for i in range(len(self.depths)): + x, hw_shape = self.patch_embeds[i](x) + h, w = hw_shape + x = self.position_encoding_drops[i](x) + for j, blk in enumerate(self.layers[i]): + x = blk(x, hw_shape) + if j == 0: + x = self.position_encodings[i](x, hw_shape) + if self.norm_after_stage: + x = self.norm_list[i](x) + x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() + + if i in self.out_indices: + outputs.append(x) + + return tuple(outputs) + + +@MODELS.register_module() +class SVT(PCPVT): + """The backbone of Twins-SVT. + + This backbone is the implementation of `Twins: Revisiting the Design + of Spatial Attention in Vision Transformers + `_. + + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (list): Embedding dimension. Default: [64, 128, 256, 512]. + patch_sizes (list): The patch sizes. Default: [4, 2, 2, 2]. + strides (list): The strides. Default: [4, 2, 2, 2]. + num_heads (int): Number of attention heads. Default: [1, 2, 4]. + mlp_ratios (int): Ratio of mlp hidden dim to embedding dim. + Default: [4, 4, 4]. + out_indices (tuple[int]): Output from which stages. + Default: (0, 1, 2, 3). + qkv_bias (bool): Enable bias for qkv if True. Default: False. + drop_rate (float): Dropout rate. Default 0. + attn_drop_rate (float): Dropout ratio of attention weight. + Default 0.0 + drop_path_rate (float): Stochastic depth rate. Default 0.2. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + depths (list): Depths of each stage. Default [4, 4, 4]. + sr_ratios (list): Kernel_size of conv in each Attn module in + Transformer encoder layer. Default: [4, 2, 1]. + windiow_sizes (list): Window size of LSA. Default: [7, 7, 7], + input_features_slice(bool): Input features need slice. Default: False. + norm_after_stage(bool): Add extra norm. Default False. + strides (list): Strides in patch-Embedding modules. Default: (2, 2, 2) + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + out_indices=(0, 1, 2, 3), + qkv_bias=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_cfg=dict(type='LN'), + depths=[4, 4, 4], + sr_ratios=[4, 2, 1], + windiow_sizes=[7, 7, 7], + norm_after_stage=True, + pretrained=None, + init_cfg=None): + super().__init__(in_channels, embed_dims, patch_sizes, strides, + num_heads, mlp_ratios, out_indices, qkv_bias, + drop_rate, attn_drop_rate, drop_path_rate, norm_cfg, + depths, sr_ratios, norm_after_stage, pretrained, + init_cfg) + # transformer encoder + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + for k in range(len(depths)): + for i in range(depths[k]): + if i % 2 == 0: + self.layers[k][i] = \ + LSAEncoderLayer( + embed_dims=embed_dims[k], + num_heads=num_heads[k], + feedforward_channels=mlp_ratios[k] * embed_dims[k], + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[sum(depths[:k])+i], + qkv_bias=qkv_bias, + window_size=windiow_sizes[k]) diff --git a/finetune/mmseg/models/backbones/unet.py b/finetune/mmseg/models/backbones/unet.py new file mode 100644 index 0000000..545921d --- /dev/null +++ b/finetune/mmseg/models/backbones/unet.py @@ -0,0 +1,436 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from mmengine.model import BaseModule +from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm + +from mmseg.registry import MODELS +from ..utils import UpConvBlock, Upsample + + +class BasicConvBlock(nn.Module): + """Basic convolutional block for UNet. + + This module consists of several plain convolutional layers. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers. Default: 2. + stride (int): Whether use stride convolution to downsample + the input feature map. If stride=2, it only uses stride convolution + in the first convolutional layer to downsample the input feature + map. Options are 1 or 2. Default: 1. + dilation (int): Whether use dilated convolution to expand the + receptive field. Set dilation rate of each convolutional layer and + the dilation rate of the first convolutional layer is always 1. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.with_cp = with_cp + convs = [] + for i in range(num_convs): + convs.append( + ConvModule( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride if i == 0 else 1, + dilation=1 if i == 0 else dilation, + padding=1 if i == 0 else dilation, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + self.convs = nn.Sequential(*convs) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.convs, x) + else: + out = self.convs(x) + return out + + +@MODELS.register_module() +class DeconvModule(nn.Module): + """Deconvolution upsample module in decoder for UNet (2X upsample). + + This module uses deconvolution to upsample feature map in the decoder + of UNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + kernel_size (int): Kernel size of the convolutional layer. Default: 4. + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + kernel_size=4, + scale_factor=2): + super().__init__() + + assert (kernel_size - scale_factor >= 0) and\ + (kernel_size - scale_factor) % 2 == 0,\ + f'kernel_size should be greater than or equal to scale_factor '\ + f'and (kernel_size - scale_factor) should be even numbers, '\ + f'while the kernel size is {kernel_size} and scale_factor is '\ + f'{scale_factor}.' + + stride = scale_factor + padding = (kernel_size - scale_factor) // 2 + self.with_cp = with_cp + deconv = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + + norm_name, norm = build_norm_layer(norm_cfg, out_channels) + activate = build_activation_layer(act_cfg) + self.deconv_upsamping = nn.Sequential(deconv, norm, activate) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.deconv_upsamping, x) + else: + out = self.deconv_upsamping(x) + return out + + +@MODELS.register_module() +class InterpConv(nn.Module): + """Interpolation upsample module in decoder for UNet. + + This module uses interpolation to upsample feature map in the decoder + of UNet. It consists of one interpolation upsample layer and one + convolutional layer. It can be one interpolation upsample layer followed + by one convolutional layer (conv_first=False) or one convolutional layer + followed by one interpolation upsample layer (conv_first=True). + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + conv_first (bool): Whether convolutional layer or interpolation + upsample layer first. Default: False. It means interpolation + upsample layer followed by one convolutional layer. + kernel_size (int): Kernel size of the convolutional layer. Default: 1. + stride (int): Stride of the convolutional layer. Default: 1. + padding (int): Padding of the convolutional layer. Default: 1. + upsample_cfg (dict): Interpolation config of the upsample layer. + Default: dict( + scale_factor=2, mode='bilinear', align_corners=False). + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + conv_cfg=None, + conv_first=False, + kernel_size=1, + stride=1, + padding=0, + upsample_cfg=dict( + scale_factor=2, mode='bilinear', align_corners=False)): + super().__init__() + + self.with_cp = with_cp + conv = ConvModule( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + upsample = Upsample(**upsample_cfg) + if conv_first: + self.interp_upsample = nn.Sequential(conv, upsample) + else: + self.interp_upsample = nn.Sequential(upsample, conv) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.interp_upsample, x) + else: + out = self.interp_upsample(x) + return out + + +@MODELS.register_module() +class UNet(BaseModule): + """UNet backbone. + + This backbone is the implementation of `U-Net: Convolutional Networks + for Biomedical Image Segmentation `_. + + Args: + in_channels (int): Number of input image channels. Default" 3. + base_channels (int): Number of base channels of each stage. + The output channels of the first stage. Default: 64. + num_stages (int): Number of stages in encoder, normally 5. Default: 5. + strides (Sequence[int 1 | 2]): Strides of each stage in encoder. + len(strides) is equal to num_stages. Normally the stride of the + first stage in encoder is 1. If strides[i]=2, it uses stride + convolution to downsample in the correspondence encoder stage. + Default: (1, 1, 1, 1, 1). + enc_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence encoder stage. + Default: (2, 2, 2, 2, 2). + dec_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondence decoder stage. + Default: (2, 2, 2, 2). + downsamples (Sequence[int]): Whether use MaxPool to downsample the + feature map after the first stage of encoder + (stages: [1, num_stages)). If the correspondence encoder stage use + stride convolution (strides[i]=2), it will never use MaxPool to + downsample, even downsamples[i-1]=True. + Default: (True, True, True, True). + enc_dilations (Sequence[int]): Dilation rate of each stage in encoder. + Default: (1, 1, 1, 1, 1). + dec_dilations (Sequence[int]): Dilation rate of each stage in decoder. + Default: (1, 1, 1, 1). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Notice: + The input image size should be divisible by the whole downsample rate + of the encoder. More detail of the whole downsample rate can be found + in UNet._check_input_divisible. + """ + + def __init__(self, + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False, + dcn=None, + plugins=None, + pretrained=None, + init_cfg=None): + super().__init__(init_cfg) + + self.pretrained = pretrained + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + else: + raise TypeError('pretrained must be a str or None') + + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + assert len(strides) == num_stages, \ + 'The length of strides should be equal to num_stages, '\ + f'while the strides is {strides}, the length of '\ + f'strides is {len(strides)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_num_convs) == num_stages, \ + 'The length of enc_num_convs should be equal to num_stages, '\ + f'while the enc_num_convs is {enc_num_convs}, the length of '\ + f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_num_convs) == (num_stages-1), \ + 'The length of dec_num_convs should be equal to (num_stages-1), '\ + f'while the dec_num_convs is {dec_num_convs}, the length of '\ + f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(downsamples) == (num_stages-1), \ + 'The length of downsamples should be equal to (num_stages-1), '\ + f'while the downsamples is {downsamples}, the length of '\ + f'downsamples is {len(downsamples)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_dilations) == num_stages, \ + 'The length of enc_dilations should be equal to num_stages, '\ + f'while the enc_dilations is {enc_dilations}, the length of '\ + f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_dilations) == (num_stages-1), \ + 'The length of dec_dilations should be equal to (num_stages-1), '\ + f'while the dec_dilations is {dec_dilations}, the length of '\ + f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\ + f'{num_stages}.' + self.num_stages = num_stages + self.strides = strides + self.downsamples = downsamples + self.norm_eval = norm_eval + self.base_channels = base_channels + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for i in range(num_stages): + enc_conv_block = [] + if i != 0: + if strides[i] == 1 and downsamples[i - 1]: + enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) + upsample = (strides[i] != 1 or downsamples[i - 1]) + self.decoder.append( + UpConvBlock( + conv_block=BasicConvBlock, + in_channels=base_channels * 2**i, + skip_channels=base_channels * 2**(i - 1), + out_channels=base_channels * 2**(i - 1), + num_convs=dec_num_convs[i - 1], + stride=1, + dilation=dec_dilations[i - 1], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + upsample_cfg=upsample_cfg if upsample else None, + dcn=None, + plugins=None)) + + enc_conv_block.append( + BasicConvBlock( + in_channels=in_channels, + out_channels=base_channels * 2**i, + num_convs=enc_num_convs[i], + stride=strides[i], + dilation=enc_dilations[i], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None)) + self.encoder.append(nn.Sequential(*enc_conv_block)) + in_channels = base_channels * 2**i + + def forward(self, x): + self._check_input_divisible(x) + enc_outs = [] + for enc in self.encoder: + x = enc(x) + enc_outs.append(x) + dec_outs = [x] + for i in reversed(range(len(self.decoder))): + x = self.decoder[i](enc_outs[i], x) + dec_outs.append(x) + + return dec_outs + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _check_input_divisible(self, x): + h, w = x.shape[-2:] + whole_downsample_rate = 1 + for i in range(1, self.num_stages): + if self.strides[i] == 2 or self.downsamples[i - 1]: + whole_downsample_rate *= 2 + assert (h % whole_downsample_rate == 0) \ + and (w % whole_downsample_rate == 0),\ + f'The input image size {(h, w)} should be divisible by the whole '\ + f'downsample rate {whole_downsample_rate}, when num_stages is '\ + f'{self.num_stages}, strides is {self.strides}, and downsamples '\ + f'is {self.downsamples}.' diff --git a/finetune/mmseg/models/backbones/vit.py b/finetune/mmseg/models/backbones/vit.py new file mode 100644 index 0000000..dd0f688 --- /dev/null +++ b/finetune/mmseg/models/backbones/vit.py @@ -0,0 +1,501 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmengine.logging import print_log +from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.utils import _pair as to_2tuple + +from mmseg.registry import MODELS +from ..utils import PatchEmbed, resize + + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Vision Transformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): enable bias for qkv if True. Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + attn_cfg=dict(), + ffn_cfg=dict(), + with_cp=False): + super().__init__() + + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + attn_cfg.update( + dict( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + batch_first=batch_first, + bias=qkv_bias)) + + self.build_attn(attn_cfg) + + self.norm2_name, norm2 = build_norm_layer( + norm_cfg, embed_dims, postfix=2) + self.add_module(self.norm2_name, norm2) + + ffn_cfg.update( + dict( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate) + if drop_path_rate > 0 else None, + act_cfg=act_cfg)) + self.build_ffn(ffn_cfg) + self.with_cp = with_cp + + def build_attn(self, attn_cfg): + self.attn = MultiheadAttention(**attn_cfg) + + def build_ffn(self, ffn_cfg): + self.ffn = FFN(**ffn_cfg) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def forward(self, x): + + def _inner_forward(x): + x = self.attn(self.norm1(x), identity=x) + x = self.ffn(self.norm2(x), identity=x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@MODELS.register_module() +class VisionTransformer(BaseModule): + """Vision Transformer. + + This backbone is the implementation of `An Image is Worth 16x16 Words: + Transformers for Image Recognition at + Scale `_. + + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + patch_pad (str | int | None): The padding method in patch embedding. + Default: 'corner'. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): embedding dimension. Default: 768. + num_layers (int): depth of transformer. Default: 12. + num_heads (int): number of attention heads. Default: 12. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + out_origin (bool): Whether to output the original input embedding. + Default: False + out_indices (list | tuple | int): Output from which stages. + Default: -1. + qkv_bias (bool): enable bias for qkv if True. Default: True. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0 + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Default: True. + output_cls_token (bool): Whether output the cls_token. If set True, + `with_cls_token` must be True. Default: False. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_bias (dict): Whether use bias in convolution of PatchEmbed Block. + Default: True. + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + pre_norm (bool): Whether to add a norm before Transformer Layers. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Default: bicubic. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + frozen_exclude (List): List of parameters that are not to be frozen. + Default: ["all"], "all" means there are no frozen parameters. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + patch_pad='corner', + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_origin=False, + out_indices=-1, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + with_cls_token=True, + output_cls_token=False, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + patch_bias=False, + pre_norm=False, + final_norm=False, + interpolate_mode='bicubic', + num_fcs=2, + norm_eval=False, + with_cp=False, + frozen_exclude=['all'], + pretrained=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + if output_cls_token: + assert with_cls_token is True, f'with_cls_token must be True if' \ + f'set output_cls_token to True, but got {with_cls_token}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.img_size = img_size + self.patch_size = patch_size + self.interpolate_mode = interpolate_mode + self.norm_eval = norm_eval + self.with_cp = with_cp + self.pretrained = pretrained + self.out_origin = out_origin + self.frozen_exclude = frozen_exclude + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + padding=patch_pad, + bias=patch_bias, + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None, + ) + + num_patches = (img_size[0] // patch_size) * \ + (img_size[1] // patch_size) + + self.with_cls_token = with_cls_token + self.output_cls_token = output_cls_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dims)) + self.drop_after_pos = nn.Dropout(p=drop_rate) + self.pre_norm = pre_norm + + if self.pre_norm: + self.pre_ln_name, pre_ln = build_norm_layer( + norm_cfg, embed_dims, postfix='_pre') + self.add_module(self.pre_ln_name, pre_ln) + + if isinstance(out_indices, int): + if out_indices == -1: + out_indices = num_layers - 1 + self.out_indices = [out_indices] + elif isinstance(out_indices, list) or isinstance(out_indices, tuple): + self.out_indices = out_indices + else: + raise TypeError('out_indices must be type of int, list or tuple') + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, num_layers) + ] # stochastic depth decay rule + + self.layers = ModuleList() + for i in range(num_layers): + self.layers.append( + TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + batch_first=True)) + + self.final_norm = final_norm + if final_norm: + self.norm1_name, norm1 = build_norm_layer( + norm_cfg, embed_dims, postfix=1) + self.add_module(self.norm1_name, norm1) + + self._freeze() + + @property + def pre_ln(self): + return getattr(self, self.pre_ln_name) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def init_weights(self): + if isinstance(self.init_cfg, dict) and \ + self.init_cfg.get('type') in ['Pretrained', 'Pretrained_Part']: + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + + if self.init_cfg.get('type') == 'Pretrained': + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + elif self.init_cfg.get('type') == 'Pretrained_Part': + state_dict = checkpoint.copy() + para_prefix = 'image_encoder' + prefix_len = len(para_prefix) + 1 + for k, v in checkpoint.items(): + state_dict.pop(k) + if para_prefix in k: + state_dict[k[prefix_len:]] = v + + if 'pos_embed' in state_dict.keys(): + if self.pos_embed.shape != state_dict['pos_embed'].shape: + print_log(msg=f'Resize the pos_embed shape from ' + f'{state_dict["pos_embed"].shape} to ' + f'{self.pos_embed.shape}') + h, w = self.img_size + pos_size = int( + math.sqrt(state_dict['pos_embed'].shape[1] - 1)) + state_dict['pos_embed'] = self.resize_pos_embed( + state_dict['pos_embed'], + (h // self.patch_size, w // self.patch_size), + (pos_size, pos_size), self.interpolate_mode) + + load_state_dict(self, state_dict, strict=False, logger=None) + elif self.init_cfg is not None: + super().init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def _freeze(self): + if 'all' in self.frozen_exclude: + return + for name, param in self.named_parameters(): + if not any([exclude in name for exclude in self.frozen_exclude]): + param.requires_grad = False + + def _pos_embeding(self, patched_img, hw_shape, pos_embed): + """Positioning embeding method. + + Resize the pos_embed, if the input image size doesn't match + the training size. + Args: + patched_img (torch.Tensor): The patched image, it should be + shape of [B, L1, C]. + hw_shape (tuple): The downsampled image resolution. + pos_embed (torch.Tensor): The pos_embed weighs, it should be + shape of [B, L2, c]. + Return: + torch.Tensor: The pos encoded image feature. + """ + assert patched_img.ndim == 3 and pos_embed.ndim == 3, \ + 'the shapes of patched_img and pos_embed must be [B, L, C]' + x_len, pos_len = patched_img.shape[1], pos_embed.shape[1] + if x_len != pos_len: + if pos_len == (self.img_size[0] // self.patch_size) * ( + self.img_size[1] // self.patch_size) + 1: + pos_h = self.img_size[0] // self.patch_size + pos_w = self.img_size[1] // self.patch_size + else: + raise ValueError( + 'Unexpected shape of pos_embed, got {}.'.format( + pos_embed.shape)) + pos_embed = self.resize_pos_embed(pos_embed, hw_shape, + (pos_h, pos_w), + self.interpolate_mode) + return self.drop_after_pos(patched_img + pos_embed) + + @staticmethod + def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): + """Resize pos_embed weights. + + Resize pos_embed using bicubic interpolate method. + Args: + pos_embed (torch.Tensor): Position embedding weights. + input_shpae (tuple): Tuple for (downsampled input image height, + downsampled input image width). + pos_shape (tuple): The resolution of downsampled origin training + image. + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C] + """ + assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' + pos_h, pos_w = pos_shape + cls_token_weight = pos_embed[:, 0] + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] + pos_embed_weight = pos_embed_weight.reshape( + 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + pos_embed_weight = resize( + pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) + cls_token_weight = cls_token_weight.unsqueeze(1) + pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) + pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) + return pos_embed + + def forward(self, inputs): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = self._pos_embeding(x, hw_shape, self.pos_embed) + + if not self.with_cls_token: + # Remove class token for transformer encoder input + x = x[:, 1:] + + if self.pre_norm: + x = self.pre_ln(x) + + outs = [] + if self.out_origin: + if self.with_cls_token: + # Remove class token and reshape token for decoder head + out = x[:, 1:] + else: + out = x + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + if self.output_cls_token: + out = [out, x[:, 0]] + outs.append(out) + + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + if self.with_cls_token: + # Remove class token and reshape token for decoder head + out = x[:, 1:] + else: + out = x + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + if self.output_cls_token: + out = [out, x[:, 0]] + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + super().train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.LayerNorm): + m.eval() diff --git a/finetune/mmseg/models/backbones/vpd.py b/finetune/mmseg/models/backbones/vpd.py new file mode 100644 index 0000000..e0536d3 --- /dev/null +++ b/finetune/mmseg/models/backbones/vpd.py @@ -0,0 +1,395 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# ------------------------------------------------------------------------------ +# Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py +# Original licence: MIT License +# ------------------------------------------------------------------------------ + +import math +from typing import List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule +from mmengine.runner import CheckpointLoader, load_checkpoint + +from mmseg.registry import MODELS +from mmseg.utils import ConfigType, OptConfigType + +try: + from ldm.modules.diffusionmodules.util import timestep_embedding + from ldm.util import instantiate_from_config + has_ldm = True +except ImportError: + has_ldm = False + + +def register_attention_control(model, controller): + """Registers a control function to manage attention within a model. + + Args: + model: The model to which attention is to be registered. + controller: The control function responsible for managing attention. + """ + + def ca_forward(self, place_in_unet): + """Custom forward method for attention. + + Args: + self: Reference to the current object. + place_in_unet: The location in UNet (down/mid/up). + + Returns: + The modified forward method. + """ + + def forward(x, context=None, mask=None): + h = self.heads + is_cross = context is not None + context = context or x # if context is None, use x + + q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) + q, k, v = ( + tensor.view(tensor.shape[0] * h, tensor.shape[1], + tensor.shape[2] // h) for tensor in [q, k, v]) + + sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale + + if mask is not None: + mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1) + max_neg_value = -torch.finfo(sim.dtype).max + sim.masked_fill_(~mask, max_neg_value) + + attn = sim.softmax(dim=-1) + attn_mean = attn.view(h, attn.shape[0] // h, + *attn.shape[1:]).mean(0) + controller(attn_mean, is_cross, place_in_unet) + + out = torch.matmul(attn, v) + out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h) + return self.to_out(out) + + return forward + + def register_recr(net_, count, place_in_unet): + """Recursive function to register the custom forward method to all + CrossAttention layers. + + Args: + net_: The network layer currently being processed. + count: The current count of layers processed. + place_in_unet: The location in UNet (down/mid/up). + + Returns: + The updated count of layers processed. + """ + if net_.__class__.__name__ == 'CrossAttention': + net_.forward = ca_forward(net_, place_in_unet) + return count + 1 + if hasattr(net_, 'children'): + return sum( + register_recr(child, 0, place_in_unet) + for child in net_.children()) + return count + + cross_att_count = sum( + register_recr(net[1], 0, place) for net, place in [ + (child, 'down') if 'input_blocks' in name else ( + child, 'up') if 'output_blocks' in name else + (child, + 'mid') if 'middle_block' in name else (None, None) # Default case + for name, child in model.diffusion_model.named_children() + ] if net is not None) + + controller.num_att_layers = cross_att_count + + +class AttentionStore: + """A class for storing attention information in the UNet model. + + Attributes: + base_size (int): Base size for storing attention information. + max_size (int): Maximum size for storing attention information. + """ + + def __init__(self, base_size=64, max_size=None): + """Initialize AttentionStore with default or custom sizes.""" + self.reset() + self.base_size = base_size + self.max_size = max_size or (base_size // 2) + self.num_att_layers = -1 + + @staticmethod + def get_empty_store(): + """Returns an empty store for holding attention values.""" + return { + key: [] + for key in [ + 'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self', + 'up_self' + ] + } + + def reset(self): + """Resets the step and attention stores to their initial states.""" + self.cur_step = 0 + self.cur_att_layer = 0 + self.step_store = self.get_empty_store() + self.attention_store = {} + + def forward(self, attn, is_cross: bool, place_in_unet: str): + """Processes a single forward step, storing the attention. + + Args: + attn: The attention tensor. + is_cross (bool): Whether it's cross attention. + place_in_unet (str): The location in UNet (down/mid/up). + + Returns: + The unmodified attention tensor. + """ + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + if attn.shape[1] <= (self.max_size)**2: + self.step_store[key].append(attn) + return attn + + def between_steps(self): + """Processes and stores attention information between steps.""" + if not self.attention_store: + self.attention_store = self.step_store + else: + for key in self.attention_store: + self.attention_store[key] = [ + stored + step for stored, step in zip( + self.attention_store[key], self.step_store[key]) + ] + self.step_store = self.get_empty_store() + + def get_average_attention(self): + """Calculates and returns the average attention across all steps.""" + return { + key: [item for item in self.step_store[key]] + for key in self.step_store + } + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + """Allows the class instance to be callable.""" + return self.forward(attn, is_cross, place_in_unet) + + @property + def num_uncond_att_layers(self): + """Returns the number of unconditional attention layers (default is + 0).""" + return 0 + + def step_callback(self, x_t): + """A placeholder for a step callback. + + Returns the input unchanged. + """ + return x_t + + +class UNetWrapper(nn.Module): + """A wrapper for UNet with optional attention mechanisms. + + Args: + unet (nn.Module): The UNet model to wrap + use_attn (bool): Whether to use attention. Defaults to True + base_size (int): Base size for the attention store. Defaults to 512 + max_attn_size (int, optional): Maximum size for the attention store. + Defaults to None + attn_selector (str): The types of attention to use. + Defaults to 'up_cross+down_cross' + """ + + def __init__(self, + unet, + use_attn=True, + base_size=512, + max_attn_size=None, + attn_selector='up_cross+down_cross'): + super().__init__() + + assert has_ldm, 'To use UNetWrapper, please install required ' \ + 'packages via `pip install -r requirements/optional.txt`.' + + self.unet = unet + self.attention_store = AttentionStore( + base_size=base_size // 8, max_size=max_attn_size) + self.attn_selector = attn_selector.split('+') + self.use_attn = use_attn + self.init_sizes(base_size) + if self.use_attn: + register_attention_control(unet, self.attention_store) + + def init_sizes(self, base_size): + """Initialize sizes based on the base size.""" + self.size16 = base_size // 32 + self.size32 = base_size // 16 + self.size64 = base_size // 8 + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """Forward pass through the model.""" + diffusion_model = self.unet.diffusion_model + if self.use_attn: + self.attention_store.reset() + hs, emb, out_list = self._unet_forward(x, timesteps, context, y, + diffusion_model) + if self.use_attn: + self._append_attn_to_output(out_list) + return out_list[::-1] + + def _unet_forward(self, x, timesteps, context, y, diffusion_model): + hs = [] + t_emb = timestep_embedding( + timesteps, diffusion_model.model_channels, repeat_only=False) + emb = diffusion_model.time_embed(t_emb) + h = x.type(diffusion_model.dtype) + for module in diffusion_model.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = diffusion_model.middle_block(h, emb, context) + out_list = [] + for i_out, module in enumerate(diffusion_model.output_blocks): + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + if i_out in [1, 4, 7]: + out_list.append(h) + h = h.type(x.dtype) + out_list.append(h) + return hs, emb, out_list + + def _append_attn_to_output(self, out_list): + avg_attn = self.attention_store.get_average_attention() + attns = {self.size16: [], self.size32: [], self.size64: []} + for k in self.attn_selector: + for up_attn in avg_attn[k]: + size = int(math.sqrt(up_attn.shape[1])) + up_attn = up_attn.transpose(-1, -2).reshape( + *up_attn.shape[:2], size, -1) + attns[size].append(up_attn) + attn16 = torch.stack(attns[self.size16]).mean(0) + attn32 = torch.stack(attns[self.size32]).mean(0) + attn64 = torch.stack(attns[self.size64]).mean(0) if len( + attns[self.size64]) > 0 else None + out_list[1] = torch.cat([out_list[1], attn16], dim=1) + out_list[2] = torch.cat([out_list[2], attn32], dim=1) + if attn64 is not None: + out_list[3] = torch.cat([out_list[3], attn64], dim=1) + + +class TextAdapter(nn.Module): + """A PyTorch Module that serves as a text adapter. + + This module takes text embeddings and adjusts them based on a scaling + factor gamma. + """ + + def __init__(self, text_dim=768): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(text_dim, text_dim), nn.GELU(), + nn.Linear(text_dim, text_dim)) + + def forward(self, texts, gamma): + texts_after = self.fc(texts) + texts = texts + gamma * texts_after + return texts + + +@MODELS.register_module() +class VPD(BaseModule): + """VPD (Visual Perception Diffusion) model. + + .. _`VPD`: https://arxiv.org/abs/2303.02153 + + Args: + diffusion_cfg (dict): Configuration for diffusion model. + class_embed_path (str): Path for class embeddings. + unet_cfg (dict, optional): Configuration for U-Net. + gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4. + class_embed_select (bool, optional): If True, enables class embedding + selection. Defaults to False. + pad_shape (Optional[Union[int, List[int]]], optional): Padding shape. + Defaults to None. + pad_val (Union[int, List[int]], optional): Padding value. + Defaults to 0. + init_cfg (dict, optional): Configuration for network initialization. + """ + + def __init__(self, + diffusion_cfg: ConfigType, + class_embed_path: str, + unet_cfg: OptConfigType = dict(), + gamma: float = 1e-4, + class_embed_select=False, + pad_shape: Optional[Union[int, List[int]]] = None, + pad_val: Union[int, List[int]] = 0, + init_cfg: OptConfigType = None): + + super().__init__(init_cfg=init_cfg) + + assert has_ldm, 'To use VPD model, please install required packages' \ + ' via `pip install -r requirements/optional.txt`.' + + if pad_shape is not None: + if not isinstance(pad_shape, (list, tuple)): + pad_shape = (pad_shape, pad_shape) + + self.pad_shape = pad_shape + self.pad_val = pad_val + + # diffusion model + diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None) + sd_model = instantiate_from_config(diffusion_cfg) + if diffusion_checkpoint is not None: + load_checkpoint(sd_model, diffusion_checkpoint, strict=False) + + self.encoder_vq = sd_model.first_stage_model + self.unet = UNetWrapper(sd_model.model, **unet_cfg) + + # class embeddings & text adapter + class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path) + text_dim = class_embeddings.size(-1) + self.text_adapter = TextAdapter(text_dim=text_dim) + self.class_embed_select = class_embed_select + if class_embed_select: + class_embeddings = torch.cat( + (class_embeddings, class_embeddings.mean(dim=0, + keepdims=True)), + dim=0) + self.register_buffer('class_embeddings', class_embeddings) + self.gamma = nn.Parameter(torch.ones(text_dim) * gamma) + + def forward(self, x): + """Extract features from images.""" + + # calculate cross-attn map + if self.class_embed_select: + if isinstance(x, (tuple, list)): + x, class_ids = x[:2] + class_ids = class_ids.tolist() + else: + class_ids = [-1] * x.size(0) + class_embeddings = self.class_embeddings[class_ids] + c_crossattn = self.text_adapter(class_embeddings, self.gamma) + c_crossattn = c_crossattn.unsqueeze(1) + else: + class_embeddings = self.class_embeddings + c_crossattn = self.text_adapter(class_embeddings, self.gamma) + c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1) + + # pad to required input shape for pretrained diffusion model + if self.pad_shape is not None: + pad_width = max(0, self.pad_shape[1] - x.shape[-1]) + pad_height = max(0, self.pad_shape[0] - x.shape[-2]) + x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val) + + # forward the denoising model + with torch.no_grad(): + latents = self.encoder_vq.encode(x).mode().detach() + t = torch.ones((x.shape[0], ), device=x.device).long() + outs = self.unet(latents, t, context=c_crossattn) + + return outs diff --git a/finetune/mmseg/models/builder.py b/finetune/mmseg/models/builder.py new file mode 100644 index 0000000..081c646 --- /dev/null +++ b/finetune/mmseg/models/builder.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmseg.registry import MODELS + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +SEGMENTORS = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + warnings.warn('``build_backbone`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + warnings.warn('``build_neck`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + warnings.warn('``build_head`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + warnings.warn('``build_loss`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') + return LOSSES.build(cfg) + + +def build_segmentor(cfg, train_cfg=None, test_cfg=None): + """Build segmentor.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn( + 'train_cfg and test_cfg is deprecated, ' + 'please specify them in model', UserWarning) + assert cfg.get('train_cfg') is None or train_cfg is None, \ + 'train_cfg specified in both outer field and model field ' + assert cfg.get('test_cfg') is None or test_cfg is None, \ + 'test_cfg specified in both outer field and model field ' + return SEGMENTORS.build( + cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/finetune/mmseg/models/data_preprocessor.py b/finetune/mmseg/models/data_preprocessor.py new file mode 100644 index 0000000..8d32bc6 --- /dev/null +++ b/finetune/mmseg/models/data_preprocessor.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from numbers import Number +from typing import Any, Dict, List, Optional, Sequence + +import torch +from mmengine.model import BaseDataPreprocessor + +from mmseg.registry import MODELS +from mmseg.utils import stack_batch + + +@MODELS.register_module() +class SegDataPreProcessor(BaseDataPreprocessor): + """Image pre-processor for segmentation tasks. + + Comparing with the :class:`mmengine.ImgDataPreprocessor`, + + 1. It won't do normalization if ``mean`` is not specified. + 2. It does normalization and color space conversion after stacking batch. + 3. It supports batch augmentations like mixup and cutmix. + + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the input size with defined ``pad_val``, and pad seg map + with defined ``seg_pad_val``. + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + - Do batch augmentations like Mixup and Cutmix during training. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B channels. + Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (float, optional): Padding value. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + padding_mode (str): Type of padding. Default: constant. + - constant: pads with a constant value, this value is specified + with pad_val. + bgr_to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + rgb_to_bgr (bool): whether to convert image from RGB to RGB. + Defaults to False. + batch_augments (list[dict], optional): Batch-level augmentations + test_cfg (dict, optional): The padding size config in testing, if not + specify, will use `size` and `size_divisor` params as default. + Defaults to None, only supports keys `size` or `size_divisor`. + """ + + def __init__( + self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + size: Optional[tuple] = None, + size_divisor: Optional[int] = None, + pad_val: Number = 0, + seg_pad_val: Number = 255, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + batch_augments: Optional[List[dict]] = None, + test_cfg: dict = None, + ): + super().__init__() + self.size = size + self.size_divisor = size_divisor + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + assert not (bgr_to_rgb and rgb_to_bgr), ( + '`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time') + self.channel_conversion = rgb_to_bgr or bgr_to_rgb + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both ' \ + '`mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + self.register_buffer('mean', + torch.tensor(mean).view(-1, 1, 1), False) + self.register_buffer('std', + torch.tensor(std).view(-1, 1, 1), False) + else: + self._enable_normalize = False + + # TODO: support batch augmentations. + self.batch_augments = batch_augments + + # Support different padding methods in testing + self.test_cfg = test_cfg + + def forward(self, data: dict, training: bool = False) -> Dict[str, Any]: + """Perform normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + Dict: Data in the same format as the model input. + """ + data = self.cast_data(data) # type: ignore + inputs = data['inputs'] + data_samples = data.get('data_samples', None) + # TODO: whether normalize should be after stack_batch + if self.channel_conversion and inputs[0].size(0) == 3: + inputs = [_input[[2, 1, 0], ...] for _input in inputs] + + inputs = [_input.float() for _input in inputs] + if self._enable_normalize: + inputs = [(_input - self.mean) / self.std for _input in inputs] + + if training: + assert data_samples is not None, ('During training, ', + '`data_samples` must be define.') + inputs, data_samples = stack_batch( + inputs=inputs, + data_samples=data_samples, + size=self.size, + size_divisor=self.size_divisor, + pad_val=self.pad_val, + seg_pad_val=self.seg_pad_val) + + if self.batch_augments is not None: + inputs, data_samples = self.batch_augments( + inputs, data_samples) + else: + img_size = inputs[0].shape[1:] + assert all(input_.shape[1:] == img_size for input_ in inputs), \ + 'The image size in a batch should be the same.' + # pad images when testing + if self.test_cfg: + inputs, padded_samples = stack_batch( + inputs=inputs, + size=self.test_cfg.get('size', None), + size_divisor=self.test_cfg.get('size_divisor', None), + pad_val=self.pad_val, + seg_pad_val=self.seg_pad_val) + for data_sample, pad_info in zip(data_samples, padded_samples): + data_sample.set_metainfo({**pad_info}) + else: + inputs = torch.stack(inputs, dim=0) + + return dict(inputs=inputs, data_samples=data_samples) diff --git a/finetune/mmseg/models/decode_heads/__init__.py b/finetune/mmseg/models/decode_heads/__init__.py new file mode 100644 index 0000000..4229763 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/__init__.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .ann_head import ANNHead +from .apc_head import APCHead +from .aspp_head import ASPPHead +from .cc_head import CCHead +from .da_head import DAHead +from .ddr_head import DDRHead +from .dm_head import DMHead +from .dnl_head import DNLHead +from .dpt_head import DPTHead +from .ema_head import EMAHead +from .enc_head import EncHead +from .fcn_head import FCNHead +from .fpn_head import FPNHead +from .gc_head import GCHead +from .ham_head import LightHamHead +from .isa_head import ISAHead +from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator +from .lraspp_head import LRASPPHead +from .mask2former_head import Mask2FormerHead +from .maskformer_head import MaskFormerHead +from .nl_head import NLHead +from .ocr_head import OCRHead +from .pid_head import PIDHead +from .point_head import PointHead +from .psa_head import PSAHead +from .psp_head import PSPHead +from .san_head import SideAdapterCLIPHead +from .segformer_head import SegformerHead +from .segmenter_mask_head import SegmenterMaskTransformerHead +from .sep_aspp_head import DepthwiseSeparableASPPHead +from .sep_fcn_head import DepthwiseSeparableFCNHead +from .setr_mla_head import SETRMLAHead +from .setr_up_head import SETRUPHead +from .stdc_head import STDCHead +from .uper_head import UPerHead +from .vpd_depth_head import VPDDepthHead + +__all__ = [ + 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', + 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', + 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', + 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', + 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', + 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', + 'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead', + 'LightHamHead', 'PIDHead', 'DDRHead', 'VPDDepthHead', 'SideAdapterCLIPHead' +] diff --git a/finetune/mmseg/models/decode_heads/ann_head.py b/finetune/mmseg/models/decode_heads/ann_head.py new file mode 100644 index 0000000..2b40ef5 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/ann_head.py @@ -0,0 +1,245 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class PPMConcat(nn.ModuleList): + """Pyramid Pooling Module that only concat the features of each layer. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + """ + + def __init__(self, pool_scales=(1, 3, 6, 8)): + super().__init__( + [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales]) + + def forward(self, feats): + """Forward function.""" + ppm_outs = [] + for ppm in self: + ppm_out = ppm(feats) + ppm_outs.append(ppm_out.view(*feats.shape[:2], -1)) + concat_outs = torch.cat(ppm_outs, dim=2) + return concat_outs + + +class SelfAttentionBlock(_SelfAttentionBlock): + """Make a ANN used SelfAttentionBlock. + + Args: + low_in_channels (int): Input channels of lower level feature, + which is the key feature for self-attention. + high_in_channels (int): Input channels of higher level feature, + which is the query feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + share_key_query (bool): Whether share projection weight between key + and query projection. + query_scale (int): The scale of query feature map. + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, low_in_channels, high_in_channels, channels, + out_channels, share_key_query, query_scale, key_pool_scales, + conv_cfg, norm_cfg, act_cfg): + key_psp = PPMConcat(key_pool_scales) + if query_scale > 1: + query_downsample = nn.MaxPool2d(kernel_size=query_scale) + else: + query_downsample = None + super().__init__( + key_in_channels=low_in_channels, + query_in_channels=high_in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=share_key_query, + query_downsample=query_downsample, + key_downsample=key_psp, + key_query_num_convs=1, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + +class AFNB(nn.Module): + """Asymmetric Fusion Non-local Block(AFNB) + + Args: + low_in_channels (int): Input channels of lower level feature, + which is the key feature for self-attention. + high_in_channels (int): Input channels of higher level feature, + which is the query feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + and query projection. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, low_in_channels, high_in_channels, channels, + out_channels, query_scales, key_pool_scales, conv_cfg, + norm_cfg, act_cfg): + super().__init__() + self.stages = nn.ModuleList() + for query_scale in query_scales: + self.stages.append( + SelfAttentionBlock( + low_in_channels=low_in_channels, + high_in_channels=high_in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=False, + query_scale=query_scale, + key_pool_scales=key_pool_scales, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.bottleneck = ConvModule( + out_channels + high_in_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, low_feats, high_feats): + """Forward function.""" + priors = [stage(high_feats, low_feats) for stage in self.stages] + context = torch.stack(priors, dim=0).sum(dim=0) + output = self.bottleneck(torch.cat([context, high_feats], 1)) + return output + + +class APNB(nn.Module): + """Asymmetric Pyramid Non-local Block (APNB) + + Args: + in_channels (int): Input channels of key/query feature, + which is the key feature for self-attention. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module of key feature. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, in_channels, channels, out_channels, query_scales, + key_pool_scales, conv_cfg, norm_cfg, act_cfg): + super().__init__() + self.stages = nn.ModuleList() + for query_scale in query_scales: + self.stages.append( + SelfAttentionBlock( + low_in_channels=in_channels, + high_in_channels=in_channels, + channels=channels, + out_channels=out_channels, + share_key_query=True, + query_scale=query_scale, + key_pool_scales=key_pool_scales, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.bottleneck = ConvModule( + 2 * in_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, feats): + """Forward function.""" + priors = [stage(feats, feats) for stage in self.stages] + context = torch.stack(priors, dim=0).sum(dim=0) + output = self.bottleneck(torch.cat([context, feats], 1)) + return output + + +@MODELS.register_module() +class ANNHead(BaseDecodeHead): + """Asymmetric Non-local Neural Networks for Semantic Segmentation. + + This head is the implementation of `ANNNet + `_. + + Args: + project_channels (int): Projection channels for Nonlocal. + query_scales (tuple[int]): The scales of query feature map. + Default: (1,) + key_pool_scales (tuple[int]): The pooling scales of key feature map. + Default: (1, 3, 6, 8). + """ + + def __init__(self, + project_channels, + query_scales=(1, ), + key_pool_scales=(1, 3, 6, 8), + **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + assert len(self.in_channels) == 2 + low_in_channels, high_in_channels = self.in_channels + self.project_channels = project_channels + self.fusion = AFNB( + low_in_channels=low_in_channels, + high_in_channels=high_in_channels, + out_channels=high_in_channels, + channels=project_channels, + query_scales=query_scales, + key_pool_scales=key_pool_scales, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + high_in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.context = APNB( + in_channels=self.channels, + out_channels=self.channels, + channels=project_channels, + query_scales=query_scales, + key_pool_scales=key_pool_scales, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + low_feats, high_feats = self._transform_inputs(inputs) + output = self.fusion(low_feats, high_feats) + output = self.dropout(output) + output = self.bottleneck(output) + output = self.context(output) + output = self.cls_seg(output) + + return output diff --git a/finetune/mmseg/models/decode_heads/apc_head.py b/finetune/mmseg/models/decode_heads/apc_head.py new file mode 100644 index 0000000..728f396 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/apc_head.py @@ -0,0 +1,159 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class ACM(nn.Module): + """Adaptive Context Module used in APCNet. + + Args: + pool_scale (int): Pooling scale used in Adaptive Context + Module to extract region features. + fusion (bool): Add one conv to fuse residual feature. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg, + norm_cfg, act_cfg): + super().__init__() + self.pool_scale = pool_scale + self.fusion = fusion + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.pooled_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.input_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.global_info = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0) + + self.residual_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + if self.fusion: + self.fusion_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, x): + """Forward function.""" + pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale) + # [batch_size, channels, h, w] + x = self.input_redu_conv(x) + # [batch_size, channels, pool_scale, pool_scale] + pooled_x = self.pooled_redu_conv(pooled_x) + batch_size = x.size(0) + # [batch_size, pool_scale * pool_scale, channels] + pooled_x = pooled_x.view(batch_size, self.channels, + -1).permute(0, 2, 1).contiguous() + # [batch_size, h * w, pool_scale * pool_scale] + affinity_matrix = self.gla(x + resize( + self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:]) + ).permute(0, 2, 3, 1).reshape( + batch_size, -1, self.pool_scale**2) + affinity_matrix = F.sigmoid(affinity_matrix) + # [batch_size, h * w, channels] + z_out = torch.matmul(affinity_matrix, pooled_x) + # [batch_size, channels, h * w] + z_out = z_out.permute(0, 2, 1).contiguous() + # [batch_size, channels, h, w] + z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3)) + z_out = self.residual_conv(z_out) + z_out = F.relu(z_out + x) + if self.fusion: + z_out = self.fusion_conv(z_out) + + return z_out + + +@MODELS.register_module() +class APCHead(BaseDecodeHead): + """Adaptive Pyramid Context Network for Semantic Segmentation. + + This head is the implementation of + `APCNet `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Adaptive Context + Module. Default: (1, 2, 3, 6). + fusion (bool): Add one conv to fuse residual feature. + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs): + super().__init__(**kwargs) + assert isinstance(pool_scales, (list, tuple)) + self.pool_scales = pool_scales + self.fusion = fusion + acm_modules = [] + for pool_scale in self.pool_scales: + acm_modules.append( + ACM(pool_scale, + self.fusion, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.acm_modules = nn.ModuleList(acm_modules) + self.bottleneck = ConvModule( + self.in_channels + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + acm_outs = [x] + for acm_module in self.acm_modules: + acm_outs.append(acm_module(x)) + acm_outs = torch.cat(acm_outs, dim=1) + output = self.bottleneck(acm_outs) + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/aspp_head.py b/finetune/mmseg/models/decode_heads/aspp_head.py new file mode 100644 index 0000000..6d7185d --- /dev/null +++ b/finetune/mmseg/models/decode_heads/aspp_head.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class ASPPModule(nn.ModuleList): + """Atrous Spatial Pyramid Pooling (ASPP) Module. + + Args: + dilations (tuple[int]): Dilation rate of each layer. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg, + act_cfg): + super().__init__() + self.dilations = dilations + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + for dilation in dilations: + self.append( + ConvModule( + self.in_channels, + self.channels, + 1 if dilation == 1 else 3, + dilation=dilation, + padding=0 if dilation == 1 else dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def forward(self, x): + """Forward function.""" + aspp_outs = [] + for aspp_module in self: + aspp_outs.append(aspp_module(x)) + + return aspp_outs + + +@MODELS.register_module() +class ASPPHead(BaseDecodeHead): + """Rethinking Atrous Convolution for Semantic Image Segmentation. + + This head is the implementation of `DeepLabV3 + `_. + + Args: + dilations (tuple[int]): Dilation rates for ASPP module. + Default: (1, 6, 12, 18). + """ + + def __init__(self, dilations=(1, 6, 12, 18), **kwargs): + super().__init__(**kwargs) + assert isinstance(dilations, (list, tuple)) + self.dilations = dilations + self.image_pool = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.aspp_modules = ASPPModule( + dilations, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + (len(dilations) + 1) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + aspp_outs = [ + resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ] + aspp_outs.extend(self.aspp_modules(x)) + aspp_outs = torch.cat(aspp_outs, dim=1) + feats = self.bottleneck(aspp_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/cascade_decode_head.py b/finetune/mmseg/models/decode_heads/cascade_decode_head.py new file mode 100644 index 0000000..fe2bcb9 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/cascade_decode_head.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List + +from torch import Tensor + +from mmseg.utils import ConfigType +from .decode_head import BaseDecodeHead + + +class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): + """Base class for cascade decode head used in + :class:`CascadeEncoderDecoder.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @abstractmethod + def forward(self, inputs, prev_output): + """Placeholder of forward function.""" + pass + + def loss(self, inputs: List[Tensor], prev_output: Tensor, + batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor: + """Forward function for training. + + Args: + inputs (List[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + batch_data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs, prev_output) + losses = self.loss_by_feat(seg_logits, batch_data_samples) + + return losses + + def predict(self, inputs: List[Tensor], prev_output: Tensor, + batch_img_metas: List[dict], tese_cfg: ConfigType): + """Forward function for testing. + + Args: + inputs (List[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + batch_img_metas (dict): List Image info where each dict may also + contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + seg_logits = self.forward(inputs, prev_output) + + return self.predict_by_feat(seg_logits, batch_img_metas) diff --git a/finetune/mmseg/models/decode_heads/cc_head.py b/finetune/mmseg/models/decode_heads/cc_head.py new file mode 100644 index 0000000..e9075a2 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/cc_head.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + +try: + from mmcv.ops import CrissCrossAttention +except ModuleNotFoundError: + CrissCrossAttention = None + + +@MODELS.register_module() +class CCHead(FCNHead): + """CCNet: Criss-Cross Attention for Semantic Segmentation. + + This head is the implementation of `CCNet + `_. + + Args: + recurrence (int): Number of recurrence of Criss Cross Attention + module. Default: 2. + """ + + def __init__(self, recurrence=2, **kwargs): + if CrissCrossAttention is None: + raise RuntimeError('Please install mmcv-full for ' + 'CrissCrossAttention ops') + super().__init__(num_convs=2, **kwargs) + self.recurrence = recurrence + self.cca = CrissCrossAttention(self.channels) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + for _ in range(self.recurrence): + output = self.cca(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/da_head.py b/finetune/mmseg/models/decode_heads/da_head.py new file mode 100644 index 0000000..d872143 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/da_head.py @@ -0,0 +1,184 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule, Scale +from torch import Tensor, nn + +from mmseg.registry import MODELS +from mmseg.utils import SampleList, add_prefix +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class PAM(_SelfAttentionBlock): + """Position Attention Module (PAM) + + Args: + in_channels (int): Input channels of key/query feature. + channels (int): Output channels of key/query transform. + """ + + def __init__(self, in_channels, channels): + super().__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=None, + key_downsample=None, + key_query_num_convs=1, + key_query_norm=False, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=False, + with_out=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None) + + self.gamma = Scale(0) + + def forward(self, x): + """Forward function.""" + out = super().forward(x, x) + + out = self.gamma(out) + x + return out + + +class CAM(nn.Module): + """Channel Attention Module (CAM)""" + + def __init__(self): + super().__init__() + self.gamma = Scale(0) + + def forward(self, x): + """Forward function.""" + batch_size, channels, height, width = x.size() + proj_query = x.view(batch_size, channels, -1) + proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) + energy = torch.bmm(proj_query, proj_key) + energy_new = torch.max( + energy, -1, keepdim=True)[0].expand_as(energy) - energy + attention = F.softmax(energy_new, dim=-1) + proj_value = x.view(batch_size, channels, -1) + + out = torch.bmm(attention, proj_value) + out = out.view(batch_size, channels, height, width) + + out = self.gamma(out) + x + return out + + +@MODELS.register_module() +class DAHead(BaseDecodeHead): + """Dual Attention Network for Scene Segmentation. + + This head is the implementation of `DANet + `_. + + Args: + pam_channels (int): The channels of Position Attention Module(PAM). + """ + + def __init__(self, pam_channels, **kwargs): + super().__init__(**kwargs) + self.pam_channels = pam_channels + self.pam_in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.pam = PAM(self.channels, pam_channels) + self.pam_out_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.pam_conv_seg = nn.Conv2d( + self.channels, self.num_classes, kernel_size=1) + + self.cam_in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.cam = CAM() + self.cam_out_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.cam_conv_seg = nn.Conv2d( + self.channels, self.num_classes, kernel_size=1) + + def pam_cls_seg(self, feat): + """PAM feature classification.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.pam_conv_seg(feat) + return output + + def cam_cls_seg(self, feat): + """CAM feature classification.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.cam_conv_seg(feat) + return output + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + pam_feat = self.pam_in_conv(x) + pam_feat = self.pam(pam_feat) + pam_feat = self.pam_out_conv(pam_feat) + pam_out = self.pam_cls_seg(pam_feat) + + cam_feat = self.cam_in_conv(x) + cam_feat = self.cam(cam_feat) + cam_feat = self.cam_out_conv(cam_feat) + cam_out = self.cam_cls_seg(cam_feat) + + feat_sum = pam_feat + cam_feat + pam_cam_out = self.cls_seg(feat_sum) + + return pam_cam_out, pam_out, cam_out + + def predict(self, inputs, batch_img_metas: List[dict], test_cfg, + **kwargs) -> List[Tensor]: + """Forward function for testing, only ``pam_cam`` is used.""" + seg_logits = self.forward(inputs)[0] + return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs) + + def loss_by_feat(self, seg_logit: Tuple[Tensor], + batch_data_samples: SampleList, **kwargs) -> dict: + """Compute ``pam_cam``, ``pam``, ``cam`` loss.""" + pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit + loss = dict() + loss.update( + add_prefix( + super().loss_by_feat(pam_cam_seg_logit, batch_data_samples), + 'pam_cam')) + loss.update( + add_prefix(super().loss_by_feat(pam_seg_logit, batch_data_samples), + 'pam')) + loss.update( + add_prefix(super().loss_by_feat(cam_seg_logit, batch_data_samples), + 'cam')) + return loss diff --git a/finetune/mmseg/models/decode_heads/ddr_head.py b/finetune/mmseg/models/decode_heads/ddr_head.py new file mode 100644 index 0000000..ba26d65 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/ddr_head.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Union + +import torch.nn as nn +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from torch import Tensor + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.losses import accuracy +from mmseg.models.utils import resize +from mmseg.registry import MODELS +from mmseg.utils import OptConfigType, SampleList + + +@MODELS.register_module() +class DDRHead(BaseDecodeHead): + """Decode head for DDRNet. + + Args: + in_channels (int): Number of input channels. + channels (int): Number of output channels. + num_classes (int): Number of classes. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + """ + + def __init__(self, + in_channels: int, + channels: int, + num_classes: int, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + **kwargs): + super().__init__( + in_channels, + channels, + num_classes=num_classes, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs) + + self.head = self._make_base_head(self.in_channels, self.channels) + self.aux_head = self._make_base_head(self.in_channels // 2, + self.channels) + self.aux_cls_seg = nn.Conv2d( + self.channels, self.out_channels, kernel_size=1) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward( + self, + inputs: Union[Tensor, + Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]: + if self.training: + c3_feat, c5_feat = inputs + x_c = self.head(c5_feat) + x_c = self.cls_seg(x_c) + x_s = self.aux_head(c3_feat) + x_s = self.aux_cls_seg(x_s) + + return x_c, x_s + else: + x_c = self.head(inputs) + x_c = self.cls_seg(x_c) + return x_c + + def _make_base_head(self, in_channels: int, + channels: int) -> nn.Sequential: + layers = [ + ConvModule( + in_channels, + channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + order=('norm', 'act', 'conv')), + build_norm_layer(self.norm_cfg, channels)[1], + build_activation_layer(self.act_cfg), + ] + + return nn.Sequential(*layers) + + def loss_by_feat(self, seg_logits: Tuple[Tensor], + batch_data_samples: SampleList) -> dict: + loss = dict() + context_logit, spatial_logit = seg_logits + seg_label = self._stack_batch_gt(batch_data_samples) + + context_logit = resize( + context_logit, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + spatial_logit = resize( + spatial_logit, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + seg_label = seg_label.squeeze(1) + + loss['loss_context'] = self.loss_decode[0](context_logit, seg_label) + loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label) + loss['acc_seg'] = accuracy( + context_logit, seg_label, ignore_index=self.ignore_index) + + return loss diff --git a/finetune/mmseg/models/decode_heads/decode_head.py b/finetune/mmseg/models/decode_heads/decode_head.py new file mode 100644 index 0000000..fd53afe --- /dev/null +++ b/finetune/mmseg/models/decode_heads/decode_head.py @@ -0,0 +1,366 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from abc import ABCMeta, abstractmethod +from typing import List, Tuple + +import torch +import torch.nn as nn +from mmengine.model import BaseModule +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures import build_pixel_sampler +from mmseg.utils import ConfigType, SampleList +from ..losses import accuracy +from ..utils import resize + + +class BaseDecodeHead(BaseModule, metaclass=ABCMeta): + """Base class for BaseDecodeHead. + + 1. The ``init_weights`` method is used to initialize decode_head's + model parameters. After segmentor initialization, ``init_weights`` + is triggered when ``segmentor.init_weights()`` is called externally. + + 2. The ``loss`` method is used to calculate the loss of decode_head, + which includes two steps: (1) the decode_head model performs forward + propagation to obtain the feature maps (2) The ``loss_by_feat`` method + is called based on the feature maps to calculate the loss. + + .. code:: text + + loss(): forward() -> loss_by_feat() + + 3. The ``predict`` method is used to predict segmentation results, + which includes two steps: (1) the decode_head model performs forward + propagation to obtain the feature maps (2) The ``predict_by_feat`` method + is called based on the feature maps to predict segmentation results + including post-processing. + + .. code:: text + + predict(): forward() -> predict_by_feat() + + Args: + in_channels (int|Sequence[int]): Input channels. + channels (int): Channels after modules, before conv_seg. + num_classes (int): Number of classes. + out_channels (int): Output channels of conv_seg. Default: None. + threshold (float): Threshold for binary segmentation in the case of + `num_classes==1`. Default: None. + dropout_ratio (float): Ratio of dropout layer. Default: 0.1. + conv_cfg (dict|None): Config of conv layers. Default: None. + norm_cfg (dict|None): Config of norm layers. Default: None. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + in_index (int|Sequence[int]): Input feature index. Default: -1 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + Default: None. + loss_decode (dict | Sequence[dict]): Config of decode loss. + The `loss_name` is property of corresponding loss function which + could be shown in training log. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + e.g. dict(type='CrossEntropyLoss'), + [dict(type='CrossEntropyLoss', loss_name='loss_ce'), + dict(type='DiceLoss', loss_name='loss_dice')] + Default: dict(type='CrossEntropyLoss'). + ignore_index (int | None): The label index to be ignored. When using + masked BCE loss, ignore_index should be set to None. Default: 255. + sampler (dict|None): The config of segmentation map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + channels, + *, + num_classes, + out_channels=None, + threshold=None, + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + in_index=-1, + input_transform=None, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + ignore_index=255, + sampler=None, + align_corners=False, + init_cfg=dict( + type='Normal', std=0.01, override=dict(name='conv_seg'))): + super().__init__(init_cfg) + self._init_inputs(in_channels, in_index, input_transform) + self.channels = channels + self.dropout_ratio = dropout_ratio + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.in_index = in_index + + self.ignore_index = ignore_index + self.align_corners = align_corners + + if out_channels is None: + if num_classes == 2: + warnings.warn('For binary segmentation, we suggest using' + '`out_channels = 1` to define the output' + 'channels of segmentor, and use `threshold`' + 'to convert `seg_logits` into a prediction' + 'applying a threshold') + out_channels = num_classes + + if out_channels != num_classes and out_channels != 1: + raise ValueError( + 'out_channels should be equal to num_classes,' + 'except binary segmentation set out_channels == 1 and' + f'num_classes == 2, but got out_channels={out_channels}' + f'and num_classes={num_classes}') + + if out_channels == 1 and threshold is None: + threshold = 0.3 + warnings.warn('threshold is not defined for binary, and defaults' + 'to 0.3') + self.num_classes = num_classes + self.out_channels = out_channels + self.threshold = threshold + + if isinstance(loss_decode, dict): + self.loss_decode = MODELS.build(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(MODELS.build(loss)) + else: + raise TypeError(f'loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}') + + if sampler is not None: + self.sampler = build_pixel_sampler(sampler, context=self) + else: + self.sampler = None + + self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + + def extra_repr(self): + """Extra repr.""" + s = f'input_transform={self.input_transform}, ' \ + f'ignore_index={self.ignore_index}, ' \ + f'align_corners={self.align_corners}' + return s + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + @abstractmethod + def forward(self, inputs): + """Placeholder of forward function.""" + pass + + def cls_seg(self, feat): + """Classify each pixel.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.conv_seg(feat) + return output + + def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Forward function for training. + + Args: + inputs (Tuple[Tensor]): List of multi-level img features. + batch_data_samples (list[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `img_metas` or `gt_semantic_seg`. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs) + losses = self.loss_by_feat(seg_logits, batch_data_samples) + return losses + + def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tensor: + """Forward function for prediction. + + Args: + inputs (Tuple[Tensor]): List of multi-level img features. + batch_img_metas (dict): List Image info where each dict may also + contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Outputs segmentation logits map. + """ + seg_logits = self.forward(inputs) + + return self.predict_by_feat(seg_logits, batch_img_metas) + + def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor: + gt_semantic_segs = [ + data_sample.gt_sem_seg.data for data_sample in batch_data_samples + ] + return torch.stack(gt_semantic_segs, dim=0) + + def loss_by_feat(self, seg_logits: Tensor, + batch_data_samples: SampleList) -> dict: + """Compute segmentation loss. + + Args: + seg_logits (Tensor): The output from decode head forward function. + batch_data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + seg_label = self._stack_batch_gt(batch_data_samples) + loss = dict() + seg_logits = resize( + input=seg_logits, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logits, seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + seg_logits, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + else: + loss[loss_decode.loss_name] += loss_decode( + seg_logits, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + + loss['acc_seg'] = accuracy( + seg_logits, seg_label, ignore_index=self.ignore_index) + return loss + + def predict_by_feat(self, seg_logits: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Transform a batch of output seg_logits to the input shape. + + Args: + seg_logits (Tensor): The output from decode head forward function. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + Tensor: Outputs segmentation logits map. + """ + + if isinstance(batch_img_metas[0]['img_shape'], torch.Size): + # slide inference + size = batch_img_metas[0]['img_shape'] + elif 'pad_shape' in batch_img_metas[0]: + size = batch_img_metas[0]['pad_shape'][:2] + else: + size = batch_img_metas[0]['img_shape'] + + seg_logits = resize( + input=seg_logits, + size=size, + mode='bilinear', + align_corners=self.align_corners) + return seg_logits diff --git a/finetune/mmseg/models/decode_heads/dm_head.py b/finetune/mmseg/models/decode_heads/dm_head.py new file mode 100644 index 0000000..7694abd --- /dev/null +++ b/finetune/mmseg/models/decode_heads/dm_head.py @@ -0,0 +1,141 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer + +from mmseg.registry import MODELS +from .decode_head import BaseDecodeHead + + +class DCM(nn.Module): + """Dynamic Convolutional Module used in DMNet. + + Args: + filter_size (int): The filter size of generated convolution kernel + used in Dynamic Convolutional Module. + fusion (bool): Add one conv to fuse DCM output feature. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg, + norm_cfg, act_cfg): + super().__init__() + self.filter_size = filter_size + self.fusion = fusion + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1, + 0) + + self.input_redu_conv = ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + if self.norm_cfg is not None: + self.norm = build_norm_layer(self.norm_cfg, self.channels)[1] + else: + self.norm = None + self.activate = build_activation_layer(self.act_cfg) + + if self.fusion: + self.fusion_conv = ConvModule( + self.channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, x): + """Forward function.""" + generated_filter = self.filter_gen_conv( + F.adaptive_avg_pool2d(x, self.filter_size)) + x = self.input_redu_conv(x) + b, c, h, w = x.shape + # [1, b * c, h, w], c = self.channels + x = x.view(1, b * c, h, w) + # [b * c, 1, filter_size, filter_size] + generated_filter = generated_filter.view(b * c, 1, self.filter_size, + self.filter_size) + pad = (self.filter_size - 1) // 2 + if (self.filter_size - 1) % 2 == 0: + p2d = (pad, pad, pad, pad) + else: + p2d = (pad + 1, pad, pad + 1, pad) + x = F.pad(input=x, pad=p2d, mode='constant', value=0) + # [1, b * c, h, w] + output = F.conv2d(input=x, weight=generated_filter, groups=b * c) + # [b, c, h, w] + output = output.view(b, c, h, w) + if self.norm is not None: + output = self.norm(output) + output = self.activate(output) + + if self.fusion: + output = self.fusion_conv(output) + + return output + + +@MODELS.register_module() +class DMHead(BaseDecodeHead): + """Dynamic Multi-scale Filters for Semantic Segmentation. + + This head is the implementation of + `DMNet `_. + + Args: + filter_sizes (tuple[int]): The size of generated convolutional filters + used in Dynamic Convolutional Module. Default: (1, 3, 5, 7). + fusion (bool): Add one conv to fuse DCM output feature. + """ + + def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs): + super().__init__(**kwargs) + assert isinstance(filter_sizes, (list, tuple)) + self.filter_sizes = filter_sizes + self.fusion = fusion + dcm_modules = [] + for filter_size in self.filter_sizes: + dcm_modules.append( + DCM(filter_size, + self.fusion, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.dcm_modules = nn.ModuleList(dcm_modules) + self.bottleneck = ConvModule( + self.in_channels + len(filter_sizes) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + dcm_outs = [x] + for dcm_module in self.dcm_modules: + dcm_outs.append(dcm_module(x)) + dcm_outs = torch.cat(dcm_outs, dim=1) + output = self.bottleneck(dcm_outs) + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/dnl_head.py b/finetune/mmseg/models/decode_heads/dnl_head.py new file mode 100644 index 0000000..248c118 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/dnl_head.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import NonLocal2d +from torch import nn + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + + +class DisentangledNonLocal2d(NonLocal2d): + """Disentangled Non-Local Blocks. + + Args: + temperature (float): Temperature to adjust attention. Default: 0.05 + """ + + def __init__(self, *arg, temperature, **kwargs): + super().__init__(*arg, **kwargs) + self.temperature = temperature + self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1) + + def embedded_gaussian(self, theta_x, phi_x): + """Embedded gaussian with temperature.""" + + # NonLocal2d pairwise_weight: [N, HxW, HxW] + pairwise_weight = torch.matmul(theta_x, phi_x) + if self.use_scale: + # theta_x.shape[-1] is `self.inter_channels` + pairwise_weight /= torch.tensor( + theta_x.shape[-1], + dtype=torch.float, + device=pairwise_weight.device)**torch.tensor( + 0.5, device=pairwise_weight.device) + pairwise_weight /= torch.tensor( + self.temperature, device=pairwise_weight.device) + pairwise_weight = pairwise_weight.softmax(dim=-1) + return pairwise_weight + + def forward(self, x): + # x: [N, C, H, W] + n = x.size(0) + + # g_x: [N, HxW, C] + g_x = self.g(x).view(n, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + # theta_x: [N, HxW, C], phi_x: [N, C, HxW] + if self.mode == 'gaussian': + theta_x = x.view(n, self.in_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + if self.sub_sample: + phi_x = self.phi(x).view(n, self.in_channels, -1) + else: + phi_x = x.view(n, self.in_channels, -1) + elif self.mode == 'concatenation': + theta_x = self.theta(x).view(n, self.inter_channels, -1, 1) + phi_x = self.phi(x).view(n, self.inter_channels, 1, -1) + else: + theta_x = self.theta(x).view(n, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(n, self.inter_channels, -1) + + # subtract mean + theta_x -= theta_x.mean(dim=-2, keepdim=True) + phi_x -= phi_x.mean(dim=-1, keepdim=True) + + pairwise_func = getattr(self, self.mode) + # pairwise_weight: [N, HxW, HxW] + pairwise_weight = pairwise_func(theta_x, phi_x) + + # y: [N, HxW, C] + y = torch.matmul(pairwise_weight, g_x) + # y: [N, C, H, W] + y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels, + *x.size()[2:]) + + # unary_mask: [N, 1, HxW] + unary_mask = self.conv_mask(x) + unary_mask = unary_mask.view(n, 1, -1) + unary_mask = unary_mask.softmax(dim=-1) + # unary_x: [N, 1, C] + unary_x = torch.matmul(unary_mask, g_x) + # unary_x: [N, C, 1, 1] + unary_x = unary_x.permute(0, 2, 1).contiguous().reshape( + n, self.inter_channels, 1, 1) + + output = x + self.conv_out(y + unary_x) + + return output + + +@MODELS.register_module() +class DNLHead(FCNHead): + """Disentangled Non-Local Neural Networks. + + This head is the implementation of `DNLNet + `_. + + Args: + reduction (int): Reduction factor of projection transform. Default: 2. + use_scale (bool): Whether to scale pairwise_weight by + sqrt(1/inter_channels). Default: False. + mode (str): The nonlocal mode. Options are 'embedded_gaussian', + 'dot_product'. Default: 'embedded_gaussian.'. + temperature (float): Temperature to adjust attention. Default: 0.05 + """ + + def __init__(self, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + temperature=0.05, + **kwargs): + super().__init__(num_convs=2, **kwargs) + self.reduction = reduction + self.use_scale = use_scale + self.mode = mode + self.temperature = temperature + self.dnl_block = DisentangledNonLocal2d( + in_channels=self.channels, + reduction=self.reduction, + use_scale=self.use_scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + mode=self.mode, + temperature=self.temperature) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.dnl_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/dpt_head.py b/finetune/mmseg/models/decode_heads/dpt_head.py new file mode 100644 index 0000000..d2cfd89 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/dpt_head.py @@ -0,0 +1,294 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, Linear, build_activation_layer +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class ReassembleBlocks(BaseModule): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + + Args: + in_channels (int): ViT feature channels. Default: 768. + out_channels (List): output channels of each stage. + Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, + in_channels=768, + out_channels=[96, 192, 384, 768], + readout_type='ignore', + patch_size=16, + init_cfg=None): + super().__init__(init_cfg) + + assert readout_type in ['ignore', 'add', 'project'] + self.readout_type = readout_type + self.patch_size = patch_size + + self.projects = nn.ModuleList([ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_cfg=None, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + if self.readout_type == 'project': + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + Linear(2 * in_channels, in_channels), + build_activation_layer(dict(type='GELU')))) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == 'project': + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == 'add': + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(BaseModule): + """ResidualConvUnit, pre-activate residual unit. + + Args: + in_channels (int): number of channels in the input feature map. + act_cfg (dict): dictionary to construct and config activation layer. + norm_cfg (dict): dictionary to construct and config norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, + in_channels, + act_cfg, + norm_cfg, + stride=1, + dilation=1, + init_cfg=None): + super().__init__(init_cfg) + + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=('act', 'conv', 'norm')) + + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=('act', 'conv', 'norm')) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(BaseModule): + """FeatureFusionBlock, merge feature map from different stages. + + Args: + in_channels (int): Input channels. + act_cfg (dict): The activation config for ResidualConvUnit. + norm_cfg (dict): Config dict for normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, + in_channels, + act_cfg, + norm_cfg, + expand=False, + align_corners=True, + init_cfg=None): + super().__init__(init_cfg) + + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + + self.project = ConvModule( + self.in_channels, + self.out_channels, + kernel_size=1, + act_cfg=None, + bias=True) + + self.res_conv_unit1 = PreActResidualConvUnit( + in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + self.res_conv_unit2 = PreActResidualConvUnit( + in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + + def forward(self, *inputs): + x = inputs[0] + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = resize( + inputs[1], + size=(x.shape[2], x.shape[3]), + mode='bilinear', + align_corners=False) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) + x = resize( + x, + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners) + x = self.project(x) + return x + + +@MODELS.register_module() +class DPTHead(BaseDecodeHead): + """Vision Transformers for Dense Prediction. + + This head is implemented of `DPT `_. + + Args: + embed_dims (int): The embed dimension of the ViT backbone. + Default: 768. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + act_cfg (dict): The activation config for residual conv unit. + Default dict(type='ReLU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + """ + + def __init__(self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type='ignore', + patch_size=16, + expand_channels=False, + act_cfg=dict(type='ReLU'), + norm_cfg=dict(type='BN'), + **kwargs): + super().__init__(**kwargs) + + self.in_channels = self.in_channels + self.expand_channels = expand_channels + self.reassemble_blocks = ReassembleBlocks(embed_dims, + post_process_channels, + readout_type, patch_size) + + self.post_process_channels = [ + channel * math.pow(2, i) if expand_channels else channel + for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append( + ConvModule( + channel, + self.channels, + kernel_size=3, + padding=1, + act_cfg=None, + bias=False)) + self.fusion_blocks = nn.ModuleList() + for _ in range(len(self.convs)): + self.fusion_blocks.append( + FeatureFusionBlock(self.channels, act_cfg, norm_cfg)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule( + self.channels, + self.channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + + def forward(self, inputs): + assert len(inputs) == self.num_reassemble_blocks + x = self._transform_inputs(inputs) + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + out = self.project(out) + out = self.cls_seg(out) + return out diff --git a/finetune/mmseg/models/decode_heads/ema_head.py b/finetune/mmseg/models/decode_heads/ema_head.py new file mode 100644 index 0000000..ab8dbb0 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/ema_head.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from .decode_head import BaseDecodeHead + + +def reduce_mean(tensor): + """Reduce mean when distributed training.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor + + +class EMAModule(nn.Module): + """Expectation Maximization Attention Module used in EMANet. + + Args: + channels (int): Channels of the whole module. + num_bases (int): Number of bases. + num_stages (int): Number of the EM iterations. + """ + + def __init__(self, channels, num_bases, num_stages, momentum): + super().__init__() + assert num_stages >= 1, 'num_stages must be at least 1!' + self.num_bases = num_bases + self.num_stages = num_stages + self.momentum = momentum + + bases = torch.zeros(1, channels, self.num_bases) + bases.normal_(0, math.sqrt(2. / self.num_bases)) + # [1, channels, num_bases] + bases = F.normalize(bases, dim=1, p=2) + self.register_buffer('bases', bases) + + def forward(self, feats): + """Forward function.""" + batch_size, channels, height, width = feats.size() + # [batch_size, channels, height*width] + feats = feats.view(batch_size, channels, height * width) + # [batch_size, channels, num_bases] + bases = self.bases.repeat(batch_size, 1, 1) + + with torch.no_grad(): + for i in range(self.num_stages): + # [batch_size, height*width, num_bases] + attention = torch.einsum('bcn,bck->bnk', feats, bases) + attention = F.softmax(attention, dim=2) + # l1 norm + attention_normed = F.normalize(attention, dim=1, p=1) + # [batch_size, channels, num_bases] + bases = torch.einsum('bcn,bnk->bck', feats, attention_normed) + # l2 norm + bases = F.normalize(bases, dim=1, p=2) + + feats_recon = torch.einsum('bck,bnk->bcn', bases, attention) + feats_recon = feats_recon.view(batch_size, channels, height, width) + + if self.training: + bases = bases.mean(dim=0, keepdim=True) + bases = reduce_mean(bases) + # l2 norm + bases = F.normalize(bases, dim=1, p=2) + self.bases = (1 - + self.momentum) * self.bases + self.momentum * bases + + return feats_recon + + +@MODELS.register_module() +class EMAHead(BaseDecodeHead): + """Expectation Maximization Attention Networks for Semantic Segmentation. + + This head is the implementation of `EMANet + `_. + + Args: + ema_channels (int): EMA module channels + num_bases (int): Number of bases. + num_stages (int): Number of the EM iterations. + concat_input (bool): Whether concat the input and output of convs + before classification layer. Default: True + momentum (float): Momentum to update the base. Default: 0.1. + """ + + def __init__(self, + ema_channels, + num_bases, + num_stages, + concat_input=True, + momentum=0.1, + **kwargs): + super().__init__(**kwargs) + self.ema_channels = ema_channels + self.num_bases = num_bases + self.num_stages = num_stages + self.concat_input = concat_input + self.momentum = momentum + self.ema_module = EMAModule(self.ema_channels, self.num_bases, + self.num_stages, self.momentum) + + self.ema_in_conv = ConvModule( + self.in_channels, + self.ema_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + # project (0, inf) -> (-inf, inf) + self.ema_mid_conv = ConvModule( + self.ema_channels, + self.ema_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=None, + act_cfg=None) + for param in self.ema_mid_conv.parameters(): + param.requires_grad = False + + self.ema_out_conv = ConvModule( + self.ema_channels, + self.ema_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=None) + self.bottleneck = ConvModule( + self.ema_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if self.concat_input: + self.conv_cat = ConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + feats = self.ema_in_conv(x) + identity = feats + feats = self.ema_mid_conv(feats) + recon = self.ema_module(feats) + recon = F.relu(recon, inplace=True) + recon = self.ema_out_conv(recon) + output = F.relu(identity + recon, inplace=True) + output = self.bottleneck(output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/enc_head.py b/finetune/mmseg/models/decode_heads/enc_head.py new file mode 100644 index 0000000..2bba73b --- /dev/null +++ b/finetune/mmseg/models/decode_heads/enc_head.py @@ -0,0 +1,196 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_norm_layer +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import ConfigType, SampleList +from ..utils import Encoding, resize +from .decode_head import BaseDecodeHead + + +class EncModule(nn.Module): + """Encoding Module used in EncNet. + + Args: + in_channels (int): Input channels. + num_codes (int): Number of code words. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + """ + + def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg): + super().__init__() + self.encoding_project = ConvModule( + in_channels, + in_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + # TODO: resolve this hack + # change to 1d + if norm_cfg is not None: + encoding_norm_cfg = norm_cfg.copy() + if encoding_norm_cfg['type'] in ['BN', 'IN']: + encoding_norm_cfg['type'] += '1d' + else: + encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace( + '2d', '1d') + else: + # fallback to BN1d + encoding_norm_cfg = dict(type='BN1d') + self.encoding = nn.Sequential( + Encoding(channels=in_channels, num_codes=num_codes), + build_norm_layer(encoding_norm_cfg, num_codes)[1], + nn.ReLU(inplace=True)) + self.fc = nn.Sequential( + nn.Linear(in_channels, in_channels), nn.Sigmoid()) + + def forward(self, x): + """Forward function.""" + encoding_projection = self.encoding_project(x) + encoding_feat = self.encoding(encoding_projection).mean(dim=1) + batch_size, channels, _, _ = x.size() + gamma = self.fc(encoding_feat) + y = gamma.view(batch_size, channels, 1, 1) + output = F.relu_(x + x * y) + return encoding_feat, output + + +@MODELS.register_module() +class EncHead(BaseDecodeHead): + """Context Encoding for Semantic Segmentation. + + This head is the implementation of `EncNet + `_. + + Args: + num_codes (int): Number of code words. Default: 32. + use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to + regularize the training. Default: True. + add_lateral (bool): Whether use lateral connection to fuse features. + Default: False. + loss_se_decode (dict): Config of decode loss. + Default: dict(type='CrossEntropyLoss', use_sigmoid=True). + """ + + def __init__(self, + num_codes=32, + use_se_loss=True, + add_lateral=False, + loss_se_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=0.2), + **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + self.use_se_loss = use_se_loss + self.add_lateral = add_lateral + self.num_codes = num_codes + self.bottleneck = ConvModule( + self.in_channels[-1], + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if add_lateral: + self.lateral_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the last one + self.lateral_convs.append( + ConvModule( + in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.fusion = ConvModule( + len(self.in_channels) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.enc_module = EncModule( + self.channels, + num_codes=num_codes, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if self.use_se_loss: + self.loss_se_decode = MODELS.build(loss_se_decode) + self.se_layer = nn.Linear(self.channels, self.num_classes) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + feat = self.bottleneck(inputs[-1]) + if self.add_lateral: + laterals = [ + resize( + lateral_conv(inputs[i]), + size=feat.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + feat = self.fusion(torch.cat([feat, *laterals], 1)) + encode_feat, output = self.enc_module(feat) + output = self.cls_seg(output) + if self.use_se_loss: + se_output = self.se_layer(encode_feat) + return output, se_output + else: + return output + + def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType): + """Forward function for testing, ignore se_loss.""" + if self.use_se_loss: + seg_logits = self.forward(inputs)[0] + else: + seg_logits = self.forward(inputs) + return self.predict_by_feat(seg_logits, batch_img_metas) + + @staticmethod + def _convert_to_onehot_labels(seg_label, num_classes): + """Convert segmentation label to onehot. + + Args: + seg_label (Tensor): Segmentation label of shape (N, H, W). + num_classes (int): Number of classes. + + Returns: + Tensor: Onehot labels of shape (N, num_classes). + """ + + batch_size = seg_label.size(0) + onehot_labels = seg_label.new_zeros((batch_size, num_classes)) + for i in range(batch_size): + hist = seg_label[i].float().histc( + bins=num_classes, min=0, max=num_classes - 1) + onehot_labels[i] = hist > 0 + return onehot_labels + + def loss_by_feat(self, seg_logit: Tuple[Tensor], + batch_data_samples: SampleList, **kwargs) -> dict: + """Compute segmentation and semantic encoding loss.""" + seg_logit, se_seg_logit = seg_logit + loss = dict() + loss.update(super().loss_by_feat(seg_logit, batch_data_samples)) + + seg_label = self._stack_batch_gt(batch_data_samples) + se_loss = self.loss_se_decode( + se_seg_logit, + self._convert_to_onehot_labels(seg_label, self.num_classes)) + loss['loss_se'] = se_loss + return loss diff --git a/finetune/mmseg/models/decode_heads/fcn_head.py b/finetune/mmseg/models/decode_heads/fcn_head.py new file mode 100644 index 0000000..3418018 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/fcn_head.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class FCNHead(BaseDecodeHead): + """Fully Convolution Networks for Semantic Segmentation. + + This head is implemented of `FCNNet `_. + + Args: + num_convs (int): Number of convs in the head. Default: 2. + kernel_size (int): The kernel size for convs in the head. Default: 3. + concat_input (bool): Whether concat the input and output of convs + before classification layer. + dilation (int): The dilation rate for convs in the head. Default: 1. + """ + + def __init__(self, + num_convs=2, + kernel_size=3, + concat_input=True, + dilation=1, + **kwargs): + assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) + self.num_convs = num_convs + self.concat_input = concat_input + self.kernel_size = kernel_size + super().__init__(**kwargs) + if num_convs == 0: + assert self.in_channels == self.channels + + conv_padding = (kernel_size // 2) * dilation + convs = [] + convs.append( + ConvModule( + self.in_channels, + self.channels, + kernel_size=kernel_size, + padding=conv_padding, + dilation=dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + for i in range(num_convs - 1): + convs.append( + ConvModule( + self.channels, + self.channels, + kernel_size=kernel_size, + padding=conv_padding, + dilation=dilation, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + if num_convs == 0: + self.convs = nn.Identity() + else: + self.convs = nn.Sequential(*convs) + if self.concat_input: + self.conv_cat = ConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + feats = self.convs(x) + if self.concat_input: + feats = self.conv_cat(torch.cat([x, feats], dim=1)) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/fpn_head.py b/finetune/mmseg/models/decode_heads/fpn_head.py new file mode 100644 index 0000000..25f481f --- /dev/null +++ b/finetune/mmseg/models/decode_heads/fpn_head.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import Upsample, resize +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class FPNHead(BaseDecodeHead): + """Panoptic Feature Pyramid Networks. + + This head is the implementation of `Semantic FPN + `_. + + Args: + feature_strides (tuple[int]): The strides for input feature maps. + stack_lateral. All strides suppose to be power of 2. The first + one is of largest resolution. + """ + + def __init__(self, feature_strides, **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + assert len(feature_strides) == len(self.in_channels) + assert min(feature_strides) == feature_strides[0] + self.feature_strides = feature_strides + + self.scale_heads = nn.ModuleList() + for i in range(len(feature_strides)): + head_length = max( + 1, + int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) + scale_head = [] + for k in range(head_length): + scale_head.append( + ConvModule( + self.in_channels[i] if k == 0 else self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + if feature_strides[i] != feature_strides[0]: + scale_head.append( + Upsample( + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners)) + self.scale_heads.append(nn.Sequential(*scale_head)) + + def forward(self, inputs): + + x = self._transform_inputs(inputs) + + output = self.scale_heads[0](x[0]) + for i in range(1, len(self.feature_strides)): + # non inplace + output = output + resize( + self.scale_heads[i](x[i]), + size=output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/gc_head.py b/finetune/mmseg/models/decode_heads/gc_head.py new file mode 100644 index 0000000..14f0ef0 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/gc_head.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import ContextBlock + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + + +@MODELS.register_module() +class GCHead(FCNHead): + """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. + + This head is the implementation of `GCNet + `_. + + Args: + ratio (float): Multiplier of channels ratio. Default: 1/4. + pooling_type (str): The pooling type of context aggregation. + Options are 'att', 'avg'. Default: 'avg'. + fusion_types (tuple[str]): The fusion type for feature fusion. + Options are 'channel_add', 'channel_mul'. Default: ('channel_add',) + """ + + def __init__(self, + ratio=1 / 4., + pooling_type='att', + fusion_types=('channel_add', ), + **kwargs): + super().__init__(num_convs=2, **kwargs) + self.ratio = ratio + self.pooling_type = pooling_type + self.fusion_types = fusion_types + self.gc_block = ContextBlock( + in_channels=self.channels, + ratio=self.ratio, + pooling_type=self.pooling_type, + fusion_types=self.fusion_types) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.gc_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/ham_head.py b/finetune/mmseg/models/decode_heads/ham_head.py new file mode 100644 index 0000000..073d801 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/ham_head.py @@ -0,0 +1,255 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Originally from https://github.com/visual-attention-network/segnext +# Licensed under the Apache License, Version 2.0 (the "License") +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.device import get_device + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class Matrix_Decomposition_2D_Base(nn.Module): + """Base class of 2D Matrix Decomposition. + + Args: + MD_S (int): The number of spatial coefficient in + Matrix Decomposition, it may be used for calculation + of the number of latent dimension D in Matrix + Decomposition. Defaults: 1. + MD_R (int): The number of latent dimension R in + Matrix Decomposition. Defaults: 64. + train_steps (int): The number of iteration steps in + Multiplicative Update (MU) rule to solve Non-negative + Matrix Factorization (NMF) in training. Defaults: 6. + eval_steps (int): The number of iteration steps in + Multiplicative Update (MU) rule to solve Non-negative + Matrix Factorization (NMF) in evaluation. Defaults: 7. + inv_t (int): Inverted multiple number to make coefficient + smaller in softmax. Defaults: 100. + rand_init (bool): Whether to initialize randomly. + Defaults: True. + """ + + def __init__(self, + MD_S=1, + MD_R=64, + train_steps=6, + eval_steps=7, + inv_t=100, + rand_init=True): + super().__init__() + + self.S = MD_S + self.R = MD_R + + self.train_steps = train_steps + self.eval_steps = eval_steps + + self.inv_t = inv_t + + self.rand_init = rand_init + + def _build_bases(self, B, S, D, R, device=None): + raise NotImplementedError + + def local_step(self, x, bases, coef): + raise NotImplementedError + + def local_inference(self, x, bases): + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + coef = torch.bmm(x.transpose(1, 2), bases) + coef = F.softmax(self.inv_t * coef, dim=-1) + + steps = self.train_steps if self.training else self.eval_steps + for _ in range(steps): + bases, coef = self.local_step(x, bases, coef) + + return bases, coef + + def compute_coef(self, x, bases, coef): + raise NotImplementedError + + def forward(self, x, return_bases=False): + """Forward Function.""" + B, C, H, W = x.shape + + # (B, C, H, W) -> (B * S, D, N) + D = C // self.S + N = H * W + x = x.view(B * self.S, D, N) + if not self.rand_init and not hasattr(self, 'bases'): + bases = self._build_bases(1, self.S, D, self.R, device=x.device) + self.register_buffer('bases', bases) + + # (S, D, R) -> (B * S, D, R) + if self.rand_init: + bases = self._build_bases(B, self.S, D, self.R, device=x.device) + else: + bases = self.bases.repeat(B, 1, 1) + + bases, coef = self.local_inference(x, bases) + + # (B * S, N, R) + coef = self.compute_coef(x, bases, coef) + + # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N) + x = torch.bmm(bases, coef.transpose(1, 2)) + + # (B * S, D, N) -> (B, C, H, W) + x = x.view(B, C, H, W) + + return x + + +class NMF2D(Matrix_Decomposition_2D_Base): + """Non-negative Matrix Factorization (NMF) module. + + It is inherited from ``Matrix_Decomposition_2D_Base`` module. + """ + + def __init__(self, args=dict()): + super().__init__(**args) + + self.inv_t = 1 + + def _build_bases(self, B, S, D, R, device=None): + """Build bases in initialization.""" + if device is None: + device = get_device() + bases = torch.rand((B * S, D, R)).to(device) + bases = F.normalize(bases, dim=1) + + return bases + + def local_step(self, x, bases, coef): + """Local step in iteration to renew bases and coefficient.""" + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # Multiplicative Update + coef = coef * numerator / (denominator + 1e-6) + + # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R) + numerator = torch.bmm(x, coef) + # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R) + denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) + # Multiplicative Update + bases = bases * numerator / (denominator + 1e-6) + + return bases, coef + + def compute_coef(self, x, bases, coef): + """Compute coefficient.""" + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # multiplication update + coef = coef * numerator / (denominator + 1e-6) + + return coef + + +class Hamburger(nn.Module): + """Hamburger Module. It consists of one slice of "ham" (matrix + decomposition) and two slices of "bread" (linear transformation). + + Args: + ham_channels (int): Input and output channels of feature. + ham_kwargs (dict): Config of matrix decomposition module. + norm_cfg (dict | None): Config of norm layers. + """ + + def __init__(self, + ham_channels=512, + ham_kwargs=dict(), + norm_cfg=None, + **kwargs): + super().__init__() + + self.ham_in = ConvModule( + ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None) + + self.ham = NMF2D(ham_kwargs) + + self.ham_out = ConvModule( + ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None) + + def forward(self, x): + enjoy = self.ham_in(x) + enjoy = F.relu(enjoy, inplace=True) + enjoy = self.ham(enjoy) + enjoy = self.ham_out(enjoy) + ham = F.relu(x + enjoy, inplace=True) + + return ham + + +@MODELS.register_module() +class LightHamHead(BaseDecodeHead): + """SegNeXt decode head. + + This decode head is the implementation of `SegNeXt: Rethinking + Convolutional Attention Design for Semantic + Segmentation `_. + Inspiration from https://github.com/visual-attention-network/segnext. + + Specifically, LightHamHead is inspired by HamNet from + `Is Attention Better Than Matrix Decomposition? + `. + + Args: + ham_channels (int): input channels for Hamburger. + Defaults: 512. + ham_kwargs (int): kwagrs for Ham. Defaults: dict(). + """ + + def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + self.ham_channels = ham_channels + + self.squeeze = ConvModule( + sum(self.in_channels), + self.ham_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs) + + self.align = ConvModule( + self.ham_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + + inputs = [ + resize( + level, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for level in inputs + ] + + inputs = torch.cat(inputs, dim=1) + # apply a conv block to squeeze feature map + x = self.squeeze(inputs) + # apply hamburger module + x = self.hamburger(x) + + # apply a conv block to align feature map + output = self.align(x) + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/isa_head.py b/finetune/mmseg/models/decode_heads/isa_head.py new file mode 100644 index 0000000..355f215 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/isa_head.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from .decode_head import BaseDecodeHead + + +class SelfAttentionBlock(_SelfAttentionBlock): + """Self-Attention Module. + + Args: + in_channels (int): Input channels of key/query feature. + channels (int): Output channels of key/query transform. + conv_cfg (dict | None): Config of conv layers. + norm_cfg (dict | None): Config of norm layers. + act_cfg (dict | None): Config of activation layers. + """ + + def __init__(self, in_channels, channels, conv_cfg, norm_cfg, act_cfg): + super().__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=None, + key_downsample=None, + key_query_num_convs=2, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=False, + matmul_norm=True, + with_out=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.output_project = self.build_project( + in_channels, + in_channels, + num_convs=1, + use_conv_module=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + """Forward function.""" + context = super().forward(x, x) + return self.output_project(context) + + +@MODELS.register_module() +class ISAHead(BaseDecodeHead): + """Interlaced Sparse Self-Attention for Semantic Segmentation. + + This head is the implementation of `ISA + `_. + + Args: + isa_channels (int): The channels of ISA Module. + down_factor (tuple[int]): The local group size of ISA. + """ + + def __init__(self, isa_channels, down_factor=(8, 8), **kwargs): + super().__init__(**kwargs) + self.down_factor = down_factor + + self.in_conv = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.global_relation = SelfAttentionBlock( + self.channels, + isa_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.local_relation = SelfAttentionBlock( + self.channels, + isa_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.out_conv = ConvModule( + self.channels * 2, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x_ = self._transform_inputs(inputs) + x = self.in_conv(x_) + residual = x + + n, c, h, w = x.size() + loc_h, loc_w = self.down_factor # size of local group in H- and W-axes + glb_h, glb_w = math.ceil(h / loc_h), math.ceil(w / loc_w) + pad_h, pad_w = glb_h * loc_h - h, glb_w * loc_w - w + if pad_h > 0 or pad_w > 0: # pad if the size is not divisible + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2) + x = F.pad(x, padding) + + # global relation + x = x.view(n, c, glb_h, loc_h, glb_w, loc_w) + # do permutation to gather global group + x = x.permute(0, 3, 5, 1, 2, 4) # (n, loc_h, loc_w, c, glb_h, glb_w) + x = x.reshape(-1, c, glb_h, glb_w) + # apply attention within each global group + x = self.global_relation(x) # (n * loc_h * loc_w, c, glb_h, glb_w) + + # local relation + x = x.view(n, loc_h, loc_w, c, glb_h, glb_w) + # do permutation to gather local group + x = x.permute(0, 4, 5, 3, 1, 2) # (n, glb_h, glb_w, c, loc_h, loc_w) + x = x.reshape(-1, c, loc_h, loc_w) + # apply attention within each local group + x = self.local_relation(x) # (n * glb_h * glb_w, c, loc_h, loc_w) + + # permute each pixel back to its original position + x = x.view(n, glb_h, glb_w, c, loc_h, loc_w) + x = x.permute(0, 3, 1, 4, 2, 5) # (n, c, glb_h, loc_h, glb_w, loc_w) + x = x.reshape(n, c, glb_h * loc_h, glb_w * loc_w) + if pad_h > 0 or pad_w > 0: # remove padding + x = x[:, :, pad_h // 2:pad_h // 2 + h, pad_w // 2:pad_w // 2 + w] + + x = self.out_conv(torch.cat([x, residual], dim=1)) + out = self.cls_seg(x) + + return out diff --git a/finetune/mmseg/models/decode_heads/knet_head.py b/finetune/mmseg/models/decode_heads/knet_head.py new file mode 100644 index 0000000..82d3a28 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/knet_head.py @@ -0,0 +1,461 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks.transformer import (FFN, MultiheadAttention, + build_transformer_layer) +from mmengine.logging import print_log +from torch import Tensor + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.registry import MODELS +from mmseg.utils import SampleList + + +@MODELS.register_module() +class KernelUpdator(nn.Module): + """Dynamic Kernel Updator in Kernel Update Head. + + Args: + in_channels (int): The number of channels of input feature map. + Default: 256. + feat_channels (int): The number of middle-stage channels in + the kernel updator. Default: 64. + out_channels (int): The number of output channels. + gate_sigmoid (bool): Whether use sigmoid function in gate + mechanism. Default: True. + gate_norm_act (bool): Whether add normalization and activation + layer in gate mechanism. Default: False. + activate_out: Whether add activation after gate mechanism. + Default: False. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='LN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + """ + + def __init__( + self, + in_channels=256, + feat_channels=64, + out_channels=None, + gate_sigmoid=True, + gate_norm_act=False, + activate_out=False, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='ReLU', inplace=True), + ): + super().__init__() + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.gate_sigmoid = gate_sigmoid + self.gate_norm_act = gate_norm_act + self.activate_out = activate_out + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.feat_channels + self.num_params_out = self.feat_channels + self.dynamic_layer = nn.Linear( + self.in_channels, self.num_params_in + self.num_params_out) + self.input_layer = nn.Linear(self.in_channels, + self.num_params_in + self.num_params_out, + 1) + self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1) + if self.gate_norm_act: + self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, update_feature, input_feature): + """Forward function of KernelUpdator. + + Args: + update_feature (torch.Tensor): Feature map assembled from + each group. It would be reshaped with last dimension + shape: `self.in_channels`. + input_feature (torch.Tensor): Intermediate feature + with shape: (N, num_classes, conv_kernel_size**2, channels). + Returns: + Tensor: The output tensor of shape (N*C1/C2, K*K, C2), where N is + the number of classes, C1 and C2 are the feature map channels of + KernelUpdateHead and KernelUpdator, respectively. + """ + + update_feature = update_feature.reshape(-1, self.in_channels) + num_proposals = update_feature.size(0) + # dynamic_layer works for + # phi_1 and psi_3 in Eq.(4) and (5) of K-Net paper + parameters = self.dynamic_layer(update_feature) + param_in = parameters[:, :self.num_params_in].view( + -1, self.feat_channels) + param_out = parameters[:, -self.num_params_out:].view( + -1, self.feat_channels) + + # input_layer works for + # phi_2 and psi_4 in Eq.(4) and (5) of K-Net paper + input_feats = self.input_layer( + input_feature.reshape(num_proposals, -1, self.feat_channels)) + input_in = input_feats[..., :self.num_params_in] + input_out = input_feats[..., -self.num_params_out:] + + # `gate_feats` is F^G in K-Net paper + gate_feats = input_in * param_in.unsqueeze(-2) + if self.gate_norm_act: + gate_feats = self.activation(self.gate_norm(gate_feats)) + + input_gate = self.input_norm_in(self.input_gate(gate_feats)) + update_gate = self.norm_in(self.update_gate(gate_feats)) + if self.gate_sigmoid: + input_gate = input_gate.sigmoid() + update_gate = update_gate.sigmoid() + param_out = self.norm_out(param_out) + input_out = self.input_norm_out(input_out) + + if self.activate_out: + param_out = self.activation(param_out) + input_out = self.activation(input_out) + + # Gate mechanism. Eq.(5) in original paper. + # param_out has shape (batch_size, feat_channels, out_channels) + features = update_gate * param_out.unsqueeze( + -2) + input_gate * input_out + + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features + + +@MODELS.register_module() +class KernelUpdateHead(nn.Module): + """Kernel Update Head in K-Net. + + Args: + num_classes (int): Number of classes. Default: 150. + num_ffn_fcs (int): The number of fully-connected layers in + FFNs. Default: 2. + num_heads (int): The number of parallel attention heads. + Default: 8. + num_mask_fcs (int): The number of fully connected layers for + mask prediction. Default: 3. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 2048. + in_channels (int): The number of channels of input feature map. + Default: 256. + out_channels (int): The number of output channels. + Default: 256. + dropout (float): The Probability of an element to be + zeroed in MultiheadAttention and FFN. Default 0.0. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + ffn_act_cfg (dict): Config of activation layers in FFN. + Default: dict(type='ReLU'). + conv_kernel_size (int): The kernel size of convolution in + Kernel Update Head for dynamic kernel updation. + Default: 1. + feat_transform_cfg (dict | None): Config of feature transform. + Default: None. + kernel_init (bool): Whether initiate mask kernel in mask head. + Default: False. + with_ffn (bool): Whether add FFN in kernel update head. + Default: True. + feat_gather_stride (int): Stride of convolution in feature transform. + Default: 1. + mask_transform_stride (int): Stride of mask transform. + Default: 1. + kernel_updator_cfg (dict): Config of kernel updator. + Default: dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN')). + """ + + def __init__(self, + num_classes=150, + num_ffn_fcs=2, + num_heads=8, + num_mask_fcs=3, + feedforward_channels=2048, + in_channels=256, + out_channels=256, + dropout=0.0, + act_cfg=dict(type='ReLU', inplace=True), + ffn_act_cfg=dict(type='ReLU', inplace=True), + conv_kernel_size=1, + feat_transform_cfg=None, + kernel_init=False, + with_ffn=True, + feat_gather_stride=1, + mask_transform_stride=1, + kernel_updator_cfg=dict( + type='DynamicConv', + in_channels=256, + feat_channels=64, + out_channels=256, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'))): + super().__init__() + self.num_classes = num_classes + self.in_channels = in_channels + self.out_channels = out_channels + self.fp16_enabled = False + self.dropout = dropout + self.num_heads = num_heads + self.kernel_init = kernel_init + self.with_ffn = with_ffn + self.conv_kernel_size = conv_kernel_size + self.feat_gather_stride = feat_gather_stride + self.mask_transform_stride = mask_transform_stride + + self.attention = MultiheadAttention(in_channels * conv_kernel_size**2, + num_heads, dropout) + self.attention_norm = build_norm_layer( + dict(type='LN'), in_channels * conv_kernel_size**2)[1] + self.kernel_update_conv = build_transformer_layer(kernel_updator_cfg) + + if feat_transform_cfg is not None: + kernel_size = feat_transform_cfg.pop('kernel_size', 1) + transform_channels = in_channels + self.feat_transform = ConvModule( + transform_channels, + in_channels, + kernel_size, + stride=feat_gather_stride, + padding=int(feat_gather_stride // 2), + **feat_transform_cfg) + else: + self.feat_transform = None + + if self.with_ffn: + self.ffn = FFN( + in_channels, + feedforward_channels, + num_ffn_fcs, + act_cfg=ffn_act_cfg, + dropout=dropout) + self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1] + + self.mask_fcs = nn.ModuleList() + for _ in range(num_mask_fcs): + self.mask_fcs.append( + nn.Linear(in_channels, in_channels, bias=False)) + self.mask_fcs.append( + build_norm_layer(dict(type='LN'), in_channels)[1]) + self.mask_fcs.append(build_activation_layer(act_cfg)) + + self.fc_mask = nn.Linear(in_channels, out_channels) + + def init_weights(self): + """Use xavier initialization for all weight parameter and set + classification head bias as a specific value when use focal loss.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + else: + # adopt the default initialization for + # the weight and bias of the layer norm + pass + if self.kernel_init: + print_log( + 'mask kernel in mask head is normal initialized by std 0.01') + nn.init.normal_(self.fc_mask.weight, mean=0, std=0.01) + + def forward(self, x, proposal_feat, mask_preds, mask_shape=None): + """Forward function of Dynamic Instance Interactive Head. + + Args: + x (Tensor): Feature map from FPN with shape + (batch_size, feature_dimensions, H , W). + proposal_feat (Tensor): Intermediate feature get from + diihead in last stage, has shape + (batch_size, num_proposals, feature_dimensions) + mask_preds (Tensor): mask prediction from the former stage in shape + (batch_size, num_proposals, H, W). + + Returns: + Tuple: The first tensor is predicted mask with shape + (N, num_classes, H, W), the second tensor is dynamic kernel + with shape (N, num_classes, channels, K, K). + """ + N, num_proposals = proposal_feat.shape[:2] + if self.feat_transform is not None: + x = self.feat_transform(x) + + C, H, W = x.shape[-3:] + + mask_h, mask_w = mask_preds.shape[-2:] + if mask_h != H or mask_w != W: + gather_mask = F.interpolate( + mask_preds, (H, W), align_corners=False, mode='bilinear') + else: + gather_mask = mask_preds + + sigmoid_masks = gather_mask.softmax(dim=1) + + # Group Feature Assembling. Eq.(3) in original paper. + # einsum is faster than bmm by 30% + x_feat = torch.einsum('bnhw,bchw->bnc', sigmoid_masks, x) + + # obj_feat in shape [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C] + proposal_feat = proposal_feat.reshape(N, num_proposals, + self.in_channels, + -1).permute(0, 1, 3, 2) + obj_feat = self.kernel_update_conv(x_feat, proposal_feat) + + # [B, N, K*K, C] -> [B, N, K*K*C] -> [N, B, K*K*C] + obj_feat = obj_feat.reshape(N, num_proposals, -1).permute(1, 0, 2) + obj_feat = self.attention_norm(self.attention(obj_feat)) + # [N, B, K*K*C] -> [B, N, K*K*C] + obj_feat = obj_feat.permute(1, 0, 2) + + # obj_feat in shape [B, N, K*K*C] -> [B, N, K*K, C] + obj_feat = obj_feat.reshape(N, num_proposals, -1, self.in_channels) + + # FFN + if self.with_ffn: + obj_feat = self.ffn_norm(self.ffn(obj_feat)) + + mask_feat = obj_feat + + for reg_layer in self.mask_fcs: + mask_feat = reg_layer(mask_feat) + + # [B, N, K*K, C] -> [B, N, C, K*K] + mask_feat = self.fc_mask(mask_feat).permute(0, 1, 3, 2) + + if (self.mask_transform_stride == 2 and self.feat_gather_stride == 1): + mask_x = F.interpolate( + x, scale_factor=0.5, mode='bilinear', align_corners=False) + H, W = mask_x.shape[-2:] + else: + mask_x = x + # group conv is 5x faster than unfold and uses about 1/5 memory + # Group conv vs. unfold vs. concat batch, 2.9ms :13.5ms :3.8ms + # Group conv vs. unfold vs. concat batch, 278 : 1420 : 369 + # but in real training group conv is slower than concat batch + # so we keep using concat batch. + # fold_x = F.unfold( + # mask_x, + # self.conv_kernel_size, + # padding=int(self.conv_kernel_size // 2)) + # mask_feat = mask_feat.reshape(N, num_proposals, -1) + # new_mask_preds = torch.einsum('bnc,bcl->bnl', mask_feat, fold_x) + # [B, N, C, K*K] -> [B*N, C, K, K] + mask_feat = mask_feat.reshape(N, num_proposals, C, + self.conv_kernel_size, + self.conv_kernel_size) + # [B, C, H, W] -> [1, B*C, H, W] + new_mask_preds = [] + for i in range(N): + new_mask_preds.append( + F.conv2d( + mask_x[i:i + 1], + mask_feat[i], + padding=int(self.conv_kernel_size // 2))) + + new_mask_preds = torch.cat(new_mask_preds, dim=0) + new_mask_preds = new_mask_preds.reshape(N, num_proposals, H, W) + if self.mask_transform_stride == 2: + new_mask_preds = F.interpolate( + new_mask_preds, + scale_factor=2, + mode='bilinear', + align_corners=False) + + if mask_shape is not None and mask_shape[0] != H: + new_mask_preds = F.interpolate( + new_mask_preds, + mask_shape, + align_corners=False, + mode='bilinear') + + return new_mask_preds, obj_feat.permute(0, 1, 3, 2).reshape( + N, num_proposals, self.in_channels, self.conv_kernel_size, + self.conv_kernel_size) + + +@MODELS.register_module() +class IterativeDecodeHead(BaseDecodeHead): + """K-Net: Towards Unified Image Segmentation. + + This head is the implementation of + `K-Net: `_. + + Args: + num_stages (int): The number of stages (kernel update heads) + in IterativeDecodeHead. Default: 3. + kernel_generate_head:(dict): Config of kernel generate head which + generate mask predictions, dynamic kernels and class predictions + for next kernel update heads. + kernel_update_head (dict): Config of kernel update head which refine + dynamic kernels and class predictions iteratively. + + """ + + def __init__(self, num_stages, kernel_generate_head, kernel_update_head, + **kwargs): + # ``IterativeDecodeHead`` would skip initialization of + # ``BaseDecodeHead`` which would be called when building + # ``self.kernel_generate_head``. + super(BaseDecodeHead, self).__init__(**kwargs) + assert num_stages == len(kernel_update_head) + self.num_stages = num_stages + self.kernel_generate_head = MODELS.build(kernel_generate_head) + self.kernel_update_head = nn.ModuleList() + self.align_corners = self.kernel_generate_head.align_corners + self.num_classes = self.kernel_generate_head.num_classes + self.input_transform = self.kernel_generate_head.input_transform + self.ignore_index = self.kernel_generate_head.ignore_index + self.out_channels = self.num_classes + + for head_cfg in kernel_update_head: + self.kernel_update_head.append(MODELS.build(head_cfg)) + + def forward(self, inputs): + """Forward function.""" + feats = self.kernel_generate_head._forward_feature(inputs) + sem_seg = self.kernel_generate_head.cls_seg(feats) + seg_kernels = self.kernel_generate_head.conv_seg.weight.clone() + seg_kernels = seg_kernels[None].expand( + feats.size(0), *seg_kernels.size()) + + stage_segs = [sem_seg] + for i in range(self.num_stages): + sem_seg, seg_kernels = self.kernel_update_head[i](feats, + seg_kernels, + sem_seg) + stage_segs.append(sem_seg) + if self.training: + return stage_segs + # only return the prediction of the last stage during testing + return stage_segs[-1] + + def loss_by_feat(self, seg_logits: List[Tensor], + batch_data_samples: SampleList, **kwargs) -> dict: + losses = dict() + for i, logit in enumerate(seg_logits): + loss = self.kernel_generate_head.loss_by_feat( + logit, batch_data_samples) + for k, v in loss.items(): + losses[f'{k}.s{i}'] = v + + return losses diff --git a/finetune/mmseg/models/decode_heads/lraspp_head.py b/finetune/mmseg/models/decode_heads/lraspp_head.py new file mode 100644 index 0000000..ba2465f --- /dev/null +++ b/finetune/mmseg/models/decode_heads/lraspp_head.py @@ -0,0 +1,91 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.utils import is_tuple_of + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class LRASPPHead(BaseDecodeHead): + """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. + + This head is the improved implementation of `Searching for MobileNetV3 + `_. + + Args: + branch_channels (tuple[int]): The number of output channels in every + each branch. Default: (32, 64). + """ + + def __init__(self, branch_channels=(32, 64), **kwargs): + super().__init__(**kwargs) + if self.input_transform != 'multiple_select': + raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' + f'must be \'multiple_select\'. But received ' + f'\'{self.input_transform}\'') + assert is_tuple_of(branch_channels, int) + assert len(branch_channels) == len(self.in_channels) - 1 + self.branch_channels = branch_channels + + self.convs = nn.Sequential() + self.conv_ups = nn.Sequential() + for i in range(len(branch_channels)): + self.convs.add_module( + f'conv{i}', + nn.Conv2d( + self.in_channels[i], branch_channels[i], 1, bias=False)) + self.conv_ups.add_module( + f'conv_up{i}', + ConvModule( + self.channels + branch_channels[i], + self.channels, + 1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + bias=False)) + + self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) + + self.aspp_conv = ConvModule( + self.in_channels[-1], + self.channels, + 1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + bias=False) + self.image_pool = nn.Sequential( + nn.AvgPool2d(kernel_size=49, stride=(16, 20)), + ConvModule( + self.in_channels[2], + self.channels, + 1, + act_cfg=dict(type='Sigmoid'), + bias=False)) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + + x = inputs[-1] + + x = self.aspp_conv(x) * resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + x = self.conv_up_input(x) + + for i in range(len(self.branch_channels) - 1, -1, -1): + x = resize( + x, + size=inputs[i].size()[2:], + mode='bilinear', + align_corners=self.align_corners) + x = torch.cat([x, self.convs[i](inputs[i])], 1) + x = self.conv_ups[i](x) + + return self.cls_seg(x) diff --git a/finetune/mmseg/models/decode_heads/mask2former_head.py b/finetune/mmseg/models/decode_heads/mask2former_head.py new file mode 100644 index 0000000..0135af0 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/mask2former_head.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +try: + from mmdet.models.dense_heads import \ + Mask2FormerHead as MMDET_Mask2FormerHead +except ModuleNotFoundError: + MMDET_Mask2FormerHead = BaseModule + +from mmengine.structures import InstanceData +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures.seg_data_sample import SegDataSample +from mmseg.utils import ConfigType, SampleList + + +@MODELS.register_module() +class Mask2FormerHead(MMDET_Mask2FormerHead): + """Implements the Mask2Former head. + + See `Mask2Former: Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + num_classes (int): Number of classes. Default: 150. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + ignore_index (int): The label index to be ignored. Default: 255. + """ + + def __init__(self, + num_classes, + align_corners=False, + ignore_index=255, + **kwargs): + super().__init__(**kwargs) + + self.num_classes = num_classes + self.align_corners = align_corners + self.out_channels = num_classes + self.ignore_index = ignore_index + + feat_channels = kwargs['feat_channels'] + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + + def _seg_data_to_instance_data(self, batch_data_samples: SampleList): + """Perform forward propagation to convert paradigm from MMSegmentation + to MMDetection to ensure ``MMDET_Mask2FormerHead`` could be called + normally. Specifically, ``batch_gt_instances`` would be added. + + Args: + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two lists. + + - batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``labels``, each is + unique ground truth label id of images, with + shape (num_gt, ) and ``masks``, each is ground truth + masks of each instances of a image, shape (num_gt, h, w). + - batch_img_metas (list[dict]): List of image meta information. + """ + batch_img_metas = [] + batch_gt_instances = [] + + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + gt_sem_seg = data_sample.gt_sem_seg.data + classes = torch.unique( + gt_sem_seg, + sorted=False, + return_inverse=False, + return_counts=False) + + # remove ignored region + gt_labels = classes[classes != self.ignore_index] + + masks = [] + for class_id in gt_labels: + masks.append(gt_sem_seg == class_id) + + if len(masks) == 0: + gt_masks = torch.zeros( + (0, gt_sem_seg.shape[-2], + gt_sem_seg.shape[-1])).to(gt_sem_seg).long() + else: + gt_masks = torch.stack(masks).squeeze(1).long() + + instance_data = InstanceData(labels=gt_labels, masks=gt_masks) + batch_gt_instances.append(instance_data) + return batch_gt_instances, batch_img_metas + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Perform forward propagation and loss calculation of the decoder head + on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + train_cfg (ConfigType): Training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + # batch SegDataSample to InstanceDataSample + batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data( + batch_data_samples) + + # forward + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + + # loss + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, + batch_gt_instances, batch_img_metas) + + return losses + + def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tuple[Tensor]: + """Test without augmentaton. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_img_metas (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + test_cfg (ConfigType): Test config. + + Returns: + Tensor: A tensor of segmentation mask. + """ + batch_data_samples = [ + SegDataSample(metainfo=metainfo) for metainfo in batch_img_metas + ] + + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + if 'pad_shape' in batch_img_metas[0]: + size = batch_img_metas[0]['pad_shape'] + else: + size = batch_img_metas[0]['img_shape'] + # upsample mask + mask_pred_results = F.interpolate( + mask_pred_results, size=size, mode='bilinear', align_corners=False) + cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1] + mask_pred = mask_pred_results.sigmoid() + seg_logits = torch.einsum('bqc, bqhw->bchw', cls_score, mask_pred) + return seg_logits diff --git a/finetune/mmseg/models/decode_heads/maskformer_head.py b/finetune/mmseg/models/decode_heads/maskformer_head.py new file mode 100644 index 0000000..6e61a7f --- /dev/null +++ b/finetune/mmseg/models/decode_heads/maskformer_head.py @@ -0,0 +1,174 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +try: + from mmdet.models.dense_heads import MaskFormerHead as MMDET_MaskFormerHead +except ModuleNotFoundError: + MMDET_MaskFormerHead = BaseModule + +from mmengine.structures import InstanceData +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures.seg_data_sample import SegDataSample +from mmseg.utils import ConfigType, SampleList + + +@MODELS.register_module() +class MaskFormerHead(MMDET_MaskFormerHead): + """Implements the MaskFormer head. + + See `Per-Pixel Classification is Not All You Need for Semantic Segmentation + `_ for details. + + Args: + num_classes (int): Number of classes. Default: 150. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + ignore_index (int): The label index to be ignored. Default: 255. + """ + + def __init__(self, + num_classes: int = 150, + align_corners: bool = False, + ignore_index: int = 255, + **kwargs) -> None: + super().__init__(**kwargs) + + self.out_channels = kwargs['out_channels'] + self.align_corners = True + self.num_classes = num_classes + self.align_corners = align_corners + self.out_channels = num_classes + self.ignore_index = ignore_index + + feat_channels = kwargs['feat_channels'] + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + + def _seg_data_to_instance_data(self, batch_data_samples: SampleList): + """Perform forward propagation to convert paradigm from MMSegmentation + to MMDetection to ensure ``MMDET_MaskFormerHead`` could be called + normally. Specifically, ``batch_gt_instances`` would be added. + + Args: + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two lists. + + - batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``labels``, each is + unique ground truth label id of images, with + shape (num_gt, ) and ``masks``, each is ground truth + masks of each instances of a image, shape (num_gt, h, w). + - batch_img_metas (list[dict]): List of image meta information. + """ + batch_img_metas = [] + batch_gt_instances = [] + for data_sample in batch_data_samples: + # Add `batch_input_shape` in metainfo of data_sample, which would + # be used in MaskFormerHead of MMDetection. + metainfo = data_sample.metainfo + metainfo['batch_input_shape'] = metainfo['img_shape'] + data_sample.set_metainfo(metainfo) + batch_img_metas.append(data_sample.metainfo) + gt_sem_seg = data_sample.gt_sem_seg.data + classes = torch.unique( + gt_sem_seg, + sorted=False, + return_inverse=False, + return_counts=False) + + # remove ignored region + gt_labels = classes[classes != self.ignore_index] + + masks = [] + for class_id in gt_labels: + masks.append(gt_sem_seg == class_id) + + if len(masks) == 0: + gt_masks = torch.zeros((0, gt_sem_seg.shape[-2], + gt_sem_seg.shape[-1])).to(gt_sem_seg) + else: + gt_masks = torch.stack(masks).squeeze(1) + + instance_data = InstanceData( + labels=gt_labels, masks=gt_masks.long()) + batch_gt_instances.append(instance_data) + return batch_gt_instances, batch_img_metas + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Perform forward propagation and loss calculation of the decoder head + on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + train_cfg (ConfigType): Training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + # batch SegDataSample to InstanceDataSample + batch_gt_instances, batch_img_metas = self._seg_data_to_instance_data( + batch_data_samples) + + # forward + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + + # loss + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, + batch_gt_instances, batch_img_metas) + + return losses + + def predict(self, x: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tuple[Tensor]: + """Test without augmentaton. + + Args: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_img_metas (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + test_cfg (ConfigType): Test config. + + Returns: + Tensor: A tensor of segmentation mask. + """ + + batch_data_samples = [] + for metainfo in batch_img_metas: + metainfo['batch_input_shape'] = metainfo['img_shape'] + batch_data_samples.append(SegDataSample(metainfo=metainfo)) + # Forward function of MaskFormerHead from MMDetection needs + # 'batch_data_samples' as inputs, which is image shape actually. + all_cls_scores, all_mask_preds = self(x, batch_data_samples) + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + + # upsample masks + img_shape = batch_img_metas[0]['batch_input_shape'] + mask_pred_results = F.interpolate( + mask_pred_results, + size=img_shape, + mode='bilinear', + align_corners=False) + + # semantic inference + cls_score = F.softmax(mask_cls_results, dim=-1)[..., :-1] + mask_pred = mask_pred_results.sigmoid() + seg_logits = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred) + return seg_logits diff --git a/finetune/mmseg/models/decode_heads/nl_head.py b/finetune/mmseg/models/decode_heads/nl_head.py new file mode 100644 index 0000000..0ffcc2a --- /dev/null +++ b/finetune/mmseg/models/decode_heads/nl_head.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import NonLocal2d + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + + +@MODELS.register_module() +class NLHead(FCNHead): + """Non-local Neural Networks. + + This head is the implementation of `NLNet + `_. + + Args: + reduction (int): Reduction factor of projection transform. Default: 2. + use_scale (bool): Whether to scale pairwise_weight by + sqrt(1/inter_channels). Default: True. + mode (str): The nonlocal mode. Options are 'embedded_gaussian', + 'dot_product'. Default: 'embedded_gaussian.'. + """ + + def __init__(self, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + **kwargs): + super().__init__(num_convs=2, **kwargs) + self.reduction = reduction + self.use_scale = use_scale + self.mode = mode + self.nl_block = NonLocal2d( + in_channels=self.channels, + reduction=self.reduction, + use_scale=self.use_scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + mode=self.mode) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + output = self.convs[0](x) + output = self.nl_block(output) + output = self.convs[1](output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/ocr_head.py b/finetune/mmseg/models/decode_heads/ocr_head.py new file mode 100644 index 0000000..9afe37b --- /dev/null +++ b/finetune/mmseg/models/decode_heads/ocr_head.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from ..utils import resize +from .cascade_decode_head import BaseCascadeDecodeHead + + +class SpatialGatherModule(nn.Module): + """Aggregate the context features according to the initial predicted + probability distribution. + + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, scale): + super().__init__() + self.scale = scale + + def forward(self, feats, probs): + """Forward function.""" + batch_size, num_classes, height, width = probs.size() + channels = feats.size(1) + probs = probs.view(batch_size, num_classes, -1) + feats = feats.view(batch_size, channels, -1) + # [batch_size, height*width, num_classes] + feats = feats.permute(0, 2, 1) + # [batch_size, channels, height*width] + probs = F.softmax(self.scale * probs, dim=2) + # [batch_size, channels, num_classes] + ocr_context = torch.matmul(probs, feats) + ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3) + return ocr_context + + +class ObjectAttentionBlock(_SelfAttentionBlock): + """Make a OCR used SelfAttentionBlock.""" + + def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg, + act_cfg): + if scale > 1: + query_downsample = nn.MaxPool2d(kernel_size=scale) + else: + query_downsample = None + super().__init__( + key_in_channels=in_channels, + query_in_channels=in_channels, + channels=channels, + out_channels=in_channels, + share_key_query=False, + query_downsample=query_downsample, + key_downsample=None, + key_query_num_convs=2, + key_query_norm=True, + value_out_num_convs=1, + value_out_norm=True, + matmul_norm=True, + with_out=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.bottleneck = ConvModule( + in_channels * 2, + in_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, query_feats, key_feats): + """Forward function.""" + context = super().forward(query_feats, key_feats) + output = self.bottleneck(torch.cat([context, query_feats], dim=1)) + if self.query_downsample is not None: + output = resize(query_feats) + + return output + + +@MODELS.register_module() +class OCRHead(BaseCascadeDecodeHead): + """Object-Contextual Representations for Semantic Segmentation. + + This head is the implementation of `OCRNet + `_. + + Args: + ocr_channels (int): The intermediate channels of OCR block. + scale (int): The scale of probability map in SpatialGatherModule in + Default: 1. + """ + + def __init__(self, ocr_channels, scale=1, **kwargs): + super().__init__(**kwargs) + self.ocr_channels = ocr_channels + self.scale = scale + self.object_context_block = ObjectAttentionBlock( + self.channels, + self.ocr_channels, + self.scale, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.spatial_gather_module = SpatialGatherModule(self.scale) + + self.bottleneck = ConvModule( + self.in_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs, prev_output): + """Forward function.""" + x = self._transform_inputs(inputs) + feats = self.bottleneck(x) + context = self.spatial_gather_module(feats, prev_output) + object_context = self.object_context_block(feats, context) + output = self.cls_seg(object_context) + + return output diff --git a/finetune/mmseg/models/decode_heads/pid_head.py b/finetune/mmseg/models/decode_heads/pid_head.py new file mode 100644 index 0000000..c092cb3 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/pid_head.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer +from mmengine.model import BaseModule +from torch import Tensor + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.models.losses import accuracy +from mmseg.models.utils import resize +from mmseg.registry import MODELS +from mmseg.utils import OptConfigType, SampleList + + +class BasePIDHead(BaseModule): + """Base class for PID head. + + Args: + in_channels (int): Number of input channels. + channels (int): Number of output channels. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + init_cfg (dict or list[dict], optional): Init config dict. + Default: None. + """ + + def __init__(self, + in_channels: int, + channels: int, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.conv = ConvModule( + in_channels, + channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + order=('norm', 'act', 'conv')) + _, self.norm = build_norm_layer(norm_cfg, num_features=channels) + self.act = build_activation_layer(act_cfg) + + def forward(self, x: Tensor, cls_seg: Optional[nn.Module]) -> Tensor: + """Forward function. + Args: + x (Tensor): Input tensor. + cls_seg (nn.Module, optional): The classification head. + + Returns: + Tensor: Output tensor. + """ + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + if cls_seg is not None: + x = cls_seg(x) + return x + + +@MODELS.register_module() +class PIDHead(BaseDecodeHead): + """Decode head for PIDNet. + + Args: + in_channels (int): Number of input channels. + channels (int): Number of output channels. + num_classes (int): Number of classes. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU', inplace=True). + """ + + def __init__(self, + in_channels: int, + channels: int, + num_classes: int, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + **kwargs): + super().__init__( + in_channels, + channels, + num_classes=num_classes, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs) + self.i_head = BasePIDHead(in_channels, channels, norm_cfg, act_cfg) + self.p_head = BasePIDHead(in_channels // 2, channels, norm_cfg, + act_cfg) + self.d_head = BasePIDHead( + in_channels // 2, + in_channels // 4, + norm_cfg, + ) + self.p_cls_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) + self.d_cls_seg = nn.Conv2d(in_channels // 4, 1, kernel_size=1) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward( + self, + inputs: Union[Tensor, + Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]: + """Forward function. + Args: + inputs (Tensor | tuple[Tensor]): Input tensor or tuple of + Tensor. When training, the input is a tuple of three tensors, + (p_feat, i_feat, d_feat), and the output is a tuple of three + tensors, (p_seg_logit, i_seg_logit, d_seg_logit). + When inference, only the head of integral branch is used, and + input is a tensor of integral feature map, and the output is + the segmentation logit. + + Returns: + Tensor | tuple[Tensor]: Output tensor or tuple of tensors. + """ + if self.training: + x_p, x_i, x_d = inputs + x_p = self.p_head(x_p, self.p_cls_seg) + x_i = self.i_head(x_i, self.cls_seg) + x_d = self.d_head(x_d, self.d_cls_seg) + return x_p, x_i, x_d + else: + return self.i_head(inputs, self.cls_seg) + + def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tuple[Tensor]: + gt_semantic_segs = [ + data_sample.gt_sem_seg.data for data_sample in batch_data_samples + ] + gt_edge_segs = [ + data_sample.gt_edge_map.data for data_sample in batch_data_samples + ] + gt_sem_segs = torch.stack(gt_semantic_segs, dim=0) + gt_edge_segs = torch.stack(gt_edge_segs, dim=0) + return gt_sem_segs, gt_edge_segs + + def loss_by_feat(self, seg_logits: Tuple[Tensor], + batch_data_samples: SampleList) -> dict: + loss = dict() + p_logit, i_logit, d_logit = seg_logits + sem_label, bd_label = self._stack_batch_gt(batch_data_samples) + p_logit = resize( + input=p_logit, + size=sem_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + i_logit = resize( + input=i_logit, + size=sem_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + d_logit = resize( + input=d_logit, + size=bd_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + sem_label = sem_label.squeeze(1) + bd_label = bd_label.squeeze(1) + loss['loss_sem_p'] = self.loss_decode[0]( + p_logit, sem_label, ignore_index=self.ignore_index) + loss['loss_sem_i'] = self.loss_decode[1](i_logit, sem_label) + loss['loss_bd'] = self.loss_decode[2](d_logit, bd_label) + filler = torch.ones_like(sem_label) * self.ignore_index + sem_bd_label = torch.where( + torch.sigmoid(d_logit[:, 0, :, :]) > 0.8, sem_label, filler) + loss['loss_sem_bd'] = self.loss_decode[3](i_logit, sem_bd_label) + loss['acc_seg'] = accuracy( + i_logit, sem_label, ignore_index=self.ignore_index) + return loss diff --git a/finetune/mmseg/models/decode_heads/point_head.py b/finetune/mmseg/models/decode_heads/point_head.py new file mode 100644 index 0000000..e8e433d --- /dev/null +++ b/finetune/mmseg/models/decode_heads/point_head.py @@ -0,0 +1,367 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +try: + from mmcv.ops import point_sample +except ModuleNotFoundError: + point_sample = None + +from typing import List + +from mmseg.registry import MODELS +from mmseg.utils import SampleList +from ..losses import accuracy +from ..utils import resize +from .cascade_decode_head import BaseCascadeDecodeHead + + +def calculate_uncertainty(seg_logits): + """Estimate uncertainty based on seg logits. + + For each location of the prediction ``seg_logits`` we estimate + uncertainty as the difference between top first and top second + predicted logits. + + Args: + seg_logits (Tensor): Semantic segmentation logits, + shape (batch_size, num_classes, height, width). + + Returns: + scores (Tensor): T uncertainty scores with the most uncertain + locations having the highest uncertainty score, shape ( + batch_size, 1, height, width) + """ + top2_scores = torch.topk(seg_logits, k=2, dim=1)[0] + return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) + + +@MODELS.register_module() +class PointHead(BaseCascadeDecodeHead): + """A mask point head use in PointRend. + + This head is implemented of `PointRend: Image Segmentation as + Rendering `_. + ``PointHead`` use shared multi-layer perceptron (equivalent to + nn.Conv1d) to predict the logit of input points. The fine-grained feature + and coarse feature will be concatenate together for predication. + + Args: + num_fcs (int): Number of fc layers in the head. Default: 3. + in_channels (int): Number of input channels. Default: 256. + fc_channels (int): Number of fc channels. Default: 256. + num_classes (int): Number of classes for logits. Default: 80. + class_agnostic (bool): Whether use class agnostic classification. + If so, the output channels of logits will be 1. Default: False. + coarse_pred_each_layer (bool): Whether concatenate coarse feature with + the output of each fc layer. Default: True. + conv_cfg (dict|None): Dictionary to construct and config conv layer. + Default: dict(type='Conv1d')) + norm_cfg (dict|None): Dictionary to construct and config norm layer. + Default: None. + loss_point (dict): Dictionary to construct and config loss layer of + point head. Default: dict(type='CrossEntropyLoss', use_mask=True, + loss_weight=1.0). + """ + + def __init__(self, + num_fcs=3, + coarse_pred_each_layer=True, + conv_cfg=dict(type='Conv1d'), + norm_cfg=None, + act_cfg=dict(type='ReLU', inplace=False), + **kwargs): + super().__init__( + input_transform='multiple_select', + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + init_cfg=dict( + type='Normal', std=0.01, override=dict(name='fc_seg')), + **kwargs) + if point_sample is None: + raise RuntimeError('Please install mmcv-full for ' + 'point_sample ops') + + self.num_fcs = num_fcs + self.coarse_pred_each_layer = coarse_pred_each_layer + + fc_in_channels = sum(self.in_channels) + self.num_classes + fc_channels = self.channels + self.fcs = nn.ModuleList() + for k in range(num_fcs): + fc = ConvModule( + fc_in_channels, + fc_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.fcs.append(fc) + fc_in_channels = fc_channels + fc_in_channels += self.num_classes if self.coarse_pred_each_layer \ + else 0 + self.fc_seg = nn.Conv1d( + fc_in_channels, + self.num_classes, + kernel_size=1, + stride=1, + padding=0) + if self.dropout_ratio > 0: + self.dropout = nn.Dropout(self.dropout_ratio) + delattr(self, 'conv_seg') + + def cls_seg(self, feat): + """Classify each pixel with fc.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.fc_seg(feat) + return output + + def forward(self, fine_grained_point_feats, coarse_point_feats): + x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1) + for fc in self.fcs: + x = fc(x) + if self.coarse_pred_each_layer: + x = torch.cat((x, coarse_point_feats), dim=1) + return self.cls_seg(x) + + def _get_fine_grained_point_feats(self, x, points): + """Sample from fine grained features. + + Args: + x (list[Tensor]): Feature pyramid from by neck or backbone. + points (Tensor): Point coordinates, shape (batch_size, + num_points, 2). + + Returns: + fine_grained_feats (Tensor): Sampled fine grained feature, + shape (batch_size, sum(channels of x), num_points). + """ + + fine_grained_feats_list = [ + point_sample(_, points, align_corners=self.align_corners) + for _ in x + ] + if len(fine_grained_feats_list) > 1: + fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1) + else: + fine_grained_feats = fine_grained_feats_list[0] + + return fine_grained_feats + + def _get_coarse_point_feats(self, prev_output, points): + """Sample from fine grained features. + + Args: + prev_output (list[Tensor]): Prediction of previous decode head. + points (Tensor): Point coordinates, shape (batch_size, + num_points, 2). + + Returns: + coarse_feats (Tensor): Sampled coarse feature, shape (batch_size, + num_classes, num_points). + """ + + coarse_feats = point_sample( + prev_output, points, align_corners=self.align_corners) + + return coarse_feats + + def loss(self, inputs, prev_output, batch_data_samples: SampleList, + train_cfg, **kwargs): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + batch_data_samples (list[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `img_metas` or `gt_semantic_seg`. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + x = self._transform_inputs(inputs) + with torch.no_grad(): + points = self.get_points_train( + prev_output, calculate_uncertainty, cfg=train_cfg) + fine_grained_point_feats = self._get_fine_grained_point_feats( + x, points) + coarse_point_feats = self._get_coarse_point_feats(prev_output, points) + point_logits = self.forward(fine_grained_point_feats, + coarse_point_feats) + + losses = self.loss_by_feat(point_logits, points, batch_data_samples) + + return losses + + def predict(self, inputs, prev_output, batch_img_metas: List[dict], + test_cfg, **kwargs): + """Forward function for testing. + + Args: + inputs (list[Tensor]): List of multi-level img features. + prev_output (Tensor): The output of previous decode head. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + + x = self._transform_inputs(inputs) + refined_seg_logits = prev_output.clone() + for _ in range(test_cfg.subdivision_steps): + refined_seg_logits = resize( + refined_seg_logits, + scale_factor=test_cfg.scale_factor, + mode='bilinear', + align_corners=self.align_corners) + batch_size, channels, height, width = refined_seg_logits.shape + point_indices, points = self.get_points_test( + refined_seg_logits, calculate_uncertainty, cfg=test_cfg) + fine_grained_point_feats = self._get_fine_grained_point_feats( + x, points) + coarse_point_feats = self._get_coarse_point_feats( + prev_output, points) + point_logits = self.forward(fine_grained_point_feats, + coarse_point_feats) + + point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1) + refined_seg_logits = refined_seg_logits.reshape( + batch_size, channels, height * width) + refined_seg_logits = refined_seg_logits.scatter_( + 2, point_indices, point_logits) + refined_seg_logits = refined_seg_logits.view( + batch_size, channels, height, width) + + return self.predict_by_feat(refined_seg_logits, batch_img_metas, + **kwargs) + + def loss_by_feat(self, point_logits, points, batch_data_samples, **kwargs): + """Compute segmentation loss.""" + gt_semantic_seg = self._stack_batch_gt(batch_data_samples) + point_label = point_sample( + gt_semantic_seg.float(), + points, + mode='nearest', + align_corners=self.align_corners) + point_label = point_label.squeeze(1).long() + + loss = dict() + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_module in losses_decode: + loss['point' + loss_module.loss_name] = loss_module( + point_logits, point_label, ignore_index=self.ignore_index) + + loss['acc_point'] = accuracy( + point_logits, point_label, ignore_index=self.ignore_index) + return loss + + def get_points_train(self, seg_logits, uncertainty_func, cfg): + """Sample points for training. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'uncertainty_func' function that takes point's logit prediction as + input. + + Args: + seg_logits (Tensor): Semantic segmentation logits, shape ( + batch_size, num_classes, height, width). + uncertainty_func (func): uncertainty calculation function. + cfg (dict): Training config of point head. + + Returns: + point_coords (Tensor): A tensor of shape (batch_size, num_points, + 2) that contains the coordinates of ``num_points`` sampled + points. + """ + num_points = cfg.num_points + oversample_ratio = cfg.oversample_ratio + importance_sample_ratio = cfg.importance_sample_ratio + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = seg_logits.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand( + batch_size, num_sampled, 2, device=seg_logits.device) + point_logits = point_sample(seg_logits, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = uncertainty_func(point_logits) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk( + point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange( + batch_size, dtype=torch.long, device=seg_logits.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( + batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_point_coords = torch.rand( + batch_size, num_random_points, 2, device=seg_logits.device) + point_coords = torch.cat((point_coords, rand_point_coords), dim=1) + return point_coords + + def get_points_test(self, seg_logits, uncertainty_func, cfg): + """Sample points for testing. + + Find ``num_points`` most uncertain points from ``uncertainty_map``. + + Args: + seg_logits (Tensor): A tensor of shape (batch_size, num_classes, + height, width) for class-specific or class-agnostic prediction. + uncertainty_func (func): uncertainty calculation function. + cfg (dict): Testing config of point head. + + Returns: + point_indices (Tensor): A tensor of shape (batch_size, num_points) + that contains indices from [0, height x width) of the most + uncertain points. + point_coords (Tensor): A tensor of shape (batch_size, num_points, + 2) that contains [0, 1] x [0, 1] normalized coordinates of the + most uncertain points from the ``height x width`` grid . + """ + + num_points = cfg.subdivision_num_points + uncertainty_map = uncertainty_func(seg_logits) + batch_size, _, height, width = uncertainty_map.shape + h_step = 1.0 / height + w_step = 1.0 / width + + uncertainty_map = uncertainty_map.view(batch_size, height * width) + num_points = min(height * width, num_points) + point_indices = uncertainty_map.topk(num_points, dim=1)[1] + point_coords = torch.zeros( + batch_size, + num_points, + 2, + dtype=torch.float, + device=seg_logits.device) + point_coords[:, :, 0] = w_step / 2.0 + (point_indices % + width).float() * w_step + point_coords[:, :, 1] = h_step / 2.0 + (point_indices // + width).float() * h_step + return point_indices, point_coords diff --git a/finetune/mmseg/models/decode_heads/psa_head.py b/finetune/mmseg/models/decode_heads/psa_head.py new file mode 100644 index 0000000..13ee5c5 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/psa_head.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + +try: + from mmcv.ops import PSAMask +except ModuleNotFoundError: + PSAMask = None + + +@MODELS.register_module() +class PSAHead(BaseDecodeHead): + """Point-wise Spatial Attention Network for Scene Parsing. + + This head is the implementation of `PSANet + `_. + + Args: + mask_size (tuple[int]): The PSA mask size. It usually equals input + size. + psa_type (str): The type of psa module. Options are 'collect', + 'distribute', 'bi-direction'. Default: 'bi-direction' + compact (bool): Whether use compact map for 'collect' mode. + Default: True. + shrink_factor (int): The downsample factors of psa mask. Default: 2. + normalization_factor (float): The normalize factor of attention. + psa_softmax (bool): Whether use softmax for attention. + """ + + def __init__(self, + mask_size, + psa_type='bi-direction', + compact=False, + shrink_factor=2, + normalization_factor=1.0, + psa_softmax=True, + **kwargs): + if PSAMask is None: + raise RuntimeError('Please install mmcv-full for PSAMask ops') + super().__init__(**kwargs) + assert psa_type in ['collect', 'distribute', 'bi-direction'] + self.psa_type = psa_type + self.compact = compact + self.shrink_factor = shrink_factor + self.mask_size = mask_size + mask_h, mask_w = mask_size + self.psa_softmax = psa_softmax + if normalization_factor is None: + normalization_factor = mask_h * mask_w + self.normalization_factor = normalization_factor + + self.reduce = ConvModule( + self.in_channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.attention = nn.Sequential( + ConvModule( + self.channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Conv2d( + self.channels, mask_h * mask_w, kernel_size=1, bias=False)) + if psa_type == 'bi-direction': + self.reduce_p = ConvModule( + self.in_channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.attention_p = nn.Sequential( + ConvModule( + self.channels, + self.channels, + kernel_size=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + nn.Conv2d( + self.channels, mask_h * mask_w, kernel_size=1, bias=False)) + self.psamask_collect = PSAMask('collect', mask_size) + self.psamask_distribute = PSAMask('distribute', mask_size) + else: + self.psamask = PSAMask(psa_type, mask_size) + self.proj = ConvModule( + self.channels * (2 if psa_type == 'bi-direction' else 1), + self.in_channels, + kernel_size=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + self.bottleneck = ConvModule( + self.in_channels * 2, + self.channels, + kernel_size=3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + identity = x + align_corners = self.align_corners + if self.psa_type in ['collect', 'distribute']: + out = self.reduce(x) + n, c, h, w = out.size() + if self.shrink_factor != 1: + if h % self.shrink_factor and w % self.shrink_factor: + h = (h - 1) // self.shrink_factor + 1 + w = (w - 1) // self.shrink_factor + 1 + align_corners = True + else: + h = h // self.shrink_factor + w = w // self.shrink_factor + align_corners = False + out = resize( + out, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + y = self.attention(out) + if self.compact: + if self.psa_type == 'collect': + y = y.view(n, h * w, + h * w).transpose(1, 2).view(n, h * w, h, w) + else: + y = self.psamask(y) + if self.psa_softmax: + y = F.softmax(y, dim=1) + out = torch.bmm( + out.view(n, c, h * w), y.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + else: + x_col = self.reduce(x) + x_dis = self.reduce_p(x) + n, c, h, w = x_col.size() + if self.shrink_factor != 1: + if h % self.shrink_factor and w % self.shrink_factor: + h = (h - 1) // self.shrink_factor + 1 + w = (w - 1) // self.shrink_factor + 1 + align_corners = True + else: + h = h // self.shrink_factor + w = w // self.shrink_factor + align_corners = False + x_col = resize( + x_col, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + x_dis = resize( + x_dis, + size=(h, w), + mode='bilinear', + align_corners=align_corners) + y_col = self.attention(x_col) + y_dis = self.attention_p(x_dis) + if self.compact: + y_dis = y_dis.view(n, h * w, + h * w).transpose(1, 2).view(n, h * w, h, w) + else: + y_col = self.psamask_collect(y_col) + y_dis = self.psamask_distribute(y_dis) + if self.psa_softmax: + y_col = F.softmax(y_col, dim=1) + y_dis = F.softmax(y_dis, dim=1) + x_col = torch.bmm( + x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + x_dis = torch.bmm( + x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view( + n, c, h, w) * (1.0 / self.normalization_factor) + out = torch.cat([x_col, x_dis], 1) + out = self.proj(out) + out = resize( + out, + size=identity.shape[2:], + mode='bilinear', + align_corners=align_corners) + out = self.bottleneck(torch.cat((identity, out), dim=1)) + out = self.cls_seg(out) + return out diff --git a/finetune/mmseg/models/decode_heads/psp_head.py b/finetune/mmseg/models/decode_heads/psp_head.py new file mode 100644 index 0000000..a40ec41 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/psp_head.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class PPM(nn.ModuleList): + """Pooling Pyramid Module used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + align_corners (bool): align_corners argument of F.interpolate. + """ + + def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, + act_cfg, align_corners, **kwargs): + super().__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + for pool_scale in pool_scales: + self.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(pool_scale), + ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + **kwargs))) + + def forward(self, x): + """Forward function.""" + ppm_outs = [] + for ppm in self: + ppm_out = ppm(x) + upsampled_ppm_out = resize( + ppm_out, + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs + + +@MODELS.register_module() +class PSPHead(BaseDecodeHead): + """Pyramid Scene Parsing Network. + + This head is the implementation of + `PSPNet `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. Default: (1, 2, 3, 6). + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): + super().__init__(**kwargs) + assert isinstance(pool_scales, (list, tuple)) + self.pool_scales = pool_scales + self.psp_modules = PPM( + self.pool_scales, + self.in_channels, + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.bottleneck = ConvModule( + self.in_channels + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + x = self._transform_inputs(inputs) + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + feats = self.bottleneck(psp_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/san_head.py b/finetune/mmseg/models/decode_heads/san_head.py new file mode 100644 index 0000000..d20da80 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/san_head.py @@ -0,0 +1,736 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, build_norm_layer +from mmcv.cnn.bricks.transformer import BaseTransformerLayer +from mmcv.ops import point_sample +from mmengine.dist import all_reduce +from mmengine.model.weight_init import (caffe2_xavier_init, normal_init, + trunc_normal_) +from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict +from mmengine.structures import InstanceData +from torch import Tensor +from torch.nn import functional as F + +from mmseg.models.backbones.vit import TransformerEncoderLayer +from mmseg.registry import MODELS +from mmseg.utils import (ConfigType, MatchMasks, SampleList, + seg_data_to_instance_data) +from ..utils import (MLP, LayerNorm2d, PatchEmbed, cross_attn_layer, + get_uncertain_point_coords_with_randomness, resize) +from .decode_head import BaseDecodeHead + + +class MLPMaskDecoder(nn.Module): + """Module for decoding query and visual features with MLP layers to + generate the attention biases and the mask proposals.""" + + def __init__( + self, + *, + in_channels: int, + total_heads: int = 1, + total_layers: int = 1, + embed_channels: int = 256, + mlp_channels: int = 256, + mlp_num_layers: int = 3, + rescale_attn_bias: bool = False, + ): + super().__init__() + self.total_heads = total_heads + self.total_layers = total_layers + + dense_affine_func = partial(nn.Conv2d, kernel_size=1) + # Query Branch + self.query_mlp = MLP(in_channels, mlp_channels, embed_channels, + mlp_num_layers) + # Pixel Branch + self.pix_mlp = MLP( + in_channels, + mlp_channels, + embed_channels, + mlp_num_layers, + affine_func=dense_affine_func, + ) + # Attention Bias Branch + self.attn_mlp = MLP( + in_channels, + mlp_channels, + embed_channels * self.total_heads * self.total_layers, + mlp_num_layers, + affine_func=dense_affine_func, + ) + if rescale_attn_bias: + self.bias_scaling = nn.Linear(1, 1) + else: + self.bias_scaling = nn.Identity() + + def forward(self, query: torch.Tensor, + x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward function. + Args: + query (Tensor): Query Tokens [B,N,C]. + x (Tensor): Visual features [B,C,H,W] + + Return: + mask_preds (Tensor): Mask proposals. + attn_bias (List[Tensor]): List of attention bias. + """ + query = self.query_mlp(query) + pix = self.pix_mlp(x) + b, c, h, w = pix.shape + # preidict mask + mask_preds = torch.einsum('bqc,bchw->bqhw', query, pix) + # generate attn bias + attn = self.attn_mlp(x) + attn = attn.reshape(b, self.total_layers, self.total_heads, c, h, w) + attn_bias = torch.einsum('bqc,blnchw->blnqhw', query, attn) + attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1) + attn_bias = attn_bias.chunk(self.total_layers, dim=1) + attn_bias = [attn.squeeze(1) for attn in attn_bias] + return mask_preds, attn_bias + + +class SideAdapterNetwork(nn.Module): + """Side Adapter Network for predicting mask proposals and attention bias. + + Args: + in_channels (int): Number of input channels. Default: 3. + clip_channels (int): Number of channels of visual features. + Default: 768. + embed_dims (int): embedding dimension. Default: 240. + patch_size (int): The patch size. Default: 16. + patch_bias (bool): Whether use bias in patch embedding. + Default: True. + num_queries (int): Number of queries for mask proposals. + Default: 100. + fusion_index (List[int]): The layer number of the encode + transformer to fuse with the CLIP feature. + Default: [0, 1, 2, 3]. + cfg_encoder (ConfigType): Configs for the encode layers. + cfg_decoder (ConfigType): Configs for the decode layers. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + """ + + def __init__( + self, + in_channels: int = 3, + clip_channels: int = 768, + embed_dims: int = 240, + patch_size: int = 16, + patch_bias: bool = True, + num_queries: int = 100, + fusion_index: list = [0, 1, 2, 3], + cfg_encoder: ConfigType = ..., + cfg_decoder: ConfigType = ..., + norm_cfg: dict = dict(type='LN'), + ): + super().__init__() + + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + padding=0, + input_size=(640, 640), + bias=patch_bias, + norm_cfg=None, + init_cfg=None, + ) + ori_h, ori_w = self.patch_embed.init_out_size + num_patches = ori_h * ori_w + self.pos_embed = nn.Parameter( + torch.randn(1, num_patches, embed_dims) * .02) + self.query_pos_embed = nn.Parameter( + torch.zeros(1, num_queries, embed_dims)) + self.query_embed = nn.Parameter( + torch.zeros(1, num_queries, embed_dims)) + encode_layers = [] + for i in range(cfg_encoder.num_encode_layer): + encode_layers.append( + TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=cfg_encoder.num_heads, + feedforward_channels=cfg_encoder.mlp_ratio * embed_dims, + norm_cfg=norm_cfg)) + self.encode_layers = nn.ModuleList(encode_layers) + conv_clips = [] + for i in range(len(fusion_index)): + conv_clips.append( + nn.Sequential( + LayerNorm2d(clip_channels), + ConvModule( + clip_channels, + embed_dims, + kernel_size=1, + norm_cfg=None, + act_cfg=None))) + self.conv_clips = nn.ModuleList(conv_clips) + self.fusion_index = fusion_index + self.mask_decoder = MLPMaskDecoder( + in_channels=embed_dims, + total_heads=cfg_decoder.num_heads, + total_layers=cfg_decoder.num_layers, + embed_channels=cfg_decoder.embed_channels, + mlp_channels=cfg_decoder.mlp_channels, + mlp_num_layers=cfg_decoder.num_mlp, + rescale_attn_bias=cfg_decoder.rescale) + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.query_embed, std=0.02) + nn.init.normal_(self.query_pos_embed, std=0.02) + for i in range(len(self.conv_clips)): + caffe2_xavier_init(self.conv_clips[i][1].conv) + + def fuse_clip(self, fused_index: int, x: torch.Tensor, + clip_feature: torch.Tensor, hwshape: Tuple[int, + int], L: int): + """Fuse CLIP feature and visual tokens.""" + fused_clip = (resize( + self.conv_clips[fused_index](clip_feature.contiguous()), + size=hwshape, + mode='bilinear', + align_corners=False)).permute(0, 2, 3, 1).reshape(x[:, -L:, + ...].shape) + x = torch.cat([x[:, :-L, ...], x[:, -L:, ...] + fused_clip], dim=1) + return x + + def encode_feature(self, image: torch.Tensor, + clip_features: List[torch.Tensor], + deep_supervision_idxs: List[int]) -> List[List]: + """Encode images by a lightweight vision transformer.""" + assert len(self.fusion_index) == len(clip_features) + x, hwshape = self.patch_embed(image) + ori_h, ori_w = self.patch_embed.init_out_size + pos_embed = self.pos_embed + if self.pos_embed.shape[1] != x.shape[1]: + # resize the position embedding + pos_embed = ( + resize( + self.pos_embed.reshape(1, ori_h, ori_w, + -1).permute(0, 3, 1, 2), + size=hwshape, + mode='bicubic', + align_corners=False, + ).flatten(2).permute(0, 2, 1)) + pos_embed = torch.cat([ + self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed + ], + dim=1) + x = torch.cat([self.query_embed.expand(x.shape[0], -1, -1), x], dim=1) + x = x + pos_embed + L = hwshape[0] * hwshape[1] + fused_index = 0 + if self.fusion_index[fused_index] == 0: + x = self.fuse_clip(fused_index, x, clip_features[0][0], hwshape, L) + fused_index += 1 + outs = [] + for index, block in enumerate(self.encode_layers, start=1): + x = block(x) + if index < len(self.fusion_index + ) and index == self.fusion_index[fused_index]: + x = self.fuse_clip(fused_index, x, + clip_features[fused_index][0], hwshape, L) + fused_index += 1 + x_query = x[:, :-L, ...] + x_feat = x[:, -L:, ...].permute(0, 2, 1)\ + .reshape(x.shape[0], x.shape[-1], hwshape[0], hwshape[1]) + + if index in deep_supervision_idxs or index == len( + self.encode_layers): + outs.append({'query': x_query, 'x': x_feat}) + + if index < len(self.encode_layers): + x = x + pos_embed + return outs + + def decode_feature(self, features): + mask_embeds = [] + attn_biases = [] + for feature in features: + mask_embed, attn_bias = self.mask_decoder(**feature) + mask_embeds.append(mask_embed) + attn_biases.append(attn_bias) + return mask_embeds, attn_biases + + def forward( + self, image: torch.Tensor, clip_features: List[torch.Tensor], + deep_supervision_idxs: List[int] + ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: + """Forward function.""" + features = self.encode_feature(image, clip_features, + deep_supervision_idxs) + mask_embeds, attn_biases = self.decode_feature(features) + return mask_embeds, attn_biases + + +class RecWithAttnbias(nn.Module): + """Mask recognition module by applying the attention biases to rest deeper + CLIP layers. + + Args: + sos_token_format (str): The format of sos token. It should be + chosen from ["cls_token", "learnable_token", "pos_embedding"]. + Default: 'cls_token'. + sos_token_num (int): Number of sos token. It should be equal to + the number of quries. Default: 100. + num_layers (int): Number of rest CLIP layers for mask recognition. + Default: 3. + cross_attn (bool): Whether use cross attention to update sos token. + Default: False. + embed_dims (int): The feature dimension of CLIP layers. + Default: 768. + num_heads (int): Parallel attention heads of CLIP layers. + Default: 768. + mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. + Default: 4. + qkv_bias (bool): Whether to use bias in multihead-attention. + Default: True. + out_dims (int): Number of channels of the output mask proposals. + It should be equal to the out_dims of text_encoder. + Default: 512. + final_norm (True): Whether use norm layer for sos token. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + frozen_exclude (List): List of parameters that are not to be frozen. + """ + + def __init__(self, + sos_token_format: str = 'cls_token', + sos_token_num: int = 100, + num_layers: int = 3, + cross_attn: bool = False, + embed_dims: int = 768, + num_heads: int = 12, + mlp_ratio: int = 4, + num_fcs: int = 2, + qkv_bias: bool = True, + out_dims: int = 512, + final_norm: bool = True, + act_cfg: dict = dict(type='GELU'), + norm_cfg: dict = dict(type='LN'), + frozen_exclude: List = []): + super().__init__() + + assert sos_token_format in [ + 'cls_token', 'learnable_token', 'pos_embedding' + ] + self.sos_token_format = sos_token_format + self.sos_token_num = sos_token_num + self.frozen_exclude = frozen_exclude + self.cross_attn = cross_attn + self.num_layers = num_layers + self.num_heads = num_heads + if sos_token_format in ['learnable_token', 'pos_embedding']: + self.sos_token = nn.Parameter( + torch.randn(sos_token_num, 1, self.proj.shape[0])) + self.frozen.append('sos_token') + + layers = [] + for i in range(num_layers): + layers.append( + BaseTransformerLayer( + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=embed_dims, + num_heads=num_heads, + batch_first=False, + bias=qkv_bias), + ffn_cfgs=dict( + type='FFN', + embed_dims=embed_dims, + feedforward_channels=mlp_ratio * embed_dims, + act_cfg=act_cfg), + operation_order=('norm', 'self_attn', 'norm', 'ffn'))) + self.layers = nn.ModuleList(layers) + + self.ln_post = build_norm_layer(norm_cfg, embed_dims)[1] + self.proj = nn.Linear(embed_dims, out_dims, bias=False) + + self.final_norm = final_norm + self._freeze() + + def init_weights(self, rec_state_dict): + if hasattr(self, 'sos_token'): + normal_init(self.sos_token, std=0.02) + if rec_state_dict is not None: + load_state_dict(self, rec_state_dict, strict=False, logger=None) + else: + super().init_weights() + + def _freeze(self): + if 'all' in self.frozen_exclude: + return + for name, param in self.named_parameters(): + if not any([exclude in name for exclude in self.frozen_exclude]): + param.requires_grad = False + + def _build_attn_biases(self, attn_biases, target_shape): + formatted_attn_biases = [] + for attn_bias in attn_biases: + # convert it to proper format: N*num_head,L,L + # attn_bias: [N, num_head/1, num_sos,H,W] + n, num_head, num_sos, h, w = attn_bias.shape + # reshape and downsample + attn_bias = F.adaptive_max_pool2d( + attn_bias.reshape(n, num_head * num_sos, h, w), + output_size=target_shape) + attn_bias = attn_bias.reshape(n, num_head, num_sos, *target_shape) + + true_num_head = self.num_heads + assert (num_head == 1 or num_head + == true_num_head), f'num_head={num_head} is not supported.' + if num_head == 1: + attn_bias = attn_bias.repeat(1, true_num_head, 1, 1, 1) + attn_bias = attn_bias.reshape(n * true_num_head, num_sos, -1) + L = attn_bias.shape[-1] + if self.cross_attn: + # [n*num_head, num_sos, L] + formatted_attn_biases.append(attn_bias) + else: + # [n*num_head, num_sos+1+L, num_sos+1+L] + new_attn_bias = attn_bias.new_zeros(num_sos + 1 + L, + num_sos + 1 + L) + new_attn_bias[:, :num_sos] = -100 + new_attn_bias[torch.arange(num_sos), torch.arange(num_sos)] = 0 + new_attn_bias[:num_sos, num_sos] = -100 + new_attn_bias = ( + new_attn_bias[None, ...].expand(n * true_num_head, -1, + -1).clone()) + new_attn_bias[..., :num_sos, -L:] = attn_bias + formatted_attn_biases.append(new_attn_bias) + + if len(formatted_attn_biases) == 1: + formatted_attn_biases = [ + formatted_attn_biases[0] for _ in range(self.num_layers) + ] + return formatted_attn_biases + + def forward(self, bias: List[Tensor], feature: List[Tensor]): + """Forward function to recognize the category of masks + Args: + bias (List[Tensor]): Attention bias for transformer layers + feature (List[Tensor]): Output of the image encoder, + including cls_token and img_feature. + """ + cls_token = feature[1].unsqueeze(0) + img_feature = feature[0] + b, c, h, w = img_feature.shape + # construct clip shadow features + x = torch.cat( + [cls_token, + img_feature.reshape(b, c, -1).permute(2, 0, 1)]) + + # construct sos token + if self.sos_token_format == 'cls_token': + sos_token = cls_token.repeat(self.sos_token_num, 1, 1) + elif self.sos_token_format == 'learnable_token': + sos_token = self.sos_token.expand(-1, b, -1) + elif self.sos_token_format == 'pos_embedding': + sos_token = self.sos_token.expand(-1, b, -1) + cls_token + + # construct attn bias + attn_biases = self._build_attn_biases(bias, target_shape=(h, w)) + + if self.cross_attn: + for i, block in enumerate(self.layers): + if self.cross_attn: + sos_token = cross_attn_layer( + block, + sos_token, + x[1:, ], + attn_biases[i], + ) + if i < len(self.layers) - 1: + x = block(x) + else: + x = torch.cat([sos_token, x], dim=0) + for i, block in enumerate(self.layers): + x = block(x, attn_masks=[attn_biases[i]]) + sos_token = x[:self.sos_token_num] + + sos_token = sos_token.permute(1, 0, 2) # LND -> NLD + sos_token = self.ln_post(sos_token) + sos_token = self.proj(sos_token) + if self.final_norm: + sos_token = F.normalize(sos_token, dim=-1) + return sos_token + + +@MODELS.register_module() +class SideAdapterCLIPHead(BaseDecodeHead): + """Side Adapter Network (SAN) for open-vocabulary semantic segmentation + with pre-trained vision-language model. + + This decode head is the implementation of `Side Adapter Network + for Open-Vocabulary Semantic Segmentation` + . + Modified from https://github.com/MendelXu/SAN/blob/main/san/model/side_adapter/side_adapter.py # noqa:E501 + Copyright (c) 2023 MendelXu. + Licensed under the MIT License + + Args: + num_classes (int): the number of classes. + san_cfg (ConfigType): Configs for SideAdapterNetwork module + maskgen_cfg (ConfigType): Configs for RecWithAttnbias module + """ + + def __init__(self, num_classes: int, san_cfg: ConfigType, + maskgen_cfg: ConfigType, deep_supervision_idxs: List[int], + train_cfg: ConfigType, **kwargs): + super().__init__( + in_channels=san_cfg.in_channels, + channels=san_cfg.embed_dims, + num_classes=num_classes, + **kwargs) + assert san_cfg.num_queries == maskgen_cfg.sos_token_num, \ + 'num_queries in san_cfg should be equal to sos_token_num ' \ + 'in maskgen_cfg' + del self.conv_seg + self.side_adapter_network = SideAdapterNetwork(**san_cfg) + self.rec_with_attnbias = RecWithAttnbias(**maskgen_cfg) + self.deep_supervision_idxs = deep_supervision_idxs + self.train_cfg = train_cfg + if train_cfg: + self.match_masks = MatchMasks( + num_points=train_cfg.num_points, + num_queries=san_cfg.num_queries, + num_classes=num_classes, + assigner=train_cfg.assigner) + + def init_weights(self): + + rec_state_dict = None + if isinstance(self.init_cfg, dict) and \ + self.init_cfg.get('type') == 'Pretrained_Part': + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + + rec_state_dict = checkpoint.copy() + para_prefix = 'decode_head.rec_with_attnbias' + prefix_len = len(para_prefix) + 1 + for k, v in checkpoint.items(): + rec_state_dict.pop(k) + if para_prefix in k: + rec_state_dict[k[prefix_len:]] = v + + self.side_adapter_network.init_weights() + self.rec_with_attnbias.init_weights(rec_state_dict) + + def forward(self, inputs: Tuple[Tensor], + deep_supervision_idxs) -> Tuple[List]: + """Forward function. + + Args: + inputs (Tuple[Tensor]): A triplet including images, + list of multi-level visual features from image encoder and + class embeddings from text_encoder. + + Returns: + mask_props (List[Tensor]): Mask proposals predicted by SAN. + mask_logits (List[Tensor]): Class logits of mask proposals. + """ + imgs, clip_feature, class_embeds = inputs + # predict mask proposals and attention bias + mask_props, attn_biases = self.side_adapter_network( + imgs, clip_feature, deep_supervision_idxs) + + # mask recognition with attention bias + mask_embeds = [ + self.rec_with_attnbias(att_bias, clip_feature[-1]) + for att_bias in attn_biases + ] + # Obtain class prediction of masks by comparing the similarity + # between the image token and the text embedding of class names. + mask_logits = [ + torch.einsum('bqc,nc->bqn', mask_embed, class_embeds) + for mask_embed in mask_embeds + ] + return mask_props, mask_logits + + def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], + test_cfg: ConfigType) -> Tensor: + """Forward function for prediction. + + Args: + inputs (Tuple[Tensor]): Images, visual features from image encoder + and class embedding from text encoder. + batch_img_metas (dict): List Image info where each dict may also + contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Outputs segmentation logits map. + """ + mask_props, mask_logits = self.forward(inputs, []) + + return self.predict_by_feat([mask_props[-1], mask_logits[-1]], + batch_img_metas) + + def predict_by_feat(self, seg_logits: List[Tensor], + batch_img_metas: List[dict]) -> Tensor: + """1. Transform a batch of mask proposals to the input shape. + 2. Generate segmentation map with mask proposals and class logits. + """ + mask_pred = seg_logits[0] + cls_score = seg_logits[1] + if isinstance(batch_img_metas[0]['img_shape'], torch.Size): + # slide inference + size = batch_img_metas[0]['img_shape'] + elif 'pad_shape' in batch_img_metas[0]: + size = batch_img_metas[0]['pad_shape'][:2] + else: + size = batch_img_metas[0]['img_shape'] + # upsample mask + mask_pred = F.interpolate( + mask_pred, size=size, mode='bilinear', align_corners=False) + + mask_cls = F.softmax(cls_score, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + seg_logits = torch.einsum('bqc,bqhw->bchw', mask_cls, mask_pred) + return seg_logits + + def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList, + train_cfg: ConfigType) -> dict: + """Perform forward propagation and loss calculation of the decoder head + on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`SegDataSample`]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + train_cfg (ConfigType): Training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components. + """ + # batch SegDataSample to InstanceDataSample + batch_gt_instances = seg_data_to_instance_data(self.ignore_index, + batch_data_samples) + + # forward + all_mask_props, all_mask_logits = self.forward( + x, self.deep_supervision_idxs) + + # loss + losses = self.loss_by_feat(all_mask_logits, all_mask_props, + batch_gt_instances) + + return losses + + def loss_by_feat( + self, all_cls_scores: Tensor, all_mask_preds: Tensor, + batch_gt_instances: List[InstanceData]) -> Dict[str, Tensor]: + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape (num_decoder, batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape (num_decoder, batch_size, num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + batch_gt_instances_list = [ + batch_gt_instances for _ in range(num_dec_layers) + ] + + losses = [] + for i in range(num_dec_layers): + cls_scores = all_cls_scores[i] + mask_preds = all_mask_preds[i] + # matching N mask predictions to K category labels + (labels, mask_targets, mask_weights, + avg_factor) = self.match_masks.get_targets( + cls_scores, mask_preds, batch_gt_instances_list[i]) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + num_total_masks = cls_scores.new_tensor([avg_factor], + dtype=torch.float) + all_reduce(num_total_masks, op='mean') + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] != 0: + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, + self.train_cfg.num_points, + self.train_cfg.oversample_ratio, + self.train_cfg.importance_sample_ratio) + # shape (num_total_gts, h, w) + # -> (num_total_gts, num_points) + mask_point_targets = point_sample( + mask_targets.unsqueeze(1).float(), + points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample( + mask_preds.unsqueeze(1), points_coords).squeeze(1) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + loss = dict() + for loss_decode in losses_decode: + if 'loss_cls' in loss_decode.loss_name: + if loss_decode.loss_name == 'loss_cls_ce': + loss[loss_decode.loss_name] = loss_decode( + cls_scores, labels) + else: + assert False, "Only support 'CrossEntropyLoss' in" \ + ' classification loss' + + elif 'loss_mask' in loss_decode.loss_name: + if mask_targets.shape[0] == 0: + loss[loss_decode.loss_name] = mask_preds.sum() + elif loss_decode.loss_name == 'loss_mask_ce': + loss[loss_decode.loss_name] = loss_decode( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks * + self.train_cfg.num_points) + elif loss_decode.loss_name == 'loss_mask_dice': + loss[loss_decode.loss_name] = loss_decode( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks) + else: + assert False, "Only support 'CrossEntropyLoss' and" \ + " 'DiceLoss' in mask loss" + else: + assert False, "Only support for 'loss_cls' and 'loss_mask'" + + losses.append(loss) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict.update(losses[-1]) + # loss from other decoder layers + for i, loss in enumerate(losses[:-1]): + for k, v in loss.items(): + loss_dict[f'd{self.deep_supervision_idxs[i]}.{k}'] = v + return loss_dict diff --git a/finetune/mmseg/models/decode_heads/segformer_head.py b/finetune/mmseg/models/decode_heads/segformer_head.py new file mode 100644 index 0000000..f9eb0b3 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/segformer_head.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.registry import MODELS +from ..utils import resize + + +@MODELS.register_module() +class SegformerHead(BaseDecodeHead): + """The all mlp Head of segformer. + + This head is the implementation of + `Segformer ` _. + + Args: + interpolate_mode: The interpolate mode of MLP head upsample operation. + Default: 'bilinear'. + """ + + def __init__(self, interpolate_mode='bilinear', **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + + self.interpolate_mode = interpolate_mode + num_inputs = len(self.in_channels) + + assert num_inputs == len(self.in_index) + + self.convs = nn.ModuleList() + for i in range(num_inputs): + self.convs.append( + ConvModule( + in_channels=self.in_channels[i], + out_channels=self.channels, + kernel_size=1, + stride=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + self.fusion_conv = ConvModule( + in_channels=self.channels * num_inputs, + out_channels=self.channels, + kernel_size=1, + norm_cfg=self.norm_cfg) + + def forward(self, inputs): + # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 + inputs = self._transform_inputs(inputs) + outs = [] + for idx in range(len(inputs)): + x = inputs[idx] + conv = self.convs[idx] + outs.append( + resize( + input=conv(x), + size=inputs[0].shape[2:], + mode=self.interpolate_mode, + align_corners=self.align_corners)) + + out = self.fusion_conv(torch.cat(outs, dim=1)) + + out = self.cls_seg(out) + + return out diff --git a/finetune/mmseg/models/decode_heads/segmenter_mask_head.py b/finetune/mmseg/models/decode_heads/segmenter_mask_head.py new file mode 100644 index 0000000..85d2773 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/segmenter_mask_head.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer +from mmengine.model import ModuleList +from mmengine.model.weight_init import (constant_init, trunc_normal_, + trunc_normal_init) + +from mmseg.models.backbones.vit import TransformerEncoderLayer +from mmseg.registry import MODELS +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class SegmenterMaskTransformerHead(BaseDecodeHead): + """Segmenter: Transformer for Semantic Segmentation. + + This head is the implementation of + `Segmenter: `_. + + Args: + backbone_cfg:(dict): Config of backbone of + Context Path. + in_channels (int): The number of channels of input image. + num_layers (int): The depth of transformer. + num_heads (int): The number of attention heads. + embed_dims (int): The number of embedding dimension. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + drop_path_rate (float): stochastic depth rate. Default 0.1. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + init_std (float): The value of std in weight initialization. + Default: 0.02. + """ + + def __init__( + self, + in_channels, + num_layers, + num_heads, + embed_dims, + mlp_ratio=4, + drop_path_rate=0.1, + drop_rate=0.0, + attn_drop_rate=0.0, + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_std=0.02, + **kwargs, + ): + super().__init__(in_channels=in_channels, **kwargs) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] + self.layers = ModuleList() + for i in range(num_layers): + self.layers.append( + TransformerEncoderLayer( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * embed_dims, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + batch_first=True, + )) + + self.dec_proj = nn.Linear(in_channels, embed_dims) + + self.cls_emb = nn.Parameter( + torch.randn(1, self.num_classes, embed_dims)) + self.patch_proj = nn.Linear(embed_dims, embed_dims, bias=False) + self.classes_proj = nn.Linear(embed_dims, embed_dims, bias=False) + + self.decoder_norm = build_norm_layer( + norm_cfg, embed_dims, postfix=1)[1] + self.mask_norm = build_norm_layer( + norm_cfg, self.num_classes, postfix=2)[1] + + self.init_std = init_std + + delattr(self, 'conv_seg') + + def init_weights(self): + trunc_normal_(self.cls_emb, std=self.init_std) + trunc_normal_init(self.patch_proj, std=self.init_std) + trunc_normal_init(self.classes_proj, std=self.init_std) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=self.init_std, bias=0) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.0) + + def forward(self, inputs): + x = self._transform_inputs(inputs) + b, c, h, w = x.shape + x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c) + + x = self.dec_proj(x) + cls_emb = self.cls_emb.expand(x.size(0), -1, -1) + x = torch.cat((x, cls_emb), 1) + for layer in self.layers: + x = layer(x) + x = self.decoder_norm(x) + + patches = self.patch_proj(x[:, :-self.num_classes]) + cls_seg_feat = self.classes_proj(x[:, -self.num_classes:]) + + patches = F.normalize(patches, dim=2, p=2) + cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2) + + masks = patches @ cls_seg_feat.transpose(1, 2) + masks = self.mask_norm(masks) + masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w) + + return masks diff --git a/finetune/mmseg/models/decode_heads/sep_aspp_head.py b/finetune/mmseg/models/decode_heads/sep_aspp_head.py new file mode 100644 index 0000000..9dba68c --- /dev/null +++ b/finetune/mmseg/models/decode_heads/sep_aspp_head.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .aspp_head import ASPPHead, ASPPModule + + +class DepthwiseSeparableASPPModule(ASPPModule): + """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable + conv.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + for i, dilation in enumerate(self.dilations): + if dilation > 1: + self[i] = DepthwiseSeparableConvModule( + self.in_channels, + self.channels, + 3, + dilation=dilation, + padding=dilation, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + +@MODELS.register_module() +class DepthwiseSeparableASPPHead(ASPPHead): + """Encoder-Decoder with Atrous Separable Convolution for Semantic Image + Segmentation. + + This head is the implementation of `DeepLabV3+ + `_. + + Args: + c1_in_channels (int): The input channels of c1 decoder. If is 0, + the no decoder will be used. + c1_channels (int): The intermediate channels of c1 decoder. + """ + + def __init__(self, c1_in_channels, c1_channels, **kwargs): + super().__init__(**kwargs) + assert c1_in_channels >= 0 + self.aspp_modules = DepthwiseSeparableASPPModule( + dilations=self.dilations, + in_channels=self.in_channels, + channels=self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if c1_in_channels > 0: + self.c1_bottleneck = ConvModule( + c1_in_channels, + c1_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + else: + self.c1_bottleneck = None + self.sep_bottleneck = nn.Sequential( + DepthwiseSeparableConvModule( + self.channels + c1_channels, + self.channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + DepthwiseSeparableConvModule( + self.channels, + self.channels, + 3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + aspp_outs = [ + resize( + self.image_pool(x), + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + ] + aspp_outs.extend(self.aspp_modules(x)) + aspp_outs = torch.cat(aspp_outs, dim=1) + output = self.bottleneck(aspp_outs) + if self.c1_bottleneck is not None: + c1_output = self.c1_bottleneck(inputs[0]) + output = resize( + input=output, + size=c1_output.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + output = torch.cat([output, c1_output], dim=1) + output = self.sep_bottleneck(output) + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/sep_fcn_head.py b/finetune/mmseg/models/decode_heads/sep_fcn_head.py new file mode 100644 index 0000000..3b15983 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/sep_fcn_head.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import DepthwiseSeparableConvModule + +from mmseg.registry import MODELS +from .fcn_head import FCNHead + + +@MODELS.register_module() +class DepthwiseSeparableFCNHead(FCNHead): + """Depthwise-Separable Fully Convolutional Network for Semantic + Segmentation. + + This head is implemented according to `Fast-SCNN: Fast Semantic + Segmentation Network `_. + + Args: + in_channels(int): Number of output channels of FFM. + channels(int): Number of middle-stage channels in the decode head. + concat_input(bool): Whether to concatenate original decode input into + the result of several consecutive convolution layers. + Default: True. + num_classes(int): Used to determine the dimension of + final prediction tensor. + in_index(int): Correspond with 'out_indices' in FastSCNN backbone. + norm_cfg (dict | None): Config of norm layers. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + loss_decode(dict): Config of loss type and some + relevant additional options. + dw_act_cfg (dict):Activation config of depthwise ConvModule. If it is + 'default', it will be the same as `act_cfg`. Default: None. + """ + + def __init__(self, dw_act_cfg=None, **kwargs): + super().__init__(**kwargs) + self.convs[0] = DepthwiseSeparableConvModule( + self.in_channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) + + for i in range(1, self.num_convs): + self.convs[i] = DepthwiseSeparableConvModule( + self.channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) + + if self.concat_input: + self.conv_cat = DepthwiseSeparableConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + norm_cfg=self.norm_cfg, + dw_act_cfg=dw_act_cfg) diff --git a/finetune/mmseg/models/decode_heads/setr_mla_head.py b/finetune/mmseg/models/decode_heads/setr_mla_head.py new file mode 100644 index 0000000..1975991 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/setr_mla_head.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import Upsample +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class SETRMLAHead(BaseDecodeHead): + """Multi level feature aggretation head of SETR. + + MLA head of `SETR `_. + + Args: + mlahead_channels (int): Channels of conv-conv-4x of multi-level feature + aggregation. Default: 128. + up_scale (int): The scale factor of interpolate. Default:4. + """ + + def __init__(self, mla_channels=128, up_scale=4, **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + self.mla_channels = mla_channels + + num_inputs = len(self.in_channels) + + # Refer to self.cls_seg settings of BaseDecodeHead + assert self.channels == num_inputs * mla_channels + + self.up_convs = nn.ModuleList() + for i in range(num_inputs): + self.up_convs.append( + nn.Sequential( + ConvModule( + in_channels=self.in_channels[i], + out_channels=mla_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + ConvModule( + in_channels=mla_channels, + out_channels=mla_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + Upsample( + scale_factor=up_scale, + mode='bilinear', + align_corners=self.align_corners))) + + def forward(self, inputs): + inputs = self._transform_inputs(inputs) + outs = [] + for x, up_conv in zip(inputs, self.up_convs): + outs.append(up_conv(x)) + out = torch.cat(outs, dim=1) + out = self.cls_seg(out) + return out diff --git a/finetune/mmseg/models/decode_heads/setr_up_head.py b/finetune/mmseg/models/decode_heads/setr_up_head.py new file mode 100644 index 0000000..9c796d8 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/setr_up_head.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, build_norm_layer + +from mmseg.registry import MODELS +from ..utils import Upsample +from .decode_head import BaseDecodeHead + + +@MODELS.register_module() +class SETRUPHead(BaseDecodeHead): + """Naive upsampling head and Progressive upsampling head of SETR. + + Naive or PUP head of `SETR `_. + + Args: + norm_layer (dict): Config dict for input normalization. + Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). + num_convs (int): Number of decoder convolutions. Default: 1. + up_scale (int): The scale factor of interpolate. Default:4. + kernel_size (int): The kernel size of convolution when decoding + feature information from backbone. Default: 3. + init_cfg (dict | list[dict] | None): Initialization config dict. + Default: dict( + type='Constant', val=1.0, bias=0, layer='LayerNorm'). + """ + + def __init__(self, + norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), + num_convs=1, + up_scale=4, + kernel_size=3, + init_cfg=[ + dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'), + dict( + type='Normal', + std=0.01, + override=dict(name='conv_seg')) + ], + **kwargs): + + assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.' + + super().__init__(init_cfg=init_cfg, **kwargs) + + assert isinstance(self.in_channels, int) + + _, self.norm = build_norm_layer(norm_layer, self.in_channels) + + self.up_convs = nn.ModuleList() + in_channels = self.in_channels + out_channels = self.channels + for _ in range(num_convs): + self.up_convs.append( + nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=int(kernel_size - 1) // 2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg), + Upsample( + scale_factor=up_scale, + mode='bilinear', + align_corners=self.align_corners))) + in_channels = out_channels + + def forward(self, x): + x = self._transform_inputs(x) + + n, c, h, w = x.shape + x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() + x = self.norm(x) + x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() + + for up_conv in self.up_convs: + x = up_conv(x) + out = self.cls_seg(x) + return out diff --git a/finetune/mmseg/models/decode_heads/stdc_head.py b/finetune/mmseg/models/decode_heads/stdc_head.py new file mode 100644 index 0000000..1c1c21e --- /dev/null +++ b/finetune/mmseg/models/decode_heads/stdc_head.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from mmengine.structures import PixelData +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample +from mmseg.utils import SampleList +from .fcn_head import FCNHead + + +@MODELS.register_module() +class STDCHead(FCNHead): + """This head is the implementation of `Rethinking BiSeNet For Real-time + Semantic Segmentation `_. + + Args: + boundary_threshold (float): The threshold of calculating boundary. + Default: 0.1. + """ + + def __init__(self, boundary_threshold=0.1, **kwargs): + super().__init__(**kwargs) + self.boundary_threshold = boundary_threshold + # Using register buffer to make laplacian kernel on the same + # device of `seg_label`. + self.register_buffer( + 'laplacian_kernel', + torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1], + dtype=torch.float32, + requires_grad=False).reshape((1, 1, 3, 3))) + self.fusion_kernel = torch.nn.Parameter( + torch.tensor([[6. / 10], [3. / 10], [1. / 10]], + dtype=torch.float32).reshape(1, 3, 1, 1), + requires_grad=False) + + def loss_by_feat(self, seg_logits: Tensor, + batch_data_samples: SampleList) -> dict: + """Compute Detail Aggregation Loss.""" + # Note: The paper claims `fusion_kernel` is a trainable 1x1 conv + # parameters. However, it is a constant in original repo and other + # codebase because it would not be added into computation graph + # after threshold operation. + seg_label = self._stack_batch_gt(batch_data_samples).to( + self.laplacian_kernel) + boundary_targets = F.conv2d( + seg_label, self.laplacian_kernel, padding=1) + boundary_targets = boundary_targets.clamp(min=0) + boundary_targets[boundary_targets > self.boundary_threshold] = 1 + boundary_targets[boundary_targets <= self.boundary_threshold] = 0 + + boundary_targets_x2 = F.conv2d( + seg_label, self.laplacian_kernel, stride=2, padding=1) + boundary_targets_x2 = boundary_targets_x2.clamp(min=0) + + boundary_targets_x4 = F.conv2d( + seg_label, self.laplacian_kernel, stride=4, padding=1) + boundary_targets_x4 = boundary_targets_x4.clamp(min=0) + + boundary_targets_x4_up = F.interpolate( + boundary_targets_x4, boundary_targets.shape[2:], mode='nearest') + boundary_targets_x2_up = F.interpolate( + boundary_targets_x2, boundary_targets.shape[2:], mode='nearest') + + boundary_targets_x2_up[ + boundary_targets_x2_up > self.boundary_threshold] = 1 + boundary_targets_x2_up[ + boundary_targets_x2_up <= self.boundary_threshold] = 0 + + boundary_targets_x4_up[ + boundary_targets_x4_up > self.boundary_threshold] = 1 + boundary_targets_x4_up[ + boundary_targets_x4_up <= self.boundary_threshold] = 0 + + boundary_targets_pyramids = torch.stack( + (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), + dim=1) + + boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2) + boudary_targets_pyramid = F.conv2d(boundary_targets_pyramids, + self.fusion_kernel) + + boudary_targets_pyramid[ + boudary_targets_pyramid > self.boundary_threshold] = 1 + boudary_targets_pyramid[ + boudary_targets_pyramid <= self.boundary_threshold] = 0 + + seg_labels = boudary_targets_pyramid.long() + batch_sample_list = [] + for label in seg_labels: + seg_data_sample = SegDataSample() + seg_data_sample.gt_sem_seg = PixelData(data=label) + batch_sample_list.append(seg_data_sample) + + loss = super().loss_by_feat(seg_logits, batch_sample_list) + return loss diff --git a/finetune/mmseg/models/decode_heads/uper_head.py b/finetune/mmseg/models/decode_heads/uper_head.py new file mode 100644 index 0000000..b1ccc31 --- /dev/null +++ b/finetune/mmseg/models/decode_heads/uper_head.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead +from .psp_head import PPM + + +@MODELS.register_module() +class UPerHead(BaseDecodeHead): + """Unified Perceptual Parsing for Scene Understanding. + + This head is the implementation of `UPerNet + `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module applied on the last feature. Default: (1, 2, 3, 6). + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + # PSP Module + self.psp_modules = PPM( + pool_scales, + self.in_channels[-1], + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.bottleneck = ConvModule( + self.in_channels[-1] + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = ConvModule( + in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False) + fpn_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = ConvModule( + len(self.in_channels) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def psp_forward(self, inputs): + """Forward function of PSP module.""" + x = inputs[-1] + psp_outs = [x] + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + inputs = self._transform_inputs(inputs) + + # build laterals + laterals = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + laterals.append(self.psp_forward(inputs)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], + size=prev_shape, + mode='bilinear', + align_corners=self.align_corners) + + # build outputs + fpn_outs = [ + self.fpn_convs[i](laterals[i]) + for i in range(used_backbone_levels - 1) + ] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = resize( + fpn_outs[i], + size=fpn_outs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) + fpn_outs = torch.cat(fpn_outs, dim=1) + feats = self.fpn_bottleneck(fpn_outs) + return feats + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/finetune/mmseg/models/decode_heads/vpd_depth_head.py b/finetune/mmseg/models/decode_heads/vpd_depth_head.py new file mode 100644 index 0000000..65bdfbd --- /dev/null +++ b/finetune/mmseg/models/decode_heads/vpd_depth_head.py @@ -0,0 +1,253 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer +from mmengine.model import BaseModule +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import SampleList +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class VPDDepthDecoder(BaseModule): + """VPD Depth Decoder class. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_deconv_layers (int): Number of deconvolution layers. + num_deconv_filters (List[int]): List of output channels for + deconvolution layers. + init_cfg (Optional[Union[Dict, List[Dict]]], optional): Configuration + for weight initialization. Defaults to Normal for Conv2d and + ConvTranspose2d layers. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + num_deconv_layers: int, + num_deconv_filters: List[int], + init_cfg: Optional[Union[Dict, List[Dict]]] = dict( + type='Normal', + std=0.001, + layer=['Conv2d', 'ConvTranspose2d'])): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + + self.deconv_layers = self._make_deconv_layer( + num_deconv_layers, + num_deconv_filters, + ) + + conv_layers = [] + conv_layers.append( + build_conv_layer( + dict(type='Conv2d'), + in_channels=num_deconv_filters[-1], + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1)) + conv_layers.append(build_norm_layer(dict(type='BN'), out_channels)[1]) + conv_layers.append(nn.ReLU(inplace=True)) + self.conv_layers = nn.Sequential(*conv_layers) + + self.up_sample = nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False) + + def forward(self, x): + """Forward pass through the decoder network.""" + out = self.deconv_layers(x) + out = self.conv_layers(out) + + out = self.up_sample(out) + out = self.up_sample(out) + + return out + + def _make_deconv_layer(self, num_layers, num_deconv_filters): + """Make deconv layers.""" + + layers = [] + in_channels = self.in_channels + for i in range(num_layers): + + num_channels = num_deconv_filters[i] + layers.append( + build_upsample_layer( + dict(type='deconv'), + in_channels=in_channels, + out_channels=num_channels, + kernel_size=2, + stride=2, + padding=0, + output_padding=0, + bias=False)) + layers.append(nn.BatchNorm2d(num_channels)) + layers.append(nn.ReLU(inplace=True)) + in_channels = num_channels + + return nn.Sequential(*layers) + + +@MODELS.register_module() +class VPDDepthHead(BaseDecodeHead): + """Depth Prediction Head for VPD. + + .. _`VPD`: https://arxiv.org/abs/2303.02153 + + Args: + max_depth (float): Maximum depth value. Defaults to 10.0. + in_channels (Sequence[int]): Number of input channels for each + convolutional layer. + embed_dim (int): Dimension of embedding. Defaults to 192. + feature_dim (int): Dimension of aggregated feature. Defaults to 1536. + num_deconv_layers (int): Number of deconvolution layers in the + decoder. Defaults to 3. + num_deconv_filters (Sequence[int]): Number of filters for each deconv + layer. Defaults to (32, 32, 32). + fmap_border (Union[int, Sequence[int]]): Feature map border for + cropping. Defaults to 0. + align_corners (bool): Flag for align_corners in interpolation. + Defaults to False. + loss_decode (dict): Configurations for the loss function. Defaults to + dict(type='SiLogLoss'). + init_cfg (dict): Initialization configurations. Defaults to + dict(type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']). + """ + + num_classes = 1 + out_channels = 1 + input_transform = None + + def __init__( + self, + max_depth: float = 10.0, + in_channels: Sequence[int] = [320, 640, 1280, 1280], + embed_dim: int = 192, + feature_dim: int = 1536, + num_deconv_layers: int = 3, + num_deconv_filters: Sequence[int] = (32, 32, 32), + fmap_border: Union[int, Sequence[int]] = 0, + align_corners: bool = False, + loss_decode: dict = dict(type='SiLogLoss'), + init_cfg=dict( + type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']), + ): + + super(BaseDecodeHead, self).__init__(init_cfg=init_cfg) + + # initialize parameters + self.in_channels = in_channels + self.max_depth = max_depth + self.align_corners = align_corners + + # feature map border + if isinstance(fmap_border, int): + fmap_border = (fmap_border, fmap_border) + self.fmap_border = fmap_border + + # define network layers + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1), + nn.GroupNorm(16, in_channels[0]), + nn.ReLU(), + nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1), + ) + self.conv2 = nn.Conv2d( + in_channels[1], in_channels[1], 3, stride=2, padding=1) + + self.conv_aggregation = nn.Sequential( + nn.Conv2d(sum(in_channels), feature_dim, 1), + nn.GroupNorm(16, feature_dim), + nn.ReLU(), + ) + + self.decoder = VPDDepthDecoder( + in_channels=embed_dim * 8, + out_channels=embed_dim, + num_deconv_layers=num_deconv_layers, + num_deconv_filters=num_deconv_filters) + + self.depth_pred_layer = nn.Sequential( + nn.Conv2d( + embed_dim, embed_dim, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(embed_dim, 1, kernel_size=3, stride=1, padding=1)) + + # build loss + if isinstance(loss_decode, dict): + self.loss_decode = MODELS.build(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(MODELS.build(loss)) + else: + raise TypeError(f'loss_decode must be a dict or sequence of dict,\ + but got {type(loss_decode)}') + + def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor: + gt_depth_maps = [ + data_sample.gt_depth_map.data for data_sample in batch_data_samples + ] + return torch.stack(gt_depth_maps, dim=0) + + def forward(self, x): + x = [ + x[0], x[1], + torch.cat([x[2], F.interpolate(x[3], scale_factor=2)], dim=1) + ] + x = torch.cat([self.conv1(x[0]), self.conv2(x[1]), x[2]], dim=1) + x = self.conv_aggregation(x) + + x = x[:, :, :x.size(2) - self.fmap_border[0], :x.size(3) - + self.fmap_border[1]].contiguous() + x = self.decoder(x) + out = self.depth_pred_layer(x) + + depth = torch.sigmoid(out) * self.max_depth + + return depth + + def loss_by_feat(self, pred_depth_map: Tensor, + batch_data_samples: SampleList) -> dict: + """Compute depth estimation loss. + + Args: + pred_depth_map (Tensor): The output from decode head forward + function. + batch_data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_dpeth_map`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + gt_depth_map = self._stack_batch_gt(batch_data_samples) + loss = dict() + pred_depth_map = resize( + input=pred_depth_map, + size=gt_depth_map.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode( + pred_depth_map, gt_depth_map) + else: + loss[loss_decode.loss_name] += loss_decode( + pred_depth_map, gt_depth_map) + + return loss diff --git a/finetune/mmseg/models/losses/__init__.py b/finetune/mmseg/models/losses/__init__.py new file mode 100644 index 0000000..0467cb3 --- /dev/null +++ b/finetune/mmseg/models/losses/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .accuracy import Accuracy, accuracy +from .boundary_loss import BoundaryLoss +from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, + cross_entropy, mask_cross_entropy) +from .dice_loss import DiceLoss +from .focal_loss import FocalLoss +from .huasdorff_distance_loss import HuasdorffDisstanceLoss +from .lovasz_loss import LovaszLoss +from .ohem_cross_entropy_loss import OhemCrossEntropy +from .silog_loss import SiLogLoss +from .tversky_loss import TverskyLoss +from .utils import reduce_loss, weight_reduce_loss, weighted_loss + +__all__ = [ + 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', + 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', + 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', + 'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss', + 'HuasdorffDisstanceLoss', 'SiLogLoss' +] diff --git a/finetune/mmseg/models/losses/accuracy.py b/finetune/mmseg/models/losses/accuracy.py new file mode 100644 index 0000000..1d9e2d7 --- /dev/null +++ b/finetune/mmseg/models/losses/accuracy.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + + +def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): + """Calculate accuracy according to the prediction and target. + + Args: + pred (torch.Tensor): The model prediction, shape (N, num_class, ...) + target (torch.Tensor): The target of each prediction, shape (N, , ...) + ignore_index (int | None): The label index to be ignored. Default: None + topk (int | tuple[int], optional): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thresh (float, optional): If not None, predictions with scores under + this threshold are considered incorrect. Default to None. + + Returns: + float | tuple[float]: If the input ``topk`` is a single integer, + the function will return a single float as accuracy. If + ``topk`` is a tuple containing multiple integers, the + function will return a tuple containing accuracies of + each ``topk`` number. + """ + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk, ) + return_single = True + else: + return_single = False + + maxk = max(topk) + if pred.size(0) == 0: + accu = [pred.new_tensor(0.) for i in range(len(topk))] + return accu[0] if return_single else accu + assert pred.ndim == target.ndim + 1 + assert pred.size(0) == target.size(0) + assert maxk <= pred.size(1), \ + f'maxk {maxk} exceeds pred dimension {pred.size(1)}' + pred_value, pred_label = pred.topk(maxk, dim=1) + # transpose to shape (maxk, N, ...) + pred_label = pred_label.transpose(0, 1) + correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) + if thresh is not None: + # Only prediction values larger than thresh are counted as correct + correct = correct & (pred_value > thresh).t() + if ignore_index is not None: + correct = correct[:, target != ignore_index] + res = [] + eps = torch.finfo(torch.float32).eps + for k in topk: + # Avoid causing ZeroDivisionError when all pixels + # of an image are ignored + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps + if ignore_index is not None: + total_num = target[target != ignore_index].numel() + eps + else: + total_num = target.numel() + eps + res.append(correct_k.mul_(100.0 / total_num)) + return res[0] if return_single else res + + +class Accuracy(nn.Module): + """Accuracy calculation module.""" + + def __init__(self, topk=(1, ), thresh=None, ignore_index=None): + """Module to calculate the accuracy. + + Args: + topk (tuple, optional): The criterion used to calculate the + accuracy. Defaults to (1,). + thresh (float, optional): If not None, predictions with scores + under this threshold are considered incorrect. Default to None. + """ + super().__init__() + self.topk = topk + self.thresh = thresh + self.ignore_index = ignore_index + + def forward(self, pred, target): + """Forward function to calculate accuracy. + + Args: + pred (torch.Tensor): Prediction of models. + target (torch.Tensor): Target for each prediction. + + Returns: + tuple[float]: The accuracies under different topk criterions. + """ + return accuracy(pred, target, self.topk, self.thresh, + self.ignore_index) diff --git a/finetune/mmseg/models/losses/boundary_loss.py b/finetune/mmseg/models/losses/boundary_loss.py new file mode 100644 index 0000000..e86b850 --- /dev/null +++ b/finetune/mmseg/models/losses/boundary_loss.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class BoundaryLoss(nn.Module): + """Boundary loss. + + This function is modified from + `PIDNet `_. # noqa + Licensed under the MIT License. + + + Args: + loss_weight (float): Weight of the loss. Defaults to 1.0. + loss_name (str): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_boundary'. + """ + + def __init__(self, + loss_weight: float = 1.0, + loss_name: str = 'loss_boundary'): + super().__init__() + self.loss_weight = loss_weight + self.loss_name_ = loss_name + + def forward(self, bd_pre: Tensor, bd_gt: Tensor) -> Tensor: + """Forward function. + Args: + bd_pre (Tensor): Predictions of the boundary head. + bd_gt (Tensor): Ground truth of the boundary. + + Returns: + Tensor: Loss tensor. + """ + log_p = bd_pre.permute(0, 2, 3, 1).contiguous().view(1, -1) + target_t = bd_gt.view(1, -1).float() + + pos_index = (target_t == 1) + neg_index = (target_t == 0) + + weight = torch.zeros_like(log_p) + pos_num = pos_index.sum() + neg_num = neg_index.sum() + sum_num = pos_num + neg_num + weight[pos_index] = neg_num * 1.0 / sum_num + weight[neg_index] = pos_num * 1.0 / sum_num + + loss = F.binary_cross_entropy_with_logits( + log_p, target_t, weight, reduction='mean') + + return self.loss_weight * loss + + @property + def loss_name(self): + return self.loss_name_ diff --git a/finetune/mmseg/models/losses/cross_entropy_loss.py b/finetune/mmseg/models/losses/cross_entropy_loss.py new file mode 100644 index 0000000..988fb78 --- /dev/null +++ b/finetune/mmseg/models/losses/cross_entropy_loss.py @@ -0,0 +1,311 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmseg.registry import MODELS +from .utils import get_class_weight, weight_reduce_loss + + +def cross_entropy(pred, + label, + weight=None, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=-100, + avg_non_ignore=False): + """cross_entropy. The wrapper function for :func:`F.cross_entropy` + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + Default: None. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. + Options are 'none', 'mean' and 'sum'. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Default: None. + ignore_index (int): Specifies a target value that is ignored and + does not contribute to the input gradients. When + ``avg_non_ignore `` is ``True``, and the ``reduction`` is + ``''mean''``, the loss is averaged over non-ignored targets. + Defaults: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + # class_weight is a manual rescaling weight given to each class. + # If given, has to be a Tensor of size C element-wise losses + loss = F.cross_entropy( + pred, + label, + weight=class_weight, + reduction='none', + ignore_index=ignore_index) + + # apply weights and do the reduction + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa + if (avg_factor is None) and reduction == 'mean': + if class_weight is None: + if avg_non_ignore: + avg_factor = label.numel() - (label + == ignore_index).sum().item() + else: + avg_factor = label.numel() + + else: + # the average factor should take the class weights into account + label_weights = torch.stack([class_weight[cls] for cls in label + ]).to(device=class_weight.device) + + if avg_non_ignore: + label_weights[label == ignore_index] = 0 + avg_factor = label_weights.sum() + + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_zeros(target_shape) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask, as_tuple=True) + + if inds[0].numel() > 0: + if labels.dim() == 3: + bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 + else: + bin_labels[inds[0], labels[valid_mask]] = 1 + + valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() + + if label_weights is None: + bin_label_weights = valid_mask + else: + bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) + bin_label_weights = bin_label_weights * valid_mask + + return bin_labels, bin_label_weights, valid_mask + + +def binary_cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=-100, + avg_non_ignore=False, + **kwargs): + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + Note: In bce loss, label < 0 is invalid. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int): The label index to be ignored. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + + Returns: + torch.Tensor: The calculated loss + """ + if pred.size(1) == 1: + # For binary class segmentation, the shape of pred is + # [N, 1, H, W] and that of label is [N, H, W]. + # As the ignore_index often set as 255, so the + # binary class label check should mask out + # ignore_index + assert label[label != ignore_index].max() <= 1, \ + 'For pred with shape [N, 1, H, W], its label must have at ' \ + 'most 2 classes' + pred = pred.squeeze(1) + if pred.dim() != label.dim(): + assert (pred.dim() == 2 and label.dim() == 1) or ( + pred.dim() == 4 and label.dim() == 3), \ + 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ + 'H, W], label shape [N, H, W] are supported' + # `weight` returned from `_expand_onehot_labels` + # has been treated for valid (non-ignore) pixels + label, weight, valid_mask = _expand_onehot_labels( + label, weight, pred.shape, ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + if weight is not None: + weight = weight * valid_mask + else: + weight = valid_mask + # average loss over non-ignored and valid elements + if reduction == 'mean' and avg_factor is None and avg_non_ignore: + avg_factor = valid_mask.sum().item() + + loss = F.binary_cross_entropy_with_logits( + pred, label.float(), pos_weight=class_weight, reduction='none') + # do the reduction for the weighted loss + loss = weight_reduce_loss( + loss, weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy(pred, + target, + label, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=None, + **kwargs): + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask' + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. + + Returns: + torch.Tensor: The calculated loss + """ + assert ignore_index is None, 'BCE loss does not support ignore_index' + # TODO: handle these two reserved arguments + assert reduction == 'mean' and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits( + pred_slice, target, weight=class_weight, reduction='mean')[None] + + +@MODELS.register_module() +class CrossEntropyLoss(nn.Module): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + def __init__(self, + use_sigmoid=False, + use_mask=False, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_ce', + avg_non_ignore=False): + super().__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self.avg_non_ignore = avg_non_ignore + if not self.avg_non_ignore and self.reduction == 'mean': + warnings.warn( + 'Default ``avg_non_ignore`` is False, if you would like to ' + 'ignore the certain label and average loss over non-ignore ' + 'labels, which is the same with PyTorch official ' + 'cross_entropy, set ``avg_non_ignore=True``.') + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + self._loss_name = loss_name + + def extra_repr(self): + """Extra repr.""" + s = f'avg_non_ignore={self.avg_non_ignore}' + return s + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=-100, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + # Note: for BCE loss, label < 0 is invalid. + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + avg_non_ignore=self.avg_non_ignore, + ignore_index=ignore_index, + **kwargs) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/finetune/mmseg/models/losses/dice_loss.py b/finetune/mmseg/models/losses/dice_loss.py new file mode 100644 index 0000000..fb2ffdb --- /dev/null +++ b/finetune/mmseg/models/losses/dice_loss.py @@ -0,0 +1,202 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +import torch.nn as nn + +from mmseg.registry import MODELS +from .utils import weight_reduce_loss + + +def _expand_onehot_labels_dice(pred: torch.Tensor, + target: torch.Tensor) -> torch.Tensor: + """Expand onehot labels to match the size of prediction. + + Args: + pred (torch.Tensor): The prediction, has a shape (N, num_class, H, W). + target (torch.Tensor): The learning label of the prediction, + has a shape (N, H, W). + + Returns: + torch.Tensor: The target after one-hot encoding, + has a shape (N, num_class, H, W). + """ + num_classes = pred.shape[1] + one_hot_target = torch.clamp(target, min=0, max=num_classes) + one_hot_target = torch.nn.functional.one_hot(one_hot_target, + num_classes + 1) + one_hot_target = one_hot_target[..., :num_classes].permute(0, 3, 1, 2) + return one_hot_target + + +def dice_loss(pred: torch.Tensor, + target: torch.Tensor, + weight: Union[torch.Tensor, None], + eps: float = 1e-3, + reduction: Union[str, None] = 'mean', + naive_dice: Union[bool, None] = False, + avg_factor: Union[int, None] = None, + ignore_index: Union[int, None] = 255) -> float: + """Calculate dice loss, there are two forms of dice loss is supported: + + - the one proposed in `V-Net: Fully Convolutional Neural + Networks for Volumetric Medical Image Segmentation + `_. + - the dice loss in which the power of the number in the + denominator is the first power instead of the second + power. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + naive_dice (bool, optional): If false, use the dice + loss defined in the V-Net paper, otherwise, use the + naive dice loss in which the power of the number in the + denominator is the first power instead of the second + power.Defaults to False. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + ignore_index (int, optional): The label index to be ignored. + Defaults to 255. + """ + if ignore_index is not None: + num_classes = pred.shape[1] + pred = pred[:, torch.arange(num_classes) != ignore_index, :, :] + target = target[:, torch.arange(num_classes) != ignore_index, :, :] + assert pred.shape[1] != 0 # if the ignored index is the only class + input = pred.flatten(1) + target = target.flatten(1).float() + a = torch.sum(input * target, 1) + if naive_dice: + b = torch.sum(input, 1) + c = torch.sum(target, 1) + d = (2 * a + eps) / (b + c + eps) + else: + b = torch.sum(input * input, 1) + eps + c = torch.sum(target * target, 1) + eps + d = (2 * a) / (b + c) + + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class DiceLoss(nn.Module): + + def __init__(self, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=False, + loss_weight=1.0, + ignore_index=255, + eps=1e-3, + loss_name='loss_dice'): + """Compute dice loss. + + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + activate (bool): Whether to activate the predictions inside, + this will disable the inside sigmoid operation. + Defaults to True. + reduction (str, optional): The method used + to reduce the loss. Options are "none", + "mean" and "sum". Defaults to 'mean'. + naive_dice (bool, optional): If false, use the dice + loss defined in the V-Net paper, otherwise, use the + naive dice loss in which the power of the number in the + denominator is the first power instead of the second + power. Defaults to False. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + ignore_index (int, optional): The label index to be ignored. + Default: 255. + eps (float): Avoid dividing by zero. Defaults to 1e-3. + loss_name (str, optional): Name of the loss item. If you want this + loss item to be included into the backward graph, `loss_` must + be the prefix of the name. Defaults to 'loss_dice'. + """ + + super().__init__() + self.use_sigmoid = use_sigmoid + self.reduction = reduction + self.naive_dice = naive_dice + self.loss_weight = loss_weight + self.eps = eps + self.activate = activate + self.ignore_index = ignore_index + self._loss_name = loss_name + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=255, + **kwargs): + """Forward function. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *). + target (torch.Tensor): The label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + one_hot_target = target + if (pred.shape != target.shape): + one_hot_target = _expand_onehot_labels_dice(pred, target) + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.activate: + if self.use_sigmoid: + pred = pred.sigmoid() + elif pred.shape[1] != 1: + # softmax does not work when there is only 1 class + pred = pred.softmax(dim=1) + loss = self.loss_weight * dice_loss( + pred, + one_hot_target, + weight, + eps=self.eps, + reduction=reduction, + naive_dice=self.naive_dice, + avg_factor=avg_factor, + ignore_index=self.ignore_index) + + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/finetune/mmseg/models/losses/focal_loss.py b/finetune/mmseg/models/losses/focal_loss.py new file mode 100644 index 0000000..6507ed7 --- /dev/null +++ b/finetune/mmseg/models/losses/focal_loss.py @@ -0,0 +1,337 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/open-mmlab/mmdetection +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss + +from mmseg.registry import MODELS +from .utils import weight_reduce_loss + + +# This method is used when cuda is not available +def py_sigmoid_focal_loss(pred, + target, + one_hot_target=None, + weight=None, + gamma=2.0, + alpha=0.5, + class_weight=None, + valid_mask=None, + reduction='mean', + avg_factor=None): + """PyTorch version of `Focal Loss `_. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the + number of classes + target (torch.Tensor): The learning label of the prediction with + shape (N, C) + one_hot_target (None): Placeholder. It should be None. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal Loss. + Defaults to 0.5. + class_weight (list[float], optional): Weight of each class. + Defaults to None. + valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid + samples and uses 0 to mark the ignored samples. Default: None. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + if isinstance(alpha, list): + alpha = pred.new_tensor(alpha) + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + one_minus_pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * one_minus_pt.pow(gamma) + + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + final_weight = torch.ones(1, pred.size(1)).type_as(loss) + if weight is not None: + if weight.shape != loss.shape and weight.size(0) == loss.size(0): + # For most cases, weight is of shape (N, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + assert weight.dim() == loss.dim() + final_weight = final_weight * weight + if class_weight is not None: + final_weight = final_weight * pred.new_tensor(class_weight) + if valid_mask is not None: + final_weight = final_weight * valid_mask + loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor) + return loss + + +def sigmoid_focal_loss(pred, + target, + one_hot_target, + weight=None, + gamma=2.0, + alpha=0.5, + class_weight=None, + valid_mask=None, + reduction='mean', + avg_factor=None): + r"""A wrapper of cuda version `Focal Loss + `_. + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. It's shape + should be (N, ) + one_hot_target (torch.Tensor): The learning label with shape (N, C) + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal Loss. + Defaults to 0.5. + class_weight (list[float], optional): Weight of each class. + Defaults to None. + valid_mask (torch.Tensor, optional): A mask uses 1 to mark the valid + samples and uses 0 to mark the ignored samples. Default: None. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + # Function.apply does not accept keyword arguments, so the decorator + # "weighted_loss" is not applicable + final_weight = torch.ones(1, pred.size(1)).type_as(pred) + if isinstance(alpha, list): + # _sigmoid_focal_loss doesn't accept alpha of list type. Therefore, if + # a list is given, we set the input alpha as 0.5. This means setting + # equal weight for foreground class and background class. By + # multiplying the loss by 2, the effect of setting alpha as 0.5 is + # undone. The alpha of type list is used to regulate the loss in the + # post-processing process. + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), + gamma, 0.5, None, 'none') * 2 + alpha = pred.new_tensor(alpha) + final_weight = final_weight * ( + alpha * one_hot_target + (1 - alpha) * (1 - one_hot_target)) + else: + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), + gamma, alpha, None, 'none') + if weight is not None: + if weight.shape != loss.shape and weight.size(0) == loss.size(0): + # For most cases, weight is of shape (N, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + assert weight.dim() == loss.dim() + final_weight = final_weight * weight + if class_weight is not None: + final_weight = final_weight * pred.new_tensor(class_weight) + if valid_mask is not None: + final_weight = final_weight * valid_mask + loss = weight_reduce_loss(loss, final_weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class FocalLoss(nn.Module): + + def __init__(self, + use_sigmoid=True, + gamma=2.0, + alpha=0.5, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_focal'): + """`Focal Loss `_ + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float | list[float], optional): A balanced form for Focal + Loss. Defaults to 0.5. When a list is provided, the length + of the list should be equal to the number of classes. + Please be careful that this parameter is not the + class-wise weight but the weight of a binary classification + problem. This binary classification problem regards the + pixels which belong to one class as the foreground + and the other pixels as the background, each element in + the list is the weight of the corresponding foreground class. + The value of alpha or each element of alpha should be a float + in the interval [0, 1]. If you want to specify the class-wise + weight, please use `class_weight` parameter. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and + "sum". + class_weight (list[float], optional): Weight of each class. + Defaults to None. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this + loss item to be included into the backward graph, `loss_` must + be the prefix of the name. Defaults to 'loss_focal'. + """ + super().__init__() + assert use_sigmoid is True, \ + 'AssertionError: Only sigmoid focal loss supported now.' + assert reduction in ('none', 'mean', 'sum'), \ + "AssertionError: reduction should be 'none', 'mean' or " \ + "'sum'" + assert isinstance(alpha, (float, list)), \ + 'AssertionError: alpha should be of type float' + assert isinstance(gamma, float), \ + 'AssertionError: gamma should be of type float' + assert isinstance(loss_weight, float), \ + 'AssertionError: loss_weight should be of type float' + assert isinstance(loss_name, str), \ + 'AssertionError: loss_name should be of type str' + assert isinstance(class_weight, list) or class_weight is None, \ + 'AssertionError: class_weight must be None or of type list' + self.use_sigmoid = use_sigmoid + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.class_weight = class_weight + self.loss_weight = loss_weight + self._loss_name = loss_name + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=255, + **kwargs): + """Forward function. + + Args: + pred (torch.Tensor): The prediction with shape + (N, C) where C = number of classes, or + (N, C, d_1, d_2, ..., d_K) with K≥1 in the + case of K-dimensional loss. + target (torch.Tensor): The ground truth. If containing class + indices, shape (N) where each value is 0≤targets[i]≤C−1, + or (N, d_1, d_2, ..., d_K) with K≥1 in the case of + K-dimensional loss. If containing class probabilities, + same shape as the input. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to + average the loss. Defaults to None. + reduction_override (str, optional): The reduction method used + to override the original reduction method of the loss. + Options are "none", "mean" and "sum". + ignore_index (int, optional): The label index to be ignored. + Default: 255 + Returns: + torch.Tensor: The calculated loss + """ + assert isinstance(ignore_index, int), \ + 'ignore_index must be of type int' + assert reduction_override in (None, 'none', 'mean', 'sum'), \ + "AssertionError: reduction should be 'none', 'mean' or " \ + "'sum'" + assert pred.shape == target.shape or \ + (pred.size(0) == target.size(0) and + pred.shape[2:] == target.shape[1:]), \ + "The shape of pred doesn't match the shape of target" + + original_shape = pred.shape + + # [B, C, d_1, d_2, ..., d_k] -> [C, B, d_1, d_2, ..., d_k] + pred = pred.transpose(0, 1) + # [C, B, d_1, d_2, ..., d_k] -> [C, N] + pred = pred.reshape(pred.size(0), -1) + # [C, N] -> [N, C] + pred = pred.transpose(0, 1).contiguous() + + if original_shape == target.shape: + # target with shape [B, C, d_1, d_2, ...] + # transform it's shape into [N, C] + # [B, C, d_1, d_2, ...] -> [C, B, d_1, d_2, ..., d_k] + target = target.transpose(0, 1) + # [C, B, d_1, d_2, ..., d_k] -> [C, N] + target = target.reshape(target.size(0), -1) + # [C, N] -> [N, C] + target = target.transpose(0, 1).contiguous() + else: + # target with shape [B, d_1, d_2, ...] + # transform it's shape into [N, ] + target = target.view(-1).contiguous() + valid_mask = (target != ignore_index).view(-1, 1) + # avoid raising error when using F.one_hot() + target = torch.where(target == ignore_index, target.new_tensor(0), + target) + + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.use_sigmoid: + num_classes = pred.size(1) + if torch.cuda.is_available() and pred.is_cuda: + if target.dim() == 1: + one_hot_target = F.one_hot( + target, num_classes=num_classes + 1) + if num_classes == 1: + one_hot_target = one_hot_target[:, 1] + target = 1 - target + else: + one_hot_target = one_hot_target[:, :num_classes] + else: + one_hot_target = target + target = target.argmax(dim=1) + valid_mask = (target != ignore_index).view(-1, 1) + calculate_loss_func = sigmoid_focal_loss + else: + one_hot_target = None + if target.dim() == 1: + target = F.one_hot(target, num_classes=num_classes + 1) + if num_classes == 1: + target = target[:, 1] + else: + target = target[:, num_classes] + else: + valid_mask = (target.argmax(dim=1) != ignore_index).view( + -1, 1) + calculate_loss_func = py_sigmoid_focal_loss + + loss_cls = self.loss_weight * calculate_loss_func( + pred, + target, + one_hot_target, + weight, + gamma=self.gamma, + alpha=self.alpha, + class_weight=self.class_weight, + valid_mask=valid_mask, + reduction=reduction, + avg_factor=avg_factor) + + if reduction == 'none': + # [N, C] -> [C, N] + loss_cls = loss_cls.transpose(0, 1) + # [C, N] -> [C, B, d1, d2, ...] + # original_shape: [B, C, d1, d2, ...] + loss_cls = loss_cls.reshape(original_shape[1], + original_shape[0], + *original_shape[2:]) + # [C, B, d1, d2, ...] -> [B, C, d1, d2, ...] + loss_cls = loss_cls.transpose(0, 1).contiguous() + else: + raise NotImplementedError + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/finetune/mmseg/models/losses/huasdorff_distance_loss.py b/finetune/mmseg/models/losses/huasdorff_distance_loss.py new file mode 100644 index 0000000..d950ba7 --- /dev/null +++ b/finetune/mmseg/models/losses/huasdorff_distance_loss.py @@ -0,0 +1,160 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/ +master/code/train_LA_HD.py (Apache-2.0 License)""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.ndimage import distance_transform_edt as distance +from torch import Tensor + +from mmseg.registry import MODELS +from .utils import get_class_weight, weighted_loss + + +def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor: + """ + compute the distance transform map of foreground in mask + Args: + img_gt: Ground truth of the image, (b, h, w) + pred: Predictions of the segmentation head after softmax, (b, c, h, w) + + Returns: + output: the foreground Distance Map (SDM) + dtm(x) = 0; x in segmentation boundary + inf|x-y|; x in segmentation + """ + + fg_dtm = torch.zeros_like(pred) + out_shape = pred.shape + for b in range(out_shape[0]): # batch size + for c in range(1, out_shape[1]): # default 0 channel is background + posmask = img_gt[b].byte() + if posmask.any(): + posdis = distance(posmask) + fg_dtm[b][c] = torch.from_numpy(posdis) + + return fg_dtm + + +@weighted_loss +def hd_loss(seg_soft: Tensor, + gt: Tensor, + seg_dtm: Tensor, + gt_dtm: Tensor, + class_weight=None, + ignore_index=255) -> Tensor: + """ + compute huasdorff distance loss for segmentation + Args: + seg_soft: softmax results, shape=(b,c,x,y) + gt: ground truth, shape=(b,x,y) + seg_dtm: segmentation distance transform map, shape=(b,c,x,y) + gt_dtm: ground truth distance transform map, shape=(b,c,x,y) + + Returns: + output: hd_loss + """ + assert seg_soft.shape[0] == gt.shape[0] + total_loss = 0 + num_class = seg_soft.shape[1] + if class_weight is not None: + assert class_weight.ndim == num_class + for i in range(1, num_class): + if i != ignore_index: + delta_s = (seg_soft[:, i, ...] - gt.float())**2 + s_dtm = seg_dtm[:, i, ...]**2 + g_dtm = gt_dtm[:, i, ...]**2 + dtm = s_dtm + g_dtm + multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm) + hd_loss = multiplied.mean() + if class_weight is not None: + hd_loss *= class_weight[i] + total_loss += hd_loss + + return total_loss / num_class + + +@MODELS.register_module() +class HuasdorffDisstanceLoss(nn.Module): + """HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform + Maps Boost Segmentation CNNs: An Empirical Study. + + `_. + Args: + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float): Weight of the loss. Defaults to 1.0. + ignore_index (int | None): The label index to be ignored. Default: 255. + loss_name (str): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_boundary'. + """ + + def __init__(self, + reduction='mean', + class_weight=None, + loss_weight=1.0, + ignore_index=255, + loss_name='loss_huasdorff_disstance', + **kwargs): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self._loss_name = loss_name + self.ignore_index = ignore_index + + def forward(self, + pred: Tensor, + target: Tensor, + avg_factor=None, + reduction_override=None, + **kwargs) -> Tensor: + """Forward function. + + Args: + pred (Tensor): Predictions of the segmentation head. (B, C, H, W) + target (Tensor): Ground truth of the image. (B, H, W) + avg_factor (int, optional): Average factor that is used to + average the loss. Defaults to None. + reduction_override (str, optional): The reduction method used + to override the original reduction method of the loss. + Options are "none", "mean" and "sum". + Returns: + Tensor: Loss tensor. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = pred.new_tensor(self.class_weight) + else: + class_weight = None + + pred_soft = F.softmax(pred, dim=1) + valid_mask = (target != self.ignore_index).long() + target = target * valid_mask + + with torch.no_grad(): + gt_dtm = compute_dtm(target.cpu(), pred_soft) + gt_dtm = gt_dtm.float() + seg_dtm2 = compute_dtm( + pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft) + seg_dtm2 = seg_dtm2.float() + + loss_hd = self.loss_weight * hd_loss( + pred_soft, + target, + seg_dtm=seg_dtm2, + gt_dtm=gt_dtm, + reduction=reduction, + avg_factor=avg_factor, + class_weight=class_weight, + ignore_index=self.ignore_index) + return loss_hd + + @property + def loss_name(self): + return self._loss_name diff --git a/finetune/mmseg/models/losses/kldiv_loss.py b/finetune/mmseg/models/losses/kldiv_loss.py new file mode 100644 index 0000000..496ef97 --- /dev/null +++ b/finetune/mmseg/models/losses/kldiv_loss.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class KLDivLoss(nn.Module): + + def __init__(self, + temperature: float = 1.0, + reduction: str = 'mean', + loss_name: str = 'loss_kld'): + """Kullback-Leibler divergence Loss. + + + + Args: + temperature (float, optional): Temperature param + reduction (str, optional): The method to reduce the loss into a + scalar. Default is "mean". Options are "none", "sum", + and "mean" + """ + + assert isinstance(temperature, (float, int)), \ + 'Expected temperature to be' \ + f'float or int, but got {temperature.__class__.__name__} instead' + assert temperature != 0., 'Temperature must not be zero' + + assert reduction in ['mean', 'none', 'sum'], \ + 'Reduction must be one of the options ("mean", ' \ + f'"sum", "none"), but got {reduction}' + + super().__init__() + self.temperature = temperature + self.reduction = reduction + self._loss_name = loss_name + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """Forward function. Calculate KL divergence Loss. + + Args: + input (Tensor): Logit tensor, + the data type is float32 or float64. + The shape is (N, C) where N is batchsize and C is number of + channels. + If there more than 2 dimensions, shape is (N, C, D1, D2, ... + Dk), k>= 1 + target (Tensor): Logit tensor, + the data type is float32 or float64. + input and target must be with the same shape. + + Returns: + (Tensor): Reduced loss. + """ + assert isinstance(input, torch.Tensor), 'Expected input to' \ + f'be Tensor, but got {input.__class__.__name__} instead' + assert isinstance(target, torch.Tensor), 'Expected target to' \ + f'be Tensor, but got {target.__class__.__name__} instead' + + assert input.shape == target.shape, 'Input and target ' \ + 'must have same shape,' \ + f'but got shapes {input.shape} and {target.shape}' + + input = F.softmax(input / self.temperature, dim=1) + target = F.softmax(target / self.temperature, dim=1) + + loss = F.kl_div(input, target, reduction='none', log_target=False) + loss = loss * self.temperature**2 + + batch_size = input.shape[0] + + if self.reduction == 'sum': + # Change view to calculate instance-wise sum + loss = loss.view(batch_size, -1) + return torch.sum(loss, dim=1) + + elif self.reduction == 'mean': + # Change view to calculate instance-wise mean + loss = loss.view(batch_size, -1) + return torch.mean(loss, dim=1) + + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/finetune/mmseg/models/losses/lovasz_loss.py b/finetune/mmseg/models/losses/lovasz_loss.py new file mode 100644 index 0000000..b47f9d8 --- /dev/null +++ b/finetune/mmseg/models/losses/lovasz_loss.py @@ -0,0 +1,323 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor +ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim +Berman 2018 ESAT-PSI KU Leuven (MIT License)""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.utils import is_list_of + +from mmseg.registry import MODELS +from .utils import get_class_weight, weight_reduce_loss + + +def lovasz_grad(gt_sorted): + """Computes gradient of the Lovasz extension w.r.t sorted errors. + + See Alg. 1 in paper. + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def flatten_binary_logits(logits, labels, ignore_index=None): + """Flattens predictions in the batch (binary case) Remove labels equal to + 'ignore_index'.""" + logits = logits.view(-1) + labels = labels.view(-1) + if ignore_index is None: + return logits, labels + valid = (labels != ignore_index) + vlogits = logits[valid] + vlabels = labels[valid] + return vlogits, vlabels + + +def flatten_probs(probs, labels, ignore_index=None): + """Flattens predictions in the batch.""" + if probs.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probs.size() + probs = probs.view(B, 1, H, W) + B, C, H, W = probs.size() + probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C + labels = labels.view(-1) + if ignore_index is None: + return probs, labels + valid = (labels != ignore_index) + vprobs = probs[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobs, vlabels + + +def lovasz_hinge_flat(logits, labels): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [P], logits at each prediction + (between -infty and +infty). + labels (torch.Tensor): [P], binary ground truth labels (0 or 1). + + Returns: + torch.Tensor: The calculated loss. + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * signs) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), grad) + return loss + + +def lovasz_hinge(logits, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [B, H, W], logits at each pixel + (between -infty and +infty). + labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1). + classes (str | list[int], optional): Placeholder, to be consistent with + other loss. Default: None. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): Placeholder, to be consistent + with other loss. Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + if per_image: + loss = [ + lovasz_hinge_flat(*flatten_binary_logits( + logit.unsqueeze(0), label.unsqueeze(0), ignore_index)) + for logit, label in zip(logits, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_hinge_flat( + *flatten_binary_logits(logits, labels, ignore_index)) + return loss + + +def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [P, C], class probabilities at each prediction + (between 0 and 1). + labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1). + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + class_weight (list[float], optional): The weight for each class. + Default: None. + + Returns: + torch.Tensor: The calculated loss. + """ + if probs.numel() == 0: + # only void pixels, the gradients should be 0 + return probs * 0. + C = probs.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes == 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probs[:, 0] + else: + class_pred = probs[:, c] + errors = (fg - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted)) + if class_weight is not None: + loss *= class_weight[c] + losses.append(loss) + return torch.stack(losses).mean() + + +def lovasz_softmax(probs, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [B, C, H, W], class probabilities at each + prediction (between 0 and 1). + labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and + C - 1). + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + + if per_image: + loss = [ + lovasz_softmax_flat( + *flatten_probs( + prob.unsqueeze(0), label.unsqueeze(0), ignore_index), + classes=classes, + class_weight=class_weight) + for prob, label in zip(probs, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_softmax_flat( + *flatten_probs(probs, labels, ignore_index), + classes=classes, + class_weight=class_weight) + return loss + + +@MODELS.register_module() +class LovaszLoss(nn.Module): + """LovaszLoss. + + This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate + for the optimization of the intersection-over-union measure in neural + networks `_. + + Args: + loss_type (str, optional): Binary or multi-class loss. + Default: 'multi_class'. Options are "binary" and "multi_class". + classes (str | list[int], optional): Classes chosen to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_lovasz'. + """ + + def __init__(self, + loss_type='multi_class', + classes='present', + per_image=False, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_lovasz'): + super().__init__() + assert loss_type in ('binary', 'multi_class'), "loss_type should be \ + 'binary' or 'multi_class'." + + if loss_type == 'binary': + self.cls_criterion = lovasz_hinge + else: + self.cls_criterion = lovasz_softmax + assert classes in ('all', 'present') or is_list_of(classes, int) + if not per_image: + assert reduction == 'none', "reduction should be 'none' when \ + per_image is False." + + self.classes = classes + self.per_image = per_image + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self._loss_name = loss_name + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + + # if multi-class loss, transform logits to probs + if self.cls_criterion == lovasz_softmax: + cls_score = F.softmax(cls_score, dim=1) + + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + self.classes, + self.per_image, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/finetune/mmseg/models/losses/ohem_cross_entropy_loss.py b/finetune/mmseg/models/losses/ohem_cross_entropy_loss.py new file mode 100644 index 0000000..a519b4d --- /dev/null +++ b/finetune/mmseg/models/losses/ohem_cross_entropy_loss.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class OhemCrossEntropy(nn.Module): + """OhemCrossEntropy loss. + + This func is modified from + `PIDNet `_. # noqa + + Licensed under the MIT License. + + Args: + ignore_label (int): Labels to ignore when computing the loss. + Default: 255 + thresh (float, optional): The threshold for hard example selection. + Below which, are prediction with low confidence. If not + specified, the hard examples will be pixels of top ``min_kept`` + loss. Default: 0.7. + min_kept (int, optional): The minimum number of predictions to keep. + Default: 100000. + loss_weight (float): Weight of the loss. Defaults to 1.0. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_name (str): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_boundary'. + """ + + def __init__(self, + ignore_label: int = 255, + thres: float = 0.7, + min_kept: int = 100000, + loss_weight: float = 1.0, + class_weight: Optional[Union[List[float], str]] = None, + loss_name: str = 'loss_ohem'): + super().__init__() + self.thresh = thres + self.min_kept = max(1, min_kept) + self.ignore_label = ignore_label + self.loss_weight = loss_weight + self.loss_name_ = loss_name + self.class_weight = class_weight + + def forward(self, score: Tensor, target: Tensor) -> Tensor: + """Forward function. + Args: + score (Tensor): Predictions of the segmentation head. + target (Tensor): Ground truth of the image. + + Returns: + Tensor: Loss tensor. + """ + # score: (N, C, H, W) + pred = F.softmax(score, dim=1) + if self.class_weight is not None: + class_weight = score.new_tensor(self.class_weight) + else: + class_weight = None + + pixel_losses = F.cross_entropy( + score, + target, + weight=class_weight, + ignore_index=self.ignore_label, + reduction='none').contiguous().view(-1) # (N*H*W) + mask = target.contiguous().view(-1) != self.ignore_label # (N*H*W) + + tmp_target = target.clone() # (N, H, W) + tmp_target[tmp_target == self.ignore_label] = 0 + # pred: (N, C, H, W) -> (N*H*W, C) + pred = pred.gather(1, tmp_target.unsqueeze(1)) + # pred: (N*H*W, C) -> (N*H*W), ind: (N*H*W) + pred, ind = pred.contiguous().view(-1, )[mask].contiguous().sort() + if pred.numel() > 0: + min_value = pred[min(self.min_kept, pred.numel() - 1)] + else: + return score.new_tensor(0.0) + threshold = max(min_value, self.thresh) + + pixel_losses = pixel_losses[mask][ind] + pixel_losses = pixel_losses[pred < threshold] + return self.loss_weight * pixel_losses.mean() + + @property + def loss_name(self): + return self.loss_name_ diff --git a/finetune/mmseg/models/losses/silog_loss.py b/finetune/mmseg/models/losses/silog_loss.py new file mode 100644 index 0000000..ecc07aa --- /dev/null +++ b/finetune/mmseg/models/losses/silog_loss.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor + +from mmseg.registry import MODELS +from .utils import weight_reduce_loss + + +def silog_loss(pred: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + eps: float = 1e-4, + reduction: Union[str, None] = 'mean', + avg_factor: Optional[int] = None) -> Tensor: + """Computes the Scale-Invariant Logarithmic (SI-Log) loss between + prediction and target. + + Args: + pred (Tensor): Predicted output. + target (Tensor): Ground truth. + weight (Optional[Tensor]): Optional weight to apply on the loss. + eps (float): Epsilon value to avoid division and log(0). + reduction (Union[str, None]): Specifies the reduction to apply to the + output: 'mean', 'sum' or None. + avg_factor (Optional[int]): Optional average factor for the loss. + + Returns: + Tensor: The calculated SI-Log loss. + """ + pred, target = pred.flatten(1), target.flatten(1) + valid_mask = (target > eps).detach().float() + + diff_log = torch.log(target.clamp(min=eps)) - torch.log( + pred.clamp(min=eps)) + + valid_mask = (target > eps).detach() & (~torch.isnan(diff_log)) + diff_log[~valid_mask] = 0.0 + valid_mask = valid_mask.float() + + diff_log_sq_mean = (diff_log.pow(2) * valid_mask).sum( + dim=1) / valid_mask.sum(dim=1).clamp(min=eps) + diff_log_mean = (diff_log * valid_mask).sum(dim=1) / valid_mask.sum( + dim=1).clamp(min=eps) + + loss = torch.sqrt(diff_log_sq_mean - 0.5 * diff_log_mean.pow(2)) + + if weight is not None: + weight = weight.float() + + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@MODELS.register_module() +class SiLogLoss(nn.Module): + """Compute SiLog loss. + + Args: + reduction (str, optional): The method used + to reduce the loss. Options are "none", + "mean" and "sum". Defaults to 'mean'. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + eps (float): Avoid dividing by zero. Defaults to 1e-3. + loss_name (str, optional): Name of the loss item. If you want this + loss item to be included into the backward graph, `loss_` must + be the prefix of the name. Defaults to 'loss_silog'. + """ + + def __init__(self, + reduction='mean', + loss_weight=1.0, + eps=1e-6, + loss_name='loss_silog'): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.eps = eps + self._loss_name = loss_name + + def forward( + self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + ): + + assert pred.shape == target.shape, 'the shapes of pred ' \ + f'({pred.shape}) and target ({target.shape}) are mismatch' + + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + + loss = self.loss_weight * silog_loss( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor, + ) + + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/finetune/mmseg/models/losses/tversky_loss.py b/finetune/mmseg/models/losses/tversky_loss.py new file mode 100644 index 0000000..bfca1af --- /dev/null +++ b/finetune/mmseg/models/losses/tversky_loss.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from +https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/dice_loss.py#L333 +(Apache-2.0 License)""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import get_class_weight, weighted_loss + + +@weighted_loss +def tversky_loss(pred, + target, + valid_mask, + alpha=0.3, + beta=0.7, + smooth=1, + class_weight=None, + ignore_index=255): + assert pred.shape[0] == target.shape[0] + total_loss = 0 + num_classes = pred.shape[1] + for i in range(num_classes): + if i != ignore_index: + tversky_loss = binary_tversky_loss( + pred[:, i], + target[..., i], + valid_mask=valid_mask, + alpha=alpha, + beta=beta, + smooth=smooth) + if class_weight is not None: + tversky_loss *= class_weight[i] + total_loss += tversky_loss + return total_loss / num_classes + + +@weighted_loss +def binary_tversky_loss(pred, + target, + valid_mask, + alpha=0.3, + beta=0.7, + smooth=1): + assert pred.shape[0] == target.shape[0] + pred = pred.reshape(pred.shape[0], -1) + target = target.reshape(target.shape[0], -1) + valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) + + TP = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) + FP = torch.sum(torch.mul(pred, 1 - target) * valid_mask, dim=1) + FN = torch.sum(torch.mul(1 - pred, target) * valid_mask, dim=1) + tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth) + + return 1 - tversky + + +@LOSSES.register_module() +class TverskyLoss(nn.Module): + """TverskyLoss. This loss is proposed in `Tversky loss function for image + segmentation using 3D fully convolutional deep networks. + + `_. + Args: + smooth (float): A float number to smooth loss, and avoid NaN error. + Default: 1. + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Default to 1.0. + ignore_index (int | None): The label index to be ignored. Default: 255. + alpha(float, in [0, 1]): + The coefficient of false positives. Default: 0.3. + beta (float, in [0, 1]): + The coefficient of false negatives. Default: 0.7. + Note: alpha + beta = 1. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_tversky'. + """ + + def __init__(self, + smooth=1, + class_weight=None, + loss_weight=1.0, + ignore_index=255, + alpha=0.3, + beta=0.7, + loss_name='loss_tversky'): + super().__init__() + self.smooth = smooth + self.class_weight = get_class_weight(class_weight) + self.loss_weight = loss_weight + self.ignore_index = ignore_index + assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!' + self.alpha = alpha + self.beta = beta + self._loss_name = loss_name + + def forward(self, pred, target, **kwargs): + if self.class_weight is not None: + class_weight = pred.new_tensor(self.class_weight) + else: + class_weight = None + + pred = F.softmax(pred, dim=1) + num_classes = pred.shape[1] + one_hot_target = F.one_hot( + torch.clamp(target.long(), 0, num_classes - 1), + num_classes=num_classes) + valid_mask = (target != self.ignore_index).long() + + loss = self.loss_weight * tversky_loss( + pred, + one_hot_target, + valid_mask=valid_mask, + alpha=self.alpha, + beta=self.beta, + smooth=self.smooth, + class_weight=class_weight, + ignore_index=self.ignore_index) + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/finetune/mmseg/models/losses/utils.py b/finetune/mmseg/models/losses/utils.py new file mode 100644 index 0000000..0478034 --- /dev/null +++ b/finetune/mmseg/models/losses/utils.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.fileio import load + + +def get_class_weight(class_weight): + """Get class weight for loss function. + + Args: + class_weight (list[float] | str | None): If class_weight is a str, + take it as a file name and read from it. + """ + if isinstance(class_weight, str): + # take it as a file path + if class_weight.endswith('.npy'): + class_weight = np.load(class_weight) + else: + # pkl, json or yaml + class_weight = load(class_weight) + + return class_weight + + +def reduce_loss(loss, reduction) -> torch.Tensor: + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, + weight=None, + reduction='mean', + avg_factor=None) -> torch.Tensor: + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Average factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + if weight.dim() > 1: + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper diff --git a/finetune/mmseg/models/necks/__init__.py b/finetune/mmseg/models/necks/__init__.py new file mode 100644 index 0000000..0083f05 --- /dev/null +++ b/finetune/mmseg/models/necks/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .featurepyramid import Feature2Pyramid +from .fpn import FPN +from .ic_neck import ICNeck +from .jpu import JPU +from .mla_neck import MLANeck +from .multilevel_neck import MultiLevelNeck +from .fusion_transformer import FusionTransformer +from .fusion_multilevel_neck import FusionMultiLevelNeck + +__all__ = [ + 'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid', + 'FusionTransformer', 'FusionMultiLevelNeck' +] diff --git a/finetune/mmseg/models/necks/featurepyramid.py b/finetune/mmseg/models/necks/featurepyramid.py new file mode 100644 index 0000000..dc1250d --- /dev/null +++ b/finetune/mmseg/models/necks/featurepyramid.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import build_norm_layer + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class Feature2Pyramid(nn.Module): + """Feature2Pyramid. + + A neck structure connect ViT backbone and decoder_heads. + + Args: + embed_dims (int): Embedding dimension. + rescales (list[float]): Different sampling multiples were + used to obtain pyramid features. Default: [4, 2, 1, 0.5]. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + embed_dim, + rescales=[4, 2, 1, 0.5], + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + self.rescales = rescales + self.upsample_4x = None + for k in self.rescales: + if k == 4: + self.upsample_4x = nn.Sequential( + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + build_norm_layer(norm_cfg, embed_dim)[1], + nn.GELU(), + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2), + ) + elif k == 2: + self.upsample_2x = nn.Sequential( + nn.ConvTranspose2d( + embed_dim, embed_dim, kernel_size=2, stride=2)) + elif k == 1: + self.identity = nn.Identity() + elif k == 0.5: + self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2) + elif k == 0.25: + self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) + else: + raise KeyError(f'invalid {k} for feature2pyramid') + + def forward(self, inputs): + assert len(inputs) == len(self.rescales) + outputs = [] + if self.upsample_4x is not None: + ops = [ + self.upsample_4x, self.upsample_2x, self.identity, + self.downsample_2x + ] + else: + ops = [ + self.upsample_2x, self.identity, self.downsample_2x, + self.downsample_4x + ] + for i in range(len(inputs)): + outputs.append(ops[i](inputs[i])) + return tuple(outputs) diff --git a/finetune/mmseg/models/necks/fpn.py b/finetune/mmseg/models/necks/fpn.py new file mode 100644 index 0000000..ddab74c --- /dev/null +++ b/finetune/mmseg/models/necks/fpn.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +@MODELS.register_module() +class FPN(BaseModule): + """Feature Pyramid Network. + + This neck is the implementation of `Feature Pyramid Networks for Object + Detection `_. + + Args: + in_channels (list[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, its actual mode is specified by `extra_convs_on_inputs`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs + on the original feature from the backbone. If True, + it is equivalent to `add_extra_convs='on_input'`. If False, it is + equivalent to set `add_extra_convs='on_output'`. Default to True. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: dict(mode='nearest'). + init_cfg (dict or list[dict], optional): Initialization config dict. + + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = FPN(in_channels, 11, len(in_channels)).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + extra_convs_on_inputs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + upsample_cfg=dict(mode='nearest'), + init_cfg=dict( + type='Xavier', layer='Conv2d', distribution='uniform')): + super().__init__(init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + if extra_convs_on_inputs: + # For compatibility with previous release + # TODO: deprecate `extra_convs_on_inputs` + self.add_extra_convs = 'on_input' + else: + self.add_extra_convs = 'on_output' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) diff --git a/finetune/mmseg/models/necks/fusion_multilevel_neck.py b/finetune/mmseg/models/necks/fusion_multilevel_neck.py new file mode 100644 index 0000000..b5abb73 --- /dev/null +++ b/finetune/mmseg/models/necks/fusion_multilevel_neck.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +from .multilevel_neck import MultiLevelNeck +from .fusion_transformer import FusionTransformer +from mmseg.registry import MODELS + + +@MODELS.register_module() +class FusionMultiLevelNeck(nn.Module): + def __init__(self, + ts_size=10, + in_channels_ml=[768, 768, 768, 768], + out_channels_ml=768, + scales_ml=[0.5, 1, 2, 4], + norm_cfg_ml=None, + act_cfg_ml=None, + input_dims=768, + embed_dims=768, + num_layers=4, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + with_cls_token=True, + output_cls_token=True, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + num_fcs=2, + norm_eval=False, + with_cp=False, + init_cfg=None, + *args, + **kwargs): + super(FusionMultiLevelNeck, self).__init__() + self.in_channels = in_channels_ml + self.ts_size = ts_size + self.multilevel_neck = MultiLevelNeck( + in_channels_ml, + out_channels_ml, + scales_ml, + norm_cfg_ml, + act_cfg_ml + ) + # self.up_head = UPHead(1024, 2816, 4) + + self.fusion_transformer = FusionTransformer( + input_dims, + embed_dims, + num_layers, + num_heads, + mlp_ratio, + qkv_bias, + drop_rate, + attn_drop_rate, + drop_path_rate, + with_cls_token, + output_cls_token, + norm_cfg, + act_cfg, + num_fcs, + norm_eval, + with_cp, + init_cfg, + ) + + def init_weights(self): + self.fusion_transformer.init_weights() + + def forward(self, inputs, require_feat: bool = False, require_two: bool = False): + assert len(inputs) == len(self.in_channels) + + inputs = self.multilevel_neck(inputs) + + ts = self.ts_size + b_total, c, h, w = inputs[-1].shape + b = int(b_total / ts) + outs = [] + for idx in range(len(inputs)): + + input_feat = inputs[idx] + b_total, c, h, w = inputs[idx].shape + input_feat = input_feat.reshape(b, ts, c, h, w).permute(0, 3, 4, 1, 2).reshape(b*h*w, ts, c) # b*ts, c, h, w转换为b*h*w, ts, c + feat_fusion = self.fusion_transformer(input_feat, require_feat, require_two) + c_fusion = feat_fusion.shape[-1] + feat_fusion = feat_fusion.reshape(b, h, w, c_fusion).permute(0, 3, 1, 2) # b*h*w, c -> b, c, h, w + outs.append(feat_fusion) + + return tuple(outs) \ No newline at end of file diff --git a/finetune/mmseg/models/necks/fusion_transformer.py b/finetune/mmseg/models/necks/fusion_transformer.py new file mode 100644 index 0000000..3be57c6 --- /dev/null +++ b/finetune/mmseg/models/necks/fusion_transformer.py @@ -0,0 +1,166 @@ +# Copyright (c) Ant Group. All rights reserved. +from collections import OrderedDict +import torch +import torch.nn as nn +from mmengine.model.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict +from torch.nn.modules.batchnorm import _BatchNorm +from mmseg.models.backbones.vit import TransformerEncoderLayer + +# from mmseg.utils import get_root_logger +from mmseg.registry import MODELS + +# @MODELS.register_module() +class FusionTransformer(nn.Module): + def __init__(self, + input_dims=768, + embed_dims=768, + num_layers=4, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + with_cls_token=True, + output_cls_token=True, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + num_fcs=2, + norm_eval=False, + with_cp=False, + init_cfg=None, + *args, + **kwargs): + super(FusionTransformer, self).__init__() + + self.porj_linear = nn.Linear(input_dims, embed_dims) + if output_cls_token: + assert with_cls_token is True, f'with_cls_token must be True if' \ + f'set output_cls_token to True, but got {with_cls_token}' + + self.init_cfg = init_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.with_cls_token = with_cls_token + self.output_cls_token = output_cls_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self.drop_after_pos = nn.Dropout(p=drop_rate) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, num_layers) + ] # stochastic depth decay rule + + self.layers = nn.ModuleList() + for i in range(num_layers): + self.layers.append( + TransformerEncoderLayer(embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * + embed_dims, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + batch_first=True)) + + def init_weights(self): + if isinstance(self.init_cfg, dict) and \ + self.init_cfg.get('type') in ['Pretrained', 'Pretrained_Part']: + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + + if self.init_cfg.get('type') == 'Pretrained': + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + elif self.init_cfg.get('type') == 'Pretrained_Part': + state_dict = checkpoint.copy() + para_prefix = 'image_encoder' + prefix_len = len(para_prefix) + 1 + for k, v in checkpoint.items(): + state_dict.pop(k) + if para_prefix in k: + state_dict[k[prefix_len:]] = v + + # if 'pos_embed' in state_dict.keys(): + # if self.pos_embed.shape != state_dict['pos_embed'].shape: + # print_log(msg=f'Resize the pos_embed shape from ' + # f'{state_dict["pos_embed"].shape} to ' + # f'{self.pos_embed.shape}') + # h, w = self.img_size + # pos_size = int( + # math.sqrt(state_dict['pos_embed'].shape[1] - 1)) + # state_dict['pos_embed'] = self.resize_pos_embed( + # state_dict['pos_embed'], + # (h // self.patch_size, w // self.patch_size), + # (pos_size, pos_size), self.interpolate_mode) + + load_state_dict(self, state_dict, strict=False, logger=None) + elif self.init_cfg is not None: + super().init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + # trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def forward(self, inputs, require_feat: bool = False, require_two: bool = False): + inputs = self.porj_linear(inputs) + B, N, C = inputs.shape + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, inputs), dim=1) + if not self.with_cls_token: + # Remove class token for transformer encoder input + x = x[:, 1:] + + # add hidden and atten state + block_outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if require_feat: + block_outs.append(x) + + if self.output_cls_token: + if require_two: + x = x[:, :2] + else: + x = x[:, 0] + elif not self.output_cls_token and self.with_cls_token: + x = x[:, 1:] + + if require_feat: + return x, block_outs + else: + return x + + def train(self, mode=True): + super(FusionTransformer, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.LayerNorm): + m.eval() + +if __name__ == '__main__': + fusion_transformer = FusionTransformer() + print(fusion_transformer) \ No newline at end of file diff --git a/finetune/mmseg/models/necks/ic_neck.py b/finetune/mmseg/models/necks/ic_neck.py new file mode 100644 index 0000000..9763541 --- /dev/null +++ b/finetune/mmseg/models/necks/ic_neck.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +class CascadeFeatureFusion(BaseModule): + """Cascade Feature Fusion Unit in ICNet. + + Args: + low_channels (int): The number of input channels for + low resolution feature map. + high_channels (int): The number of input channels for + high resolution feature map. + out_channels (int): The number of output channels. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Dictionary to construct and config act layer. + Default: dict(type='ReLU'). + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + + Returns: + x (Tensor): The output tensor of shape (N, out_channels, H, W). + x_low (Tensor): The output tensor of shape (N, out_channels, H, W) + for Cascade Label Guidance in auxiliary heads. + """ + + def __init__(self, + low_channels, + high_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.align_corners = align_corners + self.conv_low = ConvModule( + low_channels, + out_channels, + 3, + padding=2, + dilation=2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv_high = ConvModule( + high_channels, + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x_low, x_high): + x_low = resize( + x_low, + size=x_high.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + # Note: Different from original paper, `x_low` is underwent + # `self.conv_low` rather than another 1x1 conv classifier + # before being used for auxiliary head. + x_low = self.conv_low(x_low) + x_high = self.conv_high(x_high) + x = x_low + x_high + x = F.relu(x, inplace=True) + return x, x_low + + +@MODELS.register_module() +class ICNeck(BaseModule): + """ICNet for Real-Time Semantic Segmentation on High-Resolution Images. + + This head is the implementation of `ICHead + `_. + + Args: + in_channels (int): The number of input image channels. Default: 3. + out_channels (int): The numbers of output feature channels. + Default: 128. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Dictionary to construct and config act layer. + Default: dict(type='ReLU'). + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=(64, 256, 256), + out_channels=128, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert len(in_channels) == 3, 'Length of input channels \ + must be 3!' + + self.in_channels = in_channels + self.out_channels = out_channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + self.cff_24 = CascadeFeatureFusion( + self.in_channels[2], + self.in_channels[1], + self.out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + + self.cff_12 = CascadeFeatureFusion( + self.out_channels, + self.in_channels[0], + self.out_channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + + def forward(self, inputs): + assert len(inputs) == 3, 'Length of input feature \ + maps must be 3!' + + x_sub1, x_sub2, x_sub4 = inputs + x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2) + x_cff_12, x_12 = self.cff_12(x_cff_24, x_sub1) + # Note: `x_cff_12` is used for decode_head, + # `x_24` and `x_12` are used for auxiliary head. + return x_24, x_12, x_cff_12 diff --git a/finetune/mmseg/models/necks/jpu.py b/finetune/mmseg/models/necks/jpu.py new file mode 100644 index 0000000..3ea0fe2 --- /dev/null +++ b/finetune/mmseg/models/necks/jpu.py @@ -0,0 +1,131 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule +from mmengine.model import BaseModule + +from mmseg.registry import MODELS +from ..utils import resize + + +@MODELS.register_module() +class JPU(BaseModule): + """FastFCN: Rethinking Dilated Convolution in the Backbone + for Semantic Segmentation. + + This Joint Pyramid Upsampling (JPU) neck is the implementation of + `FastFCN `_. + + Args: + in_channels (Tuple[int], optional): The number of input channels + for each convolution operations before upsampling. + Default: (512, 1024, 2048). + mid_channels (int): The number of output channels of JPU. + Default: 512. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + dilations (tuple[int]): Dilation rate of each Depthwise + Separable ConvModule. Default: (1, 2, 4, 8). + align_corners (bool, optional): The align_corners argument of + resize operation. Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=(512, 1024, 2048), + mid_channels=512, + start_level=0, + end_level=-1, + dilations=(1, 2, 4, 8), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + assert isinstance(in_channels, tuple) + assert isinstance(dilations, tuple) + self.in_channels = in_channels + self.mid_channels = mid_channels + self.start_level = start_level + self.num_ins = len(in_channels) + if end_level == -1: + self.backbone_end_level = self.num_ins + else: + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + + self.dilations = dilations + self.align_corners = align_corners + + self.conv_layers = nn.ModuleList() + self.dilation_layers = nn.ModuleList() + for i in range(self.start_level, self.backbone_end_level): + conv_layer = nn.Sequential( + ConvModule( + self.in_channels[i], + self.mid_channels, + kernel_size=3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.conv_layers.append(conv_layer) + for i in range(len(dilations)): + dilation_layer = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=(self.backbone_end_level - self.start_level) * + self.mid_channels, + out_channels=self.mid_channels, + kernel_size=3, + stride=1, + padding=dilations[i], + dilation=dilations[i], + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=norm_cfg, + pw_act_cfg=act_cfg)) + self.dilation_layers.append(dilation_layer) + + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == len(self.in_channels), 'Length of inputs must \ + be the same with self.in_channels!' + + feats = [ + self.conv_layers[i - self.start_level](inputs[i]) + for i in range(self.start_level, self.backbone_end_level) + ] + + h, w = feats[0].shape[2:] + for i in range(1, len(feats)): + feats[i] = resize( + feats[i], + size=(h, w), + mode='bilinear', + align_corners=self.align_corners) + + feat = torch.cat(feats, dim=1) + concat_feat = torch.cat([ + self.dilation_layers[i](feat) for i in range(len(self.dilations)) + ], + dim=1) + + outs = [] + + # Default: outs[2] is the output of JPU for decoder head, outs[1] is + # the feature map from backbone for auxiliary head. Additionally, + # outs[0] can also be used for auxiliary head. + for i in range(self.start_level, self.backbone_end_level - 1): + outs.append(inputs[i]) + outs.append(concat_feat) + return tuple(outs) diff --git a/finetune/mmseg/models/necks/mla_neck.py b/finetune/mmseg/models/necks/mla_neck.py new file mode 100644 index 0000000..db250ae --- /dev/null +++ b/finetune/mmseg/models/necks/mla_neck.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule, build_norm_layer + +from mmseg.registry import MODELS + + +class MLAModule(nn.Module): + + def __init__(self, + in_channels=[1024, 1024, 1024, 1024], + out_channels=256, + norm_cfg=None, + act_cfg=None): + super().__init__() + self.channel_proj = nn.ModuleList() + for i in range(len(in_channels)): + self.channel_proj.append( + ConvModule( + in_channels=in_channels[i], + out_channels=out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.feat_extract = nn.ModuleList() + for i in range(len(in_channels)): + self.feat_extract.append( + ConvModule( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs): + + # feat_list -> [p2, p3, p4, p5] + feat_list = [] + for x, conv in zip(inputs, self.channel_proj): + feat_list.append(conv(x)) + + # feat_list -> [p5, p4, p3, p2] + # mid_list -> [m5, m4, m3, m2] + feat_list = feat_list[::-1] + mid_list = [] + for feat in feat_list: + if len(mid_list) == 0: + mid_list.append(feat) + else: + mid_list.append(mid_list[-1] + feat) + + # mid_list -> [m5, m4, m3, m2] + # out_list -> [o2, o3, o4, o5] + out_list = [] + for mid, conv in zip(mid_list, self.feat_extract): + out_list.append(conv(mid)) + + return tuple(out_list) + + +@MODELS.register_module() +class MLANeck(nn.Module): + """Multi-level Feature Aggregation. + + This neck is `The Multi-level Feature Aggregation construction of + SETR `_. + + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + norm_layer (dict): Config dict for input normalization. + Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), + norm_cfg=None, + act_cfg=None): + super().__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + + # In order to build general vision transformer backbone, we have to + # move MLA to neck. + self.norm = nn.ModuleList([ + build_norm_layer(norm_layer, in_channels[i])[1] + for i in range(len(in_channels)) + ]) + + self.mla = MLAModule( + in_channels=in_channels, + out_channels=out_channels, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + + # Convert from nchw to nlc + outs = [] + for i in range(len(inputs)): + x = inputs[i] + n, c, h, w = x.shape + x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() + x = self.norm[i](x) + x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() + outs.append(x) + + outs = self.mla(outs) + return tuple(outs) diff --git a/finetune/mmseg/models/necks/multilevel_neck.py b/finetune/mmseg/models/necks/multilevel_neck.py new file mode 100644 index 0000000..c997125 --- /dev/null +++ b/finetune/mmseg/models/necks/multilevel_neck.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model.weight_init import xavier_init + +from mmseg.registry import MODELS +from ..utils import resize + + +@MODELS.register_module() +class MultiLevelNeck(nn.Module): + """MultiLevelNeck. + + A neck structure connect vit backbone and decoder_heads. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + scales (List[float]): Scale factors for each input feature map. + Default: [0.5, 1, 2, 4] + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + scales=[0.5, 1, 2, 4], + norm_cfg=None, + act_cfg=None): + super().__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.scales = scales + self.num_outs = len(scales) + self.lateral_convs = nn.ModuleList() + self.convs = nn.ModuleList() + for in_channel in in_channels: + self.lateral_convs.append( + ConvModule( + in_channel, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + for _ in range(self.num_outs): + self.convs.append( + ConvModule( + out_channels, + out_channels, + kernel_size=3, + padding=1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + # default init_weights for conv(msra) and norm in ConvModule + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + xavier_init(m, distribution='uniform') + + def forward(self, inputs): + assert len(inputs) == len(self.in_channels) + inputs = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + # for len(inputs) not equal to self.num_outs + if len(inputs) == 1: + inputs = [inputs[0] for _ in range(self.num_outs)] + outs = [] + for i in range(self.num_outs): + x_resize = resize( + inputs[i], scale_factor=self.scales[i], mode='bilinear') + outs.append(self.convs[i](x_resize)) + return tuple(outs) diff --git a/finetune/mmseg/models/segmentors/__init__.py b/finetune/mmseg/models/segmentors/__init__.py new file mode 100644 index 0000000..59b012f --- /dev/null +++ b/finetune/mmseg/models/segmentors/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseSegmentor +from .cascade_encoder_decoder import CascadeEncoderDecoder +from .depth_estimator import DepthEstimator +from .encoder_decoder import EncoderDecoder +from .multimodal_encoder_decoder import MultimodalEncoderDecoder +from .seg_tta import SegTTAModel + +__all__ = [ + 'BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder', 'SegTTAModel', + 'MultimodalEncoderDecoder', 'DepthEstimator' +] diff --git a/finetune/mmseg/models/segmentors/base.py b/finetune/mmseg/models/segmentors/base.py new file mode 100644 index 0000000..17a0bb2 --- /dev/null +++ b/finetune/mmseg/models/segmentors/base.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import List, Tuple + +from mmengine.model import BaseModel +from mmengine.structures import PixelData +from torch import Tensor + +from mmseg.structures import SegDataSample +from mmseg.utils import (ForwardResults, OptConfigType, OptMultiConfig, + OptSampleList, SampleList) +from ..utils import resize + + +class BaseSegmentor(BaseModel, metaclass=ABCMeta): + """Base class for segmentors. + + Args: + data_preprocessor (dict, optional): Model preprocessing config + for processing the input data. it usually includes + ``to_rgb``, ``pad_size_divisor``, ``pad_val``, + ``mean`` and ``std``. Default to None. + init_cfg (dict, optional): the config to control the + initialization. Default to None. + """ + + def __init__(self, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + + @property + def with_neck(self) -> bool: + """bool: whether the segmentor has neck""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_auxiliary_head(self) -> bool: + """bool: whether the segmentor has auxiliary head""" + return hasattr(self, + 'auxiliary_head') and self.auxiliary_head is not None + + @property + def with_decode_head(self) -> bool: + """bool: whether the segmentor has decode head""" + return hasattr(self, 'decode_head') and self.decode_head is not None + + @abstractmethod + def extract_feat(self, inputs: Tensor) -> bool: + """Placeholder for extract features from images.""" + pass + + @abstractmethod + def encode_decode(self, inputs: Tensor, batch_data_samples: SampleList): + """Placeholder for encode images with backbone and decode into a + semantic segmentation map of the same size as input.""" + pass + + def forward(self, + inputs: Tensor, + data_samples: OptSampleList = None, + mode: str = 'tensor') -> ForwardResults: + """The unified entry for a forward process in both training and test. + + The method should accept three modes: "tensor", "predict" and "loss": + + - "tensor": Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - "predict": Forward and return the predictions, which are fully + processed to a list of :obj:`SegDataSample`. + - "loss": Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (torch.Tensor): The input tensor with shape (N, C, ...) in + general. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. Default to None. + mode (str): Return what kind of value. Defaults to 'tensor'. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of :obj:`DetDataSample`. + - If ``mode="loss"``, return a dict of tensor. + """ + if mode == 'loss': + return self.loss(inputs, data_samples) + elif mode == 'predict': + return self.predict(inputs, data_samples) + elif mode == 'tensor': + return self._forward(inputs, data_samples) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') + + @abstractmethod + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + pass + + @abstractmethod + def predict(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing.""" + pass + + @abstractmethod + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: + """Network forward process. + + Usually includes backbone, neck and head forward without any post- + processing. + """ + pass + + def postprocess_result(self, + seg_logits: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """ Convert results list to `SegDataSample`. + Args: + seg_logits (Tensor): The segmentation results, seg_logits from + model of each input image. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. Default to None. + Returns: + list[:obj:`SegDataSample`]: Segmentation results of the + input images. Each SegDataSample usually contain: + + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic + segmentation before normalization. + """ + batch_size, C, H, W = seg_logits.shape + + if data_samples is None: + data_samples = [SegDataSample() for _ in range(batch_size)] + only_prediction = True + else: + only_prediction = False + + for i in range(batch_size): + if not only_prediction: + img_meta = data_samples[i].metainfo + # remove padding area + if 'img_padding_size' not in img_meta: + padding_size = img_meta.get('padding_size', [0] * 4) + else: + padding_size = img_meta['img_padding_size'] + padding_left, padding_right, padding_top, padding_bottom =\ + padding_size + # i_seg_logits shape is 1, C, H, W after remove padding + i_seg_logits = seg_logits[i:i + 1, :, + padding_top:H - padding_bottom, + padding_left:W - padding_right] + + flip = img_meta.get('flip', None) + if flip: + flip_direction = img_meta.get('flip_direction', None) + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + i_seg_logits = i_seg_logits.flip(dims=(3, )) + else: + i_seg_logits = i_seg_logits.flip(dims=(2, )) + + # resize as original shape + i_seg_logits = resize( + i_seg_logits, + size=img_meta['ori_shape'], + mode='bilinear', + align_corners=self.align_corners, + warning=False).squeeze(0) + else: + i_seg_logits = seg_logits[i] + + if C > 1: + i_seg_pred = i_seg_logits.argmax(dim=0, keepdim=True) + else: + i_seg_logits = i_seg_logits.sigmoid() + i_seg_pred = (i_seg_logits > + self.decode_head.threshold).to(i_seg_logits) + data_samples[i].set_data({ + 'seg_logits': + PixelData(**{'data': i_seg_logits}), + 'pred_sem_seg': + PixelData(**{'data': i_seg_pred}) + }) + + return data_samples diff --git a/finetune/mmseg/models/segmentors/cascade_encoder_decoder.py b/finetune/mmseg/models/segmentors/cascade_encoder_decoder.py new file mode 100644 index 0000000..0184a35 --- /dev/null +++ b/finetune/mmseg/models/segmentors/cascade_encoder_decoder.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +from torch import Tensor, nn + +from mmseg.registry import MODELS +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from .encoder_decoder import EncoderDecoder + + +@MODELS.register_module() +class CascadeEncoderDecoder(EncoderDecoder): + """Cascade Encoder Decoder segmentors. + + CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of + CascadeEncoderDecoder are cascaded. The output of previous decoder_head + will be the input of next decoder_head. + + Args: + + num_stages (int): How many stages will be cascaded. + backbone (ConfigType): The config for the backnone of segmentor. + decode_head (ConfigType): The config for the decode head of segmentor. + neck (OptConfigType): The config for the neck of segmentor. + Defaults to None. + auxiliary_head (OptConfigType): The config for the auxiliary head of + segmentor. Defaults to None. + train_cfg (OptConfigType): The config for training. Defaults to None. + test_cfg (OptConfigType): The config for testing. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + pretrained (str, optional): The path for pretrained model. + Defaults to None. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + """ + + def __init__(self, + num_stages: int, + backbone: ConfigType, + decode_head: ConfigType, + neck: OptConfigType = None, + auxiliary_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + init_cfg: OptMultiConfig = None): + self.num_stages = num_stages + super().__init__( + backbone=backbone, + decode_head=decode_head, + neck=neck, + auxiliary_head=auxiliary_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + pretrained=pretrained, + init_cfg=init_cfg) + + def _init_decode_head(self, decode_head: ConfigType) -> None: + """Initialize ``decode_head``""" + assert isinstance(decode_head, list) + assert len(decode_head) == self.num_stages + self.decode_head = nn.ModuleList() + for i in range(self.num_stages): + self.decode_head.append(MODELS.build(decode_head[i])) + self.align_corners = self.decode_head[-1].align_corners + self.num_classes = self.decode_head[-1].num_classes + self.out_channels = self.decode_head[-1].out_channels + + def encode_decode(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(inputs) + out = self.decode_head[0].forward(x) + for i in range(1, self.num_stages - 1): + out = self.decode_head[i].forward(x, out) + seg_logits_list = self.decode_head[-1].predict(x, out, batch_img_metas, + self.test_cfg) + + return seg_logits_list + + def _decode_head_forward_train(self, inputs: Tensor, + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + + loss_decode = self.decode_head[0].loss(inputs, data_samples, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode_0')) + # get batch_img_metas + batch_size = len(data_samples) + batch_img_metas = [] + for batch_index in range(batch_size): + metainfo = data_samples[batch_index].metainfo + batch_img_metas.append(metainfo) + + for i in range(1, self.num_stages): + # forward test again, maybe unnecessary for most methods. + if i == 1: + prev_outputs = self.decode_head[0].forward(inputs) + else: + prev_outputs = self.decode_head[i - 1].forward( + inputs, prev_outputs) + loss_decode = self.decode_head[i].loss(inputs, prev_outputs, + data_samples, + self.train_cfg) + losses.update(add_prefix(loss_decode, f'decode_{i}')) + + return losses + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_semantic_seg`. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(inputs) + + out = self.decode_head[0].forward(x) + for i in range(1, self.num_stages): + # TODO support PointRend tensor mode + out = self.decode_head[i].forward(x, out) + + return out diff --git a/finetune/mmseg/models/segmentors/depth_estimator.py b/finetune/mmseg/models/segmentors/depth_estimator.py new file mode 100644 index 0000000..1020637 --- /dev/null +++ b/finetune/mmseg/models/segmentors/depth_estimator.py @@ -0,0 +1,392 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.logging import print_log +from mmengine.structures import PixelData +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from ..utils import resize +from .encoder_decoder import EncoderDecoder + + +@MODELS.register_module() +class DepthEstimator(EncoderDecoder): + """Encoder Decoder depth estimator. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + + 1. The ``loss`` method is used to calculate the loss of model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2) Call the decode head loss function to forward decode head model and + calculate losses. + + .. code:: text + + loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional) + _decode_head_forward_train(): decode_head.loss() + _auxiliary_head_forward_train(): auxiliary_head.loss (optional) + + 2. The ``predict`` method is used to predict depth estimation results, + which includes two steps: (1) Run inference function to obtain the list of + depth (2) Call post-processing function to obtain list of + ``SegDataSample`` including ``pred_depth_map``. + + .. code:: text + + predict(): inference() -> postprocess_result() + inference(): whole_inference()/slide_inference() + whole_inference()/slide_inference(): encoder_decoder() + encoder_decoder(): extract_feat() -> decode_head.predict() + + 3. The ``_forward`` method is used to output the tensor by running the model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2)Call the decode head forward function to forward decode head model. + + .. code:: text + + _forward(): extract_feat() -> _decode_head.forward() + + Args: + + backbone (ConfigType): The config for the backnone of depth estimator. + decode_head (ConfigType): The config for the decode head of depth estimator. + neck (OptConfigType): The config for the neck of depth estimator. + Defaults to None. + auxiliary_head (OptConfigType): The config for the auxiliary head of + depth estimator. Defaults to None. + train_cfg (OptConfigType): The config for training. Defaults to None. + test_cfg (OptConfigType): The config for testing. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + pretrained (str, optional): The path for pretrained model. + Defaults to None. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + """ # noqa: E501 + + def __init__(self, + backbone: ConfigType, + decode_head: ConfigType, + neck: OptConfigType = None, + auxiliary_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + init_cfg: OptMultiConfig = None): + super().__init__( + backbone=backbone, + decode_head=decode_head, + neck=neck, + auxiliary_head=auxiliary_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + pretrained=pretrained, + init_cfg=init_cfg) + + def extract_feat(self, + inputs: Tensor, + batch_img_metas: Optional[List[dict]] = None) -> Tensor: + """Extract features from images.""" + + if getattr(self.backbone, 'class_embed_select', False) and \ + isinstance(batch_img_metas, list) and \ + 'category_id' in batch_img_metas[0]: + cat_ids = [meta['category_id'] for meta in batch_img_metas] + cat_ids = torch.tensor(cat_ids).to(inputs.device) + inputs = (inputs, cat_ids) + + x = self.backbone(inputs) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Encode images with backbone and decode into a depth map of the same + size as input.""" + x = self.extract_feat(inputs, batch_img_metas) + depth = self.decode_head.predict(x, batch_img_metas, self.test_cfg) + + return depth + + def _decode_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.loss(inputs, data_samples, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def _auxiliary_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg) + losses.update(add_prefix(loss_aux, f'aux_{idx}')) + else: + loss_aux = self.auxiliary_head.loss(inputs, data_samples, + self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + return losses + + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): Input images. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_depth_map`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + x = self.extract_feat(inputs, batch_img_metas) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, data_samples) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train(x, data_samples) + losses.update(loss_aux) + + return losses + + def predict(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`], optional): The seg data + samples. It usually includes information such as `metainfo` + and `gt_depth_map`. + + Returns: + list[:obj:`SegDataSample`]: Depth estimation results of the + input images. Each SegDataSample usually contain: + + - ``pred_depth_max``(PixelData): Prediction of depth estimation. + """ + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + depth = self.inference(inputs, batch_img_metas) + + return self.postprocess_result(depth, data_samples) + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_depth_map`. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(inputs) + return self.decode_head.forward(x) + + def slide_flip_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference by sliding-window with overlap and flip. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + + Args: + inputs (tensor): the tensor should have a shape NxCxHxW, + which contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The depth estimation results. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = inputs.size() + out_channels = self.out_channels + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) + count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + # change the image shape to patch shape + batch_img_metas[0]['img_shape'] = crop_img.shape[2:] + # the output of encode_decode is depth tensor map + # with shape [N, C, H, W] + crop_depth_map = self.encode_decode(crop_img, batch_img_metas) + + # average out the original and flipped prediction + crop_depth_map_flip = self.encode_decode( + crop_img.flip(dims=(3, )), batch_img_metas) + crop_depth_map_flip = crop_depth_map_flip.flip(dims=(3, )) + crop_depth_map = (crop_depth_map + crop_depth_map_flip) / 2.0 + + preds += F.pad(crop_depth_map, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + depth = preds / count_mat + + return depth + + def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor: + """Inference with slide/whole style. + + Args: + inputs (Tensor): The input image of shape (N, 3, H, W). + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', 'pad_shape', and 'padding_size'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The depth estimation results. + """ + assert self.test_cfg.get('mode', 'whole') in ['slide', 'whole', + 'slide_flip'], \ + f'Only "slide", "slide_flip" or "whole" test mode are ' \ + f'supported, but got {self.test_cfg["mode"]}.' + ori_shape = batch_img_metas[0]['ori_shape'] + if not all(_['ori_shape'] == ori_shape for _ in batch_img_metas): + print_log( + 'Image shapes are different in the batch.', + logger='current', + level=logging.WARN) + if self.test_cfg.mode == 'slide': + depth_map = self.slide_inference(inputs, batch_img_metas) + if self.test_cfg.mode == 'slide_flip': + depth_map = self.slide_flip_inference(inputs, batch_img_metas) + else: + depth_map = self.whole_inference(inputs, batch_img_metas) + + return depth_map + + def postprocess_result(self, + depth: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """ Convert results list to `SegDataSample`. + Args: + depth (Tensor): The depth estimation results. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_depth_map`. Default to None. + Returns: + list[:obj:`SegDataSample`]: Depth estomation results of the + input images. Each SegDataSample usually contain: + + - ``pred_depth_map``(PixelData): Prediction of depth estimation. + """ + batch_size, C, H, W = depth.shape + + if data_samples is None: + data_samples = [SegDataSample() for _ in range(batch_size)] + only_prediction = True + else: + only_prediction = False + + for i in range(batch_size): + if not only_prediction: + img_meta = data_samples[i].metainfo + # remove padding area + if 'img_padding_size' not in img_meta: + padding_size = img_meta.get('padding_size', [0] * 4) + else: + padding_size = img_meta['img_padding_size'] + padding_left, padding_right, padding_top, padding_bottom =\ + padding_size + # i_depth shape is 1, C, H, W after remove padding + i_depth = depth[i:i + 1, :, padding_top:H - padding_bottom, + padding_left:W - padding_right] + + flip = img_meta.get('flip', None) + if flip: + flip_direction = img_meta.get('flip_direction', None) + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + i_depth = i_depth.flip(dims=(3, )) + else: + i_depth = i_depth.flip(dims=(2, )) + + # resize as original shape + i_depth = resize( + i_depth, + size=img_meta['ori_shape'], + mode='bilinear', + align_corners=self.align_corners, + warning=False).squeeze(0) + else: + i_depth = depth[i] + + data_samples[i].set_data( + {'pred_depth_map': PixelData(**{'data': i_depth})}) + + return data_samples diff --git a/finetune/mmseg/models/segmentors/encoder_decoder.py b/finetune/mmseg/models/segmentors/encoder_decoder.py new file mode 100644 index 0000000..fa4050e --- /dev/null +++ b/finetune/mmseg/models/segmentors/encoder_decoder.py @@ -0,0 +1,364 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import List, Optional + +import torch.nn as nn +import torch.nn.functional as F +from mmengine.logging import print_log +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from .base import BaseSegmentor + + +@MODELS.register_module() +class EncoderDecoder(BaseSegmentor): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + + 1. The ``loss`` method is used to calculate the loss of model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2) Call the decode head loss function to forward decode head model and + calculate losses. + + .. code:: text + + loss(): extract_feat() -> _decode_head_forward_train() -> _auxiliary_head_forward_train (optional) + _decode_head_forward_train(): decode_head.loss() + _auxiliary_head_forward_train(): auxiliary_head.loss (optional) + + 2. The ``predict`` method is used to predict segmentation results, + which includes two steps: (1) Run inference function to obtain the list of + seg_logits (2) Call post-processing function to obtain list of + ``SegDataSample`` including ``pred_sem_seg`` and ``seg_logits``. + + .. code:: text + + predict(): inference() -> postprocess_result() + infercen(): whole_inference()/slide_inference() + whole_inference()/slide_inference(): encoder_decoder() + encoder_decoder(): extract_feat() -> decode_head.predict() + + 3. The ``_forward`` method is used to output the tensor by running the model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2)Call the decode head forward function to forward decode head model. + + .. code:: text + + _forward(): extract_feat() -> _decode_head.forward() + + Args: + + backbone (ConfigType): The config for the backnone of segmentor. + decode_head (ConfigType): The config for the decode head of segmentor. + neck (OptConfigType): The config for the neck of segmentor. + Defaults to None. + auxiliary_head (OptConfigType): The config for the auxiliary head of + segmentor. Defaults to None. + train_cfg (OptConfigType): The config for training. Defaults to None. + test_cfg (OptConfigType): The config for testing. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + pretrained (str, optional): The path for pretrained model. + Defaults to None. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + """ # noqa: E501 + + def __init__(self, + backbone: ConfigType, + decode_head: ConfigType, + neck: OptConfigType = None, + auxiliary_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + init_cfg: OptMultiConfig = None): + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + if pretrained is not None: + assert backbone.get('pretrained') is None, \ + 'both backbone and segmentor set pretrained weight' + backbone.pretrained = pretrained + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head: ConfigType) -> None: + """Initialize ``decode_head``""" + self.decode_head = MODELS.build(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + self.out_channels = self.decode_head.out_channels + + def _init_auxiliary_head(self, auxiliary_head: ConfigType) -> None: + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(MODELS.build(head_cfg)) + else: + self.auxiliary_head = MODELS.build(auxiliary_head) + + def extract_feat(self, inputs: Tensor) -> List[Tensor]: + """Extract features from images.""" + x = self.backbone(inputs) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(inputs) + seg_logits = self.decode_head.predict(x, batch_img_metas, + self.test_cfg) + + return seg_logits + + def _decode_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.loss(inputs, data_samples, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def _auxiliary_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.loss(inputs, data_samples, self.train_cfg) + losses.update(add_prefix(loss_aux, f'aux_{idx}')) + else: + loss_aux = self.auxiliary_head.loss(inputs, data_samples, + self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + return losses + + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): Input images. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(inputs) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, data_samples) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train(x, data_samples) + losses.update(loss_aux) + + return losses + + def predict(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`], optional): The seg data + samples. It usually includes information such as `metainfo` + and `gt_sem_seg`. + + Returns: + list[:obj:`SegDataSample`]: Segmentation results of the + input images. Each SegDataSample usually contain: + + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic + segmentation before normalization. + """ + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + seg_logits = self.inference(inputs, batch_img_metas) + + return self.postprocess_result(seg_logits, data_samples) + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(inputs) + return self.decode_head.forward(x) + + def slide_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + + Args: + inputs (tensor): the tensor should have a shape NxCxHxW, + which contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = inputs.size() + out_channels = self.out_channels + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) + count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + # change the image shape to patch shape + batch_img_metas[0]['img_shape'] = crop_img.shape[2:] + # the output of encode_decode is seg logits tensor map + # with shape [N, C, H, W] + crop_seg_logit = self.encode_decode(crop_img, batch_img_metas) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + seg_logits = preds / count_mat + + return seg_logits + + def whole_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference with full image. + + Args: + inputs (Tensor): The tensor should have a shape NxCxHxW, which + contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + seg_logits = self.encode_decode(inputs, batch_img_metas) + + return seg_logits + + def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor: + """Inference with slide/whole style. + + Args: + inputs (Tensor): The input image of shape (N, 3, H, W). + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', 'pad_shape', and 'padding_size'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + assert self.test_cfg.get('mode', 'whole') in ['slide', 'whole'], \ + f'Only "slide" or "whole" test mode are supported, but got ' \ + f'{self.test_cfg["mode"]}.' + ori_shape = batch_img_metas[0]['ori_shape'] + if not all(_['ori_shape'] == ori_shape for _ in batch_img_metas): + print_log( + 'Image shapes are different in the batch.', + logger='current', + level=logging.WARN) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(inputs, batch_img_metas) + else: + seg_logit = self.whole_inference(inputs, batch_img_metas) + + return seg_logit + + def aug_test(self, inputs, batch_img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale) + for i in range(1, len(inputs)): + cur_seg_logit = self.inference(inputs[i], batch_img_metas[i], + rescale) + seg_logit += cur_seg_logit + seg_logit /= len(inputs) + seg_pred = seg_logit.argmax(dim=1) + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/finetune/mmseg/models/segmentors/multimodal_encoder_decoder.py b/finetune/mmseg/models/segmentors/multimodal_encoder_decoder.py new file mode 100644 index 0000000..75aa8b9 --- /dev/null +++ b/finetune/mmseg/models/segmentors/multimodal_encoder_decoder.py @@ -0,0 +1,350 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional + +import torch.nn.functional as F +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import (ConfigType, OptConfigType, OptMultiConfig, + OptSampleList, SampleList, add_prefix) +from .base import BaseSegmentor + + +@MODELS.register_module() +class MultimodalEncoderDecoder(BaseSegmentor): + """Multimodal Encoder-Decoder segmentors. + + Multimodal segmentation architecture is used for open-vocabulary + semantic segmentation with combining the visual and language + pretrain models. It consists of a image_encoder (backbone) to extract + visual feature, a text encoder to extract text feature, and a decode + head to generate semantic maps. + Note that the deep supervision during training is implemented in decode head. + + 1. The ``loss`` method is used to calculate the loss of model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2) Call the decode head loss function to forward decode head model and + calculate losses. + + .. code:: text + + loss(): extract_feat() -> _decode_head_forward_train() + _decode_head_forward_train(): decode_head.loss() + + 2. The ``predict`` method is used to predict segmentation results, + which includes two steps: (1) Run inference function to obtain the list of + seg_logits (2) Call post-processing function to obtain list of + ``SegDataSampel`` including ``pred_sem_seg`` and ``seg_logits``. + + .. code:: text + + predict(): inference() -> postprocess_result() + inference(): whole_inference()/slide_inference() + whole_inference()/slide_inference(): encoder_decoder() + encoder_decoder(): extract_feat() -> decode_head.predict() + + 3. The ``_forward`` method is used to output the tensor by running the model, + which includes two steps: (1) Extracts features to obtain the feature maps + (2)Call the decode head forward function to forward decode head model. + + .. code:: text + + _forward(): extract_feat() -> _decode_head.forward() + + Args: + + image_encoder (ConfigType): The config for the visual encoder of segmentor. + text_encoder ((ConfigType): The config for the text encoder of segmentor. + decode_head (ConfigType): The config for the decode head of segmentor. + train_cfg (OptConfigType): The config for training. Defaults to None. + test_cfg (OptConfigType): The config for testing. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. + pretrained (str, optional): The path for pretrained model. + Defaults to None. + asymetric_input (bool): whether to use different size of input for image encoder + and decode head. Defaults to False. + encoder_resolution (float): resize scale of input images for image encoder. + Defaults to None. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. + """ # noqa: E501 + + def __init__(self, + image_encoder: ConfigType, + text_encoder: ConfigType, + decode_head: ConfigType, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + pretrained: Optional[str] = None, + asymetric_input: bool = True, + encoder_resolution: float = None, + init_cfg: OptMultiConfig = None): + super().__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + if pretrained is not None: + image_encoder.init_cfg = dict( + type='Pretrained_Part', checkpoint=pretrained) + text_encoder.init_cfg = dict( + type='Pretrained_Part', checkpoint=pretrained) + decode_head.init_cfg = dict( + type='Pretrained_Part', checkpoint=pretrained) + + if asymetric_input: + assert encoder_resolution is not None, \ + 'if asymetric_input set True, ' \ + 'clip_resolution must be a certain value' + self.asymetric_input = asymetric_input + self.encoder_resolution = encoder_resolution + self.image_encoder = MODELS.build(image_encoder) + self.text_encoder = MODELS.build(text_encoder) + self._init_decode_head(decode_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head: ConfigType) -> None: + """Initialize ``decode_head``""" + self.decode_head = MODELS.build(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + self.out_channels = self.decode_head.out_channels + + def extract_feat(self, inputs: Tensor) -> List[Tensor]: + """Extract visual features from images.""" + x = self.image_encoder(inputs) + return x + + def encode_decode(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Encode the name of classes with text_encoder and encode images with + image_encoder. + + Then decode the class embedding and visual feature into a semantic + segmentation map of the same size as input. + """ + classifier_embeds = self.text_encoder() + clip_inputs = inputs + if self.asymetric_input: + clip_inputs = F.interpolate( + inputs, scale_factor=self.encoder_resolution, mode='bilinear') + x = self.image_encoder(clip_inputs) + seg_logits = self.decode_head.predict([inputs, x, classifier_embeds], + batch_img_metas, self.test_cfg) + + return seg_logits + + def _decode_head_forward_train(self, inputs: List[Tensor], + data_samples: SampleList) -> dict: + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.loss(inputs, data_samples, + self.train_cfg) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: + """Calculate losses from a batch of inputs and data samples. + + Args: + inputs (Tensor): Input images. + data_samples (list[:obj:`SegDataSample`]): The seg data samples. + It usually includes information such as `metainfo` and + `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + classifier_embeds = self.text_encoder() + clip_inputs = inputs + if self.asymetric_input: + clip_inputs = F.interpolate( + inputs, scale_factor=self.encoder_resolution, mode='bilinear') + x = self.image_encoder(clip_inputs) + + losses = dict() + + loss_decode = self._decode_head_forward_train( + [inputs, x, classifier_embeds], data_samples) + losses.update(loss_decode) + + return losses + + def predict(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`], optional): The seg data + samples. It usually includes information such as `metainfo` + and `gt_sem_seg`. + + Returns: + list[:obj:`SegDataSample`]: Segmentation results of the + input images. Each SegDataSample usually contain: + + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic + segmentation before normalization. + """ + if data_samples is not None: + batch_img_metas = [ + data_sample.metainfo for data_sample in data_samples + ] + else: + batch_img_metas = [ + dict( + ori_shape=inputs.shape[2:], + img_shape=inputs.shape[2:], + pad_shape=inputs.shape[2:], + padding_size=[0, 0, 0, 0]) + ] * inputs.shape[0] + + seg_logits = self.inference(inputs, batch_img_metas) + + return self.postprocess_result(seg_logits, data_samples) + + def _forward(self, + inputs: Tensor, + data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + inputs (Tensor): Inputs with shape (N, C, H, W). + data_samples (List[:obj:`SegDataSample`]): The seg + data samples. It usually includes information such + as `metainfo` and `gt_sem_seg`. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(inputs) + return self.decode_head.forward(x) + + def slide_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + + Args: + inputs (tensor): the tensor should have a shape NxCxHxW, + which contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = inputs.size() + out_channels = self.out_channels + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) + count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + # change the image shape to patch shape + batch_img_metas[0]['img_shape'] = crop_img.shape[2:] + # the output of encode_decode is seg logits tensor map + # with shape [N, C, H, W] + crop_seg_logit = self.encode_decode(crop_img, batch_img_metas) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + seg_logits = preds / count_mat + + return seg_logits + + def whole_inference(self, inputs: Tensor, + batch_img_metas: List[dict]) -> Tensor: + """Inference with full image. + + Args: + inputs (Tensor): The tensor should have a shape NxCxHxW, which + contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + seg_logits = self.encode_decode(inputs, batch_img_metas) + + return seg_logits + + def inference(self, inputs: Tensor, batch_img_metas: List[dict]) -> Tensor: + """Inference with slide/whole style. + + Args: + inputs (Tensor): The input image of shape (N, 3, H, W). + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', 'pad_shape', and 'padding_size'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + assert self.test_cfg.mode in ['slide', 'whole'] + ori_shape = batch_img_metas[0]['ori_shape'] + assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(inputs, batch_img_metas) + else: + seg_logit = self.whole_inference(inputs, batch_img_metas) + + return seg_logit + + def aug_test(self, inputs, batch_img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(inputs[0], batch_img_metas[0], rescale) + for i in range(1, len(inputs)): + cur_seg_logit = self.inference(inputs[i], batch_img_metas[i], + rescale) + seg_logit += cur_seg_logit + seg_logit /= len(inputs) + seg_pred = seg_logit.argmax(dim=1) + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/finetune/mmseg/models/segmentors/seg_tta.py b/finetune/mmseg/models/segmentors/seg_tta.py new file mode 100644 index 0000000..63ef61d --- /dev/null +++ b/finetune/mmseg/models/segmentors/seg_tta.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +from mmengine.model import BaseTTAModel +from mmengine.structures import PixelData + +from mmseg.registry import MODELS +from mmseg.utils import SampleList + + +@MODELS.register_module() +class SegTTAModel(BaseTTAModel): + + def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList: + """Merge predictions of enhanced data to one prediction. + + Args: + data_samples_list (List[SampleList]): List of predictions + of all enhanced data. + + Returns: + SampleList: Merged prediction. + """ + predictions = [] + for data_samples in data_samples_list: + seg_logits = data_samples[0].seg_logits.data + logits = torch.zeros(seg_logits.shape).to(seg_logits) + for data_sample in data_samples: + seg_logit = data_sample.seg_logits.data + if self.module.out_channels > 1: + logits += seg_logit.softmax(dim=0) + else: + logits += seg_logit.sigmoid() + logits /= len(data_samples) + if self.module.out_channels == 1: + seg_pred = (logits > self.module.decode_head.threshold + ).to(logits).squeeze(1) + else: + seg_pred = logits.argmax(dim=0) + data_sample.set_data({'pred_sem_seg': PixelData(data=seg_pred)}) + if hasattr(data_samples[0], 'gt_sem_seg'): + data_sample.set_data( + {'gt_sem_seg': data_samples[0].gt_sem_seg}) + data_sample.set_metainfo({'img_path': data_samples[0].img_path}) + predictions.append(data_sample) + return predictions diff --git a/finetune/mmseg/models/text_encoder/__init__.py b/finetune/mmseg/models/text_encoder/__init__.py new file mode 100644 index 0000000..199856d --- /dev/null +++ b/finetune/mmseg/models/text_encoder/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .clip_text_encoder import CLIPTextEncoder + +__all__ = ['CLIPTextEncoder'] diff --git a/finetune/mmseg/models/text_encoder/clip_text_encoder.py b/finetune/mmseg/models/text_encoder/clip_text_encoder.py new file mode 100644 index 0000000..1a18b86 --- /dev/null +++ b/finetune/mmseg/models/text_encoder/clip_text_encoder.py @@ -0,0 +1,229 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import numpy as np +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import BaseTransformerLayer +from mmengine.model import BaseModule, ModuleList +from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict +from torch.nn import functional as F + +from mmseg.registry import MODELS +from mmseg.utils import get_classes, get_predefined_templates, tokenizer + + +@MODELS.register_module() +class CLIPTextEncoder(BaseModule): + """A text encoder with transformer architecture to encode the label text. + + Modified from https://github.com/MendelXu/SAN/blob/main/san/model/clip_utils/classifier.py # noqa:E501 + Copyright (c) 2023 MendelXu. + Licensed under the MIT License + + Args: + dataset_name: (str|None): The name of the dataset to which + the data belongs. + vocabulary: (List[str]|None): The list of class names. Default: None. + templates: (List[str]|None): The prompt template used for labels. + Default: None. + total_vocab_size: (int): Number of all words used by the pre-trained + model. Default: 49408 (CLIP). + context_length: (int): The max length of prompt text. + Default: 77 (CLIP). + embed_dims: (int): Width of transformer model. Default: 512. + num_layers: (int): Depth of transformer. Default: 12, + num_heads: (int): Number of attention heads in transformer. + Default: 8, + mlp_ratio: (int) Ratio of mlp hidden dim to embedding dim in + transformer. Default: 4, + output_dims: (int) Dim of output text embeddings. Default: 512, + cache_feature: (bool) Whether to save class embeddings in cache. + Default: True, + cat_bg: (bool) Whether to add background embedding. Default: True. + norm_cfg (dict|None): Config for norm layer. Default: dict(type='LN') + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + dataset_name: str = None, + vocabulary: List[str] = None, + templates: str = 'vild', + total_vocab_size: int = 49408, + context_length: int = 77, + embed_dims: int = 512, + num_layers: int = 12, + num_heads: int = 8, + mlp_ratio: int = 4, + output_dims: int = 512, + cache_feature: bool = True, + cat_bg: bool = True, + norm_cfg: dict = dict(type='LN'), + init_cfg: dict = None): + super().__init__(init_cfg) + if isinstance(templates, List): + self.templates = templates + else: + self.templates = get_predefined_templates(templates) + + assert dataset_name is not None or vocabulary is not None, \ + "text_encoder required either 'dataset_name' or 'vocabulary'" + assert dataset_name is None or vocabulary is None, \ + "there is conflict between 'dataset_name' and 'vocabulary'" + self.dataset_name = dataset_name + self.vocabulary = vocabulary + self.num_pos = context_length + self.token_embedding = nn.Embedding(total_vocab_size, embed_dims) + self.positional_embedding = nn.Parameter( + torch.empty(context_length, embed_dims)) + self.text_projection = nn.Parameter( + torch.empty(embed_dims, output_dims)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.transformer = ModuleList() + self.register_buffer( + 'attn_mask', self.build_attention_mask(), persistent=False) + for i in range(num_layers): + self.transformer.append( + BaseTransformerLayer( + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=embed_dims, + num_heads=num_heads, + batch_first=False, + bias=True), + ffn_cfgs=dict( + type='FFN', + embed_dims=embed_dims, + feedforward_channels=mlp_ratio * embed_dims, + act_cfg=dict(type='QuickGELU')), + operation_order=('norm', 'self_attn', 'norm', 'ffn'))) + self.ln_final = build_norm_layer( + norm_cfg, embed_dims, postfix='_final')[1] + + self.cache_feature = cache_feature + if self.cache_feature: + self.cache = {} + + self._freeze() + + self.cat_bg = cat_bg + if self.cat_bg: + self.bg_embed = nn.Parameter( + torch.randn(1, self.text_projection.shape[1])) + + @property + def ln_final(self): + return getattr(self, self.final_name) + + def build_attention_mask(self): + """lazily create causal attention mask, with full attention between the + tokens. + + pytorch uses additive attention mask; fill with -inf + """ + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float('-inf')) + mask.triu_(1) # zero out the lower diagonal + return mask + + def _freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def init_weights(self): + if self.cat_bg: + nn.init.normal_( + self.bg_embed, + std=self.bg_embed.shape[1]**-0.5, + ) + if isinstance(self.init_cfg, dict) and \ + self.init_cfg.get('type') == 'Pretrained_Part': + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=None, map_location='cpu') + + state_dict = checkpoint.copy() + para_prefix = 'text_encoder' + prefix_len = len(para_prefix) + 1 + for k, v in checkpoint.items(): + state_dict.pop(k) + if para_prefix in k: + state_dict[k[prefix_len:]] = v + + load_state_dict(self, state_dict, strict=False, logger=None) + + else: + super().init_weights() + + @torch.no_grad() + def encode_text(self, text, normalize=False): + """encode class token.""" + + embed_device = self.token_embedding.weight.device + x = self.token_embedding( + text.to(embed_device)) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + for block in self.transformer: + x = block(query=x, attn_masks=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding + # (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), + text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def template_encode(self, vocabulary): + """Prompt engineering.""" + text_embed_bucket = [] + for template in self.templates: + text_inputs = tokenizer.tokenize( + [template.format(noun) for noun in vocabulary]) + text_embed = self.encode_text(text_inputs, normalize=True) + text_embed_bucket.append(text_embed) + text_embed = torch.stack(text_embed_bucket).mean(dim=0) + text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + return text_embed + + def forward(self): + """Forward function.""" + if self.dataset_name is None: # encoding vocabulary directly + class_names = self.vocabulary + if self.cache_feature: + new_classes = [ + word for word in class_names if word not in self.cache + ] + if len(new_classes) > 0: + class_embeds = self.template_encode(new_classes) + self.cache.update(dict(zip(new_classes, class_embeds))) + class_embeds = torch.stack( + [self.cache[word] for word in class_names]) + else: + class_embeds = self.template_encode(class_names) + + else: # encoding the classes of the dataset + class_names = get_classes(self.dataset_name) + if class_names[0] == 'background': + class_names = class_names[1:] + if self.cache_feature: + if self.dataset_name not in self.cache: + class_embeds = self.template_encode(class_names) + self.cache[self.dataset_name] = class_embeds + else: + class_embeds = self.cache[self.dataset_name] + else: + class_embeds = self.template_encode(class_names) + + if self.cat_bg: + class_embeds = torch.cat([class_embeds, self.bg_embed]) + class_embeds = F.normalize(class_embeds, p=2, dim=-1) + return self.logit_scale.exp() * class_embeds + + +@MODELS.register_module() +class QuickGELU(nn.Module): + # From https://github.com/openai/CLIP/blob/main/clip/model.py + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) diff --git a/finetune/mmseg/models/utils/__init__.py b/finetune/mmseg/models/utils/__init__.py new file mode 100644 index 0000000..c0751b1 --- /dev/null +++ b/finetune/mmseg/models/utils/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .basic_block import BasicBlock, Bottleneck +from .embed import PatchEmbed +from .encoding import Encoding +from .inverted_residual import InvertedResidual, InvertedResidualV3 +from .make_divisible import make_divisible +from .point_sample import get_uncertain_point_coords_with_randomness +from .ppm import DAPPM, PAPPM +from .res_layer import ResLayer +from .se_layer import SELayer +from .self_attention_block import SelfAttentionBlock +from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, + nlc_to_nchw) +from .up_conv_block import UpConvBlock + +# isort: off +from .wrappers import Upsample, resize +from .san_layers import MLP, LayerNorm2d, cross_attn_layer + +__all__ = [ + 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', + 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', + 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding', + 'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck', + 'cross_attn_layer', 'LayerNorm2d', 'MLP', + 'get_uncertain_point_coords_with_randomness' +] diff --git a/finetune/mmseg/models/utils/basic_block.py b/finetune/mmseg/models/utils/basic_block.py new file mode 100644 index 0000000..4e1ad81 --- /dev/null +++ b/finetune/mmseg/models/utils/basic_block.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch import Tensor + +from mmseg.registry import MODELS +from mmseg.utils import OptConfigType + + +class BasicBlock(BaseModule): + """Basic block from `ResNet `_. + + Args: + in_channels (int): Input channels. + channels (int): Output channels. + stride (int): Stride of the first block. Default: 1. + downsample (nn.Module, optional): Downsample operation on identity. + Default: None. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer in + ConvModule. Default: dict(type='ReLU', inplace=True). + act_cfg_out (dict, optional): Config dict for activation layer at the + last of the block. Default: None. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + expansion = 1 + + def __init__(self, + in_channels: int, + channels: int, + stride: int = 1, + downsample: nn.Module = None, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + act_cfg_out: OptConfigType = dict(type='ReLU', inplace=True), + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.conv1 = ConvModule( + in_channels, + channels, + kernel_size=3, + stride=stride, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv2 = ConvModule( + channels, + channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=None) + self.downsample = downsample + if act_cfg_out: + self.act = MODELS.build(act_cfg_out) + + def forward(self, x: Tensor) -> Tensor: + residual = x + out = self.conv1(x) + out = self.conv2(out) + + if self.downsample: + residual = self.downsample(x) + + out += residual + + if hasattr(self, 'act'): + out = self.act(out) + + return out + + +class Bottleneck(BaseModule): + """Bottleneck block from `ResNet `_. + + Args: + in_channels (int): Input channels. + channels (int): Output channels. + stride (int): Stride of the first block. Default: 1. + downsample (nn.Module, optional): Downsample operation on identity. + Default: None. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict, optional): Config dict for activation layer in + ConvModule. Default: dict(type='ReLU', inplace=True). + act_cfg_out (dict, optional): Config dict for activation layer at + the last of the block. Default: None. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + expansion = 2 + + def __init__(self, + in_channels: int, + channels: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + norm_cfg: OptConfigType = dict(type='BN'), + act_cfg: OptConfigType = dict(type='ReLU', inplace=True), + act_cfg_out: OptConfigType = None, + init_cfg: OptConfigType = None): + super().__init__(init_cfg) + self.conv1 = ConvModule( + in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) + self.conv2 = ConvModule( + channels, + channels, + 3, + stride, + 1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv3 = ConvModule( + channels, + channels * self.expansion, + 1, + norm_cfg=norm_cfg, + act_cfg=None) + if act_cfg_out: + self.act = MODELS.build(act_cfg_out) + self.downsample = downsample + + def forward(self, x: Tensor) -> Tensor: + residual = x + + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.downsample: + residual = self.downsample(x) + + out += residual + + if hasattr(self, 'act'): + out = self.act(out) + + return out diff --git a/finetune/mmseg/models/utils/embed.py b/finetune/mmseg/models/utils/embed.py new file mode 100644 index 0000000..aef0a40 --- /dev/null +++ b/finetune/mmseg/models/utils/embed.py @@ -0,0 +1,330 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Sequence + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule +from mmengine.utils import to_2tuple + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed) so that input can get fully covered + by filter you specified. It support two modes "same" and "corner". The + "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around + input. The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1. + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + Example: + >>> kernel_size = 16 + >>> stride = 16 + >>> dilation = 1 + >>> input = torch.rand(1, 1, 15, 17) + >>> adap_pad = AdaptivePadding( + >>> kernel_size=kernel_size, + >>> stride=stride, + >>> dilation=dilation, + >>> padding="corner") + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + >>> input = torch.rand(1, 1, 16, 17) + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + """ + + def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): + + super().__init__() + + assert padding in ('same', 'corner') + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x): + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == 'corner': + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == 'same': + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2 + ]) + return x + + +class PatchEmbed(BaseModule): + """Image to Patch Embedding. + + We use a conv layer to implement PatchEmbed. + + Args: + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_type (str): The config dict for embedding + conv layer type selection. Default: "Conv2d". + kernel_size (int): The kernel_size of embedding conv. Default: 16. + stride (int, optional): The slide stride of embedding conv. + Default: None (Would be set as `kernel_size`). + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int): The dilation rate of embedding conv. Default: 1. + bias (bool): Bias of embed conv. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: None. + input_size (int | tuple | None): The size of input, which will be + used to calculate the out size. Only work when `dynamic_size` + is False. Default: None. + init_cfg (`mmengine.ConfigDict`, optional): The Config for + initialization. Default: None. + """ + + def __init__(self, + in_channels=3, + embed_dims=768, + conv_type='Conv2d', + kernel_size=16, + stride=None, + padding='corner', + dilation=1, + bias=True, + norm_cfg=None, + input_size=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + if stride is None: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of conv + padding = 0 + else: + self.adap_padding = None + padding = to_2tuple(padding) + + self.projection = build_conv_layer( + dict(type=conv_type), + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + if input_size: + input_size = to_2tuple(input_size) + # `init_out_size` would be used outside to + # calculate the num_patches + # when `use_abs_pos_embed` outside + self.init_input_size = input_size + if self.adap_padding: + pad_h, pad_w = self.adap_padding.get_pad_shape(input_size) + input_h, input_w = input_size + input_h = input_h + pad_h + input_w = input_w + pad_w + input_size = (input_h, input_w) + + # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + h_out = (input_size[0] + 2 * padding[0] - dilation[0] * + (kernel_size[0] - 1) - 1) // stride[0] + 1 + w_out = (input_size[1] + 2 * padding[1] - dilation[1] * + (kernel_size[1] - 1) - 1) // stride[1] + 1 + self.init_out_size = (h_out, w_out) + else: + self.init_input_size = None + self.init_out_size = None + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adap_padding: + x = self.adap_padding(x) + + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + x = x.flatten(2).transpose(1, 2) + if self.norm is not None: + x = self.norm(x) + return x, out_size + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding='corner', + dilation=1, + bias=False, + norm_cfg=dict(type='LN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold( + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f'Expect ' \ + f'input_size is ' \ + f'`Sequence` ' \ + f'but get {input_size}' + + H, W = input_size + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + H, W = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * + (self.sampler.kernel_size[0] - 1) - + 1) // self.sampler.stride[0] + 1 + out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * + (self.sampler.kernel_size[1] - 1) - + 1) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size diff --git a/finetune/mmseg/models/utils/encoding.py b/finetune/mmseg/models/utils/encoding.py new file mode 100644 index 0000000..ee4f057 --- /dev/null +++ b/finetune/mmseg/models/utils/encoding.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn +from torch.nn import functional as F + + +class Encoding(nn.Module): + """Encoding Layer: a learnable residual encoder. + + Input is of shape (batch_size, channels, height, width). + Output is of shape (batch_size, num_codes, channels). + + Args: + channels: dimension of the features or feature channels + num_codes: number of code words + """ + + def __init__(self, channels, num_codes): + super().__init__() + # init codewords and smoothing factor + self.channels, self.num_codes = channels, num_codes + std = 1. / ((num_codes * channels)**0.5) + # [num_codes, channels] + self.codewords = nn.Parameter( + torch.empty(num_codes, channels, + dtype=torch.float).uniform_(-std, std), + requires_grad=True) + # [num_codes] + self.scale = nn.Parameter( + torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), + requires_grad=True) + + @staticmethod + def scaled_l2(x, codewords, scale): + num_codes, channels = codewords.size() + batch_size = x.size(0) + reshaped_scale = scale.view((1, 1, num_codes)) + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + + scaled_l2_norm = reshaped_scale * ( + expanded_x - reshaped_codewords).pow(2).sum(dim=3) + return scaled_l2_norm + + @staticmethod + def aggregate(assignment_weights, x, codewords): + num_codes, channels = codewords.size() + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + batch_size = x.size(0) + + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + encoded_feat = (assignment_weights.unsqueeze(3) * + (expanded_x - reshaped_codewords)).sum(dim=1) + return encoded_feat + + def forward(self, x): + assert x.dim() == 4 and x.size(1) == self.channels + # [batch_size, channels, height, width] + batch_size = x.size(0) + # [batch_size, height x width, channels] + x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() + # assignment_weights: [batch_size, channels, num_codes] + assignment_weights = F.softmax( + self.scaled_l2(x, self.codewords, self.scale), dim=2) + # aggregate + encoded_feat = self.aggregate(assignment_weights, x, self.codewords) + return encoded_feat + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ + f'x{self.channels})' + return repr_str diff --git a/finetune/mmseg/models/utils/inverted_residual.py b/finetune/mmseg/models/utils/inverted_residual.py new file mode 100644 index 0000000..56190b3 --- /dev/null +++ b/finetune/mmseg/models/utils/inverted_residual.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import ConvModule +from torch import nn +from torch.utils import checkpoint as cp + +from .se_layer import SELayer + + +class InvertedResidual(nn.Module): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + dilation (int): Dilation rate of depthwise conv. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + dilation=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False, + **kwargs): + super().__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **kwargs), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + **kwargs) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class InvertedResidualV3(nn.Module): + """Inverted Residual Block for MobileNetV3. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernel size of the depthwise convolution. + Default: 3. + stride (int): The stride of the depthwise convolution. Default: 1. + se_cfg (dict): Config dict for se layer. Default: None, which means no + se layer. + with_expand_conv (bool): Use expand conv or not. If set False, + mid_channels must be the same with in_channels. Default: True. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_expand_conv=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + super().__init__() + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2] + self.with_cp = with_cp + self.with_se = se_cfg is not None + self.with_expand_conv = with_expand_conv + + if self.with_se: + assert isinstance(se_cfg, dict) + if not self.with_expand_conv: + assert mid_channels == in_channels + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=mid_channels, + conv_cfg=dict( + type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if self.with_se: + self.se = SELayer(**se_cfg) + + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + out + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out diff --git a/finetune/mmseg/models/utils/make_divisible.py b/finetune/mmseg/models/utils/make_divisible.py new file mode 100644 index 0000000..ed42c2e --- /dev/null +++ b/finetune/mmseg/models/utils/make_divisible.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number to the nearest value that can be + divisible by the divisor. It is taken from the original tf repo. It ensures + that all layers have a channel number that is divisible by divisor. It can + be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float): The minimum ratio of the rounded channel number to + the original channel number. Default: 0.9. + + Returns: + int: The modified output channel number. + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/finetune/mmseg/models/utils/point_sample.py b/finetune/mmseg/models/utils/point_sample.py new file mode 100644 index 0000000..1afc957 --- /dev/null +++ b/finetune/mmseg/models/utils/point_sample.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.ops import point_sample +from torch import Tensor + + +def get_uncertainty(mask_preds: Tensor, labels: Tensor) -> Tensor: + """Estimate uncertainty based on pred logits. + + We estimate uncertainty as L1 distance between 0.0 and the logits + prediction in 'mask_preds' for the foreground class in `classes`. + + Args: + mask_preds (Tensor): mask predication logits, shape (num_rois, + num_classes, mask_height, mask_width). + + labels (Tensor): Either predicted or ground truth label for + each predicted mask, of length num_rois. + + Returns: + scores (Tensor): Uncertainty scores with the most uncertain + locations having the highest uncertainty score, + shape (num_rois, 1, mask_height, mask_width) + """ + if mask_preds.shape[1] == 1: + gt_class_logits = mask_preds.clone() + else: + inds = torch.arange(mask_preds.shape[0], device=mask_preds.device) + gt_class_logits = mask_preds[inds, labels].unsqueeze(1) + return -torch.abs(gt_class_logits) + + +def get_uncertain_point_coords_with_randomness( + mask_preds: Tensor, labels: Tensor, num_points: int, + oversample_ratio: float, importance_sample_ratio: float) -> Tensor: + """Get ``num_points`` most uncertain points with random points during + train. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'get_uncertainty()' function that takes point's logit prediction as + input. + + Args: + mask_preds (Tensor): A tensor of shape (num_rois, num_classes, + mask_height, mask_width) for class-specific or class-agnostic + prediction. + labels (Tensor): The ground truth class for each instance. + num_points (int): The number of points to sample. + oversample_ratio (float): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled + via importnace sampling. + + Returns: + point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) + that contains the coordinates sampled points. + """ + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = mask_preds.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand( + batch_size, num_sampled, 2, device=mask_preds.device) + point_logits = point_sample(mask_preds, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = get_uncertainty(point_logits, labels) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk( + point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange( + batch_size, dtype=torch.long, device=mask_preds.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( + batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_roi_coords = torch.rand( + batch_size, num_random_points, 2, device=mask_preds.device) + point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) + return point_coords diff --git a/finetune/mmseg/models/utils/ppm.py b/finetune/mmseg/models/utils/ppm.py new file mode 100644 index 0000000..5fe6ff2 --- /dev/null +++ b/finetune/mmseg/models/utils/ppm.py @@ -0,0 +1,193 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule, ModuleList, Sequential +from torch import Tensor + + +class DAPPM(BaseModule): + """DAPPM module in `DDRNet `_. + + Args: + in_channels (int): Input channels. + branch_channels (int): Branch channels. + out_channels (int): Output channels. + num_scales (int): Number of scales. + kernel_sizes (list[int]): Kernel sizes of each scale. + strides (list[int]): Strides of each scale. + paddings (list[int]): Paddings of each scale. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU', inplace=True). + conv_cfg (dict): Config dict for convolution layer in ConvModule. + Default: dict(order=('norm', 'act', 'conv'), bias=False). + upsample_mode (str): Upsample mode. Default: 'bilinear'. + """ + + def __init__(self, + in_channels: int, + branch_channels: int, + out_channels: int, + num_scales: int, + kernel_sizes: List[int] = [5, 9, 17], + strides: List[int] = [2, 4, 8], + paddings: List[int] = [2, 4, 8], + norm_cfg: Dict = dict(type='BN', momentum=0.1), + act_cfg: Dict = dict(type='ReLU', inplace=True), + conv_cfg: Dict = dict( + order=('norm', 'act', 'conv'), bias=False), + upsample_mode: str = 'bilinear'): + super().__init__() + + self.num_scales = num_scales + self.unsample_mode = upsample_mode + self.in_channels = in_channels + self.branch_channels = branch_channels + self.out_channels = out_channels + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.conv_cfg = conv_cfg + + self.scales = ModuleList([ + ConvModule( + in_channels, + branch_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + ]) + for i in range(1, num_scales - 1): + self.scales.append( + Sequential(*[ + nn.AvgPool2d( + kernel_size=kernel_sizes[i - 1], + stride=strides[i - 1], + padding=paddings[i - 1]), + ConvModule( + in_channels, + branch_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + ])) + self.scales.append( + Sequential(*[ + nn.AdaptiveAvgPool2d((1, 1)), + ConvModule( + in_channels, + branch_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + ])) + self.processes = ModuleList() + for i in range(num_scales - 1): + self.processes.append( + ConvModule( + branch_channels, + branch_channels, + kernel_size=3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg)) + + self.compression = ConvModule( + branch_channels * num_scales, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + + self.shortcut = ConvModule( + in_channels, + out_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + **conv_cfg) + + def forward(self, inputs: Tensor): + feats = [] + feats.append(self.scales[0](inputs)) + + for i in range(1, self.num_scales): + feat_up = F.interpolate( + self.scales[i](inputs), + size=inputs.shape[2:], + mode=self.unsample_mode) + feats.append(self.processes[i - 1](feat_up + feats[i - 1])) + + return self.compression(torch.cat(feats, + dim=1)) + self.shortcut(inputs) + + +class PAPPM(DAPPM): + """PAPPM module in `PIDNet `_. + + Args: + in_channels (int): Input channels. + branch_channels (int): Branch channels. + out_channels (int): Output channels. + num_scales (int): Number of scales. + kernel_sizes (list[int]): Kernel sizes of each scale. + strides (list[int]): Strides of each scale. + paddings (list[int]): Paddings of each scale. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN', momentum=0.1). + act_cfg (dict): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU', inplace=True). + conv_cfg (dict): Config dict for convolution layer in ConvModule. + Default: dict(order=('norm', 'act', 'conv'), bias=False). + upsample_mode (str): Upsample mode. Default: 'bilinear'. + """ + + def __init__(self, + in_channels: int, + branch_channels: int, + out_channels: int, + num_scales: int, + kernel_sizes: List[int] = [5, 9, 17], + strides: List[int] = [2, 4, 8], + paddings: List[int] = [2, 4, 8], + norm_cfg: Dict = dict(type='BN', momentum=0.1), + act_cfg: Dict = dict(type='ReLU', inplace=True), + conv_cfg: Dict = dict( + order=('norm', 'act', 'conv'), bias=False), + upsample_mode: str = 'bilinear'): + super().__init__(in_channels, branch_channels, out_channels, + num_scales, kernel_sizes, strides, paddings, norm_cfg, + act_cfg, conv_cfg, upsample_mode) + + self.processes = ConvModule( + self.branch_channels * (self.num_scales - 1), + self.branch_channels * (self.num_scales - 1), + kernel_size=3, + padding=1, + groups=self.num_scales - 1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + **self.conv_cfg) + + def forward(self, inputs: Tensor): + x_ = self.scales[0](inputs) + feats = [] + for i in range(1, self.num_scales): + feat_up = F.interpolate( + self.scales[i](inputs), + size=inputs.shape[2:], + mode=self.unsample_mode, + align_corners=False) + feats.append(feat_up + x_) + scale_out = self.processes(torch.cat(feats, dim=1)) + return self.compression(torch.cat([x_, scale_out], + dim=1)) + self.shortcut(inputs) diff --git a/finetune/mmseg/models/utils/res_layer.py b/finetune/mmseg/models/utils/res_layer.py new file mode 100644 index 0000000..3dd7a6f --- /dev/null +++ b/finetune/mmseg/models/utils/res_layer.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import Sequential +from torch import nn as nn + + +class ResLayer(Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + multi_grid (int | None): Multi grid dilation rates of last + stage. Default: None + contract_dilation (bool): Whether contract first dilation of each layer + Default: False + """ + + def __init__(self, + block, + inplanes, + planes, + num_blocks, + stride=1, + dilation=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + multi_grid=None, + contract_dilation=False, + **kwargs): + self.block = block + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + if multi_grid is None: + if dilation > 1 and contract_dilation: + first_dilation = dilation // 2 + else: + first_dilation = dilation + else: + first_dilation = multi_grid[0] + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + dilation=first_dilation, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + dilation=dilation if multi_grid is None else multi_grid[i], + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + super().__init__(*layers) diff --git a/finetune/mmseg/models/utils/san_layers.py b/finetune/mmseg/models/utils/san_layers.py new file mode 100644 index 0000000..2267686 --- /dev/null +++ b/finetune/mmseg/models/utils/san_layers.py @@ -0,0 +1,418 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/MendelXu/SAN/blob/main/san/model/attn_helper.py # noqa: E501 +# Copyright (c) 2023 MendelXu. +# Licensed under the MIT License + +import warnings +from typing import Optional + +import torch +from mmcv.cnn.bricks.transformer import BaseTransformerLayer +from torch import Tensor, nn +from torch.nn import functional as F + + +def cross_attn_with_self_bias( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, +): + """Forward function of multi-head attention. Modified from + multi_head_attention_forward in + https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py. + + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + Default: `True` + Note: `needs_weight` defaults to `True`, but should be set to `False` + For best performance when attention weights are not needed. + *Setting needs_weights to `True` + leads to a significant performance degradation.* + attn_mask: 2D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + """ # noqa: E501 + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, \ + 'embed_dim must be divisible by num_heads' + scaling = float(head_dim)**-0.5 + + if not use_separate_proj_weight: + if (query is key or torch.equal( + query, key)) and (key is value or torch.equal(key, value)): + # self-attention + raise NotImplementedError('self-attention is not implemented') + + elif key is value or torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function + # with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = F.linear(query, _w, _b) + + if key is None: + assert value is None + k = None + v = None + q_k = None + q_v = None + else: + # This is inline in_proj function with + # in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = F.linear(key, _w, _b).chunk(2, dim=-1) + q_k, q_v = F.linear(query, _w, _b).chunk(2, dim=-1) + else: + # This is inline in_proj function with + # in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = F.linear(query, _w, _b) + + # This is inline in_proj function with + # in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = F.linear(key, _w, _b) + q_k = F.linear(query, _w, _b) + # This is inline in_proj function with + # in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = F.linear(value, _w, _b) + q_v = F.linear(query, _w, _b) + else: + q_proj_weight_non_opt = \ + torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == query.size(-1) + + k_proj_weight_non_opt = \ + torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == key.size(-1) + + v_proj_weight_non_opt = \ + torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == value.size(-1) + + if in_proj_bias is not None: + q = F.linear(query, q_proj_weight_non_opt, + in_proj_bias[0:embed_dim]) + k = F.linear(key, k_proj_weight_non_opt, + in_proj_bias[embed_dim:(embed_dim * 2)]) + v = F.linear(value, v_proj_weight_non_opt, + in_proj_bias[(embed_dim * 2):]) + else: + q = F.linear(query, q_proj_weight_non_opt, in_proj_bias) + k = F.linear(key, k_proj_weight_non_opt, in_proj_bias) + v = F.linear(value, v_proj_weight_non_opt, in_proj_bias) + q = q * scaling + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool + ), 'Only float, byte, and bool types are supported for ' \ + 'attn_mask, not {}'.format(attn_mask.dtype) + if attn_mask.dtype == torch.uint8: + warnings.warn('Byte tensor for attn_mask in nn.MultiheadAttention ' + 'is deprecated. Use bool tensor instead.') + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + 'The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), key.size(0) + ]: + raise RuntimeError( + 'The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim())) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + 'Byte tensor for key_padding_mask in nn.MultiheadAttention ' + 'is deprecated. Use bool tensor instead.') + key_padding_mask = key_padding_mask.to(torch.bool) + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + else: + assert static_k is None, 'bias cannot be added to static key.' + assert static_v is None, 'bias cannot be added to static value.' + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + q_k = q_k.contiguous().view(tgt_len, bsz * num_heads, + head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + q_v = q_v.contiguous().view(tgt_len, bsz * num_heads, + head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat( + [ + k, + torch.zeros( + (k.size(0), 1) + k.size()[2:], + dtype=k.dtype, + device=k.device), + ], + dim=1, + ) + v = torch.cat( + [ + v, + torch.zeros( + (v.size(0), 1) + v.size()[2:], + dtype=v.dtype, + device=v.device), + ], + dim=1, + ) + if attn_mask is not None: + attn_mask = F.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = F.pad(key_padding_mask, (0, 1)) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list( + attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, + src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, + tgt_len, src_len) + # attn_out_weights: [bsz * num_heads, tgt_len, src_len] + # ->[bsz * num_heads, tgt_len, src_len+1] + self_weight = (q * q_k).sum( + dim=-1, keepdim=True) # [bsz * num_heads, tgt_len, 1] + total_attn_output_weights = torch.cat([attn_output_weights, self_weight], + dim=-1) + total_attn_output_weights = F.softmax(total_attn_output_weights, dim=-1) + total_attn_output_weights = F.dropout( + total_attn_output_weights, p=dropout_p, training=training) + attn_output_weights = \ + total_attn_output_weights[:, :, : -1] + # [bsz * num_heads, tgt_len, src_len] + self_weight = \ + total_attn_output_weights[:, :, -1:] # [bsz * num_heads, tgt_len, 1] + + attn_output = torch.bmm(attn_output_weights, + v) # [bsz * num_heads, tgt_len, head_dim] + attn_output = (attn_output + self_weight * q_v + ) # [bsz * num_heads, tgt_len, head_dim] + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = attn_output.transpose(0, 1).contiguous().view( + tgt_len, bsz, embed_dim) + attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, + src_len) + return attn_output, attn_output_weights # .sum(dim=1) / num_heads + else: + return attn_output, None + + +def cross_attn_layer(tf_layer: BaseTransformerLayer, x, mem, attn_bias): + """Implementation of transformer layer with cross attention. The cross + attention shares the embedding weights with self-attention of tf_layer. + Args: + tf_layer: (TransformerEncoderLayer): The Module of transformer layer. + x (Tensor): query [K,N,C] + mem (Tensor): key and value [L,N,C] + attn_bias (Tensor): attention bias [N*num_head,K,L] + + Return: + x (Tensor): cross attention output [K,N,C] + """ + self_attn_layer = tf_layer.attentions[0].attn + attn_layer_paras = { + 'embed_dim_to_check': self_attn_layer.embed_dim, + 'num_heads': self_attn_layer.num_heads, + 'in_proj_weight': self_attn_layer.in_proj_weight, + 'in_proj_bias': self_attn_layer.in_proj_bias, + 'bias_k': self_attn_layer.bias_k, + 'bias_v': self_attn_layer.bias_v, + 'add_zero_attn': self_attn_layer.add_zero_attn, + 'dropout_p': self_attn_layer.dropout, + 'out_proj_weight': self_attn_layer.out_proj.weight, + 'out_proj_bias': self_attn_layer.out_proj.bias, + 'training': self_attn_layer.training + } + + q_x = tf_layer.norms[0](x) + k_x = v_x = tf_layer.norms[0](mem) + x = x + cross_attn_with_self_bias( + q_x, + k_x, + v_x, + attn_mask=attn_bias, + need_weights=False, + **attn_layer_paras)[0] + x = tf_layer.ffns[0](tf_layer.norms[1](x), identity=x) + return x + + +class LayerNorm2d(nn.Module): + """A LayerNorm variant, popularized by Transformers, that performs point- + wise mean and variance normalization over the channel dimension for inputs + that have shape (batch_size, channels, height, width). + + https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.normalized_shape = (normalized_shape, ) + + def forward(self, x: torch.Tensor): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, + input_dim, + hidden_dim, + output_dim, + num_layers, + affine_func=nn.Linear): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + affine_func(n, k) + for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x: torch.Tensor): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x diff --git a/finetune/mmseg/models/utils/se_layer.py b/finetune/mmseg/models/utils/se_layer.py new file mode 100644 index 0000000..0ff632c --- /dev/null +++ b/finetune/mmseg/models/utils/se_layer.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.utils import is_tuple_of + +from .make_divisible import make_divisible + + +class SELayer(nn.Module): + """Squeeze-and-Excitation Module. + + Args: + channels (int): The input (and output) channels of the SE layer. + ratio (int): Squeeze ratio in SELayer, the intermediate channel will be + ``int(channels/ratio)``. Default: 16. + conv_cfg (None or dict): Config dict for convolution layer. + Default: None, which means using conv2d. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configured + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configured by the first dict and the + second activation layer will be configured by the second dict. + Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, + divisor=6.0)). + """ + + def __init__(self, + channels, + ratio=16, + conv_cfg=None, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=3.0, divisor=6.0))): + super().__init__() + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=make_divisible(channels // ratio, 8), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=make_divisible(channels // ratio, 8), + out_channels=channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out diff --git a/finetune/mmseg/models/utils/self_attention_block.py b/finetune/mmseg/models/utils/self_attention_block.py new file mode 100644 index 0000000..5bb6e82 --- /dev/null +++ b/finetune/mmseg/models/utils/self_attention_block.py @@ -0,0 +1,161 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import ConvModule +from mmengine.model.weight_init import constant_init +from torch import nn as nn +from torch.nn import functional as F + + +class SelfAttentionBlock(nn.Module): + """General self-attention block/non-local block. + + Please refer to https://arxiv.org/abs/1706.03762 for details about key, + query and value. + + Args: + key_in_channels (int): Input channels of key feature. + query_in_channels (int): Input channels of query feature. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + share_key_query (bool): Whether share projection weight between key + and query projection. + query_downsample (nn.Module): Query downsample module. + key_downsample (nn.Module): Key downsample module. + key_query_num_convs (int): Number of convs for key/query projection. + value_num_convs (int): Number of convs for value projection. + matmul_norm (bool): Whether normalize attention map with sqrt of + channels + with_out (bool): Whether use out projection. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, key_in_channels, query_in_channels, channels, + out_channels, share_key_query, query_downsample, + key_downsample, key_query_num_convs, value_out_num_convs, + key_query_norm, value_out_norm, matmul_norm, with_out, + conv_cfg, norm_cfg, act_cfg): + super().__init__() + if share_key_query: + assert key_in_channels == query_in_channels + self.key_in_channels = key_in_channels + self.query_in_channels = query_in_channels + self.out_channels = out_channels + self.channels = channels + self.share_key_query = share_key_query + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.key_project = self.build_project( + key_in_channels, + channels, + num_convs=key_query_num_convs, + use_conv_module=key_query_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if share_key_query: + self.query_project = self.key_project + else: + self.query_project = self.build_project( + query_in_channels, + channels, + num_convs=key_query_num_convs, + use_conv_module=key_query_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.value_project = self.build_project( + key_in_channels, + channels if with_out else out_channels, + num_convs=value_out_num_convs, + use_conv_module=value_out_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if with_out: + self.out_project = self.build_project( + channels, + out_channels, + num_convs=value_out_num_convs, + use_conv_module=value_out_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.out_project = None + + self.query_downsample = query_downsample + self.key_downsample = key_downsample + self.matmul_norm = matmul_norm + + self.init_weights() + + def init_weights(self): + """Initialize weight of later layer.""" + if self.out_project is not None: + if not isinstance(self.out_project, ConvModule): + constant_init(self.out_project, 0) + + def build_project(self, in_channels, channels, num_convs, use_conv_module, + conv_cfg, norm_cfg, act_cfg): + """Build projection layer for key/query/value/out.""" + if use_conv_module: + convs = [ + ConvModule( + in_channels, + channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + ] + for _ in range(num_convs - 1): + convs.append( + ConvModule( + channels, + channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + else: + convs = [nn.Conv2d(in_channels, channels, 1)] + for _ in range(num_convs - 1): + convs.append(nn.Conv2d(channels, channels, 1)) + if len(convs) > 1: + convs = nn.Sequential(*convs) + else: + convs = convs[0] + return convs + + def forward(self, query_feats, key_feats): + """Forward function.""" + batch_size = query_feats.size(0) + query = self.query_project(query_feats) + if self.query_downsample is not None: + query = self.query_downsample(query) + query = query.reshape(*query.shape[:2], -1) + query = query.permute(0, 2, 1).contiguous() + + key = self.key_project(key_feats) + value = self.value_project(key_feats) + if self.key_downsample is not None: + key = self.key_downsample(key) + value = self.key_downsample(value) + key = key.reshape(*key.shape[:2], -1) + value = value.reshape(*value.shape[:2], -1) + value = value.permute(0, 2, 1).contiguous() + + sim_map = torch.matmul(query, key) + if self.matmul_norm: + sim_map = (self.channels**-.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.reshape(batch_size, -1, *query_feats.shape[2:]) + if self.out_project is not None: + context = self.out_project(context) + return context diff --git a/finetune/mmseg/models/utils/shape_convert.py b/finetune/mmseg/models/utils/shape_convert.py new file mode 100644 index 0000000..cce1e22 --- /dev/null +++ b/finetune/mmseg/models/utils/shape_convert.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def nlc_to_nchw(x, hw_shape): + """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, L, C] before conversion. + hw_shape (Sequence[int]): The height and width of output feature map. + + Returns: + Tensor: The output tensor of shape [N, C, H, W] after conversion. + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len doesn\'t match H, W' + return x.transpose(1, 2).reshape(B, C, H, W) + + +def nchw_to_nlc(x): + """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, C, H, W] before conversion. + + Returns: + Tensor: The output tensor of shape [N, L, C] after conversion. + """ + assert len(x.shape) == 4 + return x.flatten(2).transpose(1, 2).contiguous() + + +def nchw2nlc2nchw(module, x, contiguous=False, **kwargs): + """Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the + reshaped tensor as the input of `module`, and the convert the output of + `module`, whose shape is. + + [N, L, C], to [N, C, H, W]. + + Args: + module (Callable): A callable object the takes a tensor + with shape [N, L, C] as input. + x (Tensor): The input tensor of shape [N, C, H, W]. + contiguous: + contiguous (Bool): Whether to make the tensor contiguous + after each shape transform. + + Returns: + Tensor: The output tensor of shape [N, C, H, W]. + + Example: + >>> import torch + >>> import torch.nn as nn + >>> norm = nn.LayerNorm(4) + >>> feature_map = torch.rand(4, 4, 5, 5) + >>> output = nchw2nlc2nchw(norm, feature_map) + """ + B, C, H, W = x.shape + if not contiguous: + x = x.flatten(2).transpose(1, 2) + x = module(x, **kwargs) + x = x.transpose(1, 2).reshape(B, C, H, W) + else: + x = x.flatten(2).transpose(1, 2).contiguous() + x = module(x, **kwargs) + x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() + return x + + +def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs): + """Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the + reshaped tensor as the input of `module`, and convert the output of + `module`, whose shape is. + + [N, C, H, W], to [N, L, C]. + + Args: + module (Callable): A callable object the takes a tensor + with shape [N, C, H, W] as input. + x (Tensor): The input tensor of shape [N, L, C]. + hw_shape: (Sequence[int]): The height and width of the + feature map with shape [N, C, H, W]. + contiguous (Bool): Whether to make the tensor contiguous + after each shape transform. + + Returns: + Tensor: The output tensor of shape [N, L, C]. + + Example: + >>> import torch + >>> import torch.nn as nn + >>> conv = nn.Conv2d(16, 16, 3, 1, 1) + >>> feature_map = torch.rand(4, 25, 16) + >>> output = nlc2nchw2nlc(conv, feature_map, (5, 5)) + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len doesn\'t match H, W' + if not contiguous: + x = x.transpose(1, 2).reshape(B, C, H, W) + x = module(x, **kwargs) + x = x.flatten(2).transpose(1, 2) + else: + x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() + x = module(x, **kwargs) + x = x.flatten(2).transpose(1, 2).contiguous() + return x diff --git a/finetune/mmseg/models/utils/up_conv_block.py b/finetune/mmseg/models/utils/up_conv_block.py new file mode 100644 index 0000000..4fa3b59 --- /dev/null +++ b/finetune/mmseg/models/utils/up_conv_block.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, build_upsample_layer + + +class UpConvBlock(nn.Module): + """Upsample convolution block in decoder for UNet. + + This upsample convolution block consists of one upsample module + followed by one convolution block. The upsample module expands the + high-level low-resolution feature map and the convolution block fuses + the upsampled high-level low-resolution feature map and the low-level + high-resolution feature map from encoder. + + Args: + conv_block (nn.Sequential): Sequential of convolutional layers. + in_channels (int): Number of input channels of the high-level + skip_channels (int): Number of input channels of the low-level + high-resolution feature map from encoder. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers in the conv_block. + Default: 2. + stride (int): Stride of convolutional layer in conv_block. Default: 1. + dilation (int): Dilation rate of convolutional layer in conv_block. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). If the size of + high-level feature map is the same as that of skip feature map + (low-level feature map from encoder), it does not need upsample the + high-level feature map and the upsample_cfg is None. + dcn (bool): Use deformable convolution in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + conv_block, + in_channels, + skip_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + dcn=None, + plugins=None): + super().__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.conv_block = conv_block( + in_channels=2 * skip_channels, + out_channels=out_channels, + num_convs=num_convs, + stride=stride, + dilation=dilation, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None) + if upsample_cfg is not None: + self.upsample = build_upsample_layer( + cfg=upsample_cfg, + in_channels=in_channels, + out_channels=skip_channels, + with_cp=with_cp, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.upsample = ConvModule( + in_channels, + skip_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, skip, x): + """Forward function.""" + + x = self.upsample(x) + out = torch.cat([skip, x], dim=1) + out = self.conv_block(out) + + return out diff --git a/finetune/mmseg/models/utils/wrappers.py b/finetune/mmseg/models/utils/wrappers.py new file mode 100644 index 0000000..abbd0c0 --- /dev/null +++ b/finetune/mmseg/models/utils/wrappers.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.nn.functional as F + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__(self, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + super().__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) diff --git a/finetune/mmseg/registry/__init__.py b/finetune/mmseg/registry/__init__.py new file mode 100644 index 0000000..ee514d1 --- /dev/null +++ b/finetune/mmseg/registry/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, INFERENCERS, + LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, + OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS, + PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, + TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS, + WEIGHT_INITIALIZERS) + +__all__ = [ + 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', + 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', + 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', + 'VISBACKENDS', 'VISUALIZERS', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS', + 'EVALUATOR', 'LOG_PROCESSORS', 'OPTIM_WRAPPERS', 'INFERENCERS' +] diff --git a/finetune/mmseg/registry/registry.py b/finetune/mmseg/registry/registry.py new file mode 100644 index 0000000..37b6a77 --- /dev/null +++ b/finetune/mmseg/registry/registry.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""MMSegmentation provides 21 registry nodes to support using modules across +projects. Each node is a child of the root registry in MMEngine. + +More details can be found at +https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html. +""" + +from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS +from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR +from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import INFERENCERS as MMENGINE_INFERENCERS +from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS +from mmengine.registry import LOOPS as MMENGINE_LOOPS +from mmengine.registry import METRICS as MMENGINE_METRICS +from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS +from mmengine.registry import MODELS as MMENGINE_MODELS +from mmengine.registry import \ + OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS +from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS +from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS +from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS +from mmengine.registry import \ + RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS +from mmengine.registry import RUNNERS as MMENGINE_RUNNERS +from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS +from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS +from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS +from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS +from mmengine.registry import \ + WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS +from mmengine.registry import Registry + +# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` +RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS) +# manage runner constructors that define how to initialize runners +RUNNER_CONSTRUCTORS = Registry( + 'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS) +# manage all kinds of loops like `EpochBasedTrainLoop` +LOOPS = Registry('loop', parent=MMENGINE_LOOPS) +# manage all kinds of hooks like `CheckpointHook` +HOOKS = Registry( + 'hook', parent=MMENGINE_HOOKS, locations=['mmseg.engine.hooks']) + +# manage data-related modules +DATASETS = Registry( + 'dataset', parent=MMENGINE_DATASETS, locations=['mmseg.datasets']) +DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS) +TRANSFORMS = Registry( + 'transform', + parent=MMENGINE_TRANSFORMS, + locations=['mmseg.datasets.transforms']) + +# mangage all kinds of modules inheriting `nn.Module` +MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmseg.models']) +# mangage all kinds of model wrappers like 'MMDistributedDataParallel' +MODEL_WRAPPERS = Registry( + 'model_wrapper', + parent=MMENGINE_MODEL_WRAPPERS, + locations=['mmseg.models']) +# mangage all kinds of weight initialization modules like `Uniform` +WEIGHT_INITIALIZERS = Registry( + 'weight initializer', + parent=MMENGINE_WEIGHT_INITIALIZERS, + locations=['mmseg.models']) + +# mangage all kinds of optimizers like `SGD` and `Adam` +OPTIMIZERS = Registry( + 'optimizer', + parent=MMENGINE_OPTIMIZERS, + locations=['mmseg.engine.optimizers']) +# manage optimizer wrapper +OPTIM_WRAPPERS = Registry( + 'optim_wrapper', + parent=MMENGINE_OPTIM_WRAPPERS, + locations=['mmseg.engine.optimizers']) +# manage constructors that customize the optimization hyperparameters. +OPTIM_WRAPPER_CONSTRUCTORS = Registry( + 'optimizer wrapper constructor', + parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS, + locations=['mmseg.engine.optimizers']) +# mangage all kinds of parameter schedulers like `MultiStepLR` +PARAM_SCHEDULERS = Registry( + 'parameter scheduler', + parent=MMENGINE_PARAM_SCHEDULERS, + locations=['mmseg.engine.schedulers']) + +# manage all kinds of metrics +METRICS = Registry( + 'metric', parent=MMENGINE_METRICS, locations=['mmseg.evaluation']) +# manage evaluator +EVALUATOR = Registry( + 'evaluator', parent=MMENGINE_EVALUATOR, locations=['mmseg.evaluation']) + +# manage task-specific modules like ohem pixel sampler +TASK_UTILS = Registry( + 'task util', parent=MMENGINE_TASK_UTILS, locations=['mmseg.models']) + +# manage visualizer +VISUALIZERS = Registry( + 'visualizer', + parent=MMENGINE_VISUALIZERS, + locations=['mmseg.visualization']) +# manage visualizer backend +VISBACKENDS = Registry( + 'vis_backend', + parent=MMENGINE_VISBACKENDS, + locations=['mmseg.visualization']) + +# manage logprocessor +LOG_PROCESSORS = Registry( + 'log_processor', + parent=MMENGINE_LOG_PROCESSORS, + locations=['mmseg.visualization']) + +# manage inferencer +INFERENCERS = Registry('inferencer', parent=MMENGINE_INFERENCERS) diff --git a/finetune/mmseg/structures/__init__.py b/finetune/mmseg/structures/__init__.py new file mode 100644 index 0000000..63d118d --- /dev/null +++ b/finetune/mmseg/structures/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .sampler import BasePixelSampler, OHEMPixelSampler, build_pixel_sampler +from .seg_data_sample import SegDataSample + +__all__ = [ + 'SegDataSample', 'BasePixelSampler', 'OHEMPixelSampler', + 'build_pixel_sampler' +] diff --git a/finetune/mmseg/structures/sampler/__init__.py b/finetune/mmseg/structures/sampler/__init__.py new file mode 100644 index 0000000..91d762d --- /dev/null +++ b/finetune/mmseg/structures/sampler/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base_pixel_sampler import BasePixelSampler +from .builder import build_pixel_sampler +from .ohem_pixel_sampler import OHEMPixelSampler + +__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] diff --git a/finetune/mmseg/structures/sampler/base_pixel_sampler.py b/finetune/mmseg/structures/sampler/base_pixel_sampler.py new file mode 100644 index 0000000..03672cd --- /dev/null +++ b/finetune/mmseg/structures/sampler/base_pixel_sampler.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + + +class BasePixelSampler(metaclass=ABCMeta): + """Base class of pixel sampler.""" + + def __init__(self, **kwargs): + pass + + @abstractmethod + def sample(self, seg_logit, seg_label): + """Placeholder for sample function.""" diff --git a/finetune/mmseg/structures/sampler/builder.py b/finetune/mmseg/structures/sampler/builder.py new file mode 100644 index 0000000..48e1479 --- /dev/null +++ b/finetune/mmseg/structures/sampler/builder.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmseg.registry import TASK_UTILS + +PIXEL_SAMPLERS = TASK_UTILS + + +def build_pixel_sampler(cfg, **default_args): + """Build pixel sampler for segmentation map.""" + warnings.warn( + '``build_pixel_sampler`` would be deprecated soon, please use ' + '``mmseg.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) diff --git a/finetune/mmseg/structures/sampler/ohem_pixel_sampler.py b/finetune/mmseg/structures/sampler/ohem_pixel_sampler.py new file mode 100644 index 0000000..a974273 --- /dev/null +++ b/finetune/mmseg/structures/sampler/ohem_pixel_sampler.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_pixel_sampler import BasePixelSampler +from .builder import PIXEL_SAMPLERS + + +@PIXEL_SAMPLERS.register_module() +class OHEMPixelSampler(BasePixelSampler): + """Online Hard Example Mining Sampler for segmentation. + + Args: + context (nn.Module): The context of sampler, subclass of + :obj:`BaseDecodeHead`. + thresh (float, optional): The threshold for hard example selection. + Below which, are prediction with low confidence. If not + specified, the hard examples will be pixels of top ``min_kept`` + loss. Default: None. + min_kept (int, optional): The minimum number of predictions to keep. + Default: 100000. + """ + + def __init__(self, context, thresh=None, min_kept=100000): + super().__init__() + self.context = context + assert min_kept > 1 + self.thresh = thresh + self.min_kept = min_kept + + def sample(self, seg_logit, seg_label): + """Sample pixels that have high loss or with low prediction confidence. + + Args: + seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) + seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) + + Returns: + torch.Tensor: segmentation weight, shape (N, H, W) + """ + with torch.no_grad(): + assert seg_logit.shape[2:] == seg_label.shape[2:] + assert seg_label.shape[1] == 1 + seg_label = seg_label.squeeze(1).long() + batch_kept = self.min_kept * seg_label.size(0) + valid_mask = seg_label != self.context.ignore_index + seg_weight = seg_logit.new_zeros(size=seg_label.size()) + valid_seg_weight = seg_weight[valid_mask] + if self.thresh is not None: + seg_prob = F.softmax(seg_logit, dim=1) + + tmp_seg_label = seg_label.clone().unsqueeze(1) + tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 + seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) + sort_prob, sort_indices = seg_prob[valid_mask].sort() + + if sort_prob.numel() > 0: + min_threshold = sort_prob[min(batch_kept, + sort_prob.numel() - 1)] + else: + min_threshold = 0.0 + threshold = max(min_threshold, self.thresh) + valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. + else: + if not isinstance(self.context.loss_decode, nn.ModuleList): + losses_decode = [self.context.loss_decode] + else: + losses_decode = self.context.loss_decode + losses = 0.0 + for loss_module in losses_decode: + losses += loss_module( + seg_logit, + seg_label, + weight=None, + ignore_index=self.context.ignore_index, + reduction_override='none') + + # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa + _, sort_indices = losses[valid_mask].sort(descending=True) + valid_seg_weight[sort_indices[:batch_kept]] = 1. + + seg_weight[valid_mask] = valid_seg_weight + + return seg_weight diff --git a/finetune/mmseg/structures/seg_data_sample.py b/finetune/mmseg/structures/seg_data_sample.py new file mode 100644 index 0000000..ce68b54 --- /dev/null +++ b/finetune/mmseg/structures/seg_data_sample.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.structures import BaseDataElement, PixelData + + +class SegDataSample(BaseDataElement): + """A data structure interface of MMSegmentation. They are used as + interfaces between different components. + + The attributes in ``SegDataSample`` are divided into several parts: + + - ``gt_sem_seg``(PixelData): Ground truth of semantic segmentation. + - ``pred_sem_seg``(PixelData): Prediction of semantic segmentation. + - ``seg_logits``(PixelData): Predicted logits of semantic segmentation. + + Examples: + >>> import torch + >>> import numpy as np + >>> from mmengine.structures import PixelData + >>> from mmseg.structures import SegDataSample + + >>> data_sample = SegDataSample() + >>> img_meta = dict(img_shape=(4, 4, 3), + ... pad_shape=(4, 4, 3)) + >>> gt_segmentations = PixelData(metainfo=img_meta) + >>> gt_segmentations.data = torch.randint(0, 2, (1, 4, 4)) + >>> data_sample.gt_sem_seg = gt_segmentations + >>> assert 'img_shape' in data_sample.gt_sem_seg.metainfo_keys() + >>> data_sample.gt_sem_seg.shape + (4, 4) + >>> print(data_sample) + + ) at 0x1c2aae44d60> + + >>> data_sample = SegDataSample() + >>> gt_sem_seg_data = dict(sem_seg=torch.rand(1, 4, 4)) + >>> gt_sem_seg = PixelData(**gt_sem_seg_data) + >>> data_sample.gt_sem_seg = gt_sem_seg + >>> assert 'gt_sem_seg' in data_sample + >>> assert 'sem_seg' in data_sample.gt_sem_seg + """ + + @property + def gt_sem_seg(self) -> PixelData: + return self._gt_sem_seg + + @gt_sem_seg.setter + def gt_sem_seg(self, value: PixelData) -> None: + self.set_field(value, '_gt_sem_seg', dtype=PixelData) + + @gt_sem_seg.deleter + def gt_sem_seg(self) -> None: + del self._gt_sem_seg + + @property + def pred_sem_seg(self) -> PixelData: + return self._pred_sem_seg + + @pred_sem_seg.setter + def pred_sem_seg(self, value: PixelData) -> None: + self.set_field(value, '_pred_sem_seg', dtype=PixelData) + + @pred_sem_seg.deleter + def pred_sem_seg(self) -> None: + del self._pred_sem_seg + + @property + def seg_logits(self) -> PixelData: + return self._seg_logits + + @seg_logits.setter + def seg_logits(self, value: PixelData) -> None: + self.set_field(value, '_seg_logits', dtype=PixelData) + + @seg_logits.deleter + def seg_logits(self) -> None: + del self._seg_logits diff --git a/finetune/mmseg/utils/__init__.py b/finetune/mmseg/utils/__init__.py new file mode 100644 index 0000000..0a2af58 --- /dev/null +++ b/finetune/mmseg/utils/__init__.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# yapf: disable +from .class_names import (ade_classes, ade_palette, bdd100k_classes, + bdd100k_palette, cityscapes_classes, + cityscapes_palette, cocostuff_classes, + cocostuff_palette, dataset_aliases, get_classes, + get_palette, isaid_classes, isaid_palette, + loveda_classes, loveda_palette, potsdam_classes, + potsdam_palette, stare_classes, stare_palette, + synapse_classes, synapse_palette, vaihingen_classes, + vaihingen_palette, voc_classes, voc_palette) +# yapf: enable +from .collect_env import collect_env +from .get_templates import get_predefined_templates +from .io import datafrombytes +from .misc import add_prefix, stack_batch +from .set_env import register_all_modules +from .tokenizer import tokenize +from .typing_utils import (ConfigType, ForwardResults, MultiConfig, + OptConfigType, OptMultiConfig, OptSampleList, + SampleList, TensorDict, TensorList) + +# isort: off +from .mask_classification import MatchMasks, seg_data_to_instance_data + +__all__ = [ + 'collect_env', + 'register_all_modules', + 'stack_batch', + 'add_prefix', + 'ConfigType', + 'OptConfigType', + 'MultiConfig', + 'OptMultiConfig', + 'SampleList', + 'OptSampleList', + 'TensorDict', + 'TensorList', + 'ForwardResults', + 'cityscapes_classes', + 'ade_classes', + 'voc_classes', + 'cocostuff_classes', + 'loveda_classes', + 'potsdam_classes', + 'vaihingen_classes', + 'isaid_classes', + 'stare_classes', + 'cityscapes_palette', + 'ade_palette', + 'voc_palette', + 'cocostuff_palette', + 'loveda_palette', + 'potsdam_palette', + 'vaihingen_palette', + 'isaid_palette', + 'stare_palette', + 'dataset_aliases', + 'get_classes', + 'get_palette', + 'datafrombytes', + 'synapse_palette', + 'synapse_classes', + 'get_predefined_templates', + 'tokenize', + 'seg_data_to_instance_data', + 'MatchMasks', + 'bdd100k_classes', + 'bdd100k_palette', +] diff --git a/finetune/mmseg/utils/bpe_simple_vocab_16e6.txt.gz b/finetune/mmseg/utils/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000..7b5088a Binary files /dev/null and b/finetune/mmseg/utils/bpe_simple_vocab_16e6.txt.gz differ diff --git a/finetune/mmseg/utils/class_names.py b/finetune/mmseg/utils/class_names.py new file mode 100644 index 0000000..644e955 --- /dev/null +++ b/finetune/mmseg/utils/class_names.py @@ -0,0 +1,548 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils import is_str + + +def cityscapes_classes(): + """Cityscapes class names for external use.""" + return [ + 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle' + ] + + +def ade_classes(): + """ADE20K class names for external use.""" + return [ + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', + 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', + 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', + 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', + 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', + 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', + 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', + 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', + 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', + 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', + 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', + 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', + 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', + 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', + 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', + 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag' + ] + + +def voc_classes(): + """Pascal VOC class names for external use.""" + return [ + 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', + 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', + 'tvmonitor' + ] + + +def pcontext_classes(): + """Pascal Context class names for external use.""" + return [ + 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', + 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', + 'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', + 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', + 'horse', 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', + 'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep', + 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track', + 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', + 'wood' + ] + + +def cocostuff_classes(): + """CocoStuff class names for external use.""" + return [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', + 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', + 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', + 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', + 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', + 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower', + 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel', + 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal', + 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper', + 'pavement', 'pillow', 'plant-other', 'plastic', 'platform', + 'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof', + 'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper', + 'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other', + 'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable', + 'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel', + 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', + 'window-blind', 'window-other', 'wood' + ] + + +def loveda_classes(): + """LoveDA class names for external use.""" + return [ + 'background', 'building', 'road', 'water', 'barren', 'forest', + 'agricultural' + ] + + +def potsdam_classes(): + """Potsdam class names for external use.""" + return [ + 'impervious_surface', 'building', 'low_vegetation', 'tree', 'car', + 'clutter' + ] + + +def vaihingen_classes(): + """Vaihingen class names for external use.""" + return [ + 'impervious_surface', 'building', 'low_vegetation', 'tree', 'car', + 'clutter' + ] + + +def isaid_classes(): + """iSAID class names for external use.""" + return [ + 'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court', + 'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle', + 'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout', + 'Soccer_ball_field', 'plane', 'Harbor' + ] + + +def stare_classes(): + """stare class names for external use.""" + return ['background', 'vessel'] + + +def mapillary_v1_classes(): + """mapillary_v1 class names for external use.""" + return [ + 'Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier', + 'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking', + 'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane', 'Sidewalk', + 'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist', + 'Other Rider', 'Lane Marking - Crosswalk', 'Lane Marking - General', + 'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water', + 'Banner', 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin', + 'CCTV Camera', 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', + 'Phone Booth', 'Pothole', 'Street Light', 'Pole', 'Traffic Sign Frame', + 'Utility Pole', 'Traffic Light', 'Traffic Sign (Back)', + 'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car', + 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer', + 'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled' + ] + + +def mapillary_v1_palette(): + """mapillary_v1_ palette for external use.""" + return [[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], + [180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255], + [140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232], + [150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60], + [255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128], + [255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180], + [190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30], + [255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220], + [220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40], + [33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150], + [210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80], + [250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20], + [119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142], + [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], + [0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]] + + +def mapillary_v2_classes(): + """mapillary_v2 class names for external use.""" + return [ + 'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block', 'Curb', + 'Fence', 'Guard Rail', 'Barrier', 'Road Median', 'Road Side', + 'Lane Separator', 'Temporary Barrier', 'Wall', 'Bike Lane', + 'Crosswalk - Plain', 'Curb Cut', 'Driveway', 'Parking', + 'Parking Aisle', 'Pedestrian Area', 'Rail Track', 'Road', + 'Road Shoulder', 'Service Lane', 'Sidewalk', 'Traffic Island', + 'Bridge', 'Building', 'Garage', 'Tunnel', 'Person', 'Person Group', + 'Bicyclist', 'Motorcyclist', 'Other Rider', + 'Lane Marking - Dashed Line', 'Lane Marking - Straight Line', + 'Lane Marking - Zigzag Line', 'Lane Marking - Ambiguous', + 'Lane Marking - Arrow (Left)', 'Lane Marking - Arrow (Other)', + 'Lane Marking - Arrow (Right)', + 'Lane Marking - Arrow (Split Left or Straight)', + 'Lane Marking - Arrow (Split Right or Straight)', + 'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk', + 'Lane Marking - Give Way (Row)', 'Lane Marking - Give Way (Single)', + 'Lane Marking - Hatched (Chevron)', + 'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other', + 'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)', + 'Lane Marking - Symbol (Other)', 'Lane Marking - Text', + 'Lane Marking (only) - Dashed Line', 'Lane Marking (only) - Crosswalk', + 'Lane Marking (only) - Other', 'Lane Marking (only) - Test', + 'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water', + 'Banner', 'Bench', 'Bike Rack', 'Catch Basin', 'CCTV Camera', + 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Parking Meter', + 'Phone Booth', 'Pothole', 'Signage - Advertisement', + 'Signage - Ambiguous', 'Signage - Back', 'Signage - Information', + 'Signage - Other', 'Signage - Store', 'Street Light', 'Pole', + 'Pole Group', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Cone', + 'Traffic Light - General (Single)', 'Traffic Light - Pedestrians', + 'Traffic Light - General (Upright)', + 'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists', + 'Traffic Light - Other', 'Traffic Sign - Ambiguous', + 'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)', + 'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)', + 'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)', + 'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat', + 'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', + 'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve', + 'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static', 'Unlabeled' + ] + + +def mapillary_v2_palette(): + """mapillary_v2_ palette for external use.""" + return [[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32], + [196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150], + [250, 170, 33], [250, 170, 34], [128, 128, 128], [250, 170, 35], + [102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170], + [250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [110, 110, 110], [110, 110, 110], + [244, 35, 232], [128, 196, 128], [150, 100, 100], [70, 70, 70], + [150, 150, 150], [150, 120, 90], [220, 20, 60], [220, 20, 60], + [255, 0, 0], [255, 0, 100], [255, 0, 200], [255, 255, 255], + [255, 255, 255], [250, 170, 29], [250, 170, 28], [250, 170, 26], + [250, 170, 25], [250, 170, 24], [250, 170, 22], [250, 170, 21], + [250, 170, 20], [255, 255, 255], [250, 170, 19], [250, 170, 18], + [250, 170, 12], [250, 170, 11], [255, 255, 255], [255, 255, 255], + [250, 170, 16], [250, 170, 15], [250, 170, 15], [255, 255, 255], + [255, 255, 255], [255, 255, 255], [255, 255, 255], [64, 170, 64], + [230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152], + [107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30], + [100, 140, 180], [220, 128, 128], [222, 40, 40], [100, 170, 30], + [40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255], + [142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30], + [250, 173, 30], [250, 174, 30], [250, 175, 30], [250, 176, 30], + [210, 170, 100], [153, 153, 153], [153, 153, 153], [128, 128, 128], + [0, 0, 80], [210, 60, 60], [250, 170, 30], [250, 170, 30], + [250, 170, 30], [250, 170, 30], [250, 170, 30], [250, 170, 30], + [192, 192, 192], [192, 192, 192], [192, 192, 192], [220, 220, 0], + [220, 220, 0], [0, 0, 196], [192, 192, 192], [220, 220, 0], + [140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100], + [0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], + [0, 0, 110], [0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170], + [32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81], + [111, 111, 0], [0, 0, 0]] + + +def cityscapes_palette(): + """Cityscapes palette for external use.""" + return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], + [0, 0, 230], [119, 11, 32]] + + +def ade_palette(): + """ADE20K palette for external use.""" + return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + + +def voc_palette(): + """Pascal VOC palette for external use.""" + return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], + [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], + [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], + [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], + [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] + + +def pcontext_palette(): + """Pascal Context palette for external use.""" + return [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], + [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], + [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], + [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140], + [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], + [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71], + [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], + [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], + [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], + [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], + [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], + [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], + [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], + [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0], + [0, 235, 255], [0, 173, 255], [31, 0, 255]] + + +def cocostuff_palette(): + """CocoStuff palette for external use.""" + return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], + [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], + [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], + [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], + [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192], + [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128], + [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], [0, 32, 0], + [0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0], + [192, 128, 32], [128, 96, 128], [0, 0, 128], [64, 0, 32], + [0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128], + [128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64], + [192, 0, 32], [128, 96, 0], [128, 0, 192], [0, 128, 32], + [64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0], + [0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64], + [128, 128, 32], [192, 32, 128], [0, 64, 192], [0, 0, 32], + [64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128], + [128, 192, 192], [0, 0, 160], [192, 160, 128], [128, 192, 0], + [128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96], + [64, 160, 0], [0, 64, 0], [192, 128, 224], [64, 32, 0], + [0, 192, 128], [64, 128, 224], [192, 160, 0], [0, 192, 0], + [192, 128, 96], [192, 96, 128], [0, 64, 128], [64, 0, 96], + [64, 224, 128], [128, 64, 0], [192, 0, 224], [64, 96, 128], + [128, 192, 128], [64, 0, 224], [192, 224, 128], [128, 192, 64], + [192, 0, 96], [192, 96, 0], [128, 64, 192], [0, 128, 96], + [0, 224, 0], [64, 64, 64], [128, 128, 224], [0, 96, 0], + [64, 192, 192], [0, 128, 224], [128, 224, 0], [64, 192, 64], + [128, 128, 96], [128, 32, 128], [64, 0, 192], [0, 64, 96], + [0, 160, 128], [192, 0, 64], [128, 64, 224], [0, 32, 128], + [192, 128, 192], [0, 64, 224], [128, 160, 128], [192, 128, 0], + [128, 64, 32], [128, 32, 64], [192, 0, 128], [64, 192, 32], + [0, 160, 64], [64, 0, 0], [192, 192, 160], [0, 32, 64], + [64, 128, 128], [64, 192, 160], [128, 160, 64], [64, 128, 0], + [192, 192, 32], [128, 96, 192], [64, 0, 128], [64, 64, 32], + [0, 224, 192], [192, 0, 0], [192, 64, 160], [0, 96, 192], + [192, 128, 128], [64, 64, 160], [128, 224, 192], [192, 128, 64], + [192, 64, 32], [128, 96, 64], [192, 0, 192], [0, 192, 32], + [64, 224, 64], [64, 0, 64], [128, 192, 160], [64, 96, 64], + [64, 128, 192], [0, 192, 160], [192, 224, 64], [64, 128, 64], + [128, 192, 32], [192, 32, 192], [64, 64, 192], [0, 64, 32], + [64, 160, 192], [192, 64, 64], [128, 64, 160], [64, 32, 192], + [192, 192, 192], [0, 64, 160], [192, 160, 192], [192, 192, 0], + [128, 64, 96], [192, 32, 64], [192, 64, 128], [64, 192, 96], + [64, 160, 64], [64, 64, 0]] + + +def loveda_palette(): + """LoveDA palette for external use.""" + return [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], + [159, 129, 183], [0, 255, 0], [255, 195, 128]] + + +def potsdam_palette(): + """Potsdam palette for external use.""" + return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]] + + +def vaihingen_palette(): + """Vaihingen palette for external use.""" + return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], + [255, 255, 0], [255, 0, 0]] + + +def isaid_palette(): + """iSAID palette for external use.""" + return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], + [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, + 127], [0, 0, 127], + [0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191], + [0, 127, 255], [0, 100, 155]] + + +def stare_palette(): + """STARE palette for external use.""" + return [[120, 120, 120], [6, 230, 230]] + + +def synapse_palette(): + """Synapse palette for external use.""" + return [[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], [0, 255, 255], + [255, 0, 255], [255, 255, 0], [60, 255, 255], [240, 240, 240]] + + +def synapse_classes(): + """Synapse class names for external use.""" + return [ + 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney', + 'liver', 'pancreas', 'spleen', 'stomach' + ] + + +def lip_classes(): + """LIP class names for external use.""" + return [ + 'background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', + 'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt', + 'face', 'leftArm', 'rightArm', 'leftLeg', 'rightLeg', 'leftShoe', + 'rightShoe' + ] + + +def lip_palette(): + """LIP palette for external use.""" + return [ + 'Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'UpperClothes', + 'Dress', 'Coat', 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', + 'Face', 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe', + 'Right-shoe' + ] + + +def bdd100k_classes(): + """BDD100K class names for external use(the class name is compatible with + Cityscapes ).""" + return [ + 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle' + ] + + +def bdd100k_palette(): + """bdd100k palette for external use(same with cityscapes)""" + return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], + [0, 0, 230], [119, 11, 32]] + + +def hsidrive_classes(): + """HSI Drive 2.0 class names for external use.""" + return [ + 'unlabelled', 'road', 'road marks', 'vegetation', 'painted metal', + 'sky', 'concrete', 'pedestrian', 'water', 'unpainted metal', 'glass' + ] + + +def hsidrive_palette(): + """HSI Drive 2.0 palette for external use.""" + return [[0, 0, 0], [77, 77, 77], [255, 255, 255], [0, 255, 0], [255, 0, 0], + [0, 0, 255], [102, 51, 0], [255, 255, 0], [0, 207, 250], + [255, 166, 0], [0, 204, 204]] + + +dataset_aliases = { + 'cityscapes': ['cityscapes'], + 'ade': ['ade', 'ade20k'], + 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'], + 'pcontext': ['pcontext', 'pascal_context', 'voc2010'], + 'loveda': ['loveda'], + 'potsdam': ['potsdam'], + 'vaihingen': ['vaihingen'], + 'cocostuff': [ + 'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff', + 'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k', + 'coco_stuff164k' + ], + 'isaid': ['isaid', 'iSAID'], + 'stare': ['stare', 'STARE'], + 'lip': ['LIP', 'lip'], + 'mapillary_v1': ['mapillary_v1'], + 'mapillary_v2': ['mapillary_v2'], + 'bdd100k': ['bdd100k'], + 'hsidrive': [ + 'hsidrive', 'HSIDrive', 'HSI-Drive', 'hsidrive20', 'HSIDrive20', + 'HSI-Drive20' + ] +} + + +def get_classes(dataset): + """Get class names of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_classes()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels + + +def get_palette(dataset): + """Get class palette (RGB) of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_palette()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels diff --git a/finetune/mmseg/utils/collect_env.py b/finetune/mmseg/utils/collect_env.py new file mode 100644 index 0000000..d5d6ea2 --- /dev/null +++ b/finetune/mmseg/utils/collect_env.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.utils import get_git_hash +from mmengine.utils.dl_utils import collect_env as collect_base_env + +import mmseg + + +def collect_env(): + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' + + return env_info + + +if __name__ == '__main__': + for name, val in collect_env().items(): + print(f'{name}: {val}') diff --git a/finetune/mmseg/utils/get_templates.py b/finetune/mmseg/utils/get_templates.py new file mode 100644 index 0000000..7e9032b --- /dev/null +++ b/finetune/mmseg/utils/get_templates.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +PREDEFINED_TEMPLATES = { + 'imagenet': [ + 'a bad photo of a {}.', + 'a photo of many {}.', + 'a sculpture of a {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of the {}.', + 'a rendering of a {}.', + 'graffiti of a {}.', + 'a bad photo of the {}.', + 'a cropped photo of the {}.', + 'a tattoo of a {}.', + 'the embroidered {}.', + 'a photo of a hard to see {}.', + 'a bright photo of a {}.', + 'a photo of a clean {}.', + 'a photo of a dirty {}.', + 'a dark photo of the {}.', + 'a drawing of a {}.', + 'a photo of my {}.', + 'the plastic {}.', + 'a photo of the cool {}.', + 'a close-up photo of a {}.', + 'a black and white photo of the {}.', + 'a painting of the {}.', + 'a painting of a {}.', + 'a pixelated photo of the {}.', + 'a sculpture of the {}.', + 'a bright photo of the {}.', + 'a cropped photo of a {}.', + 'a plastic {}.', + 'a photo of the dirty {}.', + 'a jpeg corrupted photo of a {}.', + 'a blurry photo of the {}.', + 'a photo of the {}.', + 'a good photo of the {}.', + 'a rendering of the {}.', + 'a {} in a video game.', + 'a photo of one {}.', + 'a doodle of a {}.', + 'a close-up photo of the {}.', + 'a photo of a {}.', + 'the origami {}.', + 'the {} in a video game.', + 'a sketch of a {}.', + 'a doodle of the {}.', + 'a origami {}.', + 'a low resolution photo of a {}.', + 'the toy {}.', + 'a rendition of the {}.', + 'a photo of the clean {}.', + 'a photo of a large {}.', + 'a rendition of a {}.', + 'a photo of a nice {}.', + 'a photo of a weird {}.', + 'a blurry photo of a {}.', + 'a cartoon {}.', + 'art of a {}.', + 'a sketch of the {}.', + 'a embroidered {}.', + 'a pixelated photo of a {}.', + 'itap of the {}.', + 'a jpeg corrupted photo of the {}.', + 'a good photo of a {}.', + 'a plushie {}.', + 'a photo of the nice {}.', + 'a photo of the small {}.', + 'a photo of the weird {}.', + 'the cartoon {}.', + 'art of the {}.', + 'a drawing of the {}.', + 'a photo of the large {}.', + 'a black and white photo of a {}.', + 'the plushie {}.', + 'a dark photo of a {}.', + 'itap of a {}.', + 'graffiti of the {}.', + 'a toy {}.', + 'itap of my {}.', + 'a photo of a cool {}.', + 'a photo of a small {}.', + 'a tattoo of the {}.', + ], + 'vild': [ + 'a photo of a {}.', + 'This is a photo of a {}', + 'There is a {} in the scene', + 'There is the {} in the scene', + 'a photo of a {} in the scene', + 'a photo of a small {}.', + 'a photo of a medium {}.', + 'a photo of a large {}.', + 'This is a photo of a small {}.', + 'This is a photo of a medium {}.', + 'This is a photo of a large {}.', + 'There is a small {} in the scene.', + 'There is a medium {} in the scene.', + 'There is a large {} in the scene.', + ], +} + + +def get_predefined_templates(template_set_name: str) -> List[str]: + if template_set_name not in PREDEFINED_TEMPLATES: + raise ValueError(f'Template set {template_set_name} not found') + return PREDEFINED_TEMPLATES[template_set_name] diff --git a/finetune/mmseg/utils/io.py b/finetune/mmseg/utils/io.py new file mode 100644 index 0000000..7029c3c --- /dev/null +++ b/finetune/mmseg/utils/io.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import gzip +import io +import pickle + +import cv2 +import numpy as np + + +def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray: + """Data decoding from bytes. + + Args: + content (bytes): The data bytes got from files or other streams. + backend (str): The data decoding backend type. Options are 'numpy', + 'nifti', 'cv2' and 'pickle'. Defaults to 'numpy'. + + Returns: + numpy.ndarray: Loaded data array. + """ + if backend == 'pickle': + data = pickle.loads(content) + else: + with io.BytesIO(content) as f: + if backend == 'nifti': + f = gzip.open(f) + try: + from nibabel import FileHolder, Nifti1Image + except ImportError: + print('nifti files io depends on nibabel, please run' + '`pip install nibabel` to install it') + fh = FileHolder(fileobj=f) + data = Nifti1Image.from_file_map({'header': fh, 'image': fh}) + data = Nifti1Image.from_bytes(data.to_bytes()).get_fdata() + elif backend == 'numpy': + data = np.load(f) + elif backend == 'cv2': + data = np.frombuffer(f.read(), dtype=np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_UNCHANGED) + else: + raise ValueError + return data diff --git a/finetune/mmseg/utils/mask_classification.py b/finetune/mmseg/utils/mask_classification.py new file mode 100644 index 0000000..205d525 --- /dev/null +++ b/finetune/mmseg/utils/mask_classification.py @@ -0,0 +1,205 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +from mmcv.ops import point_sample +from mmengine.structures import InstanceData +from torch import Tensor + +from mmseg.registry import TASK_UTILS +from mmseg.utils import ConfigType, SampleList + + +def seg_data_to_instance_data(ignore_index: int, + batch_data_samples: SampleList): + """Convert the paradigm of ground truth from semantic segmentation to + instance segmentation. + + Args: + ignore_index (int): The label index to be ignored. + batch_data_samples (List[SegDataSample]): The Data + Samples. It usually includes information such as + `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two lists. + - batch_gt_instances (List[InstanceData]): Batch of + gt_instance. It usually includes ``labels``, each is + unique ground truth label id of images, with + shape (num_gt, ) and ``masks``, each is ground truth + masks of each instances of a image, shape (num_gt, h, w). + - batch_img_metas (List[Dict]): List of image meta information. + """ + batch_gt_instances = [] + + for data_sample in batch_data_samples: + gt_sem_seg = data_sample.gt_sem_seg.data + classes = torch.unique( + gt_sem_seg, + sorted=False, + return_inverse=False, + return_counts=False) + + # remove ignored region + gt_labels = classes[classes != ignore_index] + + masks = [] + for class_id in gt_labels: + masks.append(gt_sem_seg == class_id) + + if len(masks) == 0: + gt_masks = torch.zeros( + (0, gt_sem_seg.shape[-2], + gt_sem_seg.shape[-1])).to(gt_sem_seg).long() + else: + gt_masks = torch.stack(masks).squeeze(1).long() + + instance_data = InstanceData(labels=gt_labels, masks=gt_masks) + batch_gt_instances.append(instance_data) + return batch_gt_instances + + +class MatchMasks: + """Match the predictions to category labels. + + Args: + num_points (int): the number of sampled points to compute cost. + num_queries (int): the number of prediction masks. + num_classes (int): the number of classes. + assigner (BaseAssigner): the assigner to compute matching. + """ + + def __init__(self, + num_points: int, + num_queries: int, + num_classes: int, + assigner: ConfigType = None): + assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \ + 'cannot be None' + assert num_points > 0, 'num_points should be a positive integer.' + self.num_points = num_points + self.num_queries = num_queries + self.num_classes = num_classes + self.assigner = TASK_UTILS.build(assigner) + + def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor], + batch_gt_instances: List[InstanceData]) -> Tuple: + """Compute best mask matches for all images for a decoder layer. + + Args: + cls_scores (List[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape (num_queries, + cls_out_channels). + mask_preds (List[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape (num_queries, h, w). + batch_gt_instances (List[InstanceData]): each contains + ``labels`` and ``masks``. + + Returns: + tuple: a tuple containing the following targets. + + - labels (List[Tensor]): Labels of all images.\ + Each with shape (num_queries, ). + - mask_targets (List[Tensor]): Mask targets of\ + all images. Each with shape (num_queries, h, w). + - mask_weights (List[Tensor]): Mask weights of\ + all images. Each with shape (num_queries, ). + - avg_factor (int): Average factor that is used to + average the loss. `avg_factor` is usually equal + to the number of positive priors. + """ + batch_size = cls_scores.shape[0] + results = dict({ + 'labels': [], + 'mask_targets': [], + 'mask_weights': [], + }) + for i in range(batch_size): + labels, mask_targets, mask_weights\ + = self._get_targets_single(cls_scores[i], + mask_preds[i], + batch_gt_instances[i]) + results['labels'].append(labels) + results['mask_targets'].append(mask_targets) + results['mask_weights'].append(mask_weights) + + # shape (batch_size, num_queries) + labels = torch.stack(results['labels'], dim=0) + # shape (batch_size, num_gts, h, w) + mask_targets = torch.cat(results['mask_targets'], dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(results['mask_weights'], dim=0) + + avg_factor = sum( + [len(gt_instances.labels) for gt_instances in batch_gt_instances]) + + res = (labels, mask_targets, mask_weights, avg_factor) + + return res + + def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, + gt_instances: InstanceData) \ + -> Tuple[Tensor, Tensor, Tensor]: + """Compute a set of best mask matches for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_instances (:obj:`InstanceData`): It contains ``labels`` and + ``masks``. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + """ + gt_labels = gt_instances.labels + gt_masks = gt_instances.masks + # when "gt_labels" is empty, classify all queries to background + if len(gt_labels) == 0: + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + mask_targets = gt_labels + mask_weights = gt_labels.new_zeros((self.num_queries, )) + return labels, mask_targets, mask_weights + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + point_coords = torch.rand((1, self.num_points, 2), + device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample( + mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, + 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample( + gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, + 1)).squeeze(1) + + sampled_gt_instances = InstanceData( + labels=gt_labels, masks=gt_points_masks) + sampled_pred_instances = InstanceData( + scores=cls_score, masks=mask_points_pred) + # assign and sample + matched_quiery_inds, matched_label_inds = self.assigner.assign( + pred_instances=sampled_pred_instances, + gt_instances=sampled_gt_instances) + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + labels[matched_quiery_inds] = gt_labels[matched_label_inds] + + mask_weights = gt_labels.new_zeros((self.num_queries, )) + mask_weights[matched_quiery_inds] = 1 + mask_targets = gt_masks[matched_label_inds] + + return labels, mask_targets, mask_weights diff --git a/finetune/mmseg/utils/misc.py b/finetune/mmseg/utils/misc.py new file mode 100644 index 0000000..dfc469e --- /dev/null +++ b/finetune/mmseg/utils/misc.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from .typing_utils import SampleList + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f'{prefix}.{name}'] = value + + return outputs + + +def stack_batch(inputs: List[torch.Tensor], + data_samples: Optional[SampleList] = None, + size: Optional[tuple] = None, + size_divisor: Optional[int] = None, + pad_val: Union[int, float] = 0, + seg_pad_val: Union[int, float] = 255) -> torch.Tensor: + """Stack multiple inputs to form a batch and pad the images and gt_sem_segs + to the max shape use the right bottom padding mode. + + Args: + inputs (List[Tensor]): The input multiple tensors. each is a + CHW 3D-tensor. + data_samples (list[:obj:`SegDataSample`]): The list of data samples. + It usually includes information such as `gt_sem_seg`. + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (int, float): The padding value. Defaults to 0 + seg_pad_val (int, float): The padding value. Defaults to 255 + + Returns: + Tensor: The 4D-tensor. + List[:obj:`SegDataSample`]: After the padding of the gt_seg_map. + """ + assert isinstance(inputs, list), \ + f'Expected input type to be list, but got {type(inputs)}' + assert len({tensor.ndim for tensor in inputs}) == 1, \ + f'Expected the dimensions of all inputs must be the same, ' \ + f'but got {[tensor.ndim for tensor in inputs]}' + assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \ + f'but got {inputs[0].ndim}' + assert len({tensor.shape[0] for tensor in inputs}) == 1, \ + f'Expected the channels of all inputs must be the same, ' \ + f'but got {[tensor.shape[0] for tensor in inputs]}' + + # only one of size and size_divisor should be valid + assert (size is not None) ^ (size_divisor is not None), \ + 'only one of size and size_divisor should be valid' + + padded_inputs = [] + padded_samples = [] + inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs] + max_size = np.stack(inputs_sizes).max(0) + if size_divisor is not None and size_divisor > 1: + # the last two dims are H,W, both subject to divisibility requirement + max_size = (max_size + + (size_divisor - 1)) // size_divisor * size_divisor + + for i in range(len(inputs)): + tensor = inputs[i] + if size is not None: + width = max(size[-1] - tensor.shape[-1], 0) + height = max(size[-2] - tensor.shape[-2], 0) + # (padding_left, padding_right, padding_top, padding_bottom) + padding_size = (0, width, 0, height) + elif size_divisor is not None: + width = max(max_size[-1] - tensor.shape[-1], 0) + height = max(max_size[-2] - tensor.shape[-2], 0) + padding_size = (0, width, 0, height) + else: + padding_size = [0, 0, 0, 0] + + # pad img + pad_img = F.pad(tensor, padding_size, value=pad_val) + padded_inputs.append(pad_img) + # pad gt_sem_seg + if data_samples is not None: + data_sample = data_samples[i] + pad_shape = None + if 'gt_sem_seg' in data_sample: + gt_sem_seg = data_sample.gt_sem_seg.data + del data_sample.gt_sem_seg.data + data_sample.gt_sem_seg.data = F.pad( + gt_sem_seg, padding_size, value=seg_pad_val) + pad_shape = data_sample.gt_sem_seg.shape + if 'gt_edge_map' in data_sample: + gt_edge_map = data_sample.gt_edge_map.data + del data_sample.gt_edge_map.data + data_sample.gt_edge_map.data = F.pad( + gt_edge_map, padding_size, value=seg_pad_val) + pad_shape = data_sample.gt_edge_map.shape + if 'gt_depth_map' in data_sample: + gt_depth_map = data_sample.gt_depth_map.data + del data_sample.gt_depth_map.data + data_sample.gt_depth_map.data = F.pad( + gt_depth_map, padding_size, value=seg_pad_val) + pad_shape = data_sample.gt_depth_map.shape + data_sample.set_metainfo({ + 'img_shape': tensor.shape[-2:], + 'pad_shape': pad_shape, + 'padding_size': padding_size + }) + padded_samples.append(data_sample) + else: + padded_samples.append( + dict( + img_padding_size=padding_size, + pad_shape=pad_img.shape[-2:])) + + return torch.stack(padded_inputs, dim=0), padded_samples diff --git a/finetune/mmseg/utils/set_env.py b/finetune/mmseg/utils/set_env.py new file mode 100644 index 0000000..c948950 --- /dev/null +++ b/finetune/mmseg/utils/set_env.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import warnings + +from mmengine import DefaultScope + + +def register_all_modules(init_default_scope: bool = True) -> None: + """Register all modules in mmseg into the registries. + + Args: + init_default_scope (bool): Whether initialize the mmseg default scope. + When `init_default_scope=True`, the global default scope will be + set to `mmseg`, and all registries will build modules from mmseg's + registry node. To understand more about the registry, please refer + to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md + Defaults to True. + """ # noqa + import mmseg.datasets # noqa: F401,F403 + import mmseg.engine # noqa: F401,F403 + import mmseg.evaluation # noqa: F401,F403 + import mmseg.models # noqa: F401,F403 + import mmseg.structures # noqa: F401,F403 + + if init_default_scope: + never_created = DefaultScope.get_current_instance() is None \ + or not DefaultScope.check_instance_created('mmseg') + if never_created: + DefaultScope.get_instance('mmseg', scope_name='mmseg') + return + current_scope = DefaultScope.get_current_instance() + if current_scope.scope_name != 'mmseg': + warnings.warn('The current default scope ' + f'"{current_scope.scope_name}" is not "mmseg", ' + '`register_all_modules` will force the current' + 'default scope to be "mmseg". If this is not ' + 'expected, please set `init_default_scope=False`.') + # avoid name conflict + new_instance_name = f'mmseg-{datetime.datetime.now()}' + DefaultScope.get_instance(new_instance_name, scope_name='mmseg') diff --git a/finetune/mmseg/utils/tokenizer.py b/finetune/mmseg/utils/tokenizer.py new file mode 100644 index 0000000..d56f5fa --- /dev/null +++ b/finetune/mmseg/utils/tokenizer.py @@ -0,0 +1,240 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""CLIP tokenizer. + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright +(c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import List, Union + +import ftfy +import regex as re +import torch + +os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'bpe_simple_vocab_16e6.txt.gz') + + +@lru_cache() +def bytes_to_unicode(): + """Returns list of utf-8 byte and a corresponding list of unicode strings. + + The reversible bpe codes work on unicode strings. This means you need a + large # of unicode characters in your vocab if you want to avoid UNKs. When + you're at something like a 10B token dataset you end up needing around 5K + for decent coverage. This is a significant percentage of your normal, say, + 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and + unicode strings. And avoids mapping to whitespace/control characters the + bpe code barfs on. + """ + bs = list(range(ord('!'), + ord('~') + 1)) + list(range( + ord('¡'), + ord('¬') + 1)) + list(range(ord('®'), + ord('ÿ') + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length + strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer: + + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode('utf-8').split('\n') + merges = merges[1:49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + '' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', '' + ] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in special_tokens} + special = '|'.join(special_tokens) + self.pat = re.compile( + special + + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + '', ) + pairs = get_pairs(word) + + if not pairs: + return token + '' + + while True: + bigram = min( + pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: # noqa: E722, E261 + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[ + i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] + for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] + for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + 'utf-8', errors='replace').replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + + +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) + + +def tokenize(texts: Union[str, List[str]], + context_length: int = 77) -> torch.LongTensor: + """Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, + shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[''] + eot_token = _tokenizer.encoder[''] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] + for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + """HuggingFace tokenizer wrapper.""" + + def __init__(self, tokenizer_name: str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, + texts: Union[str, List[str]], + context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it + # more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids + return input_ids diff --git a/finetune/mmseg/utils/typing_utils.py b/finetune/mmseg/utils/typing_utils.py new file mode 100644 index 0000000..fba7d3b --- /dev/null +++ b/finetune/mmseg/utils/typing_utils.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Collecting some commonly used type hint in mmflow.""" +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch +from mmengine.config import ConfigDict + +from mmseg.structures import SegDataSample + +# Type hint of config data +ConfigType = Union[ConfigDict, dict] +OptConfigType = Optional[ConfigType] +# Type hint of one or more config data +MultiConfig = Union[ConfigType, Sequence[ConfigType]] +OptMultiConfig = Optional[MultiConfig] + +SampleList = Sequence[SegDataSample] +OptSampleList = Optional[SampleList] + +# Type hint of Tensor +TensorDict = Dict[str, torch.Tensor] +TensorList = Sequence[torch.Tensor] + +ForwardResults = Union[Dict[str, torch.Tensor], List[SegDataSample], + Tuple[torch.Tensor], torch.Tensor] diff --git a/finetune/mmseg/version.py b/finetune/mmseg/version.py new file mode 100644 index 0000000..b76bb45 --- /dev/null +++ b/finetune/mmseg/version.py @@ -0,0 +1,18 @@ +# Copyright (c) Open-MMLab. All rights reserved. + +__version__ = '1.2.2' + + +def parse_version_info(version_str): + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) diff --git a/finetune/mmseg/visualization/__init__.py b/finetune/mmseg/visualization/__init__.py new file mode 100644 index 0000000..8cbb211 --- /dev/null +++ b/finetune/mmseg/visualization/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .local_visualizer import SegLocalVisualizer + +__all__ = ['SegLocalVisualizer'] diff --git a/finetune/mmseg/visualization/local_visualizer.py b/finetune/mmseg/visualization/local_visualizer.py new file mode 100644 index 0000000..ee3d652 --- /dev/null +++ b/finetune/mmseg/visualization/local_visualizer.py @@ -0,0 +1,349 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional + +import cv2 +import mmcv +import numpy as np +import torch +from mmengine.dist import master_only +from mmengine.structures import PixelData +from mmengine.visualization import Visualizer + +from mmseg.registry import VISUALIZERS +from mmseg.structures import SegDataSample +from mmseg.utils import get_classes, get_palette + + +@VISUALIZERS.register_module() +class SegLocalVisualizer(Visualizer): + """Local Visualizer. + + Args: + name (str): Name of the instance. Defaults to 'visualizer'. + image (np.ndarray, optional): the origin image to draw. The format + should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Defaults to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + classes (list, optional): Input classes for result rendering, as the + prediction of segmentation model is a segment map with label + indices, `classes` is a list which includes items responding to the + label indices. If classes is not defined, visualizer will take + `cityscapes` classes by default. Defaults to None. + palette (list, optional): Input palette for result rendering, which is + a list of color palette responding to the classes. Defaults to None. + dataset_name (str, optional): `Dataset name or alias `_ + visulizer will use the meta information of the dataset i.e. classes + and palette, but the `classes` and `palette` have higher priority. + Defaults to None. + alpha (int, float): The transparency of segmentation mask. + Defaults to 0.8. + + Examples: + >>> import numpy as np + >>> import torch + >>> from mmengine.structures import PixelData + >>> from mmseg.structures import SegDataSample + >>> from mmseg.visualization import SegLocalVisualizer + + >>> seg_local_visualizer = SegLocalVisualizer() + >>> image = np.random.randint(0, 256, + ... size=(10, 12, 3)).astype('uint8') + >>> gt_sem_seg_data = dict(data=torch.randint(0, 2, (1, 10, 12))) + >>> gt_sem_seg = PixelData(**gt_sem_seg_data) + >>> gt_seg_data_sample = SegDataSample() + >>> gt_seg_data_sample.gt_sem_seg = gt_sem_seg + >>> seg_local_visualizer.dataset_meta = dict( + >>> classes=('background', 'foreground'), + >>> palette=[[120, 120, 120], [6, 230, 230]]) + >>> seg_local_visualizer.add_datasample('visualizer_example', + ... image, gt_seg_data_sample) + >>> seg_local_visualizer.add_datasample( + ... 'visualizer_example', image, + ... gt_seg_data_sample, show=True) + """ # noqa + + def __init__(self, + name: str = 'visualizer', + image: Optional[np.ndarray] = None, + vis_backends: Optional[Dict] = None, + save_dir: Optional[str] = None, + classes: Optional[List] = None, + palette: Optional[List] = None, + dataset_name: Optional[str] = None, + alpha: float = 0.8, + **kwargs): + super().__init__(name, image, vis_backends, save_dir, **kwargs) + self.alpha: float = alpha + self.set_dataset_meta(palette, classes, dataset_name) + + def _get_center_loc(self, mask: np.ndarray) -> np.ndarray: + """Get semantic seg center coordinate. + + Args: + mask: np.ndarray: get from sem_seg + """ + loc = np.argwhere(mask == 1) + + loc_sort = np.array( + sorted(loc.tolist(), key=lambda row: (row[0], row[1]))) + y_list = loc_sort[:, 0] + unique, indices, counts = np.unique( + y_list, return_index=True, return_counts=True) + y_loc = unique[counts.argmax()] + y_most_freq_loc = loc[loc_sort[:, 0] == y_loc] + center_num = len(y_most_freq_loc) // 2 + x = y_most_freq_loc[center_num][1] + y = y_most_freq_loc[center_num][0] + return np.array([x, y]) + + def _draw_sem_seg(self, + image: np.ndarray, + sem_seg: PixelData, + classes: Optional[List], + palette: Optional[List], + with_labels: Optional[bool] = True) -> np.ndarray: + """Draw semantic seg of GT or prediction. + + Args: + image (np.ndarray): The image to draw. + sem_seg (:obj:`PixelData`): Data structure for pixel-level + annotations or predictions. + classes (list, optional): Input classes for result rendering, as + the prediction of segmentation model is a segment map with + label indices, `classes` is a list which includes items + responding to the label indices. If classes is not defined, + visualizer will take `cityscapes` classes by default. + Defaults to None. + palette (list, optional): Input palette for result rendering, which + is a list of color palette responding to the classes. + Defaults to None. + with_labels(bool, optional): Add semantic labels in visualization + result, Default to True. + + Returns: + np.ndarray: the drawn image which channel is RGB. + """ + num_classes = len(classes) + + sem_seg = sem_seg.cpu().data + ids = np.unique(sem_seg)[::-1] + legal_indices = ids < num_classes + ids = ids[legal_indices] + labels = np.array(ids, dtype=np.int64) + + colors = [palette[label] for label in labels] + + mask = np.zeros_like(image, dtype=np.uint8) + for label, color in zip(labels, colors): + mask[sem_seg[0] == label, :] = color + + if with_labels: + font = cv2.FONT_HERSHEY_SIMPLEX + # (0,1] to change the size of the text relative to the image + scale = 0.05 + fontScale = min(image.shape[0], image.shape[1]) / (25 / scale) + fontColor = (255, 255, 255) + if image.shape[0] < 300 or image.shape[1] < 300: + thickness = 1 + rectangleThickness = 1 + else: + thickness = 2 + rectangleThickness = 2 + lineType = 2 + + if isinstance(sem_seg[0], torch.Tensor): + masks = sem_seg[0].numpy() == labels[:, None, None] + else: + masks = sem_seg[0] == labels[:, None, None] + masks = masks.astype(np.uint8) + for mask_num in range(len(labels)): + classes_id = labels[mask_num] + classes_color = colors[mask_num] + loc = self._get_center_loc(masks[mask_num]) + text = classes[classes_id] + (label_width, label_height), baseline = cv2.getTextSize( + text, font, fontScale, thickness) + mask = cv2.rectangle(mask, loc, + (loc[0] + label_width + baseline, + loc[1] + label_height + baseline), + classes_color, -1) + mask = cv2.rectangle(mask, loc, + (loc[0] + label_width + baseline, + loc[1] + label_height + baseline), + (0, 0, 0), rectangleThickness) + mask = cv2.putText(mask, text, (loc[0], loc[1] + label_height), + font, fontScale, fontColor, thickness, + lineType) + color_seg = (image * (1 - self.alpha) + mask * self.alpha).astype( + np.uint8) + self.set_image(color_seg) + return color_seg + + def _draw_depth_map(self, image: np.ndarray, + depth_map: PixelData) -> np.ndarray: + """Draws a depth map on a given image. + + This function takes an image and a depth map as input, + renders the depth map, and concatenates it with the original image. + Finally, it updates the internal image state of the visualizer with + the concatenated result. + + Args: + image (np.ndarray): The original image where the depth map will + be drawn. The array should be in the format HxWx3 where H is + the height, W is the width. + + depth_map (PixelData): Depth map to be drawn. The depth map + should be in the form of a PixelData object. It will be + converted to a torch tensor if it is a numpy array. + + Returns: + np.ndarray: The concatenated image with the depth map drawn. + + Example: + >>> depth_map_data = PixelData(data=torch.rand(1, 10, 10)) + >>> image = np.random.randint(0, 256, + >>> size=(10, 10, 3)).astype('uint8') + >>> visualizer = SegLocalVisualizer() + >>> visualizer._draw_depth_map(image, depth_map_data) + """ + depth_map = depth_map.cpu().data + if isinstance(depth_map, np.ndarray): + depth_map = torch.from_numpy(depth_map) + if depth_map.ndim == 2: + depth_map = depth_map[None] + + depth_map = self.draw_featmap(depth_map, resize_shape=image.shape[:2]) + out_image = np.concatenate((image, depth_map), axis=0) + self.set_image(out_image) + return out_image + + def set_dataset_meta(self, + classes: Optional[List] = None, + palette: Optional[List] = None, + dataset_name: Optional[str] = None) -> None: + """Set meta information to visualizer. + + Args: + classes (list, optional): Input classes for result rendering, as + the prediction of segmentation model is a segment map with + label indices, `classes` is a list which includes items + responding to the label indices. If classes is not defined, + visualizer will take `cityscapes` classes by default. + Defaults to None. + palette (list, optional): Input palette for result rendering, which + is a list of color palette responding to the classes. + Defaults to None. + dataset_name (str, optional): `Dataset name or alias `_ + visulizer will use the meta information of the dataset i.e. + classes and palette, but the `classes` and `palette` have + higher priority. Defaults to None. + """ # noqa + # Set default value. When calling + # `SegLocalVisualizer().dataset_meta=xxx`, + # it will override the default value. + if dataset_name is None: + dataset_name = 'cityscapes' + classes = classes if classes else get_classes(dataset_name) + palette = palette if palette else get_palette(dataset_name) + assert len(classes) == len( + palette), 'The length of classes should be equal to palette' + self.dataset_meta: dict = {'classes': classes, 'palette': palette} + + @master_only + def add_datasample( + self, + name: str, + image: np.ndarray, + data_sample: Optional[SegDataSample] = None, + draw_gt: bool = True, + draw_pred: bool = True, + show: bool = False, + wait_time: float = 0, + # TODO: Supported in mmengine's Viusalizer. + out_file: Optional[str] = None, + step: int = 0, + with_labels: Optional[bool] = True) -> None: + """Draw datasample and save to all backends. + + - If GT and prediction are plotted at the same time, they are + displayed in a stitched image where the left image is the + ground truth and the right image is the prediction. + - If ``show`` is True, all storage backends are ignored, and + the images will be displayed in a local window. + - If ``out_file`` is specified, the drawn image will be + saved to ``out_file``. it is usually used when the display + is not available. + + Args: + name (str): The image identifier. + image (np.ndarray): The image to draw. + gt_sample (:obj:`SegDataSample`, optional): GT SegDataSample. + Defaults to None. + pred_sample (:obj:`SegDataSample`, optional): Prediction + SegDataSample. Defaults to None. + draw_gt (bool): Whether to draw GT SegDataSample. Default to True. + draw_pred (bool): Whether to draw Prediction SegDataSample. + Defaults to True. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + out_file (str): Path to output file. Defaults to None. + step (int): Global step value to record. Defaults to 0. + with_labels(bool, optional): Add semantic labels in visualization + result, Defaults to True. + """ + classes = self.dataset_meta.get('classes', None) + palette = self.dataset_meta.get('palette', None) + + gt_img_data = None + pred_img_data = None + + if draw_gt and data_sample is not None: + if 'gt_sem_seg' in data_sample: + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing semantic ' \ + 'segmentation results.' + gt_img_data = self._draw_sem_seg(image, data_sample.gt_sem_seg, + classes, palette, with_labels) + + if 'gt_depth_map' in data_sample: + gt_img_data = gt_img_data if gt_img_data is not None else image + gt_img_data = self._draw_depth_map(gt_img_data, + data_sample.gt_depth_map) + + if draw_pred and data_sample is not None: + + if 'pred_sem_seg' in data_sample: + + assert classes is not None, 'class information is ' \ + 'not provided when ' \ + 'visualizing semantic ' \ + 'segmentation results.' + pred_img_data = self._draw_sem_seg(image, + data_sample.pred_sem_seg, + classes, palette, + with_labels) + + if 'pred_depth_map' in data_sample: + pred_img_data = pred_img_data if pred_img_data is not None \ + else image + pred_img_data = self._draw_depth_map( + pred_img_data, data_sample.pred_depth_map) + + if gt_img_data is not None and pred_img_data is not None: + drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) + elif gt_img_data is not None: + drawn_img = gt_img_data + else: + drawn_img = pred_img_data + + if show: + self.show(drawn_img, win_name=name, wait_time=wait_time) + + if out_file is not None: + mmcv.imwrite(mmcv.rgb2bgr(drawn_img), out_file) + else: + self.add_image(name, drawn_img, step) diff --git a/finetune/requirements.txt b/finetune/requirements.txt new file mode 100644 index 0000000..501bddc --- /dev/null +++ b/finetune/requirements.txt @@ -0,0 +1,4 @@ +-r requirements/optional.txt +-r requirements/runtime.txt +-r requirements/tests.txt +-r requirements/multimodal.txt diff --git a/finetune/requirements/albu.txt b/finetune/requirements/albu.txt new file mode 100644 index 0000000..f421fbb --- /dev/null +++ b/finetune/requirements/albu.txt @@ -0,0 +1 @@ +albumentations>=0.3.2 --no-binary qudida,albumentations diff --git a/finetune/requirements/docs.txt b/finetune/requirements/docs.txt new file mode 100644 index 0000000..19632d3 --- /dev/null +++ b/finetune/requirements/docs.txt @@ -0,0 +1,7 @@ +docutils==0.16.0 +myst-parser +-e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme +sphinx==4.0.2 +sphinx_copybutton +sphinx_markdown_tables +urllib3<2.0.0 diff --git a/finetune/requirements/mminstall.txt b/finetune/requirements/mminstall.txt new file mode 100644 index 0000000..5732d34 --- /dev/null +++ b/finetune/requirements/mminstall.txt @@ -0,0 +1,2 @@ +mmcv>=2.0.0rc4,<2.2.0 +mmengine>=0.5.0,<1.0.0 diff --git a/finetune/requirements/multimodal.txt b/finetune/requirements/multimodal.txt new file mode 100644 index 0000000..2195d0d --- /dev/null +++ b/finetune/requirements/multimodal.txt @@ -0,0 +1,2 @@ +ftfy +regex diff --git a/finetune/requirements/optional.txt b/finetune/requirements/optional.txt new file mode 100644 index 0000000..b0310f5 --- /dev/null +++ b/finetune/requirements/optional.txt @@ -0,0 +1,22 @@ +cityscapesscripts +-e git+https://github.com/openai/CLIP.git@main#egg=clip + +# for vpd model +diffusers +einops==0.3.0 +imageio==2.9.0 +imageio-ffmpeg==0.4.2 +invisible-watermark +kornia==0.6 +-e git+https://github.com/CompVis/stable-diffusion@21f890f#egg=latent-diffusion +nibabel +omegaconf==2.1.1 +pudb==2019.2 +pytorch-lightning==1.4.2 +streamlit>=0.73.1 +-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers +test-tube>=0.7.5 +timm +torch-fidelity==0.3.0 +torchmetrics==0.6.0 +transformers==4.19.2 diff --git a/finetune/requirements/readthedocs.txt b/finetune/requirements/readthedocs.txt new file mode 100644 index 0000000..9627504 --- /dev/null +++ b/finetune/requirements/readthedocs.txt @@ -0,0 +1,6 @@ +mmcv>=2.0.0rc1,<2.1.0 +mmengine>=0.4.0,<1.0.0 +prettytable +scipy +torch +torchvision diff --git a/finetune/requirements/runtime.txt b/finetune/requirements/runtime.txt new file mode 100644 index 0000000..3e24258 --- /dev/null +++ b/finetune/requirements/runtime.txt @@ -0,0 +1,5 @@ +matplotlib +numpy +packaging +prettytable +scipy diff --git a/finetune/requirements/tests.txt b/finetune/requirements/tests.txt new file mode 100644 index 0000000..3fff252 --- /dev/null +++ b/finetune/requirements/tests.txt @@ -0,0 +1,8 @@ +codecov +flake8 +ftfy +interrogate +pytest +regex +xdoctest>=0.10.0 +yapf diff --git a/finetune/setup.py b/finetune/setup.py new file mode 100644 index 0000000..45d923d --- /dev/null +++ b/finetune/setup.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import platform +import shutil +import sys +import warnings +from setuptools import find_packages, setup + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +version_file = 'mmseg/version.py' + + +def get_version(): + with open(version_file) as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def parse_requirements(fname='requirements.txt', with_version=True): + """Parse the package dependencies listed in a requirements file but strips + specific versioning information. + + Args: + fname (str): path to requirements file + with_version (bool, default=False): if True include version specs + + Returns: + List[str]: list of requirements items + + CommandLine: + python -c "import setup; print(setup.parse_requirements())" + """ + import re + import sys + from os.path import exists + require_fpath = fname + + def parse_line(line): + """Parse information from a line in a requirements text file.""" + if line.startswith('-r '): + # Allow specifying requirements in other files + target = line.split(' ')[1] + for info in parse_require_file(target): + yield info + else: + info = {'line': line} + if line.startswith('-e '): + info['package'] = line.split('#egg=')[1] + else: + # Remove versioning from the package + pat = '(' + '|'.join(['>=', '==', '>']) + ')' + parts = re.split(pat, line, maxsplit=1) + parts = [p.strip() for p in parts] + + info['package'] = parts[0] + if len(parts) > 1: + op, rest = parts[1:] + if ';' in rest: + # Handle platform specific dependencies + # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies + version, platform_deps = map(str.strip, + rest.split(';')) + info['platform_deps'] = platform_deps + else: + version = rest # NOQA + info['version'] = (op, version) + yield info + + def parse_require_file(fpath): + with open(fpath) as f: + for line in f.readlines(): + line = line.strip() + if line and not line.startswith('#'): + yield from parse_line(line) + + def gen_packages_items(): + if exists(require_fpath): + for info in parse_require_file(require_fpath): + parts = [info['package']] + if with_version and 'version' in info: + parts.extend(info['version']) + if not sys.version.startswith('3.4'): + # apparently package_deps are broken in 3.4 + platform_deps = info.get('platform_deps') + if platform_deps is not None: + parts.append(';' + platform_deps) + item = ''.join(parts) + yield item + + packages = list(gen_packages_items()) + return packages + + +def add_mim_extension(): + """Add extra files that are required to support MIM into the package. + + These files will be added by creating a symlink to the originals if the + package is installed in `editable` mode (e.g. pip install -e .), or by + copying from the originals otherwise. + """ + + # parse installment mode + if 'develop' in sys.argv: + # installed by `pip install -e .` + if platform.system() == 'Windows': + # set `copy` mode here since symlink fails on Windows. + mode = 'copy' + else: + mode = 'symlink' + elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv or \ + platform.system() == 'Windows': + # installed by `pip install .` + # or create source distribution by `python setup.py sdist` + # set `copy` mode here since symlink fails with WinError on Windows. + mode = 'copy' + else: + return + + filenames = ['tools', 'configs', 'model-index.yml', 'dataset-index.yml'] + repo_path = osp.dirname(__file__) + mim_path = osp.join(repo_path, 'mmseg', '.mim') + os.makedirs(mim_path, exist_ok=True) + + for filename in filenames: + if osp.exists(filename): + src_path = osp.join(repo_path, filename) + tar_path = osp.join(mim_path, filename) + + if osp.isfile(tar_path) or osp.islink(tar_path): + os.remove(tar_path) + elif osp.isdir(tar_path): + shutil.rmtree(tar_path) + + if mode == 'symlink': + src_relpath = osp.relpath(src_path, osp.dirname(tar_path)) + try: + os.symlink(src_relpath, tar_path) + except OSError: + # Creating a symbolic link on windows may raise an + # `OSError: [WinError 1314]` due to privilege. If + # the error happens, the src file will be copied + mode = 'copy' + warnings.warn( + f'Failed to create a symbolic link for {src_relpath}, ' + f'and it will be copied to {tar_path}') + else: + continue + + if mode == 'copy': + if osp.isfile(src_path): + shutil.copyfile(src_path, tar_path) + elif osp.isdir(src_path): + shutil.copytree(src_path, tar_path) + else: + warnings.warn(f'Cannot copy file {src_path}.') + else: + raise ValueError(f'Invalid mode {mode}') + + +if __name__ == '__main__': + add_mim_extension() + setup( + name='mmsegmentation', + version=get_version(), + description='Open MMLab Semantic Segmentation Toolbox and Benchmark', + long_description=readme(), + long_description_content_type='text/markdown', + author='MMSegmentation Contributors', + author_email='openmmlab@gmail.com', + keywords='computer vision, semantic segmentation', + url='https://github.com/open-mmlab/mmsegmentation', + packages=find_packages(exclude=('configs', 'tools', 'demo')), + include_package_data=True, + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + ], + license='Apache License 2.0', + install_requires=parse_requirements('requirements/runtime.txt'), + extras_require={ + 'all': parse_requirements('requirements.txt'), + 'tests': parse_requirements('requirements/tests.txt'), + 'optional': parse_requirements('requirements/optional.txt'), + 'mim': parse_requirements('requirements/mminstall.txt'), + 'multimodal': parse_requirements('requirements/multimodal.txt'), + }, + ext_modules=[], + zip_safe=False) diff --git a/finetune/tools/analysis_tools/analyze_logs.py b/finetune/tools/analysis_tools/analyze_logs.py new file mode 100644 index 0000000..7464d23 --- /dev/null +++ b/finetune/tools/analysis_tools/analyze_logs.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Modified from https://github.com/open- +mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py.""" +import argparse +import json +from collections import defaultdict + +import matplotlib.pyplot as plt +import seaborn as sns + + +def plot_curve(log_dicts, args): + if args.backend is not None: + plt.switch_backend(args.backend) + sns.set_style(args.style) + # if legend is None, use {filename}_{key} as legend + legend = args.legend + if legend is None: + legend = [] + for json_log in args.json_logs: + for metric in args.keys: + legend.append(f'{json_log}_{metric}') + assert len(legend) == (len(args.json_logs) * len(args.keys)) + metrics = args.keys + + num_metrics = len(metrics) + for i, log_dict in enumerate(log_dicts): + epochs = list(log_dict.keys()) + for j, metric in enumerate(metrics): + print(f'plot curve of {args.json_logs[i]}, metric is {metric}') + plot_epochs = [] + plot_iters = [] + plot_values = [] + # In some log files exist lines of validation, + # `mode` list is used to only collect iter number + # of training line. + for epoch in epochs: + epoch_logs = log_dict[epoch] + if metric not in epoch_logs.keys(): + continue + if metric in ['mIoU', 'mAcc', 'aAcc']: + plot_epochs.append(epoch) + plot_values.append(epoch_logs[metric][0]) + else: + for idx in range(len(epoch_logs[metric])): + plot_iters.append(epoch_logs['step'][idx]) + plot_values.append(epoch_logs[metric][idx]) + ax = plt.gca() + label = legend[i * num_metrics + j] + if metric in ['mIoU', 'mAcc', 'aAcc']: + ax.set_xticks(plot_epochs) + plt.xlabel('step') + plt.plot(plot_epochs, plot_values, label=label, marker='o') + else: + plt.xlabel('iter') + plt.plot(plot_iters, plot_values, label=label, linewidth=0.5) + plt.legend() + if args.title is not None: + plt.title(args.title) + if args.out is None: + plt.show() + else: + print(f'save curve to: {args.out}') + plt.savefig(args.out) + plt.cla() + + +def parse_args(): + parser = argparse.ArgumentParser(description='Analyze Json Log') + parser.add_argument( + 'json_logs', + type=str, + nargs='+', + help='path of train log in json format') + parser.add_argument( + '--keys', + type=str, + nargs='+', + default=['mIoU'], + help='the metric that you want to plot') + parser.add_argument('--title', type=str, help='title of figure') + parser.add_argument( + '--legend', + type=str, + nargs='+', + default=None, + help='legend of each plot') + parser.add_argument( + '--backend', type=str, default=None, help='backend of plt') + parser.add_argument( + '--style', type=str, default='dark', help='style of plt') + parser.add_argument('--out', type=str, default=None) + args = parser.parse_args() + return args + + +def load_json_logs(json_logs): + # load and convert json_logs to log_dict, key is step, value is a sub dict + # keys of sub dict is different metrics + # value of sub dict is a list of corresponding values of all iterations + log_dicts = [dict() for _ in json_logs] + prev_step = 0 + for json_log, log_dict in zip(json_logs, log_dicts): + with open(json_log) as log_file: + for line in log_file: + log = json.loads(line.strip()) + # the final step in json file is 0. + if 'step' in log and log['step'] != 0: + step = log['step'] + prev_step = step + else: + step = prev_step + if step not in log_dict: + log_dict[step] = defaultdict(list) + for k, v in log.items(): + log_dict[step][k].append(v) + return log_dicts + + +def main(): + args = parse_args() + json_logs = args.json_logs + for json_log in json_logs: + assert json_log.endswith('.json') + log_dicts = load_json_logs(json_logs) + plot_curve(log_dicts, args) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/analysis_tools/benchmark.py b/finetune/tools/analysis_tools/benchmark.py new file mode 100644 index 0000000..afaeaba --- /dev/null +++ b/finetune/tools/analysis_tools/benchmark.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +import time + +import numpy as np +import torch +from mmengine import Config +from mmengine.fileio import dump +from mmengine.model.utils import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.runner import Runner, load_checkpoint +from mmengine.utils import mkdir_or_exist + +from mmseg.registry import MODELS + + +def parse_args(): + parser = argparse.ArgumentParser(description='MMSeg benchmark a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--log-interval', type=int, default=50, help='interval of logging') + parser.add_argument( + '--work-dir', + help=('if specified, the results will be dumped ' + 'into the directory as json')) + parser.add_argument('--repeat-times', type=int, default=1) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + cfg = Config.fromfile(args.config) + + init_default_scope(cfg.get('default_scope', 'mmseg')) + + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + if args.work_dir is not None: + mkdir_or_exist(osp.abspath(args.work_dir)) + json_file = osp.join(args.work_dir, f'fps_{timestamp}.json') + else: + # use config filename as default work_dir if cfg.work_dir is None + work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + mkdir_or_exist(osp.abspath(work_dir)) + json_file = osp.join(work_dir, f'fps_{timestamp}.json') + + repeat_times = args.repeat_times + # set cudnn_benchmark + torch.backends.cudnn.benchmark = False + cfg.model.pretrained = None + + benchmark_dict = dict(config=args.config, unit='img / s') + overall_fps_list = [] + cfg.test_dataloader.batch_size = 1 + for time_index in range(repeat_times): + print(f'Run {time_index + 1}:') + # build the dataloader + data_loader = Runner.build_dataloader(cfg.test_dataloader) + + # build the model and load checkpoint + cfg.model.train_cfg = None + model = MODELS.build(cfg.model) + + if 'checkpoint' in args and osp.exists(args.checkpoint): + load_checkpoint(model, args.checkpoint, map_location='cpu') + + if torch.cuda.is_available(): + model = model.cuda() + + model = revert_sync_batchnorm(model) + + model.eval() + + # the first several iterations may be very slow so skip them + num_warmup = 5 + pure_inf_time = 0 + total_iters = 200 + + # benchmark with 200 batches and take the average + for i, data in enumerate(data_loader): + data = model.data_preprocessor(data, True) + inputs = data['inputs'] + data_samples = data['data_samples'] + if torch.cuda.is_available(): + torch.cuda.synchronize() + start_time = time.perf_counter() + + with torch.no_grad(): + model(inputs, data_samples, mode='predict') + + if torch.cuda.is_available(): + torch.cuda.synchronize() + elapsed = time.perf_counter() - start_time + + if i >= num_warmup: + pure_inf_time += elapsed + if (i + 1) % args.log_interval == 0: + fps = (i + 1 - num_warmup) / pure_inf_time + print(f'Done image [{i + 1:<3}/ {total_iters}], ' + f'fps: {fps:.2f} img / s') + + if (i + 1) == total_iters: + fps = (i + 1 - num_warmup) / pure_inf_time + print(f'Overall fps: {fps:.2f} img / s\n') + benchmark_dict[f'overall_fps_{time_index + 1}'] = round(fps, 2) + overall_fps_list.append(fps) + break + benchmark_dict['average_fps'] = round(np.mean(overall_fps_list), 2) + benchmark_dict['fps_variance'] = round(np.var(overall_fps_list), 4) + print(f'Average fps of {repeat_times} evaluations: ' + f'{benchmark_dict["average_fps"]}') + print(f'The variance of {repeat_times} evaluations: ' + f'{benchmark_dict["fps_variance"]}') + dump(benchmark_dict, json_file, indent=4) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/analysis_tools/browse_dataset.py b/finetune/tools/analysis_tools/browse_dataset.py new file mode 100644 index 0000000..925c14a --- /dev/null +++ b/finetune/tools/analysis_tools/browse_dataset.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.utils import ProgressBar + +from mmseg.registry import DATASETS, VISUALIZERS +from mmseg.utils import register_all_modules + + +def parse_args(): + parser = argparse.ArgumentParser(description='Browse a dataset') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--output-dir', + default=None, + type=str, + help='If there is no display interface, you can save it') + parser.add_argument('--not-show', default=False, action='store_true') + parser.add_argument( + '--show-interval', + type=float, + default=2, + help='the interval of show (s)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # register all modules in mmdet into the registries + register_all_modules() + + dataset = DATASETS.build(cfg.train_dataloader.dataset) + visualizer = VISUALIZERS.build(cfg.visualizer) + visualizer.dataset_meta = dataset.metainfo + + progress_bar = ProgressBar(len(dataset)) + for item in dataset: + img = item['inputs'].permute(1, 2, 0).numpy() + img = img[..., [2, 1, 0]] # bgr to rgb + data_sample = item['data_samples'].numpy() + img_path = osp.basename(item['data_samples'].img_path) + + out_file = osp.join( + args.output_dir, + osp.basename(img_path)) if args.output_dir is not None else None + + visualizer.add_datasample( + name=osp.basename(img_path), + image=img, + data_sample=data_sample, + draw_gt=True, + draw_pred=False, + wait_time=args.show_interval, + out_file=out_file, + show=not args.not_show) + progress_bar.update() + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/analysis_tools/confusion_matrix.py b/finetune/tools/analysis_tools/confusion_matrix.py new file mode 100644 index 0000000..39756cd --- /dev/null +++ b/finetune/tools/analysis_tools/confusion_matrix.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.ticker import MultipleLocator +from mmengine.config import Config, DictAction +from mmengine.registry import init_default_scope +from mmengine.utils import mkdir_or_exist, progressbar +from PIL import Image + +from mmseg.registry import DATASETS + +init_default_scope('mmseg') + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Generate confusion matrix from segmentation results') + parser.add_argument('config', help='test config file path') + parser.add_argument( + 'prediction_path', help='prediction path where test folder result') + parser.add_argument( + 'save_dir', help='directory where confusion matrix will be saved') + parser.add_argument( + '--show', action='store_true', help='show confusion matrix') + parser.add_argument( + '--color-theme', + default='winter', + help='theme of the matrix color map') + parser.add_argument( + '--title', + default='Normalized Confusion Matrix', + help='title of the matrix color map') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def calculate_confusion_matrix(dataset, results): + """Calculate the confusion matrix. + + Args: + dataset (Dataset): Test or val dataset. + results (list[ndarray]): A list of segmentation results in each image. + """ + n = len(dataset.METAINFO['classes']) + confusion_matrix = np.zeros(shape=[n, n]) + assert len(dataset) == len(results) + ignore_index = dataset.ignore_index + reduce_zero_label = dataset.reduce_zero_label + prog_bar = progressbar.ProgressBar(len(results)) + for idx, per_img_res in enumerate(results): + res_segm = per_img_res + gt_segm = dataset[idx]['data_samples'] \ + .gt_sem_seg.data.squeeze().numpy().astype(np.uint8) + gt_segm, res_segm = gt_segm.flatten(), res_segm.flatten() + if reduce_zero_label: + gt_segm = gt_segm - 1 + to_ignore = gt_segm == ignore_index + + gt_segm, res_segm = gt_segm[~to_ignore], res_segm[~to_ignore] + inds = n * gt_segm + res_segm + mat = np.bincount(inds, minlength=n**2).reshape(n, n) + confusion_matrix += mat + prog_bar.update() + return confusion_matrix + + +def plot_confusion_matrix(confusion_matrix, + labels, + save_dir=None, + show=True, + title='Normalized Confusion Matrix', + color_theme='OrRd'): + """Draw confusion matrix with matplotlib. + + Args: + confusion_matrix (ndarray): The confusion matrix. + labels (list[str]): List of class names. + save_dir (str|optional): If set, save the confusion matrix plot to the + given path. Default: None. + show (bool): Whether to show the plot. Default: True. + title (str): Title of the plot. Default: `Normalized Confusion Matrix`. + color_theme (str): Theme of the matrix color map. Default: `winter`. + """ + # normalize the confusion matrix + per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis] + confusion_matrix = \ + confusion_matrix.astype(np.float32) / per_label_sums * 100 + + num_classes = len(labels) + fig, ax = plt.subplots( + figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=300) + cmap = plt.get_cmap(color_theme) + im = ax.imshow(confusion_matrix, cmap=cmap) + colorbar = plt.colorbar(mappable=im, ax=ax) + colorbar.ax.tick_params(labelsize=20) # 设置 colorbar 标签的字体大小 + + title_font = {'weight': 'bold', 'size': 20} + ax.set_title(title, fontdict=title_font) + label_font = {'size': 40} + plt.ylabel('Ground Truth Label', fontdict=label_font) + plt.xlabel('Prediction Label', fontdict=label_font) + + # draw locator + xmajor_locator = MultipleLocator(1) + xminor_locator = MultipleLocator(0.5) + ax.xaxis.set_major_locator(xmajor_locator) + ax.xaxis.set_minor_locator(xminor_locator) + ymajor_locator = MultipleLocator(1) + yminor_locator = MultipleLocator(0.5) + ax.yaxis.set_major_locator(ymajor_locator) + ax.yaxis.set_minor_locator(yminor_locator) + + # draw grid + ax.grid(True, which='minor', linestyle='-') + + # draw label + ax.set_xticks(np.arange(num_classes)) + ax.set_yticks(np.arange(num_classes)) + ax.set_xticklabels(labels, fontsize=20) + ax.set_yticklabels(labels, fontsize=20) + + ax.tick_params( + axis='x', bottom=False, top=True, labelbottom=False, labeltop=True) + plt.setp( + ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor') + + # draw confusion matrix value + for i in range(num_classes): + for j in range(num_classes): + ax.text( + j, + i, + '{}%'.format( + round(confusion_matrix[i, j], 2 + ) if not np.isnan(confusion_matrix[i, j]) else -1), + ha='center', + va='center', + color='k', + size=20) + + ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1 + + fig.tight_layout() + if save_dir is not None: + mkdir_or_exist(save_dir) + plt.savefig( + os.path.join(save_dir, 'confusion_matrix.png'), format='png') + if show: + plt.show() + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + results = [] + for img in sorted(os.listdir(args.prediction_path)): + img = os.path.join(args.prediction_path, img) + image = Image.open(img) + image = np.copy(image) + results.append(image) + + assert isinstance(results, list) + if isinstance(results[0], np.ndarray): + pass + else: + raise TypeError('invalid type of prediction results') + + dataset = DATASETS.build(cfg.test_dataloader.dataset) + confusion_matrix = calculate_confusion_matrix(dataset, results) + plot_confusion_matrix( + confusion_matrix, + dataset.METAINFO['classes'], + save_dir=args.save_dir, + show=args.show, + title=args.title, + color_theme=args.color_theme) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/analysis_tools/get_flops.py b/finetune/tools/analysis_tools/get_flops.py new file mode 100644 index 0000000..78a7398 --- /dev/null +++ b/finetune/tools/analysis_tools/get_flops.py @@ -0,0 +1,124 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import tempfile +from pathlib import Path + +import torch +from mmengine import Config, DictAction +from mmengine.logging import MMLogger +from mmengine.model import revert_sync_batchnorm +from mmengine.registry import init_default_scope + +from mmseg.models import BaseSegmentor +from mmseg.registry import MODELS +from mmseg.structures import SegDataSample + +try: + from mmengine.analysis import get_model_complexity_info + from mmengine.analysis.print_helper import _format_size +except ImportError: + raise ImportError('Please upgrade mmengine >= 0.6.0 to use this script.') + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Get the FLOPs of a segmentor') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[2048, 1024], + help='input image size') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def inference(args: argparse.Namespace, logger: MMLogger) -> dict: + config_name = Path(args.config) + + if not config_name.exists(): + logger.error(f'Config file {config_name} does not exist') + + cfg: Config = Config.fromfile(config_name) + cfg.work_dir = tempfile.TemporaryDirectory().name + cfg.log_level = 'WARN' + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + init_default_scope(cfg.get('scope', 'mmseg')) + + if len(args.shape) == 1: + input_shape = (3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = (3, ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + result = {} + + model: BaseSegmentor = MODELS.build(cfg.model) + if hasattr(model, 'auxiliary_head'): + model.auxiliary_head = None + if torch.cuda.is_available(): + model.cuda() + model = revert_sync_batchnorm(model) + result['ori_shape'] = input_shape[-2:] + result['pad_shape'] = input_shape[-2:] + data_batch = { + 'inputs': [torch.rand(input_shape)], + 'data_samples': [SegDataSample(metainfo=result)] + } + data = model.data_preprocessor(data_batch) + model.eval() + if cfg.model.decode_head.type in ['MaskFormerHead', 'Mask2FormerHead']: + # TODO: Support MaskFormer and Mask2Former + raise NotImplementedError('MaskFormer and Mask2Former are not ' + 'supported yet.') + outputs = get_model_complexity_info( + model, + input_shape=None, + inputs=data['inputs'], + show_table=False, + show_arch=False) + result['flops'] = _format_size(outputs['flops']) + result['params'] = _format_size(outputs['params']) + result['compute_type'] = 'direct: randomly generate a picture' + return result + + +def main(): + + args = parse_args() + logger = MMLogger.get_instance(name='MMLogger') + + result = inference(args, logger) + split_line = '=' * 30 + ori_shape = result['ori_shape'] + pad_shape = result['pad_shape'] + flops = result['flops'] + params = result['params'] + compute_type = result['compute_type'] + + if pad_shape != ori_shape: + print(f'{split_line}\nUse size divisor set input shape ' + f'from {ori_shape} to {pad_shape}') + print(f'{split_line}\nCompute type: {compute_type}\n' + f'Input shape: {pad_shape}\nFlops: {flops}\n' + f'Params: {params}\n{split_line}') + print('!!!Please be cautious if you use the results in papers. ' + 'You may need to check if all ops are supported and verify ' + 'that the flops computation is correct.') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/analysis_tools/visualization_cam.py b/finetune/tools/analysis_tools/visualization_cam.py new file mode 100644 index 0000000..00cdb3e --- /dev/null +++ b/finetune/tools/analysis_tools/visualization_cam.py @@ -0,0 +1,127 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM). + +requirement: pip install grad-cam +""" + +from argparse import ArgumentParser + +import numpy as np +import torch +import torch.nn.functional as F +from mmengine import Config +from mmengine.model import revert_sync_batchnorm +from PIL import Image +from pytorch_grad_cam import GradCAM +from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image + +from mmseg.apis import inference_model, init_model, show_result_pyplot +from mmseg.utils import register_all_modules + + +class SemanticSegmentationTarget: + """wrap the model. + + requirement: pip install grad-cam + + Args: + category (int): Visualization class. + mask (ndarray): Mask of class. + size (tuple): Image size. + """ + + def __init__(self, category, mask, size): + self.category = category + self.mask = torch.from_numpy(mask) + self.size = size + if torch.cuda.is_available(): + self.mask = self.mask.cuda() + + def __call__(self, model_output): + model_output = torch.unsqueeze(model_output, dim=0) + model_output = F.interpolate( + model_output, size=self.size, mode='bilinear') + model_output = torch.squeeze(model_output, dim=0) + + return (model_output[self.category, :, :] * self.mask).sum() + + +def main(): + parser = ArgumentParser() + parser.add_argument('img', help='Image file') + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument( + '--out-file', + default='prediction.png', + help='Path to output prediction file') + parser.add_argument( + '--cam-file', default='vis_cam.png', help='Path to output cam file') + parser.add_argument( + '--target-layers', + default='backbone.layer4[2]', + help='Target layers to visualize CAM') + parser.add_argument( + '--category-index', default='7', help='Category to visualize CAM') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + args = parser.parse_args() + + # build the model from a config file and a checkpoint file + register_all_modules() + model = init_model(args.config, args.checkpoint, device=args.device) + if args.device == 'cpu': + model = revert_sync_batchnorm(model) + + # test a single image + result = inference_model(model, args.img) + + # show the results + show_result_pyplot( + model, + args.img, + result, + draw_gt=False, + show=False if args.out_file is not None else True, + out_file=args.out_file) + + # result data conversion + prediction_data = result.pred_sem_seg.data + pre_np_data = prediction_data.cpu().numpy().squeeze(0) + + target_layers = args.target_layers + target_layers = [eval(f'model.{target_layers}')] + + category = int(args.category_index) + mask_float = np.float32(pre_np_data == category) + + # data processing + image = np.array(Image.open(args.img).convert('RGB')) + height, width = image.shape[0], image.shape[1] + rgb_img = np.float32(image) / 255 + config = Config.fromfile(args.config) + image_mean = config.data_preprocessor['mean'] + image_std = config.data_preprocessor['std'] + input_tensor = preprocess_image( + rgb_img, + mean=[x / 255 for x in image_mean], + std=[x / 255 for x in image_std]) + + # Grad CAM(Class Activation Maps) + # Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM + targets = [ + SemanticSegmentationTarget(category, mask_float, (height, width)) + ] + with GradCAM( + model=model, + target_layers=target_layers, + use_cuda=torch.cuda.is_available()) as cam: + grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :] + cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) + + # save cam file + Image.fromarray(cam_image).save(args.cam_file) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/chase_db1.py b/finetune/tools/dataset_converters/chase_db1.py new file mode 100644 index 0000000..f4fefbd --- /dev/null +++ b/finetune/tools/dataset_converters/chase_db1.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import tempfile +import zipfile + +import mmcv +from mmengine.utils import mkdir_or_exist + +CHASE_DB1_LEN = 28 * 3 +TRAINING_LEN = 60 + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert CHASE_DB1 dataset to mmsegmentation format') + parser.add_argument('dataset_path', help='path of CHASEDB1.zip') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + dataset_path = args.dataset_path + if args.out_dir is None: + out_dir = osp.join('data', 'CHASE_DB1') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(out_dir) + mkdir_or_exist(osp.join(out_dir, 'images')) + mkdir_or_exist(osp.join(out_dir, 'images', 'training')) + mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) + mkdir_or_exist(osp.join(out_dir, 'annotations')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + print('Extracting CHASEDB1.zip...') + zip_file = zipfile.ZipFile(dataset_path) + zip_file.extractall(tmp_dir) + + print('Generating training dataset...') + + assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \ + f'len(os.listdir(tmp_dir)) != {CHASE_DB1_LEN}' + + for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: + img = mmcv.imread(osp.join(tmp_dir, img_name)) + if osp.splitext(img_name)[1] == '.jpg': + mmcv.imwrite( + img, + osp.join(out_dir, 'images', 'training', + osp.splitext(img_name)[0] + '.png')) + else: + # The annotation img should be divided by 128, because some of + # the annotation imgs are not standard. We should set a + # threshold to convert the nonstandard annotation imgs. The + # value divided by 128 is equivalent to '1 if value >= 128 + # else 0' + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'training', + osp.splitext(img_name)[0] + '.png')) + + for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: + img = mmcv.imread(osp.join(tmp_dir, img_name)) + if osp.splitext(img_name)[1] == '.jpg': + mmcv.imwrite( + img, + osp.join(out_dir, 'images', 'validation', + osp.splitext(img_name)[0] + '.png')) + else: + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'validation', + osp.splitext(img_name)[0] + '.png')) + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/cityscapes.py b/finetune/tools/dataset_converters/cityscapes.py new file mode 100644 index 0000000..0d6a801 --- /dev/null +++ b/finetune/tools/dataset_converters/cityscapes.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +from cityscapesscripts.preparation.json2labelImg import json2labelImg +from mmengine.utils import (mkdir_or_exist, scandir, track_parallel_progress, + track_progress) + + +def convert_json_to_label(json_file): + label_file = json_file.replace('_polygons.json', '_labelTrainIds.png') + json2labelImg(json_file, label_file, 'trainIds') + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert Cityscapes annotations to TrainIds') + parser.add_argument('cityscapes_path', help='cityscapes data path') + parser.add_argument('--gt-dir', default='gtFine', type=str) + parser.add_argument('-o', '--out-dir', help='output path') + parser.add_argument( + '--nproc', default=1, type=int, help='number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + cityscapes_path = args.cityscapes_path + out_dir = args.out_dir if args.out_dir else cityscapes_path + mkdir_or_exist(out_dir) + + gt_dir = osp.join(cityscapes_path, args.gt_dir) + + poly_files = [] + for poly in scandir(gt_dir, '_polygons.json', recursive=True): + poly_file = osp.join(gt_dir, poly) + poly_files.append(poly_file) + if args.nproc > 1: + track_parallel_progress(convert_json_to_label, poly_files, args.nproc) + else: + track_progress(convert_json_to_label, poly_files) + + split_names = ['train', 'val', 'test'] + + for split in split_names: + filenames = [] + for poly in scandir( + osp.join(gt_dir, split), '_polygons.json', recursive=True): + filenames.append(poly.replace('_gtFine_polygons.json', '')) + with open(osp.join(out_dir, f'{split}.txt'), 'w') as f: + f.writelines(f + '\n' for f in filenames) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/coco_stuff10k.py b/finetune/tools/dataset_converters/coco_stuff10k.py new file mode 100644 index 0000000..920127e --- /dev/null +++ b/finetune/tools/dataset_converters/coco_stuff10k.py @@ -0,0 +1,308 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +import shutil +from functools import partial + +import numpy as np +from mmengine.utils import (mkdir_or_exist, track_parallel_progress, + track_progress) +from PIL import Image +from scipy.io import loadmat + +COCO_LEN = 10000 + +clsID_to_trID = { + 0: 0, + 1: 1, + 2: 2, + 3: 3, + 4: 4, + 5: 5, + 6: 6, + 7: 7, + 8: 8, + 9: 9, + 10: 10, + 11: 11, + 13: 12, + 14: 13, + 15: 14, + 16: 15, + 17: 16, + 18: 17, + 19: 18, + 20: 19, + 21: 20, + 22: 21, + 23: 22, + 24: 23, + 25: 24, + 27: 25, + 28: 26, + 31: 27, + 32: 28, + 33: 29, + 34: 30, + 35: 31, + 36: 32, + 37: 33, + 38: 34, + 39: 35, + 40: 36, + 41: 37, + 42: 38, + 43: 39, + 44: 40, + 46: 41, + 47: 42, + 48: 43, + 49: 44, + 50: 45, + 51: 46, + 52: 47, + 53: 48, + 54: 49, + 55: 50, + 56: 51, + 57: 52, + 58: 53, + 59: 54, + 60: 55, + 61: 56, + 62: 57, + 63: 58, + 64: 59, + 65: 60, + 67: 61, + 70: 62, + 72: 63, + 73: 64, + 74: 65, + 75: 66, + 76: 67, + 77: 68, + 78: 69, + 79: 70, + 80: 71, + 81: 72, + 82: 73, + 84: 74, + 85: 75, + 86: 76, + 87: 77, + 88: 78, + 89: 79, + 90: 80, + 92: 81, + 93: 82, + 94: 83, + 95: 84, + 96: 85, + 97: 86, + 98: 87, + 99: 88, + 100: 89, + 101: 90, + 102: 91, + 103: 92, + 104: 93, + 105: 94, + 106: 95, + 107: 96, + 108: 97, + 109: 98, + 110: 99, + 111: 100, + 112: 101, + 113: 102, + 114: 103, + 115: 104, + 116: 105, + 117: 106, + 118: 107, + 119: 108, + 120: 109, + 121: 110, + 122: 111, + 123: 112, + 124: 113, + 125: 114, + 126: 115, + 127: 116, + 128: 117, + 129: 118, + 130: 119, + 131: 120, + 132: 121, + 133: 122, + 134: 123, + 135: 124, + 136: 125, + 137: 126, + 138: 127, + 139: 128, + 140: 129, + 141: 130, + 142: 131, + 143: 132, + 144: 133, + 145: 134, + 146: 135, + 147: 136, + 148: 137, + 149: 138, + 150: 139, + 151: 140, + 152: 141, + 153: 142, + 154: 143, + 155: 144, + 156: 145, + 157: 146, + 158: 147, + 159: 148, + 160: 149, + 161: 150, + 162: 151, + 163: 152, + 164: 153, + 165: 154, + 166: 155, + 167: 156, + 168: 157, + 169: 158, + 170: 159, + 171: 160, + 172: 161, + 173: 162, + 174: 163, + 175: 164, + 176: 165, + 177: 166, + 178: 167, + 179: 168, + 180: 169, + 181: 170, + 182: 171 +} + + +def convert_to_trainID(tuple_path, in_img_dir, in_ann_dir, out_img_dir, + out_mask_dir, is_train): + imgpath, maskpath = tuple_path + shutil.copyfile( + osp.join(in_img_dir, imgpath), + osp.join(out_img_dir, 'train2014', imgpath) if is_train else osp.join( + out_img_dir, 'test2014', imgpath)) + annotate = loadmat(osp.join(in_ann_dir, maskpath)) + mask = annotate['S'].astype(np.uint8) + mask_copy = mask.copy() + for clsID, trID in clsID_to_trID.items(): + mask_copy[mask == clsID] = trID + seg_filename = osp.join(out_mask_dir, 'train2014', + maskpath.split('.')[0] + + '_labelTrainIds.png') if is_train else osp.join( + out_mask_dir, 'test2014', + maskpath.split('.')[0] + '_labelTrainIds.png') + Image.fromarray(mask_copy).save(seg_filename, 'PNG') + + +def generate_coco_list(folder): + train_list = osp.join(folder, 'imageLists', 'train.txt') + test_list = osp.join(folder, 'imageLists', 'test.txt') + train_paths = [] + test_paths = [] + + with open(train_list) as f: + for filename in f: + basename = filename.strip() + imgpath = basename + '.jpg' + maskpath = basename + '.mat' + train_paths.append((imgpath, maskpath)) + + with open(test_list) as f: + for filename in f: + basename = filename.strip() + imgpath = basename + '.jpg' + maskpath = basename + '.mat' + test_paths.append((imgpath, maskpath)) + + return train_paths, test_paths + + +def parse_args(): + parser = argparse.ArgumentParser( + description=\ + 'Convert COCO Stuff 10k annotations to mmsegmentation format') # noqa + parser.add_argument('coco_path', help='coco stuff path') + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--nproc', default=16, type=int, help='number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + coco_path = args.coco_path + nproc = args.nproc + + out_dir = args.out_dir or coco_path + out_img_dir = osp.join(out_dir, 'images') + out_mask_dir = osp.join(out_dir, 'annotations') + + mkdir_or_exist(osp.join(out_img_dir, 'train2014')) + mkdir_or_exist(osp.join(out_img_dir, 'test2014')) + mkdir_or_exist(osp.join(out_mask_dir, 'train2014')) + mkdir_or_exist(osp.join(out_mask_dir, 'test2014')) + + train_list, test_list = generate_coco_list(coco_path) + assert (len(train_list) + + len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( + len(train_list), len(test_list)) + + if args.nproc > 1: + track_parallel_progress( + partial( + convert_to_trainID, + in_img_dir=osp.join(coco_path, 'images'), + in_ann_dir=osp.join(coco_path, 'annotations'), + out_img_dir=out_img_dir, + out_mask_dir=out_mask_dir, + is_train=True), + train_list, + nproc=nproc) + track_parallel_progress( + partial( + convert_to_trainID, + in_img_dir=osp.join(coco_path, 'images'), + in_ann_dir=osp.join(coco_path, 'annotations'), + out_img_dir=out_img_dir, + out_mask_dir=out_mask_dir, + is_train=False), + test_list, + nproc=nproc) + else: + track_progress( + partial( + convert_to_trainID, + in_img_dir=osp.join(coco_path, 'images'), + in_ann_dir=osp.join(coco_path, 'annotations'), + out_img_dir=out_img_dir, + out_mask_dir=out_mask_dir, + is_train=True), train_list) + track_progress( + partial( + convert_to_trainID, + in_img_dir=osp.join(coco_path, 'images'), + in_ann_dir=osp.join(coco_path, 'annotations'), + out_img_dir=out_img_dir, + out_mask_dir=out_mask_dir, + is_train=False), test_list) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/coco_stuff164k.py b/finetune/tools/dataset_converters/coco_stuff164k.py new file mode 100644 index 0000000..a13114a --- /dev/null +++ b/finetune/tools/dataset_converters/coco_stuff164k.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +import shutil +from functools import partial +from glob import glob + +import numpy as np +from mmengine.utils import (mkdir_or_exist, track_parallel_progress, + track_progress) +from PIL import Image + +COCO_LEN = 123287 + +clsID_to_trID = { + 0: 0, + 1: 1, + 2: 2, + 3: 3, + 4: 4, + 5: 5, + 6: 6, + 7: 7, + 8: 8, + 9: 9, + 10: 10, + 12: 11, + 13: 12, + 14: 13, + 15: 14, + 16: 15, + 17: 16, + 18: 17, + 19: 18, + 20: 19, + 21: 20, + 22: 21, + 23: 22, + 24: 23, + 26: 24, + 27: 25, + 30: 26, + 31: 27, + 32: 28, + 33: 29, + 34: 30, + 35: 31, + 36: 32, + 37: 33, + 38: 34, + 39: 35, + 40: 36, + 41: 37, + 42: 38, + 43: 39, + 45: 40, + 46: 41, + 47: 42, + 48: 43, + 49: 44, + 50: 45, + 51: 46, + 52: 47, + 53: 48, + 54: 49, + 55: 50, + 56: 51, + 57: 52, + 58: 53, + 59: 54, + 60: 55, + 61: 56, + 62: 57, + 63: 58, + 64: 59, + 66: 60, + 69: 61, + 71: 62, + 72: 63, + 73: 64, + 74: 65, + 75: 66, + 76: 67, + 77: 68, + 78: 69, + 79: 70, + 80: 71, + 81: 72, + 83: 73, + 84: 74, + 85: 75, + 86: 76, + 87: 77, + 88: 78, + 89: 79, + 91: 80, + 92: 81, + 93: 82, + 94: 83, + 95: 84, + 96: 85, + 97: 86, + 98: 87, + 99: 88, + 100: 89, + 101: 90, + 102: 91, + 103: 92, + 104: 93, + 105: 94, + 106: 95, + 107: 96, + 108: 97, + 109: 98, + 110: 99, + 111: 100, + 112: 101, + 113: 102, + 114: 103, + 115: 104, + 116: 105, + 117: 106, + 118: 107, + 119: 108, + 120: 109, + 121: 110, + 122: 111, + 123: 112, + 124: 113, + 125: 114, + 126: 115, + 127: 116, + 128: 117, + 129: 118, + 130: 119, + 131: 120, + 132: 121, + 133: 122, + 134: 123, + 135: 124, + 136: 125, + 137: 126, + 138: 127, + 139: 128, + 140: 129, + 141: 130, + 142: 131, + 143: 132, + 144: 133, + 145: 134, + 146: 135, + 147: 136, + 148: 137, + 149: 138, + 150: 139, + 151: 140, + 152: 141, + 153: 142, + 154: 143, + 155: 144, + 156: 145, + 157: 146, + 158: 147, + 159: 148, + 160: 149, + 161: 150, + 162: 151, + 163: 152, + 164: 153, + 165: 154, + 166: 155, + 167: 156, + 168: 157, + 169: 158, + 170: 159, + 171: 160, + 172: 161, + 173: 162, + 174: 163, + 175: 164, + 176: 165, + 177: 166, + 178: 167, + 179: 168, + 180: 169, + 181: 170, + 255: 255 +} + + +def convert_to_trainID(maskpath, out_mask_dir, is_train): + mask = np.array(Image.open(maskpath)) + mask_copy = mask.copy() + for clsID, trID in clsID_to_trID.items(): + mask_copy[mask == clsID] = trID + seg_filename = osp.join( + out_mask_dir, 'train2017', + osp.basename(maskpath).split('.')[0] + + '_labelTrainIds.png') if is_train else osp.join( + out_mask_dir, 'val2017', + osp.basename(maskpath).split('.')[0] + '_labelTrainIds.png') + Image.fromarray(mask_copy).save(seg_filename, 'PNG') + + +def parse_args(): + parser = argparse.ArgumentParser( + description=\ + 'Convert COCO Stuff 164k annotations to mmsegmentation format') # noqa + parser.add_argument('coco_path', help='coco stuff path') + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--nproc', default=16, type=int, help='number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + coco_path = args.coco_path + nproc = args.nproc + + out_dir = args.out_dir or coco_path + out_img_dir = osp.join(out_dir, 'images') + out_mask_dir = osp.join(out_dir, 'annotations') + + mkdir_or_exist(osp.join(out_mask_dir, 'train2017')) + mkdir_or_exist(osp.join(out_mask_dir, 'val2017')) + + if out_dir != coco_path: + shutil.copytree(osp.join(coco_path, 'images'), out_img_dir) + + train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png')) + train_list = [file for file in train_list if '_labelTrainIds' not in file] + test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png')) + test_list = [file for file in test_list if '_labelTrainIds' not in file] + assert (len(train_list) + + len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( + len(train_list), len(test_list)) + + if args.nproc > 1: + track_parallel_progress( + partial( + convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), + train_list, + nproc=nproc) + track_parallel_progress( + partial( + convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), + test_list, + nproc=nproc) + else: + track_progress( + partial( + convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), + train_list) + track_progress( + partial( + convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), + test_list) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/drive.py b/finetune/tools/dataset_converters/drive.py new file mode 100644 index 0000000..076fd05 --- /dev/null +++ b/finetune/tools/dataset_converters/drive.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import tempfile +import zipfile + +import cv2 +import mmcv +from mmengine.utils import mkdir_or_exist + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert DRIVE dataset to mmsegmentation format') + parser.add_argument( + 'training_path', help='the training part of DRIVE dataset') + parser.add_argument( + 'testing_path', help='the testing part of DRIVE dataset') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + training_path = args.training_path + testing_path = args.testing_path + if args.out_dir is None: + out_dir = osp.join('data', 'DRIVE') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(out_dir) + mkdir_or_exist(osp.join(out_dir, 'images')) + mkdir_or_exist(osp.join(out_dir, 'images', 'training')) + mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) + mkdir_or_exist(osp.join(out_dir, 'annotations')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + print('Extracting training.zip...') + zip_file = zipfile.ZipFile(training_path) + zip_file.extractall(tmp_dir) + + print('Generating training dataset...') + now_dir = osp.join(tmp_dir, 'training', 'images') + for img_name in os.listdir(now_dir): + img = mmcv.imread(osp.join(now_dir, img_name)) + mmcv.imwrite( + img, + osp.join( + out_dir, 'images', 'training', + osp.splitext(img_name)[0].replace('_training', '') + + '.png')) + + now_dir = osp.join(tmp_dir, 'training', '1st_manual') + for img_name in os.listdir(now_dir): + cap = cv2.VideoCapture(osp.join(now_dir, img_name)) + ret, img = cap.read() + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'training', + osp.splitext(img_name)[0] + '.png')) + + print('Extracting test.zip...') + zip_file = zipfile.ZipFile(testing_path) + zip_file.extractall(tmp_dir) + + print('Generating validation dataset...') + now_dir = osp.join(tmp_dir, 'test', 'images') + for img_name in os.listdir(now_dir): + img = mmcv.imread(osp.join(now_dir, img_name)) + mmcv.imwrite( + img, + osp.join( + out_dir, 'images', 'validation', + osp.splitext(img_name)[0].replace('_test', '') + '.png')) + + now_dir = osp.join(tmp_dir, 'test', '1st_manual') + if osp.exists(now_dir): + for img_name in os.listdir(now_dir): + cap = cv2.VideoCapture(osp.join(now_dir, img_name)) + ret, img = cap.read() + # The annotation img should be divided by 128, because some of + # the annotation imgs are not standard. We should set a + # threshold to convert the nonstandard annotation imgs. The + # value divided by 128 is equivalent to '1 if value >= 128 + # else 0' + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'validation', + osp.splitext(img_name)[0] + '.png')) + + now_dir = osp.join(tmp_dir, 'test', '2nd_manual') + if osp.exists(now_dir): + for img_name in os.listdir(now_dir): + cap = cv2.VideoCapture(osp.join(now_dir, img_name)) + ret, img = cap.read() + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'validation', + osp.splitext(img_name)[0] + '.png')) + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/hrf.py b/finetune/tools/dataset_converters/hrf.py new file mode 100644 index 0000000..3bfd80c --- /dev/null +++ b/finetune/tools/dataset_converters/hrf.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import tempfile +import zipfile + +import mmcv +from mmengine.utils import mkdir_or_exist + +HRF_LEN = 15 +TRAINING_LEN = 5 + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert HRF dataset to mmsegmentation format') + parser.add_argument('healthy_path', help='the path of healthy.zip') + parser.add_argument( + 'healthy_manualsegm_path', help='the path of healthy_manualsegm.zip') + parser.add_argument('glaucoma_path', help='the path of glaucoma.zip') + parser.add_argument( + 'glaucoma_manualsegm_path', help='the path of glaucoma_manualsegm.zip') + parser.add_argument( + 'diabetic_retinopathy_path', + help='the path of diabetic_retinopathy.zip') + parser.add_argument( + 'diabetic_retinopathy_manualsegm_path', + help='the path of diabetic_retinopathy_manualsegm.zip') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + images_path = [ + args.healthy_path, args.glaucoma_path, args.diabetic_retinopathy_path + ] + annotations_path = [ + args.healthy_manualsegm_path, args.glaucoma_manualsegm_path, + args.diabetic_retinopathy_manualsegm_path + ] + if args.out_dir is None: + out_dir = osp.join('data', 'HRF') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(out_dir) + mkdir_or_exist(osp.join(out_dir, 'images')) + mkdir_or_exist(osp.join(out_dir, 'images', 'training')) + mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) + mkdir_or_exist(osp.join(out_dir, 'annotations')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) + + print('Generating images...') + for now_path in images_path: + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + zip_file = zipfile.ZipFile(now_path) + zip_file.extractall(tmp_dir) + + assert len(os.listdir(tmp_dir)) == HRF_LEN, \ + f'len(os.listdir(tmp_dir)) != {HRF_LEN}' + + for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: + img = mmcv.imread(osp.join(tmp_dir, filename)) + mmcv.imwrite( + img, + osp.join(out_dir, 'images', 'training', + osp.splitext(filename)[0] + '.png')) + for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: + img = mmcv.imread(osp.join(tmp_dir, filename)) + mmcv.imwrite( + img, + osp.join(out_dir, 'images', 'validation', + osp.splitext(filename)[0] + '.png')) + + print('Generating annotations...') + for now_path in annotations_path: + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + zip_file = zipfile.ZipFile(now_path) + zip_file.extractall(tmp_dir) + + assert len(os.listdir(tmp_dir)) == HRF_LEN, \ + f'len(os.listdir(tmp_dir)) != {HRF_LEN}' + + for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: + img = mmcv.imread(osp.join(tmp_dir, filename)) + # The annotation img should be divided by 128, because some of + # the annotation imgs are not standard. We should set a + # threshold to convert the nonstandard annotation imgs. The + # value divided by 128 is equivalent to '1 if value >= 128 + # else 0' + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'training', + osp.splitext(filename)[0] + '.png')) + for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: + img = mmcv.imread(osp.join(tmp_dir, filename)) + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'validation', + osp.splitext(filename)[0] + '.png')) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/isaid.py b/finetune/tools/dataset_converters/isaid.py new file mode 100644 index 0000000..1d5ccd9 --- /dev/null +++ b/finetune/tools/dataset_converters/isaid.py @@ -0,0 +1,246 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import glob +import os +import os.path as osp +import shutil +import tempfile +import zipfile + +import mmcv +import numpy as np +from mmengine.utils import ProgressBar, mkdir_or_exist +from PIL import Image + +iSAID_palette = \ + { + 0: (0, 0, 0), + 1: (0, 0, 63), + 2: (0, 63, 63), + 3: (0, 63, 0), + 4: (0, 63, 127), + 5: (0, 63, 191), + 6: (0, 63, 255), + 7: (0, 127, 63), + 8: (0, 127, 127), + 9: (0, 0, 127), + 10: (0, 0, 191), + 11: (0, 0, 255), + 12: (0, 191, 127), + 13: (0, 127, 191), + 14: (0, 127, 255), + 15: (0, 100, 155) + } + +iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()} + + +def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette): + """RGB-color encoding to grayscale labels.""" + arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8) + + for c, i in palette.items(): + m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2) + arr_2d[m] = i + + return arr_2d + + +def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap): + img = np.asarray(Image.open(src_path).convert('RGB')) + + img_H, img_W, _ = img.shape + + if img_H < patch_H and img_W > patch_W: + + img = mmcv.impad(img, shape=(patch_H, img_W), pad_val=0) + + img_H, img_W, _ = img.shape + + elif img_H > patch_H and img_W < patch_W: + + img = mmcv.impad(img, shape=(img_H, patch_W), pad_val=0) + + img_H, img_W, _ = img.shape + + elif img_H < patch_H and img_W < patch_W: + + img = mmcv.impad(img, shape=(patch_H, patch_W), pad_val=0) + + img_H, img_W, _ = img.shape + + for x in range(0, img_W, patch_W - overlap): + for y in range(0, img_H, patch_H - overlap): + x_str = x + x_end = x + patch_W + if x_end > img_W: + diff_x = x_end - img_W + x_str -= diff_x + x_end = img_W + y_str = y + y_end = y + patch_H + if y_end > img_H: + diff_y = y_end - img_H + y_str -= diff_y + y_end = img_H + + img_patch = img[y_str:y_end, x_str:x_end, :] + img_patch = Image.fromarray(img_patch.astype(np.uint8)) + image = osp.basename(src_path).split('.')[0] + '_' + str( + y_str) + '_' + str(y_end) + '_' + str(x_str) + '_' + str( + x_end) + '.png' + # print(image) + save_path_image = osp.join(out_dir, 'img_dir', mode, str(image)) + img_patch.save(save_path_image, format='BMP') + + +def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap): + label = mmcv.imread(src_path, channel_order='rgb') + label = iSAID_convert_from_color(label) + img_H, img_W = label.shape + + if img_H < patch_H and img_W > patch_W: + + label = mmcv.impad(label, shape=(patch_H, img_W), pad_val=255) + + img_H = patch_H + + elif img_H > patch_H and img_W < patch_W: + + label = mmcv.impad(label, shape=(img_H, patch_W), pad_val=255) + + img_W = patch_W + + elif img_H < patch_H and img_W < patch_W: + + label = mmcv.impad(label, shape=(patch_H, patch_W), pad_val=255) + + img_H = patch_H + img_W = patch_W + + for x in range(0, img_W, patch_W - overlap): + for y in range(0, img_H, patch_H - overlap): + x_str = x + x_end = x + patch_W + if x_end > img_W: + diff_x = x_end - img_W + x_str -= diff_x + x_end = img_W + y_str = y + y_end = y + patch_H + if y_end > img_H: + diff_y = y_end - img_H + y_str -= diff_y + y_end = img_H + + lab_patch = label[y_str:y_end, x_str:x_end] + lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode='P') + + image = osp.basename(src_path).split('.')[0].split( + '_')[0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str( + x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png' + lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image))) + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert iSAID dataset to mmsegmentation format') + parser.add_argument('dataset_path', help='iSAID folder path') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + + parser.add_argument( + '--patch_width', + default=896, + type=int, + help='Width of the cropped image patch') + parser.add_argument( + '--patch_height', + default=896, + type=int, + help='Height of the cropped image patch') + parser.add_argument( + '--overlap_area', default=384, type=int, help='Overlap area') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + dataset_path = args.dataset_path + # image patch width and height + patch_H, patch_W = args.patch_width, args.patch_height + + overlap = args.overlap_area # overlap area + + if args.out_dir is None: + out_dir = osp.join('data', 'iSAID') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test')) + + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test')) + + assert os.path.exists(os.path.join(dataset_path, 'train')), \ + f'train is not in {dataset_path}' + assert os.path.exists(os.path.join(dataset_path, 'val')), \ + f'val is not in {dataset_path}' + assert os.path.exists(os.path.join(dataset_path, 'test')), \ + f'test is not in {dataset_path}' + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + for dataset_mode in ['train', 'val', 'test']: + + # for dataset_mode in [ 'test']: + print(f'Extracting {dataset_mode}ing.zip...') + img_zipp_list = glob.glob( + os.path.join(dataset_path, dataset_mode, 'images', '*.zip')) + print('Find the data', img_zipp_list) + for img_zipp in img_zipp_list: + zip_file = zipfile.ZipFile(img_zipp) + zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img')) + src_path_list = glob.glob( + os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png')) + + src_prog_bar = ProgressBar(len(src_path_list)) + for i, img_path in enumerate(src_path_list): + if dataset_mode != 'test': + slide_crop_image(img_path, out_dir, dataset_mode, patch_H, + patch_W, overlap) + + else: + shutil.move(img_path, + os.path.join(out_dir, 'img_dir', dataset_mode)) + src_prog_bar.update() + + if dataset_mode != 'test': + label_zipp_list = glob.glob( + os.path.join(dataset_path, dataset_mode, 'Semantic_masks', + '*.zip')) + for label_zipp in label_zipp_list: + zip_file = zipfile.ZipFile(label_zipp) + zip_file.extractall( + os.path.join(tmp_dir, dataset_mode, 'lab')) + + lab_path_list = glob.glob( + os.path.join(tmp_dir, dataset_mode, 'lab', 'images', + '*.png')) + lab_prog_bar = ProgressBar(len(lab_path_list)) + for i, lab_path in enumerate(lab_path_list): + slide_crop_label(lab_path, out_dir, dataset_mode, patch_H, + patch_W, overlap) + lab_prog_bar.update() + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/levircd.py b/finetune/tools/dataset_converters/levircd.py new file mode 100644 index 0000000..8717f3e --- /dev/null +++ b/finetune/tools/dataset_converters/levircd.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import glob +import math +import os +import os.path as osp + +import mmcv +import numpy as np +from mmengine.utils import ProgressBar + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert levir-cd dataset to mmsegmentation format') + parser.add_argument('--dataset_path', help='potsdam folder path') + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--clip_size', + type=int, + help='clipped size of image after preparation', + default=256) + parser.add_argument( + '--stride_size', + type=int, + help='stride of clipping original images', + default=256) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + input_folder = args.dataset_path + png_files = glob.glob( + os.path.join(input_folder, '**/*.png'), recursive=True) + output_folder = args.out_dir + prog_bar = ProgressBar(len(png_files)) + for png_file in png_files: + new_path = os.path.join( + output_folder, + os.path.relpath(os.path.dirname(png_file), input_folder)) + os.makedirs(os.path.dirname(new_path), exist_ok=True) + label = False + if 'label' in png_file: + label = True + clip_big_image(png_file, new_path, args, label) + prog_bar.update() + + +def clip_big_image(image_path, clip_save_dir, args, to_label=False): + image = mmcv.imread(image_path) + + h, w, c = image.shape + clip_size = args.clip_size + stride_size = args.stride_size + + num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil( + (h - clip_size) / + stride_size) * stride_size + clip_size >= h else math.ceil( + (h - clip_size) / stride_size) + 1 + num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil( + (w - clip_size) / + stride_size) * stride_size + clip_size >= w else math.ceil( + (w - clip_size) / stride_size) + 1 + + x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1)) + xmin = x * clip_size + ymin = y * clip_size + + xmin = xmin.ravel() + ymin = ymin.ravel() + xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size, + np.zeros_like(xmin)) + ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size, + np.zeros_like(ymin)) + boxes = np.stack([ + xmin + xmin_offset, ymin + ymin_offset, + np.minimum(xmin + clip_size, w), + np.minimum(ymin + clip_size, h) + ], + axis=1) + + if to_label: + image[image == 255] = 1 + image = image[:, :, 0] + for box in boxes: + start_x, start_y, end_x, end_y = box + clipped_image = image[start_y:end_y, start_x:end_x] \ + if to_label else image[start_y:end_y, start_x:end_x, :] + idx = osp.basename(image_path).split('.')[0] + mmcv.imwrite( + clipped_image.astype(np.uint8), + osp.join(clip_save_dir, + f'{idx}_{start_x}_{start_y}_{end_x}_{end_y}.png')) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/loveda.py b/finetune/tools/dataset_converters/loveda.py new file mode 100644 index 0000000..5b0ef4b --- /dev/null +++ b/finetune/tools/dataset_converters/loveda.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import shutil +import tempfile +import zipfile + +from mmengine.utils import mkdir_or_exist + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert LoveDA dataset to mmsegmentation format') + parser.add_argument('dataset_path', help='LoveDA folder path') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + dataset_path = args.dataset_path + if args.out_dir is None: + out_dir = osp.join('data', 'loveDA') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(out_dir) + mkdir_or_exist(osp.join(out_dir, 'img_dir')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + + assert 'Train.zip' in os.listdir(dataset_path), \ + f'Train.zip is not in {dataset_path}' + assert 'Val.zip' in os.listdir(dataset_path), \ + f'Val.zip is not in {dataset_path}' + assert 'Test.zip' in os.listdir(dataset_path), \ + f'Test.zip is not in {dataset_path}' + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + for dataset in ['Train', 'Val', 'Test']: + zip_file = zipfile.ZipFile( + os.path.join(dataset_path, dataset + '.zip')) + zip_file.extractall(tmp_dir) + data_type = dataset.lower() + for location in ['Rural', 'Urban']: + for image_type in ['images_png', 'masks_png']: + if image_type == 'images_png': + dst = osp.join(out_dir, 'img_dir', data_type) + else: + dst = osp.join(out_dir, 'ann_dir', data_type) + if dataset == 'Test' and image_type == 'masks_png': + continue + else: + src_dir = osp.join(tmp_dir, dataset, location, + image_type) + src_lst = os.listdir(src_dir) + for file in src_lst: + shutil.move(osp.join(src_dir, file), dst) + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/nyu.py b/finetune/tools/dataset_converters/nyu.py new file mode 100644 index 0000000..49e09e7 --- /dev/null +++ b/finetune/tools/dataset_converters/nyu.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +import shutil +import tempfile +import zipfile + +from mmengine.utils import mkdir_or_exist + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert NYU Depth dataset to mmsegmentation format') + parser.add_argument('raw_data', help='the path of raw data') + parser.add_argument( + '-o', '--out_dir', help='output path', default='./data/nyu') + args = parser.parse_args() + return args + + +def reorganize(raw_data_dir: str, out_dir: str): + """Reorganize NYU Depth dataset files into the required directory + structure. + + Args: + raw_data_dir (str): Path to the raw data directory. + out_dir (str): Output directory for the organized dataset. + """ + + def move_data(data_list, dst_prefix, fname_func): + """Move data files from source to destination directory. + + Args: + data_list (list): List of data file paths. + dst_prefix (str): Prefix to be added to destination paths. + fname_func (callable): Function to process file names + """ + for data_item in data_list: + data_item = data_item.strip().strip('/') + new_item = fname_func(data_item) + shutil.move( + osp.join(raw_data_dir, data_item), + osp.join(out_dir, dst_prefix, new_item)) + + def process_phase(phase): + """Process a dataset phase (e.g., 'train' or 'test').""" + with open(osp.join(raw_data_dir, f'nyu_{phase}.txt')) as f: + data = filter(lambda x: len(x.strip()) > 0, f.readlines()) + data = map(lambda x: x.split()[:2], data) + images, annos = zip(*data) + + move_data(images, f'images/{phase}', + lambda x: x.replace('/rgb', '')) + move_data(annos, f'annotations/{phase}', + lambda x: x.replace('/sync_depth', '')) + + process_phase('train') + process_phase('test') + + +def main(): + args = parse_args() + + print('Making directories...') + mkdir_or_exist(args.out_dir) + for subdir in [ + 'images/train', 'images/test', 'annotations/train', + 'annotations/test' + ]: + mkdir_or_exist(osp.join(args.out_dir, subdir)) + + print('Generating images and annotations...') + + if args.raw_data.endswith('.zip'): + with tempfile.TemporaryDirectory() as tmp_dir: + zip_file = zipfile.ZipFile(args.raw_data) + zip_file.extractall(tmp_dir) + reorganize(osp.join(tmp_dir, 'nyu'), args.out_dir) + else: + assert osp.isdir( + args.raw_data + ), 'the argument --raw-data should be either a zip file or directory.' + reorganize(args.raw_data, args.out_dir) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/pascal_context.py b/finetune/tools/dataset_converters/pascal_context.py new file mode 100644 index 0000000..a92d1dc --- /dev/null +++ b/finetune/tools/dataset_converters/pascal_context.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from functools import partial + +import numpy as np +from detail import Detail +from mmengine.utils import mkdir_or_exist, track_progress +from PIL import Image + +_mapping = np.sort( + np.array([ + 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284, + 158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59, + 440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355, + 85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115 + ])) +_key = np.array(range(len(_mapping))).astype('uint8') + + +def generate_labels(img_id, detail, out_dir): + + def _class_to_index(mask, _mapping, _key): + # assert the values + values = np.unique(mask) + for i in range(len(values)): + assert (values[i] in _mapping) + index = np.digitize(mask.ravel(), _mapping, right=True) + return _key[index].reshape(mask.shape) + + mask = Image.fromarray( + _class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key)) + filename = img_id['file_name'] + mask.save(osp.join(out_dir, filename.replace('jpg', 'png'))) + return osp.splitext(osp.basename(filename))[0] + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert PASCAL VOC annotations to mmsegmentation format') + parser.add_argument('devkit_path', help='pascal voc devkit path') + parser.add_argument('json_path', help='annoation json filepath') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + devkit_path = args.devkit_path + if args.out_dir is None: + out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext') + else: + out_dir = args.out_dir + json_path = args.json_path + mkdir_or_exist(out_dir) + img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages') + + train_detail = Detail(json_path, img_dir, 'train') + train_ids = train_detail.getImgs() + + val_detail = Detail(json_path, img_dir, 'val') + val_ids = val_detail.getImgs() + + mkdir_or_exist( + osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext')) + + train_list = track_progress( + partial(generate_labels, detail=train_detail, out_dir=out_dir), + train_ids) + with open( + osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', + 'train.txt'), 'w') as f: + f.writelines(line + '\n' for line in sorted(train_list)) + + val_list = track_progress( + partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids) + with open( + osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', + 'val.txt'), 'w') as f: + f.writelines(line + '\n' for line in sorted(val_list)) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/potsdam.py b/finetune/tools/dataset_converters/potsdam.py new file mode 100644 index 0000000..f3c713e --- /dev/null +++ b/finetune/tools/dataset_converters/potsdam.py @@ -0,0 +1,158 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import glob +import math +import os +import os.path as osp +import tempfile +import zipfile + +import mmcv +import numpy as np +from mmengine.utils import ProgressBar, mkdir_or_exist + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert potsdam dataset to mmsegmentation format') + parser.add_argument('dataset_path', help='potsdam folder path') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--clip_size', + type=int, + help='clipped size of image after preparation', + default=512) + parser.add_argument( + '--stride_size', + type=int, + help='stride of clipping original images', + default=256) + args = parser.parse_args() + return args + + +def clip_big_image(image_path, clip_save_dir, args, to_label=False): + # Original image of Potsdam dataset is very large, thus pre-processing + # of them is adopted. Given fixed clip size and stride size to generate + # clipped image, the intersection of width and height is determined. + # For example, given one 5120 x 5120 original image, the clip size is + # 512 and stride size is 256, thus it would generate 20x20 = 400 images + # whose size are all 512x512. + image = mmcv.imread(image_path) + + h, w, c = image.shape + clip_size = args.clip_size + stride_size = args.stride_size + + num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil( + (h - clip_size) / + stride_size) * stride_size + clip_size >= h else math.ceil( + (h - clip_size) / stride_size) + 1 + num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil( + (w - clip_size) / + stride_size) * stride_size + clip_size >= w else math.ceil( + (w - clip_size) / stride_size) + 1 + + x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1)) + xmin = x * clip_size + ymin = y * clip_size + + xmin = xmin.ravel() + ymin = ymin.ravel() + xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size, + np.zeros_like(xmin)) + ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size, + np.zeros_like(ymin)) + boxes = np.stack([ + xmin + xmin_offset, ymin + ymin_offset, + np.minimum(xmin + clip_size, w), + np.minimum(ymin + clip_size, h) + ], + axis=1) + + if to_label: + color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0], + [255, 255, 0], [0, 255, 0], [0, 255, 255], + [0, 0, 255]]) + flatten_v = np.matmul( + image.reshape(-1, c), + np.array([2, 3, 4]).reshape(3, 1)) + out = np.zeros_like(flatten_v) + for idx, class_color in enumerate(color_map): + value_idx = np.matmul(class_color, + np.array([2, 3, 4]).reshape(3, 1)) + out[flatten_v == value_idx] = idx + image = out.reshape(h, w) + + for box in boxes: + start_x, start_y, end_x, end_y = box + clipped_image = image[start_y:end_y, + start_x:end_x] if to_label else image[ + start_y:end_y, start_x:end_x, :] + idx_i, idx_j = osp.basename(image_path).split('_')[2:4] + mmcv.imwrite( + clipped_image.astype(np.uint8), + osp.join( + clip_save_dir, + f'{idx_i}_{idx_j}_{start_x}_{start_y}_{end_x}_{end_y}.png')) + + +def main(): + args = parse_args() + splits = { + 'train': [ + '2_10', '2_11', '2_12', '3_10', '3_11', '3_12', '4_10', '4_11', + '4_12', '5_10', '5_11', '5_12', '6_10', '6_11', '6_12', '6_7', + '6_8', '6_9', '7_10', '7_11', '7_12', '7_7', '7_8', '7_9' + ], + 'val': [ + '5_15', '6_15', '6_13', '3_13', '4_14', '6_14', '5_14', '2_13', + '4_15', '2_14', '5_13', '4_13', '3_14', '7_13' + ] + } + + dataset_path = args.dataset_path + if args.out_dir is None: + out_dir = osp.join('data', 'potsdam') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + + zipp_list = glob.glob(os.path.join(dataset_path, '*.zip')) + print('Find the data', zipp_list) + + for zipp in zipp_list: + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + zip_file = zipfile.ZipFile(zipp) + zip_file.extractall(tmp_dir) + src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif')) + if not len(src_path_list): + sub_tmp_dir = os.path.join(tmp_dir, os.listdir(tmp_dir)[0]) + src_path_list = glob.glob(os.path.join(sub_tmp_dir, '*.tif')) + + prog_bar = ProgressBar(len(src_path_list)) + for i, src_path in enumerate(src_path_list): + idx_i, idx_j = osp.basename(src_path).split('_')[2:4] + data_type = 'train' if f'{idx_i}_{idx_j}' in splits[ + 'train'] else 'val' + if 'label' in src_path: + dst_dir = osp.join(out_dir, 'ann_dir', data_type) + clip_big_image(src_path, dst_dir, args, to_label=True) + else: + dst_dir = osp.join(out_dir, 'img_dir', data_type) + clip_big_image(src_path, dst_dir, args, to_label=False) + prog_bar.update() + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/refuge.py b/finetune/tools/dataset_converters/refuge.py new file mode 100644 index 0000000..1186866 --- /dev/null +++ b/finetune/tools/dataset_converters/refuge.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import tempfile +import zipfile + +import mmcv +import numpy as np +from mmengine.utils import mkdir_or_exist + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert REFUGE dataset to mmsegmentation format') + parser.add_argument('--raw_data_root', help='the root path of raw data') + + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def extract_img(root: str, + cur_dir: str, + out_dir: str, + mode: str = 'train', + file_type: str = 'img') -> None: + """_summary_ + + Args: + Args: + root (str): root where the extracted data is saved + cur_dir (cur_dir): dir where the zip_file exists + out_dir (str): root dir where the data is saved + + mode (str, optional): Defaults to 'train'. + file_type (str, optional): Defaults to 'img',else to 'mask'. + """ + zip_file = zipfile.ZipFile(cur_dir) + zip_file.extractall(root) + for cur_dir, dirs, files in os.walk(root): + # filter child dirs and directories with "Illustration" and "MACOSX" + if len(dirs) == 0 and \ + cur_dir.split('\\')[-1].find('Illustration') == -1 and \ + cur_dir.find('MACOSX') == -1: + + file_names = [ + file for file in files + if file.endswith('.jpg') or file.endswith('.bmp') + ] + for filename in sorted(file_names): + img = mmcv.imread(osp.join(cur_dir, filename)) + + if file_type == 'annotations': + img = img[:, :, 0] + img[np.where(img == 0)] = 1 + img[np.where(img == 128)] = 2 + img[np.where(img == 255)] = 0 + mmcv.imwrite( + img, + osp.join(out_dir, file_type, mode, + osp.splitext(filename)[0] + '.png')) + + +def main(): + args = parse_args() + + raw_data_root = args.raw_data_root + if args.out_dir is None: + out_dir = osp.join('./data', 'REFUGE') + + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(out_dir) + mkdir_or_exist(osp.join(out_dir, 'images')) + mkdir_or_exist(osp.join(out_dir, 'images', 'training')) + mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) + mkdir_or_exist(osp.join(out_dir, 'images', 'test')) + mkdir_or_exist(osp.join(out_dir, 'annotations')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'test')) + + print('Generating images and annotations...') + # process data from the child dir on the first rank + cur_dir, dirs, files = list(os.walk(raw_data_root))[0] + print('====================') + + files = list(filter(lambda x: x.endswith('.zip'), files)) + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + for file in files: + # search data folders for training,validation,test + mode = list( + filter(lambda x: file.lower().find(x) != -1, + ['training', 'test', 'validation']))[0] + file_root = osp.join(tmp_dir, file[:-4]) + file_type = 'images' if file.find('Anno') == -1 and file.find( + 'GT') == -1 else 'annotations' + extract_img(file_root, osp.join(cur_dir, file), out_dir, mode, + file_type) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/stare.py b/finetune/tools/dataset_converters/stare.py new file mode 100644 index 0000000..4a23ba4 --- /dev/null +++ b/finetune/tools/dataset_converters/stare.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import gzip +import os +import os.path as osp +import tarfile +import tempfile + +import mmcv +from mmengine.utils import mkdir_or_exist + +STARE_LEN = 20 +TRAINING_LEN = 10 + + +def un_gz(src, dst): + g_file = gzip.GzipFile(src) + with open(dst, 'wb+') as f: + f.write(g_file.read()) + g_file.close() + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert STARE dataset to mmsegmentation format') + parser.add_argument('image_path', help='the path of stare-images.tar') + parser.add_argument('labels_ah', help='the path of labels-ah.tar') + parser.add_argument('labels_vk', help='the path of labels-vk.tar') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + image_path = args.image_path + labels_ah = args.labels_ah + labels_vk = args.labels_vk + if args.out_dir is None: + out_dir = osp.join('data', 'STARE') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(out_dir) + mkdir_or_exist(osp.join(out_dir, 'images')) + mkdir_or_exist(osp.join(out_dir, 'images', 'training')) + mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) + mkdir_or_exist(osp.join(out_dir, 'annotations')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) + mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + mkdir_or_exist(osp.join(tmp_dir, 'gz')) + mkdir_or_exist(osp.join(tmp_dir, 'files')) + + print('Extracting stare-images.tar...') + with tarfile.open(image_path) as f: + f.extractall(osp.join(tmp_dir, 'gz')) + + for filename in os.listdir(osp.join(tmp_dir, 'gz')): + un_gz( + osp.join(tmp_dir, 'gz', filename), + osp.join(tmp_dir, 'files', + osp.splitext(filename)[0])) + + now_dir = osp.join(tmp_dir, 'files') + + assert len(os.listdir(now_dir)) == STARE_LEN, \ + f'len(os.listdir(now_dir)) != {STARE_LEN}' + + for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: + img = mmcv.imread(osp.join(now_dir, filename)) + mmcv.imwrite( + img, + osp.join(out_dir, 'images', 'training', + osp.splitext(filename)[0] + '.png')) + + for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: + img = mmcv.imread(osp.join(now_dir, filename)) + mmcv.imwrite( + img, + osp.join(out_dir, 'images', 'validation', + osp.splitext(filename)[0] + '.png')) + + print('Removing the temporary files...') + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + mkdir_or_exist(osp.join(tmp_dir, 'gz')) + mkdir_or_exist(osp.join(tmp_dir, 'files')) + + print('Extracting labels-ah.tar...') + with tarfile.open(labels_ah) as f: + f.extractall(osp.join(tmp_dir, 'gz')) + + for filename in os.listdir(osp.join(tmp_dir, 'gz')): + un_gz( + osp.join(tmp_dir, 'gz', filename), + osp.join(tmp_dir, 'files', + osp.splitext(filename)[0])) + + now_dir = osp.join(tmp_dir, 'files') + + assert len(os.listdir(now_dir)) == STARE_LEN, \ + f'len(os.listdir(now_dir)) != {STARE_LEN}' + + for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: + img = mmcv.imread(osp.join(now_dir, filename)) + # The annotation img should be divided by 128, because some of + # the annotation imgs are not standard. We should set a threshold + # to convert the nonstandard annotation imgs. The value divided by + # 128 equivalent to '1 if value >= 128 else 0' + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'training', + osp.splitext(filename)[0] + '.png')) + + for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: + img = mmcv.imread(osp.join(now_dir, filename)) + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'validation', + osp.splitext(filename)[0] + '.png')) + + print('Removing the temporary files...') + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + mkdir_or_exist(osp.join(tmp_dir, 'gz')) + mkdir_or_exist(osp.join(tmp_dir, 'files')) + + print('Extracting labels-vk.tar...') + with tarfile.open(labels_vk) as f: + f.extractall(osp.join(tmp_dir, 'gz')) + + for filename in os.listdir(osp.join(tmp_dir, 'gz')): + un_gz( + osp.join(tmp_dir, 'gz', filename), + osp.join(tmp_dir, 'files', + osp.splitext(filename)[0])) + + now_dir = osp.join(tmp_dir, 'files') + + assert len(os.listdir(now_dir)) == STARE_LEN, \ + f'len(os.listdir(now_dir)) != {STARE_LEN}' + + for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: + img = mmcv.imread(osp.join(now_dir, filename)) + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'training', + osp.splitext(filename)[0] + '.png')) + + for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: + img = mmcv.imread(osp.join(now_dir, filename)) + mmcv.imwrite( + img[:, :, 0] // 128, + osp.join(out_dir, 'annotations', 'validation', + osp.splitext(filename)[0] + '.png')) + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/synapse.py b/finetune/tools/dataset_converters/synapse.py new file mode 100644 index 0000000..42dac6b --- /dev/null +++ b/finetune/tools/dataset_converters/synapse.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +import nibabel as nib +import numpy as np +from mmengine.utils import mkdir_or_exist +from PIL import Image + + +def read_files_from_txt(txt_path): + with open(txt_path) as f: + files = f.readlines() + files = [file.strip() for file in files] + return files + + +def read_nii_file(nii_path): + img = nib.load(nii_path).get_fdata() + return img + + +def split_3d_image(img): + c, _, _ = img.shape + res = [] + for i in range(c): + res.append(img[i, :, :]) + return res + + +def label_mapping(label): + """Label mapping from TransUNet paper setting. It only has 9 classes, which + are 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney', + 'liver', 'pancreas', 'spleen', 'stomach', respectively. Other foreground + classes in original dataset are all set to background. + + More details could be found here: https://arxiv.org/abs/2102.04306 + """ + maped_label = np.zeros_like(label) + maped_label[label == 8] = 1 + maped_label[label == 4] = 2 + maped_label[label == 3] = 3 + maped_label[label == 2] = 4 + maped_label[label == 6] = 5 + maped_label[label == 11] = 6 + maped_label[label == 1] = 7 + maped_label[label == 7] = 8 + return maped_label + + +def pares_args(): + parser = argparse.ArgumentParser( + description='Convert synapse dataset to mmsegmentation format') + parser.add_argument( + '--dataset-path', type=str, help='synapse dataset path.') + parser.add_argument( + '--save-path', + default='data/synapse', + type=str, + help='save path of the dataset.') + args = parser.parse_args() + return args + + +def main(): + args = pares_args() + dataset_path = args.dataset_path + save_path = args.save_path + + if not osp.exists(dataset_path): + raise ValueError('The dataset path does not exist. ' + 'Please enter a correct dataset path.') + if not osp.exists(osp.join(dataset_path, 'img')) \ + or not osp.exists(osp.join(dataset_path, 'label')): + raise FileNotFoundError('The dataset structure is incorrect. ' + 'Please check your dataset.') + + train_id = read_files_from_txt(osp.join(dataset_path, 'train.txt')) + train_id = [idx[3:7] for idx in train_id] + + test_id = read_files_from_txt(osp.join(dataset_path, 'val.txt')) + test_id = [idx[3:7] for idx in test_id] + + mkdir_or_exist(osp.join(save_path, 'img_dir/train')) + mkdir_or_exist(osp.join(save_path, 'img_dir/val')) + mkdir_or_exist(osp.join(save_path, 'ann_dir/train')) + mkdir_or_exist(osp.join(save_path, 'ann_dir/val')) + + # It follows data preparation pipeline from here: + # https://github.com/Beckschen/TransUNet/tree/main/datasets + for i, idx in enumerate(train_id): + img_3d = read_nii_file( + osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz')) + label_3d = read_nii_file( + osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz')) + + img_3d = np.clip(img_3d, -125, 275) + img_3d = (img_3d + 125) / 400 + img_3d *= 255 + img_3d = np.transpose(img_3d, [2, 0, 1]) + img_3d = np.flip(img_3d, 2) + + label_3d = np.transpose(label_3d, [2, 0, 1]) + label_3d = np.flip(label_3d, 2) + label_3d = label_mapping(label_3d) + + for c in range(img_3d.shape[0]): + img = img_3d[c] + label = label_3d[c] + + img = Image.fromarray(img).convert('RGB') + label = Image.fromarray(label).convert('L') + img.save( + osp.join( + save_path, 'img_dir/train', 'case' + idx.zfill(4) + + '_slice' + str(c).zfill(3) + '.jpg')) + label.save( + osp.join( + save_path, 'ann_dir/train', 'case' + idx.zfill(4) + + '_slice' + str(c).zfill(3) + '.png')) + + for i, idx in enumerate(test_id): + img_3d = read_nii_file( + osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz')) + label_3d = read_nii_file( + osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz')) + + img_3d = np.clip(img_3d, -125, 275) + img_3d = (img_3d + 125) / 400 + img_3d *= 255 + img_3d = np.transpose(img_3d, [2, 0, 1]) + img_3d = np.flip(img_3d, 2) + + label_3d = np.transpose(label_3d, [2, 0, 1]) + label_3d = np.flip(label_3d, 2) + label_3d = label_mapping(label_3d) + + for c in range(img_3d.shape[0]): + img = img_3d[c] + label = label_3d[c] + + img = Image.fromarray(img).convert('RGB') + label = Image.fromarray(label).convert('L') + img.save( + osp.join( + save_path, 'img_dir/val', 'case' + idx.zfill(4) + + '_slice' + str(c).zfill(3) + '.jpg')) + label.save( + osp.join( + save_path, 'ann_dir/val', 'case' + idx.zfill(4) + + '_slice' + str(c).zfill(3) + '.png')) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/dataset_converters/vaihingen.py b/finetune/tools/dataset_converters/vaihingen.py new file mode 100644 index 0000000..db98014 --- /dev/null +++ b/finetune/tools/dataset_converters/vaihingen.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import glob +import math +import os +import os.path as osp +import tempfile +import zipfile + +import mmcv +import numpy as np +from mmengine.utils import ProgressBar, mkdir_or_exist + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert vaihingen dataset to mmsegmentation format') + parser.add_argument('dataset_path', help='vaihingen folder path') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--clip_size', + type=int, + help='clipped size of image after preparation', + default=512) + parser.add_argument( + '--stride_size', + type=int, + help='stride of clipping original images', + default=256) + args = parser.parse_args() + return args + + +def clip_big_image(image_path, clip_save_dir, to_label=False): + # Original image of Vaihingen dataset is very large, thus pre-processing + # of them is adopted. Given fixed clip size and stride size to generate + # clipped image, the intersection of width and height is determined. + # For example, given one 5120 x 5120 original image, the clip size is + # 512 and stride size is 256, thus it would generate 20x20 = 400 images + # whose size are all 512x512. + image = mmcv.imread(image_path) + + h, w, c = image.shape + cs = args.clip_size + ss = args.stride_size + + num_rows = math.ceil((h - cs) / ss) if math.ceil( + (h - cs) / ss) * ss + cs >= h else math.ceil((h - cs) / ss) + 1 + num_cols = math.ceil((w - cs) / ss) if math.ceil( + (w - cs) / ss) * ss + cs >= w else math.ceil((w - cs) / ss) + 1 + + x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1)) + xmin = x * cs + ymin = y * cs + + xmin = xmin.ravel() + ymin = ymin.ravel() + xmin_offset = np.where(xmin + cs > w, w - xmin - cs, np.zeros_like(xmin)) + ymin_offset = np.where(ymin + cs > h, h - ymin - cs, np.zeros_like(ymin)) + boxes = np.stack([ + xmin + xmin_offset, ymin + ymin_offset, + np.minimum(xmin + cs, w), + np.minimum(ymin + cs, h) + ], + axis=1) + + if to_label: + color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0], + [255, 255, 0], [0, 255, 0], [0, 255, 255], + [0, 0, 255]]) + flatten_v = np.matmul( + image.reshape(-1, c), + np.array([2, 3, 4]).reshape(3, 1)) + out = np.zeros_like(flatten_v) + for idx, class_color in enumerate(color_map): + value_idx = np.matmul(class_color, + np.array([2, 3, 4]).reshape(3, 1)) + out[flatten_v == value_idx] = idx + image = out.reshape(h, w) + + for box in boxes: + start_x, start_y, end_x, end_y = box + clipped_image = image[start_y:end_y, + start_x:end_x] if to_label else image[ + start_y:end_y, start_x:end_x, :] + area_idx = osp.basename(image_path).split('_')[3].strip('.tif') + mmcv.imwrite( + clipped_image.astype(np.uint8), + osp.join(clip_save_dir, + f'{area_idx}_{start_x}_{start_y}_{end_x}_{end_y}.png')) + + +def main(): + splits = { + 'train': [ + 'area1', 'area11', 'area13', 'area15', 'area17', 'area21', + 'area23', 'area26', 'area28', 'area3', 'area30', 'area32', + 'area34', 'area37', 'area5', 'area7' + ], + 'val': [ + 'area6', 'area24', 'area35', 'area16', 'area14', 'area22', + 'area10', 'area4', 'area2', 'area20', 'area8', 'area31', 'area33', + 'area27', 'area38', 'area12', 'area29' + ], + } + + dataset_path = args.dataset_path + if args.out_dir is None: + out_dir = osp.join('data', 'vaihingen') + else: + out_dir = args.out_dir + + print('Making directories...') + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) + mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + + zipp_list = glob.glob(os.path.join(dataset_path, '*.zip')) + print('Find the data', zipp_list) + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + for zipp in zipp_list: + zip_file = zipfile.ZipFile(zipp) + zip_file.extractall(tmp_dir) + src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif')) + if 'ISPRS_semantic_labeling_Vaihingen' in zipp: + src_path_list = glob.glob( + os.path.join(os.path.join(tmp_dir, 'top'), '*.tif')) + if 'ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE' in zipp: # noqa + src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif')) + # delete unused area9 ground truth + for area_ann in src_path_list: + if 'area9' in area_ann: + src_path_list.remove(area_ann) + prog_bar = ProgressBar(len(src_path_list)) + for i, src_path in enumerate(src_path_list): + area_idx = osp.basename(src_path).split('_')[3].strip('.tif') + data_type = 'train' if area_idx in splits['train'] else 'val' + if 'noBoundary' in src_path: + dst_dir = osp.join(out_dir, 'ann_dir', data_type) + clip_big_image(src_path, dst_dir, to_label=True) + else: + dst_dir = osp.join(out_dir, 'img_dir', data_type) + clip_big_image(src_path, dst_dir, to_label=False) + prog_bar.update() + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + args = parse_args() + main() diff --git a/finetune/tools/dataset_converters/voc_aug.py b/finetune/tools/dataset_converters/voc_aug.py new file mode 100644 index 0000000..a536f42 --- /dev/null +++ b/finetune/tools/dataset_converters/voc_aug.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from functools import partial + +import numpy as np +from mmengine.utils import mkdir_or_exist, scandir, track_parallel_progress +from PIL import Image +from scipy.io import loadmat + +AUG_LEN = 10582 + + +def convert_mat(mat_file, in_dir, out_dir): + data = loadmat(osp.join(in_dir, mat_file)) + mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8) + seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png')) + Image.fromarray(mask).save(seg_filename, 'PNG') + + +def generate_aug_list(merged_list, excluded_list): + return list(set(merged_list) - set(excluded_list)) + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert PASCAL VOC annotations to mmsegmentation format') + parser.add_argument('devkit_path', help='pascal voc devkit path') + parser.add_argument('aug_path', help='pascal voc aug path') + parser.add_argument('-o', '--out_dir', help='output path') + parser.add_argument( + '--nproc', default=1, type=int, help='number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + devkit_path = args.devkit_path + aug_path = args.aug_path + nproc = args.nproc + if args.out_dir is None: + out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug') + else: + out_dir = args.out_dir + mkdir_or_exist(out_dir) + in_dir = osp.join(aug_path, 'dataset', 'cls') + + track_parallel_progress( + partial(convert_mat, in_dir=in_dir, out_dir=out_dir), + list(scandir(in_dir, suffix='.mat')), + nproc=nproc) + + full_aug_list = [] + with open(osp.join(aug_path, 'dataset', 'train.txt')) as f: + full_aug_list += [line.strip() for line in f] + with open(osp.join(aug_path, 'dataset', 'val.txt')) as f: + full_aug_list += [line.strip() for line in f] + + with open( + osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', + 'train.txt')) as f: + ori_train_list = [line.strip() for line in f] + with open( + osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', + 'val.txt')) as f: + val_list = [line.strip() for line in f] + + aug_train_list = generate_aug_list(ori_train_list + full_aug_list, + val_list) + assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format( + AUG_LEN) + + with open( + osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', + 'trainaug.txt'), 'w') as f: + f.writelines(line + '\n' for line in aug_train_list) + + aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list) + assert len(aug_list) == AUG_LEN - len( + ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN - + len(ori_train_list)) + with open( + osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'), + 'w') as f: + f.writelines(line + '\n' for line in aug_list) + + print('Done!') + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/deployment/pytorch2torchscript.py b/finetune/tools/deployment/pytorch2torchscript.py new file mode 100644 index 0000000..e69e705 --- /dev/null +++ b/finetune/tools/deployment/pytorch2torchscript.py @@ -0,0 +1,185 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import numpy as np +import torch +import torch._C +import torch.serialization +from mmengine import Config +from mmengine.runner import load_checkpoint +from torch import nn + +from mmseg.models import build_segmentor + +torch.manual_seed(3) + + +def digit_version(version_str): + digit_version = [] + for x in version_str.split('.'): + if x.isdigit(): + digit_version.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + digit_version.append(int(patch_version[0]) - 1) + digit_version.append(int(patch_version[1])) + return digit_version + + +def check_torch_version(): + torch_minimum_version = '1.8.0' + torch_version = digit_version(torch.__version__) + + assert (torch_version >= digit_version(torch_minimum_version)), \ + f'Torch=={torch.__version__} is not support for converting to ' \ + f'torchscript. Please install pytorch>={torch_minimum_version}.' + + +def _convert_batchnorm(module): + module_output = module + if isinstance(module, torch.nn.SyncBatchNorm): + module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + module_output.weight.data = module.weight.data.clone().detach() + module_output.bias.data = module.bias.data.clone().detach() + # keep requires_grad unchanged + module_output.weight.requires_grad = module.weight.requires_grad + module_output.bias.requires_grad = module.bias.requires_grad + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + for name, child in module.named_children(): + module_output.add_module(name, _convert_batchnorm(child)) + del module + return module_output + + +def _demo_mm_inputs(input_shape, num_classes): + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + num_classes (int): + number of semantic classes + """ + (N, C, H, W) = input_shape + rng = np.random.RandomState(0) + imgs = rng.rand(*input_shape) + segs = rng.randint( + low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) + img_metas = [{ + 'img_shape': (H, W, C), + 'ori_shape': (H, W, C), + 'pad_shape': (H, W, C), + 'filename': '.png', + 'scale_factor': 1.0, + 'flip': False, + } for _ in range(N)] + mm_inputs = { + 'imgs': torch.FloatTensor(imgs).requires_grad_(True), + 'img_metas': img_metas, + 'gt_semantic_seg': torch.LongTensor(segs) + } + return mm_inputs + + +def pytorch2libtorch(model, + input_shape, + show=False, + output_file='tmp.pt', + verify=False): + """Export Pytorch model to TorchScript model and verify the outputs are + same between Pytorch and TorchScript. + + Args: + model (nn.Module): Pytorch model we want to export. + input_shape (tuple): Use this input shape to construct + the corresponding dummy input and execute the model. + show (bool): Whether print the computation graph. Default: False. + output_file (string): The path to where we store the + output TorchScript model. Default: `tmp.pt`. + verify (bool): Whether compare the outputs between + Pytorch and TorchScript. Default: False. + """ + if isinstance(model.decode_head, nn.ModuleList): + num_classes = model.decode_head[-1].num_classes + else: + num_classes = model.decode_head.num_classes + + mm_inputs = _demo_mm_inputs(input_shape, num_classes) + + imgs = mm_inputs.pop('imgs') + + # replace the original forword with forward_dummy + model.forward = model.forward_dummy + model.eval() + traced_model = torch.jit.trace( + model, + example_inputs=imgs, + check_trace=verify, + ) + + if show: + print(traced_model.graph) + + traced_model.save(output_file) + print(f'Successfully exported TorchScript model: {output_file}') + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert MMSeg to TorchScript') + parser.add_argument('config', help='test config file path') + parser.add_argument('--checkpoint', help='checkpoint file', default=None) + parser.add_argument( + '--show', action='store_true', help='show TorchScript graph') + parser.add_argument( + '--verify', action='store_true', help='verify the TorchScript model') + parser.add_argument('--output-file', type=str, default='tmp.pt') + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[512, 512], + help='input image size (height, width)') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + check_torch_version() + + if len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = ( + 1, + 3, + ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + cfg = Config.fromfile(args.config) + cfg.model.pretrained = None + + # build the model and load checkpoint + cfg.model.train_cfg = None + segmentor = build_segmentor( + cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) + # convert SyncBN to BN + segmentor = _convert_batchnorm(segmentor) + + if args.checkpoint: + load_checkpoint(segmentor, args.checkpoint, map_location='cpu') + + # convert the PyTorch model to LibTorch model + pytorch2libtorch( + segmentor, + input_shape, + show=args.show, + output_file=args.output_file, + verify=args.verify) diff --git a/finetune/tools/dist_test.sh b/finetune/tools/dist_test.sh new file mode 100644 index 0000000..89711fd --- /dev/null +++ b/finetune/tools/dist_test.sh @@ -0,0 +1,20 @@ +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/test.py \ + $CONFIG \ + $CHECKPOINT \ + --launcher pytorch \ + ${@:4} diff --git a/finetune/tools/dist_train.sh b/finetune/tools/dist_train.sh new file mode 100644 index 0000000..a857df7 --- /dev/null +++ b/finetune/tools/dist_train.sh @@ -0,0 +1,17 @@ +CONFIG=$1 +GPUS=$2 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/train.py \ + $CONFIG \ + --launcher pytorch ${@:3} diff --git a/finetune/tools/misc/browse_dataset.py b/finetune/tools/misc/browse_dataset.py new file mode 100644 index 0000000..7863eb7 --- /dev/null +++ b/finetune/tools/misc/browse_dataset.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +from mmengine import Config, DictAction +from mmengine.registry import init_default_scope +from mmengine.utils import ProgressBar + +from mmseg.registry import DATASETS, VISUALIZERS + + +def parse_args(): + parser = argparse.ArgumentParser(description='Browse a dataset') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--output-dir', + default=None, + type=str, + help='If there is no display interface, you can save it') + parser.add_argument('--not-show', default=False, action='store_true') + parser.add_argument( + '--show-interval', + type=float, + default=2, + help='the interval of show (s)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # register all modules in mmseg into the registries + init_default_scope('mmseg') + + dataset = DATASETS.build(cfg.train_dataloader.dataset) + cfg.visualizer['save_dir'] = args.output_dir + visualizer = VISUALIZERS.build(cfg.visualizer) + visualizer.dataset_meta = dataset.METAINFO + + progress_bar = ProgressBar(len(dataset)) + for item in dataset: + img = item['inputs'].permute(1, 2, 0).numpy() + data_sample = item['data_samples'].numpy() + img_path = osp.basename(item['data_samples'].img_path) + + img = img[..., [2, 1, 0]] # bgr to rgb + + visualizer.add_datasample( + osp.basename(img_path), + img, + data_sample, + show=not args.not_show, + wait_time=args.show_interval) + + progress_bar.update() + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/misc/print_config.py b/finetune/tools/misc/print_config.py new file mode 100644 index 0000000..2a1c024 --- /dev/null +++ b/finetune/tools/misc/print_config.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import warnings + +from mmengine import Config, DictAction + +from mmseg.apis import init_model + + +def parse_args(): + parser = argparse.ArgumentParser(description='Print the whole config') + parser.add_argument('config', help='config file path') + parser.add_argument( + '--graph', action='store_true', help='print the models graph') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help="--options is deprecated in favor of --cfg_options' and it will " + 'not be supported in version v0.22.0. Override some settings in the ' + 'used config, the key-value pair in xxx=yyy format will be merged ' + 'into config file. If the value to be overwritten is a list, it ' + 'should be like key="[a,b]" or key=a,b It also allows nested ' + 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' + 'marks are necessary and that no white space is allowed.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + + if args.options and args.cfg_options: + raise ValueError( + '--options and --cfg-options cannot be both ' + 'specified, --options is deprecated in favor of --cfg-options. ' + '--options will not be supported in version v0.22.0.') + if args.options: + warnings.warn('--options is deprecated in favor of --cfg-options, ' + '--options will not be supported in version v0.22.0.') + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + print(f'Config:\n{cfg.pretty_text}') + # dump config + cfg.dump('example.py') + # dump models graph + if args.graph: + model = init_model(args.config, device='cpu') + print(f'Model graph:\n{str(model)}') + with open('example-graph.txt', 'w') as f: + f.writelines(str(model)) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/misc/publish_model.py b/finetune/tools/misc/publish_model.py new file mode 100644 index 0000000..e035ad9 --- /dev/null +++ b/finetune/tools/misc/publish_model.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import subprocess +from hashlib import sha256 + +import torch + +BLOCK_SIZE = 128 * 1024 + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Process a checkpoint to be published') + parser.add_argument('in_file', help='input checkpoint filename') + parser.add_argument('out_file', help='output checkpoint filename') + args = parser.parse_args() + return args + + +def sha256sum(filename: str) -> str: + """Compute SHA256 message digest from a file.""" + hash_func = sha256() + byte_array = bytearray(BLOCK_SIZE) + memory_view = memoryview(byte_array) + with open(filename, 'rb', buffering=0) as file: + for block in iter(lambda: file.readinto(memory_view), 0): + hash_func.update(memory_view[:block]) + return hash_func.hexdigest() + + +def process_checkpoint(in_file, out_file): + checkpoint = torch.load(in_file, map_location='cpu') + # remove optimizer for smaller file size + if 'optimizer' in checkpoint: + del checkpoint['optimizer'] + # if it is necessary to remove some sensitive data in checkpoint['meta'], + # add the code here. + torch.save(checkpoint, out_file) + sha = sha256sum(in_file) + final_file = out_file.rstrip('.pth') + f'-{sha[:8]}.pth' + subprocess.Popen(['mv', out_file, final_file]) + + +def main(): + args = parse_args() + process_checkpoint(args.in_file, args.out_file) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/model_converters/beit2mmseg.py b/finetune/tools/model_converters/beit2mmseg.py new file mode 100644 index 0000000..20f8f0f --- /dev/null +++ b/finetune/tools/model_converters/beit2mmseg.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_beit(ckpt): + new_ckpt = OrderedDict() + + for k, v in ckpt.items(): + if k.startswith('patch_embed'): + new_key = k.replace('patch_embed.proj', 'patch_embed.projection') + new_ckpt[new_key] = v + if k.startswith('blocks'): + new_key = k.replace('blocks', 'layers') + if 'norm' in new_key: + new_key = new_key.replace('norm', 'ln') + elif 'mlp.fc1' in new_key: + new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0') + elif 'mlp.fc2' in new_key: + new_key = new_key.replace('mlp.fc2', 'ffn.layers.1') + new_ckpt[new_key] = v + else: + new_key = k + new_ckpt[new_key] = v + + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in official pretrained beit models to' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_beit(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/model_converters/clip2mmseg.py b/finetune/tools/model_converters/clip2mmseg.py new file mode 100644 index 0000000..9a97e4b --- /dev/null +++ b/finetune/tools/model_converters/clip2mmseg.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_vitlayer(paras): + new_para_name = '' + if paras[0] == 'ln_1': + new_para_name = '.'.join(['ln1'] + paras[1:]) + elif paras[0] == 'attn': + new_para_name = '.'.join(['attn.attn'] + paras[1:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['ln2'] + paras[1:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffn.layers.0.0'] + paras[-1:]) + else: + new_para_name = '.'.join(['ffn.layers.1'] + paras[-1:]) + else: + print(f'Wrong for {paras}') + return new_para_name + + +def convert_translayer(paras): + new_para_name = '' + if paras[0] == 'attn': + new_para_name = '.'.join(['attentions.0.attn'] + paras[1:]) + elif paras[0] == 'ln_1': + new_para_name = '.'.join(['norms.0'] + paras[1:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['norms.1'] + paras[1:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffns.0.layers.0.0'] + paras[2:]) + elif paras[1] == 'c_proj': + new_para_name = '.'.join(['ffns.0.layers.1'] + paras[2:]) + else: + print(f'Wrong for {paras}') + else: + print(f'Wrong for {paras}') + return new_para_name + + +def convert_key_name(ckpt, visual_split): + new_ckpt = OrderedDict() + for k, v in ckpt.items(): + key_list = k.split('.') + if key_list[0] == 'visual': + new_transform_name = 'image_encoder' + if key_list[1] == 'class_embedding': + new_name = '.'.join([new_transform_name, 'cls_token']) + elif key_list[1] == 'positional_embedding': + new_name = '.'.join([new_transform_name, 'pos_embed']) + elif key_list[1] == 'conv1': + new_name = '.'.join([ + new_transform_name, 'patch_embed.projection', key_list[2] + ]) + elif key_list[1] == 'ln_pre': + new_name = '.'.join( + [new_transform_name, key_list[1], key_list[2]]) + elif key_list[1] == 'transformer': + new_layer_name = 'layers' + layer_index = key_list[3] + paras = key_list[4:] + if int(layer_index) < visual_split: + new_para_name = convert_vitlayer(paras) + new_name = '.'.join([ + new_transform_name, new_layer_name, layer_index, + new_para_name + ]) + else: + new_para_name = convert_translayer(paras) + new_transform_name = 'decode_head.rec_with_attnbias' + new_layer_name = 'layers' + layer_index = str(int(layer_index) - visual_split) + new_name = '.'.join([ + new_transform_name, new_layer_name, layer_index, + new_para_name + ]) + elif key_list[1] == 'proj': + new_name = 'decode_head.rec_with_attnbias.proj.weight' + elif key_list[1] == 'ln_post': + new_name = k.replace('visual', 'decode_head.rec_with_attnbias') + else: + print(f'pop parameter: {k}') + continue + else: + text_encoder_name = 'text_encoder' + if key_list[0] == 'transformer': + layer_name = 'transformer' + layer_index = key_list[2] + paras = key_list[3:] + new_para_name = convert_translayer(paras) + new_name = '.'.join([ + text_encoder_name, layer_name, layer_index, new_para_name + ]) + elif key_list[0] in [ + 'positional_embedding', 'text_projection', 'bg_embed', + 'attn_mask', 'logit_scale', 'token_embedding', 'ln_final' + ]: + new_name = 'text_encoder.' + k + else: + print(f'pop parameter: {k}') + continue + new_ckpt[new_name] = v + + return new_ckpt + + +def convert_tensor(ckpt): + cls_token = ckpt['image_encoder.cls_token'] + new_cls_token = cls_token.unsqueeze(0).unsqueeze(0) + ckpt['image_encoder.cls_token'] = new_cls_token + pos_embed = ckpt['image_encoder.pos_embed'] + new_pos_embed = pos_embed.unsqueeze(0) + ckpt['image_encoder.pos_embed'] = new_pos_embed + proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight'] + new_proj_weight = proj_weight.transpose(1, 0) + ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight + return ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in timm pretrained vit models to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + if any([s in args.src for s in ['B-16', 'b16', 'base_patch16']]): + visual_split = 9 + elif any([s in args.src for s in ['L-14', 'l14', 'large_patch14']]): + visual_split = 18 + else: + print('Make sure the clip model is ViT-B/16 or ViT-L/14!') + visual_split = -1 + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if isinstance(checkpoint, torch.jit.RecursiveScriptModule): + state_dict = checkpoint.state_dict() + else: + if 'state_dict' in checkpoint: + # timm checkpoint + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + # deit checkpoint + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_key_name(state_dict, visual_split) + weight = convert_tensor(weight) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/model_converters/mit2mmseg.py b/finetune/tools/model_converters/mit2mmseg.py new file mode 100644 index 0000000..f10cbbf --- /dev/null +++ b/finetune/tools/model_converters/mit2mmseg.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_mit(ckpt): + new_ckpt = OrderedDict() + # Process the concat between q linear weights and kv linear weights + for k, v in ckpt.items(): + if k.startswith('head'): + continue + # patch embedding conversion + elif k.startswith('patch_embed'): + stage_i = int(k.split('.')[0].replace('patch_embed', '')) + new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0') + new_v = v + if 'proj.' in new_k: + new_k = new_k.replace('proj.', 'projection.') + # transformer encoder layer conversion + elif k.startswith('block'): + stage_i = int(k.split('.')[0].replace('block', '')) + new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1') + new_v = v + if 'attn.q.' in new_k: + sub_item_k = k.replace('q.', 'kv.') + new_k = new_k.replace('q.', 'attn.in_proj_') + new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) + elif 'attn.kv.' in new_k: + continue + elif 'attn.proj.' in new_k: + new_k = new_k.replace('proj.', 'attn.out_proj.') + elif 'attn.sr.' in new_k: + new_k = new_k.replace('sr.', 'sr.') + elif 'mlp.' in new_k: + string = f'{new_k}-' + new_k = new_k.replace('mlp.', 'ffn.layers.') + if 'fc1.weight' in new_k or 'fc2.weight' in new_k: + new_v = v.reshape((*v.shape, 1, 1)) + new_k = new_k.replace('fc1.', '0.') + new_k = new_k.replace('dwconv.dwconv.', '1.') + new_k = new_k.replace('fc2.', '4.') + string += f'{new_k} {v.shape}-{new_v.shape}' + # norm layer conversion + elif k.startswith('norm'): + stage_i = int(k.split('.')[0].replace('norm', '')) + new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2') + new_v = v + else: + new_k = k + new_v = v + new_ckpt[new_k] = new_v + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in official pretrained segformer to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_mit(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/model_converters/san2mmseg.py b/finetune/tools/model_converters/san2mmseg.py new file mode 100644 index 0000000..301a466 --- /dev/null +++ b/finetune/tools/model_converters/san2mmseg.py @@ -0,0 +1,220 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_key_name(ckpt): + new_ckpt = OrderedDict() + + for k, v in ckpt.items(): + key_list = k.split('.') + if key_list[0] == 'clip_visual_extractor': + new_transform_name = 'image_encoder' + if key_list[1] == 'class_embedding': + new_name = '.'.join([new_transform_name, 'cls_token']) + elif key_list[1] == 'positional_embedding': + new_name = '.'.join([new_transform_name, 'pos_embed']) + elif key_list[1] == 'conv1': + new_name = '.'.join([ + new_transform_name, 'patch_embed.projection', key_list[2] + ]) + elif key_list[1] == 'ln_pre': + new_name = '.'.join( + [new_transform_name, key_list[1], key_list[2]]) + elif key_list[1] == 'resblocks': + new_layer_name = 'layers' + layer_index = key_list[2] + paras = key_list[3:] + if paras[0] == 'ln_1': + new_para_name = '.'.join(['ln1'] + key_list[4:]) + elif paras[0] == 'attn': + new_para_name = '.'.join(['attn.attn'] + key_list[4:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['ln2'] + key_list[4:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffn.layers.0.0'] + + key_list[-1:]) + else: + new_para_name = '.'.join(['ffn.layers.1'] + + key_list[-1:]) + new_name = '.'.join([ + new_transform_name, new_layer_name, layer_index, + new_para_name + ]) + elif key_list[0] == 'side_adapter_network': + decode_head_name = 'decode_head' + module_name = 'side_adapter_network' + if key_list[1] == 'vit_model': + if key_list[2] == 'blocks': + layer_name = 'encode_layers' + layer_index = key_list[3] + paras = key_list[4:] + if paras[0] == 'norm1': + new_para_name = '.'.join(['ln1'] + key_list[5:]) + elif paras[0] == 'attn': + new_para_name = '.'.join(key_list[4:]) + new_para_name = new_para_name.replace( + 'attn.qkv.', 'attn.attn.in_proj_') + new_para_name = new_para_name.replace( + 'attn.proj', 'attn.attn.out_proj') + elif paras[0] == 'norm2': + new_para_name = '.'.join(['ln2'] + key_list[5:]) + elif paras[0] == 'mlp': + new_para_name = '.'.join(['ffn'] + key_list[5:]) + new_para_name = new_para_name.replace( + 'fc1', 'layers.0.0') + new_para_name = new_para_name.replace( + 'fc2', 'layers.1') + else: + print(f'Wrong for {k}') + new_name = '.'.join([ + decode_head_name, module_name, layer_name, layer_index, + new_para_name + ]) + elif key_list[2] == 'pos_embed': + new_name = '.'.join( + [decode_head_name, module_name, 'pos_embed']) + elif key_list[2] == 'patch_embed': + new_name = '.'.join([ + decode_head_name, module_name, 'patch_embed', + 'projection', key_list[4] + ]) + else: + print(f'Wrong for {k}') + elif key_list[1] == 'query_embed' or key_list[ + 1] == 'query_pos_embed': + new_name = '.'.join( + [decode_head_name, module_name, key_list[1]]) + elif key_list[1] == 'fusion_layers': + layer_name = 'conv_clips' + layer_index = key_list[2][-1] + paras = '.'.join(key_list[3:]) + new_para_name = paras.replace('input_proj.0', '0') + new_para_name = new_para_name.replace('input_proj.1', '1.conv') + new_name = '.'.join([ + decode_head_name, module_name, layer_name, layer_index, + new_para_name + ]) + elif key_list[1] == 'mask_decoder': + new_name = 'decode_head.' + k + else: + print(f'Wrong for {k}') + elif key_list[0] == 'clip_rec_head': + module_name = 'rec_with_attnbias' + if key_list[1] == 'proj': + new_name = '.'.join( + [decode_head_name, module_name, 'proj.weight']) + elif key_list[1] == 'ln_post': + new_name = '.'.join( + [decode_head_name, module_name, 'ln_post', key_list[2]]) + elif key_list[1] == 'resblocks': + new_layer_name = 'layers' + layer_index = key_list[2] + paras = key_list[3:] + if paras[0] == 'ln_1': + new_para_name = '.'.join(['norms.0'] + paras[1:]) + elif paras[0] == 'attn': + new_para_name = '.'.join(['attentions.0.attn'] + paras[1:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['norms.1'] + paras[1:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffns.0.layers.0.0'] + + paras[2:]) + elif paras[1] == 'c_proj': + new_para_name = '.'.join(['ffns.0.layers.1'] + + paras[2:]) + else: + print(f'Wrong for {k}') + new_name = '.'.join([ + decode_head_name, module_name, new_layer_name, layer_index, + new_para_name + ]) + else: + print(f'Wrong for {k}') + elif key_list[0] == 'ov_classifier': + text_encoder_name = 'text_encoder' + if key_list[1] == 'transformer': + layer_name = 'transformer' + layer_index = key_list[3] + paras = key_list[4:] + if paras[0] == 'attn': + new_para_name = '.'.join(['attentions.0.attn'] + paras[1:]) + elif paras[0] == 'ln_1': + new_para_name = '.'.join(['norms.0'] + paras[1:]) + elif paras[0] == 'ln_2': + new_para_name = '.'.join(['norms.1'] + paras[1:]) + elif paras[0] == 'mlp': + if paras[1] == 'c_fc': + new_para_name = '.'.join(['ffns.0.layers.0.0'] + + paras[2:]) + elif paras[1] == 'c_proj': + new_para_name = '.'.join(['ffns.0.layers.1'] + + paras[2:]) + else: + print(f'Wrong for {k}') + else: + print(f'Wrong for {k}') + new_name = '.'.join([ + text_encoder_name, layer_name, layer_index, new_para_name + ]) + elif key_list[1] in [ + 'positional_embedding', 'text_projection', 'bg_embed', + 'attn_mask', 'logit_scale', 'token_embedding', 'ln_final' + ]: + new_name = k.replace('ov_classifier', 'text_encoder') + else: + print(f'Wrong for {k}') + elif key_list[0] == 'criterion': + new_name = k + else: + print(f'Wrong for {k}') + new_ckpt[new_name] = v + return new_ckpt + + +def convert_tensor(ckpt): + cls_token = ckpt['image_encoder.cls_token'] + new_cls_token = cls_token.unsqueeze(0).unsqueeze(0) + ckpt['image_encoder.cls_token'] = new_cls_token + pos_embed = ckpt['image_encoder.pos_embed'] + new_pos_embed = pos_embed.unsqueeze(0) + ckpt['image_encoder.pos_embed'] = new_pos_embed + proj_weight = ckpt['decode_head.rec_with_attnbias.proj.weight'] + new_proj_weight = proj_weight.transpose(1, 0) + ckpt['decode_head.rec_with_attnbias.proj.weight'] = new_proj_weight + return ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in timm pretrained vit models to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + # timm checkpoint + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + # deit checkpoint + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_key_name(state_dict) + weight = convert_tensor(weight) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/model_converters/stdc2mmseg.py b/finetune/tools/model_converters/stdc2mmseg.py new file mode 100644 index 0000000..6ea3b83 --- /dev/null +++ b/finetune/tools/model_converters/stdc2mmseg.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_stdc(ckpt, stdc_type): + new_state_dict = {} + if stdc_type == 'STDC1': + stage_lst = ['0', '1', '2.0', '2.1', '3.0', '3.1', '4.0', '4.1'] + else: + stage_lst = [ + '0', '1', '2.0', '2.1', '2.2', '2.3', '3.0', '3.1', '3.2', '3.3', + '3.4', '4.0', '4.1', '4.2' + ] + for k, v in ckpt.items(): + ori_k = k + flag = False + if 'cp.' in k: + k = k.replace('cp.', '') + if 'features.' in k: + num_layer = int(k.split('.')[1]) + feature_key_lst = 'features.' + str(num_layer) + '.' + stages_key_lst = 'stages.' + stage_lst[num_layer] + '.' + k = k.replace(feature_key_lst, stages_key_lst) + flag = True + if 'conv_list' in k: + k = k.replace('conv_list', 'layers') + flag = True + if 'avd_layer.' in k: + if 'avd_layer.0' in k: + k = k.replace('avd_layer.0', 'downsample.conv') + elif 'avd_layer.1' in k: + k = k.replace('avd_layer.1', 'downsample.bn') + flag = True + if flag: + new_state_dict[k] = ckpt[ori_k] + + return new_state_dict + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in official pretrained STDC1/2 to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + parser.add_argument('type', help='model type: STDC1 or STDC2') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + assert args.type in ['STDC1', + 'STDC2'], 'STD type should be STDC1 or STDC2!' + weight = convert_stdc(state_dict, args.type) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/model_converters/swin2mmseg.py b/finetune/tools/model_converters/swin2mmseg.py new file mode 100644 index 0000000..d434f94 --- /dev/null +++ b/finetune/tools/model_converters/swin2mmseg.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_swin(ckpt): + new_ckpt = OrderedDict() + + def correct_unfold_reduction_order(x): + out_channel, in_channel = x.shape + x = x.reshape(out_channel, 4, in_channel // 4) + x = x[:, [0, 2, 1, 3], :].transpose(1, + 2).reshape(out_channel, in_channel) + return x + + def correct_unfold_norm_order(x): + in_channel = x.shape[0] + x = x.reshape(4, in_channel // 4) + x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) + return x + + for k, v in ckpt.items(): + if k.startswith('head'): + continue + elif k.startswith('layers'): + new_v = v + if 'attn.' in k: + new_k = k.replace('attn.', 'attn.w_msa.') + elif 'mlp.' in k: + if 'mlp.fc1.' in k: + new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') + elif 'mlp.fc2.' in k: + new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') + else: + new_k = k.replace('mlp.', 'ffn.') + elif 'downsample' in k: + new_k = k + if 'reduction.' in k: + new_v = correct_unfold_reduction_order(v) + elif 'norm.' in k: + new_v = correct_unfold_norm_order(v) + else: + new_k = k + new_k = new_k.replace('layers', 'stages', 1) + elif k.startswith('patch_embed'): + new_v = v + if 'proj' in k: + new_k = k.replace('proj', 'projection') + else: + new_k = k + else: + new_v = v + new_k = k + + new_ckpt[new_k] = new_v + + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in official pretrained swin models to' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_swin(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/model_converters/twins2mmseg.py b/finetune/tools/model_converters/twins2mmseg.py new file mode 100644 index 0000000..647d417 --- /dev/null +++ b/finetune/tools/model_converters/twins2mmseg.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_twins(args, ckpt): + + new_ckpt = OrderedDict() + + for k, v in list(ckpt.items()): + new_v = v + if k.startswith('head'): + continue + elif k.startswith('patch_embeds'): + if 'proj.' in k: + new_k = k.replace('proj.', 'projection.') + else: + new_k = k + elif k.startswith('blocks'): + # Union + if 'attn.q.' in k: + new_k = k.replace('q.', 'attn.in_proj_') + new_v = torch.cat([v, ckpt[k.replace('attn.q.', 'attn.kv.')]], + dim=0) + elif 'mlp.fc1' in k: + new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') + elif 'mlp.fc2' in k: + new_k = k.replace('mlp.fc2', 'ffn.layers.1') + # Only pcpvt + elif args.model == 'pcpvt': + if 'attn.proj.' in k: + new_k = k.replace('proj.', 'attn.out_proj.') + else: + new_k = k + + # Only svt + else: + if 'attn.proj.' in k: + k_lst = k.split('.') + if int(k_lst[2]) % 2 == 1: + new_k = k.replace('proj.', 'attn.out_proj.') + else: + new_k = k + else: + new_k = k + new_k = new_k.replace('blocks.', 'layers.') + elif k.startswith('pos_block'): + new_k = k.replace('pos_block', 'position_encodings') + if 'proj.0.' in new_k: + new_k = new_k.replace('proj.0.', 'proj.') + else: + new_k = k + if 'attn.kv.' not in k: + new_ckpt[new_k] = new_v + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in timm pretrained vit models to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + parser.add_argument('model', help='model: pcpvt or svt') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + + if 'state_dict' in checkpoint: + # timm checkpoint + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + weight = convert_twins(args, state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/model_converters/vit2mmseg.py b/finetune/tools/model_converters/vit2mmseg.py new file mode 100644 index 0000000..1d1f8a4 --- /dev/null +++ b/finetune/tools/model_converters/vit2mmseg.py @@ -0,0 +1,70 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def convert_vit(ckpt): + + new_ckpt = OrderedDict() + + for k, v in ckpt.items(): + if k.startswith('head'): + continue + if k.startswith('norm'): + new_k = k.replace('norm.', 'ln1.') + elif k.startswith('patch_embed'): + if 'proj' in k: + new_k = k.replace('proj', 'projection') + else: + new_k = k + elif k.startswith('blocks'): + if 'norm' in k: + new_k = k.replace('norm', 'ln') + elif 'mlp.fc1' in k: + new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') + elif 'mlp.fc2' in k: + new_k = k.replace('mlp.fc2', 'ffn.layers.1') + elif 'attn.qkv' in k: + new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_') + elif 'attn.proj' in k: + new_k = k.replace('attn.proj', 'attn.attn.out_proj') + else: + new_k = k + new_k = new_k.replace('blocks.', 'layers.') + else: + new_k = k + new_ckpt[new_k] = v + + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in timm pretrained vit models to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + if 'state_dict' in checkpoint: + # timm checkpoint + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + # deit checkpoint + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + weight = convert_vit(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/model_converters/vitjax2mmseg.py b/finetune/tools/model_converters/vitjax2mmseg.py new file mode 100644 index 0000000..81bc2ea --- /dev/null +++ b/finetune/tools/model_converters/vitjax2mmseg.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +import mmengine +import numpy as np +import torch + + +def vit_jax_to_torch(jax_weights, num_layer=12): + torch_weights = dict() + + # patch embedding + conv_filters = jax_weights['embedding/kernel'] + conv_filters = conv_filters.permute(3, 2, 0, 1) + torch_weights['patch_embed.projection.weight'] = conv_filters + torch_weights['patch_embed.projection.bias'] = jax_weights[ + 'embedding/bias'] + + # pos embedding + torch_weights['pos_embed'] = jax_weights[ + 'Transformer/posembed_input/pos_embedding'] + + # cls token + torch_weights['cls_token'] = jax_weights['cls'] + + # head + torch_weights['ln1.weight'] = jax_weights['Transformer/encoder_norm/scale'] + torch_weights['ln1.bias'] = jax_weights['Transformer/encoder_norm/bias'] + + # transformer blocks + for i in range(num_layer): + jax_block = f'Transformer/encoderblock_{i}' + torch_block = f'layers.{i}' + + # attention norm + torch_weights[f'{torch_block}.ln1.weight'] = jax_weights[ + f'{jax_block}/LayerNorm_0/scale'] + torch_weights[f'{torch_block}.ln1.bias'] = jax_weights[ + f'{jax_block}/LayerNorm_0/bias'] + + # attention + query_weight = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/query/kernel'] + query_bias = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/query/bias'] + key_weight = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/key/kernel'] + key_bias = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/key/bias'] + value_weight = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/value/kernel'] + value_bias = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/value/bias'] + + qkv_weight = torch.from_numpy( + np.stack((query_weight, key_weight, value_weight), 1)) + qkv_weight = torch.flatten(qkv_weight, start_dim=1) + qkv_bias = torch.from_numpy( + np.stack((query_bias, key_bias, value_bias), 0)) + qkv_bias = torch.flatten(qkv_bias, start_dim=0) + + torch_weights[f'{torch_block}.attn.attn.in_proj_weight'] = qkv_weight + torch_weights[f'{torch_block}.attn.attn.in_proj_bias'] = qkv_bias + to_out_weight = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/out/kernel'] + to_out_weight = torch.flatten(to_out_weight, start_dim=0, end_dim=1) + torch_weights[ + f'{torch_block}.attn.attn.out_proj.weight'] = to_out_weight + torch_weights[f'{torch_block}.attn.attn.out_proj.bias'] = jax_weights[ + f'{jax_block}/MultiHeadDotProductAttention_1/out/bias'] + + # mlp norm + torch_weights[f'{torch_block}.ln2.weight'] = jax_weights[ + f'{jax_block}/LayerNorm_2/scale'] + torch_weights[f'{torch_block}.ln2.bias'] = jax_weights[ + f'{jax_block}/LayerNorm_2/bias'] + + # mlp + torch_weights[f'{torch_block}.ffn.layers.0.0.weight'] = jax_weights[ + f'{jax_block}/MlpBlock_3/Dense_0/kernel'] + torch_weights[f'{torch_block}.ffn.layers.0.0.bias'] = jax_weights[ + f'{jax_block}/MlpBlock_3/Dense_0/bias'] + torch_weights[f'{torch_block}.ffn.layers.1.weight'] = jax_weights[ + f'{jax_block}/MlpBlock_3/Dense_1/kernel'] + torch_weights[f'{torch_block}.ffn.layers.1.bias'] = jax_weights[ + f'{jax_block}/MlpBlock_3/Dense_1/bias'] + + # transpose weights + for k, v in torch_weights.items(): + if 'weight' in k and 'patch_embed' not in k and 'ln' not in k: + v = v.permute(1, 0) + torch_weights[k] = v + + return torch_weights + + +def main(): + # stole refactoring code from Robin Strudel, thanks + parser = argparse.ArgumentParser( + description='Convert keys from jax official pretrained vit models to ' + 'MMSegmentation style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + jax_weights = np.load(args.src) + jax_weights_tensor = {} + for key in jax_weights.files: + value = torch.from_numpy(jax_weights[key]) + jax_weights_tensor[key] = value + if 'L_16-i21k' in args.src: + num_layer = 24 + else: + num_layer = 12 + torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(torch_weights, args.dst) + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/slurm_test.sh b/finetune/tools/slurm_test.sh new file mode 100644 index 0000000..4e6f7bf --- /dev/null +++ b/finetune/tools/slurm_test.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +CHECKPOINT=$4 +GPUS=${GPUS:-4} +GPUS_PER_NODE=${GPUS_PER_NODE:-4} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +PY_ARGS=${@:5} +SRUN_ARGS=${SRUN_ARGS:-""} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} diff --git a/finetune/tools/slurm_train.sh b/finetune/tools/slurm_train.sh new file mode 100644 index 0000000..ab23210 --- /dev/null +++ b/finetune/tools/slurm_train.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +GPUS=${GPUS:-4} +GPUS_PER_NODE=${GPUS_PER_NODE:-4} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +SRUN_ARGS=${SRUN_ARGS:-""} +PY_ARGS=${@:4} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} diff --git a/finetune/tools/test.py b/finetune/tools/test.py new file mode 100644 index 0000000..0d7f39b --- /dev/null +++ b/finetune/tools/test.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.runner import Runner + + +# TODO: support fuse_conv_bn, visualization, and format_only +def parse_args(): + parser = argparse.ArgumentParser( + description='MMSeg test (and eval) a model') + parser.add_argument('config', help='train config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help=('if specified, the evaluation metric results will be dumped' + 'into the directory as json')) + parser.add_argument( + '--out', + type=str, + help='The directory to save output prediction for offline evaluation') + parser.add_argument( + '--show', action='store_true', help='show prediction results') + parser.add_argument( + '--show-dir', + help='directory where painted images will be saved. ' + 'If specified, it will be automatically saved ' + 'to the work_dir/timestamp/show_dir') + parser.add_argument( + '--wait-time', type=float, default=2, help='the interval of show (s)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument( + '--tta', action='store_true', help='Test time augmentation') + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def trigger_visualization_hook(cfg, args): + default_hooks = cfg.default_hooks + if 'visualization' in default_hooks: + visualization_hook = default_hooks['visualization'] + # Turn on visualization + visualization_hook['draw'] = True + if args.show: + visualization_hook['show'] = True + visualization_hook['wait_time'] = args.wait_time + if args.show_dir: + visualizer = cfg.visualizer + visualizer['save_dir'] = args.show_dir + else: + raise RuntimeError( + 'VisualizationHook must be included in default_hooks.' + 'refer to usage ' + '"visualization=dict(type=\'VisualizationHook\')"') + + return cfg + + +def main(): + args = parse_args() + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + cfg.load_from = args.checkpoint + + if args.show or args.show_dir: + cfg = trigger_visualization_hook(cfg, args) + + if args.tta: + cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline + cfg.tta_model.module = cfg.model + cfg.model = cfg.tta_model + + # add output_dir in metric + if args.out is not None: + cfg.test_evaluator['output_dir'] = args.out + cfg.test_evaluator['keep_results'] = True + + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start testing + runner.test() + + +if __name__ == '__main__': + main() diff --git a/finetune/tools/torchserve/mmseg2torchserve.py b/finetune/tools/torchserve/mmseg2torchserve.py new file mode 100644 index 0000000..23f9963 --- /dev/null +++ b/finetune/tools/torchserve/mmseg2torchserve.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser, Namespace +from pathlib import Path +from tempfile import TemporaryDirectory + +from mmengine import Config +from mmengine.utils import mkdir_or_exist + +try: + from model_archiver.model_packaging import package_model + from model_archiver.model_packaging_utils import ModelExportUtils +except ImportError: + package_model = None + + +def mmseg2torchserve( + config_file: str, + checkpoint_file: str, + output_folder: str, + model_name: str, + model_version: str = '1.0', + force: bool = False, +): + """Converts mmsegmentation model (config + checkpoint) to TorchServe + `.mar`. + + Args: + config_file: + In MMSegmentation config format. + The contents vary for each task repository. + checkpoint_file: + In MMSegmentation checkpoint format. + The contents vary for each task repository. + output_folder: + Folder where `{model_name}.mar` will be created. + The file created will be in TorchServe archive format. + model_name: + If not None, used for naming the `{model_name}.mar` file + that will be created under `output_folder`. + If None, `{Path(checkpoint_file).stem}` will be used. + model_version: + Model's version. + force: + If True, if there is an existing `{model_name}.mar` + file under `output_folder` it will be overwritten. + """ + mkdir_or_exist(output_folder) + + config = Config.fromfile(config_file) + + with TemporaryDirectory() as tmpdir: + config.dump(f'{tmpdir}/config.py') + + args = Namespace( + **{ + 'model_file': f'{tmpdir}/config.py', + 'serialized_file': checkpoint_file, + 'handler': f'{Path(__file__).parent}/mmseg_handler.py', + 'model_name': model_name or Path(checkpoint_file).stem, + 'version': model_version, + 'export_path': output_folder, + 'force': force, + 'requirements_file': None, + 'extra_files': None, + 'runtime': 'python', + 'archive_format': 'default' + }) + manifest = ModelExportUtils.generate_manifest_json(args) + package_model(args, manifest) + + +def parse_args(): + parser = ArgumentParser( + description='Convert mmseg models to TorchServe `.mar` format.') + parser.add_argument('config', type=str, help='config file path') + parser.add_argument('checkpoint', type=str, help='checkpoint file path') + parser.add_argument( + '--output-folder', + type=str, + required=True, + help='Folder where `{model_name}.mar` will be created.') + parser.add_argument( + '--model-name', + type=str, + default=None, + help='If not None, used for naming the `{model_name}.mar`' + 'file that will be created under `output_folder`.' + 'If None, `{Path(checkpoint_file).stem}` will be used.') + parser.add_argument( + '--model-version', + type=str, + default='1.0', + help='Number used for versioning.') + parser.add_argument( + '-f', + '--force', + action='store_true', + help='overwrite the existing `{model_name}.mar`') + args = parser.parse_args() + + return args + + +if __name__ == '__main__': + args = parse_args() + + if package_model is None: + raise ImportError('`torch-model-archiver` is required.' + 'Try: pip install torch-model-archiver') + + mmseg2torchserve(args.config, args.checkpoint, args.output_folder, + args.model_name, args.model_version, args.force) diff --git a/finetune/tools/torchserve/mmseg_handler.py b/finetune/tools/torchserve/mmseg_handler.py new file mode 100644 index 0000000..dbe5ded --- /dev/null +++ b/finetune/tools/torchserve/mmseg_handler.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import base64 +import os + +import cv2 +import mmcv +import torch +from mmengine.model.utils import revert_sync_batchnorm +from ts.torch_handler.base_handler import BaseHandler + +from mmseg.apis import inference_model, init_model + + +class MMsegHandler(BaseHandler): + + def initialize(self, context): + properties = context.system_properties + self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = torch.device(self.map_location + ':' + + str(properties.get('gpu_id')) if torch.cuda. + is_available() else self.map_location) + self.manifest = context.manifest + + model_dir = properties.get('model_dir') + serialized_file = self.manifest['model']['serializedFile'] + checkpoint = os.path.join(model_dir, serialized_file) + self.config_file = os.path.join(model_dir, 'config.py') + + self.model = init_model(self.config_file, checkpoint, self.device) + self.model = revert_sync_batchnorm(self.model) + self.initialized = True + + def preprocess(self, data): + images = [] + + for row in data: + image = row.get('data') or row.get('body') + if isinstance(image, str): + image = base64.b64decode(image) + image = mmcv.imfrombytes(image) + images.append(image) + + return images + + def inference(self, data, *args, **kwargs): + results = [inference_model(self.model, img) for img in data] + return results + + def postprocess(self, data): + output = [] + + for image_result in data: + _, buffer = cv2.imencode('.png', image_result[0].astype('uint8')) + content = buffer.tobytes() + output.append(content) + return output diff --git a/finetune/tools/torchserve/test_torchserve.py b/finetune/tools/torchserve/test_torchserve.py new file mode 100644 index 0000000..b015b66 --- /dev/null +++ b/finetune/tools/torchserve/test_torchserve.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser +from io import BytesIO + +import matplotlib.pyplot as plt +import mmcv +import requests + +from mmseg.apis import inference_model, init_model + + +def parse_args(): + parser = ArgumentParser( + description='Compare result of torchserve and pytorch,' + 'and visualize them.') + parser.add_argument('img', help='Image file') + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument('model_name', help='The model name in the server') + parser.add_argument( + '--inference-addr', + default='127.0.0.1:8080', + help='Address and port of the inference server') + parser.add_argument( + '--result-image', + type=str, + default=None, + help='save server output in result-image') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + + args = parser.parse_args() + return args + + +def main(args): + url = 'http://' + args.inference_addr + '/predictions/' + args.model_name + with open(args.img, 'rb') as image: + tmp_res = requests.post(url, image) + content = tmp_res.content + if args.result_image: + with open(args.result_image, 'wb') as out_image: + out_image.write(content) + plt.imshow(mmcv.imread(args.result_image, 'grayscale')) + plt.show() + else: + plt.imshow(plt.imread(BytesIO(content))) + plt.show() + model = init_model(args.config, args.checkpoint, args.device) + image = mmcv.imread(args.img) + result = inference_model(model, image) + plt.imshow(result[0]) + plt.show() + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/finetune/tools/train.py b/finetune/tools/train.py new file mode 100644 index 0000000..10fdaa1 --- /dev/null +++ b/finetune/tools/train.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import logging +import os +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.logging import print_log +from mmengine.runner import Runner + +from mmseg.registry import RUNNERS + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a segmentor') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--resume', + action='store_true', + default=False, + help='resume from the latest checkpoint in the work_dir automatically') + parser.add_argument( + '--amp', + action='store_true', + default=False, + help='enable automatic-mixed-precision training') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def main(): + args = parse_args() + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + # enable automatic-mixed-precision training + if args.amp is True: + optim_wrapper = cfg.optim_wrapper.type + if optim_wrapper == 'AmpOptimWrapper': + print_log( + 'AMP training is already enabled in your config.', + logger='current', + level=logging.WARNING) + else: + assert optim_wrapper == 'OptimWrapper', ( + '`--amp` is only supported when the optimizer wrapper type is ' + f'`OptimWrapper` but got {optim_wrapper}.') + cfg.optim_wrapper.type = 'AmpOptimWrapper' + cfg.optim_wrapper.loss_scale = 'dynamic' + + # resume training + cfg.resume = args.resume + + # build the runner from config + if 'runner_type' not in cfg: + # build the default runner + runner = Runner.from_cfg(cfg) + else: + # build customized runner from the registry + # if 'runner_type' is set in the cfg + runner = RUNNERS.build(cfg) + + # start training + runner.train() + + +if __name__ == '__main__': + main() diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100644 index 0000000..e88d170 --- /dev/null +++ b/lib/__init__.py @@ -0,0 +1,5 @@ +from .datasets import * +from .models import * +from .predictors import * +from .trainer import * +from .task import * \ No newline at end of file diff --git a/lib/datasets/__init__.py b/lib/datasets/__init__.py new file mode 100644 index 0000000..44a0cd2 --- /dev/null +++ b/lib/datasets/__init__.py @@ -0,0 +1,3 @@ +from .builder import PretrainingBuilder + +__all__ = ["PretrainingBuilder"] \ No newline at end of file diff --git a/lib/datasets/builder.py b/lib/datasets/builder.py new file mode 100644 index 0000000..a52e03b --- /dev/null +++ b/lib/datasets/builder.py @@ -0,0 +1,18 @@ +from antmmf.common.registry import registry +from antmmf.datasets.base_dataset_builder import BaseDatasetBuilder +from .loader.pretraining_loader import PretrainingLoader + +@registry.register_builder("pretraining_loader") +class PretrainingBuilder(BaseDatasetBuilder): + def __init__(self): + super().__init__("pretraining_loader") + + def _build(self, dataset_type, config, *args, **kwargs): + return None + + def _load(self, dataset_type, config, *args, **kwargs): + self.dataset = PretrainingLoader(dataset_type, config) + return self.dataset + + def update_registry_for_model(self, config): + pass diff --git a/lib/datasets/loader/__init__.py b/lib/datasets/loader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/datasets/loader/few_shot_flood3i_loader.py b/lib/datasets/loader/few_shot_flood3i_loader.py new file mode 100644 index 0000000..a961c4b --- /dev/null +++ b/lib/datasets/loader/few_shot_flood3i_loader.py @@ -0,0 +1,289 @@ +import os +import json +import datetime +import random +import itertools +import time + +import numpy as np +import torch +import torch.nn.functional as F + +from antmmf.structures import Sample +from antmmf.datasets.base_dataset import BaseDataset +from antmmf.common import Configuration + +from lib.datasets.utils.transforms import Compose, MSNormalize +from lib.datasets.utils.formatting import ToTensor +import lib.datasets.utils.pair_trainsforms as pair_transforms + +from skimage import io +from osgeo import gdal +from PIL import Image + + +class FewShotFloodLoader(BaseDataset): + DATASET_NAME = "few_shot_flood_loader" + + def __init__(self, dataset_type, config): + super().__init__(self.__class__.DATASET_NAME, dataset_type, config) + if dataset_type == 'train': + raise ValueError('train mode not support!!!') + + self.root = config.data_root_dir + self.dataset_type = dataset_type + self.img_dir = config.img_dir + self.tgt_dir = config.tgt_dir + with open(config.data_txt, 'r') as f: + test_list = f.readlines() + self.test_pairs = [] + self.cls2path = {} + for i in test_list: + i = i.strip() + if i == '': + continue + img_path = i[:-3] + cls = int(i[-2:]) + cls = int(cls) + self.test_pairs.append( + {'hr_path': img_path, + 'class': cls, + 'tgt_path': img_path.replace('_', '_lab_', 1).replace('.jpg', '.png') + }) + if cls in self.cls2path.keys(): + self.cls2path[cls].append({'hr_path': img_path, 'tgt_path': img_path.replace('_', '_lab_', 1).replace('.jpg', '.png'), 'class': cls}) + else: + self.cls2path[cls] = [{'hr_path': img_path, 'tgt_path': img_path.replace('_', '_lab_', 1).replace('.jpg', '.png'), 'class': cls}] + + self.seq_len = config.seq_len # ts + self.hr_size = config.image_size.hr + self.s2_size = config.image_size.s2 + self.s1_size = config.image_size.s1 + self.anno_size = config.image_size.anno + self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]) + self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]) # 先不管 + self.config = config + self.pipeline = self._get_pipline() + # self.crop_resize = pair_transforms.RandomResizedCropComb(512, scale=(0.99, 1.0), interpolation=3) + + def __len__(self) -> int: + return len(self.test_pairs) + + def _combine_two_images(self, image, image2): + dst = torch.cat([image, image2], dim=-2) + return dst + + def _get_pipline(self): + if self.dataset_type == 'val' or self.dataset_type == 'test': + pipeline = [ + pair_transforms.ToTensor(), + pair_transforms.RandomResizedCrop(512, scale=(0.9999, 1.0), interpolation=3), + pair_transforms.Normalize(), + ] + else: + raise ValueError('dataset_type not support') + return pair_transforms.Compose(pipeline) + + def _load_data(self, data_path): + file_name, file_extension = os.path.splitext(data_path) + if file_extension == '.npz' or file_extension == '.npy': + npz_key = self.config.get('npz_key', 'image') + data = np.load(data_path)[npz_key] + elif file_extension == '.png' or file_extension == '.jpg': + data = io.imread(data_path) + if len(data.shape) == 3: + data = data.transpose(2, 0, 1) + elif file_extension == '.tiff' or file_extension == '.tif': + dataset = gdal.Open(data_path) + if dataset is None: + raise IOError(f'can not open file: {data_path}') + data = dataset.ReadAsArray() + dataset = None + else: + raise ValueError(f'file type {data_path} not support') + # check nan + if np.isnan(data).any(): + print(f'{data_path} with nan, replace it to 0!') + data[np.isnan(data)] = 0 + return data + + def load_s2(self, pair): + if 'l8_path' in pair.keys(): + pair['s2_path'] = pair['l8_path'] + + if 's2_path' in pair.keys() and not self.config.get('masking_s2', False): + with_s2 = True + if isinstance(pair['s2_path'], list): + if True: # len(pair['s2_path']) > self.seq_len: + s2_path_list = np.random.choice(pair['s2_path'], self.seq_len) + s2_path_list = sorted(s2_path_list) + else: + s2_path_list = pair['s2_path'] + s2_list = [] + s2_ct_1 = [] + for s2_path in s2_path_list: + s2 = self._load_data(os.path.join(self.root, s2_path)) # [:10] + s2_list.append(s2) + ct = os.path.splitext(s2_path)[0].split('_') + ct = ct[3] # + ct[-3] + '01' + try: + ct = datetime.datetime.strptime(ct, '%Y%m%d') + except: + ct = datetime.datetime.strptime(ct, '%Y-%m-%d') + ct = ct.timetuple() + ct = ct.tm_yday - 1 + s2_ct_1.append(ct) + s2_1 = np.stack(s2_list, axis=1) + + else: + + s2 = np.load(os.path.join(self.root, pair['s2_path']))['image'] + date = np.load(os.path.join(self.root, pair['s2_path']))['date'] + if True: # s2.shape[0] > self.seq_len: + selected_indices = np.random.choice(s2.shape[0], size=self.seq_len, replace=False) + selected_indices = sorted(selected_indices) + s2 = s2[selected_indices, :, :, :] + date = date[selected_indices] + s2_1 = s2.transpose(1, 0, 2, 3) # ts, c, h, w -> c, ts, h, w + s2_ct_1 = [] + for ct in date: + try: + ct = datetime.datetime.strptime(ct, '%Y%m%d') + except: + ct = datetime.datetime.strptime(ct, '%Y-%m-%d') + ct = ct.timetuple() + ct = ct.tm_yday - 1 + s2_ct_1.append(ct) + + else: + with_s2 = False + s2_1 = np.zeros((10, self.seq_len, self.s2_size[0], self.s2_size[1]), + dtype=np.int16) + s2_ct_1 = [0] * self.seq_len + + return with_s2, s2_1, s2_ct_1 + + def load_s1(self, pair): + if 's1_path' in pair.keys(): + with_s1 = True + if isinstance(pair['s1_path'], list): + if True: # len(pair['s1_path']) > self.seq_len: + s1_path_list = np.random.choice(pair['s1_path'], self.seq_len) + s1_path_list = sorted(s1_path_list) + else: + s1_path_list = pair['s1_path'] + s1_list = [] + for s1_path in s1_path_list: + s1 = self._load_data(os.path.join(self.root, s1_path)) + s1_list.append(s1) + s1_1 = np.stack(s1_list, axis=1) + else: + s1 = self._load_data(os.path.join(self.root, pair['s1_path'])) + if True: # s1.shape[0] > self.seq_len: + selected_indices = np.random.choice(s1.shape[0], size=self.seq_len, replace=False) + selected_indices = sorted(selected_indices) + s1 = s1[selected_indices, :, :, :] + s1_1 = s1.transpose(1, 0, 2, 3) # ts, c, h, w -> c, ts, h, w + else: + with_s1 = False + s1_1 = np.zeros((2, self.seq_len, self.s1_size[0], self.s1_size[1]), + dtype=np.float32) + return with_s1, s1_1 + + def load_hr(self, pair): + if 'hr_path' in pair.keys(): + with_hr = True + hr = self._load_data(os.path.join(self.root, pair['hr_path'])) + else: + with_hr = False + hr = np.zeros((3, self.hr_size[0], self.hr_size[1]), + dtype=np.uint8) + return with_hr, hr + + def load_tgt(self, pair): + targets = self._load_data(os.path.join(self.root, pair['target_path'])) + return targets + + def get_item(self, idx): + pair = self.test_pairs[idx] + test_class = pair['class'] + + current_dataset = 'flood3i' + with_hr = True + with_s2 = False + with_s1 = False + + input_hr = io.imread(os.path.join(self.img_dir, pair['hr_path'])) + input_hr = input_hr.transpose(2,0,1) + _, input_s2,_ = self.load_s2(pair) + _, input_s1 = self.load_s1(pair) + input_tgt = io.imread(os.path.join(self.tgt_dir, pair['tgt_path'])) + modality_dict = { + 's2': with_s2, + 's1': with_s1, + 'hr': with_hr + } + + + input_tgt[input_tgt != test_class] = 0 + input_tgt[input_tgt == test_class] = 255 + input_tgt = np.concatenate((input_tgt[None, :,:],)*3, axis=0) + input_hr, input_s2, input_s1, input_tgt = self.pipeline(current_dataset, input_hr, input_s2, input_s1, + input_tgt) + + while True: + sel_prompt = random.choice(self.cls2path[test_class]) + if sel_prompt['hr_path'] != pair['hr_path']: + break + prompt_hr = io.imread(os.path.join(self.img_dir, sel_prompt['hr_path'])) + prompt_hr = prompt_hr.transpose(2,0,1) + _, prompt_s2,_ = self.load_s2(pair) + _, prompt_s1 = self.load_s1(pair) + prompt_tgt = io.imread(os.path.join(self.tgt_dir, sel_prompt['tgt_path'])) + + prompt_tgt[prompt_tgt != test_class] = 0 + prompt_tgt[prompt_tgt == test_class] = 255 + prompt_tgt = np.concatenate((prompt_tgt[None, :,:],)*3, axis=0) + + prompt_hr, prompt_s2, prompt_s1, prompt_tgt = self.pipeline(current_dataset, prompt_hr, prompt_s2, prompt_s1, prompt_tgt) + + targets_comb = self._combine_two_images(prompt_tgt, input_tgt) + hr_comb = self._combine_two_images(prompt_hr, input_hr) + s2_comb = self._combine_two_images(prompt_s2, input_s2) + s1_comb = self._combine_two_images(prompt_s1, input_s1) + + valid = torch.ones_like(targets_comb) + thres = torch.ones(3) * 1e-5 # ignore black + thres = (thres - self.imagenet_mean) / self.imagenet_std + valid[targets_comb < thres[:, None, None]] = 0 + + mask_shape = (int(self.config.mim.input_size[0] / self.config.mim.patch_size), + int(self.config.mim.input_size[1] / self.config.mim.patch_size)) + mask = np.zeros(mask_shape, dtype=np.int32) + mask[mask.shape[0] // 2:, :] = 1 + + geo_location = pair["location"] if "location" in pair.keys() else None + + modality_idx = 2 ** 0 * modality_dict['s2'] + 2 ** 1 * modality_dict['s1'] + 2 ** 2 * modality_dict['hr'] + modality_flag_s2 = modality_dict['s2'] + modality_flag_s1 = modality_dict['s1'] + modality_flag_hr = modality_dict['hr'] + + current_sample = Sample() + current_sample.img_name = pair["tgt_path"].split('/')[-1].split('.')[0] + '-' +str(test_class) + current_sample.hr_img = hr_comb + current_sample.dataset_name = 'flood3i' + current_sample.targets = targets_comb + current_sample.s2_img = s2_comb + current_sample.s2_ct = -1 + current_sample.s2_ct2 = -1 + current_sample.s1_img = s1_comb + current_sample.anno_mask = torch.from_numpy(mask) + current_sample.valid = valid + current_sample.location = geo_location + current_sample.modality_idx = modality_idx + current_sample.modality_flag_s2 = modality_flag_s2 + current_sample.modality_flag_s1 = modality_flag_s1 + current_sample.modality_flag_hr = modality_flag_hr + current_sample.task_type = self.dataset_type + return current_sample diff --git a/lib/datasets/loader/pretraining_loader.py b/lib/datasets/loader/pretraining_loader.py new file mode 100644 index 0000000..7e7748c --- /dev/null +++ b/lib/datasets/loader/pretraining_loader.py @@ -0,0 +1,494 @@ +import os +import json +import datetime +import random + +import torch +import numpy as np +from osgeo import gdal +from skimage import io +from skimage.transform import resize + +from antmmf.structures import Sample +from antmmf.datasets.base_dataset import BaseDataset + +import lib.datasets.utils.pair_trainsforms as pair_transforms +from lib.datasets.utils.masking_generator import MaskingGenerator +from lib.datasets.utils.dataset_colors import dataset_color_dict, get_painter_color_map_list, get_real_random_color_list + + +class PretrainingLoader(BaseDataset): + DATASET_NAME = "pretraining_loader" + + def __init__(self, dataset_type, config): + super().__init__(self.__class__.DATASET_NAME, dataset_type, config) + self.root = config.data_root_dir + if dataset_type == 'train': + self.json_path_list = config.train_json_path_list + if dataset_type == 'val': + self.json_path_list = config.val_json_path_list + if dataset_type == 'test': + self.json_path_list = config.val_json_path_list + self.dataset_type = dataset_type + self.pairs = [] + self.cls_repeat_cnt = config.cls_repeat_cnt + num_datasets = len(self.json_path_list) + for idx, json_path in enumerate(self.json_path_list): + print(os.path.join(config.data_root_dir, json_path)) + cur_pairs = json.load(open(os.path.join(config.data_root_dir, json_path))) + self.pairs.extend(cur_pairs) + cur_num = len(cur_pairs) + + if dataset_type == 'test' and config.prompt_json: + cur_pairs = json.load(open(config.prompt_json)) + self.prompt = cur_pairs[0] + print(f'prompt:{self.prompt}') + + self.use_multi_pairs = config.use_multi_pairs + + if self.use_multi_pairs: + self.pair_type_dict = {} + if dataset_type == 'train' or dataset_type == 'val': + for idx, pair in enumerate(self.pairs): + if pair["type"] not in self.pair_type_dict: + new_subset = {} + classes = pair["classes"] + for cls in classes: + if cls not in new_subset.keys(): + new_subset[cls] = [idx] + else: + new_subset[cls].append(idx) + self.pair_type_dict[pair["type"]] = new_subset + else: + classes = pair["classes"] + for cls in classes: + if cls not in self.pair_type_dict[pair["type"]].keys(): + self.pair_type_dict[pair["type"]][cls] = [idx] + else: + self.pair_type_dict[pair["type"]][cls].append(idx) + + cnt = 0 + self.idx_to_cls = {} + for k, v in self.pair_type_dict.items(): + for vv in v: + self.idx_to_cls[cnt] = { + 'type': k, + 'classes_id': vv + } + cnt = cnt + 1 + + print(self.idx_to_cls) + self.idx_to_cls_list = [] + for i in self.idx_to_cls.keys(): + self.idx_to_cls_list.append(self.idx_to_cls[i]) + print(self.idx_to_cls_list) + if self.dataset_type == 'train': + self.idx_to_cls_list = self.idx_to_cls_list * self.cls_repeat_cnt + self.masked_position_generator = MaskingGenerator( + input_size=config.mim.input_size, + patch_size=config.mim.patch_size, + mask_ratio=config.mim.mask_ratio + ) + if dataset_type == 'train': + self.half_mask_ratio = config.half_mask_ratio + else: + self.half_mask_ratio = 1. + + self.seq_len = config.seq_len # ts + self.hr_size = config.image_size.hr + self.s2_size = config.image_size.s2 + self.s1_size = config.image_size.s1 + self.anno_size = config.image_size.anno + self.min_random_scale = config.min_random_scale + self.imagenet_mean=torch.tensor([0.485, 0.456, 0.406]) + self.imagenet_std=torch.tensor([0.229, 0.224, 0.225]) + + self.pipeline = self._get_pipline() + self.crop_resize = pair_transforms.RandomResizedCropComb(512, scale=(0.3, 1.0), interpolation=3) + self.num_samples = 8 + + def __len__(self) -> int: + return len(self.idx_to_cls_list) + + def _convert_colors_pairs(self, images, original_colors, new_colors, current_color): + if len(original_colors) != len(new_colors): + raise ValueError("The length of original_colors and new_colors must be the same.") + unique_colors_list = [] + for image in images: + if len(image.shape) == 3: + image_hwc = image.transpose(1,2,0) # chw -> hwc + elif len(image.shape) == 2: + image_hwc = image[:,:,None] + else: + raise ValueError('image shape is {image_hwc.shape}, which is not support to change color!') + + image_2d = image_hwc.reshape(-1, image_hwc.shape[-1]) + unique_colors = np.unique(image_2d, axis=0) + unique_colors_list.append(unique_colors) + unique_colors_list.append(original_colors) + + sets_of_tuples = [set(map(tuple, a)) for a in unique_colors_list] + common_tuples = set.intersection(*sets_of_tuples) + unique_old_colors = np.array(list(common_tuples), dtype=np.uint8) + if len(unique_old_colors) == 0: + unique_old_colors = [current_color] + new_colors_coverted = new_colors[:len(unique_old_colors)] + images_converted_list = [] + + for image in images: + image_convered = self._convert_colors(image, unique_old_colors, new_colors_coverted) + images_converted_list.append(image_convered) + + return images_converted_list + + def _convert_colors(self, image, original_colors, new_colors): + """ + Remap colors in an image to new colors. + + Parameters: + image (numpy.ndarray): The image as a numpy array (channel x height x width). + original_colors (list of tuples): The list of original colors to be replaced. + new_colors (list of tuples): The list of new colors to replace the original colors. + + Returns: + numpy.ndarray: The image with remapped colors. (channel x height x width) + """ + + if len(original_colors) != len(new_colors): + raise ValueError("The length of original_colors and new_colors must be the same.") + + # Convert lists of tuples to numpy arrays for faster processing + original_colors = np.array(original_colors) + new_colors = np.array(new_colors) + if len(original_colors.shape) == 1: + original_colors = original_colors[:,None] + + # check image shape + if len(image.shape) == 3: + remapped_image = image.transpose(1,2,0) # chw -> hwc + elif len(image.shape) == 2: + remapped_image = image[:,:,None] + else: + raise ValueError('image shape is {image.shape}, which is not support to change color!') + + # generate new image for return + new_image = np.zeros((remapped_image.shape[0], remapped_image.shape[1], 3), dtype=np.uint8) + + for orig_color, new_color in zip(original_colors, new_colors): + mask = np.all(remapped_image == orig_color, axis=-1) + new_image[mask] = new_color + + new_image = new_image.transpose(2,0,1) # hwc -> chw + return new_image + + def _combine_images(self, images, interpolation='bicubic'): + # images 8, c, h, w -> c, 4h, 2w + group1 = images[:4] + group2 = images[4:] + stacked1 = torch.cat(group1, dim=-2) + stacked2 = torch.cat(group2, dim=-2) + result = torch.cat((stacked1, stacked2), dim=-1) + + return result + + def _get_pipline(self): + if self.dataset_type == 'train': + pipeline = [ + pair_transforms.ToTensor(), + pair_transforms.RandomResizedCrop(512, scale=(0.8, 1.0), interpolation=3), # 3 is bicubic + pair_transforms.RandomHorizontalFlip(), + pair_transforms.Normalize(), + ] + elif self.dataset_type == 'val' or self.dataset_type == 'test': + pipeline = [ + pair_transforms.ToTensor(), + pair_transforms.RandomResizedCrop(512, scale=(0.9999, 1.0), interpolation=3), # 3 is bicubic + pair_transforms.Normalize(), + ] + else: + raise ValueError('dataset_type not support') + return pair_transforms.Compose(pipeline) + + def _load_data(self, data_path): + file_name, file_extension = os.path.splitext(data_path) + if file_extension == '.npz' or file_extension == '.npy': + data = np.load(data_path)['image'] + elif file_extension == '.png' or file_extension == '.jpg': + data = io.imread(data_path) + if len(data.shape) == 3: + data = data.transpose(2,0,1) + elif file_extension == '.tiff' or file_extension == '.tif': + dataset = gdal.Open(data_path) + if dataset is None: + raise IOError(f'无法打开文件{data_path}') + data = dataset.ReadAsArray() + dataset = None + else: + raise ValueError(f'file type {data_path} not support') + if np.isnan(data).any(): + print(f'{data_path} with nan, replace it to 0!') + data[np.isnan(data)] = 0 + return data + + def load_s2(self, pair): + if pair['type'] == 'flair-mm' and 's2_path' in pair.keys(): + with_s2 =True + s2 = np.load(os.path.join(self.root, pair['s2_path'])) + idx_centroid = pair['s2_cut_points'] + s2_patch_size = 40 + subset_sp = s2[:,:,idx_centroid[0]-int(s2_patch_size/2):idx_centroid[0] + \ + int(s2_patch_size/2),idx_centroid[1] - int(s2_patch_size/2):idx_centroid[1] + \ + int(s2_patch_size/2)] + ts, c, h, w = subset_sp.shape + subset_sp = subset_sp.reshape(-1, h, w).transpose(1,2,0) + s2 = resize(subset_sp, (16, 16), anti_aliasing=True).transpose(2,0,1) + s2 = s2.reshape(ts, c, 16, 16) + if True: + selected_indices = np.random.choice(s2.shape[0], size=self.seq_len, replace=False) + selected_indices = sorted(selected_indices) + s2 = s2[selected_indices, :, :, :] + + s2_1 = s2.transpose(1,0,2,3) # ts, c, h, w -> c, ts, h, w + s2_ct_1 = [0] * self.seq_len + + elif 's2_path' in pair.keys(): + with_s2 =True + if isinstance(pair['s2_path'], list): + if True: + s2_path_list = np.random.choice(pair['s2_path'], self.seq_len) + s2_path_list = sorted(s2_path_list) + else: + s2_path_list = pair['s2_path'] + s2_list = [] + s2_ct_1 = [] + for s2_path in s2_path_list: + s2 = self._load_data(os.path.join(self.root, s2_path))#[:10] + s2_list.append(s2) + ct = os.path.splitext(s2_path)[0].split('_') + ct = ct[-4] + ct[-3] + '01' + try: + ct = datetime.datetime.strptime(ct, '%Y%m%d') + except: + ct = datetime.datetime.strptime(ct, '%Y-%m-%d') + ct = ct.timetuple() + ct = ct.tm_yday - 1 + s2_ct_1.append(ct) + s2_1 = np.stack(s2_list, axis=1) + + else: + s2 = np.load(os.path.join(self.root, pair['s2_path']))['image'] + date = np.load(os.path.join(self.root, pair['s2_path']))['date'] + if True: + selected_indices = np.random.choice(s2.shape[0], size=self.seq_len, replace=False) + selected_indices = sorted(selected_indices) + s2 = s2[selected_indices, :, :, :] + date = date[selected_indices] + s2_1 = s2.transpose(1,0,2,3) # ts, c, h, w -> c, ts, h, w + s2_ct_1 = [] + for ct in date: + try: + ct = datetime.datetime.strptime(ct, '%Y%m%d') + except: + ct = datetime.datetime.strptime(ct, '%Y-%m-%d') + ct = ct.timetuple() + ct = ct.tm_yday - 1 + s2_ct_1.append(ct) + else: + with_s2 = False + s2_1 = np.zeros((10, self.seq_len, self.s2_size[0], self.s2_size[1]), + dtype=np.int16) + s2_ct_1 = [0] * self.seq_len + + return with_s2, s2_1, s2_ct_1 + + def load_s1(self, pair): + if 's1_path' in pair.keys(): + with_s1 = True + if isinstance(pair['s1_path'], list): + if True: + s1_path_list = np.random.choice(pair['s1_path'], self.seq_len) + s1_path_list = sorted(s1_path_list) + else: + s1_path_list = pair['s1_path'] + s1_list = [] + for s1_path in s1_path_list: + s1 = self._load_data(os.path.join(self.root, s1_path)) + s1_list.append(s1) + s1_1 = np.stack(s1_list, axis=1) + else: + s1 = self._load_data(os.path.join(self.root, pair['s1_path'])) + if True: + selected_indices = np.random.choice(s1.shape[0], size=self.seq_len, replace=False) + selected_indices = sorted(selected_indices) + s1 = s1[selected_indices, :, :, :] + s1_1 = s1.transpose(1,0,2,3) # ts, c, h, w -> c, ts, h, w + else: + with_s1 = False + s1_1 = np.zeros((2, self.seq_len, self.s1_size[0], self.s1_size[1]), + dtype=np.float32) + return with_s1, s1_1 + + def load_hr(self, pair): + if 'hr_path' in pair.keys(): + if pair['type'] == 'flair-mm': + with_hr = True + hr = self._load_data(os.path.join(self.root, pair['hr_path']))[:3,:,:] + else: + with_hr = True + hr = self._load_data(os.path.join(self.root, pair['hr_path'])) + else: + with_hr = False + hr = np.zeros((3, self.hr_size[0], self.hr_size[1]), + dtype=np.uint8) + return with_hr, hr + + def load_tgt(self, pair): + if self.dataset_type == 'test': + targets = np.zeros((3, self.anno_size[0], self.anno_size[1]), + dtype=np.uint8) + else: + targets = self._load_data(os.path.join(self.root, pair['target_path'])) + return targets + + def find_random_position(self, matrix, current_color): + if matrix.ndim == 2: + matrix = matrix[None, :, :] + current_color = np.array(current_color) + C, H, W = matrix.shape + + if len(current_color) != C: + raise ValueError("current_color unmatch with matrix!") + + matches = np.where(np.all(matrix == current_color[:, None, None], axis=0)) + + if len(matches[0]) > 0: + index = np.random.choice(range(len(matches[0]))) + return (matches[0][index], matches[1][index]) + else: + return None + + def get_item(self, idx): + dataset_cls_infos = self.idx_to_cls_list[idx] + current_dataset = dataset_cls_infos['type'] + current_classes_id = dataset_cls_infos['classes_id'] + pair_idx_list = self.pair_type_dict[current_dataset][current_classes_id] + + old_colors = dataset_color_dict[current_dataset] + current_color = old_colors[current_classes_id] + class_num = len(old_colors) + if self.dataset_type == 'train': + new_colors = get_real_random_color_list(class_num) + else: + new_colors = get_painter_color_map_list(class_num) # fix colors mapping when testing + + num_samples = self.num_samples + if len(pair_idx_list) < num_samples: + selected_samples = [random.choice(pair_idx_list) for _ in range(num_samples)] + else: + selected_samples = random.sample(pair_idx_list, num_samples) + hr_imgs = [] + tgt_imgs = [] + s2_imgs = [] + s1_imgs = [] + s2_cts = [] + for sample_idx in selected_samples: + pair = self.pairs[sample_idx] + with_hr, hr = self.load_hr(pair) + with_s2, s2, s2_ct_1 = self.load_s2(pair) + with_s1, s1 = self.load_s1(pair) + tgt = self.load_tgt(pair) + modality_dict = { + 's2' : with_s2, + 's1' : with_s1, + 'hr' : with_hr + } + + if (hr.shape[-2:] != tuple(self.hr_size)) and (hr.shape[-2:] == tgt.shape[-2:]) and (self.hr_size == self.anno_size): + point_pos = self.find_random_position(tgt, current_color) + upper_left_raw = [point_pos[0] - self.hr_size[0] // 2, point_pos[1] - self.hr_size[1] // 2] + upper_left = [i - i%32 + 16 for i in upper_left_raw] + upper_left_sentinel = [i // 32 for i in upper_left_raw] + upper_left[0] = np.clip(np.array(upper_left[0]), 0, hr.shape[-2] - self.hr_size[0]) + upper_left[1] = np.clip(np.array(upper_left[1]), 0, hr.shape[-1] - self.hr_size[1]) + + upper_left_sentinel[0] = np.clip(np.array(upper_left_sentinel[0]), 0, s1.shape[-2] - self.s1_size[0]) + upper_left_sentinel[1] = np.clip(np.array(upper_left_sentinel[1]), 0, s1.shape[-1] - self.s1_size[1]) + hr = hr[:, upper_left[0]:upper_left[0]+self.hr_size[0], upper_left[1]:upper_left[1]+self.hr_size[1]] + if with_s1: + s1 = s1[:, :, upper_left_sentinel[0]:upper_left_sentinel[0]+self.s1_size[0], upper_left_sentinel[1]:upper_left_sentinel[1]+self.s1_size[1]] + if with_s2: + s2 = s2[:, :, upper_left_sentinel[0]:upper_left_sentinel[0]+self.s2_size[0], upper_left_sentinel[1]:upper_left_sentinel[1]+self.s2_size[1]] + if tgt.ndim == 3: + tgt = tgt[:, upper_left[0]:upper_left[0]+self.hr_size[0], upper_left[1]:upper_left[1]+self.hr_size[1]] + elif tgt.ndim == 2: + tgt = tgt[upper_left[0]:upper_left[0]+self.hr_size[0], upper_left[1]:upper_left[1]+self.hr_size[1]] + else: + raise ValueError("tgt dim unsupport!") + hr_imgs.append(hr) + tgt_imgs.append(tgt) + s2_imgs.append(s2) + s1_imgs.append(s1) + s2_cts.append(s2_ct_1) + + + cvt_hr_imgs = [] + cvt_tgt_imgs = [] + cvt_s2_imgs = [] + cvt_s1_imgs = [] + + tgt_imgs = self._convert_colors_pairs(tgt_imgs, old_colors, new_colors, current_color) + for i in range(len(tgt_imgs)): + hr, s2, s1, tgt = self.pipeline(current_dataset, hr_imgs[i], s2_imgs[i], s1_imgs[i], tgt_imgs[i]) + cvt_hr_imgs.append(hr) + cvt_s2_imgs.append(s2) + cvt_s1_imgs.append(s1) + cvt_tgt_imgs.append(tgt) + + targets_comb = self._combine_images(cvt_tgt_imgs) + hr_comb = self._combine_images(cvt_hr_imgs) + s2_comb = self._combine_images(cvt_s2_imgs) + s1_comb = self._combine_images(cvt_s1_imgs) + hr_comb, s2_comb, s1_comb, targets_comb = self.crop_resize(current_dataset, hr_comb, s2_comb, s1_comb, targets_comb) + use_half_mask = torch.rand(1)[0] < self.half_mask_ratio + valid = torch.ones_like(targets_comb) + + thres = torch.ones(3) * (1e-5) # ignore black + thres = (thres - self.imagenet_mean) / self.imagenet_std + valid[targets_comb < thres[:, None, None]] = 0 + + if use_half_mask: + num_patches = self.masked_position_generator.num_patches + mask = np.zeros(self.masked_position_generator.get_shape(), dtype=np.int32) + mask[mask.shape[0]//2:, :] = 1 + else: + mask = self.masked_position_generator() + + # location + geo_location = pair["location"] if "location" in pair.keys() else None + + # get modality index + modality_idx = 2**0 * modality_dict['s2'] + 2**1 * modality_dict['s1'] + 2**2 * modality_dict['hr'] + modality_flag_s2 = modality_dict['s2'] + modality_flag_s1 = modality_dict['s1'] + modality_flag_hr = modality_dict['hr'] + + + current_sample = Sample() + current_sample.img_name = pair["hr_path"].split('/')[-1].split('.')[0] + current_sample.hr_img = hr_comb + current_sample.dataset_name = pair["type"] + current_sample.targets = targets_comb + current_sample.s2_img = s2_comb + current_sample.s2_ct = s2_cts[0] + current_sample.s2_ct2 = s2_cts[4] + current_sample.s1_img = s1_comb + current_sample.anno_mask = torch.from_numpy(mask) + current_sample.valid = valid + current_sample.location = geo_location + current_sample.modality_idx = modality_idx + current_sample.modality_flag_s2 = modality_flag_s2 + current_sample.modality_flag_s1 = modality_flag_s1 + current_sample.modality_flag_hr = modality_flag_hr + current_sample.task_type = self.dataset_type + + return current_sample \ No newline at end of file diff --git a/lib/datasets/utils/dataset_colors.py b/lib/datasets/utils/dataset_colors.py new file mode 100644 index 0000000..75f21cc --- /dev/null +++ b/lib/datasets/utils/dataset_colors.py @@ -0,0 +1,77 @@ +import random +from functools import lru_cache +import numpy as np + +dataset_color_dict = { + "potsdam" : [[1], [2], [3], [4], [5]], + "vaihingen" : [[255, 255, 0], [0, 255, 0], [0, 255, 255], [0, 0, 255], [255, 255, 255]], + "deepglobe" : [[255,255,255], [0,0,255], [0,255,0],[255,0,255], [255,255,0], [0,255,255]], + "fbp" : [[i+1] for i in range(24)], + "loveda" : [[i+2, i+2, i+2] for i in range(6)], + "isaid" : [[i+1] for i in range(15)], + "pastis-mm" : [[i+1] for i in range(18)], + "dynamic-mm" : [[i] for i in range(7)], + "c2seg-ab" : [[i+1] for i in range(13)], + "flood3i": [[i+1] for i in range(9)], + "jl16-mm": [[i] for i in range(16)], + "flair-mm": [[i+1] for i in range(18)], + "dfc20": [[i+1] for i in range(10)] +} + + +modal_norm_dict = { + 'hr' : { + 'div' : 255., + 'mean' : [0.485, 0.456, 0.406], + 'std' : [0.229, 0.224, 0.225] + }, + 'anno' : { + 'div' : 255., + 'mean' : [0.485, 0.456, 0.406], + 'std' : [0.229, 0.224, 0.225] + }, + 's2' : { + 'div' : 1., + 'mean' : [884.29673756, 1144.16202635, 1297.47289228, 1624.90992062, 2194.6423161, 2422.21248945, 2517.76053101, 2581.64687018, 2368.51236873, 1805.06846033], + 'std' : [1155.15170768, 1183.6292542, 1368.11351514, 1370.265037, 1355.55390699, 1416.51487101, 1474.78900051, 1439.3086061, 1455.52084939, 1343.48379601] + }, + 's1' : { + 'div' : 1., + 'mean' : [-12.54847273, -20.19237134], + 'std' : [5.25697717, 5.91150917] + }, +} + +@lru_cache() +def get_painter_color_map_list(num_locations = 300): + + num_sep_per_channel = int(num_locations ** (1 / 3)) + 1 # 19 + separation_per_channel = 256 // num_sep_per_channel + + color_list = [] + for location in range(num_locations): + num_seq_r = location // num_sep_per_channel ** 2 + num_seq_g = (location % num_sep_per_channel ** 2) // num_sep_per_channel + num_seq_b = location % num_sep_per_channel + assert (num_seq_r <= num_sep_per_channel) and (num_seq_g <= num_sep_per_channel) \ + and (num_seq_b <= num_sep_per_channel) + + R = 255 - num_seq_r * separation_per_channel + G = 255 - num_seq_g * separation_per_channel + B = 255 - num_seq_b * separation_per_channel + assert (R < 256) and (G < 256) and (B < 256) + assert (R >= 0) and (G >= 0) and (B >= 0) + assert (R, G, B) not in color_list + + color_list.append((R, G, B)) + + return color_list + + +def get_real_random_color_list(num_locations): + random_color_list = np.random.randint(0, 256, (num_locations, 3)) + while np.sum(random_color_list) == 0: + print('random_color_list is 0!') + random_color_list = np.random.randint(0, 256, (num_locations, 3)) + random_color_list = random_color_list.tolist() + return random_color_list # [:num_locations] diff --git a/lib/datasets/utils/formatting.py b/lib/datasets/utils/formatting.py new file mode 100644 index 0000000..a0967ff --- /dev/null +++ b/lib/datasets/utils/formatting.py @@ -0,0 +1,65 @@ +from collections.abc import Sequence +import mmcv +import numpy as np +import torch + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not mmcv.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f'type {type(data)} cannot be converted to tensor.') + + +class ToTensor(object): + """Convert some sample to :obj:`torch.Tensor` by given keys. + + Args: + keys (Sequence[str]): Keys that need to be converted to Tensor. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, sample): + """Call function to convert data in sample to :obj:`torch.Tensor`. + + Args: + sample (Sample): sample data contains the data to convert. + + Returns: + dict: The result dict contains the data converted + to :obj:`torch.Tensor`. + """ + + for key in self.keys: + if isinstance(sample[key], list): + for i in range(len(sample[key])): + sample[key][i] = to_tensor(sample[key][i]) + else: + sample[key] = to_tensor(sample[key]) + return sample + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + + diff --git a/lib/datasets/utils/masking_generator.py b/lib/datasets/utils/masking_generator.py new file mode 100644 index 0000000..b87eedc --- /dev/null +++ b/lib/datasets/utils/masking_generator.py @@ -0,0 +1,84 @@ +import random +import math +import numpy as np + +class MaskingGenerator: + def __init__( + self, input_size, patch_size, mask_ratio=0.5, min_num_patches=4, max_num_patches=None, + min_aspect=0.3, max_aspect=None): + if not isinstance(input_size, list): + input_size = [input_size,] * 2 + self.height = input_size[0] // patch_size + self.width = input_size[1] // patch_size + + self.num_patches = self.height * self.width + self.num_masking_patches = int(self.num_patches * mask_ratio) + + self.min_num_patches = min_num_patches + self.max_num_patches = self.num_masking_patches if max_num_patches is None else max_num_patches + + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def __repr__(self): + repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( + self.height, self.width, self.min_num_patches, self.max_num_patches, + self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) + return repr_str + + def get_shape(self): + return self.height, self.width + + def _mask(self, mask, max_mask_patches): + delta = 0 + for attempt in range(10): + target_area = random.uniform(self.min_num_patches, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < self.width and h < self.height: + top = random.randint(0, self.height - h) + left = random.randint(0, self.width - w) + + num_masked = mask[top: top + h, left: left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + + if delta > 0: + break + return delta + + def __call__(self): + mask = np.zeros(shape=self.get_shape(), dtype=np.int32) + mask_count = 0 + while mask_count < self.num_masking_patches: + max_mask_patches = self.num_masking_patches - mask_count + max_mask_patches = min(max_mask_patches, self.max_num_patches) + + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + # maintain a fix number {self.num_masking_patches} + if mask_count > self.num_masking_patches: + delta = mask_count - self.num_masking_patches + mask_x, mask_y = mask.nonzero() + to_vis = np.random.choice(mask_x.shape[0], delta, replace=False) + mask[mask_x[to_vis], mask_y[to_vis]] = 0 + + elif mask_count < self.num_masking_patches: + delta = self.num_masking_patches - mask_count + mask_x, mask_y = (mask == 0).nonzero() + to_mask = np.random.choice(mask_x.shape[0], delta, replace=False) + mask[mask_x[to_mask], mask_y[to_mask]] = 1 + + assert mask.sum() == self.num_masking_patches, f"mask: {mask}, mask count {mask.sum()}" + + return mask diff --git a/lib/datasets/utils/pair_trainsforms.py b/lib/datasets/utils/pair_trainsforms.py new file mode 100644 index 0000000..5e49bbe --- /dev/null +++ b/lib/datasets/utils/pair_trainsforms.py @@ -0,0 +1,532 @@ +import math +import numbers +import random +import warnings +from collections.abc import Sequence +from typing import List, Optional, Tuple +import numpy as np + +import torch +from torch import Tensor +import torchvision.transforms as transforms +from skimage import io + + +try: + import accimage +except ImportError: + accimage = None + +import torchvision.transforms.functional as F +from torchvision.transforms.functional import _interpolation_modes_from_int, InterpolationMode +from PIL import Image, ImageFilter, ImageOps + +from .dataset_colors import modal_norm_dict + +__all__ = [ + "Compose", + "ToTensor", + "Normalize", + "RandomHorizontalFlip", + "RandomResizedCrop", +] + + + +class Compose(transforms.Compose): + """Composes several transforms together. This transform does not support torchscript. + Please, see the note below. + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + """ + + def __init__(self, transforms): + super().__init__(transforms) + + def __call__(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None): + # i = 0 + for t in self.transforms: + # i = i+1 + # print(f'dataset_name:{dataset_name}') + # print(f'step:{i}') + # print(f'hr_img shape:{hr_img.shape}') + # print(f's2_img shape:{s2_img.shape}') + # print(f's1_img shape:{s1_img.shape}') + # print(f'tgt shape:{tgt.shape}') + + hr_img, s2_img, s1_img, tgt = t(dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=interpolation1, interpolation2=interpolation2) + return hr_img, s2_img, s1_img, tgt + + +class ToTensor(transforms.ToTensor): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript. + Converts a PIL Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] + if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) + or if the numpy.ndarray has dtype = np.uint8 + In the other cases, tensors are returned without scaling. + .. note:: + Because the input image is scaled to [0.0, 1.0], this transformation should not be used when + transforming target image masks. See the `references`_ for implementing the transforms for image masks. + .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation + """ + + def __init__(self) -> None: + super().__init__() + + def __call__(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None): + """ + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + Returns: + Tensor: Converted image. + """ + + + # print(f'hr dtype:{hr_img.dtype}') + # print(f's2_img dtype:{s2_img.dtype}') + # print(f's1_img dtype:{s1_img.dtype}') + # print(f'tgt dtype:{tgt.dtype}') + if dataset_name == 'dynamic-mm' or dataset_name == 'guizhou-mm': + hr_img = hr_img.astype(np.int32)[:3,:,:] + hr_img = hr_img[::-1,:,:].copy() + else: + hr_img = hr_img.astype(np.int32) + tgt = tgt.astype(np.uint8) + s1_img = s1_img.astype(np.float32) + s2_img = s2_img.astype(np.int16) + + return torch.tensor(hr_img), torch.tensor(s2_img), torch.tensor(s1_img),torch.tensor(tgt) + + +class Normalize(transforms.Normalize): + """Normalize a tensor image with mean and standard deviation. + This transform does not support PIL Image. + Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` + channels, this transform will normalize each channel of the input + ``torch.*Tensor`` i.e., + ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` + .. note:: + This transform acts out of place, i.e., it does not mutate the input tensor. + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + inplace(bool,optional): Bool to make this operation in-place. + """ + + def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False): + super().__init__(mean, std, inplace) + + def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None): + """ + Args: + tensor (Tensor): Tensor image to be normalized. + Returns: + Tensor: Normalized Tensor image. + """ + # TODO 查询对应的mean和std + + # 处理一些mean和std + if dataset_name == 'dynamic-mm': + hr_std = [1008.4052, 760.9586, 631.4754] + hr_mean = [1085.2941, 944.2718, 689.2493] + hr_div = 1. + else: + hr_mean = modal_norm_dict['hr']['mean'] + hr_std = modal_norm_dict['hr']['std'] + hr_div = modal_norm_dict['hr']['div'] + + if dataset_name == 'l8activefire': + # if False: + s2_mean = modal_norm_dict['l8']['mean'] + s2_std = modal_norm_dict['l8']['std'] + s2_div = modal_norm_dict['l8']['div'] + else: + s2_mean = modal_norm_dict['s2']['mean'] + s2_std = modal_norm_dict['s2']['std'] + s2_div = modal_norm_dict['s2']['div'] + + s1_mean = modal_norm_dict['s1']['mean'] + s1_std = modal_norm_dict['s1']['std'] + s1_div = modal_norm_dict['s1']['div'] + + anno_mean = [0.485, 0.456, 0.406] + anno_std = [0.229, 0.224, 0.225] + ann_div = 255. + + # 存在问题:时间序列这样处理是否会出错 + + #mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) + #std = torch.as_tensor(std, dtype=dtype, device=tensor.device) + # print(s2_img.shape) + # import pdb; pdb.set_trace() + # print(s2_img) + try: + ch, ts, h, w = s2_img.shape + except: + print(f's2: {s2_img.shape}, s1: {s1_img.shape}') + s2_img = s2_img.view(ch, ts*h, w) + s2_img = self.normalize(s2_img.type(torch.float32), s2_mean, s2_std, self.inplace) + s2_img = s2_img.view(ch, ts, h, w) + + ch, ts, h, w = s1_img.shape + s1_img = s1_img.view(ch, ts*h, w) + s1_img = self.normalize(s1_img.type(torch.float32), s1_mean, s1_std, self.inplace) + s1_img = s1_img.view(ch, ts, h, w) + + # import pdb; pdb.set_trace() + # print(s2_img.shape, s2_img[:,0,:,:]) + # print(s1_img.shape, s1_img[:,0,:,:]) + # print(hr_mean, hr_std, hr_div) + return self.normalize(hr_img.type(torch.float32).div_(hr_div), hr_mean, hr_std, self.inplace), \ + s2_img, \ + s1_img, \ + self.normalize(tgt.type(torch.float32).div_(ann_div) , anno_mean, anno_std, self.inplace) + + def normalize(self, tensor, mean, std, inplace): + dtype = tensor.dtype + mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) + std = torch.as_tensor(std, dtype=dtype, device=tensor.device) + if (std == 0).any(): + raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.") + mean = mean.view(-1, 1, 1) + std = std.view(-1, 1, 1) + # print(f'tensor shape: {tensor.shape}') + # print(f'mean shape: {mean.shape}') + # print(f'std shape: {std.shape}') + return tensor.sub_(mean).div_(std) + + +class RandomResizedCrop(transforms.RandomResizedCrop): + """Crop a random portion of image and resize it to a given size. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + A crop of the original image is made: the crop has a random area (H * W) + and a random aspect ratio. This crop is finally resized to the given + size. This is popularly used to train the Inception networks. + Args: + size (int or sequence): expected output size of the crop, for each edge. If size is an + int instead of sequence like (h, w), a square output size ``(size, size)`` is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + .. note:: + In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. + scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop, + before resizing. The scale is defined with respect to the area of the original image. + ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before + resizing. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and + ``InterpolationMode.BICUBIC`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, + but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. + """ + + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation=InterpolationMode.BILINEAR, + mode='small' + ): + super().__init__(size, scale=scale, ratio=ratio, interpolation=interpolation) + self.cnt=0 + self.mode = mode + + def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None, mode='small'): + """ + Args: + img (PIL Image or Tensor): Image to be cropped and resized. + Returns: + PIL Image or Tensor: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(s2_img, self.scale, self.ratio) + size_hr = hr_img.shape[-1] + size_s2 = s2_img.shape[-1] + size_anno = tgt.shape[-1] + # 映射到其他模态 + ratio_s2_hr = size_s2 / size_hr + i_hr = int(i / ratio_s2_hr) + j_hr = int(j / ratio_s2_hr) + h_hr = int(h / ratio_s2_hr) + w_hr = int(w / ratio_s2_hr) + + ratio_s2_anno = size_s2 / size_anno + i_anno = int(i / ratio_s2_anno) + j_anno = int(j / ratio_s2_anno) + h_anno = int(h / ratio_s2_anno) + w_anno = int(w / ratio_s2_anno) + + if interpolation1 == 'nearest': + interpolation1 = InterpolationMode.NEAREST + else: + interpolation1 = InterpolationMode.BICUBIC + if interpolation2 == 'nearest': + interpolation2 = InterpolationMode.NEAREST + else: + interpolation2 = InterpolationMode.BICUBIC + # import pdb;pdb.set_trace() + if self.scale[0]>0.99 and self.scale[0]<1.0: + if self.mode=='small': + resized_s2_img = F.resize(s2_img, (16,16), interpolation=InterpolationMode.BICUBIC) + resized_hr_img = F.resize(hr_img, (512, 512), interpolation=InterpolationMode.BICUBIC) + resized_s1_img = F.resize(s1_img, (16,16), interpolation=InterpolationMode.BICUBIC) + resized_tgt = F.resize(tgt, (512,512), interpolation=InterpolationMode.NEAREST) + else: + resized_s2_img = F.resize(s2_img, (64,64), interpolation=InterpolationMode.BICUBIC) + resized_hr_img = F.resize(hr_img, (2048, 2048), interpolation=InterpolationMode.BICUBIC) + resized_s1_img = F.resize(s1_img, (64,64), interpolation=InterpolationMode.BICUBIC) + resized_tgt = F.resize(tgt, (2048,2048), interpolation=InterpolationMode.NEAREST) + return resized_hr_img, resized_s2_img, resized_s1_img, resized_tgt + + if self.mode=='small': + resized_s2_img = F.resized_crop(s2_img, i, j, h, w, (16, 16), InterpolationMode.BICUBIC) + resized_hr_img = F.resized_crop(hr_img, i_hr, j_hr, h_hr, w_hr, (512, 512), InterpolationMode.BICUBIC) + resized_s1_img = F.resized_crop(s1_img, i, j, h, w, (16, 16), InterpolationMode.BICUBIC) + resized_tgt = F.resized_crop(tgt, i_anno, j_anno, h_anno, w_anno, (512, 512), InterpolationMode.NEAREST) + else: + resized_s2_img = F.resized_crop(s2_img, i, j, h, w, (512, 512), InterpolationMode.BICUBIC) + resized_hr_img = F.resized_crop(hr_img, i_hr, j_hr, h_hr, w_hr, (2048,2048), InterpolationMode.BICUBIC) + resized_s1_img = F.resized_crop(s1_img, i, j, h, w, (512, 512), InterpolationMode.BICUBIC) + resized_tgt = F.resized_crop(tgt, i_anno, j_anno, h_anno, w_anno, (2048, 2048), InterpolationMode.NEAREST) + + # import pdb; pdb.set_trace() + # 将resize后的结果保存为concat的img + # self.cnt = self.cnt+1 + # from torchvision.utils import save_image + # save_hr = resized_hr_img[:3, :, :] / resized_hr_img[:3, :, :].max() + # save_s2 = resized_s2_img[:3,0,:,:] / resized_s2_img[:3,0,:,:].max() + # print(f'{save_hr.shape}, {save_s2.shape}') + # save_image(save_s2, f'FoundationModel/debug/output2/resized_s2_{self.cnt}.png') + # save_image(save_hr, f'FoundationModel/debug/output2/resized_hr_{self.cnt}.png') + + return resized_hr_img, resized_s2_img, resized_s1_img, resized_tgt + +class RandomResizedCropComb(transforms.RandomResizedCrop): + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation=InterpolationMode.BILINEAR, + ): + super().__init__(size, scale=scale, ratio=ratio, interpolation=interpolation) + self.cnt=0 + + def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None): + """ + Args: + img (PIL Image or Tensor): Image to be cropped and resized. + Returns: + PIL Image or Tensor: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(s2_img, self.scale, self.ratio) + # print(f'i, j, h, w: {i, j, h, w}') + # print(f's2_img shape: {s2_img.shape}') + size_hr = hr_img.shape[-1] + size_s2 = s2_img.shape[-1] + size_anno = tgt.shape[-1] + # 映射到其他模态 + ratio_s2_hr = size_s2 / size_hr + i_hr = int(i / ratio_s2_hr) + j_hr = int(j / ratio_s2_hr) + h_hr = int(h / ratio_s2_hr) + w_hr = int(w / ratio_s2_hr) + + ratio_s2_anno = size_s2 / size_anno + i_anno = int(i / ratio_s2_anno) + j_anno = int(j / ratio_s2_anno) + h_anno = int(h / ratio_s2_anno) + w_anno = int(w / ratio_s2_anno) + + if interpolation1 == 'nearest': + interpolation1 = InterpolationMode.NEAREST + else: + interpolation1 = InterpolationMode.BICUBIC + if interpolation2 == 'nearest': + interpolation2 = InterpolationMode.NEAREST + else: + interpolation2 = InterpolationMode.BICUBIC + + resized_s2_img = F.resized_crop(s2_img, i, j, h, w, (32, 16), InterpolationMode.BICUBIC) + resized_hr_img = F.resized_crop(hr_img, i_hr, j_hr, h_hr, w_hr, (1024, 512), InterpolationMode.BICUBIC) + resized_s1_img = F.resized_crop(s1_img, i, j, h, w, (32, 16), InterpolationMode.BICUBIC) + resized_tgt = F.resized_crop(tgt, i_anno, j_anno, h_anno, w_anno, (1024, 512), InterpolationMode.NEAREST) + + return resized_hr_img, resized_s2_img, resized_s1_img, resized_tgt + + +class RandomHorizontalFlip(transforms.RandomHorizontalFlip): + """Horizontally flip the given image randomly with a given probability. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading + dimensions + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + super().__init__(p=p) + + def forward(self, dataset_name, hr_img, s2_img, s1_img, tgt, interpolation1=None, interpolation2=None): + """ + Args: + img (PIL Image or Tensor): Image to be flipped. + Returns: + PIL Image or Tensor: Randomly flipped image. + """ + if torch.rand(1) < self.p: + return F.hflip(hr_img), F.hflip(s2_img), F.hflip(s1_img), F.hflip(tgt) + return hr_img, s2_img, s1_img, tgt + + +class RandomApply(transforms.RandomApply): + """Apply randomly a list of transformations with a given probability. + .. note:: + In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of + transforms as shown below: + >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ + >>> transforms.ColorJitter(), + >>> ]), p=0.3) + >>> scripted_transforms = torch.jit.script(transforms) + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + Args: + transforms (sequence or torch.nn.Module): list of transformations + p (float): probability + """ + + def __init__(self, transforms, p=0.5): + super().__init__(transforms, p=p) + + def forward(self, img, tgt, interpolation1=None, interpolation2=None): + if self.p < torch.rand(1): + return img, tgt + for t in self.transforms: + img, tgt = t(img, tgt) + return img, tgt + +class ColorJitter(transforms.ColorJitter): + """Randomly change the brightness, contrast, saturation and hue of an image. + If the image is torch Tensor, it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported. + Args: + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space; + thus it does not work if you normalize your image to an interval with negative values, + or use an interpolation that generates negative values before using this function. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) + + def forward(self, img, tgt, interpolation1=None, interpolation2=None): + """ + Args: + img (PIL Image or Tensor): Input image. + Returns: + PIL Image or Tensor: Color jittered image. + """ + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) + + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return img, tgt + + +class RandomErasing(transforms.RandomErasing): + """Randomly selects a rectangle region in a torch.Tensor image and erases its pixels. + This transform does not support PIL Image. + 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 + Args: + p: probability that the random erasing operation will be performed. + scale: range of proportion of erased area against input image. + ratio: range of aspect ratio of erased area. + value: erasing value. Default is 0. If a single int, it is used to + erase all pixels. If a tuple of length 3, it is used to erase + R, G, B channels respectively. + If a str of 'random', erasing each pixel with random values. + inplace: boolean to make this transform inplace. Default set to False. + Returns: + Erased Image. + Example: + >>> transform = transforms.Compose([ + >>> transforms.RandomHorizontalFlip(), + >>> transforms.PILToTensor(), + >>> transforms.ConvertImageDtype(torch.float), + >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + >>> transforms.RandomErasing(), + >>> ]) + """ + + def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): + super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace) + + def forward(self, img, tgt, interpolation1=None, interpolation2=None): + """ + Args: + img (Tensor): Tensor image to be erased. + Returns: + img (Tensor): Erased Tensor image. + """ + if torch.rand(1) < self.p: + + # cast self.value to script acceptable type + if isinstance(self.value, (int, float)): + value = [self.value] + elif isinstance(self.value, str): + value = None + elif isinstance(self.value, tuple): + value = list(self.value) + else: + value = self.value + + if value is not None and not (len(value) in (1, img.shape[-3])): + raise ValueError( + "If value is a sequence, it should have either a single value or " + f"{img.shape[-3]} (number of input channels)" + ) + + x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value) + return F.erase(img, x, y, h, w, v, self.inplace), tgt + return img, tgt + + + +class GaussianBlur(object): + """Gaussian blur augmentation from SimCLR: https://arxiv.org/abs/2002.05709""" + + def __init__(self, sigma=[.1, 2.]): + self.sigma = sigma + + def __call__(self, img, tgt, interpolation1=None, interpolation2=None): + sigma = random.uniform(self.sigma[0], self.sigma[1]) + img = img.filter(ImageFilter.GaussianBlur(radius=sigma)) + return img, tgt + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}( sigma={self.sigma})" + return s + diff --git a/lib/datasets/utils/transforms.py b/lib/datasets/utils/transforms.py new file mode 100644 index 0000000..314c276 --- /dev/null +++ b/lib/datasets/utils/transforms.py @@ -0,0 +1,1558 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from packaging.version import Version +import numpy as np +from numpy import random +import math +from PIL import Image +import torch +import torchvision +import torchvision.transforms as T +import torchvision.transforms.functional as F +import mmcv +import copy +from mmcv.utils import deprecated_api_warning + + +class Compose(object): + """Compose multiple transforms sequentially. + """ + + def __init__(self, transforms): + assert isinstance(transforms, (list, tuple)) + self.transforms = transforms + + def __call__(self, sample): + """Call function to apply transforms sequentially. + + Args: + sample (Sample): A result dict contains the data to transform. + + Returns: + dict: Transformed data. + """ + + for t in self.transforms: + sample = t(sample) + if sample is None: + return None + return sample + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += f' {t}' + format_string += '\n)' + return format_string + + +class SegResize(object): + """Resize images & seg. + + This transform resizes the input image to some scale. If the input dict + contains the key "scale", then the scale in the input dict is used, + otherwise the specified scale in the init method is used. + + ``img_scale`` can be None, a tuple (single-scale) or a list of tuple + (multi-scale). There are 4 multiscale modes: + + - ``ratio_range is not None``: + 1. When img_scale is None, img_scale is the shape of image in sample + (img_scale = sample.img.shape[:2]) and the image is resized based + on the original size. (mode 1) + 2. When img_scale is a tuple (single-scale), randomly sample a ratio from + the ratio range and multiply it with the image scale. (mode 2) + + - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a + scale from the a range. (mode 3) + + - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a + scale from multiple scales. (mode 4) + + Args: + img_scale (tuple or list[tuple]): Images scales for resizing. + Default:None. + multiscale_mode (str): Either "range" or "value". + Default: 'range' + ratio_range (tuple[float]): (min_ratio, max_ratio). + Default: None + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. Default: True + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=True): + if img_scale is None: + self.img_scale = None + else: + if isinstance(img_scale, list): + self.img_scale = img_scale + else: + self.img_scale = [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) + + if ratio_range is not None: + # mode 1: given img_scale=None and a range of image ratio + # mode 2: given a scale and a range of image ratio + assert self.img_scale is None or len(self.img_scale) == 1 + else: + # mode 3 and 4: given multiple scales or a range of scales + assert multiscale_mode in ['value', 'range'] + + self.multiscale_mode = multiscale_mode + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + + @staticmethod + def random_select(img_scales): + """Randomly select an img_scale from given candidates. + + Args: + img_scales (list[tuple]): Images scales for selection. + + Returns: + (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, + where ``img_scale`` is the selected image scale and + ``scale_idx`` is the selected index in the given candidates. + """ + + assert mmcv.is_list_of(img_scales, tuple) + scale_idx = np.random.randint(len(img_scales)) + img_scale = img_scales[scale_idx] + return img_scale, scale_idx + + @staticmethod + def random_sample(img_scales): + """Randomly sample an img_scale when ``multiscale_mode=='range'``. + + Args: + img_scales (list[tuple]): Images scale range for sampling. + There must be two tuples in img_scales, which specify the lower + and upper bound of image scales. + + Returns: + (tuple, None): Returns a tuple ``(img_scale, None)``, where + ``img_scale`` is sampled scale and None is just a placeholder + to be consistent with :func:`random_select`. + """ + + assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint(min(img_scale_long), + max(img_scale_long) + 1) + short_edge = np.random.randint(min(img_scale_short), + max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale, None + + @staticmethod + def random_sample_ratio(img_scale, ratio_range): + """Randomly sample an img_scale when ``ratio_range`` is specified. + + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``img_scale`` to + generate sampled scale. + + Args: + img_scale (tuple): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``img_scale``. + + Returns: + (tuple, None): Returns a tuple ``(scale, None)``, where + ``scale`` is sampled ratio multiplied with ``img_scale`` and + None is just a placeholder to be consistent with + :func:`random_select`. + """ + + assert isinstance(img_scale, tuple) and len(img_scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) + return scale, None + + def _random_scale(self, sample): + """Randomly sample an img_scale according to ``ratio_range`` and + ``multiscale_mode``. + + If ``ratio_range`` is specified, a ratio will be sampled and be + multiplied with ``img_scale``. + If multiple scales are specified by ``img_scale``, a scale will be + sampled according to ``multiscale_mode``. + Otherwise, single scale will be used. + + Args: + sample (Sample): Sample data from :obj:`dataset`. + + Returns: + dict: Two new keys 'scale` and 'scale_idx` are added into + ``sample``, which would be used by subsequent pipelines. + """ + + if self.ratio_range is not None: + if self.img_scale is None: + h, w = sample.img.shape[:2] + scale, scale_idx = self.random_sample_ratio((w, h), + self.ratio_range) + else: + scale, scale_idx = self.random_sample_ratio( + self.img_scale[0], self.ratio_range) + elif len(self.img_scale) == 1: + scale, scale_idx = self.img_scale[0], 0 + elif self.multiscale_mode == 'range': + scale, scale_idx = self.random_sample(self.img_scale) + elif self.multiscale_mode == 'value': + scale, scale_idx = self.random_select(self.img_scale) + else: + raise NotImplementedError + + sample.scale = scale + sample.scale_idx = scale_idx + + def _resize_img(self, sample): + """Resize images with ``sample['scale']``.""" + if self.keep_ratio: + img, scale_factor = mmcv.imrescale(sample[sample.img_field], + sample.scale, + return_scale=True) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = sample[sample.img_field].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = mmcv.imresize(sample[sample.img_field], + sample.scale, + return_scale=True) + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], + dtype=np.float32) + sample[sample.img_field] = img + sample.img_shape = img.shape + sample.pad_shape = img.shape # in case that there is no padding + sample.scale_factor = scale_factor + sample.keep_ratio = self.keep_ratio + + def _resize_seg(self, sample): + """Resize semantic segmentation map with ``sample.scale``.""" + for key in sample.get('ann_fields', []): + if self.keep_ratio: + gt_seg = mmcv.imrescale(sample[key], + sample.scale, + interpolation='nearest') + else: + gt_seg = mmcv.imresize(sample[key], + sample.scale, + interpolation='nearest') + sample[key] = gt_seg + + def __call__(self, sample): + """Call function to resize images, bounding boxes, masks, semantic + segmentation map. + + Args: + sample (Sample): Sample dict from loading pipeline. + + Returns: + dict: Resized sample, 'img_shape', 'pad_shape', 'scale_factor', + 'keep_ratio' keys are added into result dict. + """ + + if 'scale' not in sample: + self._random_scale(sample) + self._resize_img(sample) + self._resize_seg(sample) + return sample + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(img_scale={self.img_scale}, ' + f'multiscale_mode={self.multiscale_mode}, ' + f'ratio_range={self.ratio_range}, ' + f'keep_ratio={self.keep_ratio})') + return repr_str + + +class SegRandomFlip(object): + """Flip the image & seg. + + If the input dict contains the key "flip", then the flag will be used, + otherwise it will be randomly decided by a ratio specified in the init + method. + + Args: + prob (float, optional): The flipping probability. Default: None. + direction(str, optional): The flipping direction. Options are + 'horizontal' and 'vertical'. Default: 'horizontal'. + """ + + @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='SegRandomFlip') + def __init__(self, prob=None, direction='horizontal'): + self.prob = prob + self.direction = direction + if prob is not None: + assert prob >= 0 and prob <= 1 + assert direction in ['horizontal', 'vertical'] + + def __call__(self, sample): + """Call function to flip bounding boxes, masks, semantic segmentation + maps. + + Args: + sample (Sample): Sample data from loading pipeline. + + Returns: + dict: Flipped sample, 'flip', 'flip_direction' keys are added into + result dict. + """ + + if 'flip' not in sample: + flip = True if np.random.rand() < self.prob else False + sample.flip = flip + if 'flip_direction' not in sample: + sample.flip_direction = self.direction + if sample.flip: + # flip image + sample[sample.img_field] = mmcv.imflip( + sample[sample.img_field], direction=sample.flip_direction) + + # flip segs + for key in sample.get('ann_fields', []): + # use copy() to make numpy stride positive + sample[key] = mmcv.imflip( + sample[key], direction=sample.flip_direction).copy() + return sample + + def __repr__(self): + return self.__class__.__name__ + f'(prob={self.prob})' + + +class Normalize(object): + """Normalize the image. + + Added key is "img_norm_cfg". + + Args: + mean (sequence): Mean values of 3 channels. + std (sequence): Std values of 3 channels. + to_rgb (bool): Whether to convert the image from BGR to RGB, + default is true. + """ + + def __init__(self, mean, std, to_rgb=True): + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.to_rgb = to_rgb + + def __call__(self, sample): + """Call function to normalize images. + + Args: + sample (Sample): Sample data from loading pipeline. + + Returns: + dict: Normalized sample, 'img_norm_cfg' key is added into + sample. + """ + + sample[sample.img_field] = mmcv.imnormalize(sample[sample.img_field], + self.mean, self.std, + self.to_rgb) + sample.img_norm_cfg = dict(mean=self.mean, + std=self.std, + to_rgb=self.to_rgb) + return sample + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \ + f'{self.to_rgb})' + return repr_str + + +class MSNormalize(object): + + def __init__(self, configs): + self.configs = configs + self.keys = configs.keys() + + def normalize_(self, img, config): + if isinstance(img, np.ndarray) and img.dtype != np.float32: + img = img.astype(np.float32) + if isinstance(img, torch.Tensor): + img = img.float() + div_value = config.div_value + mean = config.mean + std = config.std + img /= div_value + for t, m, s in zip(img, mean, std): + t -= m + t /= s + return img + + def __call__(self, sample): + for key in self.keys: + if isinstance(sample[key], list): + for i in range(len(sample[key])): + sample[key][i] = self.normalize_(sample[key][i], + self.configs[key]) + else: + sample[key] = self.normalize_(sample[key], self.configs[key]) + return sample + + +class MSRandomCrop(object): + """Random crop the hr_img s2_img targets. + Args: + crop_size (tuple): Expected size ratio after cropping, (h, w). + """ + + def __init__(self, crop_size, keys): + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + self.keys = keys + + def get_crop_bbox(self): + """Randomly get a crop bounding box.""" + margin_h = max(1.0 - self.crop_size[0], 0) + margin_w = max(1.0 - self.crop_size[1], 0) + offset_h = np.random.uniform(0, margin_h) + offset_w = np.random.uniform(0, margin_w) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + return crop_y1, crop_y2, crop_x1, crop_x2 + + def crop(self, img, crop_bbox): + """Crop from ``img``""" + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + h, w = img.shape[-2:] + crop_y1, crop_y2, crop_x1, crop_x2 = int(crop_y1 * h), int( + crop_y2 * h), int(crop_x1 * w), int(crop_x2 * w) + img = img[..., crop_y1:crop_y2, crop_x1:crop_x2] + return img + + def __call__(self, sample): + """Call function to randomly crop images, semantic segmentation maps. + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + + crop_bbox = self.get_crop_bbox() + for key in self.keys: + if isinstance(sample[key], list): + for i in range(len(sample[key])): + sample[key][i] = self.crop(sample[key][i], crop_bbox) + else: + sample[key] = self.crop(sample[key], crop_bbox) + return sample + + def __repr__(self): + return self.__class__.__name__ + f'(crop_size={self.crop_size})' + + +class MSRandomRangeCrop(object): + """Random crop the hr_img s2_img targets. + Args: + crop_size (tuple): Expected size ratio after cropping, (min, max). + """ + + def __init__(self, crop_size, keys): + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + self.keys = keys + + def get_crop_bbox(self): + """Randomly get a crop bounding box.""" + crop_size_ = np.random.uniform(self.crop_size[0], self.crop_size[1]) + margin_h = max(1.0 - crop_size_, 0) + margin_w = max(1.0 - crop_size_, 0) + offset_h = np.random.uniform(0, margin_h) + offset_w = np.random.uniform(0, margin_w) + crop_y1, crop_y2 = offset_h, offset_h + crop_size_ + crop_x1, crop_x2 = offset_w, offset_w + crop_size_ + + return crop_y1, crop_y2, crop_x1, crop_x2 + + def crop(self, img, crop_bbox): + """Crop from ``img``""" + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + h, w = img.shape[-2:] + crop_y1, crop_y2, crop_x1, crop_x2 = int(crop_y1 * h), int( + crop_y2 * h), int(crop_x1 * w), int(crop_x2 * w) + img = img[..., crop_y1:crop_y2, crop_x1:crop_x2] + return img + + def __call__(self, sample): + """Call function to randomly crop images, semantic segmentation maps. + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + + crop_bbox = self.get_crop_bbox() + for key in self.keys: + if isinstance(sample[key], list): + for i in range(len(sample[key])): + sample[key][i] = self.crop(sample[key][i], crop_bbox) + else: + sample[key] = self.crop(sample[key], crop_bbox) + return sample + + def __repr__(self): + return self.__class__.__name__ + f'(crop_size={self.crop_size})' + + +class MSResize(object): + + def __init__(self, target_size, keys): + assert target_size[0] > 0 and target_size[1] > 0 + self.target_size = target_size + self.keys = keys + + def __call__(self, sample): + for key in self.keys: + if key == 'targets': + sample[key] = F.resize( + sample[key], + self.target_size, + interpolation=T.InterpolationMode.NEAREST + if Version(torchvision.__version__) >= Version('0.9.0') + else Image.NEAREST) + else: + sample[key] = F.resize(sample[key], self.target_size) + return sample + + def __repr__(self): + return self.__class__.__name__ + f'(target_size={self.target_size})' + + +class MSSSLRandomResizedCrop(object): + + def __init__(self, configs, global_crops_number, local_crops_number): + self.configs = configs + self.global_crops_number = global_crops_number + self.local_crops_number = local_crops_number + + @staticmethod + def get_params(scale: tuple, ratio: tuple): + """Get parameters for ``crop`` for a random sized crop. + Args: + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect + ratio cropped + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + origin_h, origin_w = 1.0, 1.0 + area = 1.0 + + while True: + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = math.sqrt(target_area * aspect_ratio) + h = math.sqrt(target_area / aspect_ratio) + + if w <= origin_w and h <= origin_h: + i = random.uniform(0, origin_h - h) + j = random.uniform(0, origin_w - w) + return i, j, h, w + + def __call__(self, sample): + for scope_view in self.configs.keys(): + for index in range(eval(f'self.{scope_view}_crops_number')): + i, j, h, w = self.get_params(self.configs[scope_view].scale, + self.configs[scope_view].ratio) + for source in self.configs[scope_view]['size'].keys(): + img_key = f'{scope_view}_{source}_img' + img = sample[img_key][index] + i_img, h_img = int(round(i * img.shape[-2])), int( + round(h * img.shape[-2])) + j_img, w_img = int(round(j * img.shape[-1])), int( + round(w * img.shape[-1])) + img = F.resized_crop( + img, + i_img, + j_img, + h_img, + w_img, + self.configs[scope_view]['size'][source], + interpolation=T.InterpolationMode.BICUBIC + if Version(torchvision.__version__) >= Version('0.9.0') + else Image.BICUBIC) + sample[img_key][index] = img + img_key = f'{scope_view}_{source}_distance' + img = sample[img_key][index] + img = F.resized_crop( + img, + i_img, + j_img, + h_img, + w_img, + self.configs[scope_view]['size'][source], + interpolation=T.InterpolationMode.BICUBIC + if Version(torchvision.__version__) >= Version('0.9.0') + else Image.BICUBIC) + sample[img_key][index] = img + img_key = f'{scope_view}_lc' + img = sample[img_key][index] + i_img, h_img = int(round(i * img.shape[-2])), int( + round(h * img.shape[-2])) + j_img, w_img = int(round(j * img.shape[-1])), int( + round(w * img.shape[-1])) + img = F.resized_crop( + img, + i_img, + j_img, + h_img, + w_img, + self.configs[scope_view]['size']['s2'], + interpolation=T.InterpolationMode.NEAREST + if Version(torchvision.__version__) >= Version('0.9.0') + else Image.NEAREST) + sample[img_key][index] = img + return sample + + +class MSSSLRandomFlip(object): + + def __init__(self, configs, global_crops_number, local_crops_number, + scope_views, sources): + self.configs = configs + self.global_crops_number = global_crops_number + self.local_crops_number = local_crops_number + self.scope_views = scope_views + self.sources = sources + + def __call__(self, sample): + for scope_view in self.scope_views: + for index in range(eval(f'self.{scope_view}_crops_number')): + hflip = False + vflip = False + for direction, prob in zip(self.configs['directions'], + self.configs['probs']): + p = torch.rand(1) + if direction == 'horizontal' and p < prob: + hflip = True + if direction == 'vertical' and p < prob: + vflip = True + for source in self.sources: + img_key = f'{scope_view}_{source}_img' + img = sample[img_key][index] + if hflip: + img = F.hflip(img) + if vflip: + img = F.vflip(img) + sample[img_key][index] = img + img_key = f'{scope_view}_{source}_distance' + img = sample[img_key][index] + if hflip: + img = F.hflip(img) + if vflip: + img = F.vflip(img) + sample[img_key][index] = img + img_key = f'{scope_view}_lc' + img = sample[img_key][index] + if hflip: + img = F.hflip(img) + if vflip: + img = F.vflip(img) + sample[img_key][index] = img + return sample + + +class MSSSLRandomRotate(object): + + def __init__(self, configs, global_crops_number, local_crops_number, + scope_views, sources): + self.configs = configs + self.global_crops_number = global_crops_number + self.local_crops_number = local_crops_number + self.scope_views = scope_views + self.sources = sources + self.angle_set = [90, 180, 270] + + def __call__(self, sample): + for scope_view in self.scope_views: + for index in range(eval(f'self.{scope_view}_crops_number')): + p = torch.rand(1) + if p > self.configs['probs']: + continue + angle = self.angle_set[torch.randint(0, 3, (1,)).item()] + for source in self.sources: + img_key = f'{scope_view}_{source}_img' + img = sample[img_key][index] + img = F.rotate( + img, + angle, + interpolation=T.InterpolationMode.BILINEAR + if Version(torchvision.__version__) >= Version('0.9.0') + else Image.BILINEAR) + sample[img_key][index] = img + img_key = f'{scope_view}_{source}_distance' + img = sample[img_key][index] + img = F.rotate( + img, + angle, + interpolation=T.InterpolationMode.BILINEAR + if Version(torchvision.__version__) >= Version('0.9.0') + else Image.BILINEAR) + sample[img_key][index] = img + img_key = f'{scope_view}_lc' + img = sample[img_key][index] + img = F.rotate( + img, + angle, + interpolation=T.InterpolationMode.NEAREST + if Version(torchvision.__version__) >= Version('0.9.0') + else Image.NEAREST) + sample[img_key][index] = img + return sample + + +class MSSSLRandomColorJitter(object): + + def __init__(self, configs, global_crops_number, local_crops_number, + scope_views, sources): + self.configs = configs + self.global_crops_number = global_crops_number + self.local_crops_number = local_crops_number + self.scope_views = scope_views + self.sources = sources + self.color_prob = configs['color']['probs'] + self.brightness = configs['color']['brightness'] + self.contrast = configs['color']['contrast'] + self.saturation = configs['color']['saturation'] + self.hue = configs['color']['hue'] + self.gray_prob = configs['gray']['probs'] + + def __call__(self, sample): + for scope_view in self.scope_views: + for index in range(eval(f'self.{scope_view}_crops_number')): + for source in self.sources: + p = torch.rand(1) + if p >= self.color_prob: + continue + brightness_factor = random.uniform( + max(0, 1 - self.brightness), 1 + self.brightness) + contrast_factor = random.uniform(max(0, 1 - self.contrast), + 1 + self.contrast) + saturation_factor = random.uniform( + max(0, 1 - self.saturation), 1 + self.saturation) + hue_factor = random.uniform(-self.hue, self.hue) + img_key = f'{scope_view}_{source}_img' + img = sample[img_key][index] + img = F.adjust_brightness(img, brightness_factor) + img = F.adjust_contrast(img, contrast_factor) + img = F.adjust_saturation(img, saturation_factor) + img = F.adjust_hue(img, hue_factor) + p = torch.rand(1) + if p >= self.gray_prob: + continue + num_output_channels, _, _ = img.shape + img = F.rgb_to_grayscale( + img, num_output_channels=num_output_channels) + sample[img_key][index] = img + return sample + + +class MSSSLRandomGaussianBlur(object): + + def __init__(self, configs, global_crops_number, local_crops_number, + scope_views, sources): + self.configs = configs + self.global_crops_number = global_crops_number + self.local_crops_number = local_crops_number + self.scope_views = scope_views + self.sources = sources + self.prob = configs['probs'] + self.sigma = configs['sigma'] + + def __call__(self, sample): + for scope_view in self.scope_views: + for index in range(eval(f'self.{scope_view}_crops_number')): + for source in self.sources: + p = self.prob[scope_view] + if scope_view == 'global': + p = p[index] + if torch.rand(1) >= p: + continue + sigma = random.uniform(self.sigma[0], self.sigma[1]) + kernel_size = max(int(2 * ((sigma - 0.8) / 0.3 + 1) + 1), + 1) + if kernel_size % 2 == 0: + kernel_size += 1 + img_key = f'{scope_view}_{source}_img' + img = sample[img_key][index] + img = F.gaussian_blur(img, kernel_size, sigma) + sample[img_key][index] = img + return sample + + +class MSSSLRandomSolarize(object): + + def __init__(self, configs, global_crops_number, scope_views, sources): + self.configs = configs + self.global_crops_number = global_crops_number + self.scope_views = scope_views + self.sources = sources + self.prob = configs['probs'] + self.threshold = 130 + + def __call__(self, sample): + for scope_view in self.scope_views: + for index in range(eval(f'self.{scope_view}_crops_number')): + for source in self.sources: + if index != 1: + continue + if torch.rand(1) >= self.prob: + continue + img_key = f'{scope_view}_{source}_img' + img = sample[img_key][index] + img = F.solarize(img, self.threshold) + sample[img_key][index] = img + return sample + + +class MSSSLRandomChannelOut(object): + + def __init__(self, configs, global_crops_number, local_crops_number, + scope_views, sources, mean): + self.configs = configs + self.global_crops_number = global_crops_number + self.local_crops_number = local_crops_number + self.scope_views = scope_views + self.sources = sources + self.mean = mean + + def __call__(self, sample): + for scope_view in self.scope_views: + for index in range(eval(f'self.{scope_view}_crops_number')): + for source in self.sources: + out_num = self.configs[scope_view]['out_num'] + out_num = random.randint(out_num[0], out_num[1] + 1) + out_index = sorted( + random.choice(len(self.mean), out_num, replace=False)) + img_key = f'{scope_view}_{source}_img' + img = sample[img_key][index] + for i in out_index: + img[i] = int(self.mean[i]) + sample[img_key][index] = img + return sample + + +class MaskGenerator: + + def __init__(self, + input_size=192, + mask_patch_size=32, + model_patch_size=4, + mask_ratio=0.6): + self.input_size = input_size + self.mask_patch_size = mask_patch_size + self.model_patch_size = model_patch_size + self.mask_ratio = mask_ratio + + assert self.input_size % self.mask_patch_size == 0 + assert self.mask_patch_size % self.model_patch_size == 0 + + self.rand_size = self.input_size // self.mask_patch_size + self.scale = self.mask_patch_size // self.model_patch_size + + self.token_count = self.rand_size**2 + self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) + + def __call__(self): + mask_idx = np.random.permutation(self.token_count)[:self.mask_count] + mask = np.zeros(self.token_count, dtype=int) + mask[mask_idx] = 1 + + mask = mask.reshape((self.rand_size, self.rand_size)) + mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) + + return mask + + +class DetResize: + """Resize images & bbox & mask. + + This transform resizes the input image to some scale. Bboxes and masks are + then resized with the same scale factor. If the input dict contains the key + "scale", then the scale in the input dict is used, otherwise the specified + scale in the init method is used. If the input dict contains the key + "scale_factor" (if MultiScaleFlipAug does not give img_scale but + scale_factor), the actual scale will be computed by image shape and + scale_factor. + + `img_scale` can either be a tuple (single-scale) or a list of tuple + (multi-scale). There are 3 multiscale modes: + + - ``ratio_range is not None``: randomly sample a ratio from the ratio \ + range and multiply it with the image scale. + - ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \ + sample a scale from the multiscale range. + - ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \ + sample a scale from multiple scales. + + Args: + img_scale (tuple or list[tuple]): Images scales for resizing. + multiscale_mode (str): Either "range" or "value". + ratio_range (tuple[float]): (min_ratio, max_ratio) + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different sample. Defaults + to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + override (bool, optional): Whether to override `scale` and + `scale_factor` so as to call resize twice. Default False. If True, + after the first resizing, the existed `scale` and `scale_factor` + will be ignored so the second resizing can be allowed. + This option is a work-around for multiple times of resize in DETR. + Defaults to False. + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=True, + bbox_clip_border=True, + backend='cv2', + interpolation='bilinear', + override=False): + if img_scale is None: + self.img_scale = None + else: + if isinstance(img_scale, list): + self.img_scale = img_scale + else: + self.img_scale = [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) + + if ratio_range is not None: + # mode 1: given a scale and a range of image ratio + assert len(self.img_scale) == 1 + else: + # mode 2: given multiple scales or a range of scales + assert multiscale_mode in ['value', 'range'] + + self.backend = backend + self.multiscale_mode = multiscale_mode + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + self.interpolation = interpolation + self.override = override + self.bbox_clip_border = bbox_clip_border + + @staticmethod + def random_select(img_scales): + """Randomly select an img_scale from given candidates. + + Args: + img_scales (list[tuple]): Images scales for selection. + + Returns: + (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \ + where ``img_scale`` is the selected image scale and \ + ``scale_idx`` is the selected index in the given candidates. + """ + + assert mmcv.is_list_of(img_scales, tuple) + scale_idx = np.random.randint(len(img_scales)) + img_scale = img_scales[scale_idx] + return img_scale, scale_idx + + @staticmethod + def random_sample(img_scales): + """Randomly sample an img_scale when ``multiscale_mode=='range'``. + + Args: + img_scales (list[tuple]): Images scale range for sampling. + There must be two tuples in img_scales, which specify the lower + and upper bound of image scales. + + Returns: + (tuple, None): Returns a tuple ``(img_scale, None)``, where \ + ``img_scale`` is sampled scale and None is just a placeholder \ + to be consistent with :func:`random_select`. + """ + + assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint(min(img_scale_long), + max(img_scale_long) + 1) + short_edge = np.random.randint(min(img_scale_short), + max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale, None + + @staticmethod + def random_sample_ratio(img_scale, ratio_range): + """Randomly sample an img_scale when ``ratio_range`` is specified. + + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``img_scale`` to + generate sampled scale. + + Args: + img_scale (tuple): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``img_scale``. + + Returns: + (tuple, None): Returns a tuple ``(scale, None)``, where \ + ``scale`` is sampled ratio multiplied with ``img_scale`` and \ + None is just a placeholder to be consistent with \ + :func:`random_select`. + """ + + assert isinstance(img_scale, tuple) and len(img_scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) + return scale, None + + def _random_scale(self, sample): + """Randomly sample an img_scale according to ``ratio_range`` and + ``multiscale_mode``. + + If ``ratio_range`` is specified, a ratio will be sampled and be + multiplied with ``img_scale``. + If multiple scales are specified by ``img_scale``, a scale will be + sampled according to ``multiscale_mode``. + Otherwise, single scale will be used. + + Args: + sample (Sample): Result from :obj:`dataset`. + + Returns: + dict: Two new keys 'scale` and 'scale_idx` are added into \ + ``sample``, which would be used by subsequent pipelines. + """ + + if self.ratio_range is not None: + scale, scale_idx = self.random_sample_ratio( + self.img_scale[0], self.ratio_range) + elif len(self.img_scale) == 1: + scale, scale_idx = self.img_scale[0], 0 + elif self.multiscale_mode == 'range': + scale, scale_idx = self.random_sample(self.img_scale) + elif self.multiscale_mode == 'value': + scale, scale_idx = self.random_select(self.img_scale) + else: + raise NotImplementedError + + sample.scale = scale + sample.scale_idx = scale_idx + + def _resize_img(self, sample): + """Resize images with ``sample.scale``.""" + if self.keep_ratio: + img, scale_factor = mmcv.imrescale( + sample[sample.img_field], + sample.scale, + return_scale=True, + interpolation=self.interpolation, + backend=self.backend) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = sample[sample.img_field].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = mmcv.imresize( + sample[sample.img_field], + sample.scale, + return_scale=True, + interpolation=self.interpolation, + backend=self.backend) + sample[sample.img_field] = img + + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], + dtype=np.float32) + sample.img_shape = img.shape + # in case that there is no padding + sample.pad_shape = img.shape + sample.scale_factor = scale_factor + sample.keep_ratio = self.keep_ratio + + def _resize_bboxes(self, sample): + """Resize bounding boxes with ``sample.scale_factor``.""" + for key in sample.get('bbox_fields', []): + bboxes = sample[key] * sample.scale_factor + if self.bbox_clip_border: + img_shape = sample.img_shape + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0]) + sample[key] = bboxes + + def _resize_masks(self, sample): + """Resize masks with ``sample.scale``""" + for key in sample.get('mask_fields', []): + if sample[key] is None: + continue + if self.keep_ratio: + sample[key] = sample[key].rescale(sample.scale) + else: + sample[key] = sample[key].resize(sample.img_shape[:2]) + + def _resize_seg(self, sample): + """Resize semantic segmentation map with ``sample.scale``.""" + for key in sample.get('seg_fields', []): + if self.keep_ratio: + gt_seg = mmcv.imrescale(sample[key], + sample.scale, + interpolation='nearest', + backend=self.backend) + else: + gt_seg = mmcv.imresize(sample[key], + sample.scale, + interpolation='nearest', + backend=self.backend) + sample[key] = gt_seg + + def __call__(self, sample): + """Call function to resize images, bounding boxes, masks, semantic + segmentation map. + + Args: + sample (dict): Result dict from loading pipeline. + + Returns: + dict: Resized sample, 'img_shape', 'pad_shape', 'scale_factor', \ + 'keep_ratio' keys are added into result dict. + """ + + sample.scale_idx = None + if 'scale' not in sample: + if 'scale_factor' in sample: + img_shape = sample.hr_img.shape[:2] + scale_factor = sample.scale_factor + assert isinstance(scale_factor, float) + sample.scale = tuple( + [int(x * scale_factor) for x in img_shape][::-1]) + else: + self._random_scale(sample) + else: + if not self.override: + assert 'scale_factor' not in sample, ( + 'scale and scale_factor cannot be both set.') + else: + sample.pop('scale') + if 'scale_factor' in sample: + sample.pop('scale_factor') + self._random_scale(sample) + + self._resize_img(sample) + self._resize_bboxes(sample) + self._resize_masks(sample) + self._resize_seg(sample) + return sample + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(img_scale={self.img_scale}, ' + repr_str += f'multiscale_mode={self.multiscale_mode}, ' + repr_str += f'ratio_range={self.ratio_range}, ' + repr_str += f'keep_ratio={self.keep_ratio}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + +class DetRandomFlip: + """Flip the image & bbox & mask. + + If the input dict contains the key "flip", then the flag will be used, + otherwise it will be randomly decided by a ratio specified in the init + method. + + When random flip is enabled, ``flip_ratio``/``direction`` can either be a + float/string or tuple of float/string. There are 3 flip modes: + + - ``flip_ratio`` is float, ``direction`` is string: the image will be + ``direction``ly flipped with probability of ``flip_ratio`` . + E.g., ``flip_ratio=0.5``, ``direction='horizontal'``, + then image will be horizontally flipped with probability of 0.5. + - ``flip_ratio`` is float, ``direction`` is list of string: the image will + be ``direction[i]``ly flipped with probability of + ``flip_ratio/len(direction)``. + E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``, + then image will be horizontally flipped with probability of 0.25, + vertically with probability of 0.25. + - ``flip_ratio`` is list of float, ``direction`` is list of string: + given ``len(flip_ratio) == len(direction)``, the image will + be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``. + E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal', + 'vertical']``, then image will be horizontally flipped with probability + of 0.3, vertically with probability of 0.5. + + Args: + flip_ratio (float | list[float], optional): The flipping probability. + Default: None. + direction(str | list[str], optional): The flipping direction. Options + are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'. + If input is a list, the length must equal ``flip_ratio``. Each + element in ``flip_ratio`` indicates the flip probability of + corresponding direction. + """ + + def __init__(self, flip_ratio=None, direction='horizontal'): + if isinstance(flip_ratio, list): + assert mmcv.is_list_of(flip_ratio, float) + assert 0 <= sum(flip_ratio) <= 1 + elif isinstance(flip_ratio, float): + assert 0 <= flip_ratio <= 1 + elif flip_ratio is None: + pass + else: + raise ValueError('flip_ratios must be None, float, ' + 'or list of float') + self.flip_ratio = flip_ratio + + valid_directions = ['horizontal', 'vertical', 'diagonal'] + if isinstance(direction, str): + assert direction in valid_directions + elif isinstance(direction, list): + assert mmcv.is_list_of(direction, str) + assert set(direction).issubset(set(valid_directions)) + else: + raise ValueError('direction must be either str or list of str') + self.direction = direction + + if isinstance(flip_ratio, list): + assert len(self.flip_ratio) == len(self.direction) + + def bbox_flip(self, bboxes, img_shape, direction): + """Flip bboxes horizontally. + + Args: + bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k) + img_shape (tuple[int]): Image shape (height, width) + direction (str): Flip direction. Options are 'horizontal', + 'vertical'. + + Returns: + numpy.ndarray: Flipped bounding boxes. + """ + + assert bboxes.shape[-1] % 4 == 0 + flipped = bboxes.copy() + if direction == 'horizontal': + w = img_shape[1] + flipped[..., 0::4] = w - bboxes[..., 2::4] + flipped[..., 2::4] = w - bboxes[..., 0::4] + elif direction == 'vertical': + h = img_shape[0] + flipped[..., 1::4] = h - bboxes[..., 3::4] + flipped[..., 3::4] = h - bboxes[..., 1::4] + elif direction == 'diagonal': + w = img_shape[1] + h = img_shape[0] + flipped[..., 0::4] = w - bboxes[..., 2::4] + flipped[..., 1::4] = h - bboxes[..., 3::4] + flipped[..., 2::4] = w - bboxes[..., 0::4] + flipped[..., 3::4] = h - bboxes[..., 1::4] + else: + raise ValueError(f"Invalid flipping direction '{direction}'") + return flipped + + def __call__(self, sample): + """Call function to flip bounding boxes, masks, semantic segmentation + maps. + + Args: + sample (Sample): Result from loading pipeline. + + Returns: + Sample: Flipped sample, 'flip', 'flip_direction' keys are added \ + into sample. + """ + + if 'flip' not in sample: + if isinstance(self.direction, list): + # None means non-flip + direction_list = self.direction + [None] + else: + # None means non-flip + direction_list = [self.direction, None] + + if isinstance(self.flip_ratio, list): + non_flip_ratio = 1 - sum(self.flip_ratio) + flip_ratio_list = self.flip_ratio + [non_flip_ratio] + else: + non_flip_ratio = 1 - self.flip_ratio + # exclude non-flip + single_ratio = self.flip_ratio / (len(direction_list) - 1) + flip_ratio_list = [single_ratio] * (len(direction_list) - + 1) + [non_flip_ratio] + + cur_dir = np.random.choice(direction_list, p=flip_ratio_list) + + sample.flip = cur_dir is not None + if 'flip_direction' not in sample: + sample.flip_direction = cur_dir + if sample.flip: + # flip image + sample[sample.img_field] = mmcv.imflip( + sample[sample.img_field], direction=sample.flip_direction) + # flip bboxes + for key in sample.bbox_fields: + sample[key] = self.bbox_flip(sample[key], sample.img_shape, + sample.flip_direction) + # flip masks + for key in sample.mask_fields: + sample[key] = sample[key].flip(sample.flip_direction) + + # flip segs + for key in sample.seg_fields: + sample[key] = mmcv.imflip(sample[key], + direction=sample.flip_direction) + return sample + + def __repr__(self): + return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})' + + +class DetRandomCrop: + """Random crop the image & bboxes & masks. + + The absolute `crop_size` is sampled based on `crop_type` and `image_size`, + then the cropped sample are generated. + + Args: + crop_size (tuple): The relative ratio or absolute pixels of + height and width. + crop_type (str, optional): one of "relative_range", "relative", + "absolute", "absolute_range". "relative" randomly crops + (h * crop_size[0], w * crop_size[1]) part from an input of size + (h, w). "relative_range" uniformly samples relative crop size from + range [crop_size[0], 1] and [crop_size[1], 1] for height and width + respectively. "absolute" crops from an input with absolute size + (crop_size[0], crop_size[1]). "absolute_range" uniformly samples + crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w + in range [crop_size[0], min(w, crop_size[1])]. Default "absolute". + allow_negative_crop (bool, optional): Whether to allow a crop that does + not contain any bbox area. Default False. + recompute_bbox (bool, optional): Whether to re-compute the boxes based + on cropped instance masks. Default False. + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. + + Note: + - If the image is smaller than the absolute crop size, return the + original image. + - The keys for bboxes, labels and masks must be aligned. That is, + `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and + `gt_bboxes_ignore` corresponds to `gt_labels_ignore` and + `gt_masks_ignore`. + - If the crop does not contain any gt-bbox region and + `allow_negative_crop` is set to False, skip this image. + """ + + def __init__(self, + crop_size, + crop_type='absolute', + allow_negative_crop=False, + recompute_bbox=False, + bbox_clip_border=True): + if crop_type not in [ + 'relative_range', 'relative', 'absolute', 'absolute_range' + ]: + raise ValueError(f'Invalid crop_type {crop_type}.') + if crop_type in ['absolute', 'absolute_range']: + assert crop_size[0] > 0 and crop_size[1] > 0 + assert isinstance(crop_size[0], int) and isinstance( + crop_size[1], int) + else: + assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1 + self.crop_size = crop_size + self.crop_type = crop_type + self.allow_negative_crop = allow_negative_crop + self.bbox_clip_border = bbox_clip_border + self.recompute_bbox = recompute_bbox + # The key correspondence from bboxes to labels and masks. + self.bbox2label = { + 'gt_bboxes': 'gt_labels', + 'gt_bboxes_ignore': 'gt_labels_ignore' + } + self.bbox2mask = { + 'gt_bboxes': 'gt_masks', + 'gt_bboxes_ignore': 'gt_masks_ignore' + } + + def _crop_data(self, sample, crop_size, allow_negative_crop): + """Function to randomly crop images, bounding boxes, masks, semantic + segmentation maps. + + Args: + sample (Sample): Result from loading pipeline. + crop_size (tuple): Expected absolute size after cropping, (h, w). + allow_negative_crop (bool): Whether to allow a crop that does not + contain any bbox area. Default to False. + + Returns: + dict: Randomly cropped Sample data, 'img_shape' key in sample is + updated according to crop size. + """ + max_try_times = 20 + crop_times = 0 + while True: + crop_times += 1 + assert crop_size[0] > 0 and crop_size[1] > 0 + img = sample[sample.img_field] + margin_h = max(img.shape[0] - crop_size[0], 0) + margin_w = max(img.shape[1] - crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + crop_size[1] + # crop bboxes accordingly and clip to the image boundary + is_valid = False + for key in sample.get('bbox_fields', []): + # e.g. gt_bboxes and gt_bboxes_ignore + bbox_offset = np.array( + [offset_w, offset_h, offset_w, offset_h], dtype=np.float32) + bboxes = sample[key] - bbox_offset + if self.bbox_clip_border: + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, crop_size[1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, crop_size[0]) + valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & (bboxes[:, 3] > + bboxes[:, 1]) + sample[key] = bboxes[valid_inds, :] + # label fields. e.g. gt_labels and gt_labels_ignore + label_key = self.bbox2label.get(key) + if label_key in sample: + sample[label_key] = sample[label_key][valid_inds] + # mask fields, e.g. gt_masks and gt_masks_ignore + mask_key = self.bbox2mask.get(key) + if mask_key in sample: + sample[mask_key] = sample[mask_key][ + valid_inds.nonzero()[0]].crop( + np.asarray([crop_x1, crop_y1, crop_x2, crop_y2])) + if self.recompute_bbox: + sample[key] = sample[mask_key].get_bboxes() + if valid_inds.any() and key == 'gt_bboxes': + is_valid = True + if (crop_times + == max_try_times) or is_valid or allow_negative_crop: + # crop the image + img = sample[sample.img_field] + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + img_shape = img.shape + sample[sample.img_field] = img + sample.img_shape = img_shape + # crop semantic seg + for key in sample.get('seg_fields', []): + sample[key] = sample[key][crop_y1:crop_y2, crop_x1:crop_x2] + break + return sample + + def _get_crop_size(self, image_size): + """Randomly generates the absolute crop size based on `crop_type` and + `image_size`. + + Args: + image_size (tuple): (h, w). + + Returns: + crop_size (tuple): (crop_h, crop_w) in absolute pixels. + """ + h, w = image_size + if self.crop_type == 'absolute': + return (min(self.crop_size[0], h), min(self.crop_size[1], w)) + elif self.crop_type == 'absolute_range': + assert self.crop_size[0] <= self.crop_size[1] + crop_h = np.random.randint(min(h, self.crop_size[0]), + min(h, self.crop_size[1]) + 1) + crop_w = np.random.randint(min(w, self.crop_size[0]), + min(w, self.crop_size[1]) + 1) + return crop_h, crop_w + elif self.crop_type == 'relative': + crop_h, crop_w = self.crop_size + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + elif self.crop_type == 'relative_range': + crop_size = np.asarray(self.crop_size, dtype=np.float32) + crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size) + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + + def __call__(self, sample): + """Call function to randomly crop images, bounding boxes, masks, + semantic segmentation maps. + + Args: + sample (Sample): Result from loading pipeline. + + Returns: + Sample: Randomly cropped Sample data, 'img_shape' key in sample is + updated according to crop size. + """ + image_size = sample[sample.img_field].shape[:2] + crop_size = self._get_crop_size(image_size) + sample = self._crop_data(sample, crop_size, self.allow_negative_crop) + sample.bbox_num = sample.gt_bboxes.shape[0] + return sample + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(crop_size={self.crop_size}, ' + repr_str += f'crop_type={self.crop_type}, ' + repr_str += f'allow_negative_crop={self.allow_negative_crop}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + +class AutoAugment: + """Auto augmentation. + + This data augmentation is proposed in `Learning Data Augmentation + Strategies for Object Detection `_. + + TODO: Implement 'Shear', 'Sharpness' and 'Rotate' transforms + + Args: + policies (list[list[transformer]]): The policies of auto augmentation. Each + policy in ``policies`` is a specific augmentation policy, and is + composed by several augmentations (dict). When AutoAugment is + called, a random policy in ``policies`` will be selected to + augment images. + """ + + def __init__(self, policies): + assert isinstance(policies, list) and len(policies) > 0, \ + 'Policies must be a non-empty list.' + for policy in policies: + assert isinstance(policy, list) and len(policy) > 0, \ + 'Each policy in policies must be a non-empty list.' + + self.policies = copy.deepcopy(policies) + self.transforms = [Compose(policy) for policy in self.policies] + + def __call__(self, sample): + transform = np.random.choice(self.transforms) + return transform(sample) diff --git a/lib/evaluation/segm_eval_base.py b/lib/evaluation/segm_eval_base.py new file mode 100644 index 0000000..528e0c5 --- /dev/null +++ b/lib/evaluation/segm_eval_base.py @@ -0,0 +1,208 @@ +import os +import argparse + +import numpy as np +import pandas as pd +from skimage.transform import resize + +from skimage import io +from multiprocessing import Pool +from functools import partial +import logging +import datetime + +def get_confusion_matrix(pred, gt, num_class): + assert pred.shape == gt.shape, f"pred.shape: {pred.shape} != gt.shape: {gt.shape}" + mask = (gt >= 0) & (gt < num_class) # 去掉为0的背景类别 + label = num_class * gt[mask] + pred[mask] + count = np.bincount(label, minlength=num_class**2) + confusion_matrix = count.reshape(num_class, num_class) + return confusion_matrix + +def get_miou(confusion_matrix): + diagonal_elements = np.diag(confusion_matrix) + column_sums = np.sum(confusion_matrix, axis=0) + row_sums = np.sum(confusion_matrix, axis=1) + ious = diagonal_elements/(column_sums + row_sums - diagonal_elements) + m_iou = np.nanmean(ious) + return m_iou + +def get_mprecison(confusion_matrix): + diagonal_elements = np.diag(confusion_matrix) + column_sums = np.sum(confusion_matrix, axis=0) + precisions = diagonal_elements / (column_sums + 1e-06) + m_precision = np.nanmean(precisions) + return m_precision + +def get_mrecall(confusion_matrix): + diagonal_elements = np.diag(confusion_matrix) + row_sums = np.sum(confusion_matrix, axis=1) + recalls= diagonal_elements / (row_sums + 1e-06) + m_recall = np.nanmean(recalls) + return m_recall + +def get_macc(confusion_matrix): + ''' + acc = tp/tp+fn 就是recall + ''' + m_recall = get_mrecall(confusion_matrix) + return m_recall + +def get_per_class_iou(confusion_matrix): + intersection = np.diag(confusion_matrix) + union = np.sum(confusion_matrix, axis=0) + np.sum(confusion_matrix, axis=1) - intersection + iou = intersection / (union.astype(np.float32) + 1e-6) + return iou + +def get_per_class_acc(confusion_matrix): + total_acc = np.diag(confusion_matrix) / (np.sum(confusion_matrix, axis=1).astype(np.float32) + 1e-6) + return total_acc + + +def post_process_segm_output(segm, colors, dist_type='abs'): + """ + Post-processing to turn output segm image to class index map using NumPy + Args: + segm: (H, W, 3) + Returns: + class_map: (H, W) + """ + palette = np.array(colors) + segm = segm.astype(np.float32) # (h, w, 3) + h, w, k = segm.shape[0], segm.shape[1], palette.shape[0] + if dist_type == 'abs': + dist = np.abs(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3)) # (h, w, k) + elif dist_type == 'square': + dist = np.power(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3), 2) # (h, w, k) + elif dist_type == 'mean': + dist_abs = np.abs(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3)) # (h, w, k) + dist_square = np.power(segm.reshape(h, w, 1, 3) - palette.reshape(1, 1, k, 3), 2) # (h, w, k) + dist = (dist_abs + dist_square) / 2. + else: + raise NotImplementedError + + dist = np.sum(dist, axis=-1) + pred = np.argmin(dist, axis=-1).astype(np.int) + return pred + +def get_args_parser(): + parser = argparse.ArgumentParser('semantic segmentation evaluation', add_help=False) + parser.add_argument('--pred_dir', type=str, help='dir to pred', required=True) + parser.add_argument('--gt_dir', type=str, help='dir to gt', required=True) + parser.add_argument('--gt_list_path', type=str, help='dir to gt_list_path', required=True) + parser.add_argument('--gt_suffix', type=str, help='suffix to gt', required=True) + parser.add_argument('--dataset_name', type=str, help='dataset name', required=True) + parser.add_argument('--model_name', type=str, help='model name', required=True) + parser.add_argument('--dist_type', type=str, help='dist type', + default='abs', choices=['abs', 'square', 'mean']) + return parser.parse_args() + +def process_file(file_dict, pred_dir, gt_dir, args, num_class): + filename = file_dict['file_name'] + file_cls = file_dict['file_cls'] + + gt = io.imread(os.path.join(gt_dir, filename)) + gt_index = gt.copy() + gt_index[gt_index != file_cls] = 0 + gt_index[gt_index == file_cls] = 1 + + try: + pred = io.imread(os.path.join(pred_dir, filename.replace('.png', f'-{file_cls}.png'))) + pred = resize(pred, gt.shape[-2:], anti_aliasing=False, mode='reflect', order=0) + + if len(pred.shape) == 3: + pred_index = pred[:,:,0].copy() + else: + pred_index = pred.copy() + pred_index[pred_index<=127] = 0 + pred_index[pred_index>127] = 1 + except: + logging.info(filename.replace('.png', f'_{file_cls}.png'), 'not found!') + pred_index = gt_index.copy() + + pred_index = pred_index.flatten() + gt_index = gt_index.flatten() + confusion_matrix = get_confusion_matrix(pred_index, gt_index, num_class) + return file_cls, confusion_matrix + +if __name__ == '__main__': + args = get_args_parser() + dataset_name = args.dataset_name + pred_dir = args.pred_dir + gt_dir = args.gt_dir + gt_list_path = args.gt_list_path + dist_type = args.dist_type + model_name = args.model_name + + current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + + os.makedirs('logs/eval', exist_ok=True) + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(f'logs/eval/eval_{model_name}_{dataset_name}_{current_time}.log'), + logging.StreamHandler() + ] + ) + + output_folder = os.path.join(pred_dir, f'eval_{dataset_name}') + os.makedirs(output_folder, exist_ok=True) + + + num_class = 2 + + with open(gt_list_path, 'r') as f: + file_cls_list = f.readlines() + + file_list = [] + for i in file_cls_list: + i = i.strip() + file_name = i[:-3] + file_cls = i[-2:] + file_list.append({'file_name': file_name, 'file_cls': int(file_cls)}) + + all_pred_labels = [] + all_gt_labels = [] + + process_file_partial = partial(process_file, pred_dir=pred_dir,gt_dir=gt_dir, args=args, num_class=num_class) + + pool = Pool() + + outputs = pool.map(process_file_partial, file_list) + + pool.close() + pool.join() + logging.info(f'len outputs: {len(outputs)}') + confusion_matrix_dict = {} + for cls, confusion_matrix in outputs: + if cls in confusion_matrix_dict.keys(): + confusion_matrix_dict[cls] += confusion_matrix + else: + confusion_matrix_dict[cls] = confusion_matrix + + class_list = [] + iou_list = [] + acc_list = [] + for cls, confusion_matrix in confusion_matrix_dict.items(): + ious = get_per_class_iou(confusion_matrix) + accs = get_per_class_acc(confusion_matrix) + logging.info(f'cls: {cls}, ious: {ious}, accs: {accs}') + class_list.append(cls) + iou_list.append(ious[1]) + acc_list.append(accs[1]) + + miou = np.mean(iou_list) + macc = np.mean(acc_list) + + df_metrics = pd.DataFrame({ + 'Class': class_list + ['Mean'], + 'IoU': iou_list + [miou], + 'Accuracy': acc_list + [macc], + }) + pd.set_option('display.float_format', '{:.4f}%'.format) + logging.info(df_metrics) + pd.reset_option('display.float_format') + df_metrics.to_csv(os.path.join(output_folder, 'eval.csv'), index=False, float_format='%.4f') + + diff --git a/lib/models/__init__.py b/lib/models/__init__.py new file mode 100644 index 0000000..e9aeb60 --- /dev/null +++ b/lib/models/__init__.py @@ -0,0 +1,7 @@ +from .segmentors import SkySensePP +from .losses import (ModalityVAELoss, RecLoss) +from .metrics import (SemMetric) + +__all__ = [ + 'SkySensePP', 'ModalityVAELoss', 'RecLoss', 'SemMetric' +] diff --git a/lib/models/backbones/__init__.py b/lib/models/backbones/__init__.py new file mode 100644 index 0000000..b545c09 --- /dev/null +++ b/lib/models/backbones/__init__.py @@ -0,0 +1,14 @@ +from .swin_v2 import SwinTransformerV2MSL +from .vit import VisionTransformerMSL + +__all__ = [ + 'SwinTransformerV2MSL', 'VisionTransformerMSL' +] + +type_mapping = { + 'SwinTransformerV2MSL': SwinTransformerV2MSL, + 'VisionTransformerMSL': VisionTransformerMSL +} + +def build_backbone(type, **kwargs): + return type_mapping[type](**kwargs) diff --git a/lib/models/backbones/swin_v2.py b/lib/models/backbones/swin_v2.py new file mode 100644 index 0000000..25b5f16 --- /dev/null +++ b/lib/models/backbones/swin_v2.py @@ -0,0 +1,702 @@ +from copy import deepcopy +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, PatchEmbed +from mmcv.cnn.utils.weight_init import trunc_normal_ +from mmcv.runner.base_module import BaseModule, ModuleList +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmcls.models.utils import (PatchMerging, ShiftWindowMSA, WindowMSAV2, + resize_pos_embed, to_2tuple) +from mmcls.models.backbones.base_backbone import BaseBackbone +from mmcv.runner import (CheckpointLoader, + load_state_dict) +from mmcv.cnn.bricks.transformer import MultiheadAttention + +class SwinBlockV2(BaseModule): + """Swin Transformer V2 block. Use post normalization. + + Args: + embed_dims (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + shift (bool): Shift the attention window or not. Defaults to False. + extra_norm (bool): Whether add extra norm at the end of main branch. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop_path (float): The drop path rate after attention and ffn. + Defaults to 0. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + attn_cfgs (dict): The extra config of Shift Window-MSA. + Defaults to empty dict. + ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict. + norm_cfg (dict): The config of norm layers. + Defaults to ``dict(type='LN')``. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pretrained_window_size (int): Window size in pretrained. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + num_heads, + window_size=8, + shift=False, + extra_norm=False, + ffn_ratio=4., + drop_path=0., + pad_small_map=False, + attn_cfgs=dict(), + ffn_cfgs=dict(), + norm_cfg=dict(type='LN'), + with_cp=False, + pretrained_window_size=0, + init_cfg=None): + + super(SwinBlockV2, self).__init__(init_cfg) + self.with_cp = with_cp + self.extra_norm = extra_norm + + _attn_cfgs = { + 'embed_dims': embed_dims, + 'num_heads': num_heads, + 'shift_size': window_size // 2 if shift else 0, + 'window_size': window_size, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'pad_small_map': pad_small_map, + **attn_cfgs + } + # use V2 attention implementation + _attn_cfgs.update( + window_msa=WindowMSAV2, + msa_cfg=dict( + pretrained_window_size=to_2tuple(pretrained_window_size))) + self.attn = ShiftWindowMSA(**_attn_cfgs) + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + + _ffn_cfgs = { + 'embed_dims': embed_dims, + 'feedforward_channels': int(embed_dims * ffn_ratio), + 'num_fcs': 2, + 'ffn_drop': 0, + 'dropout_layer': dict(type='DropPath', drop_prob=drop_path), + 'act_cfg': dict(type='GELU'), + 'add_identity': False, + **ffn_cfgs + } + self.ffn = FFN(**_ffn_cfgs) + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + # add extra norm for every n blocks in huge and giant model + if self.extra_norm: + self.norm3 = build_norm_layer(norm_cfg, embed_dims)[1] + + def forward(self, x, hw_shape): + + def _inner_forward(x): + # Use post normalization + identity = x + x = self.attn(x, hw_shape) + x = self.norm1(x) + x = x + identity + + identity = x + x = self.ffn(x) + x = self.norm2(x) + x = x + identity + + if self.extra_norm: + x = self.norm3(x) + + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class SwinBlockV2Sequence(BaseModule): + """Module with successive Swin Transformer blocks and downsample layer. + + Args: + embed_dims (int): Number of input channels. + depth (int): Number of successive swin transformer blocks. + num_heads (int): Number of attention heads. + window_size (int): The height and width of the window. Defaults to 7. + downsample (bool): Downsample the output of blocks by patch merging. + Defaults to False. + downsample_cfg (dict): The extra config of the patch merging layer. + Defaults to empty dict. + drop_paths (Sequence[float] | float): The drop path rate in each block. + Defaults to 0. + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + extra_norm_every_n_blocks (int): Add extra norm at the end of main + branch every n blocks. Defaults to 0, which means no needs for + extra norm layer. + pretrained_window_size (int): Window size in pretrained. + init_cfg (dict, optional): The extra config for initialization. + Defaults to None. + """ + + def __init__(self, + embed_dims, + depth, + num_heads, + window_size=8, + downsample=False, + downsample_cfg=dict(), + drop_paths=0., + block_cfgs=dict(), + with_cp=False, + pad_small_map=False, + extra_norm_every_n_blocks=0, + pretrained_window_size=0, + init_cfg=None): + super().__init__(init_cfg) + + if not isinstance(drop_paths, Sequence): + drop_paths = [drop_paths] * depth + + if not isinstance(block_cfgs, Sequence): + block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)] + + if downsample: + self.out_channels = 2 * embed_dims + _downsample_cfg = { + 'in_channels': embed_dims, + 'out_channels': self.out_channels, + 'norm_cfg': dict(type='LN'), + **downsample_cfg + } + self.downsample = PatchMerging(**_downsample_cfg) + else: + self.out_channels = embed_dims + self.downsample = None + + self.blocks = ModuleList() + for i in range(depth): + extra_norm = True if extra_norm_every_n_blocks and \ + (i + 1) % extra_norm_every_n_blocks == 0 else False + _block_cfg = { + 'embed_dims': self.out_channels, + 'num_heads': num_heads, + 'window_size': window_size, + 'shift': False if i % 2 == 0 else True, + 'extra_norm': extra_norm, + 'drop_path': drop_paths[i], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + 'pretrained_window_size': pretrained_window_size, + **block_cfgs[i] + } + block = SwinBlockV2(**_block_cfg) + self.blocks.append(block) + + def forward(self, x, in_shape): + if self.downsample: + x, out_shape = self.downsample(x, in_shape) + else: + out_shape = in_shape + + for block in self.blocks: + x = block(x, out_shape) + + return x, out_shape + + +class SwinTransformerV2(BaseBackbone): + """Swin Transformer V2. + + A PyTorch implement of : `Swin Transformer V2: + Scaling Up Capacity and Resolution + `_ + + Inspiration from + https://github.com/microsoft/Swin-Transformer + + Args: + arch (str | dict): Swin Transformer architecture. If use string, choose + from 'tiny', 'small', 'base' and 'large'. If use dict, it should + have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **num_heads** (List[int]): The number of heads in attention + modules of each stage. + - **extra_norm_every_n_blocks** (int): Add extra norm at the end + of main branch every n blocks. + + Defaults to 'tiny'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 4. + in_channels (int): The num of input channels. Defaults to 3. + window_size (int | Sequence): The height and width of the window. + Defaults to 7. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + use_abs_pos_embed (bool): If True, add absolute position embedding to + the patch embedding. Defaults to False. + interpolate_mode (str): Select the interpolate mode for absolute + position embeding vector resize. Defaults to "bicubic". + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + pad_small_map (bool): If True, pad the small feature map to the window + size, which is common used in detection and segmentation. If False, + avoid shifting window and shrink the window size to the size of + feature map, which is common used in classification. + Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + stage_cfgs (Sequence[dict] | dict): Extra config dict for each + stage. Defaults to an empty dict. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + pretrained_window_sizes (tuple(int)): Pretrained window sizes of + each layer. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmcls.models import SwinTransformerV2 + >>> import torch + >>> extra_config = dict( + >>> arch='tiny', + >>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3, + >>> 'padding': 'same'})) + >>> self = SwinTransformerV2(**extra_config) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> output = self.forward(inputs) + >>> print(output.shape) + (1, 2592, 4) + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': 96, + 'depths': [2, 2, 6, 2], + 'num_heads': [3, 6, 12, 24], + 'extra_norm_every_n_blocks': 0}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': 96, + 'depths': [2, 2, 18, 2], + 'num_heads': [3, 6, 12, 24], + 'extra_norm_every_n_blocks': 0}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': 128, + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32], + 'extra_norm_every_n_blocks': 0}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': 192, + 'depths': [2, 2, 18, 2], + 'num_heads': [6, 12, 24, 48], + 'extra_norm_every_n_blocks': 0}), + # head count not certain for huge, and is employed for another + # parallel study about self-supervised learning. + **dict.fromkeys(['h', 'huge'], + {'embed_dims': 352, + 'depths': [2, 2, 18, 2], + 'num_heads': [8, 16, 32, 64], + 'extra_norm_every_n_blocks': 6}), + **dict.fromkeys(['g', 'giant'], + {'embed_dims': 512, + 'depths': [2, 2, 42, 4], + 'num_heads': [16, 32, 64, 128], + 'extra_norm_every_n_blocks': 6}), + } # yapf: disable + + _version = 1 + num_extra_tokens = 0 + + def __init__(self, + arch='tiny', + img_size=256, + patch_size=4, + in_channels=3, + vocabulary_size=128, + window_size=8, + drop_rate=0., + drop_path_rate=0.1, + out_indices=(3, ), + use_abs_pos_embed=False, + interpolate_mode='bicubic', + with_cp=False, + frozen_stages=-1, + norm_eval=False, + pad_small_map=False, + norm_cfg=dict(type='LN'), + stage_cfgs=dict(downsample_cfg=dict(is_post_norm=True)), + patch_cfg=dict(), + pretrained_window_sizes=[0, 0, 0, 0], + init_cfg=None): + super(SwinTransformerV2, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'depths', 'num_heads', + 'extra_norm_every_n_blocks' + } + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.vocabulary_size = vocabulary_size + 1 # 增加ignore类别 + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + self.extra_norm_every_n_blocks = self.arch_settings[ + 'extra_norm_every_n_blocks'] + self.num_layers = len(self.depths) + self.out_indices = out_indices + self.use_abs_pos_embed = use_abs_pos_embed + self.interpolate_mode = interpolate_mode + self.frozen_stages = frozen_stages + + if isinstance(window_size, int): + self.window_sizes = [window_size for _ in range(self.num_layers)] + elif isinstance(window_size, Sequence): + assert len(window_size) == self.num_layers, \ + f'Length of window_sizes {len(window_size)} is not equal to '\ + f'length of stages {self.num_layers}.' + self.window_sizes = window_size + else: + raise TypeError('window_size should be a Sequence or int.') + + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + norm_cfg=dict(type='LN'), + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + self.patch_size = patch_size + + if self.use_abs_pos_embed: + num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dims)) + self._register_load_state_dict_pre_hook( + self._prepare_abs_pos_embed) + + self._register_load_state_dict_pre_hook(self._delete_reinit_params) + + self.drop_after_pos = nn.Dropout(p=drop_rate) + self.norm_eval = norm_eval + + # stochastic depth + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + self.stages = ModuleList() + embed_dims = [self.embed_dims] + for i, (depth, num_heads) in enumerate(zip(self.depths, + self.num_heads)): + if isinstance(stage_cfgs, Sequence): + stage_cfg = stage_cfgs[i] + else: + stage_cfg = deepcopy(stage_cfgs) + downsample = True if i > 0 else False + _stage_cfg = { + 'embed_dims': embed_dims[-1], + 'depth': depth, + 'num_heads': num_heads, + 'window_size': self.window_sizes[i], + 'downsample': downsample, + 'drop_paths': dpr[:depth], + 'with_cp': with_cp, + 'pad_small_map': pad_small_map, + 'extra_norm_every_n_blocks': self.extra_norm_every_n_blocks, + 'pretrained_window_size': pretrained_window_sizes[i], + **stage_cfg + } + + stage = SwinBlockV2Sequence(**_stage_cfg) + self.stages.append(stage) + + dpr = dpr[depth:] + embed_dims.append(stage.out_channels) + + for i in out_indices: + if norm_cfg is not None: + norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1] + else: + norm_layer = nn.Identity() + + self.add_module(f'norm{i}', norm_layer) + + def init_weights(self): + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + from mmcls.utils import get_root_logger + logger = get_root_logger() + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + # print(self.state_dict().keys()) + # print('---') + # print(state_dict.keys()) + # import pdb; pdb.set_trace() + load_state_dict(self, state_dict, strict=False, logger=logger) + return + else: + super(SwinTransformerV2, self).init_weights() + if self.use_abs_pos_embed: + trunc_normal_(self.absolute_pos_embed, std=0.02) + + def forward(self, x): + x, hw_shape = self.patch_embed(x) + + if self.use_abs_pos_embed: + x = x + resize_pos_embed( + self.absolute_pos_embed, self.patch_resolution, hw_shape, + self.interpolate_mode, self.num_extra_tokens) + x = self.drop_after_pos(x) + + outs = [] + for i, stage in enumerate(self.stages): + x, hw_shape = stage(x, hw_shape) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, + stage.out_channels).permute(0, 3, 1, + 2).contiguous() + outs.append(out) + + return outs + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(0, self.frozen_stages + 1): + m = self.stages[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + for i in self.out_indices: + if i <= self.frozen_stages: + for param in getattr(self, f'norm{i}').parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(SwinTransformerV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'absolute_pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.absolute_pos_embed.shape != ckpt_pos_embed_shape: + from mmcls.utils import get_root_logger + logger = get_root_logger() + logger.info( + 'Resize the absolute_pos_embed shape from ' + f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + def _delete_reinit_params(self, state_dict, prefix, *args, **kwargs): + # delete relative_position_index since we always re-init it + relative_position_index_keys = [ + k for k in state_dict.keys() if 'relative_position_index' in k + ] + for k in relative_position_index_keys: + del state_dict[k] + + # delete relative_coords_table since we always re-init it + relative_position_index_keys = [ + k for k in state_dict.keys() if 'relative_coords_table' in k + ] + for k in relative_position_index_keys: + del state_dict[k] + +class Proj_MHSA(nn.Module): + + def __init__( + self, + embed_dims, + proj_dims, + num_heads=16, + batch_first=True, + bias = True + ): + super().__init__() + self.proj_in = nn.Linear(in_features=embed_dims, out_features=proj_dims) + self.attn = MultiheadAttention( + embed_dims=proj_dims, + num_heads=num_heads, + batch_first=batch_first, + bias=bias + ) + self.proj_out = nn.Linear(in_features=proj_dims, out_features=embed_dims) + def forward(self, x): + x = self.proj_in(x) + x = self.attn(x, x, x) + x = self.proj_out(x) + return x + +class SwinTransformerV2MSL(SwinTransformerV2): + + def __init__(self, **kwargs): + if 'use_attn' in kwargs: + self.use_attn = kwargs.pop('use_attn') + else: + self.use_attn = False + if 'merge_stage' in kwargs: + self.merge_stage = kwargs.pop('merge_stage') + else: + self.merge_stage = 0 + if 'with_cls_pos' in kwargs: + self.with_cls_pos = kwargs.pop('with_cls_pos') + else: + self.with_cls_pos = False + super().__init__(**kwargs) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + #self.vocabulary_token = nn.Parameter(torch.zeros(1, 1, 1, self.vocabulary_size, self.embed_dims)) + self.vocabulary_token = nn.Parameter(torch.zeros(self.vocabulary_size, self.embed_dims)) + self.vocabulary_weight = nn.Parameter(torch.zeros(1, self.patch_size * self.patch_size)) + trunc_normal_(self.mask_token, mean=0., std=.02) + trunc_normal_(self.vocabulary_token, mean=0., std=.02) + + if self.use_attn: + self.attn1 = Proj_MHSA(embed_dims=352, proj_dims=256, num_heads=16, batch_first=True, bias = True) + self.attn2 = Proj_MHSA(embed_dims=704, proj_dims=512, num_heads=16, batch_first=True, bias = True) + self.attn3 = Proj_MHSA( embed_dims=1408, proj_dims=1024, num_heads=16, batch_first=True, bias = True) + self.attention_blocks = [self.attn1, self.attn2, self.attn3] + self.norm_attn = build_norm_layer(dict(type='LN'), 1408)[1] + + def create_ann_token(self, anno_img): + B, H, W = anno_img.shape + ann_token = torch.index_select(self.vocabulary_token, 0, anno_img.reshape(-1)).reshape(B, H, W, -1) + assert H % self.patch_size == 0 and W % self.patch_size == 0 + nph, npw = H // self.patch_size, W // self.patch_size + weight = F.softmax(self.vocabulary_weight, dim=1) * self.patch_size * self.patch_size + weight = weight.reshape(1, 1, self.patch_size, 1, self.patch_size).repeat(1, nph, 1, npw, 1).reshape(1, H, W, 1) + ann_token = ann_token * weight + ann_token = F.avg_pool2d(torch.einsum('BHWC->BCHW', ann_token), self.patch_size, self.patch_size) + ann_token = torch.einsum('BCHW->BHWC', ann_token).reshape(B, nph * npw, self.embed_dims) # shape B, L, C + return ann_token + + def forward(self, hr_img, anno_img, mask=None): + x, hw_shape = self.patch_embed(hr_img) + y = self.create_ann_token(anno_img) + assert x.shape == y.shape + B, L, C = y.shape + if mask is not None: + mask_tokens = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) + y = y * (1. - w) + mask_tokens * w + + if self.merge_stage == 0: + x = (x + y) * 0.5 + else: + x = x.reshape(B, *hw_shape, C) + y = y.reshape(B, *hw_shape, C) + x = torch.cat((x, y), dim=2) + hw_shape = (hw_shape[0], hw_shape[1] * 2) + x = x.reshape(B, -1, C) + if self.use_abs_pos_embed: + x = x + resize_pos_embed( + self.absolute_pos_embed, self.patch_resolution, hw_shape, + self.interpolate_mode, self.num_extra_tokens) + if self.with_cls_pos: + hw_shape_half = [hw_shape[0], hw_shape[1] // 2] + x = x.reshape(B, *hw_shape, C) + x1 = x[:, :, :x.shape[2]//2, :].reshape(B, -1, C) + x2 = x[:, :, x.shape[2]//2:, :].reshape(B, -1, C) + x1 = x1 + resize_pos_embed( + self.absolute_pos_embed, self.patch_resolution, hw_shape_half, + self.interpolate_mode, self.num_extra_tokens) + x2 = x2 + resize_pos_embed( + self.absolute_pos_embed, self.patch_resolution, hw_shape_half, + self.interpolate_mode, self.num_extra_tokens) + x1 = x1.reshape(B, *hw_shape_half, C) + x2 = x2.reshape(B, *hw_shape_half, C) + x = torch.cat((x1, x2), dim=2).reshape(B, -1, C) + x = self.drop_after_pos(x) + outs = [] + merge_idx = self.merge_stage - 1 + for i, stage in enumerate(self.stages): + x, hw_shape = stage(x, hw_shape) + if i == merge_idx: + x = x.reshape(x.shape[0], *hw_shape, x.shape[-1]) # b,l,c -> b, h, w, c + x = (x[:, :, :x.shape[2]//2] + x[:, :, x.shape[2]//2:]) * 0.5 + x = x.reshape(x.shape[0], -1, x.shape[-1]) + hw_shape = (hw_shape[0], hw_shape[1] // 2) + if self.use_attn: + if i <= len(self.attention_blocks) - 1: + x = x + self.attention_blocks[i](x) + if i == len(self.attention_blocks) - 1: + x = self.norm_attn(x) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + out = norm_layer(x) + out = out.view(-1, *hw_shape, stage.out_channels).permute(0, 3, 1, 2).contiguous() + outs.append(out) + return outs \ No newline at end of file diff --git a/lib/models/backbones/vit.py b/lib/models/backbones/vit.py new file mode 100644 index 0000000..4fa5ff0 --- /dev/null +++ b/lib/models/backbones/vit.py @@ -0,0 +1,611 @@ +# Copyright (c) Ant Group. All rights reserved. +import math +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList, + load_state_dict) +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.utils import _pair as to_2tuple + +from mmseg.ops import resize +from mmseg.utils import get_root_logger +from mmseg.models.utils.embed import PatchEmbed +import torch.nn.functional as F +import numpy as np + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Vision Transformer. + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed + after the feed forward layer. Default: 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default: 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.0. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + qkv_bias (bool): enable bias for qkv if True. Default: True + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_fcs=2, + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + attn_cfg=dict(), + ffn_cfg=dict(), + with_cp=False): + super(TransformerEncoderLayer, self).__init__() + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, + embed_dims, + postfix=1) + self.add_module(self.norm1_name, norm1) + + attn_cfg.update( + dict(embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + batch_first=batch_first, + bias=qkv_bias)) + + self.build_attn(attn_cfg) + + self.norm2_name, norm2 = build_norm_layer(norm_cfg, + embed_dims, + postfix=2) + self.add_module(self.norm2_name, norm2) + + ffn_cfg.update( + dict(embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + num_fcs=num_fcs, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate) + if drop_path_rate > 0 else None, + act_cfg=act_cfg)) + self.build_ffn(ffn_cfg) + self.with_cp = with_cp + + def build_attn(self, attn_cfg): + self.attn = MultiheadAttention(**attn_cfg) + + def build_ffn(self, ffn_cfg): + self.ffn = FFN(**ffn_cfg) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + @property + def norm2(self): + return getattr(self, self.norm2_name) + + def forward(self, x): + + def _inner_forward(x): + x = self.attn(self.norm1(x), identity=x) + x = self.ffn(self.norm2(x), identity=x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +class VisionTransformer(BaseModule): + """Vision Transformer. + This backbone is the implementation of `An Image is Worth 16x16 Words: + Transformers for Image Recognition at + Scale `_. + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): embedding dimension. Default: 768. + num_layers (int): depth of transformer. Default: 12. + num_heads (int): number of attention heads. Default: 12. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. + qkv_bias (bool): enable bias for qkv if True. Default: True. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0 + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Default: True. + output_cls_token (bool): Whether output the cls_token. If set True, + `with_cls_token` must be True. Default: False. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Default: bicubic. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save + some memory while slowing down the training speed. Default: False. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + with_cls_token=True, + output_cls_token=False, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + final_norm=False, + interpolate_mode='bicubic', + num_fcs=2, + norm_eval=False, + with_cp=False, + use_ccd=False, + ccd_num=0, + pretrained=None, + init_cfg=None): + super(VisionTransformer, self).__init__(init_cfg=init_cfg) + + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, tuple): + if len(img_size) == 1: + img_size = to_2tuple(img_size[0]) + assert len(img_size) == 2, \ + f'The size of image should have length 1 or 2, ' \ + f'but got {len(img_size)}' + + if output_cls_token: + assert with_cls_token is True, f'with_cls_token must be True if' \ + f'set output_cls_token to True, but got {with_cls_token}' + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.img_size = img_size + self.patch_size = patch_size + self.interpolate_mode = interpolate_mode + self.norm_eval = norm_eval + self.with_cp = with_cp + self.pretrained = pretrained + self.embed_dims = embed_dims + self.patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + padding='corner', + norm_cfg=norm_cfg if patch_norm else None, + init_cfg=None, + ) + + self.use_ccd = use_ccd + self.ccd_num = ccd_num + if self.use_ccd: + self.ccd_embed = nn.Parameter( + torch.rand(1, self.ccd_num, embed_dims)) + + num_patches = (img_size[0] // patch_size) * \ + (img_size[1] // patch_size) + + self.with_cls_token = with_cls_token + self.output_cls_token = output_cls_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + # self.pos_embed = nn.Parameter( + # torch.zeros(1, num_patches, embed_dims)) + # 原来是 + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dims)) + self.drop_after_pos = nn.Dropout(p=drop_rate) + + if isinstance(out_indices, int): + if out_indices == -1: + out_indices = num_layers - 1 + self.out_indices = [out_indices] + elif isinstance(out_indices, list) or isinstance(out_indices, tuple): + self.out_indices = out_indices + else: + raise TypeError('out_indices must be type of int, list or tuple') + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, num_layers) + ] # stochastic depth decay rule + + self.layers = ModuleList() + for i in range(num_layers): + self.layers.append( + TransformerEncoderLayer(embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * + embed_dims, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + batch_first=True)) + + self.final_norm = final_norm + if final_norm: + self.norm1_name, norm1 = build_norm_layer(norm_cfg, + embed_dims, + postfix=1) + self.add_module(self.norm1_name, norm1) + + @property + def norm1(self): + return getattr(self, self.norm1_name) + + def init_weights(self): + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + logger = get_root_logger() + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + if 'pos_embed' in state_dict.keys(): + if self.pos_embed.shape != state_dict['pos_embed'].shape: + logger.info(msg=f'Resize the pos_embed shape from ' + f'{state_dict["pos_embed"].shape} to ' + f'{self.pos_embed.shape}') + h, w = self.img_size + pos_size = int( + math.sqrt(state_dict['pos_embed'].shape[1] - 1)) + state_dict['pos_embed'] = self.resize_pos_embed( + state_dict['pos_embed'], + (h // self.patch_size, w // self.patch_size), + (pos_size, pos_size), self.interpolate_mode) + + load_state_dict(self, state_dict, strict=False, logger=logger) + elif self.init_cfg is not None: + super(VisionTransformer, self).init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + if self.use_ccd: + trunc_normal_(self.ccd_embed, std=0.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def _pos_embeding(self, patched_img, hw_shape, pos_embed): + """Positioning embeding method. + Resize the pos_embed, if the input image size doesn't match + the training size. + Args: + patched_img (torch.Tensor): The patched image, it should be + shape of [B, L1, C]. + hw_shape (tuple): The downsampled image resolution. + pos_embed (torch.Tensor): The pos_embed weighs, it should be + shape of [B, L2, c]. + Return: + torch.Tensor: The pos encoded image feature. + """ + assert patched_img.ndim == 3 and pos_embed.ndim == 3, \ + 'the shapes of patched_img and pos_embed must be [B, L, C]' + x_len, pos_len = patched_img.shape[1], pos_embed.shape[1] + if x_len != pos_len: + if pos_len == (self.img_size[0] // self.patch_size) * ( + self.img_size[1] // self.patch_size) + 1: + pos_h = self.img_size[0] // self.patch_size + pos_w = self.img_size[1] // self.patch_size + else: + raise ValueError( + 'Unexpected shape of pos_embed, got {}.'.format( + pos_embed.shape)) + pos_embed = self.resize_pos_embed(pos_embed, hw_shape, + (pos_h, pos_w), + self.interpolate_mode) + return self.drop_after_pos(patched_img + pos_embed) + + @staticmethod + def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): + """Resize pos_embed weights. + Resize pos_embed using bicubic interpolate method. + Args: + pos_embed (torch.Tensor): Position embedding weights. + input_shpae (tuple): Tuple for (downsampled input image height, + downsampled input image width). + pos_shape (tuple): The resolution of downsampled origin training + image. + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C] + """ + assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' + pos_h, pos_w = pos_shape + # keep dim for easy deployment + cls_token_weight = pos_embed[:, 0:1] + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] + pos_embed_weight = pos_embed_weight.reshape( + 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + pos_embed_weight = resize(pos_embed_weight, + size=input_shpae, + align_corners=False, + mode=mode) + pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) + pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) + return pos_embed + + def forward(self, inputs, ccd_index=None): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + if self.use_ccd: + _ccd_idx = np.concatenate(ccd_index, axis=0) + _ccd_embed = self.ccd_embed[:, _ccd_idx, :].permute(1, 0, 2) + x = x + _ccd_embed + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = self._pos_embeding(x, hw_shape, self.pos_embed) + + if not self.with_cls_token: + # Remove class token for transformer encoder input + x = x[:, 1:] + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + if self.with_cls_token: + # Remove class token and reshape token for decoder heads + out = x[:, 1:] + else: + out = x + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + if self.output_cls_token: + out = [out, x[:, 0]] + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + super(VisionTransformer, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.LayerNorm): + m.eval() + + +class VisionTransformerMSL(VisionTransformer): + + def __init__(self, **kwargs): + if 'use_attn' in kwargs: + self.use_attn = kwargs.pop('use_attn') + else: + self.use_attn = False + if 'merge_stage' in kwargs: + self.merge_stage = kwargs.pop('merge_stage') + else: + self.merge_stage = 0 + if 'with_cls_pos' in kwargs: + self.with_cls_pos = kwargs.pop('with_cls_pos') + else: + self.with_cls_pos = False + + self.vocabulary_size = kwargs.pop('vocabulary_size') + 1 # 增加ignore类别 + super().__init__(**kwargs) + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + img_size = kwargs.pop('img_size') + patch_size = kwargs.pop('patch_size') + num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dims)) + + self.vocabulary_token = nn.Parameter(torch.zeros(self.vocabulary_size, self.embed_dims)) + self.vocabulary_weight = nn.Parameter(torch.zeros(1, self.patch_size * self.patch_size)) + trunc_normal_(self.mask_token, mean=0., std=.02) + trunc_normal_(self.vocabulary_token, mean=0., std=.02) + + if self.use_attn: + self.attn1 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True) + self.attn2 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True) + self.attn3 = MultiheadAttention(embed_dims=1024, num_heads=16, batch_first=True, bias = True) + self.attention_blocks = [self.attn1, self.attn2, self.attn3] + self.norm_attn = build_norm_layer(dict(type='LN'), 1024)[1] + + def create_ann_token(self, anno_img): + B, H, W = anno_img.shape + ann_token = torch.index_select(self.vocabulary_token, 0, anno_img.reshape(-1)).reshape(B, H, W, -1) + assert H % self.patch_size == 0 and W % self.patch_size == 0 + nph, npw = H // self.patch_size, W // self.patch_size + weight = F.softmax(self.vocabulary_weight, dim=1) * self.patch_size * self.patch_size + weight = weight.reshape(1, 1, self.patch_size, 1, self.patch_size).repeat(1, nph, 1, npw, 1).reshape(1, H, W, 1) + ann_token = ann_token * weight + ann_token = F.avg_pool2d(torch.einsum('BHWC->BCHW', ann_token), self.patch_size, self.patch_size) + ann_token = torch.einsum('BCHW->BHWC', ann_token).reshape(B, nph * npw, self.embed_dims) # shape B, L, C + return ann_token + def _pos_embeding(self, patched_img, hw_shape, pos_embed): + """Positioning embeding method. + Resize the pos_embed, if the input image size doesn't match + the training size. + Args: + patched_img (torch.Tensor): The patched image, it should be + shape of [B, L1, C]. + hw_shape (tuple): The downsampled image resolution. + pos_embed (torch.Tensor): The pos_embed weighs, it should be + shape of [B, L2, c]. + Return: + torch.Tensor: The pos encoded image feature. + """ + assert patched_img.ndim == 3 and pos_embed.ndim == 3, \ + 'the shapes of patched_img and pos_embed must be [B, L, C]' + x_len, pos_len = patched_img.shape[1], pos_embed.shape[1] + if x_len != pos_len: + if pos_len == (self.img_size[0] // self.patch_size) * ( + self.img_size[1] // self.patch_size): + pos_h = self.img_size[0] // self.patch_size + pos_w = self.img_size[1] // self.patch_size + else: + raise ValueError( + 'Unexpected shape of pos_embed, got {}.'.format( + pos_embed.shape)) + pos_embed = self.resize_pos_embed(pos_embed, hw_shape, + (pos_h, pos_w), + self.interpolate_mode) + return self.drop_after_pos(patched_img + pos_embed) + + @staticmethod + def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): + """Resize pos_embed weights. + Resize pos_embed using bicubic interpolate method. + Args: + pos_embed (torch.Tensor): Position embedding weights. + input_shpae (tuple): Tuple for (downsampled input image height, + downsampled input image width). + pos_shape (tuple): The resolution of downsampled origin training + image. + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C] + """ + assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' + pos_h, pos_w = pos_shape + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] + pos_embed_weight = pos_embed_weight.reshape( + 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + pos_embed_weight = resize(pos_embed_weight, + size=input_shpae, + align_corners=False, + mode=mode) + pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) + pos_embed = pos_embed_weight # torch.cat((cls_token_weight, pos_embed_weight), dim=1) + return pos_embed + + def forward(self, x, y, mask=None): + x, hw_shape = self.patch_embed(x) + y = self.create_ann_token(y) + assert x.shape == y.shape + B, L, C = y.shape + if mask is not None: + mask_tokens = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) + y = y * (1. - w) + mask_tokens * w + if self.merge_stage == 0: + x = (x + y) * 0.5 + else: + x = x.reshape(B, *hw_shape, C) + y = y.reshape(B, *hw_shape, C) + x = torch.cat((x, y), dim=2) + hw_shape = (hw_shape[0], hw_shape[1] * 2) + x = x.reshape(B, -1, C) + x = self._pos_embeding(x, hw_shape, self.pos_embed) + merge_idx = self.merge_stage - 1 + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == merge_idx: + x = x.reshape(x.shape[0], *hw_shape, x.shape[-1]) # b,l,c -> b, h, w, c + x = (x[:, :, :x.shape[2]//2] + x[:, :, x.shape[2]//2:]) * 0.5 + x = x.reshape(x.shape[0], -1, x.shape[-1]) + hw_shape = (hw_shape[0], hw_shape[1] // 2) + if self.use_attn: + if i <= len(self.attention_blocks) - 1: + x = x + self.attention_blocks[i](x) + if i == len(self.attention_blocks) - 1: + x = self.norm_attn(x) # 会不会有冲突 + if (not self.use_attn) and (i == len(self.layers) - 1): + if self.final_norm: + x = self.norm1(x) # 会不会有冲突 + if i in self.out_indices: + if self.with_cls_token: + # Remove class token and reshape token for decoder heads + out = x[:, 1:] + else: + out = x + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + if self.output_cls_token: + out = [out, x[:, 0]] + outs.append(out) + + return tuple(outs) \ No newline at end of file diff --git a/lib/models/heads/__init__.py b/lib/models/heads/__init__.py new file mode 100644 index 0000000..5569219 --- /dev/null +++ b/lib/models/heads/__init__.py @@ -0,0 +1,15 @@ +from .uper_head import UPerHead +from .up_head import UPHead + +__all__ = [ + 'UPerHead', 'UPHead' +] + +type_mapping = { + 'UPerHead': UPerHead, + 'UPHead': UPHead +} + + +def build_head(type, **kwargs): + return type_mapping[type](**kwargs) \ No newline at end of file diff --git a/lib/models/heads/decode_head.py b/lib/models/heads/decode_head.py new file mode 100644 index 0000000..172522d --- /dev/null +++ b/lib/models/heads/decode_head.py @@ -0,0 +1,201 @@ +import warnings +from abc import ABCMeta, abstractmethod + +import torch +import torch.nn as nn +from mmcv.runner import BaseModule, auto_fp16 + +from mmseg.core import build_pixel_sampler +from mmseg.ops import resize + + +class BaseDecodeHead(BaseModule, metaclass=ABCMeta): + """Base class for BaseDecodeHead. + + Args: + in_channels (int|Sequence[int]): Input channels. + channels (int): Channels after modules, before conv_seg. + num_classes (int): Number of classes. + out_channels (int): Output channels of conv_seg. + threshold (float): Threshold for binary segmentation in the case of + `out_channels==1`. Default: None. + dropout_ratio (float): Ratio of dropout layer. Default: 0.1. + conv_cfg (dict|None): Config of conv layers. Default: None. + norm_cfg (dict|None): Config of norm layers. Default: None. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + in_index (int|Sequence[int]): Input feature index. Default: -1 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + Default: None. + loss_decode (dict | Sequence[dict]): Config of decode loss. + The `loss_name` is property of corresponding loss function which + could be shown in training log. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + e.g. dict(type='CrossEntropyLoss'), + [dict(type='CrossEntropyLoss', loss_name='loss_ce'), + dict(type='DiceLoss', loss_name='loss_dice')] + Default: dict(type='CrossEntropyLoss'). + ignore_index (int | None): The label index to be ignored. When using + masked BCE loss, ignore_index should be set to None. Default: 255. + sampler (dict|None): The config of segmentation map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + channels, + *, + num_classes, + out_channels=None, + threshold=None, + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + in_index=-1, + input_transform=None, + sampler=None, + align_corners=False, + init_cfg=dict(type='Normal', + std=0.01, + override=dict(name='conv_seg'))): + super(BaseDecodeHead, self).__init__(init_cfg) + self._init_inputs(in_channels, in_index, input_transform) + self.channels = channels + self.dropout_ratio = dropout_ratio + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.in_index = in_index + + self.align_corners = align_corners + + if out_channels is None: + if num_classes == 2: + warnings.warn('For binary segmentation, we suggest using' + '`out_channels = 1` to define the output' + 'channels of segmentor, and use `threshold`' + 'to convert seg_logist into a prediction' + 'applying a threshold') + out_channels = num_classes + + if out_channels != num_classes and out_channels != 1: + raise ValueError( + 'out_channels should be equal to num_classes,' + 'except binary segmentation set out_channels == 1 and' + f'num_classes == 2, but got out_channels={out_channels}' + f'and num_classes={num_classes}') + + if out_channels == 1 and threshold is None: + threshold = 0.3 + warnings.warn('threshold is not defined for binary, and defaults' + 'to 0.3') + self.num_classes = num_classes + self.out_channels = out_channels + self.threshold = threshold + + if sampler is not None: + self.sampler = build_pixel_sampler(sampler, context=self) + else: + self.sampler = None + + self.conv_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1) + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + self.fp16_enabled = False + + def extra_repr(self): + """Extra repr.""" + s = f'input_transform={self.input_transform}, ' \ + f'align_corners={self.align_corners}' + return s + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize(input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + @auto_fp16() + @abstractmethod + def forward(self, inputs): + """Placeholder of forward function.""" + pass + + def cls_seg(self, feat): + """Classify each pixel.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.conv_seg(feat) + return output \ No newline at end of file diff --git a/lib/models/heads/psp_head.py b/lib/models/heads/psp_head.py new file mode 100644 index 0000000..d402228 --- /dev/null +++ b/lib/models/heads/psp_head.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.ops import resize +from mmseg.models.decode_heads.decode_head import BaseDecodeHead + + +class PPM(nn.ModuleList): + """Pooling Pyramid Module used in PSPNet. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + align_corners (bool): align_corners argument of F.interpolate. + """ + + def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, + act_cfg, align_corners, **kwargs): + super(PPM, self).__init__() + self.pool_scales = pool_scales + self.align_corners = align_corners + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + for pool_scale in pool_scales: + self.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(pool_scale), + ConvModule( + self.in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + **kwargs))) + + def forward(self, x): + """Forward function.""" + ppm_outs = [] + for ppm in self: + ppm_out = ppm(x) + ppm_out = ppm_out.to(torch.float32) + upsampled_ppm_out = resize( + ppm_out, + size=x.size()[2:], + mode='bilinear', + align_corners=self.align_corners) + upsampled_ppm_out = upsampled_ppm_out.to(torch.bfloat16) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs diff --git a/lib/models/heads/up_head.py b/lib/models/heads/up_head.py new file mode 100644 index 0000000..ee048b6 --- /dev/null +++ b/lib/models/heads/up_head.py @@ -0,0 +1,52 @@ +import torch.nn as nn +from collections import OrderedDict +from mmcv.cnn.utils.weight_init import (kaiming_init, trunc_normal_) +from mmcv.runner import (CheckpointLoader, load_state_dict) +from mmseg.utils import get_root_logger + + +class UPHead(nn.Module): + + def __init__(self, in_dim, out_dim, up_scale, init_cfg=None): + super().__init__() + self.decoder = nn.Sequential( + nn.Conv2d(in_channels=in_dim, + out_channels=up_scale**2 * out_dim, + kernel_size=1), + nn.PixelShuffle(up_scale), + ) + self.init_cfg = init_cfg + self.apply(self._init_weights) + + def _init_weights(self, m): + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + logger = get_root_logger() + + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + + if 'state_dict' in checkpoint: + _state_dict = checkpoint['state_dict'] + else: + _state_dict = checkpoint + + state_dict = OrderedDict() + for k, v in _state_dict.items(): + if k.startswith('backbone.'): + state_dict[k[9:]] = v + else: + state_dict[k] = v + print(f'loading weight: {self.init_cfg["checkpoint"]}') + load_state_dict(self, state_dict, strict=False, logger=logger) + else: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + + def forward(self, x): + x = self.decoder(x) + return x \ No newline at end of file diff --git a/lib/models/heads/uper_head.py b/lib/models/heads/uper_head.py new file mode 100644 index 0000000..c8fdbaa --- /dev/null +++ b/lib/models/heads/uper_head.py @@ -0,0 +1,130 @@ +# coding: utf-8 +# Copyright (c) Ant Group. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.ops import resize +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from .psp_head import PPM + + +class UPerHead(BaseDecodeHead): + """Unified Perceptual Parsing for Scene Understanding. + + This head is the implementation of `UPerNet + `_. + + Args: + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module applied on the last feature. Default: (1, 2, 3, 6). + """ + + def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): + super(UPerHead, self).__init__( + input_transform='multiple_select', **kwargs) + # PSP Module + self.psp_modules = PPM( + pool_scales, + self.in_channels[-1], + self.channels, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + align_corners=self.align_corners) + self.bottleneck = ConvModule( + self.in_channels[-1] + len(pool_scales) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + # FPN Module + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + for in_channels in self.in_channels[:-1]: # skip the top layer + l_conv = ConvModule( + in_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False) + fpn_conv = ConvModule( + self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + inplace=False) + self.lateral_convs.append(l_conv) + self.fpn_convs.append(fpn_conv) + + self.fpn_bottleneck = ConvModule( + len(self.in_channels) * self.channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def psp_forward(self, inputs): + """Forward function of PSP module.""" + x = inputs[-1] + psp_outs = [x] + # breakpoint() + psp_outs.extend(self.psp_modules(x)) + psp_outs = torch.cat(psp_outs, dim=1) + output = self.bottleneck(psp_outs) + + return output + + def forward(self, inputs): + """Forward function.""" + # breakpoint() + inputs = self._transform_inputs(inputs) + + # build laterals + laterals = [ + lateral_conv(inputs[i]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + laterals.append(self.psp_forward(inputs)) + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + prev_shape = laterals[i - 1].shape[2:] + laterals[i] = laterals[i].type(torch.float32) + laterals[i - 1] = laterals[i - 1] + resize( + laterals[i], + size=prev_shape, + mode='bilinear', + align_corners=self.align_corners) + + # build outputs + fpn_outs = [ + self.fpn_convs[i](laterals[i]) + for i in range(used_backbone_levels - 1) + ] + # append psp feature + fpn_outs.append(laterals[-1]) + + for i in range(used_backbone_levels - 1, 0, -1): + fpn_outs[i] = fpn_outs[i].type(torch.float32) + fpn_outs[i] = resize( + fpn_outs[i], + size=fpn_outs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) + fpn_outs = torch.cat(fpn_outs, dim=1) + output = self.fpn_bottleneck(fpn_outs) + output = self.cls_seg(output) + return output diff --git a/lib/models/losses/__init__.py b/lib/models/losses/__init__.py new file mode 100644 index 0000000..3e0b929 --- /dev/null +++ b/lib/models/losses/__init__.py @@ -0,0 +1,4 @@ +from .modality_vae_loss import ModalityVAELoss +from .recon_anno_loss import RecLoss + +__all__ = [ "ModalityVAELoss", "RecLoss" ] \ No newline at end of file diff --git a/lib/models/losses/modality_vae_loss.py b/lib/models/losses/modality_vae_loss.py new file mode 100644 index 0000000..a399e10 --- /dev/null +++ b/lib/models/losses/modality_vae_loss.py @@ -0,0 +1,46 @@ +# Copyright (c) Ant Group and its affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from antmmf.common.registry import registry + + +@registry.register_loss("ModalityVAELoss") +class ModalityVAELoss(nn.Module): + def __init__(self, **params): + super().__init__() + self.weight = params.pop("weight") + + def compute_rec_loss(self, x_in, x_out, modal_flag): + loss_per_pixel = F.mse_loss(x_in, x_out, reduction='none') + loss_b = torch.mean(loss_per_pixel, dim=[1, 2, 3]) + return torch.sum(loss_b * modal_flag)/ (modal_flag.sum() + 1e-6) + + def forward(self, sample_list, output, *args, **kwargs): + vae_out = output["vae_out"] + feat_hr = vae_out['input_hr'] + feat_s2 = vae_out['input_s2'] + feat_s1 = vae_out['input_s1'] + + g_hr = vae_out['g_hr'] + g_s2 = vae_out['g_s2'] + g_s1 = vae_out['g_s1'] + + # process modality flags + modality_info = vae_out['modality_info'] + B_M, L_M = modality_info.shape + + modality_hr = modality_info[:,0] + modality_s2 = modality_info[:,1] + modality_s1 = modality_info[:,2] + + ######## rec losses ######## + loss_xent = self.compute_rec_loss(g_hr, feat_hr, modality_hr) \ + + self.compute_rec_loss(g_s2, feat_s2, modality_s2) \ + + self.compute_rec_loss(g_s1, feat_s1, modality_s1) + + + loss_quant = vae_out["loss_quant"] + total_loss = loss_xent / 3 + loss_quant + return total_loss * self.weight diff --git a/lib/models/losses/recon_anno_loss.py b/lib/models/losses/recon_anno_loss.py new file mode 100644 index 0000000..5c73df6 --- /dev/null +++ b/lib/models/losses/recon_anno_loss.py @@ -0,0 +1,89 @@ +# Copyright (c) Ant Group and its affiliates. +import torch +import torch.nn as nn +from antmmf.common.registry import registry +import torch.nn.functional as F + +@registry.register_loss("RecLoss") +class RecLoss(nn.Module): + def __init__(self, **params): + super().__init__() + self.weight = params.pop("weight") + self.patch_size = params.pop("patch_size") + self.eps = torch.finfo(torch.bfloat16).eps + self.pred_key = params.pop("pred_key") + self.vocabulary_size = params.pop("vocabulary_size") + 1 + self.mask_key = params.pop("mask_key") + self.target_key = params.pop("target_key") + self.feature_merged = params.pop("feature_merged") + self.cnt_train = 0 + self.cnt_val = 0 + self.use_bg = params.pop("use_bg") + if "use_all_patch" in params: + self.use_all_patch = params.pop("use_all_patch") + else: + self.use_all_patch = False + if "balance" in params: + self.balance = params.pop("balance") + else: + self.balance = False + if "sim_regularization" in params: + self.sim_regularization = params.pop("sim_regularization") + else: + self.sim_regularization = False + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = self.patch_size + w = int((x.shape[1]*0.5)**.5) + h = w * 2 + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p)) + x = torch.einsum('nhwpq->nhpwq', x) + imgs = x.reshape(shape=(x.shape[0], h * p, w * p)) + return imgs + + + def forward(self, sample_list, output, *args, **kwargs): + pred = output[self.pred_key] # B, C, H, W + target = output[self.target_key] # B, H, W + mask = output[self.mask_key] + b_mask, h_mask, w_mask = mask.shape + mask = mask.reshape((b_mask, h_mask*w_mask)) + mask = mask[:, :, None].repeat(1, 1, self.patch_size**2) + mask = self.unpatchify(mask) + + if not self.use_bg: + valid = sample_list['valid'] + mask = mask * valid + + loss = F.cross_entropy(pred, target, reduction="none") + + if self.balance: + if self.use_all_patch: + loss_pos = loss[target > 0].sum() / ((target > 0).sum() + 1e-6) + loss_neg = loss[target == 0].sum() / ((target == 0).sum() + 1e-6) + loss = (loss_pos + loss_neg) * 0.5 + else: + loss_pos = loss[(target > 0) & (mask == 1)].sum() / (((target > 0) & (mask == 1)).sum() + 1e-6) + loss_neg = loss[(target == 0) & (mask == 1)].sum() / (((target == 0) & (mask == 1)).sum() + 1e-6) + loss = (loss_pos + loss_neg) * 0.5 + else: + if self.use_all_patch: + loss = loss.mean() + else: + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + if self.sim_regularization: + vocabulary_token = output['vocabulary_token'] + voca_normed = F.normalize(vocabulary_token, 2, 1) + similarity_matrix = 1 + torch.einsum('nd,md->nm', voca_normed, voca_normed) + num = voca_normed.shape[0] + index = torch.triu(voca_normed.new_ones(num, num), diagonal=1).type(torch.bool) + loss_reg = similarity_matrix[index].mean() + return loss * self.weight + loss_reg * 0.05 + return loss * self.weight + diff --git a/lib/models/metrics/__init__.py b/lib/models/metrics/__init__.py new file mode 100644 index 0000000..97257e8 --- /dev/null +++ b/lib/models/metrics/__init__.py @@ -0,0 +1,4 @@ +from .sem_metrics import SemMetric + +__all__ = ["SemMetric"] + diff --git a/lib/models/metrics/sem_metrics.py b/lib/models/metrics/sem_metrics.py new file mode 100644 index 0000000..1368f56 --- /dev/null +++ b/lib/models/metrics/sem_metrics.py @@ -0,0 +1,93 @@ +# coding: utf-8 +# Copyright (c) Ant Group. All rights reserved. +import torch +from torch.distributed import all_reduce, ReduceOp +from antmmf.common.registry import registry +from antmmf.modules.metrics.base_metric import BaseMetric + +@registry.register_metric("sem_metric") +class SemMetric(BaseMetric): + """Segmentation metrics used in evaluation phase. + + Args: + name (str): Name of the metric. + eval_type(str): 3 types are supported: 'mIoU', 'mDice', 'mFscore' + result_field(str): key of predicted results in output dict + target_field(str): key of ground truth in output dict + ignore_index(int): class value will be ignored in evaluation + num_cls(int): total number of categories in evaluation + """ + + def __init__(self, + name="dummy_metric", **kwargs + ): + super().__init__(name) + self.reset() + + def calculate(self, sample_list, model_output, *args, **kwargs): + """Calculate Intersection and Union for a batch. + + Args: + sample_list (Sample_List): data which contains ground truth segmentation maps + model_output (dict): data which contains prediction segmentation maps + Returns: + torch.Tensor: The intersection of prediction and ground truth histogram + on all classes. + torch.Tensor: The union of prediction and ground truth histogram on all + classes. + torch.Tensor: The prediction histogram on all classes. + torch.Tensor: The ground truth histogram on all classes. + """ + + return torch.tensor(0).float() + + def reset(self): + """ initialized all attributes value before evaluation + + """ + self.total_mask_mae = 0 + self.total_num = torch.tensor(0) + + def collect(self, sample_list, model_output, *args, **kwargs): + """ + Args: + sample_list(Sample_List): data which contains ground truth segmentation maps + model_output (Dict): Dict returned by model, that contains two modalities + Returns: + torch.FloatTensor: Accuracy + """ + batch_mask_mae = \ + self.calculate(sample_list, model_output, *args, **kwargs) + self.total_mask_mae += batch_mask_mae + self.total_num += 1 + + def format(self, *args): + """ Format evaluated metrics for profile. + + Returns: + dict: dict of all evaluated metrics. + """ + output_metric = dict() + # if self.eval_type == 'mae': + mae = args[0] + output_metric['mae'] = mae.item() + return output_metric + + def summarize(self, *args, **kwargs): + """This method is used to calculate the overall metric. + + Returns: + dict: dict of all evaluated metrics. + + """ + # if self.eval_type == 'mae': + mae = self.total_mask_mae / (self.total_num) + return self.format(mae) + + def all_reduce(self): + total_number = torch.stack([ + self.total_mask_mae, self.total_num + ]).cuda() + all_reduce(total_number, op=ReduceOp.SUM) + self.total_mask_mae = total_number[0].cpu() + self.total_num = total_number[1].cpu() \ No newline at end of file diff --git a/lib/models/necks/__init__.py b/lib/models/necks/__init__.py new file mode 100644 index 0000000..e57f6d9 --- /dev/null +++ b/lib/models/necks/__init__.py @@ -0,0 +1,13 @@ +from .transformer_encoder import TransformerEncoder +from .modality_completion import ModalityCompletion + +__all__ = ['TransformerEncoder', 'ModalityCompletion'] + +type_mapping = { + 'TransformerEncoder': TransformerEncoder, + 'ModalityCompletion': ModalityCompletion +} + + +def build_neck(type, **kwargs): + return type_mapping[type](**kwargs) \ No newline at end of file diff --git a/lib/models/necks/modality_completion.py b/lib/models/necks/modality_completion.py new file mode 100644 index 0000000..14cbd50 --- /dev/null +++ b/lib/models/necks/modality_completion.py @@ -0,0 +1,212 @@ +# Copyright (c) AntGroup. All rights reserved. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class BFloat16UpsampleNearest2d(nn.Module): + def __init__(self, scale_factor, mode='bilinear'): + super().__init__() + self.scale_factor = scale_factor + self.mode = mode + + def forward(self, x): + x_float = x.float() + upsampled_x = F.interpolate(x_float, scale_factor=self.scale_factor, mode=self.mode) + return upsampled_x.to(x.dtype) + +class ConvVQVAEv2(nn.Module): + def __init__(self, input_shape, conv_dim, z_dim, num_tokens=8192, temp=0.9): + super().__init__() + self.z_dim = z_dim + self.conv_dim = conv_dim # 256 + self.input_shape = input_shape # 256 + self.temp = temp + # code book + self.codebook = nn.Embedding(num_tokens, z_dim) + # encoder + self.relu = nn.LeakyReLU() + self.pool = nn.AvgPool2d(2) + self.conv1 = nn.Conv2d(input_shape[0], conv_dim, 5, stride=1, padding=2) + self.enc_block1 = nn.Sequential( + nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1), + nn.LeakyReLU(), + nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1), + nn.LeakyReLU(), + ) + self.gamma_1 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1))) + self.enc_block2 = nn.Sequential( + nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1), + nn.LeakyReLU(), + nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1), + nn.LeakyReLU(), + ) + self.gamma_2 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1))) + self.logit_conv = nn.Conv2d(conv_dim, num_tokens, 1) + # decoder + self.unpool = BFloat16UpsampleNearest2d(scale_factor=2) + self.conv2 = nn.Conv2d(z_dim, conv_dim, 3, stride=1, padding=1) + self.dec_block1 = nn.Sequential( + nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1), + nn.LeakyReLU(), + nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1), + nn.LeakyReLU(), + ) + self.gamma_3 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1))) + self.dec_block2 = nn.Sequential( + nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1), + nn.LeakyReLU(), + nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1), + nn.LeakyReLU(), + ) + self.gamma_4 = nn.Parameter(0.001 * torch.ones((1, conv_dim, 1, 1))) + self.rec_conv = nn.Conv2d(conv_dim, input_shape[0], 3, stride=1, padding=1) + + def forward_encoder(self, x): + x = self.relu(self.conv1(x)) + x = x + self.gamma_1 * self.enc_block1(x) + x = self.pool(x) + x = x + self.gamma_2 * self.enc_block2(x) + x = self.pool(x) + logits = self.logit_conv(x) + return logits + + def forward_decoder(self, logits): + soft_one_hot = F.softmax(logits * (self.temp*10), dim=1) + sampled = torch.einsum('bnhw,nd->bdhw', soft_one_hot, self.codebook.weight) + x = self.relu(self.conv2(sampled)) + x = self.unpool(x) + x = x + self.gamma_3 * self.dec_block1(x) + x = self.unpool(x) + x = x + self.gamma_4 * self.dec_block2(x) + rec_feats = self.rec_conv(x) + return rec_feats, soft_one_hot + + def forward(self, x): + print(x.shape) + logits = self.forward_encoder(x) + images_p, soft_one_hot = self.forward_decoder(logits) + return [logits, images_p] + +class ModalityCompletion(nn.Module): + def __init__(self, + input_shape_hr=(2816, 16, 16), + input_shape_s2=(2816, 16, 16), + input_shape_s1=(2816, 16, 16), + conv_dim=256, + z_dim=256, + n_codebook=8192, + init_cfg=None + ): + super(ModalityCompletion, self).__init__() + self.vae_hr = ConvVQVAEv2(input_shape=input_shape_hr, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook) + self.vae_s2 = ConvVQVAEv2(input_shape=input_shape_s2, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook) + self.vae_s1 = ConvVQVAEv2(input_shape=input_shape_s1, conv_dim=conv_dim, z_dim=z_dim, num_tokens=n_codebook) + self.kl_div_loss = torch.nn.KLDivLoss(reduction="none", log_target=True) + self.init_cfg=init_cfg + + def init_weights(self): + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + from mmcls.utils import get_root_logger + from mmcv.runner import CheckpointLoader, load_state_dict + logger = get_root_logger() + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + load_state_dict(self, state_dict, strict=False, logger=logger) + else: + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def kl_loss(self, logits_hr, logits_s2, logits_s1, modality_info): + prob_hr = F.log_softmax(logits_hr, dim=1) + prob_s2 = F.log_softmax(logits_s2, dim=1) + prob_s1 = F.log_softmax(logits_s1, dim=1) + flag_hr = modality_info[:,0][:, None, None, None] + flag_s2 = modality_info[:,1][:, None, None, None] + flag_s1 = modality_info[:,2][:, None, None, None] + loss_hr_s2 = self.kl_div_loss(prob_hr, prob_s2) + self.kl_div_loss(prob_s2, prob_hr) + loss_hr_s2 = (loss_hr_s2 * flag_hr * flag_s2).sum((1, 2, 3)).mean() + loss_hr_s1 = self.kl_div_loss(prob_hr, prob_s1) + self.kl_div_loss(prob_s1, prob_hr) + loss_hr_s1 = (loss_hr_s1 * flag_hr * flag_s1).sum((1, 2, 3)).mean() + loss_s2_s1 = self.kl_div_loss(prob_s2, prob_s1) + self.kl_div_loss(prob_s1, prob_s2) + loss_s2_s1 = (loss_s2_s1 * flag_s2 * flag_s1).sum((1, 2, 3)).mean() + loss = (loss_hr_s2 + loss_hr_s1 + loss_s2_s1) / 6.0 + + return loss + + def forward(self, feat_hr, feat_s2, feat_s1, modality_info): + # encoders,add noise + # each modality + # 2816, 16, 16 => conv 256, 4, 4 => flatten 4096(256*4*4) => linear mu 256, log_var 256 + B, C, H, W = feat_hr.shape + B_M, L_M = modality_info.shape + assert B == B_M, f'feat_hr batch: {B}, modality_info batch: {B_M}' + + # quant, emb_loss, info + # hr input flow + logits_hr = self.vae_hr.forward_encoder(feat_hr) + logits_s2 = self.vae_s2.forward_encoder(feat_s2) + logits_s1 = self.vae_s1.forward_encoder(feat_s1) + modality_hr = modality_info[:,0] + modality_s2 = modality_info[:,1] + modality_s1 = modality_info[:,2] + flag_hr = modality_hr[:, None, None, None] # B => B, C, H, W + flag_s2 = modality_s2[:, None, None, None] + flag_s1 = modality_s1[:, None, None, None] + + mean_logits_hr_s2 = logits_hr * flag_hr + logits_s2 * flag_s2 + mean_logits_hr_s1 = logits_hr * flag_hr + logits_s1 * flag_s1 + mean_logits_s1_s2 = logits_s1 * flag_s1 + logits_s2 * flag_s2 + + logits_hr_rec = logits_hr * flag_hr + mean_logits_s1_s2 * (~flag_hr) + logits_s2_rec = logits_s2 * flag_s2 + mean_logits_hr_s1 * (~flag_s2) + logits_s1_rec = logits_s1 * flag_s1 + mean_logits_hr_s2 * (~flag_s1) + g_hr, soft_one_hot_hr = self.vae_hr.forward_decoder(logits_hr_rec) + g_s2, soft_one_s2 = self.vae_s2.forward_decoder(logits_s2_rec) + g_s1, soft_one_s1 = self.vae_s1.forward_decoder(logits_s1_rec) + + hr_out = feat_hr * flag_hr + g_hr * (~flag_hr) + s2_out = feat_s2 * flag_s2 + g_s2 * (~flag_s2) + s1_out = feat_s1 * flag_s1 + g_s1 * (~flag_s1) + + output = {} + + output['hr_out'] = hr_out + output['s2_out'] = s2_out + output['s1_out'] = s1_out + + output['modality_info'] = modality_info + + output['input_hr'] = feat_hr + output['input_s2'] = feat_s2 + output['input_s1'] = feat_s1 + + output['logits_hr'] = logits_hr + output['logits_s2'] = logits_s2 + output['logits_s1'] = logits_s1 + + output['soft_one_hot_hr'] = soft_one_hot_hr + output['soft_one_hot_s2'] = soft_one_s2 + output['soft_one_hot_s1'] = soft_one_s1 + + output['g_hr'] = g_hr + output['g_s2'] = g_s2 + output['g_s1'] = g_s1 + output['loss_quant'] = self.kl_loss(logits_hr, logits_s2, logits_s1, modality_info) + + return output + diff --git a/lib/models/necks/transformer_encoder.py b/lib/models/necks/transformer_encoder.py new file mode 100644 index 0000000..fcdd699 --- /dev/null +++ b/lib/models/necks/transformer_encoder.py @@ -0,0 +1,144 @@ +# Copyright (c) Ant Group. All rights reserved. +from collections import OrderedDict +import torch +import torch.nn as nn +from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmcv.runner import (CheckpointLoader, load_state_dict) +from torch.nn.modules.batchnorm import _BatchNorm +from mmseg.models.backbones.vit import TransformerEncoderLayer + +from mmseg.utils import get_root_logger + + +class TransformerEncoder(nn.Module): + + def __init__(self, + input_dims=768, + embed_dims=768, + num_layers=4, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + with_cls_token=True, + output_cls_token=False, + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + num_fcs=2, + norm_eval=False, + with_cp=False, + init_cfg=None, + *args, + **kwargs): + super(TransformerEncoder, self).__init__() + + self.porj_linear = nn.Linear(input_dims, embed_dims) + if output_cls_token: + assert with_cls_token is True, f'with_cls_token must be True if' \ + f'set output_cls_token to True, but got {with_cls_token}' + + self.init_cfg = init_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + self.with_cls_token = with_cls_token + self.output_cls_token = output_cls_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self.drop_after_pos = nn.Dropout(p=drop_rate) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, num_layers) + ] # stochastic depth decay rule + + self.layers = nn.ModuleList() + for i in range(num_layers): + self.layers.append( + TransformerEncoderLayer(embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=mlp_ratio * + embed_dims, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + drop_path_rate=dpr[i], + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + with_cp=with_cp, + batch_first=True)) + + def init_weights(self): + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + logger = get_root_logger() + checkpoint = CheckpointLoader.load_checkpoint( + self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + + if 'state_dict' in checkpoint: + _state_dict = checkpoint['state_dict'] + else: + _state_dict = checkpoint + + state_dict = OrderedDict() + for k, v in _state_dict.items(): + if k.startswith('backbone.'): + state_dict[k[9:]] = v + else: + state_dict[k] = v + + load_state_dict(self, state_dict, strict=False, logger=logger) + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def forward(self, inputs, require_feat: bool = False, require_two: bool = False): + inputs = self.porj_linear(inputs) + B, N, C = inputs.shape + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, inputs), dim=1) + if not self.with_cls_token: + # Remove class token for transformer encoder input + x = x[:, 1:] + + # add hidden and atten state + block_outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if require_feat: + block_outs.append(x) + + if self.output_cls_token: + if require_two: + x = x[:, :2] + else: + x = x[:, 0] + elif not self.output_cls_token and self.with_cls_token: + x = x # [:, :] + + if require_feat: + return x, block_outs + else: + return x + + def train(self, mode=True): + super(TransformerEncoder, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.LayerNorm): + m.eval() \ No newline at end of file diff --git a/lib/models/segmentors/__init__.py b/lib/models/segmentors/__init__.py new file mode 100644 index 0000000..d7ef4c6 --- /dev/null +++ b/lib/models/segmentors/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Ant Financial Service Group and its affiliates. +from .skysense_pp_pipeline import SkySensePP + +__all__ = ['SkySensePP'] diff --git a/lib/models/segmentors/skysense_pp_pipeline.py b/lib/models/segmentors/skysense_pp_pipeline.py new file mode 100644 index 0000000..c3d0307 --- /dev/null +++ b/lib/models/segmentors/skysense_pp_pipeline.py @@ -0,0 +1,458 @@ +# coding: utf-8 +# Copyright (c) Ant Group. All rights reserved. +import torch +import torch.nn as nn +from torch import optim +import torch.nn.functional as F + +import math +import random +from antmmf.common.registry import registry +from antmmf.models.base_model import BaseModel +from lib.models.backbones import build_backbone +from lib.models.necks import build_neck +from lib.models.heads import build_head +from lib.utils.utils import LayerDecayValueAssigner + + +@registry.register_model("SkySensePP") +class SkySensePP(BaseModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.sources = config.sources + assert len(self.sources) > 0, 'at least one data source is required' + if 's2' in self.sources: + self.use_ctpe = config.use_ctpe + self.use_modal_vae = config.use_modal_vae + self.use_cls_token_uper_head = config.use_cls_token_uper_head + self.target_mean=[0.485, 0.456, 0.406] + self.target_std=[0.229, 0.224, 0.225] + self.vocabulary_size = config.vocabulary_size + self.vocabulary = list(range(1, config.vocabulary_size + 1)) # 0 for ignore + + + def build(self): + if 'hr' in self.sources: + self.backbone_hr = self._build_backbone('hr') + if 's2' in self.sources: + self.backbone_s2 = self._build_backbone('s2') + if self.use_ctpe: + self.ctpe = nn.Parameter( + torch.zeros(1, self.config.calendar_time, + self.config.necks.input_dims)) + if 'head_s2' in self.config.keys(): + self.head_s2 = self._build_head('head_s2') + self.fusion = self._build_neck('necks') + if 's1' in self.sources: + self.backbone_s1 = self._build_backbone('s1') + if 'head_s1' in self.config.keys(): + self.head_s1 = self._build_head('head_s1') + self.head_rec_hr = self._build_head('rec_head_hr') + + self.with_aux_head = False + + if self.use_modal_vae: + self.modality_vae = self._build_neck('modality_vae') + if 'auxiliary_head' in self.config.keys(): + self.with_aux_head = True + self.aux_head = self._build_head('auxiliary_head') + if 'init_cfg' in self.config.keys( + ) and self.config.init_cfg is not None and self.config.init_cfg.checkpoint is not None and self.config.init_cfg.key is not None: + self.load_pretrained(self.config.init_cfg.checkpoint, + self.config.init_cfg.key) + + def _build_backbone(self, key): + config_dict = self.config[f'backbone_{key}'].to_dict() + backbone_type = config_dict.pop('type') + backbone = build_backbone(backbone_type, **config_dict) + backbone.init_weights() + return backbone + + def _build_neck(self, key): + config_dict = self.config[key].to_dict() + neck_type = config_dict.pop('type') + neck = build_neck(neck_type, **config_dict) + neck.init_weights() + return neck + + def _build_head(self, key): + head_config = self.config[key].to_dict() + head_type = head_config.pop('type') + head = build_head(head_type, **head_config) + return head + + def get_optimizer_parameters(self, config): + optimizer_grouped_parameters = [ + { + "params": [], + "lr": config.optimizer_attributes.params.lr, + "weight_decay": config.optimizer_attributes.params.weight_decay, + }, + { + "params": [], + "lr": config.optimizer_attributes.params.lr, + "weight_decay": 0.0, + }, + ] + layer_decay_value_assigner_hr = LayerDecayValueAssigner( + config.lr_parameters.layer_decay, None, + config.optimizer_attributes.params.lr, 'swin', + config.model_attributes.SkySensePP.backbone_hr.arch + ) + layer_decay_value_assigner_s2 = LayerDecayValueAssigner( + config.lr_parameters.layer_decay, 24, + config.optimizer_attributes.params.lr, 'vit', + ) + layer_decay_value_assigner_s1 = LayerDecayValueAssigner( + config.lr_parameters.layer_decay, 24, + config.optimizer_attributes.params.lr, 'vit', + ) + layer_decay_value_assigner_fusion = LayerDecayValueAssigner( + config.lr_parameters.layer_decay, 24, + config.optimizer_attributes.params.lr, 'vit', + ) + num_frozen_params = 0 + if 'hr' in self.sources: + print('hr'.center(60, '-')) + num_frozen_params += layer_decay_value_assigner_hr.fix_param( + self.backbone_hr, + config.lr_parameters.frozen_blocks, + ) + optimizer_grouped_parameters.extend( + layer_decay_value_assigner_hr.get_parameter_groups( + self.backbone_hr, config.optimizer_attributes.params.weight_decay + ) + ) + if 's2' in self.sources: + print('s2'.center(60, '-')) + num_frozen_params += layer_decay_value_assigner_s2.fix_param( + self.backbone_s2, + config.lr_parameters.frozen_blocks, + ) + optimizer_grouped_parameters.extend( + layer_decay_value_assigner_s2.get_parameter_groups( + self.backbone_s2, config.optimizer_attributes.params.weight_decay + ) + ) + no_decay = [".bn.", "bias"] + optimizer_grouped_parameters[0]["params"] += [ + p for n, p in self.head_s2.named_parameters() + if not any(nd in n for nd in no_decay) + ] + optimizer_grouped_parameters[1]["params"] += [ + p for n, p in self.head_s2.named_parameters() + if any(nd in n for nd in no_decay) + ] + if self.use_ctpe: + optimizer_grouped_parameters[1]["params"] += [self.ctpe] + + if 's1' in self.sources: + print('s1'.center(60, '-')) + num_frozen_params += layer_decay_value_assigner_s1.fix_param( + self.backbone_s1, + config.lr_parameters.frozen_blocks, + ) + optimizer_grouped_parameters.extend( + layer_decay_value_assigner_s1.get_parameter_groups( + self.backbone_s1, config.optimizer_attributes.params.weight_decay + ) + ) + no_decay = [".bn.", "bias"] + optimizer_grouped_parameters[0]["params"] += [ + p for n, p in self.head_s1.named_parameters() + if not any(nd in n for nd in no_decay) + ] + optimizer_grouped_parameters[1]["params"] += [ + p for n, p in self.head_s1.named_parameters() + if any(nd in n for nd in no_decay) + ] + + if len(self.sources) > 1: + print('fusion'.center(60, '-')) + num_frozen_params += layer_decay_value_assigner_fusion.fix_param_deeper( + self.fusion, + config.lr_parameters.frozen_fusion_blocks_start, # 冻结后面所有的stage + ) + optimizer_grouped_parameters.extend( + layer_decay_value_assigner_fusion.get_parameter_groups( + self.fusion, config.optimizer_attributes.params.weight_decay + ) + ) + + if self.use_modal_vae: + no_decay = [".bn.", "bias"] + optimizer_grouped_parameters[0]["params"] += [ + p for n, p in self.modality_vae.named_parameters() + if not any(nd in n for nd in no_decay) + ] + optimizer_grouped_parameters[1]["params"] += [ + p for n, p in self.modality_vae.named_parameters() + if any(nd in n for nd in no_decay) + ] + + no_decay = [".bn.", "bias"] + optimizer_grouped_parameters[0]["params"] += [ + p for n, p in self.head_rec_hr.named_parameters() + if not any(nd in n for nd in no_decay) + ] + optimizer_grouped_parameters[1]["params"] += [ + p for n, p in self.head_rec_hr.named_parameters() + if any(nd in n for nd in no_decay) + ] + + if self.with_aux_head: + no_decay = [".bn.", "bias"] + optimizer_grouped_parameters[0]["params"] += [ + p for n, p in self.aux_head.named_parameters() + if not any(nd in n for nd in no_decay) + ] + optimizer_grouped_parameters[1]["params"] += [ + p for n, p in self.aux_head.named_parameters() + if any(nd in n for nd in no_decay) + ] + num_params = [len(x['params']) for x in optimizer_grouped_parameters] + print(len(list(self.parameters())), sum(num_params), num_frozen_params) + assert len(list(self.parameters())) == sum(num_params) + num_frozen_params + return optimizer_grouped_parameters + + def get_custom_scheduler(self, trainer): + optimizer = trainer.optimizer + num_training_steps = trainer.config.training_parameters.max_iterations + num_warmup_steps = trainer.config.training_parameters.num_warmup_steps + + if "train" in trainer.run_type: + if num_training_steps == math.inf: + epoches = trainer.config.training_parameters.max_epochs + assert epoches != math.inf + num_training_steps = trainer.config.training_parameters.max_epochs * trainer.epoch_iterations + + def linear_with_wram_up(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max( + 1, num_warmup_steps)) + return max( + 0.0, + float(num_training_steps - current_step) / + float(max(1, num_training_steps - num_warmup_steps)), + ) + + def cos_with_wram_up(current_step): + num_cycles = 0.5 + if current_step < num_warmup_steps: + return float(current_step) / float(max( + 1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps)) + return max( + 0.0, + 0.5 * + (1.0 + + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), + ) + + lr_lambda = cos_with_wram_up if trainer.config.training_parameters.cos_lr else linear_with_wram_up + + else: + def lr_lambda(current_step): + return 0.0 # noqa + + return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, -1) + + def convert_target(self, target): + mean = target.new_tensor(self.target_mean).reshape(1, 3, 1, 1) + std = target.new_tensor(self.target_std).reshape(1, 3, 1, 1) + target = ((target * std + mean)*255).to(torch.long) + target[:, 0] = target[:, 0] * 256 * 256 + target[:, 1] = target[:, 1] * 256 + target = target.sum(1).type(torch.long) + unique_target = target.unique() + target_index = torch.searchsorted(unique_target, target) + no_bg = False + if unique_target[0].item() > 0: + target_index += 1 + no_bg = True + target_index_unique = target_index.unique().tolist() + random.shuffle(self.vocabulary) + value = target.new_tensor([0] + self.vocabulary) + mapped_target = target_index.clone() + idx_2_color = {} + for v in target_index_unique: + mapped_target[target_index == v] = value[v] + idx_2_color[value[v].item()] = unique_target[v - 1 if no_bg else v].item() + return mapped_target, idx_2_color + + def forward(self, sample_list): + output = dict() + modality_flag_hr = sample_list["modality_flag_hr"] + modality_flag_s2 = sample_list["modality_flag_s2"] + modality_flag_s1 = sample_list["modality_flag_s1"] + modalities = [modality_flag_hr, modality_flag_s2, modality_flag_s1] + modalities = torch.tensor(modalities).permute(1,0).contiguous() # L, B => B, L + + anno_img = sample_list["targets"] + anno_img, idx_2_color = self.convert_target(anno_img) + output["mapped_targets"] = anno_img + output["idx_2_color"] = idx_2_color + anno_mask = sample_list["anno_mask"] + anno_s2 = anno_img[:, 15::32, 15::32] + anno_s1 = anno_s2 + + output["anno_hr"] = anno_img + output["anno_s2"] = anno_s2 + + ### 1. backbone + if 'hr' in self.sources: + hr_img = sample_list["hr_img"] + B_MASK, H_MASK, W_MASK = anno_mask.shape + block_size = 32 + anno_mask_hr = anno_mask.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, block_size, block_size) + anno_mask_hr = anno_mask_hr.permute(0, 1, 3, 2, 4).reshape(B_MASK, H_MASK*block_size, W_MASK*block_size).contiguous() + B, C_G, H_G, W_G = hr_img.shape + hr_features = self.backbone_hr(hr_img, anno_img, anno_mask_hr) + output['mask_hr'] = anno_mask_hr + output['target_hr'] = anno_img + + if 's2' in self.sources: + s2_img = sample_list["s2_img"] + B, C_S2, S_S2, H_S2, W_S2 = s2_img.shape + s2_img = s2_img.permute(0, 2, 1, 3, + 4).reshape(B * S_S2, C_S2, H_S2, W_S2).contiguous() # ts time to batch + anno_mask_s2 = anno_mask + s2_features = self.backbone_s2(s2_img, anno_s2, anno_mask_s2) + if 'head_s2' in self.config.keys(): + s2_features = self.head_s2(s2_features[-1]) + s2_features = [s2_features] + + if 's1' in self.sources: + s1_img = sample_list["s1_img"] + B, C_S1, S_S1, H_S1, W_S1 = s1_img.shape + s1_img = s1_img.permute(0, 2, 1, 3, + 4).reshape(B * S_S1, C_S1, H_S1, W_S1).contiguous() + + anno_mask_s1 = anno_mask + s1_features = self.backbone_s1(s1_img, anno_s1, anno_mask_s1) + if 'head_s1' in self.config.keys(): + s1_features = self.head_s1(s1_features[-1]) + s1_features = [s1_features] + + ### 2. prepare features for fusion + hr_features_stage3 = hr_features[-1] + s2_features_stage3 = s2_features[-1] + s1_features_stage3 = s1_features[-1] + modalities = modalities.to(hr_features_stage3.device) + if self.use_modal_vae: + vae_out = self.modality_vae(hr_features_stage3, s2_features_stage3, s1_features_stage3, modalities) + hr_features_stage3 = vae_out['hr_out'] + s2_features_stage3 = vae_out['s2_out'] + s1_features_stage3 = vae_out['s1_out'] + output['vae_out'] = vae_out + + features_stage3 = [] + if 'hr' in self.sources: + B, C3_G, H3_G, W3_G = hr_features_stage3.shape + hr_features_stage3 = hr_features_stage3.permute( + 0, 2, 3, 1).reshape(B * H3_G * W3_G, C3_G).unsqueeze(1).contiguous() # B * H3_G * W3_G, 1, C3_G + features_stage3 = hr_features_stage3 + + if 's2' in self.sources: + # s2_features_stage3 = s2_features[-1] + _, C3_S2, H3_S2, W3_S2 = s2_features_stage3.shape + s2_features_stage3 = s2_features_stage3.reshape( + B, S_S2, C3_S2, H3_S2, + W3_S2).permute(0, 3, 4, 1, 2).reshape(B, H3_S2 * W3_S2, S_S2, + C3_S2).contiguous() + if self.use_ctpe: + ct_index = sample_list["s2_ct"] + ctpe = self.ctpe[:, ct_index, :].contiguous().permute(1, 0, 2, 3).contiguous() + ctpe = ctpe.expand(-1, 256, -1, -1) + + ct_index_2 = sample_list["s2_ct2"] + ctpe2 = self.ctpe[:, ct_index_2, :].contiguous().permute(1, 0, 2, 3).contiguous() + ctpe2 = ctpe2.expand(-1, 256, -1, -1) + + ctpe_comb = torch.cat([ctpe, ctpe2], 1) + # import pdb;pdb.set_trace() + s2_features_stage3 = (s2_features_stage3 + ctpe_comb).reshape( + B * H3_S2 * W3_S2, S_S2, C3_S2).contiguous() + else: + s2_features_stage3 = s2_features_stage3.reshape( + B * H3_S2 * W3_S2, S_S2, C3_S2).contiguous() + + if len(features_stage3) > 0: + assert H3_G == H3_S2 and W3_G == W3_S2 and C3_G == C3_S2 + features_stage3 = torch.cat((features_stage3, s2_features_stage3), dim=1) + else: + features_stage3 = s2_features_stage3 + + if 's1' in self.sources: + # s1_features_stage3 = s1_features[-1] + _, C3_S1, H3_S1, W3_S1 = s1_features_stage3.shape + s1_features_stage3 = s1_features_stage3.reshape( + B, S_S1, C3_S1, H3_S1, + W3_S1).permute(0, 3, 4, 1, 2).reshape(B, H3_S1 * W3_S1, S_S1, + C3_S1).contiguous() + s1_features_stage3 = s1_features_stage3.reshape( + B * H3_S1 * W3_S1, S_S1, C3_S1).contiguous() + + if len(features_stage3) > 0: + assert H3_S1 == H3_S2 and W3_S1 == W3_S2 and C3_S1 == C3_S2 + features_stage3 = torch.cat((features_stage3, s1_features_stage3), + dim=1) + else: + features_stage3 = s1_features_stage3 + + ### 3. fusion + if self.config.necks.output_cls_token: + if self.config.necks.get('require_feat', False): + cls_token, block_outs = self.fusion(features_stage3 , True) + else: + cls_token = self.fusion(features_stage3) + _, C3_G = cls_token.shape + cls_token = cls_token.reshape(B, H3_G, W3_G, + C3_G).contiguous().permute(0, 3, 1, 2).contiguous() # b, c, h, w + else: + assert self.config.necks.with_cls_token is False + if self.config.necks.get('require_feat', False): + features_stage3, block_outs = self.fusion(features_stage3, True) + else: + features_stage3 = self.fusion(features_stage3) + features_stage3 = features_stage3.reshape( + B, H3_S2, W3_S2, S_S2, + C3_S2).permute(0, 3, 4, 1, 2).reshape(B * S_S2, C3_S2, H3_S2, + W3_S2).contiguous() + ### 4. decoder for rec + hr_rec_inputs = hr_features + feat_stage1 = hr_rec_inputs[0] + + if feat_stage1.shape[-1] == feat_stage1.shape[-2]: + feat_stage1_left, feat_stage1_right = torch.split(feat_stage1, feat_stage1.shape[-1] // 2, dim=-1) + feat_stage1 = torch.cat((feat_stage1_left, feat_stage1_right), dim=1) + hr_rec_inputs = list(hr_features) + hr_rec_inputs[0] = feat_stage1 + + rec_feats = [*hr_rec_inputs, cls_token] + logits_hr = self.head_rec_hr(rec_feats) + if self.config.get('upsacle_results', True): + logits_hr = logits_hr.to(torch.float32) + logits_hr = F.interpolate(logits_hr, scale_factor=4, mode='bilinear', align_corners=True) + output["logits_hr"] = logits_hr + return output + + def load_pretrained(self, ckpt_path, key): + pretrained_dict = torch.load(ckpt_path, map_location={'cuda:0': 'cpu'}) + pretrained_dict = pretrained_dict[key] + for k, v in pretrained_dict.items(): + if k == 'backbone_s2.patch_embed.projection.weight': + pretrained_in_channels = v.shape[1] + if self.config.backbone_s2.in_channels == 4: + new_weight = v[:, [0, 1, 2, 6]] + new_weight = new_weight * ( + pretrained_in_channels / + self.config.backbone_s2.in_channels) + pretrained_dict[k] = new_weight + missing_keys, unexpected_keys = self.load_state_dict(pretrained_dict, + strict=False) + print('missing_keys:', missing_keys) + print('unexpected_keys:', unexpected_keys) + diff --git a/lib/predictors/flood3i_1shot.py b/lib/predictors/flood3i_1shot.py new file mode 100644 index 0000000..1a01cad --- /dev/null +++ b/lib/predictors/flood3i_1shot.py @@ -0,0 +1,244 @@ +import os +import glob +import numpy as np +import yaml +import argparse +import oss2 +import torch +from PIL import Image +from tqdm import tqdm +import concurrent.futures +from torchvision.transforms import functional as F +import random + + +from antmmf.common.registry import registry +from antmmf.common.report import Report, default_result_formater +from antmmf.structures import Sample, SampleList +from antmmf.predictors.base_predictor import BasePredictor +from antmmf.utils.timer import Timer +from antmmf.predictors.build import build_predictor +from antmmf.common.task_loader import build_collate_fn +from antmmf.datasets.samplers import SequentialSampler +from antmmf.common.build import build_config + +from lib.utils.checkpoint import SegCheckpoint +from lib.datasets.loader.few_shot_flood3i_loader import FewShotFloodLoader + + +def seed_everything(seed=0): + # 为了确保CUDA卷积的确定性 + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + +@registry.register_predictor("OneshotPredictor") +class OneshotPredictor(BasePredictor, FewShotFloodLoader): + def __init__(self, config): + self.config = config + self.predictor_parameters = self.config.predictor_parameters + + def _predict(self, sample_list): + # with torch.no_grad(): + if True: + sample_list = sample_list.to(self.device) + report = self._forward_pass(sample_list) + return report, sample_list + + def _forward_pass(self, samplelist): + autocast_dtype = torch.bfloat16 + with torch.cuda.amp.autocast(enabled=True, dtype=autocast_dtype): + model_output = self.model(samplelist) + report = Report(samplelist, model_output) + return report + + def predict(self, data=None): + if data is None: + data = self.dummy_request() + sample = self._build_sample(data) + if not isinstance(sample, Sample): + raise Exception( + f"Method _build_sample is expected to return a instance of antmmf.structures.sample.Sample," + f"but got type {type(sample)} instead.") + result, sample_list = self._predict(SampleList([sample])) + np_result = default_result_formater(result) + result = self.format_result(np_result) + assert isinstance( + result, dict + ), f"Result should be instance of Dict,but got f{type(result)} instead" + return result, sample_list + + def load_checkpoint(self): + self.resume_file = self.config.predictor_parameters.model_dir + self.checkpoint = SegCheckpoint(self, load_only=True) + self.checkpoint.load_model_weights(self.resume_file, force=True) + + def covert_speedup_op(self): + if self.config.predictor_parameters.replace_speedup_op: + from lib.utils.optim_utils import replace_speedup_op + self.model = replace_speedup_op(self.model) + +def save_image(output_path, image_np_array): + image = Image.fromarray(image_np_array) + image.save(output_path) + +def build_predictor_from_args(args, *rest, **kwargs): + config = build_config( + args.config, + config_override=args.config_override, + opts_override=args.opts, + specific_override=args, + ) + predictor_obj = build_predictor(config) + setattr(predictor_obj, "args", args) + return predictor_obj + + +def build_online_predictor(model_dir=None, config_yaml=None): + assert model_dir or config_yaml + from antmmf.utils.flags import flags + + # if config_yaml not indicated, there must be a `config.yaml` file under `model_dir` + config_path = config_yaml if config_yaml else os.path.join(model_dir, "config.yaml") + input_args = ["--config", config_path] + if model_dir is not None: + input_args += ["predictor_parameters.model_dir", model_dir] + parser = flags.get_parser() + args = parser.parse_args(input_args) + predictor = build_predictor_from_args(args) + return predictor + +def profile(profiler, text): + print(f'{text}: {profiler.get_time_since_start()}') + profiler.reset() + +def cvt_colors(img_2d, idx_2_color_rgb): + img_rgb = np.zeros((img_2d.shape[0], img_2d.shape[1], 3), dtype=np.uint8) + for idx, color in idx_2_color_rgb.items(): + img_rgb[img_2d==idx] = color + return img_rgb + +def process_results(preds, targets, input_imgs, img_names, save_dir, save_dir_vis, idx_2_color): + imagenet_std = np.array([0.229, 0.224, 0.225]) + imagenet_mean = np.array([0.485, 0.456, 0.406]) + idx_2_color_rgb = {} + + for idx, color in idx_2_color.items(): + r = color // (256 * 256) + g = (color % (256 * 256)) // 256 + b = color % 256 + idx_2_color_rgb[idx] = (r, g, b) + + for i in range(preds.size(0)): + output1 = preds[i].argmax(0) # h, w + output1_total = output1.clone() + output1 = output1[output1.shape[0]//2:, :] + output1 = output1.numpy().astype(np.uint8) + output1 = cvt_colors(output1, idx_2_color_rgb) + + output1_total = output1_total.numpy().astype(np.uint8) + output1_total = cvt_colors(output1_total, idx_2_color_rgb) + + # for visualization + output2 = targets[i] + output2 = output2.numpy().astype(np.uint8) + output2 = cvt_colors(output2, idx_2_color_rgb) + + input_img = torch.einsum('chw->hwc', input_imgs[i]) + input_img = torch.clip((input_img * imagenet_std + imagenet_mean) * 255, 0, 255) + input_img = input_img.numpy().astype(np.uint8) + + output_comb = np.concatenate((input_img, output1_total, output2), axis=1) + + # save result + save_path = os.path.join(save_dir, f'{img_names[i]}.png') + save_image(save_path, output1) + save_path_vis = os.path.join(save_dir_vis, f'{img_names[i]}.png') + save_image(save_path_vis, output_comb) + +def test(args): + model_path = args.model_path + config_path = args.config + global_seed = args.seed + predictor = build_online_predictor(model_path, config_path) + seed_everything(global_seed) + + dataset = FewShotFloodLoader( + "test", predictor.config.task_attributes.segmentation.dataset_attributes.few_shot_flood_segmentation) + + loader = torch.utils.data.DataLoader( + dataset, + batch_size=8, + shuffle=False, + sampler=SequentialSampler(dataset), + collate_fn=build_collate_fn(dataset), + num_workers=16, + pin_memory=True, + drop_last=False, + ) + + print(len(loader)) + + predictor.load(with_ckpt=True) + + predictor.covert_speedup_op() + + save_dir = args.save_dir + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + save_dir_vis = os.path.join(args.save_dir, 'vis_full') + if not os.path.exists(save_dir_vis): + os.makedirs(save_dir_vis) + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: + profiler2 = Timer() + profiler2.reset() + for sample_batched in tqdm(loader): + profile(profiler2, "Build sample time") + result, sample_list = predictor._predict(SampleList(sample_batched)) + profile(profiler2, "Infer time") + preds = result["logits_hr"].to(torch.float32).detach().cpu() + targets = result['mapped_targets'].to(torch.float32).detach().cpu() + idx_2_color = result['idx_2_color'] + input_imgs = sample_list['hr_img'].to(torch.float32).detach().cpu() + img_names = sample_list["img_name"] + + executor.submit(process_results, preds, targets, input_imgs, img_names, save_dir, save_dir_vis, idx_2_color) + profile(profiler2, "Save results time") + try: + del predictor.model + except Exception as e: + print('delete model error: ', e) + +def parse_args(): + desc = '1-shot predictor' + parser = argparse.ArgumentParser(description=desc) + parser.add_argument('--model_path', + required=True, + type=str, + help='model directory') + parser.add_argument('--seed', + default=0, + type=int, + help='seed') + parser.add_argument('--config', + required=True, + type=str, + help='config path') + parser.add_argument('--save_dir', + required=False, + type=str, + help='save directory') + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + test(args) diff --git a/lib/task/__init__.py b/lib/task/__init__.py new file mode 100644 index 0000000..eafc3ec --- /dev/null +++ b/lib/task/__init__.py @@ -0,0 +1,3 @@ +from .segmentation import SegmentationTask + +__all__ = ['SegmentationTask'] diff --git a/lib/task/segmentation.py b/lib/task/segmentation.py new file mode 100644 index 0000000..949afcd --- /dev/null +++ b/lib/task/segmentation.py @@ -0,0 +1,18 @@ +# coding: utf-8 +# Copyright (c) Ant Group. All rights reserved. + +from antmmf.common.registry import registry +from antmmf.tasks import BaseTask + + +@registry.register_task("segmentation") +class SegmentationTask(BaseTask): + + def __init__(self): + super(SegmentationTask, self).__init__("segmentation") + + def _get_available_datasets(self): + return ["pretraining_loader"] + + def _preprocess_item(self, item): + return item diff --git a/lib/trainer/__init__.py b/lib/trainer/__init__.py new file mode 100644 index 0000000..79fa5ec --- /dev/null +++ b/lib/trainer/__init__.py @@ -0,0 +1,4 @@ +from .seg_trainer import SEGTrainer + +__all__ = ['SEGTrainer'] + diff --git a/lib/trainer/seg_trainer.py b/lib/trainer/seg_trainer.py new file mode 100644 index 0000000..82b9a70 --- /dev/null +++ b/lib/trainer/seg_trainer.py @@ -0,0 +1,399 @@ +# Copyright (c) Ant Group. and its affiliates. +import gc +import math +from itertools import chain + +import torch +from torch import nn +from tqdm import tqdm + +from antmmf.common.registry import registry +from antmmf.common.report import Report +from antmmf.common.meter import Meter +from antmmf.modules.metrics import Metrics +from antmmf.optimizer.combine_optimizers import CombinedOptimizer +from antmmf.utils.distributed_utils import (broadcast_scalar, is_main_process) +from antmmf.utils.early_stopping import EarlyStopping +from antmmf.utils.general import clip_gradients, count_parameters, nullcontext +from antmmf.utils.timer import Timer +from antmmf.trainers.base_trainer import BaseTrainer + +from lib.utils.utils import cancel_gradients_backbone, EMA +from lib.utils.checkpoint import SegCheckpoint + +try: + import atorch + from atorch import amp +except ImportError: + pass + + +@registry.register_trainer("seg_trainer") +class SEGTrainer(BaseTrainer): + + def __init__(self, config): + super().__init__(config) + self.enable_torch_amp=True + self.enable_atorch_amp=False + + def load(self, has_check_point=True): + super().load(has_check_point) + torch.backends.cuda.matmul.allow_tf32 = self.config.training_parameters.get( + "enable_tf32", False) + if hasattr( + self.config.training_parameters, "freeze_backbone" + ) and self.config.training_parameters.freeze_backbone is True: + for n, p in self.model.named_parameters(): + if "backbone_hr." in n or 'backbone_s2.' in n or 'head_s2.' in n or 'backbone_s1.' in n or 'head_s1.' in n or 'fusion.' in n or 'ctpe' in n or 'glbank.' in n: + p.requires_grad = False + else: + print(n, '-->', p.requires_grad) + if hasattr(self.config.training_parameters, + "ema") and self.config.training_parameters.ema is True: + self.ema = EMA(self.model, 0.96) + self.ema.register() + + def load_extras(self, has_check_point=True): + self.checkpoint = None if has_check_point is False else SegCheckpoint( + self) + self.meter = Meter() + + self.training_parameters = self.config.training_parameters + + monitored_metric = self.training_parameters.monitored_metric + metric_minimize = self.training_parameters.metric_minimize + should_early_stop = self.training_parameters.should_early_stop + patience = self.training_parameters.patience + + self.log_interval = self.training_parameters.log_interval + self.snapshot_interval = self.training_parameters.snapshot_interval + self.max_iterations = self.training_parameters.max_iterations + self.should_clip_gradients = self.training_parameters.clip_gradients + self.max_epochs = self.training_parameters.max_epochs + self.gradient_accumulation_steps = int( + self.training_parameters.gradient_accumulation_steps) + assert self.gradient_accumulation_steps >= 1 + for t_type in self.task_loader.task_type: + if t_type == "train": + self.dataset_train_order = self.training_parameters.get( + "dataset_train_order", self.train_task.datasets_name) + if t_type == "val": + self.dataset_val_order = self.training_parameters.get( + "dataset_val_order", self.val_task.datasets_name) + if t_type == "test": + self.dataset_test_order = self.training_parameters.get( + "dataset_test_order", self.test_task.datasets_name) + if t_type == "interpret": + self.dataset_interpret_order = self.training_parameters.get( + "dataset_interpret_order", + self.interpret_task.datasets_name) + + self.early_stopping = EarlyStopping( + self.model, + self.checkpoint, + monitored_metric, + patience=patience, + minimize=metric_minimize, + should_stop=should_early_stop, + ) + self.current_epoch = 1 + self.current_iteration = 0 + + self.not_debug = self.training_parameters.logger_level != "debug" + + self.lr_scheduler = None + self.setup_lr_scheduler() + + if self.checkpoint is not None: + self.checkpoint.load_state_dict() + + if "overall_metrics" in self.training_parameters: + self.overall_metric_evaluator = Metrics( + self.config.training_parameters.get("overall_metrics", [])) + self.synchronized_loss = self.config.training_parameters.synchronized_loss + + def train(self): + self.writer.write("===== Model =====") + self.writer.write(self.model) + self.writer.write( + "Model Params: Trainable {Trainable:.3f}M Total {Total:.3f}M". + format(**count_parameters(self.model))) + + if "train" not in self.run_type: + self.inference() + return + + should_break = False + + if self.max_epochs is None: + self.max_epochs = math.inf + else: + self.max_iterations = min(self.max_iterations, + self.max_epochs * self.epoch_iterations) + + self.model.train() + self.train_timer = Timer() + + self.profile("Setup Time") + + if self.enable_torch_amp: + self.writer.write("Using Automatic mixed precision training") + if hasattr(self.config, "amp_attributes") and hasattr( + self.config.amp_attributes, "growth_interval"): + growth_interval = self.config.amp_attributes.growth_interval + else: + growth_interval = 2000 + self.scaler = torch.cuda.amp.GradScaler( + init_scale=self.config.amp_attributes.init_scale, + enabled=False, + growth_interval=growth_interval) + self.writer.write("Using Init scale:%s" % + self.config.amp_attributes.init_scale) + + self.optimizer.zero_grad() + + self.writer.write("Starting training...") + while self.current_iteration < self.max_iterations and not should_break: + registry.register("current_epoch", self.current_epoch) + self.task_loader.seed_sampler("train", self.current_epoch) + + if self.current_epoch > self.max_epochs: + break + + for batch in tqdm( + chain(*self.train_loader_list), + total=self._len_of_loader_list(self.train_loader_list), + disable=self.disable_tqdm or (not is_main_process()), + ): + self.profile("Batch load time") + report, _, _ = self._forward_pass( + batch, enable_amp=self.enable_torch_amp) + if report is None: + continue + + self._update_meter(report, self.meter) + + loss = self._extract_loss(report) + self._backward(loss) + if hasattr( + self.config.training_parameters, + "ema") and self.config.training_parameters.ema is True: + self.ema.update() + should_break = self._logistics() + + self._run_scheduler() + + self.current_iteration += 1 + self.writer.write(self.current_iteration, "debug") + registry.register("current_iteration", self.current_iteration) + if self.current_iteration >= self.max_iterations: + break + if should_break: + break + + self.current_epoch += 1 + + self.finalize() + + def _forward_pass(self, batch, enable_amp=False): + if not batch: # Samplelist might be empty dict + return None, None, None + prepared_batch = self.task_loader.prepare_batch(batch) + + self.profile("Batch prepare time") + forward_context = torch.cuda.amp.autocast( + enabled=True, + dtype=torch.bfloat16) if enable_amp else nullcontext() + + with forward_context: + # Arguments should be a dict at this point + model_output = self.model(prepared_batch) + + if self.synchronized_loss: + is_parallel = isinstance( + self.model, nn.DataParallel) or isinstance( + self.model, nn.parallel.DistributedDataParallel) + if "losses" not in model_output: + loss_func = getattr( + self.model.module if is_parallel else self.model, + "losses") + model_output["losses"] = loss_func( + prepared_batch, + model_output, + iteration=self.current_iteration) + if "metrics" not in model_output: + metric_func = getattr( + self.model.module if is_parallel else self.model, + "metrics") + model_output["metrics"] = metric_func( + prepared_batch, model_output) + + report = Report(prepared_batch, model_output) + self.profile("Forward time") + + return report, model_output, prepared_batch + + def _backward(self, loss): + loss = loss / self.gradient_accumulation_steps + + if self.enable_torch_amp: + self.scaler.scale(loss).backward() + + # Unscales the gradients of optimizer's assigned params in-place, this should + # be called first so that clip_gradients can take effect as usual. + self.scaler.unscale_(self.optimizer) + elif self.enable_atorch_amp: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + self.profile("Backward time") + + if self.current_iteration % self.gradient_accumulation_steps != 0: + return + + if self.should_clip_gradients: + if self.enable_atorch_amp: + clip_gradients(amp.master_params(self.optimizer), + self.current_iteration, self.writer, + self.config) + else: + clip_gradients(self.model, self.current_iteration, self.writer, + self.config) + + if hasattr( + self.config.training_parameters, "freeze_backbone_steps" + ) and self.config.training_parameters.freeze_backbone_steps is not None: + cancel_gradients_backbone( + self.current_iteration, self.model, + self.config.training_parameters.freeze_backbone_steps) + + if self.enable_torch_amp: + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.optimizer.step() + + self.optimizer.zero_grad() + self.profile("Optimizer time") + + def _logistics(self): + should_print = self.current_iteration and self.current_iteration % self.log_interval == 0 + extra = {} + prefix = "" + + if should_print is True: + if "cuda" in str(self.device): + extra["max mem"] = torch.cuda.max_memory_allocated() / 1024 + extra["max mem"] //= 1024 + + # display lr + if isinstance(self.optimizer, CombinedOptimizer): + extra["lr"] = self.optimizer.get_optimizers_lr_str() + else: + extra["lr"] = "|".join([ + "{:.8f}".format(x["lr"]).rstrip("0") + for x in self.optimizer.param_groups + ]) + + extra.update({ + "time": self.train_timer.get_time_since_start(), + "eta": self._calculate_time_left(), + }) + + self.train_timer.reset() + + self._summarize_meter( + self.meter, + prefix=prefix, + extra=extra, + should_print=should_print, + ) + + should_break = self._try_full_validation() + + return should_break + + def _try_full_validation(self, force=False): + should_break = False + + if self.current_iteration and self.current_iteration % self.snapshot_interval == 0 or force: + self.writer.write( + "Evaluation time. Running on full validation set...") + + validation_timer = Timer() + dataset_name, meter = self.evaluate_set(self.val_loader_list) + extra = { + "validation time": validation_timer.get_time_since_start() + } + + overall_metric = self.overall_metric_evaluator.summarize() + stop = self.early_stopping(self.current_iteration, overall_metric, + meter) + if hasattr(self.config.training_parameters, + "ema") and self.config.training_parameters.ema is True: + self.ema.restore() + stop = bool(broadcast_scalar(stop, src=0, device=self.device)) + + extra.update(self.early_stopping.get_info()) + + prefix = "{}: full val".format(dataset_name) + self._summarize_overall(overall_metric, + meter, + prefix=prefix, + extra=extra) + gc.collect() + + if "cuda" in str(self.device): + with torch.cuda.device(self.device): + torch.cuda.empty_cache() + + if stop > 0: # `stop` is now `int`, NCCL does not support `boolean` type's broadcasting + self.writer.write("Early stopping activated") + should_break = True + + return should_break + + def evaluate_set(self, loader_list): + from antmmf.structures import SampleList + + meter = Meter() + torch.cuda.empty_cache() + with torch.no_grad(): + self.model.eval() + if hasattr(self.config.training_parameters, + "ema") and self.config.training_parameters.ema is True: + self.ema.apply_shadow() + if self.config.training_parameters.get('fp16', False): + self.model.half() + self.overall_metric_evaluator.reset() + for idx, batch in tqdm( + enumerate(chain(*loader_list)), + total=self._len_of_loader_list(loader_list), + disable=not is_main_process() or self.disable_tqdm, + ): + # report, model_output, prepared_batch = self._forward_pass( + # batch, enable_amp=self.enable_torch_amp) + if idx >= self.config.training_parameters.get('num_eval', 1e7): + break + if self.config.training_parameters.get('fp16', False): + input_dict = SampleList() + for k, v in batch.items(): + if isinstance(v, torch.cuda.FloatTensor): + input_dict[k] = v.half() + else: + input_dict[k] = v + report, model_output, prepared_batch = self._forward_pass( + input_dict, enable_amp=self.enable_torch_amp) + else: + report, model_output, prepared_batch = self._forward_pass( + batch, enable_amp=self.enable_torch_amp) + self._update_meter(report, meter) + self.overall_metric_evaluator.collect(prepared_batch, + model_output) + for _, metric_object in self.overall_metric_evaluator.metrics.items( + ): + metric_object.all_reduce() + self.model.train() + + return report.dataset_name, meter \ No newline at end of file diff --git a/lib/utils/checkpoint.py b/lib/utils/checkpoint.py new file mode 100644 index 0000000..cd8849f --- /dev/null +++ b/lib/utils/checkpoint.py @@ -0,0 +1,135 @@ +# Copyright (c) Ant Financial Service Group. and its affiliates. +import os +import warnings + +import torch + +from antmmf.common import constants +from antmmf.common.registry import registry +from antmmf.common.checkpoint import Checkpoint +from antmmf.utils.distributed_utils import is_main_process + +class SegCheckpoint(Checkpoint): + def __init__(self, trainer, load_only=False): + super().__init__(trainer, load_only=False) + + def load_model_weights(self, file, force=False): + self.trainer.writer.write("Loading checkpoint") + ckpt = self._torch_load(file) + if registry.get(constants.STATE) is constants.STATE_ONLINE_SERVING: + data_parallel = False + else: + data_parallel = registry.get("data_parallel") or registry.get( + "distributed") + + if "model" in ckpt: + ckpt_model = ckpt["model"] + else: + ckpt_model = ckpt + ckpt = {"model": ckpt} + + new_dict = {} + + # TODO: Move to separate function + for attr in ckpt_model: + if "fa_history" in attr: + new_dict[attr.replace("fa_history", + "fa_context")] = ckpt_model[attr] + elif data_parallel is False and attr.startswith("module."): + new_k = attr.replace("module.", "", 1) + if '.Wqkv.' in new_k: + new_k = new_k.replace('.Wqkv.', '.in_proj_') + + new_dict[new_k] = ckpt_model[attr] + elif data_parallel is not False and not attr.startswith("module."): + new_dict["module." + attr] = ckpt_model[attr] + elif data_parallel is False and not attr.startswith("module."): + print('data_parallel is False and not attr!!!') + new_k = attr + if '.Wqkv.' in new_k: + new_k = new_k.replace('.Wqkv.', '.in_proj_') + new_dict[new_k] = ckpt_model[attr] + else: + new_dict[attr] = ckpt_model[attr] + print(new_dict.keys()) + self._load_state_dict(new_dict) + self._load_model_weights_with_mapping(new_dict, force=force) + print(f'load weight: {file} done!') + return ckpt + + def _load(self, file, force=False, resume_state=False): + ckpt = self.load_model_weights(file, force=force) + + # skip loading training state + if resume_state is False: + return + + if "optimizer" in ckpt: + try: + self.trainer.optimizer.load_state_dict(ckpt["optimizer"]) + # fix the bug of checkpoint in the pytorch with version higher than 1.11 + if "capturable" in self.trainer.optimizer.param_groups[0]: + self.trainer.optimizer.param_groups[0]["capturable"] = True + except Exception as e: + print(e) + + else: + warnings.warn( + "'optimizer' key is not present in the checkpoint asked to be loaded. Skipping." + ) + + if "lr_scheduler" in ckpt: + self.trainer.lr_scheduler.load_state_dict(ckpt["lr_scheduler"]) + else: + warnings.warn( + "'lr_scheduler' key is not present in the checkpoint asked to be loaded. Skipping." + ) + + self.trainer.early_stopping.init_from_checkpoint(ckpt) + + self.trainer.writer.write("Checkpoint {} loaded".format(file)) + + if "current_iteration" in ckpt: + self.trainer.current_iteration = ckpt["current_iteration"] + registry.register("current_iteration", + self.trainer.current_iteration) + + if "current_epoch" in ckpt: + self.trainer.current_epoch = ckpt["current_epoch"] + registry.register("current_epoch", self.trainer.current_epoch) + + def save(self, iteration, update_best=False): + if not is_main_process(): + return + + ckpt_filepath = os.path.join(self.models_foldername, + "model_%d.ckpt" % iteration) + best_ckpt_filepath = os.path.join(self.ckpt_foldername, + self.ckpt_prefix + "best.ckpt") + + best_iteration = self.trainer.early_stopping.best_monitored_iteration + best_metric = self.trainer.early_stopping.best_monitored_value + current_iteration = self.trainer.current_iteration + current_epoch = self.trainer.current_epoch + model = self.trainer.model + data_parallel = registry.get("data_parallel") or registry.get( + "distributed") + + if data_parallel is True: + model = model.module + + ckpt = { + "model": model.state_dict(), + "optimizer": self.trainer.optimizer.state_dict(), + "lr_scheduler": self.trainer.lr_scheduler.state_dict(), + "current_iteration": current_iteration, + "current_epoch": current_epoch, + "best_iteration": best_iteration, + "best_metric_value": best_metric, + } + + torch.save(ckpt, ckpt_filepath) + self.remove_redundant_ckpts() + + if update_best: + torch.save(ckpt, best_ckpt_filepath) \ No newline at end of file diff --git a/lib/utils/optim_utils.py b/lib/utils/optim_utils.py new file mode 100644 index 0000000..2780d46 --- /dev/null +++ b/lib/utils/optim_utils.py @@ -0,0 +1,122 @@ +import torch +from torch.nn import LayerNorm, Linear, GELU +from torch.nn import MultiheadAttention, Sequential +import warnings +try: + from atorch.normalization import LayerNorm as FastLayerNorm + from atorch.modules.transformer.inject import replace_module + from atorch.modules.transformer.layers import MultiheadAttentionFA, BertAttentionFA +except (ImportError, ModuleNotFoundError) as e: + warnings.warn("Using replace_speedup_op but no atorch/apex installed:%s" % e) +try: + from transformers.models.bert.modeling_bert import BertAttention + replace_transformer_bert = True + +except ImportError: + replace_transformer_bert = False + + +class DefaultStrategy: + replace_mha = True + replace_layernorm = True + replace_linear_gelu = False # TODO: numerical consistency + + +def replace_layer_norm(module: torch.nn.Module, cur_name: str): + + for name, child in module.named_children(): + child_name = cur_name + "." + name + if isinstance(child, LayerNorm): + new_module = FastLayerNorm(child.normalized_shape, eps=child.eps) + new_module.load_state_dict(child.state_dict()) + setattr(module, name, new_module) + else: + replace_layer_norm(child, child_name) + +def is_atorch_available(raise_error=True, log=None): + try: + import atorch # noqa: F401 + + return True + except ImportError as e: + if raise_error is True: + raise ImportError(e, log) + else: + return False + +def _cast_if_autocast_enabled(*args): + if not torch.is_autocast_enabled(): + return args + else: + return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype()) + + +def _fused_dense_gelu_dense(input, weight1, bias1, weight2, bias2): + batch, seq_length, hidden_size = input.size() + input = input.view(batch * seq_length, hidden_size) + args = _cast_if_autocast_enabled(input, weight1, bias1, weight2, bias2) + from apex.fused_dense import FusedDenseGeluDenseFunc # with cast + with torch.cuda.amp.autocast(enabled=False): + out = FusedDenseGeluDenseFunc.apply(*args) + out = out.view(batch, seq_length, -1) + return out + + +def linear_gelu_forward(input_, weight1, bias1, weight2, bias2): + return _fused_dense_gelu_dense(input_, weight1, bias1, weight2, bias2) + + +def replace_linear_gelu(module, cur_name: str): + """ + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1536, out_features=6144, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=6144, out_features=1536, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + """ + for name, child in module.named_children(): + child_name = cur_name + "." + name + if isinstance(child, Sequential): + if len(child) >= 2 and isinstance( + child[0], Sequential + ) and isinstance( + child[1], Linear + ) and len(child[0] + ) >= 2 and isinstance( + child[0][0], Linear + ) and isinstance( + child[0][1], GELU + ): # Sequential+Linear + linear0 = child[0][0] + linear1 = child[1] + if getattr(child, "replace_linear_gelu", False): + continue + child.forward = lambda x: linear_gelu_forward( + x, linear0.weight, linear0.bias, linear1.weight, linear1.bias) + child.replace_linear_gelu = True + print("REPLACE linear+gelu:%s" % child_name) + # setattr(module, name, new_module) + else: + replace_linear_gelu(child, child_name) + + +def replace_speedup_op(model, strategy=DefaultStrategy): + if not is_atorch_available(raise_error=False): + raise ImportError("Install Atorch/apex before using speedup op") + if strategy.replace_mha: + model = replace_module(model, MultiheadAttention, MultiheadAttentionFA, need_scr_module=True) + if replace_transformer_bert: + model = replace_module(model, BertAttention, BertAttentionFA, need_scr_module=True) + root_name = model.__class__.__name__ + if strategy.replace_layernorm: + replace_layer_norm(model, root_name) # inplace + if strategy.replace_linear_gelu: + replace_linear_gelu(model, root_name) + return model + +# TODO: +# 1. SyncBatchNorm diff --git a/lib/utils/utils.py b/lib/utils/utils.py new file mode 100644 index 0000000..fdcdd19 --- /dev/null +++ b/lib/utils/utils.py @@ -0,0 +1,243 @@ +import numpy as np + + +def cosine_scheduler(base_value, + final_value, + all_iters, + warmup_iters=0, + start_warmup_value=0): + warmup_schedule = np.array([]) + if warmup_iters > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, + warmup_iters) + + iters = np.arange(all_iters - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) * ( + 1 + np.cos(np.pi * iters / len(iters))) + + schedule = np.concatenate((warmup_schedule, schedule)) + assert len(schedule) == all_iters + return schedule + + +def cancel_gradients_last_layer(epoch, model, freeze_last_layer): + if epoch >= freeze_last_layer: + return + for n, p in model.named_parameters(): + if "last_layer" in n: + p.grad = None + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def cancel_gradients_backbone(iteration, model, freeze_backbone_steps): + if iteration >= freeze_backbone_steps: + return + for n, p in model.named_parameters(): + if "backbon_hr" in n or 'backbon_s2' in n or 'head_s2' in n or 'fusion' in n or 'ctpe' in n: + p.grad = None + + +class EMA(): + + def __init__(self, model, decay): + self.model = model + self.decay = decay + self.shadow = {} + self.backup = {} + + def register(self): + for name, param in self.model.named_parameters(): + if param.requires_grad: + self.shadow[name] = param.data.clone() + + def update(self): + for name, param in self.model.named_parameters(): + if param.requires_grad: + assert name in self.shadow + new_average = (1.0 - self.decay + ) * param.data + self.decay * self.shadow[name] + self.shadow[name] = new_average.clone() + + def apply_shadow(self): + for name, param in self.model.named_parameters(): + if param.requires_grad: + assert name in self.shadow + self.backup[name] = param.data + param.data = self.shadow[name] + + def restore(self): + for name, param in self.model.named_parameters(): + if param.requires_grad: + assert name in self.backup + param.data = self.backup[name] + self.backup = {} + + +class LayerDecayValueAssigner(object): + + def __init__(self, layer_decay, num_layers, base_lr, net_type, arch='huge'): + assert net_type in ['swin', 'vit'] + assert 0 < layer_decay <= 1 + depths_dict = { + 'tiny': [2, 2, 6, 2], + 'small': [2, 2, 18, 2], + 'base': [2, 2, 18, 2], + 'large': [2, 2, 18, 2], + 'huge': [2, 2, 18, 2], + 'giant': [2, 2, 42, 4], + } + num_layers = num_layers if net_type == 'vit' else sum(depths_dict[arch]) + self.layer_decay = layer_decay + self.values = list(layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)) + self.depths = depths_dict[arch] + self.base_lr = base_lr + self.net_type = net_type + + def get_num_layer_for_vit(self, var_name): + if var_name in ("cls_token", "mask_token", "pos_embed"): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("layers"): + layer_id = int(var_name.split('.')[1]) + return layer_id + 1 + else: + return len(self.values) - 1 + + def get_num_layer_for_swin(self, var_name): + if var_name in ("mask_token", "pos_embed"): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("stages"): + layer_id = int(var_name.split('.')[1]) + if 'blocks' in var_name: + block_id = int(var_name.split('.')[3]) + else: + block_id = self.depths[layer_id] - 1 + layer_id = sum(self.depths[:layer_id]) + block_id + return layer_id + 1 + else: + return len(self.values) - 1 + + def get_layer_id(self, var_name): + if self.net_type == 'swin': + return self.get_num_layer_for_swin(var_name) + if self.net_type == 'vit': + return self.get_num_layer_for_vit(var_name) + + def fix_param(self, model, num_block=4): + if num_block < 1: + return 0 + frozen_num = 0 + if self.net_type == 'swin': + for name, param in model.named_parameters(): + if name.startswith("patch_embed"): + param.requires_grad = False + frozen_num += 1 + if name.startswith("stages") and self.get_layer_id(name) <= num_block: + param.requires_grad = False + frozen_num += 1 + if self.net_type == 'vit': + for name, param in model.named_parameters(): + if name.startswith("patch_embed"): + param.requires_grad = False + frozen_num += 1 + if name.startswith("layers") and self.get_layer_id(name) <= num_block: + param.requires_grad = False + frozen_num += 1 + return frozen_num + + def fix_param_deeper(self, model, num_block=4): + if num_block < 1: + return 0 + frozen_num = 0 + if self.net_type == 'swin': + raise ValueError('Not Support') + if self.net_type == 'vit': + for name, param in model.named_parameters(): + if name.startswith("patch_embed"): + param.requires_grad = False + frozen_num += 1 + if name.startswith("layers") and self.get_layer_id(name) >= num_block: + param.requires_grad = False + frozen_num += 1 + return frozen_num + + def get_parameter_groups(self, model, weight_decay): + parameter_groups_with_wd, parameter_groups_without_wd = [], [] + print_info_with_wd, print_info_without_wd = [], [] + no_decay = [ + "absolute_pos_embed", "relative_position_bias_table", "norm", "bias" + ] + if self.layer_decay == 1: + parameter_groups_with_wd.append( + {"params": [], "weight_decay": weight_decay, "lr": self.base_lr} + ) + print_info_with_wd.append( + {"params": [], "weight_decay": weight_decay, "lr": self.base_lr} + ) + parameter_groups_without_wd.append( + {"params": [], "weight_decay": 0, "lr": self.base_lr} + ) + print_info_without_wd.append( + {"params": [], "weight_decay": 0, "lr": self.base_lr} + ) + else: + for scale in self.values: + parameter_groups_with_wd.append( + {"params": [], "weight_decay": weight_decay, "lr": scale * self.base_lr} + ) + print_info_with_wd.append( + {"params": [], "weight_decay": weight_decay, "lr": scale * self.base_lr} + ) + parameter_groups_without_wd.append( + {"params": [], "weight_decay": 0, "lr": scale * self.base_lr} + ) + print_info_without_wd.append( + {"params": [], "weight_decay": 0, "lr": scale * self.base_lr} + ) + for name, param in model.named_parameters(): + if not param.requires_grad: + print(f'frozen param: {name}') + continue # frozen weights + layer_id = self.get_layer_id(name) if self.layer_decay < 1 else 0 + if any(nd in name for nd in no_decay): + parameter_groups_without_wd[layer_id]['params'].append(param) + print_info_without_wd[layer_id]['params'].append(name) + else: + parameter_groups_with_wd[layer_id]['params'].append(param) + print_info_with_wd[layer_id]['params'].append(name) + parameter_groups_with_wd = [x for x in parameter_groups_with_wd if len(x['params']) > 0] + parameter_groups_without_wd = [x for x in parameter_groups_without_wd if len(x['params']) > 0] + print_info_with_wd = [x for x in print_info_with_wd if len(x['params']) > 0] + print_info_without_wd = [x for x in print_info_without_wd if len(x['params']) > 0] + if self.layer_decay < 1: + for wd, nwd in zip(print_info_with_wd, print_info_without_wd): + print(wd) + print(nwd) + parameter_groups = [] + parameter_groups.extend(parameter_groups_with_wd) + parameter_groups.extend(parameter_groups_without_wd) + return parameter_groups + diff --git a/tools/pretraining_data_builder/.gitignore b/tools/pretraining_data_builder/.gitignore new file mode 100644 index 0000000..2874db5 --- /dev/null +++ b/tools/pretraining_data_builder/.gitignore @@ -0,0 +1,9 @@ +.venv/ +.env +**/*.pyc +**/.pytest_cache +search_data.json +.idea +*.zip +*.jp2 +.ruff_cache diff --git a/tools/pretraining_data_builder/README.md b/tools/pretraining_data_builder/README.md new file mode 100644 index 0000000..7038922 --- /dev/null +++ b/tools/pretraining_data_builder/README.md @@ -0,0 +1,36 @@ +# Pretraining Data Builder +This code is for building pretraining data for the self-supervised learning of SkySense++. + +## Install +Prepare the environment: +``` +conda create -n data_builder python=3.12 +conda activate data_builder +pip install -r requirements.txt +``` +Download pretraining data list in lmdb format from [Zenodo](https://zenodo.org/records/14994430) + +## Download Data +``` +python -m rsi_download --username --password --api_key +``` +Notes: +1. `username` and `password` can be created in the [Copernicus Data Space Ecosystem](https://data.copernicus.eu/cdsapp/#!/home), +`api_key` can be created in the [Maxar](https://ard.maxar.com/docs/about/). +2. `X` `Y` `Z` are coordinates in the Web Mercator coordinate system. +3. `date_min` and `date_max` are in the format of `YYYY-MM`. + +## Process Data +``` +python -m rsi_process --platform --fn_img path/to/image.zip --save_dir output_/ +``` +Notes: +1. `platform` can be `s1`, `s2`, `wv`. +2. `fn_img` is the path to the downloaded zip file. +3. `save_dir` is the directory to save the processed data. + +## Automatic Script +``` +sh run_data_builder.sh +``` +This script will first read the pretraining list, then download the data according to the list, and proceed them automatically. diff --git a/tools/pretraining_data_builder/requirements.txt b/tools/pretraining_data_builder/requirements.txt new file mode 100644 index 0000000..02b072f --- /dev/null +++ b/tools/pretraining_data_builder/requirements.txt @@ -0,0 +1,17 @@ +httpx>=0.27.2 +python-dotenv>=1.0.0 +orjson>=3.9.10 +rich>=13.7.0 +click>=8.1.7 +msgspec>=0.18.4 +asyncclick>=8.1.3.4 +numpy +gdal +pyproj +mercantile +Pillow +shapely +imageio +geopandas +pyresample +lmdb \ No newline at end of file diff --git a/tools/pretraining_data_builder/rsi_download/__init__.py b/tools/pretraining_data_builder/rsi_download/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/pretraining_data_builder/rsi_download/__main__.py b/tools/pretraining_data_builder/rsi_download/__main__.py new file mode 100644 index 0000000..2fd326a --- /dev/null +++ b/tools/pretraining_data_builder/rsi_download/__main__.py @@ -0,0 +1,76 @@ +import click +from rsi_download.download_async import download_core +import asyncio + +@click.command() +@click.argument("x", type=click.STRING) +@click.argument("y", type=click.STRING) +@click.argument("z", type=click.STRING) +@click.argument("date_min", type=click.STRING) +@click.argument("date_max", type=click.STRING) +@click.option( + "--username", + "-u", + type=click.STRING, + help="Username for Copernicus Data Space Ecosystem", +) +@click.option( + "--password", "-p", prompt=True, hide_input=True, confirmation_prompt=False +) +@click.option( + "--api_key", "-k", prompt=True, hide_input=True, confirmation_prompt=False +) +@click.option( + "--max", + "-m", + "max_", + default=100, + type=click.INT, + show_default=True, + help="maximum number of results returned", +) +@click.option( + "--cloud-coverage", + "-c", + "cloud_coverage", + default=10.00, + type=click.FLOAT, + show_default=True, + help="Get only results with a cloud coverage percentage less then the argument given.", +) + +@click.option( + "--platform-name", + "-n", + "platform_name", + default="S2", + type=click.Choice(["S2", "S1", "WV3"]), + show_default=True, + help="Get only results with a platform name.", +) + +@click.option( + "--debug", + default=False, + is_flag=True, + type=click.BOOL, + show_default=True, + help="Debug the http requests and extra debug logging", +) +@click.option( + "--tci", + default=False, + is_flag=True, + type=click.BOOL, + show_default=True, + help="Download only True Color Image (TCI)", +) + +def main(x, y, z, date_min, date_max, username, password, api_key, max_, cloud_coverage, debug, tci, platform_name): + return asyncio.run(download_core(x, y, z, date_min, date_max, username, password, api_key, max_, cloud_coverage, debug, tci, platform_name)) + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\n程序已终止") \ No newline at end of file diff --git a/tools/pretraining_data_builder/rsi_download/auth.py b/tools/pretraining_data_builder/rsi_download/auth.py new file mode 100644 index 0000000..db79810 --- /dev/null +++ b/tools/pretraining_data_builder/rsi_download/auth.py @@ -0,0 +1,36 @@ +import httpx +import msgspec + + +class CDSETokens(msgspec.Struct): + """Copernicus Data Space Ecosystem Tokens""" + + access_token: str + refresh_token: str + expires_in: int + refresh_expires_in: int + token_type: str + not_before_policy: int = msgspec.field(name="not-before-policy") + session_state: str + scope: str + + +def get_access_token(username: str, password: str) -> CDSETokens: + data = { + "client_id": "cdse-public", + "username": username, + "password": password, + "grant_type": "password", + } + try: + with httpx.Client() as client: + r = client.post( + "https://identity.dataspace.copernicus.eu/auth/realms/CDSE/protocol/openid-connect/token", + data=data, + ) + r.raise_for_status() + except Exception as e: + raise Exception( + f"Access token creation failed: {e}. Reponse from the server was: {r.json()}" + ) + return msgspec.json.decode(r.content, type=CDSETokens) diff --git a/tools/pretraining_data_builder/rsi_download/cli.py b/tools/pretraining_data_builder/rsi_download/cli.py new file mode 100644 index 0000000..f387c7b --- /dev/null +++ b/tools/pretraining_data_builder/rsi_download/cli.py @@ -0,0 +1,133 @@ +from typing import Tuple, List + +from rich.table import Table +from rich.console import Console +import re +import msgspec +from datetime import datetime +from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + TextColumn, + TimeRemainingColumn, + TransferSpeedColumn, +) +from rsi_download.exceptions import InvalidWktPointArgument, InvalidDateRangeArgument +from rsi_download.download.search import SearchContent, SearchResult + + +class Preview(msgspec.Struct): + id: str + productid: str + url: str + origin_date: str + name: str + + +progress = Progress( + TextColumn("[bold blue]{task.fields[filename]}", justify="right"), + BarColumn(bar_width=None), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + DownloadColumn(), + "•", + TransferSpeedColumn(), + "•", + TimeRemainingColumn(), +) + + +# "2022-05-03T00:00:00.000Z" +ESA_DATE_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" + + +def convert_to_timestamp(datestring="", dateformat="%d-%m-%Y %H:%M:%S") -> str: + if len(datestring) > 10: + source = datetime.strptime(datestring, dateformat) + else: + source = datetime.strptime(datestring, "%d-%m-%Y") + return source.strftime(ESA_DATE_FORMAT) + + +def daterange_to_timestamp(daterange: str) -> Tuple[str, str]: + if "," not in daterange: + raise InvalidDateRangeArgument( + f'Give a valid daterange string. for example: "11-08-2023 00:00:00,11-09-2023 00:00:00" \n Daterange received: {daterange}' + ) + gt, lt = daterange.split(",") + try: + time_gt = convert_to_timestamp(datestring=gt) + except ValueError: + raise InvalidDateRangeArgument( + f"Invalid dateformat encountered for time_gt: {gt}. Dateformat expected: %d-%m-%Y or %d-%m-%Y %H:%M:%S" + ) + try: + time_lt = convert_to_timestamp(datestring=lt) + except ValueError: + raise InvalidDateRangeArgument( + f"Invalid dateformat encountered for time_lt: {lt}. Dateformat expected: %d-%m-%Y or %d-%m-%Y %H:%M:%S" + ) + return time_gt, time_lt + + +def wkt_to_point(wktstring: str) -> Tuple[float, ...]: + nums = re.findall(r"[-+]?\d*\.\d+|\d+", wktstring) + if len(nums) != 2: + raise InvalidWktPointArgument( + f"Give a valid WKT string. for example: POINT(-9.1372 38.7000). WKT received: {wktstring}" + ) + return tuple(float(n) for n in nums) + + +def show_preview_urls(search_json: SearchContent, platform_name: str) -> List[Preview]: + """ + Show a list of preview urls for downloading in the terminal + + :param search_json: SearchContent object + """ + # print(search_json.value) + preview_urls = [ + Preview( + id=str(i), + productid=v.id, + url=v.assets[0].download_link, + origin_date=v.content_date.start, + name=v.name, + ) + for i, v in enumerate(search_json.value) + ] + table = Table(title=f"RSI Preview Url's") + table.add_column("ID", justify="left", style="magenta") + table.add_column("Acquisition Time", justify="left", style="blue") + table.add_column("Name", justify="left", style="magenta") + + for entry in preview_urls: + table.add_row( + entry.id, + f'[link={entry.url.replace("(", "%28").replace(")", "%29")}]{entry.origin_date}[/link]', + entry.name, + ) + + console = Console() + console.print(table) + return preview_urls + + +def get_selected_products( + search_json: SearchContent, preview_urls: List[Preview], product_ids: str +) -> List[SearchResult]: + """ + Return the selected items from the search_json by the preview url id. + + :param search_json: SearchContent + :param preview_urls: List[Preview] + :param product_ids: string of preview ids + :return: List[SearchResult] + """ + download_product_ids = [ + item.productid + for item in preview_urls + if item.id in [str(n) for n in product_ids] + ] + return [x for x in search_json.value if x.id in download_product_ids] diff --git a/tools/pretraining_data_builder/rsi_download/download/__init__.py b/tools/pretraining_data_builder/rsi_download/download/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/pretraining_data_builder/rsi_download/download/product.py b/tools/pretraining_data_builder/rsi_download/download/product.py new file mode 100644 index 0000000..0c45590 --- /dev/null +++ b/tools/pretraining_data_builder/rsi_download/download/product.py @@ -0,0 +1,97 @@ +import asyncio +from typing import List +import signal +import httpx +from rich.progress import TaskID, Event +from rsi_download.cli import progress +from rsi_download.download.search import SearchResult +from rsi_download.cli import Preview +import os + +done_event = Event() + + +def handle_sigint(signum, frame): + done_event.set() + + +signal.signal(signal.SIGINT, handle_sigint) + + +async def download_tci_products_data( + task_id: TaskID, product: SearchResult, access_token: str, mm_band: str = "R10m" +): + headers = {"Authorization": f"Bearer {access_token}"} + progress.start_task(task_id) + async with httpx.AsyncClient() as client: + client.headers.update(headers) + # create the tci image url + granule_url = f"https://zipper.dataspace.copernicus.eu/odata/v1/Products({product.id})/Nodes({product.name})/Nodes(GRANULE)/Nodes" + granule_resp = await client.get( + f"{granule_url}", follow_redirects=True, headers=headers + ) + granule_folder = granule_resp.json() + img_data_url = f"{granule_url}({granule_folder['result'][0]['Name']})/Nodes(IMG_DATA)/Nodes({mm_band})/Nodes" + img_data_resp = await client.get(img_data_url, follow_redirects=True) + img_data = img_data_resp.json() + tci_name = [img["Name"] for img in img_data["result"] if "TCI" in img["Name"]][ + 0 + ] + tci_url = f"{img_data_url}({tci_name})/$value" + async with client.stream( + method="GET", + url=tci_url, + headers=headers, + ) as response: + progress.update(task_id, total=int(response.headers["Content-length"])) + with open(f"{tci_name}", "wb") as file: + progress.start_task(task_id) + async for chunk in response.aiter_bytes(): + if chunk: + file.write(chunk) + progress.update(task_id, advance=len(chunk)) + if done_event.is_set(): + return + progress.console.log(f"Downloaded {tci_name}") + + +async def download_data(task_id: TaskID, product: SearchResult, preview: Preview, access_token: str): + headers = {"Authorization": f"Bearer {access_token}"} + async with httpx.AsyncClient() as client: + client.headers.update(headers) + async with client.stream( + "GET", + url=f"https://zipper.dataspace.copernicus.eu/odata/v1/Products({product.id})/$value", + headers=headers, + ) as response: + progress.update(task_id, total=int(response.headers["Content-length"])) + with open(f"out_raw/{preview.name.replace('.SAFE', '.zip')}", "wb") as file: + progress.start_task(task_id) + async for chunk in response.aiter_bytes(): + if chunk: + file.write(chunk) + progress.update(task_id, advance=len(chunk)) + if done_event.is_set(): + return + progress.console.log(f"Downloaded {preview.name.replace('.SAFE', '.zip')}") + +async def download_products_data( + products: List[SearchResult], previews: List[Preview], access_token: str, tci_only: bool = False +): + with progress: + download_tasks = [] + for product, preview in zip(products, previews): + task_id = progress.add_task( + f"{preview.name.replace('.SAFE', '.zip')}", + filename=f"{preview.name.replace('.SAFE', '.zip')}", + start=False, + ) + if tci_only: + download_tasks.append( + download_tci_products_data(task_id, product, access_token) + ) + else: + download_tasks.append(download_data(task_id, product, preview, access_token)) + # os.rename(f"product-{product.id}.zip", f"{preview.name.replace('.SAFE', '.zip')}") + await asyncio.gather(*download_tasks) + diff --git a/tools/pretraining_data_builder/rsi_download/download/search.py b/tools/pretraining_data_builder/rsi_download/download/search.py new file mode 100644 index 0000000..361bdba --- /dev/null +++ b/tools/pretraining_data_builder/rsi_download/download/search.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import List + +import msgspec +import httpx + +from rsi_download.exceptions import SearchException +from rsi_download.geo.geo_types import GeoJsonPolygon + +ESA_SEARCH_URL = r"https://catalogue.dataspace.copernicus.eu/odata/v1/Products" + + +class ContentData(msgspec.Struct, rename="pascal"): + """Odata search result start and end date""" + + start: str + end: str + + +class Asset(msgspec.Struct, rename="pascal"): + """Odata search Asset""" + + type_: str = msgspec.field(name="Type") + id: str + download_link: str + s3_path: str + + +class SearchResult(msgspec.Struct, rename="pascal"): + """Odata search Result""" + + id: str + name: str + content_length: int + origin_date: str + s3_path: str + content_date: ContentData + geo_footprint: GeoJsonPolygon + assets: List[Asset] + + +class SearchContent(msgspec.Struct): + value: List[SearchResult] + next_link: str | None = msgspec.field(default=None, name="@odata.nextLink") + + +async def search_odata( + long: float, + lat: float, + cloud_coverage: float, + time_lt: str, + time_gt: str, + max_: int, + platform_name: str, +) -> SearchContent: + # filter voor zoeken op cloudCover, Productype en orbitDirection. + # lt = less then + # eq = equal to + # gt = greater then + # sentinel-2 + if platform_name == "S2": + search_filter = f"OData.CSC.Intersects(area=geography'SRID=4326;POINT ({long:.4f} {lat:.4f})') and Attributes/OData.CSC.DoubleAttribute/any(att:att/Name eq 'cloudCover' and att/OData.CSC.DoubleAttribute/Value lt {cloud_coverage:.2f}) and Attributes/OData.CSC.StringAttribute/any(att:att/Name eq 'productType' and att/OData.CSC.StringAttribute/Value eq 'S2MSI2A') and ContentDate/Start gt {time_gt} and ContentDate/Start lt {time_lt}" + elif platform_name == "S1": + search_filter = f"OData.CSC.Intersects(area=geography'SRID=4326;POINT ({long:.4f} {lat:.4f})') and Attributes/OData.CSC.StringAttribute/any(att:att/Name eq 'productType' and att/OData.CSC.StringAttribute/Value eq 'IW_GRDH_1S') and ContentDate/Start gt {time_gt} and ContentDate/Start lt {time_lt}" + elif platform_name == "WV3": + search_filter = f"OData.CSC.Intersects(area=geography'SRID=4326;POINT ({long:.4f} {lat:.4f})') and Attributes/OData.CSC.StringAttribute/any(att:att/Name eq 'platformName' and att/OData.CSC.StringAttribute/Value eq 'WorldView-3') and ContentDate/Start gt {time_gt} and ContentDate/Start lt {time_lt}" + else: + raise ValueError(f"Invalid platform name: {platform_name}") + + async with httpx.AsyncClient() as client: + r = await client.get( + url=f"{ESA_SEARCH_URL}?$filter={search_filter}&$top={max_}&$expand=Assets", + timeout=60, + ) + if not r.status_code == 200: + raise SearchException(f"Error getting data: {r.text}") + return msgspec.json.decode(r.content, type=SearchContent) diff --git a/tools/pretraining_data_builder/rsi_download/download_async.py b/tools/pretraining_data_builder/rsi_download/download_async.py new file mode 100644 index 0000000..a87f572 --- /dev/null +++ b/tools/pretraining_data_builder/rsi_download/download_async.py @@ -0,0 +1,100 @@ +from typing import List, Tuple + +import msgspec +import asyncio +from rich import print +from rsi_download.auth import get_access_token +from rsi_download.download.product import download_products_data +from rsi_download.cli import ( + show_preview_urls, + Preview, + get_selected_products, +) +from rsi_download.download.search import search_odata +import math + + + +async def download_core( + x: str, + y: str, + z: str, + date_min: str, + date_max: str, + username: str, + password: str, + api_key: str = None, + max_: int = 100, + cloud_coverage: float = 20.0, + debug: bool = False, + tci: bool = False, + platform_name: str = "S2", +): + """ + X tile x coordinate + Y tile y coordinate + Z zoom level + DATE_MIN start date in format YYYYMM + DATE_MAX end date in format YYYYMM + """ + lat, long = tile_to_latlon(float(x), float(y), float(z)) + time_gt = f"{date_min[:4]}-{date_min[4:6]}-01T00:00:00.000Z" + year = int(date_max[:4]) + month = int(date_max[4:]) + if month == 12: + next_year = year + 1 + next_month = 1 + else: + next_year = year + next_month = month + 1 + time_lt = f"{next_year}-{next_month:02d}-01T00:00:00.000Z" + + print(f"coordinates: lat: {lat:.4f}, long: {long:.4f}") + print(f"maximum results: {max_}") + print(f"cloud coverage percentage less then: {cloud_coverage:.2f}") + print(f"time_gt: {time_gt}, time_lt: {time_lt}") + search_data = await search_odata(long, lat, cloud_coverage, time_lt, time_gt, max_, platform_name) + if debug: + print("DEBUG: Search request data is saved to disk.") + with open("search_data.json", "wb") as f: + f.write(msgspec.json.encode(search_data)) + preview_urls: List[Preview] = show_preview_urls(search_data, platform_name) + print("start downloading all data ...") + products_to_download = get_selected_products( + search_json=search_data, preview_urls=preview_urls, product_ids=list(range(len(preview_urls))) + ) + tokens = get_access_token(username, password) + + try: + for i, (product, preview) in enumerate(zip(products_to_download, preview_urls)): + print(f"[{i+1}/{len(products_to_download)}] downloading {product.id} ...") + await asyncio.shield(download_products_data( + [product], [preview], tokens.access_token, tci_only=tci + )) + except asyncio.CancelledError: + print("\nDownload cancelled, exiting...") + return + +def tile_to_latlon(x: int, y: int, z: int, get_center: bool = True) -> Tuple[float, float]: + """ + Convert XYZ tile coordinates to latitude/longitude + + Args: + x: Tile X coordinate + y: Tile Y coordinate + z: Zoom level + get_center: If True, returns the center point coordinates. If False, returns the top-left corner. + + Returns: + Tuple of (latitude, longitude) + """ + n = 2.0 ** z + if get_center: + x += 0.5 + y += 0.5 + + lon_deg = x / n * 360.0 - 180.0 + lat_rad = math.atan(math.sinh(math.pi * (1 - 2 * y / n))) + lat_deg = math.degrees(lat_rad) + return lat_deg, lon_deg + diff --git a/tools/pretraining_data_builder/rsi_download/exceptions.py b/tools/pretraining_data_builder/rsi_download/exceptions.py new file mode 100644 index 0000000..99ea2c3 --- /dev/null +++ b/tools/pretraining_data_builder/rsi_download/exceptions.py @@ -0,0 +1,16 @@ +class InvalidWktPointArgument(Exception): + """Raised when the WKT string is not a valid point""" + + pass + + +class InvalidDateRangeArgument(Exception): + """Raised when the daterange string is not valid""" + + pass + + +class SearchException(Exception): + """Raised when search endpoint returned a non 200 statuscode""" + + pass diff --git a/tools/pretraining_data_builder/rsi_download/geo/__init__.py b/tools/pretraining_data_builder/rsi_download/geo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/pretraining_data_builder/rsi_download/geo/geo_types.py b/tools/pretraining_data_builder/rsi_download/geo/geo_types.py new file mode 100644 index 0000000..3089f96 --- /dev/null +++ b/tools/pretraining_data_builder/rsi_download/geo/geo_types.py @@ -0,0 +1,13 @@ +from typing import List + +import msgspec + + +class Coordinate(msgspec.Struct): + long: float + lat: float + + +class GeoJsonPolygon(msgspec.Struct): + type: str + coordinates: List[List[List[float]]] diff --git a/tools/pretraining_data_builder/rsi_download/sort.py b/tools/pretraining_data_builder/rsi_download/sort.py new file mode 100644 index 0000000..d4f0248 --- /dev/null +++ b/tools/pretraining_data_builder/rsi_download/sort.py @@ -0,0 +1,12 @@ +def sort_by_cloudcover(search_result): + entries = search_result["feed"]["entry"] + sorted_entries = [] + for entry in entries: + sorted_entries.append( + [ + float(e["content"]) + for e in entry["double"] + if e["name"] == "cloudcoverpercentage" + ][0] + ) + return sorted(sorted_entries, key=float) diff --git a/tools/pretraining_data_builder/rsi_pipeline/data_builder.py b/tools/pretraining_data_builder/rsi_pipeline/data_builder.py new file mode 100644 index 0000000..bf9099d --- /dev/null +++ b/tools/pretraining_data_builder/rsi_pipeline/data_builder.py @@ -0,0 +1,70 @@ +import lmdb +import os +import json +from rich import print +from rsi_download.download_async import download_core +from rsi_process.adapter import process_adapter +import asyncclick as click + +@click.command() +@click.argument("lmdb_path", type=click.STRING) +async def read_lmdb_file(lmdb_path): + """ + Read the LMDB file and print all key-value pairs + + Args: + lmdb_path: LMDB file path + """ + if not os.path.exists(lmdb_path): + print(f"Error: LMDB path '{lmdb_path}' does not exist") + return + + try: + print(f"Reading Pretraining List from LMDB file from {lmdb_path}...") + env = lmdb.open(lmdb_path, readonly=True) + total_length = 0 + with env.begin() as txn: + key = b'length' + total_length = int(txn.get(key)) + print(f"Total length of the Pretraining Data: {total_length:,}") + print("Example Data:") + for i in range(10): + print(txn.get(f"{i}".encode()).decode('utf-8')) + for i in range(total_length): + key = f"{i}".encode() + data = json.loads(txn.get(key).decode('utf-8')) + print("*"* 116 + "\n" + f"* Current Data [{i+1} / {total_length}]: {data} *" + "\n" + "*"* 116 ) + print(f"Downloading: {data}") + await download_core( + x=data['x'], + y=data['y'], + z=data['z'], + date_min=data['date_min'], + date_max=data['date_max'], + username=os.getenv("USERNAME"), + password=os.getenv("PASSWORD"), + cloud_coverage=20.0, + tci=False + ) + print('-'* 40) + print(f"Processing: {data}") + process_list = os.listdir('out_raw/') + total_len_process = len(process_list) + for fn in process_list: + print(f"Processing: {fn} [{i+1} / {total_len_process}]...") + process_adapter( + fn_img=f'out_raw/{fn}', + save_dir='out_processed/', + verbose=True, + use_gcj02=False + ) + print('-'* 40) + print("Done!") + + except lmdb.Error as e: + print(f"Error reading LMDB file: {str(e)}") + finally: + env.close() + +if __name__ == "__main__": + read_lmdb_file() \ No newline at end of file diff --git a/tools/pretraining_data_builder/rsi_process/__init__.py b/tools/pretraining_data_builder/rsi_process/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/pretraining_data_builder/rsi_process/__main__.py b/tools/pretraining_data_builder/rsi_process/__main__.py new file mode 100644 index 0000000..e9ebc10 --- /dev/null +++ b/tools/pretraining_data_builder/rsi_process/__main__.py @@ -0,0 +1,19 @@ +import argparse +from rsi_process.adapter import process_adapter + +def get_main_parser(): + parser = argparse.ArgumentParser(description='RSI Processing Pipeline') + parser.add_argument('--fn_img', help='input zip file') + parser.add_argument('--save_dir', default='output/', help='prefix on oss bucket') + parser.add_argument('--verbose', action='store_true', default=True, help='whether to print info') + parser.add_argument('--use_gcj02', action='store_true', default=False, help='whether to use GCJ02 coordinate system') + return parser + + +def main(): + parser = get_main_parser() + args = parser.parse_args() + process_adapter(args.fn_img, args.save_dir, args.verbose, args.use_gcj02) + +if __name__ == '__main__': + main() diff --git a/tools/pretraining_data_builder/rsi_process/adapter.py b/tools/pretraining_data_builder/rsi_process/adapter.py new file mode 100644 index 0000000..7f3aff4 --- /dev/null +++ b/tools/pretraining_data_builder/rsi_process/adapter.py @@ -0,0 +1,21 @@ +from rsi_process.script_s1_tiles import process_s1 +from rsi_process.script_s2_tiles import process_s2 +from rsi_process.script_wv_tiles import process_wv +import EasyDict as edict + + +def process_adapter(fn_img, save_dir, verbose, use_gcj02): + satellite_info = fn_img.split('/')[-1].split('_')[0] + if 'S2' in satellite_info: + satellite = 'S2' + elif 'S1' in satellite_info: + satellite = 'S1' + elif 'WV' in satellite_info: + satellite = 'WV' + args = edict(fn_img=fn_img, save_dir=save_dir, verbose=verbose, use_gcj02=use_gcj02) + if satellite == 'S1': + process_s1(args) + elif satellite == 'S2': + process_s2(args) + elif satellite == 'WV': + process_wv(args) diff --git a/tools/pretraining_data_builder/rsi_process/script_s1_tiles.py b/tools/pretraining_data_builder/rsi_process/script_s1_tiles.py new file mode 100644 index 0000000..a9abf93 --- /dev/null +++ b/tools/pretraining_data_builder/rsi_process/script_s1_tiles.py @@ -0,0 +1,238 @@ +import os +import uuid +import numpy as np +import pyproj as prj +from osgeo import gdal +from time import time +import mercantile +from PIL import Image +import utils_s1 +import imageio.v2 as iio +from tile_resample import ( + get_tile_array, + transfer +) + +import argparse +from rich import print +from rich.progress import track + +def get_args_parser(): + parser = argparse.ArgumentParser(description='Sentinel-1 to GCJ02 tiles') + parser.add_argument('--fn_img', help='input zip file of Sentinel-1 L1C') + parser.add_argument('--save_dir', default='output_s1/', help='prefix on oss bucket') + parser.add_argument('--verbose', action='store_true', default=True, help='whether to print info') + parser.add_argument('--use_gcj02', action='store_true', default=False, help='whether to use GCJ02 coordinate system') + return parser + +def process_s1(args): + t_start = time() + fn_img = args.fn_img + max_target_file = fn_img.split('_')[2][0:8] + verbose = args.verbose + save_rgb = True + nodata = 0 + + save_dir = args.save_dir + os.makedirs(save_dir, exist_ok=True) + thumb_save_dir = os.path.join(save_dir, 'thumb') + os.makedirs(thumb_save_dir, exist_ok=True) + + print(f"converting {fn_img}...") + + z = 14 + bands = ['VV', 'VH'] + buf = 1 + + def get_image_by_approximate_boundary(boundary): + ''' + boundary: iterable of (lng, lat) in wgs84 + ''' + arr_lnglat = np.array(boundary) + xx, yy = tr_from_4326.transform(arr_lnglat[:, 0], arr_lnglat[:, 1]) + row_min = int((tr[3] - yy.max()) / yres) + row_max = int((tr[3] - yy.min()) / yres) + col_min = int((xx.min() - tr[0]) / xres) + col_max = int((xx.max() - tr[0]) / xres) + row_min = max(0, row_min - buf) + row_max = min(ny - 1, row_max + buf) + col_min = max(0, col_min - buf) + col_max = min(nx - 1, col_max + buf) + if row_min > row_max or col_min > col_max: + return None + + arr_image = np.stack([ + ds.ReadAsArray(col_min, row_min, col_max - col_min + 1, row_max - row_min + 1) + for ds in list_arr + ]) + + for iband in range(arr_image.shape[0]): + if np.any(arr_image[iband] != nodata): + break + else: + return None + arr_image = arr_image.transpose((1, 2, 0)) + if arr_image.shape[2] == 1: + arr_image = arr_image[:, :, 0] + arr_xx = tr[0] + np.arange(col_min, col_max + 1) * xres + arr_yy = tr[3] - np.arange(row_min, row_max + 1) * yres + arr_xx, arr_yy = np.meshgrid(arr_xx, arr_yy) + arr_lngs, arr_lats = tr_to_4326.transform(arr_xx, arr_yy) + return arr_image, arr_lngs, arr_lats + + + rec = utils_s1.zip2rec(fn_img) + # import pdb; pdb.set_trace() + os.makedirs(os.path.join(thumb_save_dir, rec['sensing_start'].replace('-', '')), exist_ok=True) + thumb_save_path = os.path.join(thumb_save_dir, rec['sensing_start'].replace('-', ''), rec['product_uri'].replace('SAFE', 'png')) + iio.imwrite(thumb_save_path, rec['thumb']) + + list_arr = [] + for band in bands: + fn_jp2 = utils_s1.make_full_name(rec, band=band) + # import pdb; pdb.set_trace() + fn_jp2 = '/vsizip/' + os.path.join(fn_img, fn_jp2) + ds = gdal.Open(fn_jp2) + list_arr.append(ds) + if band == bands[0]: + nx, ny = ds.RasterXSize, ds.RasterYSize + if verbose: print('input size:', nx, ny) + tr = ds.GetGeoTransform() + if verbose: + print(gdal.Info(ds, format='json')) + # import pdb; pdb.set_trace() + try: + proj_wkt = ds.GetProjectionRef() + if proj_wkt: + srs = prj.CRS.from_wkt(proj_wkt) + epsg = int(srs.to_epsg()) + else: + proj_wkt = ds.GetGCPProjection() + if proj_wkt: + srs = prj.CRS.from_wkt(proj_wkt) + epsg = int(srs.to_epsg()) + else: + print("Warning: No projection information found, using default value 4326 (WGS84)") + epsg = 4326 + except Exception as e: + print(f"Warning: Unable to get EPSG code, using default value 4326 (WGS84). Error: {e}") + epsg = 4326 + + if verbose: + print(f"Used EPSG code: {epsg}") + + size_pixel = mercantile.CE / 2 ** z / 256 + radius = np.ceil(max(tr[1], -tr[5]) / size_pixel * 1.5) + + buf_ext = buf + xmin = tr[0] - buf_ext * tr[1] + ymin = tr[3] + (ny + buf_ext) * tr[5] + xmax = tr[0] + (nx + buf_ext) * tr[1] + ymax = tr[3] - buf_ext * tr[5] + xres = tr[1] + yres = - tr[5] + if verbose: + print( + f'input extent, WGS84, buffered by {buf_ext} pixels: {xmin}, {ymin}, {xmax}, {ymax}' + ) + + tr_to_4326 = prj.Transformer.from_crs(epsg, 4326, always_xy=True) + tr_from_4326 = prj.Transformer.from_crs(4326, epsg, always_xy=True) + arr_lng, arr_lat = tr_to_4326.transform( + np.array([xmin, xmin, xmax, xmax]), + np.array([ymax, ymin, ymin, ymax]) + ) + # import pdb; pdb.set_trace() + if args.use_gcj02: + arr_lng_final, arr_lat_final = transfer.WGS84_to_GCJ02(arr_lng, arr_lat) + else: + arr_lng_final, arr_lat_final = arr_lng, arr_lat + + box = ( + arr_lng_final.min(), + arr_lat_final.min(), + arr_lng_final.max(), + arr_lat_final.max() + ) + + if verbose: + coord_system = "GCJ02" if args.use_gcj02 else "WGS84" + print(f'input extent, {coord_system}: {box}') + + tile_ul = mercantile.tile(box[0], box[3], z) + tile_lr = mercantile.tile(box[2], box[1], z) + + if verbose: + print('Upperleft ', str(tile_ul)) + print('Lowerright ', str(tile_lr)) + + def work(x, y, z, save_rgb): + arr_tile = get_tile_array( + x, y, z, + method='nearest', + func_source=get_image_by_approximate_boundary, + radius=radius, + use_gc02=args.use_gcj02 + ) + y_str = str(y) + if arr_tile is not None: + indi_gap = arr_tile[:, :, 0] == 0 + + dict_arr = { + band: arr_tile[:, :, i_band] + for i_band, band in enumerate(bands) + } + save_path = os.path.join(save_dir, str(z), str(x)) + os.makedirs(save_path, exist_ok=True) + + npz_filename = os.path.join(save_path, f'{y_str}_{max_target_file}.npz') + + if indi_gap.any(): + if os.path.exists(npz_filename): + try: + fp = np.load(npz_filename) + for band in bands: + dict_arr[band][indi_gap] = fp[band][indi_gap] + + except Exception as e: + print(e) + print("datasize is 0", npz_filename) + pass + + np.savez_compressed(npz_filename, **dict_arr) + if verbose: + print(f"npz file for X={str(x)}, Y={y_str}, Z={str(z)} date={max_target_file} generated!") + if save_rgb: + arr_rgb = np.stack([dict_arr['B4'], dict_arr['B3'], dict_arr['B2']], axis=-1) + arr_rgb = np.clip(arr_rgb / 3000. * 255, 0, 255).astype(np.uint8) + image_tile = Image.fromarray(arr_rgb) + + png_filename = os.path.join(save_path, f'{y_str}_{max_target_file}.png') + image_tile.save(png_filename, format='png') + + diff_list = [] + + tasks = [ + (x, y) for x in range(tile_ul.x, tile_lr.x + 1) + for y in range(tile_ul.y, tile_lr.y + 1) + ] + + for x, y in track(tasks, description="converting tiles..."): + work(x, y, z, save_rgb) + diff_list.append(os.path.join(str(z), str(x), f'{y}_{max_target_file}.npz')) + + diff_path = os.path.join(save_dir, 'diff', 'new') + os.makedirs(diff_path, exist_ok=True) + diff_filename = os.path.join(diff_path, f"{z}-{os.path.splitext(os.path.basename(fn_img))[0]}-{uuid.uuid1()}.txt") + with open(diff_filename, 'w') as f: + f.write('\n'.join(diff_list)) + + print("time cost :", time() - t_start) + +def main(): + args = get_args_parser().parse_args() + process_s1(args) + +if __name__ == '__main__': + main() + diff --git a/tools/pretraining_data_builder/rsi_process/script_s2_tiles.py b/tools/pretraining_data_builder/rsi_process/script_s2_tiles.py new file mode 100644 index 0000000..1697b1d --- /dev/null +++ b/tools/pretraining_data_builder/rsi_process/script_s2_tiles.py @@ -0,0 +1,226 @@ +import os +import uuid +import numpy as np +import pyproj as prj +from osgeo import gdal +from time import time +import mercantile +from PIL import Image +import utils_s2 +import imageio.v2 as iio +from tile_resample import ( + get_tile_array, + transfer +) + +import argparse +from rich import print +from rich.progress import track + +def get_args_parser(): + parser = argparse.ArgumentParser(description='Sentinel-2 to GCJ02 tiles') + parser.add_argument('--fn_img', help='input zip file of Sentinel-2 L2A') # /Users/wukang/Projects/sentinel2-downloader/S2A_MSIL2A_20220615T024601_N0400_R132_T50SNA_20220615T062308.zip + parser.add_argument('--resolution', type=int, help='10 or 20 meter resolution bands') + parser.add_argument('--save_dir', default='output_s2/', help='prefix on oss bucket') + parser.add_argument('--verbose', action='store_true', default=True, help='whether to print info') + parser.add_argument('--use_gcj02', action='store_true', default=False, help='whether to use GCJ02 coordinate system') + return parser.parse_args() + +def process_s2(args): + t_start = time() + fn_img = args.fn_img + max_target_file = fn_img.split('_')[2][0:8] + resolution = args.resolution + verbose = args.verbose + save_rgb = True + nodata = 0 + + save_dir = args.save_dir + os.makedirs(save_dir, exist_ok=True) + thumb_save_dir = os.path.join(save_dir, 'thumb') + os.makedirs(thumb_save_dir, exist_ok=True) + + print(f"converting {fn_img}...") + if resolution == 10: + z = 14 + bands = ['B2', 'B3', 'B4', 'B8'] + buf = 1 + elif resolution == 20: + z = 13 + bands = ['B5', 'B6', 'B7', 'B8A', 'B11', 'B12', 'SCL'] + buf = 1 + save_rgb = False + else: + raise Exception(f'Unknown resoluiton: {resolution}') + + + def get_image_by_approximate_boundary(boundary): + ''' + boundary: iterable of (lng, lat) in wgs84 + ''' + arr_lnglat = np.array(boundary) + xx, yy = tr_from_4326.transform(arr_lnglat[:, 0], arr_lnglat[:, 1]) + row_min = int((tr[3] - yy.max()) / yres) + row_max = int((tr[3] - yy.min()) / yres) + col_min = int((xx.min() - tr[0]) / xres) + col_max = int((xx.max() - tr[0]) / xres) + row_min = max(0, row_min - buf) + row_max = min(ny - 1, row_max + buf) + col_min = max(0, col_min - buf) + col_max = min(nx - 1, col_max + buf) + if row_min > row_max or col_min > col_max: + return None + + arr_image = np.stack([ + ds.ReadAsArray(col_min, row_min, col_max - col_min + 1, row_max - row_min + 1) + for ds in list_arr + ]) + + for iband in range(arr_image.shape[0]): + if np.any(arr_image[iband] != nodata): + break + else: + return None + arr_image = arr_image.transpose((1, 2, 0)) + if arr_image.shape[2] == 1: + arr_image = arr_image[:, :, 0] + arr_xx = tr[0] + np.arange(col_min, col_max + 1) * xres + arr_yy = tr[3] - np.arange(row_min, row_max + 1) * yres + arr_xx, arr_yy = np.meshgrid(arr_xx, arr_yy) + arr_lngs, arr_lats = tr_to_4326.transform(arr_xx, arr_yy) + return arr_image, arr_lngs, arr_lats + + + rec = utils_s2.zip2rec(fn_img) + + os.makedirs(os.path.join(thumb_save_dir, rec['sensing_start'].replace('-', '')), exist_ok=True) + thumb_save_path = os.path.join(thumb_save_dir, rec['sensing_start'].replace('-', ''), rec['product_uri'].replace('SAFE', 'png')) + iio.imwrite(thumb_save_path, rec['thumb']) + + list_arr = [] + for band in bands: + fn_jp2 = utils_s2.make_full_name(rec, band=band) + fn_jp2 = '/vsizip/' + os.path.join(fn_img, fn_jp2) + ds = gdal.Open(fn_jp2) + list_arr.append(ds) + if band == bands[0]: + nx, ny = ds.RasterXSize, ds.RasterYSize + if verbose: print('input size:', nx, ny) + tr = ds.GetGeoTransform() + if verbose: + print(gdal.Info(ds, format='json')) + epsg = int( + gdal.Info(ds, format='json')['coordinateSystem']['wkt'].rsplit('"EPSG",', 1)[-1][:-2] + ) + + size_pixel = mercantile.CE / 2 ** z / 256 + radius = np.ceil(max(tr[1], -tr[5]) / size_pixel * 1.5) + + buf_ext = buf + xmin = tr[0] - buf_ext * tr[1] + ymin = tr[3] + (ny + buf_ext) * tr[5] + xmax = tr[0] + (nx + buf_ext) * tr[1] + ymax = tr[3] - buf_ext * tr[5] + xres = tr[1] + yres = - tr[5] + if verbose: + print( + f'input extent, WGS84, buffered by {buf_ext} pixels: {xmin}, {ymin}, {xmax}, {ymax}' + ) + + tr_to_4326 = prj.Transformer.from_crs(epsg, 4326, always_xy=True) + tr_from_4326 = prj.Transformer.from_crs(4326, epsg, always_xy=True) + arr_lng, arr_lat = tr_to_4326.transform( + np.array([xmin, xmin, xmax, xmax]), + np.array([ymax, ymin, ymin, ymax]) + ) + + if args.use_gcj02: + arr_lng_final, arr_lat_final = transfer.WGS84_to_GCJ02(arr_lng, arr_lat) + else: + arr_lng_final, arr_lat_final = arr_lng, arr_lat + + box = ( + arr_lng_final.min(), + arr_lat_final.min(), + arr_lng_final.max(), + arr_lat_final.max() + ) + + if verbose: + coord_system = "GCJ02" if args.use_gcj02 else "WGS84" + print(f'input extent, {coord_system}: {box}') + + tile_ul = mercantile.tile(box[0], box[3], z) + tile_lr = mercantile.tile(box[2], box[1], z) + + if verbose: + print('Upperleft ', str(tile_ul)) + print('Lowerright ', str(tile_lr)) + + def work(x, y, z, save_rgb): + arr_tile = get_tile_array( + x, y, z, + method='nearest', + func_source=get_image_by_approximate_boundary, + radius=radius, + use_gc02=args.use_gcj02 + ) + y_str = str(y) + if arr_tile is not None: + indi_gap = arr_tile[:, :, 0] == 0 + + dict_arr = { + band: arr_tile[:, :, i_band] + for i_band, band in enumerate(bands) + } + save_path = os.path.join(save_dir, str(z), str(x)) + os.makedirs(save_path, exist_ok=True) + + npz_filename = os.path.join(save_path, f'{y_str}_{max_target_file}.npz') + + if indi_gap.any(): + if os.path.exists(npz_filename): + try: + fp = np.load(npz_filename) + for band in bands: + dict_arr[band][indi_gap] = fp[band][indi_gap] + + except Exception as e: + print(e) + print("datasize is 0", npz_filename) + pass + + np.savez_compressed(npz_filename, **dict_arr) + if verbose: + print(f"npz file for X={str(x)}, Y={y_str}, Z={str(z)} date={max_target_file} generated!") + if save_rgb: + arr_rgb = np.stack([dict_arr['B4'], dict_arr['B3'], dict_arr['B2']], axis=-1) + arr_rgb = np.clip(arr_rgb / 3000. * 255, 0, 255).astype(np.uint8) + image_tile = Image.fromarray(arr_rgb) + + png_filename = os.path.join(save_path, f'{y_str}_{max_target_file}.png') + image_tile.save(png_filename, format='png') + + diff_list = [] + + tasks = [ + (x, y) for x in range(tile_ul.x, tile_lr.x + 1) + for y in range(tile_ul.y, tile_lr.y + 1) + ] + + for x, y in track(tasks, description="converting tiles..."): + work(x, y, z, save_rgb) + diff_list.append(os.path.join(str(z), str(x), f'{y}_{max_target_file}.npz')) + + diff_path = os.path.join(save_dir, 'diff', 'new') + os.makedirs(diff_path, exist_ok=True) + diff_filename = os.path.join(diff_path, f"{z}-{os.path.splitext(os.path.basename(fn_img))[0]}-{uuid.uuid1()}.txt") + with open(diff_filename, 'w') as f: + f.write('\n'.join(diff_list)) + + print("time cost :", time() - t_start) + +if __name__ == '__main__': + args = get_args_parser() + process_s2(args) diff --git a/tools/pretraining_data_builder/rsi_process/script_wv_tiles.py b/tools/pretraining_data_builder/rsi_process/script_wv_tiles.py new file mode 100644 index 0000000..cb6ee4a --- /dev/null +++ b/tools/pretraining_data_builder/rsi_process/script_wv_tiles.py @@ -0,0 +1,183 @@ +import os +import uuid +import numpy as np +import pyproj as prj +from osgeo import gdal +from time import time +import mercantile +from PIL import Image +import imageio.v2 as iio +from tile_resample import ( + get_tile_array, + transfer +) +import argparse +from rich import print +from rich.progress import track + +def get_args_parser(): + parser = argparse.ArgumentParser(description='WorldView to tiles') + parser.add_argument('--fn_img', help='input file of WorldView image') + parser.add_argument('--save_dir', default='output_wv/', help='output directory') + parser.add_argument('--zoom', type=int, default=16, help='zoom level') + parser.add_argument('--verbose', action='store_true', default=True) + parser.add_argument('--use_gcj02', action='store_true', default=False) + return parser.parse_args() + +def get_image_by_approximate_boundary(ds_list, boundary, tr, buf=1): + '''Get image data within a specified boundary + + Args: + ds_list: List of GDAL datasets + boundary: List of (lng, lat) coordinates + tr: Geotransformation parameters + buf: Buffer size + ''' + arr_lnglat = np.array(boundary) + tr_from_4326 = prj.Transformer.from_crs(4326, ds_list[0].GetProjection(), always_xy=True) + + xx, yy = tr_from_4326.transform(arr_lnglat[:, 0], arr_lnglat[:, 1]) + + nx = ds_list[0].RasterXSize + ny = ds_list[0].RasterYSize + xres = tr[1] + yres = -tr[5] + + row_min = int((tr[3] - yy.max()) / yres) + row_max = int((tr[3] - yy.min()) / yres) + col_min = int((xx.min() - tr[0]) / xres) + col_max = int((xx.max() - tr[0]) / xres) + + row_min = max(0, row_min - buf) + row_max = min(ny - 1, row_max + buf) + col_min = max(0, col_min - buf) + col_max = min(nx - 1, col_max + buf) + + if row_min > row_max or col_min > col_max: + return None + + arr_image = np.stack([ + ds.ReadAsArray(col_min, row_min, col_max - col_min + 1, row_max - row_min + 1) + for ds in ds_list + ]) + + if np.all(arr_image == 0): + return None + + arr_image = arr_image.transpose((1, 2, 0)) + + arr_xx = tr[0] + np.arange(col_min, col_max + 1) * xres + arr_yy = tr[3] - np.arange(row_min, row_max + 1) * yres + arr_xx, arr_yy = np.meshgrid(arr_xx, arr_yy) + + tr_to_4326 = prj.Transformer.from_crs(ds_list[0].GetProjection(), 4326, always_xy=True) + arr_lngs, arr_lats = tr_to_4326.transform(arr_xx, arr_yy) + + return arr_image, arr_lngs, arr_lats + +def process_wv(args): + t_start = time() + + fn_img = args.fn_img + save_dir = args.save_dir + z = args.zoom + verbose = args.verbose + + os.makedirs(save_dir, exist_ok=True) + + ds = gdal.Open(fn_img) + if ds is None: + raise Exception(f"Cannot open {fn_img}") + + bands = [ds.GetRasterBand(i+1) for i in range(ds.RasterCount)] + list_arr = [ds] + + nx, ny = ds.RasterXSize, ds.RasterYSize + tr = ds.GetGeoTransform() + + if verbose: + print('Input size:', nx, ny) + print(gdal.Info(ds, format='json')) + + # Calculate the image range + size_pixel = mercantile.CE / 2 ** z / 256 + radius = np.ceil(max(tr[1], -tr[5]) / size_pixel * 1.5) + + buf_ext = 1 + xmin = tr[0] - buf_ext * tr[1] + ymin = tr[3] + (ny + buf_ext) * tr[5] + xmax = tr[0] + (nx + buf_ext) * tr[1] + ymax = tr[3] - buf_ext * tr[5] + + tr_to_4326 = prj.Transformer.from_crs(ds.GetProjection(), 4326, always_xy=True) + arr_lng, arr_lat = tr_to_4326.transform( + np.array([xmin, xmin, xmax, xmax]), + np.array([ymax, ymin, ymin, ymax]) + ) + + if args.use_gcj02: + arr_lng_final, arr_lat_final = transfer.WGS84_to_GCJ02(arr_lng, arr_lat) + else: + arr_lng_final, arr_lat_final = arr_lng, arr_lat + + box = ( + arr_lng_final.min(), + arr_lat_final.min(), + arr_lng_final.max(), + arr_lat_final.max() + ) + + if verbose: + coord_system = "GCJ02" if args.use_gcj02 else "WGS84" + print(f'Input extent, {coord_system}: {box}') + + # Calculate the tile range to be processed + tile_ul = mercantile.tile(box[0], box[3], z) + tile_lr = mercantile.tile(box[2], box[1], z) + + if verbose: + print('Upperleft ', str(tile_ul)) + print('Lowerright ', str(tile_lr)) + + def work(x, y, z): + arr_tile = get_tile_array( + x, y, z, + method='nearest', + func_source=lambda boundary: get_image_by_approximate_boundary(list_arr, boundary, tr), + radius=radius, + use_gc02=args.use_gcj02 + ) + + if arr_tile is not None: + save_path = os.path.join(save_dir, str(z), str(x)) + os.makedirs(save_path, exist_ok=True) + + # Save as PNG + if arr_tile.shape[2] >= 3: + arr_rgb = arr_tile[:, :, :3] + arr_rgb = np.clip(arr_rgb / 2000. * 255, 0, 255).astype(np.uint8) + image_tile = Image.fromarray(arr_rgb) + png_filename = os.path.join(save_path, f'{y}.png') + image_tile.save(png_filename, format='png') + + # Save as NPZ + dict_arr = {f'B{i+1}': arr_tile[:, :, i] for i in range(arr_tile.shape[2])} + npz_filename = os.path.join(save_path, f'{y}.npz') + np.savez_compressed(npz_filename, **dict_arr) + + tasks = [ + (x, y) for x in range(tile_ul.x, tile_lr.x + 1) + for y in range(tile_ul.y, tile_lr.y + 1) + ] + + for x, y in track(tasks, description="Converting tiles..."): + work(x, y, z) + + print("Time cost:", time() - t_start) + +def main(): + args = get_args_parser() + process_wv(args) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tools/pretraining_data_builder/rsi_process/tile_resample.py b/tools/pretraining_data_builder/rsi_process/tile_resample.py new file mode 100644 index 0000000..ba3e3e2 --- /dev/null +++ b/tools/pretraining_data_builder/rsi_process/tile_resample.py @@ -0,0 +1,230 @@ +import numpy as np +import mercantile +from pyresample import bilinear, kd_tree, geometry + +TILE_SIZE = 256 + +class LngLatTransfer(): + + def __init__(self): + self.x_pi = 3.14159265358979324 * 3000.0 / 180.0 + self.pi = np.pi # π + self.a = 6378245.0 + self.es = 0.00669342162296594323 + pass + + def GCJ02_to_BD09(self, gcj_lng, gcj_lat): + """ + Convert coordinates from GCJ02 to BD09 coordinate system + :param lng: Longitude in GCJ02 coordinate system + :param lat: Latitude in GCJ02 coordinate system + :return: Converted longitude and latitude in BD09 + """ + z = np.sqrt(gcj_lng * gcj_lng + gcj_lat * gcj_lat) + 0.00002 * np.sin(gcj_lat * self.x_pi) + theta = np.arctan2(gcj_lat, gcj_lng) + 0.000003 * np.cos(gcj_lng * self.x_pi) + bd_lng = z * np.cos(theta) + 0.0065 + bd_lat = z * np.sin(theta) + 0.006 + return bd_lng, bd_lat + + + def BD09_to_GCJ02(self, bd_lng, bd_lat): + ''' + Convert coordinates from BD09 to GCJ02 coordinate system + :param bd_lng: Longitude in BD09 coordinate system + :param bd_lat: Latitude in BD09 coordinate system + :return: Converted longitude and latitude in GCJ02 + ''' + x = bd_lng - 0.0065 + y = bd_lat - 0.006 + z = np.sqrt(x * x + y * y) - 0.00002 * np.sin(y * self.x_pi) + theta = np.arctan2(y, x) - 0.000003 * np.cos(x * self.x_pi) + gcj_lng = z * np.cos(theta) + gcj_lat = z * np.sin(theta) + return gcj_lng, gcj_lat + + + def WGS84_to_GCJ02(self, lng, lat): + ''' + Convert coordinates from WGS84 to GCJ02 coordinate system + :param lng: Longitude in WGS84 coordinate system + :param lat: Latitude in WGS84 coordinate system + :return: Converted longitude and latitude in GCJ02 + ''' + dlat = self._transformlat(lng - 105.0, lat - 35.0) + dlng = self._transformlng(lng - 105.0, lat - 35.0) + radlat = lat / 180.0 * self.pi + magic = np.sin(radlat) + magic = 1 - self.es * magic * magic + sqrtmagic = np.sqrt(magic) + dlat = (dlat * 180.0) / ((self.a * (1 - self.es)) / (magic * sqrtmagic) * self.pi) + dlng = (dlng * 180.0) / (self.a / sqrtmagic * np.cos(radlat) * self.pi) + gcj_lng = lng + dlng + gcj_lat = lat + dlat + return gcj_lng, gcj_lat + + + def GCJ02_to_WGS84(self, gcj_lng, gcj_lat): + ''' + Convert coordinates from GCJ02 to WGS84 coordinate system + :param gcj_lng: Longitude in GCJ02 coordinate system + :param gcj_lat: Latitude in GCJ02 coordinate system + :return: Converted longitude and latitude in WGS84 + ''' + dlat = self._transformlat(gcj_lng - 105.0, gcj_lat - 35.0) + dlng = self._transformlng(gcj_lng - 105.0, gcj_lat - 35.0) + radlat = gcj_lat / 180.0 * self.pi + magic = np.sin(radlat) + magic = 1 - self.es * magic * magic + sqrtmagic = np.sqrt(magic) + dlat = (dlat * 180.0) / ((self.a * (1 - self.es)) / (magic * sqrtmagic) * self.pi) + dlng = (dlng * 180.0) / (self.a / sqrtmagic * np.cos(radlat) * self.pi) + mglat = gcj_lat + dlat + mglng = gcj_lng + dlng + lng = gcj_lng * 2 - mglng + lat = gcj_lat * 2 - mglat + return lng, lat + + + def BD09_to_WGS84(self, bd_lng, bd_lat): + ''' + Convert coordinates from BD09 to WGS84 coordinate system + :param bd_lng: Longitude in BD09 coordinate system + :param bd_lat: Latitude in BD09 coordinate system + :return: Converted longitude and latitude in WGS84 + ''' + lng, lat = self.BD09_to_GCJ02(bd_lng, bd_lat) + return self.GCJ02_to_WGS84(lng, lat) + + + def WGS84_to_BD09(self, lng, lat): + ''' + Convert coordinates from WGS84 to BD09 coordinate system + :param lng: Longitude in WGS84 coordinate system + :param lat: Latitude in WGS84 coordinate system + :return: Converted longitude and latitude in BD09 + ''' + lng, lat = self.WGS84_to_GCJ02(lng, lat) + return self.GCJ02_to_BD09(lng, lat) + + + def _transformlat(self, lng, lat): + ret = -100.0 + 2.0 * lng + 3.0 * lat + 0.2 * lat * lat + \ + 0.1 * lng * lat + 0.2 * np.sqrt(np.fabs(lng)) + ret += (20.0 * np.sin(6.0 * lng * self.pi) + 20.0 * + np.sin(2.0 * lng * self.pi)) * 2.0 / 3.0 + ret += (20.0 * np.sin(lat * self.pi) + 40.0 * + np.sin(lat / 3.0 * self.pi)) * 2.0 / 3.0 + ret += (160.0 * np.sin(lat / 12.0 * self.pi) + 320 * + np.sin(lat * self.pi / 30.0)) * 2.0 / 3.0 + return ret + + + def _transformlng(self, lng, lat): + ret = 300.0 + lng + 2.0 * lat + 0.1 * lng * lng + \ + 0.1 * lng * lat + 0.1 * np.sqrt(np.fabs(lng)) + ret += (20.0 * np.sin(6.0 * lng * self.pi) + 20.0 * + np.sin(2.0 * lng * self.pi)) * 2.0 / 3.0 + ret += (20.0 * np.sin(lng * self.pi) + 40.0 * + np.sin(lng / 3.0 * self.pi)) * 2.0 / 3.0 + ret += (150.0 * np.sin(lng / 12.0 * self.pi) + 300.0 * + np.sin(lng / 30.0 * self.pi)) * 2.0 / 3.0 + return ret + + def WGS84_to_WebMercator(self, lng, lat): + ''' + Convert coordinates from WGS84 to Web Mercator + :param lng: Longitude in WGS84 + :param lat: Latitude in WGS84 + :return: Converted Web Mercator coordinates + ''' + x = lng * 20037508.342789 / 180 + y = np.log(np.tan((90 + lat) * self.pi / 360)) / (self.pi / 180) + y = y * 20037508.34789 / 180 + return x, y + + def WebMercator_to_WGS84(self, x, y): + ''' + Convert coordinates from Web Mercator to WGS84 + :param x: Web Mercator x coordinate + :param y: Web Mercator y coordinate + :return: Converted longitude and latitude in WGS84 + ''' + lng = x / 20037508.34 * 180 + lat = y / 20037508.34 * 180 + lat = 180 / self.pi * (2 * np.arctan(np.exp(lat * self.pi / 180)) - self.pi / 2) + return lng, lat + + +transfer = LngLatTransfer() +def get_tile_array(x, y, z, method='nearest', func_source=None, radius=2, fill_value=0, use_gc02=True): + """Resample source image data to map tile + + Args: + x, y, z: Tile coordinates + method: Resampling method ('nearest' or 'bilinear') + func_source: Function to get source image data + radius: Search radius in pixels + fill_value: Value for no data areas + gc02: Whether the coordinates are in GCJ02 system (True) or WGS84 (False) + + Returns: + ndarray: Resampled tile data + """ + bounds = mercantile.bounds(x, y, z) + + if use_gc02: + # Convert coordinates from GCJ02 to WGS84 + wgs84_lngs, wgs84_lats = transfer.GCJ02_to_WGS84( + gcj_lng=np.array([bounds.west, bounds.west, bounds.east, bounds.east]), + gcj_lat=np.array([bounds.north, bounds.south, bounds.south, bounds.north]) + ) + boundary = list(zip(wgs84_lngs, wgs84_lats)) + else: + boundary = list(zip( + [bounds.west, bounds.west, bounds.east, bounds.east], + [bounds.north, bounds.south, bounds.south, bounds.north] + )) + + source_data = func_source(boundary) + + if source_data is None: + return None + + arr_image, arr_lngs, arr_lats = source_data + + if use_gc02: + gcj02_lngs, gcj02_lats = transfer.WGS84_to_GCJ02(arr_lngs, arr_lats) + else: + gcj02_lngs, gcj02_lats = arr_lngs, arr_lats + + # Define source and target geometries + source_def = geometry.SwathDefinition(lons=gcj02_lngs, lats=gcj02_lats) + + xy_bounds = mercantile.xy_bounds(x, y, z) + target_def = geometry.AreaDefinition( + 'tile', 'tile', 'tile', + 'EPSG:3857', + TILE_SIZE, TILE_SIZE, + (xy_bounds.left, xy_bounds.bottom, xy_bounds.right, xy_bounds.top) + ) + + # Resample + pixel_size = mercantile.CE / 2 ** z / TILE_SIZE + if method == 'nearest': + result = kd_tree.resample_nearest( + source_def, arr_image, target_def, + radius_of_influence=radius * pixel_size, + fill_value=fill_value + ) + elif method == 'bilinear': + resampler = bilinear.NumpyBilinearResampler( + source_def, target_def, + radius_of_influence=radius * pixel_size, + neighbours=8 + ) + result = resampler.resample(arr_image).astype(arr_image.dtype) + else: + raise ValueError(f'Unknown resampling method: {method}') + + return result + diff --git a/tools/pretraining_data_builder/rsi_process/utils_s1.py b/tools/pretraining_data_builder/rsi_process/utils_s1.py new file mode 100644 index 0000000..d7a7f57 --- /dev/null +++ b/tools/pretraining_data_builder/rsi_process/utils_s1.py @@ -0,0 +1,133 @@ +import xml.dom.minidom +import os +from glob import glob +import zipfile +from shapely import wkt +import geopandas as gpd +from osgeo import gdal +import imageio.v2 as iio + +def parse_metadata(meta_xml_file): + """Parse Sentinel-1 metadata XML file + + Args: + meta_xml_file: Metadata XML file path + + Returns: + dict: Dictionary containing key metadata information + """ + record = {} + + dom = xml.dom.minidom.parse(meta_xml_file) # Get sensing start time + sensing_start = dom.getElementsByTagName('startTime')[0].firstChild.data + + product_uri = meta_xml_file.name.split('/')[0] + + record.update({ + 'product_uri': product_uri, + 'sensing_start': sensing_start, + }) + + + return record + +def convert_footprint_to_wkt(footprint): + """Convert footprint string to WKT format""" + coords = footprint.strip().split(' ') + wkt_coords = [] + for coord in coords: + lat, lon = coord.split(',') + wkt_coords.append(f"{lon} {lat}") + return f"MULTIPOLYGON ((({','.join(wkt_coords)})))" + +def zip2rec(fn_zip): + id_img = os.path.splitext(os.path.basename(fn_zip))[0] + archive = zipfile.ZipFile(fn_zip, 'r') + xml_files = [f for f in archive.namelist() if f.endswith('-001.xml')] + if not xml_files: + raise FileNotFoundError(f"No XML file ending with '-001.xml' found in {fn_zip}") + fn_xml = archive.open(xml_files[0]) + rec = parse_metadata(fn_xml) + import pdb; pdb.set_trace() + # rec['geometry'] = wkt.loads(rec['geom_wkt']) + thumb = archive.open(os.path.join(f'{id_img}.SAFE', 'preview', 'quick-look.png')) + thumb = iio.imread(thumb) + rec['thumb'] = thumb + return rec + +def build_catalog(path, fn='catalog'): + ''' + fn: filename or None + ''' + list_fnames = glob(os.path.join(path, 'S2*.zip')) + + list_rec = [] + for fn_zip in list_fnames: + rec = zip2rec(fn_zip) + list_rec.append(rec) + + gdf = gpd.GeoDataFrame(list_rec, crs='EPSG:4326').drop(columns='geom_wkt') + if fn is not None: + fn_geojson = os.path.join(path, f"{fn}.geojson") + gdf.to_file(fn_geojson, driver='GeoJSON') + return fn_geojson + else: + return gdf + +def make_full_name(rec, band): + dict_bands = { + 'VV': '001', + 'VH': '002', + } + parts = rec['product_uri'].split('_') + + satellite = parts[0].lower() # S1A -> s1a + mode = parts[1].lower() # IW -> iw + product_type = parts[2][:3].lower() # GRDH -> grd + polarization = band.lower() # Assume polarization mode is VV + start_time = parts[4].lower() # Start time + end_time = parts[5].lower() # End time + id1 = parts[6].lower() # 058175 + id2 = parts[7].lower() # 072FF2 + fixed_part = dict_bands[band] # Replace fixed part with 001 + + # Concatenate to target format + file_name = f"{satellite}-{mode}-{product_type}-{polarization}-{start_time}-{end_time}-{id1}-{id2}-{fixed_part}.tiff" + + fn_template = os.path.join( + rec['product_uri'], 'measurement', file_name + ) + return fn_template + +def warp( + ds, outputBounds, + outputBoundsSRS='EPSG:4326', + xRes=10, yRes=10, targetAlignedPixels=True, + **kwargs, +): + options_warp = gdal.WarpOptions( + format="MEM", + outputBounds=outputBounds, + outputBoundsSRS=outputBoundsSRS, + xRes=xRes, yRes=yRes, targetAlignedPixels=targetAlignedPixels, + **kwargs, + ) + ds_warp = gdal.Warp('', ds, options=options_warp) + return ds_warp + +def get_ndarray( + ds, outputBounds, + outputBoundsSRS='EPSG:4326', + xRes=10, yRes=10, targetAlignedPixels=True, + **kwargs, +): + ds_warp = warp( + ds, outputBounds, + outputBoundsSRS='EPSG:4326', + xRes=10, yRes=10, targetAlignedPixels=True, + **kwargs + ) + arr = ds_warp.ReadAsArray() + ds_warp = None + return arr + diff --git a/tools/pretraining_data_builder/rsi_process/utils_s2.py b/tools/pretraining_data_builder/rsi_process/utils_s2.py new file mode 100644 index 0000000..c19a4aa --- /dev/null +++ b/tools/pretraining_data_builder/rsi_process/utils_s2.py @@ -0,0 +1,158 @@ +import xml.dom.minidom +import os +from glob import glob +import zipfile +from shapely import wkt +import geopandas as gpd +from osgeo import gdal +import imageio.v2 as iio + +def parse_metadata(meta_xml_file): + """Parse Sentinel-2 metadata XML file + + Args: + meta_xml_file: Path to metadata XML file + + Returns: + dict: Metadata information including sensing time, product URI, etc. + """ + record = {} + try: + dom = xml.dom.minidom.parse(meta_xml_file) + + # Get sensing start time + sensing_start = dom.getElementsByTagName('DATATAKE_SENSING_START')[0].firstChild.data[0:10] + + # Get product URI and image paths + product_uri = dom.getElementsByTagName('PRODUCT_URI')[0].firstChild.data + + image_file = dom.getElementsByTagName('IMAGE_FILE')[0].firstChild.data + items = image_file.split('/') + granule_path = items[1] + img_name = items[4].split('_')[0] + '_' + items[4].split('_')[1] + + # Get footprint + footprint = dom.getElementsByTagName('EXT_POS_LIST')[0].firstChild.data + geom_wkt = convert_footprint_to_wkt(footprint) + + # Get cloud coverage info + cloud_coverage = float(dom.getElementsByTagName('Cloud_Coverage_Assessment')[0].firstChild.data) + cloud_shadow = float(dom.getElementsByTagName('CLOUD_SHADOW_PERCENTAGE')[0].firstChild.data) + medium_clouds = float(dom.getElementsByTagName('MEDIUM_PROBA_CLOUDS_PERCENTAGE')[0].firstChild.data) + high_clouds = float(dom.getElementsByTagName('HIGH_PROBA_CLOUDS_PERCENTAGE')[0].firstChild.data) + + record.update({ + 'product_uri': product_uri, + 'sensing_start': sensing_start, + 'granule_path': granule_path, + 'img_name': img_name, + 'cloud_cover': cloud_coverage, + 'cloud_shadow': cloud_shadow, + 'medium_clouds': medium_clouds, + 'high_clouds': high_clouds, + 'geom_wkt': geom_wkt + }) + + except Exception as e: + print(f'Failed to parse XML: {e}') + + return record + +def convert_footprint_to_wkt(footprint): + """Convert footprint string to WKT format""" + coords = footprint.strip().split(' ') + wkt_coords = [] + for i in range(0, len(coords), 2): + wkt_coords.append(f"{coords[i+1]} {coords[i]}") + return f"MULTIPOLYGON ((({','.join(wkt_coords)})))" + +def zip2rec(fn_zip): + id_img = os.path.splitext(os.path.basename(fn_zip))[0] + archive = zipfile.ZipFile(fn_zip, 'r') + fn_xml = archive.open(os.path.join(f'{id_img}.SAFE', 'MTD_MSIL2A.xml')) + rec = parse_metadata(fn_xml) + rec['geometry'] = wkt.loads(rec['geom_wkt']) + thumb = archive.open(os.path.join(f'{id_img}.SAFE', f'{id_img}-ql.jpg')) + thumb = iio.imread(thumb) + rec['thumb'] = thumb + return rec + +def build_catalog(path, fn='catalog'): + ''' + fn: filename or None + ''' + list_fnames = glob(os.path.join(path, 'S2*.zip')) + + list_rec = [] + for fn_zip in list_fnames: + rec = zip2rec(fn_zip) + list_rec.append(rec) + + gdf = gpd.GeoDataFrame(list_rec, crs='EPSG:4326').drop(columns='geom_wkt') + if fn is not None: + fn_geojson = os.path.join(path, f"{fn}.geojson") + gdf.to_file(fn_geojson, driver='GeoJSON') + return fn_geojson + else: + return gdf + +def make_full_name(rec, band): + dict_bands = { + 'B2': ['B02', '10m'], + 'B3': ['B03', '10m'], + 'B4': ['B04', '10m'], + 'B8': ['B08', '10m'], + 'B5': ['B05', '20m'], + 'B6': ['B06', '20m'], + 'B7': ['B07', '20m'], + 'B8A': ['B8A', '20m'], + 'B11': ['B11', '20m'], + 'B12': ['B12', '20m'], + 'SCL': ['SCL', '20m'], + } + fn_template = os.path.join( + '{p0}', 'GRANULE', + '{p1}', 'IMG_DATA', "R{p2}", + '{p3}_{p4}_{p2}.jp2' + ) + return fn_template.format(**{ + 'p0': rec['product_uri'], + 'p0b': rec['product_uri'].split('.')[0], + 'p1': rec['granule_path'], + 'p2': dict_bands[band][1], + 'p3': rec['img_name'], + 'p4': dict_bands[band][0], + }) + +def warp( + ds, outputBounds, + outputBoundsSRS='EPSG:4326', + xRes=10, yRes=10, targetAlignedPixels=True, + **kwargs, +): + options_warp = gdal.WarpOptions( + format="MEM", + outputBounds=outputBounds, + outputBoundsSRS=outputBoundsSRS, + xRes=xRes, yRes=yRes, targetAlignedPixels=targetAlignedPixels, + **kwargs, + ) + ds_warp = gdal.Warp('', ds, options=options_warp) + return ds_warp + +def get_ndarray( + ds, outputBounds, + outputBoundsSRS='EPSG:4326', + xRes=10, yRes=10, targetAlignedPixels=True, + **kwargs, +): + ds_warp = warp( + ds, outputBounds, + outputBoundsSRS='EPSG:4326', + xRes=10, yRes=10, targetAlignedPixels=True, + **kwargs + ) + arr = ds_warp.ReadAsArray() + ds_warp = None + return arr + diff --git a/tools/pretraining_data_builder/rsi_process/utils_wv.py b/tools/pretraining_data_builder/rsi_process/utils_wv.py new file mode 100644 index 0000000..feb2256 --- /dev/null +++ b/tools/pretraining_data_builder/rsi_process/utils_wv.py @@ -0,0 +1,237 @@ +import os +from osgeo import gdal +import numpy as np +from datetime import datetime +import xml.etree.ElementTree as ET + +def parse_metadata(meta_xml_file): + """Parse the WorldView metadata XML file + + Args: + meta_xml_file: Metadata XML file path + + Returns: + dict: Dictionary containing key metadata information + """ + record = {} + + try: + tree = ET.parse(meta_xml_file) + root = tree.getroot() + + ns = {'imd': root.tag.split('}')[0].strip('{')} + + # Get basic information + record['satellite_id'] = root.find('.//imd:satelliteID', ns).text + record['product_type'] = root.find('.//imd:productType', ns).text + + # Get acquisition time + acq_time = root.find('.//imd:firstLineTime', ns).text + record['sensing_start'] = datetime.strptime(acq_time, '%Y-%m-%dT%H:%M:%S.%fZ') + + # Get solar angle + record['sun_azimuth'] = float(root.find('.//imd:meanSunAz', ns).text) + record['sun_elevation'] = float(root.find('.//imd:meanSunEl', ns).text) + + # Get satellite angle + record['satellite_azimuth'] = float(root.find('.//imd:meanSatAz', ns).text) + record['satellite_elevation'] = float(root.find('.//imd:meanSatEl', ns).text) + + # Get cloud cover + cloud_cover = root.find('.//imd:cloudCover', ns) + record['cloud_cover'] = float(cloud_cover.text) if cloud_cover is not None else None + + # Get image range + record['ul_lon'] = float(root.find('.//imd:ULLon', ns).text) + record['ul_lat'] = float(root.find('.//imd:ULLat', ns).text) + record['ur_lon'] = float(root.find('.//imd:URLon', ns).text) + record['ur_lat'] = float(root.find('.//imd:URLat', ns).text) + record['ll_lon'] = float(root.find('.//imd:LLLon', ns).text) + record['ll_lat'] = float(root.find('.//imd:LLLat', ns).text) + record['lr_lon'] = float(root.find('.//imd:LRLon', ns).text) + record['lr_lat'] = float(root.find('.//imd:LRLat', ns).text) + + # Build WKT format geometry information + record['geom_wkt'] = create_footprint_wkt(record) + + except Exception as e: + print(f"Error parsing metadata: {str(e)}") + return None + + return record + +def create_footprint_wkt(record): + """Create a WKT format polygon based on corner coordinates + + Args: + record: Dictionary containing corner coordinates + + Returns: + str: WKT format polygon string + """ + coords = [ + (record['ul_lon'], record['ul_lat']), + (record['ur_lon'], record['ur_lat']), + (record['lr_lon'], record['lr_lat']), + (record['ll_lon'], record['ll_lat']), + (record['ul_lon'], record['ul_lat']) + ] + + coord_str = ', '.join([f"{lon} {lat}" for lon, lat in coords]) + return f"POLYGON(({coord_str}))" + +def get_band_info(ds): + """Get the band information of the image + + Args: + ds: GDAL dataset + + Returns: + list: Band information list + """ + bands = [] + for i in range(1, ds.RasterCount + 1): + band = ds.GetRasterBand(i) + band_info = { + 'band_number': i, + 'data_type': gdal.GetDataTypeName(band.DataType), + 'nodata_value': band.GetNoDataValue() + } + bands.append(band_info) + return bands + +def read_as_array(ds, window=None): + """Read image data as a numpy array + + Args: + ds: GDAL dataset + window: Read window, format as (xoff, yoff, xsize, ysize) + + Returns: + numpy.ndarray: Image data array + """ + if window is None: + return ds.ReadAsArray() + else: + xoff, yoff, xsize, ysize = window + return ds.ReadAsArray(xoff, yoff, xsize, ysize) + +def get_image_info(fn_img): + """Get basic information of WorldView image + + Args: + fn_img: Image file path + + Returns: + dict: Image information dictionary + """ + ds = gdal.Open(fn_img) + if ds is None: + raise Exception(f"Cannot open {fn_img}") + + info = { + 'width': ds.RasterXSize, + 'height': ds.RasterYSize, + 'bands': ds.RasterCount, + 'projection': ds.GetProjection(), + 'geotransform': ds.GetGeoTransform(), + 'band_info': get_band_info(ds) + } + + xml_file = fn_img.replace('.tif', '.xml') + if os.path.exists(xml_file): + metadata = parse_metadata(xml_file) + if metadata: + info.update(metadata) + + ds = None + return info + +def calculate_stats(fn_img, percentiles=[2, 98]): + """Calculate the statistics of the image + + Args: + fn_img: Image file path + percentiles: List of percentiles + + Returns: + dict: Statistics dictionary + """ + ds = gdal.Open(fn_img) + stats = {} + + for i in range(1, ds.RasterCount + 1): + band = ds.GetRasterBand(i) + array = band.ReadAsArray() + valid_data = array[array != band.GetNoDataValue()] + + stats[f'band_{i}'] = { + 'min': np.min(valid_data), + 'max': np.max(valid_data), + 'mean': np.mean(valid_data), + 'std': np.std(valid_data), + 'percentiles': { + p: np.percentile(valid_data, p) + for p in percentiles + } + } + + ds = None + return stats + +def create_quicklook(fn_img, output_file, size=(1024, 1024)): + """Create a thumbnail + + Args: + fn_img: Image file path + output_file: Output file path + size: Output image size + """ + ds = gdal.Open(fn_img) + + if ds.RasterCount >= 3: + r = ds.GetRasterBand(1).ReadAsArray() + g = ds.GetRasterBand(2).ReadAsArray() + b = ds.GetRasterBand(3).ReadAsArray() + + def stretch(arr): + p2, p98 = np.percentile(arr[arr > 0], (2, 98)) + return np.clip((arr - p2) / (p98 - p2) * 255, 0, 255).astype(np.uint8) + + rgb = np.dstack([stretch(r), stretch(g), stretch(b)]) + + from PIL import Image + img = Image.fromarray(rgb) + img.thumbnail(size) + img.save(output_file) + + ds = None + +def warp(ds, outputBounds, + outputBoundsSRS='EPSG:4326', + xRes=2, yRes=2, + targetAlignedPixels=True, + **kwargs): + """Reprojection and resampling + + Args: + ds: GDAL dataset + outputBounds: Output range + outputBoundsSRS: Output coordinate system + xRes, yRes: Output resolution + targetAlignedPixels: Whether to align pixels + **kwargs: Other GDAL.Warp parameters + + Returns: + GDAL dataset + """ + options_warp = gdal.WarpOptions( + format="MEM", + outputBounds=outputBounds, + outputBoundsSRS=outputBoundsSRS, + xRes=xRes, yRes=yRes, + targetAlignedPixels=targetAlignedPixels, + **kwargs + ) + ds_warp = gdal.Warp('', ds, options=options_warp) + return ds_warp \ No newline at end of file diff --git a/tools/pretraining_data_builder/run_data_builder.sh b/tools/pretraining_data_builder/run_data_builder.sh new file mode 100644 index 0000000..5017db2 --- /dev/null +++ b/tools/pretraining_data_builder/run_data_builder.sh @@ -0,0 +1,11 @@ +#! /bin/bash +source activate data_builder +export USERNAME=your_username +export PASSWORD=your_password +export API_KEY=your_api_key + +export PYTHONPATH=$PYTHONPATH:$(pwd) + +LMDB_PATH=your_lmdb_path + +python rsi_pipeline/data_builder.py $LMDB_PATH \ No newline at end of file diff --git a/tools/run.py b/tools/run.py new file mode 100644 index 0000000..6e1d049 --- /dev/null +++ b/tools/run.py @@ -0,0 +1,44 @@ +from antmmf.utils.env import setup_compatibility +from antmmf.utils.flags import flags +from antmmf.run import plain_run + +import os +import sys + +sys.path.insert(0, os.path.dirname(__file__)) +from lib import * # noqa make sure all modules have been registered. + +usage = """ + Usage: + python tools/run.py --config configs/foo/bar.yml [OPTIONS] [OPTS] + + Options: + --config_override override.yml configurations from this file will override the --config one. like + python tools/run.py --config configs/foo/bar.yml --config_override my_foobar.yml + + --local_rank local rank of your machine, used in parallel mode + + OPTS: override specific value in config, like + python tools/run.py --config configs/foo/bar.yml \\ + training_parameters.device cuda:0 \\ + training_parameters.max_epochs 5 \\ + task_attributes.hateful_memes.dataset_attributes.foo.images.train \\ + "[foo/defaults/images]" + + Priority: + OPTS OVERRIDE --config_override OVERRIDE --config, see antmmf/common/build.py::build_config for details +""" + + +def run(): + parser = flags.get_parser() + try: + args = parser.parse_args() + plain_run(args) + except SystemExit: + exit(2) + + +if __name__ == "__main__": + setup_compatibility() + run() \ No newline at end of file diff --git a/tools/run_1shot.sh b/tools/run_1shot.sh new file mode 100644 index 0000000..9474d46 --- /dev/null +++ b/tools/run_1shot.sh @@ -0,0 +1,35 @@ +#!/bin/bash +set -x +export PYTHONPATH=`pwd`:$PYTHONPATH +cd antmmf +export PYTHONPATH=`pwd`:$PYTHONPATH +cd ../ + +export CUDA_VISIBLE_DEVICES=$1 +dataset_name=$2 + +CONFIG_PATH=configs/eval_skysense_pp_${dataset_name}.yml +MODEL_PATH=pretrain/skysensepp_release.ckpt + +SAVE_DIR=eval/${dataset_name}_1shot/save +GT_DIR=eval_datasets/${dataset_name}/targets +GT_LIST_PATH=eval_datasets/${dataset_name}/val.txt + +mkdir -p $SAVE_DIR + +# predictor +python lib/predictors/${dataset_name}_1shot.py \ + --model_path $MODEL_PATH \ + --config $CONFIG_PATH \ + --save_dir $SAVE_DIR \ + --seed 0 + +# eval +python lib/evaluation/segm_eval_base.py \ + --pred_dir ${SAVE_DIR} \ + --gt_dir ${GT_DIR} \ + --gt_list_path ${GT_LIST_PATH} \ + --gt_suffix '.png' \ + --dataset_name ${dataset_name} \ + --dist_type 'abs' \ + --model_name skysense++_1shot \ No newline at end of file diff --git a/tools/run_finetune.sh b/tools/run_finetune.sh new file mode 100644 index 0000000..6e7c718 --- /dev/null +++ b/tools/run_finetune.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +cd finetune/ +cfg=$1 +gpu_num=8 + +bash finetune/tools/dist_train.sh $cfg $gpu_num + + diff --git a/tools/run_pretrain.sh b/tools/run_pretrain.sh new file mode 100644 index 0000000..3cbd6cc --- /dev/null +++ b/tools/run_pretrain.sh @@ -0,0 +1,24 @@ +#!/bin/sh +set -x +export PYTHONPATH=`pwd`:$PYTHONPATH +cd antmmf +export PYTHONPATH=`pwd`:$PYTHONPATH +cd .. +export NCCL_DEBUG=INFO + +export OMP_NUM_THREADS=4 + +NUM_GPU=8 + +CONFIG_FILE=configs/pretrain_skysensepp.yml +SAVE_DIR=save/skysensepp_pretrain + +mkdir -p ${SAVE_DIR}/$1/ + +pip install lmdb +nohup python -m antmmf.utils.launch --nproc_per_node=${NUM_GPU} --master_port 12345 --nnodes=4 --node_rank=$1 --master_addr=$2 tools/run.py --config $CONFIG_FILE \ + training_parameters.distributed True \ + training_parameters.save_dir ${SAVE_DIR} > ${SAVE_DIR}/$1/nohup.log 2>&1 & + +sleep 3s +tail -f ${SAVE_DIR}/$1/nohup.log