This commit is contained in:
esenke
2025-12-08 22:16:31 +08:00
commit 01adcfdf60
305 changed files with 50879 additions and 0 deletions

View File

@@ -0,0 +1,70 @@
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .class_names import (ade_classes, ade_palette, bdd100k_classes,
bdd100k_palette, cityscapes_classes,
cityscapes_palette, cocostuff_classes,
cocostuff_palette, dataset_aliases, get_classes,
get_palette, isaid_classes, isaid_palette,
loveda_classes, loveda_palette, potsdam_classes,
potsdam_palette, stare_classes, stare_palette,
synapse_classes, synapse_palette, vaihingen_classes,
vaihingen_palette, voc_classes, voc_palette)
# yapf: enable
from .collect_env import collect_env
from .get_templates import get_predefined_templates
from .io import datafrombytes
from .misc import add_prefix, stack_batch
from .set_env import register_all_modules
from .tokenizer import tokenize
from .typing_utils import (ConfigType, ForwardResults, MultiConfig,
OptConfigType, OptMultiConfig, OptSampleList,
SampleList, TensorDict, TensorList)
# isort: off
from .mask_classification import MatchMasks, seg_data_to_instance_data
__all__ = [
'collect_env',
'register_all_modules',
'stack_batch',
'add_prefix',
'ConfigType',
'OptConfigType',
'MultiConfig',
'OptMultiConfig',
'SampleList',
'OptSampleList',
'TensorDict',
'TensorList',
'ForwardResults',
'cityscapes_classes',
'ade_classes',
'voc_classes',
'cocostuff_classes',
'loveda_classes',
'potsdam_classes',
'vaihingen_classes',
'isaid_classes',
'stare_classes',
'cityscapes_palette',
'ade_palette',
'voc_palette',
'cocostuff_palette',
'loveda_palette',
'potsdam_palette',
'vaihingen_palette',
'isaid_palette',
'stare_palette',
'dataset_aliases',
'get_classes',
'get_palette',
'datafrombytes',
'synapse_palette',
'synapse_classes',
'get_predefined_templates',
'tokenize',
'seg_data_to_instance_data',
'MatchMasks',
'bdd100k_classes',
'bdd100k_palette',
]

Binary file not shown.

View File

@@ -0,0 +1,548 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils import is_str
def cityscapes_classes():
"""Cityscapes class names for external use."""
return [
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle'
]
def ade_classes():
"""ADE20K class names for external use."""
return [
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
'clock', 'flag'
]
def voc_classes():
"""Pascal VOC class names for external use."""
return [
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
'tvmonitor'
]
def pcontext_classes():
"""Pascal Context class names for external use."""
return [
'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird',
'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat',
'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain',
'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground',
'horse', 'keyboard', 'light', 'motorbike', 'mountain', 'mouse',
'person', 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track',
'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window',
'wood'
]
def cocostuff_classes():
"""CocoStuff class names for external use."""
return [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',
'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',
'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',
'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'paper',
'pavement', 'pillow', 'plant-other', 'plastic', 'platform',
'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',
'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',
'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',
'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',
'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',
'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
'window-blind', 'window-other', 'wood'
]
def loveda_classes():
"""LoveDA class names for external use."""
return [
'background', 'building', 'road', 'water', 'barren', 'forest',
'agricultural'
]
def potsdam_classes():
"""Potsdam class names for external use."""
return [
'impervious_surface', 'building', 'low_vegetation', 'tree', 'car',
'clutter'
]
def vaihingen_classes():
"""Vaihingen class names for external use."""
return [
'impervious_surface', 'building', 'low_vegetation', 'tree', 'car',
'clutter'
]
def isaid_classes():
"""iSAID class names for external use."""
return [
'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court',
'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle',
'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout',
'Soccer_ball_field', 'plane', 'Harbor'
]
def stare_classes():
"""stare class names for external use."""
return ['background', 'vessel']
def mapillary_v1_classes():
"""mapillary_v1 class names for external use."""
return [
'Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier',
'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking',
'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane', 'Sidewalk',
'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
'Other Rider', 'Lane Marking - Crosswalk', 'Lane Marking - General',
'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water',
'Banner', 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin',
'CCTV Camera', 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
'Phone Booth', 'Pothole', 'Street Light', 'Pole', 'Traffic Sign Frame',
'Utility Pole', 'Traffic Light', 'Traffic Sign (Back)',
'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car',
'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer',
'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'
]
def mapillary_v1_palette():
"""mapillary_v1_ palette for external use."""
return [[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
[180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255],
[140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232],
[150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60],
[255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128],
[255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180],
[190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30],
[255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220],
[220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40],
[33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150],
[210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80],
[250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20],
[119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142],
[0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]]
def mapillary_v2_classes():
"""mapillary_v2 class names for external use."""
return [
'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block', 'Curb',
'Fence', 'Guard Rail', 'Barrier', 'Road Median', 'Road Side',
'Lane Separator', 'Temporary Barrier', 'Wall', 'Bike Lane',
'Crosswalk - Plain', 'Curb Cut', 'Driveway', 'Parking',
'Parking Aisle', 'Pedestrian Area', 'Rail Track', 'Road',
'Road Shoulder', 'Service Lane', 'Sidewalk', 'Traffic Island',
'Bridge', 'Building', 'Garage', 'Tunnel', 'Person', 'Person Group',
'Bicyclist', 'Motorcyclist', 'Other Rider',
'Lane Marking - Dashed Line', 'Lane Marking - Straight Line',
'Lane Marking - Zigzag Line', 'Lane Marking - Ambiguous',
'Lane Marking - Arrow (Left)', 'Lane Marking - Arrow (Other)',
'Lane Marking - Arrow (Right)',
'Lane Marking - Arrow (Split Left or Straight)',
'Lane Marking - Arrow (Split Right or Straight)',
'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk',
'Lane Marking - Give Way (Row)', 'Lane Marking - Give Way (Single)',
'Lane Marking - Hatched (Chevron)',
'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other',
'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)',
'Lane Marking - Symbol (Other)', 'Lane Marking - Text',
'Lane Marking (only) - Dashed Line', 'Lane Marking (only) - Crosswalk',
'Lane Marking (only) - Other', 'Lane Marking (only) - Test',
'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water',
'Banner', 'Bench', 'Bike Rack', 'Catch Basin', 'CCTV Camera',
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Parking Meter',
'Phone Booth', 'Pothole', 'Signage - Advertisement',
'Signage - Ambiguous', 'Signage - Back', 'Signage - Information',
'Signage - Other', 'Signage - Store', 'Street Light', 'Pole',
'Pole Group', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Cone',
'Traffic Light - General (Single)', 'Traffic Light - Pedestrians',
'Traffic Light - General (Upright)',
'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists',
'Traffic Light - Other', 'Traffic Sign - Ambiguous',
'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)',
'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)',
'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)',
'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat',
'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle',
'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve',
'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static', 'Unlabeled'
]
def mapillary_v2_palette():
"""mapillary_v2_ palette for external use."""
return [[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32],
[196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150],
[250, 170, 33], [250, 170, 34], [128, 128, 128], [250, 170, 35],
[102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170],
[250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110], [110, 110, 110],
[244, 35, 232], [128, 196, 128], [150, 100, 100], [70, 70, 70],
[150, 150, 150], [150, 120, 90], [220, 20, 60], [220, 20, 60],
[255, 0, 0], [255, 0, 100], [255, 0, 200], [255, 255, 255],
[255, 255, 255], [250, 170, 29], [250, 170, 28], [250, 170, 26],
[250, 170, 25], [250, 170, 24], [250, 170, 22], [250, 170, 21],
[250, 170, 20], [255, 255, 255], [250, 170, 19], [250, 170, 18],
[250, 170, 12], [250, 170, 11], [255, 255, 255], [255, 255, 255],
[250, 170, 16], [250, 170, 15], [250, 170, 15], [255, 255, 255],
[255, 255, 255], [255, 255, 255], [255, 255, 255], [64, 170, 64],
[230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152],
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
[100, 140, 180], [220, 128, 128], [222, 40, 40], [100, 170, 30],
[40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255],
[142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30],
[250, 173, 30], [250, 174, 30], [250, 175, 30], [250, 176, 30],
[210, 170, 100], [153, 153, 153], [153, 153, 153], [128, 128, 128],
[0, 0, 80], [210, 60, 60], [250, 170, 30], [250, 170, 30],
[250, 170, 30], [250, 170, 30], [250, 170, 30], [250, 170, 30],
[192, 192, 192], [192, 192, 192], [192, 192, 192], [220, 220, 0],
[220, 220, 0], [0, 0, 196], [192, 192, 192], [220, 220, 0],
[140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100],
[0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64],
[0, 0, 110], [0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170],
[32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81],
[111, 111, 0], [0, 0, 0]]
def cityscapes_palette():
"""Cityscapes palette for external use."""
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
[0, 0, 230], [119, 11, 32]]
def ade_palette():
"""ADE20K palette for external use."""
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
[102, 255, 0], [92, 0, 255]]
def voc_palette():
"""Pascal VOC palette for external use."""
return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
def pcontext_palette():
"""Pascal Context palette for external use."""
return [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
[120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
[4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
[120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
[204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
[61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
[255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
[112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
[10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
[102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
[0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
[235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
[250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
[255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
[0, 235, 255], [0, 173, 255], [31, 0, 255]]
def cocostuff_palette():
"""CocoStuff palette for external use."""
return [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
[0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
[0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
[0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
[0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
[128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
[64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160], [0, 32, 0],
[0, 128, 128], [64, 128, 160], [128, 160, 0], [0, 128, 0],
[192, 128, 32], [128, 96, 128], [0, 0, 128], [64, 0, 32],
[0, 224, 128], [128, 0, 0], [192, 0, 160], [0, 96, 128],
[128, 128, 128], [64, 0, 160], [128, 224, 128], [128, 128, 64],
[192, 0, 32], [128, 96, 0], [128, 0, 192], [0, 128, 32],
[64, 224, 0], [0, 0, 64], [128, 128, 160], [64, 96, 0],
[0, 128, 192], [0, 128, 160], [192, 224, 0], [0, 128, 64],
[128, 128, 32], [192, 32, 128], [0, 64, 192], [0, 0, 32],
[64, 160, 128], [128, 64, 64], [128, 0, 160], [64, 32, 128],
[128, 192, 192], [0, 0, 160], [192, 160, 128], [128, 192, 0],
[128, 0, 96], [192, 32, 0], [128, 64, 128], [64, 128, 96],
[64, 160, 0], [0, 64, 0], [192, 128, 224], [64, 32, 0],
[0, 192, 128], [64, 128, 224], [192, 160, 0], [0, 192, 0],
[192, 128, 96], [192, 96, 128], [0, 64, 128], [64, 0, 96],
[64, 224, 128], [128, 64, 0], [192, 0, 224], [64, 96, 128],
[128, 192, 128], [64, 0, 224], [192, 224, 128], [128, 192, 64],
[192, 0, 96], [192, 96, 0], [128, 64, 192], [0, 128, 96],
[0, 224, 0], [64, 64, 64], [128, 128, 224], [0, 96, 0],
[64, 192, 192], [0, 128, 224], [128, 224, 0], [64, 192, 64],
[128, 128, 96], [128, 32, 128], [64, 0, 192], [0, 64, 96],
[0, 160, 128], [192, 0, 64], [128, 64, 224], [0, 32, 128],
[192, 128, 192], [0, 64, 224], [128, 160, 128], [192, 128, 0],
[128, 64, 32], [128, 32, 64], [192, 0, 128], [64, 192, 32],
[0, 160, 64], [64, 0, 0], [192, 192, 160], [0, 32, 64],
[64, 128, 128], [64, 192, 160], [128, 160, 64], [64, 128, 0],
[192, 192, 32], [128, 96, 192], [64, 0, 128], [64, 64, 32],
[0, 224, 192], [192, 0, 0], [192, 64, 160], [0, 96, 192],
[192, 128, 128], [64, 64, 160], [128, 224, 192], [192, 128, 64],
[192, 64, 32], [128, 96, 64], [192, 0, 192], [0, 192, 32],
[64, 224, 64], [64, 0, 64], [128, 192, 160], [64, 96, 64],
[64, 128, 192], [0, 192, 160], [192, 224, 64], [64, 128, 64],
[128, 192, 32], [192, 32, 192], [64, 64, 192], [0, 64, 32],
[64, 160, 192], [192, 64, 64], [128, 64, 160], [64, 32, 192],
[192, 192, 192], [0, 64, 160], [192, 160, 192], [192, 192, 0],
[128, 64, 96], [192, 32, 64], [192, 64, 128], [64, 192, 96],
[64, 160, 64], [64, 64, 0]]
def loveda_palette():
"""LoveDA palette for external use."""
return [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
[159, 129, 183], [0, 255, 0], [255, 195, 128]]
def potsdam_palette():
"""Potsdam palette for external use."""
return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
[255, 255, 0], [255, 0, 0]]
def vaihingen_palette():
"""Vaihingen palette for external use."""
return [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
[255, 255, 0], [255, 0, 0]]
def isaid_palette():
"""iSAID palette for external use."""
return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127,
127], [0, 0, 127],
[0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191],
[0, 127, 255], [0, 100, 155]]
def stare_palette():
"""STARE palette for external use."""
return [[120, 120, 120], [6, 230, 230]]
def synapse_palette():
"""Synapse palette for external use."""
return [[0, 0, 0], [0, 0, 255], [0, 255, 0], [255, 0, 0], [0, 255, 255],
[255, 0, 255], [255, 255, 0], [60, 255, 255], [240, 240, 240]]
def synapse_classes():
"""Synapse class names for external use."""
return [
'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney',
'liver', 'pancreas', 'spleen', 'stomach'
]
def lip_classes():
"""LIP class names for external use."""
return [
'background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes',
'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt',
'face', 'leftArm', 'rightArm', 'leftLeg', 'rightLeg', 'leftShoe',
'rightShoe'
]
def lip_palette():
"""LIP palette for external use."""
return [
'Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'UpperClothes',
'Dress', 'Coat', 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt',
'Face', 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe',
'Right-shoe'
]
def bdd100k_classes():
"""BDD100K class names for external use(the class name is compatible with
Cityscapes )."""
return [
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle'
]
def bdd100k_palette():
"""bdd100k palette for external use(same with cityscapes)"""
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
[0, 0, 230], [119, 11, 32]]
def hsidrive_classes():
"""HSI Drive 2.0 class names for external use."""
return [
'unlabelled', 'road', 'road marks', 'vegetation', 'painted metal',
'sky', 'concrete', 'pedestrian', 'water', 'unpainted metal', 'glass'
]
def hsidrive_palette():
"""HSI Drive 2.0 palette for external use."""
return [[0, 0, 0], [77, 77, 77], [255, 255, 255], [0, 255, 0], [255, 0, 0],
[0, 0, 255], [102, 51, 0], [255, 255, 0], [0, 207, 250],
[255, 166, 0], [0, 204, 204]]
dataset_aliases = {
'cityscapes': ['cityscapes'],
'ade': ['ade', 'ade20k'],
'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'],
'pcontext': ['pcontext', 'pascal_context', 'voc2010'],
'loveda': ['loveda'],
'potsdam': ['potsdam'],
'vaihingen': ['vaihingen'],
'cocostuff': [
'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff',
'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k',
'coco_stuff164k'
],
'isaid': ['isaid', 'iSAID'],
'stare': ['stare', 'STARE'],
'lip': ['LIP', 'lip'],
'mapillary_v1': ['mapillary_v1'],
'mapillary_v2': ['mapillary_v2'],
'bdd100k': ['bdd100k'],
'hsidrive': [
'hsidrive', 'HSIDrive', 'HSI-Drive', 'hsidrive20', 'HSIDrive20',
'HSI-Drive20'
]
}
def get_classes(dataset):
"""Get class names of a dataset."""
alias2name = {}
for name, aliases in dataset_aliases.items():
for alias in aliases:
alias2name[alias] = name
if is_str(dataset):
if dataset in alias2name:
labels = eval(alias2name[dataset] + '_classes()')
else:
raise ValueError(f'Unrecognized dataset: {dataset}')
else:
raise TypeError(f'dataset must a str, but got {type(dataset)}')
return labels
def get_palette(dataset):
"""Get class palette (RGB) of a dataset."""
alias2name = {}
for name, aliases in dataset_aliases.items():
for alias in aliases:
alias2name[alias] = name
if is_str(dataset):
if dataset in alias2name:
labels = eval(alias2name[dataset] + '_palette()')
else:
raise ValueError(f'Unrecognized dataset: {dataset}')
else:
raise TypeError(f'dataset must a str, but got {type(dataset)}')
return labels

View File

@@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.utils import get_git_hash
from mmengine.utils.dl_utils import collect_env as collect_base_env
import mmseg
def collect_env():
"""Collect the information of the running environments."""
env_info = collect_base_env()
env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
return env_info
if __name__ == '__main__':
for name, val in collect_env().items():
print(f'{name}: {val}')

View File

@@ -0,0 +1,109 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
PREDEFINED_TEMPLATES = {
'imagenet': [
'a bad photo of a {}.',
'a photo of many {}.',
'a sculpture of a {}.',
'a photo of the hard to see {}.',
'a low resolution photo of the {}.',
'a rendering of a {}.',
'graffiti of a {}.',
'a bad photo of the {}.',
'a cropped photo of the {}.',
'a tattoo of a {}.',
'the embroidered {}.',
'a photo of a hard to see {}.',
'a bright photo of a {}.',
'a photo of a clean {}.',
'a photo of a dirty {}.',
'a dark photo of the {}.',
'a drawing of a {}.',
'a photo of my {}.',
'the plastic {}.',
'a photo of the cool {}.',
'a close-up photo of a {}.',
'a black and white photo of the {}.',
'a painting of the {}.',
'a painting of a {}.',
'a pixelated photo of the {}.',
'a sculpture of the {}.',
'a bright photo of the {}.',
'a cropped photo of a {}.',
'a plastic {}.',
'a photo of the dirty {}.',
'a jpeg corrupted photo of a {}.',
'a blurry photo of the {}.',
'a photo of the {}.',
'a good photo of the {}.',
'a rendering of the {}.',
'a {} in a video game.',
'a photo of one {}.',
'a doodle of a {}.',
'a close-up photo of the {}.',
'a photo of a {}.',
'the origami {}.',
'the {} in a video game.',
'a sketch of a {}.',
'a doodle of the {}.',
'a origami {}.',
'a low resolution photo of a {}.',
'the toy {}.',
'a rendition of the {}.',
'a photo of the clean {}.',
'a photo of a large {}.',
'a rendition of a {}.',
'a photo of a nice {}.',
'a photo of a weird {}.',
'a blurry photo of a {}.',
'a cartoon {}.',
'art of a {}.',
'a sketch of the {}.',
'a embroidered {}.',
'a pixelated photo of a {}.',
'itap of the {}.',
'a jpeg corrupted photo of the {}.',
'a good photo of a {}.',
'a plushie {}.',
'a photo of the nice {}.',
'a photo of the small {}.',
'a photo of the weird {}.',
'the cartoon {}.',
'art of the {}.',
'a drawing of the {}.',
'a photo of the large {}.',
'a black and white photo of a {}.',
'the plushie {}.',
'a dark photo of a {}.',
'itap of a {}.',
'graffiti of the {}.',
'a toy {}.',
'itap of my {}.',
'a photo of a cool {}.',
'a photo of a small {}.',
'a tattoo of the {}.',
],
'vild': [
'a photo of a {}.',
'This is a photo of a {}',
'There is a {} in the scene',
'There is the {} in the scene',
'a photo of a {} in the scene',
'a photo of a small {}.',
'a photo of a medium {}.',
'a photo of a large {}.',
'This is a photo of a small {}.',
'This is a photo of a medium {}.',
'This is a photo of a large {}.',
'There is a small {} in the scene.',
'There is a medium {} in the scene.',
'There is a large {} in the scene.',
],
}
def get_predefined_templates(template_set_name: str) -> List[str]:
if template_set_name not in PREDEFINED_TEMPLATES:
raise ValueError(f'Template set {template_set_name} not found')
return PREDEFINED_TEMPLATES[template_set_name]

View File

@@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
import gzip
import io
import pickle
import cv2
import numpy as np
def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray:
"""Data decoding from bytes.
Args:
content (bytes): The data bytes got from files or other streams.
backend (str): The data decoding backend type. Options are 'numpy',
'nifti', 'cv2' and 'pickle'. Defaults to 'numpy'.
Returns:
numpy.ndarray: Loaded data array.
"""
if backend == 'pickle':
data = pickle.loads(content)
else:
with io.BytesIO(content) as f:
if backend == 'nifti':
f = gzip.open(f)
try:
from nibabel import FileHolder, Nifti1Image
except ImportError:
print('nifti files io depends on nibabel, please run'
'`pip install nibabel` to install it')
fh = FileHolder(fileobj=f)
data = Nifti1Image.from_file_map({'header': fh, 'image': fh})
data = Nifti1Image.from_bytes(data.to_bytes()).get_fdata()
elif backend == 'numpy':
data = np.load(f)
elif backend == 'cv2':
data = np.frombuffer(f.read(), dtype=np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_UNCHANGED)
else:
raise ValueError
return data

View File

@@ -0,0 +1,205 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
from mmcv.ops import point_sample
from mmengine.structures import InstanceData
from torch import Tensor
from mmseg.registry import TASK_UTILS
from mmseg.utils import ConfigType, SampleList
def seg_data_to_instance_data(ignore_index: int,
batch_data_samples: SampleList):
"""Convert the paradigm of ground truth from semantic segmentation to
instance segmentation.
Args:
ignore_index (int): The label index to be ignored.
batch_data_samples (List[SegDataSample]): The Data
Samples. It usually includes information such as
`gt_sem_seg`.
Returns:
tuple[Tensor]: A tuple contains two lists.
- batch_gt_instances (List[InstanceData]): Batch of
gt_instance. It usually includes ``labels``, each is
unique ground truth label id of images, with
shape (num_gt, ) and ``masks``, each is ground truth
masks of each instances of a image, shape (num_gt, h, w).
- batch_img_metas (List[Dict]): List of image meta information.
"""
batch_gt_instances = []
for data_sample in batch_data_samples:
gt_sem_seg = data_sample.gt_sem_seg.data
classes = torch.unique(
gt_sem_seg,
sorted=False,
return_inverse=False,
return_counts=False)
# remove ignored region
gt_labels = classes[classes != ignore_index]
masks = []
for class_id in gt_labels:
masks.append(gt_sem_seg == class_id)
if len(masks) == 0:
gt_masks = torch.zeros(
(0, gt_sem_seg.shape[-2],
gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
else:
gt_masks = torch.stack(masks).squeeze(1).long()
instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
batch_gt_instances.append(instance_data)
return batch_gt_instances
class MatchMasks:
"""Match the predictions to category labels.
Args:
num_points (int): the number of sampled points to compute cost.
num_queries (int): the number of prediction masks.
num_classes (int): the number of classes.
assigner (BaseAssigner): the assigner to compute matching.
"""
def __init__(self,
num_points: int,
num_queries: int,
num_classes: int,
assigner: ConfigType = None):
assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \
'cannot be None'
assert num_points > 0, 'num_points should be a positive integer.'
self.num_points = num_points
self.num_queries = num_queries
self.num_classes = num_classes
self.assigner = TASK_UTILS.build(assigner)
def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor],
batch_gt_instances: List[InstanceData]) -> Tuple:
"""Compute best mask matches for all images for a decoder layer.
Args:
cls_scores (List[Tensor]): Mask score logits from a single
decoder layer for all images. Each with shape (num_queries,
cls_out_channels).
mask_preds (List[Tensor]): Mask logits from a single decoder
layer for all images. Each with shape (num_queries, h, w).
batch_gt_instances (List[InstanceData]): each contains
``labels`` and ``masks``.
Returns:
tuple: a tuple containing the following targets.
- labels (List[Tensor]): Labels of all images.\
Each with shape (num_queries, ).
- mask_targets (List[Tensor]): Mask targets of\
all images. Each with shape (num_queries, h, w).
- mask_weights (List[Tensor]): Mask weights of\
all images. Each with shape (num_queries, ).
- avg_factor (int): Average factor that is used to
average the loss. `avg_factor` is usually equal
to the number of positive priors.
"""
batch_size = cls_scores.shape[0]
results = dict({
'labels': [],
'mask_targets': [],
'mask_weights': [],
})
for i in range(batch_size):
labels, mask_targets, mask_weights\
= self._get_targets_single(cls_scores[i],
mask_preds[i],
batch_gt_instances[i])
results['labels'].append(labels)
results['mask_targets'].append(mask_targets)
results['mask_weights'].append(mask_weights)
# shape (batch_size, num_queries)
labels = torch.stack(results['labels'], dim=0)
# shape (batch_size, num_gts, h, w)
mask_targets = torch.cat(results['mask_targets'], dim=0)
# shape (batch_size, num_queries)
mask_weights = torch.stack(results['mask_weights'], dim=0)
avg_factor = sum(
[len(gt_instances.labels) for gt_instances in batch_gt_instances])
res = (labels, mask_targets, mask_weights, avg_factor)
return res
def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
gt_instances: InstanceData) \
-> Tuple[Tensor, Tensor, Tensor]:
"""Compute a set of best mask matches for one image.
Args:
cls_score (Tensor): Mask score logits from a single decoder layer
for one image. Shape (num_queries, cls_out_channels).
mask_pred (Tensor): Mask logits for a single decoder layer for one
image. Shape (num_queries, h, w).
gt_instances (:obj:`InstanceData`): It contains ``labels`` and
``masks``.
Returns:
tuple[Tensor]: A tuple containing the following for one image.
- labels (Tensor): Labels of each image. \
shape (num_queries, ).
- mask_targets (Tensor): Mask targets of each image. \
shape (num_queries, h, w).
- mask_weights (Tensor): Mask weights of each image. \
shape (num_queries, ).
"""
gt_labels = gt_instances.labels
gt_masks = gt_instances.masks
# when "gt_labels" is empty, classify all queries to background
if len(gt_labels) == 0:
labels = gt_labels.new_full((self.num_queries, ),
self.num_classes,
dtype=torch.long)
mask_targets = gt_labels
mask_weights = gt_labels.new_zeros((self.num_queries, ))
return labels, mask_targets, mask_weights
# sample points
num_queries = cls_score.shape[0]
num_gts = gt_labels.shape[0]
point_coords = torch.rand((1, self.num_points, 2),
device=cls_score.device)
# shape (num_queries, num_points)
mask_points_pred = point_sample(
mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
1)).squeeze(1)
# shape (num_gts, num_points)
gt_points_masks = point_sample(
gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
1)).squeeze(1)
sampled_gt_instances = InstanceData(
labels=gt_labels, masks=gt_points_masks)
sampled_pred_instances = InstanceData(
scores=cls_score, masks=mask_points_pred)
# assign and sample
matched_quiery_inds, matched_label_inds = self.assigner.assign(
pred_instances=sampled_pred_instances,
gt_instances=sampled_gt_instances)
labels = gt_labels.new_full((self.num_queries, ),
self.num_classes,
dtype=torch.long)
labels[matched_quiery_inds] = gt_labels[matched_label_inds]
mask_weights = gt_labels.new_zeros((self.num_queries, ))
mask_weights[matched_quiery_inds] = 1
mask_targets = gt_masks[matched_label_inds]
return labels, mask_targets, mask_weights

View File

@@ -0,0 +1,128 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from .typing_utils import SampleList
def add_prefix(inputs, prefix):
"""Add prefix for dict.
Args:
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
dict: The dict with keys updated with ``prefix``.
"""
outputs = dict()
for name, value in inputs.items():
outputs[f'{prefix}.{name}'] = value
return outputs
def stack_batch(inputs: List[torch.Tensor],
data_samples: Optional[SampleList] = None,
size: Optional[tuple] = None,
size_divisor: Optional[int] = None,
pad_val: Union[int, float] = 0,
seg_pad_val: Union[int, float] = 255) -> torch.Tensor:
"""Stack multiple inputs to form a batch and pad the images and gt_sem_segs
to the max shape use the right bottom padding mode.
Args:
inputs (List[Tensor]): The input multiple tensors. each is a
CHW 3D-tensor.
data_samples (list[:obj:`SegDataSample`]): The list of data samples.
It usually includes information such as `gt_sem_seg`.
size (tuple, optional): Fixed padding size.
size_divisor (int, optional): The divisor of padded size.
pad_val (int, float): The padding value. Defaults to 0
seg_pad_val (int, float): The padding value. Defaults to 255
Returns:
Tensor: The 4D-tensor.
List[:obj:`SegDataSample`]: After the padding of the gt_seg_map.
"""
assert isinstance(inputs, list), \
f'Expected input type to be list, but got {type(inputs)}'
assert len({tensor.ndim for tensor in inputs}) == 1, \
f'Expected the dimensions of all inputs must be the same, ' \
f'but got {[tensor.ndim for tensor in inputs]}'
assert inputs[0].ndim == 3, f'Expected tensor dimension to be 3, ' \
f'but got {inputs[0].ndim}'
assert len({tensor.shape[0] for tensor in inputs}) == 1, \
f'Expected the channels of all inputs must be the same, ' \
f'but got {[tensor.shape[0] for tensor in inputs]}'
# only one of size and size_divisor should be valid
assert (size is not None) ^ (size_divisor is not None), \
'only one of size and size_divisor should be valid'
padded_inputs = []
padded_samples = []
inputs_sizes = [(img.shape[-2], img.shape[-1]) for img in inputs]
max_size = np.stack(inputs_sizes).max(0)
if size_divisor is not None and size_divisor > 1:
# the last two dims are H,W, both subject to divisibility requirement
max_size = (max_size +
(size_divisor - 1)) // size_divisor * size_divisor
for i in range(len(inputs)):
tensor = inputs[i]
if size is not None:
width = max(size[-1] - tensor.shape[-1], 0)
height = max(size[-2] - tensor.shape[-2], 0)
# (padding_left, padding_right, padding_top, padding_bottom)
padding_size = (0, width, 0, height)
elif size_divisor is not None:
width = max(max_size[-1] - tensor.shape[-1], 0)
height = max(max_size[-2] - tensor.shape[-2], 0)
padding_size = (0, width, 0, height)
else:
padding_size = [0, 0, 0, 0]
# pad img
pad_img = F.pad(tensor, padding_size, value=pad_val)
padded_inputs.append(pad_img)
# pad gt_sem_seg
if data_samples is not None:
data_sample = data_samples[i]
pad_shape = None
if 'gt_sem_seg' in data_sample:
gt_sem_seg = data_sample.gt_sem_seg.data
del data_sample.gt_sem_seg.data
data_sample.gt_sem_seg.data = F.pad(
gt_sem_seg, padding_size, value=seg_pad_val)
pad_shape = data_sample.gt_sem_seg.shape
if 'gt_edge_map' in data_sample:
gt_edge_map = data_sample.gt_edge_map.data
del data_sample.gt_edge_map.data
data_sample.gt_edge_map.data = F.pad(
gt_edge_map, padding_size, value=seg_pad_val)
pad_shape = data_sample.gt_edge_map.shape
if 'gt_depth_map' in data_sample:
gt_depth_map = data_sample.gt_depth_map.data
del data_sample.gt_depth_map.data
data_sample.gt_depth_map.data = F.pad(
gt_depth_map, padding_size, value=seg_pad_val)
pad_shape = data_sample.gt_depth_map.shape
data_sample.set_metainfo({
'img_shape': tensor.shape[-2:],
'pad_shape': pad_shape,
'padding_size': padding_size
})
padded_samples.append(data_sample)
else:
padded_samples.append(
dict(
img_padding_size=padding_size,
pad_shape=pad_img.shape[-2:]))
return torch.stack(padded_inputs, dim=0), padded_samples

View File

@@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import warnings
from mmengine import DefaultScope
def register_all_modules(init_default_scope: bool = True) -> None:
"""Register all modules in mmseg into the registries.
Args:
init_default_scope (bool): Whether initialize the mmseg default scope.
When `init_default_scope=True`, the global default scope will be
set to `mmseg`, and all registries will build modules from mmseg's
registry node. To understand more about the registry, please refer
to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
Defaults to True.
""" # noqa
import mmseg.datasets # noqa: F401,F403
import mmseg.engine # noqa: F401,F403
import mmseg.evaluation # noqa: F401,F403
import mmseg.models # noqa: F401,F403
import mmseg.structures # noqa: F401,F403
if init_default_scope:
never_created = DefaultScope.get_current_instance() is None \
or not DefaultScope.check_instance_created('mmseg')
if never_created:
DefaultScope.get_instance('mmseg', scope_name='mmseg')
return
current_scope = DefaultScope.get_current_instance()
if current_scope.scope_name != 'mmseg':
warnings.warn('The current default scope '
f'"{current_scope.scope_name}" is not "mmseg", '
'`register_all_modules` will force the current'
'default scope to be "mmseg". If this is not '
'expected, please set `init_default_scope=False`.')
# avoid name conflict
new_instance_name = f'mmseg-{datetime.datetime.now()}'
DefaultScope.get_instance(new_instance_name, scope_name='mmseg')

View File

@@ -0,0 +1,240 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""CLIP tokenizer.
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright
(c) 2021 OpenAI.
"""
import gzip
import html
import os
from functools import lru_cache
from typing import List, Union
import ftfy
import regex as re
import torch
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
@lru_cache()
def default_bpe():
return os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'bpe_simple_vocab_16e6.txt.gz')
@lru_cache()
def bytes_to_unicode():
"""Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings. This means you need a
large # of unicode characters in your vocab if you want to avoid UNKs. When
you're at something like a 10B token dataset you end up needing around 5K
for decent coverage. This is a significant percentage of your normal, say,
32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
unicode strings. And avoids mapping to whitespace/control characters the
bpe code barfs on.
"""
bs = list(range(ord('!'),
ord('~') + 1)) + list(range(
ord('¡'),
ord('¬') + 1)) + list(range(ord('®'),
ord('ÿ') + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length
strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer:
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + '</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
if not special_tokens:
special_tokens = ['<start_of_text>', '<end_of_text>']
else:
special_tokens = ['<start_of_text>', '<end_of_text>'
] + special_tokens
vocab.extend(special_tokens)
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {t: t for t in special_tokens}
special = '|'.join(special_tokens)
self.pat = re.compile(
special +
r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE)
self.vocab_size = len(self.encoder)
self.all_special_ids = [self.encoder[t] for t in special_tokens]
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + '</w>', )
pairs = get_pairs(word)
if not pairs:
return token + '</w>'
while True:
bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except: # noqa: E722, E261
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[
i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b]
for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
'utf-8', errors='replace').replace('</w>', ' ')
return text
_tokenizer = SimpleTokenizer()
def decode(output_ids: torch.Tensor):
output_ids = output_ids.cpu().numpy()
return _tokenizer.decode(output_ids)
def tokenize(texts: Union[str, List[str]],
context_length: int = 77) -> torch.LongTensor:
"""Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens,
shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder['<start_of_text>']
eot_token = _tokenizer.encoder['<end_of_text>']
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = eot_token
result[i, :len(tokens)] = torch.tensor(tokens)
return result
class HFTokenizer:
"""HuggingFace tokenizer wrapper."""
def __init__(self, tokenizer_name: str):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def save_pretrained(self, dest):
self.tokenizer.save_pretrained(dest)
def __call__(self,
texts: Union[str, List[str]],
context_length: int = 77) -> torch.Tensor:
# same cleaning as for default tokenizer, except lowercasing
# adding lower (for case-sensitive tokenizers) will make it
# more robust but less sensitive to nuance
if isinstance(texts, str):
texts = [texts]
texts = [whitespace_clean(basic_clean(text)) for text in texts]
input_ids = self.tokenizer(
texts,
return_tensors='pt',
max_length=context_length,
padding='max_length',
truncation=True,
).input_ids
return input_ids

View File

@@ -0,0 +1,25 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Collecting some commonly used type hint in mmflow."""
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from mmengine.config import ConfigDict
from mmseg.structures import SegDataSample
# Type hint of config data
ConfigType = Union[ConfigDict, dict]
OptConfigType = Optional[ConfigType]
# Type hint of one or more config data
MultiConfig = Union[ConfigType, Sequence[ConfigType]]
OptMultiConfig = Optional[MultiConfig]
SampleList = Sequence[SegDataSample]
OptSampleList = Optional[SampleList]
# Type hint of Tensor
TensorDict = Dict[str, torch.Tensor]
TensorList = Sequence[torch.Tensor]
ForwardResults = Union[Dict[str, torch.Tensor], List[SegDataSample],
Tuple[torch.Tensor], torch.Tensor]