init
This commit is contained in:
70
finetune/mmseg/utils/__init__.py
Normal file
70
finetune/mmseg/utils/__init__.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# yapf: disable
|
||||
from .class_names import (ade_classes, ade_palette, bdd100k_classes,
|
||||
bdd100k_palette, cityscapes_classes,
|
||||
cityscapes_palette, cocostuff_classes,
|
||||
cocostuff_palette, dataset_aliases, get_classes,
|
||||
get_palette, isaid_classes, isaid_palette,
|
||||
loveda_classes, loveda_palette, potsdam_classes,
|
||||
potsdam_palette, stare_classes, stare_palette,
|
||||
synapse_classes, synapse_palette, vaihingen_classes,
|
||||
vaihingen_palette, voc_classes, voc_palette)
|
||||
# yapf: enable
|
||||
from .collect_env import collect_env
|
||||
from .get_templates import get_predefined_templates
|
||||
from .io import datafrombytes
|
||||
from .misc import add_prefix, stack_batch
|
||||
from .set_env import register_all_modules
|
||||
from .tokenizer import tokenize
|
||||
from .typing_utils import (ConfigType, ForwardResults, MultiConfig,
|
||||
OptConfigType, OptMultiConfig, OptSampleList,
|
||||
SampleList, TensorDict, TensorList)
|
||||
|
||||
# isort: off
|
||||
from .mask_classification import MatchMasks, seg_data_to_instance_data
|
||||
|
||||
__all__ = [
|
||||
'collect_env',
|
||||
'register_all_modules',
|
||||
'stack_batch',
|
||||
'add_prefix',
|
||||
'ConfigType',
|
||||
'OptConfigType',
|
||||
'MultiConfig',
|
||||
'OptMultiConfig',
|
||||
'SampleList',
|
||||
'OptSampleList',
|
||||
'TensorDict',
|
||||
'TensorList',
|
||||
'ForwardResults',
|
||||
'cityscapes_classes',
|
||||
'ade_classes',
|
||||
'voc_classes',
|
||||
'cocostuff_classes',
|
||||
'loveda_classes',
|
||||
'potsdam_classes',
|
||||
'vaihingen_classes',
|
||||
'isaid_classes',
|
||||
'stare_classes',
|
||||
'cityscapes_palette',
|
||||
'ade_palette',
|
||||
'voc_palette',
|
||||
'cocostuff_palette',
|
||||
'loveda_palette',
|
||||
'potsdam_palette',
|
||||
'vaihingen_palette',
|
||||
'isaid_palette',
|
||||
'stare_palette',
|
||||
'dataset_aliases',
|
||||
'get_classes',
|
||||
'get_palette',
|
||||
'datafrombytes',
|
||||
'synapse_palette',
|
||||
'synapse_classes',
|
||||
'get_predefined_templates',
|
||||
'tokenize',
|
||||
'seg_data_to_instance_data',
|
||||
'MatchMasks',
|
||||
'bdd100k_classes',
|
||||
'bdd100k_palette',
|
||||
]
|
||||
BIN
finetune/mmseg/utils/bpe_simple_vocab_16e6.txt.gz
Normal file
BIN
finetune/mmseg/utils/bpe_simple_vocab_16e6.txt.gz
Normal file
Binary file not shown.
548
finetune/mmseg/utils/class_names.py
Normal file
548
finetune/mmseg/utils/class_names.py
Normal file
@@ -0,0 +1,548 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.utils import is_str
|
||||
|
||||
|
||||
def cityscapes_classes():
|
||||
"""Cityscapes class names for external use."""
|
||||
return [
|
||||
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
||||
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
||||
'bicycle'
|
||||
]
|
||||
|
||||
|
||||
def ade_classes():
|
||||
"""ADE20K class names for external use."""
|
||||
return [
|
||||
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
|
||||
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
|
||||
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
|
||||
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
|
||||
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
|
||||
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
|
||||
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
|
||||
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
|
||||
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
|
||||
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
|
||||
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
|
||||
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
|
||||
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
|
||||
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
|
||||
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
|
||||
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
|
||||
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
|
||||
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
|
||||
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
|
||||
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
|
||||
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
|
||||
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
|
||||
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
|
||||
'clock', 'flag'
|
||||
]
|
||||
|
||||
|
||||
def voc_classes():
|
||||
"""Pascal VOC class names for external use."""
|
||||
return [
|
||||
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
|
||||
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
|
||||
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
|
||||
'tvmonitor'
|
||||
]
|
||||
|
||||
|
||||
def pcontext_classes():
|
||||
"""Pascal Context class names for external use."""
|
||||
return [
|
||||
'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird',
|
||||
'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat',
|
||||
'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain',
|
||||
'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground',
|
||||
'horse', 'keyboard', 'light', 'motorbike', 'mountain', 'mouse',
|
||||
'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
|
||||
'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track',
|
||||
'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window',
|
||||
'wood'
|
||||
]
|
||||
|
||||
|
||||
def cocostuff_classes():
|
||||
"""CocoStuff class names for external use."""
|
||||
return [
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
||||
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
||||
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
|
||||
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
|
||||
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
||||
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
|
||||
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
||||
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
||||
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
|
||||
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
||||
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
||||
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
|
||||
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
|
||||
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
|
||||
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
|
||||
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
|
||||
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',
|
||||
'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',
|
||||
'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',
|
||||
'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper',
|
||||
'pavement', 'pillow', 'plant-other', 'plastic', 'platform',
|
||||
'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',
|
||||
'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',
|
||||
'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',
|
||||
'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',
|
||||
'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',
|
||||
'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
|
||||
'window-blind', 'window-other', 'wood'
|
||||
]
|
||||
|
||||
|
||||
def loveda_classes():
|
||||
"""LoveDA class names for external use."""
|
||||
return [
|
||||
'background', 'building', 'road', 'water', 'barren', 'forest',
|
||||
'agricultural'
|
||||
]
|
||||
|
||||
|
||||
def potsdam_classes():
|
||||
"""Potsdam class names for external use."""
|
||||
return [
|
||||
'impervious_surface', 'building', 'low_vegetation', 'tree', 'car',
|
||||
'clutter'
|
||||
]
|
||||
|
||||
|
||||
def vaihingen_classes():
|
||||
"""Vaihingen class names for external use."""
|
||||
return [
|
||||
'impervious_surface', 'building', 'low_vegetation', 'tree', 'car',
|
||||
'clutter'
|
||||
]
|
||||
|
||||
|
||||
def isaid_classes():
|
||||
"""iSAID class names for external use."""
|
||||
return [
|
||||
'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court',
|
||||
'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle',
|
||||
'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout',
|
||||
'Soccer_ball_field', 'plane', 'Harbor'
|
||||
]
|
||||
|
||||
|
||||
def stare_classes():
|
||||
"""stare class names for external use."""
|
||||
return ['background', 'vessel']
|
||||
|
||||
|
||||
def mapillary_v1_classes():
|
||||
"""mapillary_v1 class names for external use."""
|
||||
return [
|
||||
'Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier',
|
||||
'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking',
|
||||
'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane', 'Sidewalk',
|
||||
'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
|
||||
'Other Rider', 'Lane Marking - Crosswalk', 'Lane Marking - General',
|
||||
'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water',
|
||||
'Banner', 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin',
|
||||
'CCTV Camera', 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
|
||||
'Phone Booth', 'Pothole', 'Street Light', 'Pole', 'Traffic Sign Frame',
|
||||
'Utility Pole', 'Traffic Light', 'Traffic Sign (Back)',
|
||||
'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car',
|
||||
'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer',
|
||||
'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'
|
||||
]
|
||||
|
||||
|
||||
def mapillary_v1_palette():
|
||||
"""mapillary_v1_ palette for external use."""
|
||||
return [[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
|
||||
[180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255],
|
||||
[140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232],
|
||||
[150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60],
|
||||
[255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128],
|
||||
[255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180],
|
||||
[190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30],
|
||||
[255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220],
|
||||
[220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40],
|
||||
[33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150],
|
||||
[210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80],
|
||||
[250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20],
|
||||
[119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142],
|
||||
[0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]]
|
||||
|
||||
|
||||
def mapillary_v2_classes():
|
||||
"""mapillary_v2 class names for external use."""
|
||||
return [
|
||||
'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block', 'Curb',
|
||||
'Fence', 'Guard Rail', 'Barrier', 'Road Median', 'Road Side',
|
||||
'Lane Separator', 'Temporary Barrier', 'Wall', 'Bike Lane',
|
||||
'Crosswalk - Plain', 'Curb Cut', 'Driveway', 'Parking',
|
||||
'Parking Aisle', 'Pedestrian Area', 'Rail Track', 'Road',
|
||||
'Road Shoulder', 'Service Lane', 'Sidewalk', 'Traffic Island',
|
||||
'Bridge', 'Building', 'Garage', 'Tunnel', 'Person', 'Person Group',
|
||||
'Bicyclist', 'Motorcyclist', 'Other Rider',
|
||||
'Lane Marking - Dashed Line', 'Lane Marking - Straight Line',
|
||||
'Lane Marking - Zigzag Line', 'Lane Marking - Ambiguous',
|
||||
'Lane Marking - Arrow (Left)', 'Lane Marking - Arrow (Other)',
|
||||
'Lane Marking - Arrow (Right)',
|
||||
'Lane Marking - Arrow (Split Left or Straight)',
|
||||
'Lane Marking - Arrow (Split Right or Straight)',
|
||||
'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk',
|
||||
'Lane Marking - Give Way (Row)', 'Lane Marking - Give Way (Single)',
|
||||
'Lane Marking - Hatched (Chevron)',
|
||||
'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other',
|
||||
'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)',
|
||||
'Lane Marking - Symbol (Other)', 'Lane Marking - Text',
|
||||
'Lane Marking (only) - Dashed Line', 'Lane Marking (only) - Crosswalk',
|
||||
'Lane Marking (only) - Other', 'Lane Marking (only) - Test',
|
||||
'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water',
|
||||
'Banner', 'Bench', 'Bike Rack', 'Catch Basin', 'CCTV Camera',
|
||||
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Parking Meter',
|
||||
'Phone Booth', 'Pothole', 'Signage - Advertisement',
|
||||
'Signage - Ambiguous', 'Signage - Back', 'Signage - Information',
|
||||
'Signage - Other', 'Signage - Store', 'Street Light', 'Pole',
|
||||
'Pole Group', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Cone',
|
||||
'Traffic Light - General (Single)', 'Traffic Light - Pedestrians',
|
||||
'Traffic Light - General (Upright)',
|
||||
'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists',
|
||||
'Traffic Light - Other', 'Traffic Sign - Ambiguous',
|
||||
'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)',
|
||||
'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)',
|
||||
'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)',
|
||||
'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat',
|
||||
'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle',
|
||||
'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve',
|
||||
'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static', 'Unlabeled'
|
||||
]
|
||||
|
||||
|
||||
def mapillary_v2_palette():
|
||||
"""mapillary_v2_ palette for external use."""
|
||||
return [[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32],
|
||||
[196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150],
|
||||
[250, 170, 33], [250, 170, 34], [128, 128, 128], [250, 170, 35],
|
||||
[102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170],
|
||||
[250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110], [110, 110, 110],
|
||||
[244, 35, 232], [128, 196, 128], [150, 100, 100], [70, 70, 70],
|
||||
[150, 150, 150], [150, 120, 90], [220, 20, 60], [220, 20, 60],
|
||||
[255, 0, 0], [255, 0, 100], [255, 0, 200], [255, 255, 255],
|
||||
[255, 255, 255], [250, 170, 29], [250, 170, 28], [250, 170, 26],
|
||||
[250, 170, 25], [250, 170, 24], [250, 170, 22], [250, 170, 21],
|
||||
[250, 170, 20], [255, 255, 255], [250, 170, 19], [250, 170, 18],
|
||||
[250, 170, 12], [250, 170, 11], [255, 255, 255], [255, 255, 255],
|
||||
[250, 170, 16], [250, 170, 15], [250, 170, 15], [255, 255, 255],
|
||||
[255, 255, 255], [255, 255, 255], [255, 255, 255], [64, 170, 64],
|
||||
[230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152],
|
||||
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
|
||||
[100, 140, 180], [220, 128, 128], [222, 40, 40], [100, 170, 30],
|
||||
[40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255],
|
||||
[142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30],
|
||||
[250, 173, 30], [250, 174, 30], [250, 175, 30], [250, 176, 30],
|
||||
[210, 170, 100], [153, 153, 153], [153, 153, 153], [128, 128, 128],
|
||||
[0, 0, 80], [210, 60, 60], [250, 170, 30], [250, 170, 30],
|
||||
[250, 170, 30], [250, 170, 30], [250, 170, 30], [250, 170, 30],
|
||||
[192, 192, 192], [192, 192, 192], [192, 192, 192], [220, 220, 0],
|
||||
[220, 220, 0], [0, 0, 196], [192, 192, 192], [220, 220, 0],
|
||||
[140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100],
|
||||
[0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64],
|
||||
[0, 0, 110], [0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170],
|
||||
[32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81],
|
||||
[111, 111, 0], [0, 0, 0]]
|
||||
|
||||
|
||||
def cityscapes_palette():
|
||||
"""Cityscapes palette for external use."""
|
||||
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
||||
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
||||
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
||||
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
|
||||
[0, 0, 230], [119, 11, 32]]
|
||||
|
||||
|
||||
def ade_palette():
|
||||
"""ADE20K palette for external use."""
|
||||
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
||||
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
||||
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
||||
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
||||
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
||||
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
||||
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
||||
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
||||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
||||
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
||||
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
||||
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
||||
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
||||
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
||||
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
||||
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
||||
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
||||
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
||||
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
||||
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
||||
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
||||
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
||||
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
||||
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
||||
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
||||
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
||||
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
||||
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
||||
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
||||
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
||||
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
||||
[102, 255, 0], [92, 0, 255]]
|
||||
|
||||
|
||||
def voc_palette():
|
||||
"""Pascal VOC palette for external use."""
|
||||
return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
|
||||
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
|
||||
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
|
||||
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
|
||||
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
|
||||
|
||||
|
||||
def pcontext_palette():
|
||||
"""Pascal Context palette for external use."""
|
||||
return [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
|
||||
[120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
|
||||
[4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
|
||||
[120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
|
||||
[204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
|
||||
[61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
|
||||
[255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
|
||||
[112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
|
||||
[10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
|
||||
[102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
|
||||
[0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
|
||||
[235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
|
||||
[250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
|
||||
[255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
|
||||
[0, 235, 255], [0, 173, 255], [31, 0, 255]]
|
||||
|
||||
|
||||
def cocostuff_palette():
|
||||
"""CocoStuff palette for external use."""
|
||||
return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
|
||||
[0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
|
||||
[0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
|
||||
[0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
|
||||
[0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
|
||||
[128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
|
||||
[64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], [0, 32, 0],
|
||||
[0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0],
|
||||
[192, 128, 32], [128, 96, 128], [0, 0, 128], [64, 0, 32],
|
||||
[0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],
|
||||
[128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64],
|
||||
[192, 0, 32], [128, 96, 0], [128, 0, 192], [0, 128, 32],
|
||||
[64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],
|
||||
[0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64],
|
||||
[128, 128, 32], [192, 32, 128], [0, 64, 192], [0, 0, 32],
|
||||
[64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128],
|
||||
[128, 192, 192], [0, 0, 160], [192, 160, 128], [128, 192, 0],
|
||||
[128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96],
|
||||
[64, 160, 0], [0, 64, 0], [192, 128, 224], [64, 32, 0],
|
||||
[0, 192, 128], [64, 128, 224], [192, 160, 0], [0, 192, 0],
|
||||
[192, 128, 96], [192, 96, 128], [0, 64, 128], [64, 0, 96],
|
||||
[64, 224, 128], [128, 64, 0], [192, 0, 224], [64, 96, 128],
|
||||
[128, 192, 128], [64, 0, 224], [192, 224, 128], [128, 192, 64],
|
||||
[192, 0, 96], [192, 96, 0], [128, 64, 192], [0, 128, 96],
|
||||
[0, 224, 0], [64, 64, 64], [128, 128, 224], [0, 96, 0],
|
||||
[64, 192, 192], [0, 128, 224], [128, 224, 0], [64, 192, 64],
|
||||
[128, 128, 96], [128, 32, 128], [64, 0, 192], [0, 64, 96],
|
||||
[0, 160, 128], [192, 0, 64], [128, 64, 224], [0, 32, 128],
|
||||
[192, 128, 192], [0, 64, 224], [128, 160, 128], [192, 128, 0],
|
||||
[128, 64, 32], [128, 32, 64], [192, 0, 128], [64, 192, 32],
|
||||
[0, 160, 64], [64, 0, 0], [192, 192, 160], [0, 32, 64],
|
||||
[64, 128, 128], [64, 192, 160], [128, 160, 64], [64, 128, 0],
|
||||
[192, 192, 32], [128, 96, 192], [64, 0, 128], [64, 64, 32],
|
||||
[0, 224, 192], [192, 0, 0], [192, 64, 160], [0, 96, 192],
|
||||
[192, 128, 128], [64, 64, 160], [128, 224, 192], [192, 128, 64],
|
||||
[192, 64, 32], [128, 96, 64], [192, 0, 192], [0, 192, 32],
|
||||
[64, 224, 64], [64, 0, 64], [128, 192, 160], [64, 96, 64],
|
||||
[64, 128, 192], [0, 192, 160], [192, 224, 64], [64, 128, 64],
|
||||
[128, 192, 32], [192, 32, 192], [64, 64, 192], [0, 64, 32],
|
||||
[64, 160, 192], [192, 64, 64], [128, 64, 160], [64, 32, 192],
|
||||
[192, 192, 192], [0, 64, 160], [192, 160, 192], [192, 192, 0],
|
||||
[128, 64, 96], [192, 32, 64], [192, 64, 128], [64, 192, 96],
|
||||
[64, 160, 64], [64, 64, 0]]
|
||||
|
||||
|
||||
def loveda_palette():
|
||||
"""LoveDA palette for external use."""
|
||||
return [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
|
||||
[159, 129, 183], [0, 255, 0], [255, 195, 128]]
|
||||
|
||||
|
||||
def potsdam_palette():
|
||||
"""Potsdam palette for external use."""
|
||||
return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
||||
[255, 255, 0], [255, 0, 0]]
|
||||
|
||||
|
||||
def vaihingen_palette():
|
||||
"""Vaihingen palette for external use."""
|
||||
return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
||||
[255, 255, 0], [255, 0, 0]]
|
||||
|
||||
|
||||
def isaid_palette():
|
||||
"""iSAID palette for external use."""
|
||||
return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
|
||||
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127,
|
||||
127], [0, 0, 127],
|
||||
[0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191],
|
||||
[0, 127, 255], [0, 100, 155]]
|
||||
|
||||
|
||||
def stare_palette():
|
||||
"""STARE palette for external use."""
|
||||
return [[120, 120, 120], [6, 230, 230]]
|
||||
|
||||
|
||||
def synapse_palette():
|
||||
"""Synapse palette for external use."""
|
||||
return [[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], [0, 255, 255],
|
||||
[255, 0, 255], [255, 255, 0], [60, 255, 255], [240, 240, 240]]
|
||||
|
||||
|
||||
def synapse_classes():
|
||||
"""Synapse class names for external use."""
|
||||
return [
|
||||
'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney',
|
||||
'liver', 'pancreas', 'spleen', 'stomach'
|
||||
]
|
||||
|
||||
|
||||
def lip_classes():
|
||||
"""LIP class names for external use."""
|
||||
return [
|
||||
'background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes',
|
||||
'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt',
|
||||
'face', 'leftArm', 'rightArm', 'leftLeg', 'rightLeg', 'leftShoe',
|
||||
'rightShoe'
|
||||
]
|
||||
|
||||
|
||||
def lip_palette():
|
||||
"""LIP palette for external use."""
|
||||
return [
|
||||
'Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'UpperClothes',
|
||||
'Dress', 'Coat', 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt',
|
||||
'Face', 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe',
|
||||
'Right-shoe'
|
||||
]
|
||||
|
||||
|
||||
def bdd100k_classes():
|
||||
"""BDD100K class names for external use(the class name is compatible with
|
||||
Cityscapes )."""
|
||||
return [
|
||||
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
||||
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
||||
'bicycle'
|
||||
]
|
||||
|
||||
|
||||
def bdd100k_palette():
|
||||
"""bdd100k palette for external use(same with cityscapes)"""
|
||||
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
||||
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
||||
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
||||
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
|
||||
[0, 0, 230], [119, 11, 32]]
|
||||
|
||||
|
||||
def hsidrive_classes():
|
||||
"""HSI Drive 2.0 class names for external use."""
|
||||
return [
|
||||
'unlabelled', 'road', 'road marks', 'vegetation', 'painted metal',
|
||||
'sky', 'concrete', 'pedestrian', 'water', 'unpainted metal', 'glass'
|
||||
]
|
||||
|
||||
|
||||
def hsidrive_palette():
|
||||
"""HSI Drive 2.0 palette for external use."""
|
||||
return [[0, 0, 0], [77, 77, 77], [255, 255, 255], [0, 255, 0], [255, 0, 0],
|
||||
[0, 0, 255], [102, 51, 0], [255, 255, 0], [0, 207, 250],
|
||||
[255, 166, 0], [0, 204, 204]]
|
||||
|
||||
|
||||
dataset_aliases = {
|
||||
'cityscapes': ['cityscapes'],
|
||||
'ade': ['ade', 'ade20k'],
|
||||
'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'],
|
||||
'pcontext': ['pcontext', 'pascal_context', 'voc2010'],
|
||||
'loveda': ['loveda'],
|
||||
'potsdam': ['potsdam'],
|
||||
'vaihingen': ['vaihingen'],
|
||||
'cocostuff': [
|
||||
'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff',
|
||||
'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k',
|
||||
'coco_stuff164k'
|
||||
],
|
||||
'isaid': ['isaid', 'iSAID'],
|
||||
'stare': ['stare', 'STARE'],
|
||||
'lip': ['LIP', 'lip'],
|
||||
'mapillary_v1': ['mapillary_v1'],
|
||||
'mapillary_v2': ['mapillary_v2'],
|
||||
'bdd100k': ['bdd100k'],
|
||||
'hsidrive': [
|
||||
'hsidrive', 'HSIDrive', 'HSI-Drive', 'hsidrive20', 'HSIDrive20',
|
||||
'HSI-Drive20'
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def get_classes(dataset):
|
||||
"""Get class names of a dataset."""
|
||||
alias2name = {}
|
||||
for name, aliases in dataset_aliases.items():
|
||||
for alias in aliases:
|
||||
alias2name[alias] = name
|
||||
|
||||
if is_str(dataset):
|
||||
if dataset in alias2name:
|
||||
labels = eval(alias2name[dataset] + '_classes()')
|
||||
else:
|
||||
raise ValueError(f'Unrecognized dataset: {dataset}')
|
||||
else:
|
||||
raise TypeError(f'dataset must a str, but got {type(dataset)}')
|
||||
return labels
|
||||
|
||||
|
||||
def get_palette(dataset):
|
||||
"""Get class palette (RGB) of a dataset."""
|
||||
alias2name = {}
|
||||
for name, aliases in dataset_aliases.items():
|
||||
for alias in aliases:
|
||||
alias2name[alias] = name
|
||||
|
||||
if is_str(dataset):
|
||||
if dataset in alias2name:
|
||||
labels = eval(alias2name[dataset] + '_palette()')
|
||||
else:
|
||||
raise ValueError(f'Unrecognized dataset: {dataset}')
|
||||
else:
|
||||
raise TypeError(f'dataset must a str, but got {type(dataset)}')
|
||||
return labels
|
||||
18
finetune/mmseg/utils/collect_env.py
Normal file
18
finetune/mmseg/utils/collect_env.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmengine.utils import get_git_hash
|
||||
from mmengine.utils.dl_utils import collect_env as collect_base_env
|
||||
|
||||
import mmseg
|
||||
|
||||
|
||||
def collect_env():
|
||||
"""Collect the information of the running environments."""
|
||||
env_info = collect_base_env()
|
||||
env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
|
||||
|
||||
return env_info
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for name, val in collect_env().items():
|
||||
print(f'{name}: {val}')
|
||||
109
finetune/mmseg/utils/get_templates.py
Normal file
109
finetune/mmseg/utils/get_templates.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
PREDEFINED_TEMPLATES = {
|
||||
'imagenet': [
|
||||
'a bad photo of a {}.',
|
||||
'a photo of many {}.',
|
||||
'a sculpture of a {}.',
|
||||
'a photo of the hard to see {}.',
|
||||
'a low resolution photo of the {}.',
|
||||
'a rendering of a {}.',
|
||||
'graffiti of a {}.',
|
||||
'a bad photo of the {}.',
|
||||
'a cropped photo of the {}.',
|
||||
'a tattoo of a {}.',
|
||||
'the embroidered {}.',
|
||||
'a photo of a hard to see {}.',
|
||||
'a bright photo of a {}.',
|
||||
'a photo of a clean {}.',
|
||||
'a photo of a dirty {}.',
|
||||
'a dark photo of the {}.',
|
||||
'a drawing of a {}.',
|
||||
'a photo of my {}.',
|
||||
'the plastic {}.',
|
||||
'a photo of the cool {}.',
|
||||
'a close-up photo of a {}.',
|
||||
'a black and white photo of the {}.',
|
||||
'a painting of the {}.',
|
||||
'a painting of a {}.',
|
||||
'a pixelated photo of the {}.',
|
||||
'a sculpture of the {}.',
|
||||
'a bright photo of the {}.',
|
||||
'a cropped photo of a {}.',
|
||||
'a plastic {}.',
|
||||
'a photo of the dirty {}.',
|
||||
'a jpeg corrupted photo of a {}.',
|
||||
'a blurry photo of the {}.',
|
||||
'a photo of the {}.',
|
||||
'a good photo of the {}.',
|
||||
'a rendering of the {}.',
|
||||
'a {} in a video game.',
|
||||
'a photo of one {}.',
|
||||
'a doodle of a {}.',
|
||||
'a close-up photo of the {}.',
|
||||
'a photo of a {}.',
|
||||
'the origami {}.',
|
||||
'the {} in a video game.',
|
||||
'a sketch of a {}.',
|
||||
'a doodle of the {}.',
|
||||
'a origami {}.',
|
||||
'a low resolution photo of a {}.',
|
||||
'the toy {}.',
|
||||
'a rendition of the {}.',
|
||||
'a photo of the clean {}.',
|
||||
'a photo of a large {}.',
|
||||
'a rendition of a {}.',
|
||||
'a photo of a nice {}.',
|
||||
'a photo of a weird {}.',
|
||||
'a blurry photo of a {}.',
|
||||
'a cartoon {}.',
|
||||
'art of a {}.',
|
||||
'a sketch of the {}.',
|
||||
'a embroidered {}.',
|
||||
'a pixelated photo of a {}.',
|
||||
'itap of the {}.',
|
||||
'a jpeg corrupted photo of the {}.',
|
||||
'a good photo of a {}.',
|
||||
'a plushie {}.',
|
||||
'a photo of the nice {}.',
|
||||
'a photo of the small {}.',
|
||||
'a photo of the weird {}.',
|
||||
'the cartoon {}.',
|
||||
'art of the {}.',
|
||||
'a drawing of the {}.',
|
||||
'a photo of the large {}.',
|
||||
'a black and white photo of a {}.',
|
||||
'the plushie {}.',
|
||||
'a dark photo of a {}.',
|
||||
'itap of a {}.',
|
||||
'graffiti of the {}.',
|
||||
'a toy {}.',
|
||||
'itap of my {}.',
|
||||
'a photo of a cool {}.',
|
||||
'a photo of a small {}.',
|
||||
'a tattoo of the {}.',
|
||||
],
|
||||
'vild': [
|
||||
'a photo of a {}.',
|
||||
'This is a photo of a {}',
|
||||
'There is a {} in the scene',
|
||||
'There is the {} in the scene',
|
||||
'a photo of a {} in the scene',
|
||||
'a photo of a small {}.',
|
||||
'a photo of a medium {}.',
|
||||
'a photo of a large {}.',
|
||||
'This is a photo of a small {}.',
|
||||
'This is a photo of a medium {}.',
|
||||
'This is a photo of a large {}.',
|
||||
'There is a small {} in the scene.',
|
||||
'There is a medium {} in the scene.',
|
||||
'There is a large {} in the scene.',
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def get_predefined_templates(template_set_name: str) -> List[str]:
|
||||
if template_set_name not in PREDEFINED_TEMPLATES:
|
||||
raise ValueError(f'Template set {template_set_name} not found')
|
||||
return PREDEFINED_TEMPLATES[template_set_name]
|
||||
42
finetune/mmseg/utils/io.py
Normal file
42
finetune/mmseg/utils/io.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import gzip
|
||||
import io
|
||||
import pickle
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray:
|
||||
"""Data decoding from bytes.
|
||||
|
||||
Args:
|
||||
content (bytes): The data bytes got from files or other streams.
|
||||
backend (str): The data decoding backend type. Options are 'numpy',
|
||||
'nifti', 'cv2' and 'pickle'. Defaults to 'numpy'.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Loaded data array.
|
||||
"""
|
||||
if backend == 'pickle':
|
||||
data = pickle.loads(content)
|
||||
else:
|
||||
with io.BytesIO(content) as f:
|
||||
if backend == 'nifti':
|
||||
f = gzip.open(f)
|
||||
try:
|
||||
from nibabel import FileHolder, Nifti1Image
|
||||
except ImportError:
|
||||
print('nifti files io depends on nibabel, please run'
|
||||
'`pip install nibabel` to install it')
|
||||
fh = FileHolder(fileobj=f)
|
||||
data = Nifti1Image.from_file_map({'header': fh, 'image': fh})
|
||||
data = Nifti1Image.from_bytes(data.to_bytes()).get_fdata()
|
||||
elif backend == 'numpy':
|
||||
data = np.load(f)
|
||||
elif backend == 'cv2':
|
||||
data = np.frombuffer(f.read(), dtype=np.uint8)
|
||||
data = cv2.imdecode(data, cv2.IMREAD_UNCHANGED)
|
||||
else:
|
||||
raise ValueError
|
||||
return data
|
||||
205
finetune/mmseg/utils/mask_classification.py
Normal file
205
finetune/mmseg/utils/mask_classification.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from mmcv.ops import point_sample
|
||||
from mmengine.structures import InstanceData
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import TASK_UTILS
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
|
||||
|
||||
def seg_data_to_instance_data(ignore_index: int,
|
||||
batch_data_samples: SampleList):
|
||||
"""Convert the paradigm of ground truth from semantic segmentation to
|
||||
instance segmentation.
|
||||
|
||||
Args:
|
||||
ignore_index (int): The label index to be ignored.
|
||||
batch_data_samples (List[SegDataSample]): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_sem_seg`.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: A tuple contains two lists.
|
||||
- batch_gt_instances (List[InstanceData]): Batch of
|
||||
gt_instance. It usually includes ``labels``, each is
|
||||
unique ground truth label id of images, with
|
||||
shape (num_gt, ) and ``masks``, each is ground truth
|
||||
masks of each instances of a image, shape (num_gt, h, w).
|
||||
- batch_img_metas (List[Dict]): List of image meta information.
|
||||
"""
|
||||
batch_gt_instances = []
|
||||
|
||||
for data_sample in batch_data_samples:
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
classes = torch.unique(
|
||||
gt_sem_seg,
|
||||
sorted=False,
|
||||
return_inverse=False,
|
||||
return_counts=False)
|
||||
|
||||
# remove ignored region
|
||||
gt_labels = classes[classes != ignore_index]
|
||||
|
||||
masks = []
|
||||
for class_id in gt_labels:
|
||||
masks.append(gt_sem_seg == class_id)
|
||||
|
||||
if len(masks) == 0:
|
||||
gt_masks = torch.zeros(
|
||||
(0, gt_sem_seg.shape[-2],
|
||||
gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
|
||||
else:
|
||||
gt_masks = torch.stack(masks).squeeze(1).long()
|
||||
|
||||
instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
|
||||
batch_gt_instances.append(instance_data)
|
||||
return batch_gt_instances
|
||||
|
||||
|
||||
class MatchMasks:
|
||||
"""Match the predictions to category labels.
|
||||
|
||||
Args:
|
||||
num_points (int): the number of sampled points to compute cost.
|
||||
num_queries (int): the number of prediction masks.
|
||||
num_classes (int): the number of classes.
|
||||
assigner (BaseAssigner): the assigner to compute matching.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_points: int,
|
||||
num_queries: int,
|
||||
num_classes: int,
|
||||
assigner: ConfigType = None):
|
||||
assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \
|
||||
'cannot be None'
|
||||
assert num_points > 0, 'num_points should be a positive integer.'
|
||||
self.num_points = num_points
|
||||
self.num_queries = num_queries
|
||||
self.num_classes = num_classes
|
||||
self.assigner = TASK_UTILS.build(assigner)
|
||||
|
||||
def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor],
|
||||
batch_gt_instances: List[InstanceData]) -> Tuple:
|
||||
"""Compute best mask matches for all images for a decoder layer.
|
||||
|
||||
Args:
|
||||
cls_scores (List[Tensor]): Mask score logits from a single
|
||||
decoder layer for all images. Each with shape (num_queries,
|
||||
cls_out_channels).
|
||||
mask_preds (List[Tensor]): Mask logits from a single decoder
|
||||
layer for all images. Each with shape (num_queries, h, w).
|
||||
batch_gt_instances (List[InstanceData]): each contains
|
||||
``labels`` and ``masks``.
|
||||
|
||||
Returns:
|
||||
tuple: a tuple containing the following targets.
|
||||
|
||||
- labels (List[Tensor]): Labels of all images.\
|
||||
Each with shape (num_queries, ).
|
||||
- mask_targets (List[Tensor]): Mask targets of\
|
||||
all images. Each with shape (num_queries, h, w).
|
||||
- mask_weights (List[Tensor]): Mask weights of\
|
||||
all images. Each with shape (num_queries, ).
|
||||
- avg_factor (int): Average factor that is used to
|
||||
average the loss. `avg_factor` is usually equal
|
||||
to the number of positive priors.
|
||||
"""
|
||||
batch_size = cls_scores.shape[0]
|
||||
results = dict({
|
||||
'labels': [],
|
||||
'mask_targets': [],
|
||||
'mask_weights': [],
|
||||
})
|
||||
for i in range(batch_size):
|
||||
labels, mask_targets, mask_weights\
|
||||
= self._get_targets_single(cls_scores[i],
|
||||
mask_preds[i],
|
||||
batch_gt_instances[i])
|
||||
results['labels'].append(labels)
|
||||
results['mask_targets'].append(mask_targets)
|
||||
results['mask_weights'].append(mask_weights)
|
||||
|
||||
# shape (batch_size, num_queries)
|
||||
labels = torch.stack(results['labels'], dim=0)
|
||||
# shape (batch_size, num_gts, h, w)
|
||||
mask_targets = torch.cat(results['mask_targets'], dim=0)
|
||||
# shape (batch_size, num_queries)
|
||||
mask_weights = torch.stack(results['mask_weights'], dim=0)
|
||||
|
||||
avg_factor = sum(
|
||||
[len(gt_instances.labels) for gt_instances in batch_gt_instances])
|
||||
|
||||
res = (labels, mask_targets, mask_weights, avg_factor)
|
||||
|
||||
return res
|
||||
|
||||
def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
|
||||
gt_instances: InstanceData) \
|
||||
-> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""Compute a set of best mask matches for one image.
|
||||
|
||||
Args:
|
||||
cls_score (Tensor): Mask score logits from a single decoder layer
|
||||
for one image. Shape (num_queries, cls_out_channels).
|
||||
mask_pred (Tensor): Mask logits for a single decoder layer for one
|
||||
image. Shape (num_queries, h, w).
|
||||
gt_instances (:obj:`InstanceData`): It contains ``labels`` and
|
||||
``masks``.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: A tuple containing the following for one image.
|
||||
|
||||
- labels (Tensor): Labels of each image. \
|
||||
shape (num_queries, ).
|
||||
- mask_targets (Tensor): Mask targets of each image. \
|
||||
shape (num_queries, h, w).
|
||||
- mask_weights (Tensor): Mask weights of each image. \
|
||||
shape (num_queries, ).
|
||||
"""
|
||||
gt_labels = gt_instances.labels
|
||||
gt_masks = gt_instances.masks
|
||||
# when "gt_labels" is empty, classify all queries to background
|
||||
if len(gt_labels) == 0:
|
||||
labels = gt_labels.new_full((self.num_queries, ),
|
||||
self.num_classes,
|
||||
dtype=torch.long)
|
||||
mask_targets = gt_labels
|
||||
mask_weights = gt_labels.new_zeros((self.num_queries, ))
|
||||
return labels, mask_targets, mask_weights
|
||||
# sample points
|
||||
num_queries = cls_score.shape[0]
|
||||
num_gts = gt_labels.shape[0]
|
||||
|
||||
point_coords = torch.rand((1, self.num_points, 2),
|
||||
device=cls_score.device)
|
||||
# shape (num_queries, num_points)
|
||||
mask_points_pred = point_sample(
|
||||
mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
|
||||
1)).squeeze(1)
|
||||
# shape (num_gts, num_points)
|
||||
gt_points_masks = point_sample(
|
||||
gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
|
||||
1)).squeeze(1)
|
||||
|
||||
sampled_gt_instances = InstanceData(
|
||||
labels=gt_labels, masks=gt_points_masks)
|
||||
sampled_pred_instances = InstanceData(
|
||||
scores=cls_score, masks=mask_points_pred)
|
||||
# assign and sample
|
||||
matched_quiery_inds, matched_label_inds = self.assigner.assign(
|
||||
pred_instances=sampled_pred_instances,
|
||||
gt_instances=sampled_gt_instances)
|
||||
labels = gt_labels.new_full((self.num_queries, ),
|
||||
self.num_classes,
|
||||
dtype=torch.long)
|
||||
labels[matched_quiery_inds] = gt_labels[matched_label_inds]
|
||||
|
||||
mask_weights = gt_labels.new_zeros((self.num_queries, ))
|
||||
mask_weights[matched_quiery_inds] = 1
|
||||
mask_targets = gt_masks[matched_label_inds]
|
||||
|
||||
return labels, mask_targets, mask_weights
|
||||
128
finetune/mmseg/utils/misc.py
Normal file
128
finetune/mmseg/utils/misc.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .typing_utils import SampleList
|
||||
|
||||
|
||||
def add_prefix(inputs, prefix):
|
||||
"""Add prefix for dict.
|
||||
|
||||
Args:
|
||||
inputs (dict): The input dict with str keys.
|
||||
prefix (str): The prefix to add.
|
||||
|
||||
Returns:
|
||||
|
||||
dict: The dict with keys updated with ``prefix``.
|
||||
"""
|
||||
|
||||
outputs = dict()
|
||||
for name, value in inputs.items():
|
||||
outputs[f'{prefix}.{name}'] = value
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def stack_batch(inputs: List[torch.Tensor],
|
||||
data_samples: Optional[SampleList] = None,
|
||||
size: Optional[tuple] = None,
|
||||
size_divisor: Optional[int] = None,
|
||||
pad_val: Union[int, float] = 0,
|
||||
seg_pad_val: Union[int, float] = 255) -> torch.Tensor:
|
||||
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs
|
||||
to the max shape use the right bottom padding mode.
|
||||
|
||||
Args:
|
||||
inputs (List[Tensor]): The input multiple tensors. each is a
|
||||
CHW 3D-tensor.
|
||||
data_samples (list[:obj:`SegDataSample`]): The list of data samples.
|
||||
It usually includes information such as `gt_sem_seg`.
|
||||
size (tuple, optional): Fixed padding size.
|
||||
size_divisor (int, optional): The divisor of padded size.
|
||||
pad_val (int, float): The padding value. Defaults to 0
|
||||
seg_pad_val (int, float): The padding value. Defaults to 255
|
||||
|
||||
Returns:
|
||||
Tensor: The 4D-tensor.
|
||||
List[:obj:`SegDataSample`]: After the padding of the gt_seg_map.
|
||||
"""
|
||||
assert isinstance(inputs, list), \
|
||||
f'Expected input type to be list, but got {type(inputs)}'
|
||||
assert len({tensor.ndim for tensor in inputs}) == 1, \
|
||||
f'Expected the dimensions of all inputs must be the same, ' \
|
||||
f'but got {[tensor.ndim for tensor in inputs]}'
|
||||
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
|
||||
f'but got {inputs[0].ndim}'
|
||||
assert len({tensor.shape[0] for tensor in inputs}) == 1, \
|
||||
f'Expected the channels of all inputs must be the same, ' \
|
||||
f'but got {[tensor.shape[0] for tensor in inputs]}'
|
||||
|
||||
# only one of size and size_divisor should be valid
|
||||
assert (size is not None) ^ (size_divisor is not None), \
|
||||
'only one of size and size_divisor should be valid'
|
||||
|
||||
padded_inputs = []
|
||||
padded_samples = []
|
||||
inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs]
|
||||
max_size = np.stack(inputs_sizes).max(0)
|
||||
if size_divisor is not None and size_divisor > 1:
|
||||
# the last two dims are H,W, both subject to divisibility requirement
|
||||
max_size = (max_size +
|
||||
(size_divisor - 1)) // size_divisor * size_divisor
|
||||
|
||||
for i in range(len(inputs)):
|
||||
tensor = inputs[i]
|
||||
if size is not None:
|
||||
width = max(size[-1] - tensor.shape[-1], 0)
|
||||
height = max(size[-2] - tensor.shape[-2], 0)
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
padding_size = (0, width, 0, height)
|
||||
elif size_divisor is not None:
|
||||
width = max(max_size[-1] - tensor.shape[-1], 0)
|
||||
height = max(max_size[-2] - tensor.shape[-2], 0)
|
||||
padding_size = (0, width, 0, height)
|
||||
else:
|
||||
padding_size = [0, 0, 0, 0]
|
||||
|
||||
# pad img
|
||||
pad_img = F.pad(tensor, padding_size, value=pad_val)
|
||||
padded_inputs.append(pad_img)
|
||||
# pad gt_sem_seg
|
||||
if data_samples is not None:
|
||||
data_sample = data_samples[i]
|
||||
pad_shape = None
|
||||
if 'gt_sem_seg' in data_sample:
|
||||
gt_sem_seg = data_sample.gt_sem_seg.data
|
||||
del data_sample.gt_sem_seg.data
|
||||
data_sample.gt_sem_seg.data = F.pad(
|
||||
gt_sem_seg, padding_size, value=seg_pad_val)
|
||||
pad_shape = data_sample.gt_sem_seg.shape
|
||||
if 'gt_edge_map' in data_sample:
|
||||
gt_edge_map = data_sample.gt_edge_map.data
|
||||
del data_sample.gt_edge_map.data
|
||||
data_sample.gt_edge_map.data = F.pad(
|
||||
gt_edge_map, padding_size, value=seg_pad_val)
|
||||
pad_shape = data_sample.gt_edge_map.shape
|
||||
if 'gt_depth_map' in data_sample:
|
||||
gt_depth_map = data_sample.gt_depth_map.data
|
||||
del data_sample.gt_depth_map.data
|
||||
data_sample.gt_depth_map.data = F.pad(
|
||||
gt_depth_map, padding_size, value=seg_pad_val)
|
||||
pad_shape = data_sample.gt_depth_map.shape
|
||||
data_sample.set_metainfo({
|
||||
'img_shape': tensor.shape[-2:],
|
||||
'pad_shape': pad_shape,
|
||||
'padding_size': padding_size
|
||||
})
|
||||
padded_samples.append(data_sample)
|
||||
else:
|
||||
padded_samples.append(
|
||||
dict(
|
||||
img_padding_size=padding_size,
|
||||
pad_shape=pad_img.shape[-2:]))
|
||||
|
||||
return torch.stack(padded_inputs, dim=0), padded_samples
|
||||
40
finetune/mmseg/utils/set_env.py
Normal file
40
finetune/mmseg/utils/set_env.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import datetime
|
||||
import warnings
|
||||
|
||||
from mmengine import DefaultScope
|
||||
|
||||
|
||||
def register_all_modules(init_default_scope: bool = True) -> None:
|
||||
"""Register all modules in mmseg into the registries.
|
||||
|
||||
Args:
|
||||
init_default_scope (bool): Whether initialize the mmseg default scope.
|
||||
When `init_default_scope=True`, the global default scope will be
|
||||
set to `mmseg`, and all registries will build modules from mmseg's
|
||||
registry node. To understand more about the registry, please refer
|
||||
to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
|
||||
Defaults to True.
|
||||
""" # noqa
|
||||
import mmseg.datasets # noqa: F401,F403
|
||||
import mmseg.engine # noqa: F401,F403
|
||||
import mmseg.evaluation # noqa: F401,F403
|
||||
import mmseg.models # noqa: F401,F403
|
||||
import mmseg.structures # noqa: F401,F403
|
||||
|
||||
if init_default_scope:
|
||||
never_created = DefaultScope.get_current_instance() is None \
|
||||
or not DefaultScope.check_instance_created('mmseg')
|
||||
if never_created:
|
||||
DefaultScope.get_instance('mmseg', scope_name='mmseg')
|
||||
return
|
||||
current_scope = DefaultScope.get_current_instance()
|
||||
if current_scope.scope_name != 'mmseg':
|
||||
warnings.warn('The current default scope '
|
||||
f'"{current_scope.scope_name}" is not "mmseg", '
|
||||
'`register_all_modules` will force the current'
|
||||
'default scope to be "mmseg". If this is not '
|
||||
'expected, please set `init_default_scope=False`.')
|
||||
# avoid name conflict
|
||||
new_instance_name = f'mmseg-{datetime.datetime.now()}'
|
||||
DefaultScope.get_instance(new_instance_name, scope_name='mmseg')
|
||||
240
finetune/mmseg/utils/tokenizer.py
Normal file
240
finetune/mmseg/utils/tokenizer.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""CLIP tokenizer.
|
||||
|
||||
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright
|
||||
(c) 2021 OpenAI.
|
||||
"""
|
||||
import gzip
|
||||
import html
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import List, Union
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
return os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
'bpe_simple_vocab_16e6.txt.gz')
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
|
||||
The reversible bpe codes work on unicode strings. This means you need a
|
||||
large # of unicode characters in your vocab if you want to avoid UNKs. When
|
||||
you're at something like a 10B token dataset you end up needing around 5K
|
||||
for decent coverage. This is a significant percentage of your normal, say,
|
||||
32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
|
||||
unicode strings. And avoids mapping to whitespace/control characters the
|
||||
bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord('!'),
|
||||
ord('~') + 1)) + list(range(
|
||||
ord('¡'),
|
||||
ord('¬') + 1)) + list(range(ord('®'),
|
||||
ord('ÿ') + 1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
|
||||
Word is represented as tuple of symbols (symbols being variable-length
|
||||
strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
class SimpleTokenizer:
|
||||
|
||||
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
|
||||
merges = merges[1:49152 - 256 - 2 + 1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v + '</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
if not special_tokens:
|
||||
special_tokens = ['<start_of_text>', '<end_of_text>']
|
||||
else:
|
||||
special_tokens = ['<start_of_text>', '<end_of_text>'
|
||||
] + special_tokens
|
||||
vocab.extend(special_tokens)
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {t: t for t in special_tokens}
|
||||
special = '|'.join(special_tokens)
|
||||
self.pat = re.compile(
|
||||
special +
|
||||
r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
re.IGNORECASE)
|
||||
|
||||
self.vocab_size = len(self.encoder)
|
||||
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + (token[-1] + '</w>', )
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token + '</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(
|
||||
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except: # noqa: E722, E261
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[
|
||||
i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b]
|
||||
for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token]
|
||||
for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||
'utf-8', errors='replace').replace('</w>', ' ')
|
||||
return text
|
||||
|
||||
|
||||
_tokenizer = SimpleTokenizer()
|
||||
|
||||
|
||||
def decode(output_ids: torch.Tensor):
|
||||
output_ids = output_ids.cpu().numpy()
|
||||
return _tokenizer.decode(output_ids)
|
||||
|
||||
|
||||
def tokenize(texts: Union[str, List[str]],
|
||||
context_length: int = 77) -> torch.LongTensor:
|
||||
"""Returns the tokenized representation of given input string(s)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts : Union[str, List[str]]
|
||||
An input string or a list of input strings to tokenize
|
||||
context_length : int
|
||||
The context length to use; all CLIP models use 77 as the context length
|
||||
|
||||
Returns
|
||||
-------
|
||||
A two-dimensional tensor containing the resulting tokens,
|
||||
shape = [number of input strings, context_length]
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
sot_token = _tokenizer.encoder['<start_of_text>']
|
||||
eot_token = _tokenizer.encoder['<end_of_text>']
|
||||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
|
||||
for text in texts]
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length:
|
||||
tokens = tokens[:context_length] # Truncate
|
||||
tokens[-1] = eot_token
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class HFTokenizer:
|
||||
"""HuggingFace tokenizer wrapper."""
|
||||
|
||||
def __init__(self, tokenizer_name: str):
|
||||
from transformers import AutoTokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
|
||||
def save_pretrained(self, dest):
|
||||
self.tokenizer.save_pretrained(dest)
|
||||
|
||||
def __call__(self,
|
||||
texts: Union[str, List[str]],
|
||||
context_length: int = 77) -> torch.Tensor:
|
||||
# same cleaning as for default tokenizer, except lowercasing
|
||||
# adding lower (for case-sensitive tokenizers) will make it
|
||||
# more robust but less sensitive to nuance
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
texts = [whitespace_clean(basic_clean(text)) for text in texts]
|
||||
input_ids = self.tokenizer(
|
||||
texts,
|
||||
return_tensors='pt',
|
||||
max_length=context_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
).input_ids
|
||||
return input_ids
|
||||
25
finetune/mmseg/utils/typing_utils.py
Normal file
25
finetune/mmseg/utils/typing_utils.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Collecting some commonly used type hint in mmflow."""
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from mmengine.config import ConfigDict
|
||||
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
# Type hint of config data
|
||||
ConfigType = Union[ConfigDict, dict]
|
||||
OptConfigType = Optional[ConfigType]
|
||||
# Type hint of one or more config data
|
||||
MultiConfig = Union[ConfigType, Sequence[ConfigType]]
|
||||
OptMultiConfig = Optional[MultiConfig]
|
||||
|
||||
SampleList = Sequence[SegDataSample]
|
||||
OptSampleList = Optional[SampleList]
|
||||
|
||||
# Type hint of Tensor
|
||||
TensorDict = Dict[str, torch.Tensor]
|
||||
TensorList = Sequence[torch.Tensor]
|
||||
|
||||
ForwardResults = Union[Dict[str, torch.Tensor], List[SegDataSample],
|
||||
Tuple[torch.Tensor], torch.Tensor]
|
||||
Reference in New Issue
Block a user