init
This commit is contained in:
112
finetune/tools/torchserve/mmseg2torchserve.py
Normal file
112
finetune/tools/torchserve/mmseg2torchserve.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from mmengine import Config
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
try:
|
||||
from model_archiver.model_packaging import package_model
|
||||
from model_archiver.model_packaging_utils import ModelExportUtils
|
||||
except ImportError:
|
||||
package_model = None
|
||||
|
||||
|
||||
def mmseg2torchserve(
|
||||
config_file: str,
|
||||
checkpoint_file: str,
|
||||
output_folder: str,
|
||||
model_name: str,
|
||||
model_version: str = '1.0',
|
||||
force: bool = False,
|
||||
):
|
||||
"""Converts mmsegmentation model (config + checkpoint) to TorchServe
|
||||
`.mar`.
|
||||
|
||||
Args:
|
||||
config_file:
|
||||
In MMSegmentation config format.
|
||||
The contents vary for each task repository.
|
||||
checkpoint_file:
|
||||
In MMSegmentation checkpoint format.
|
||||
The contents vary for each task repository.
|
||||
output_folder:
|
||||
Folder where `{model_name}.mar` will be created.
|
||||
The file created will be in TorchServe archive format.
|
||||
model_name:
|
||||
If not None, used for naming the `{model_name}.mar` file
|
||||
that will be created under `output_folder`.
|
||||
If None, `{Path(checkpoint_file).stem}` will be used.
|
||||
model_version:
|
||||
Model's version.
|
||||
force:
|
||||
If True, if there is an existing `{model_name}.mar`
|
||||
file under `output_folder` it will be overwritten.
|
||||
"""
|
||||
mkdir_or_exist(output_folder)
|
||||
|
||||
config = Config.fromfile(config_file)
|
||||
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
config.dump(f'{tmpdir}/config.py')
|
||||
|
||||
args = Namespace(
|
||||
**{
|
||||
'model_file': f'{tmpdir}/config.py',
|
||||
'serialized_file': checkpoint_file,
|
||||
'handler': f'{Path(__file__).parent}/mmseg_handler.py',
|
||||
'model_name': model_name or Path(checkpoint_file).stem,
|
||||
'version': model_version,
|
||||
'export_path': output_folder,
|
||||
'force': force,
|
||||
'requirements_file': None,
|
||||
'extra_files': None,
|
||||
'runtime': 'python',
|
||||
'archive_format': 'default'
|
||||
})
|
||||
manifest = ModelExportUtils.generate_manifest_json(args)
|
||||
package_model(args, manifest)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(
|
||||
description='Convert mmseg models to TorchServe `.mar` format.')
|
||||
parser.add_argument('config', type=str, help='config file path')
|
||||
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
|
||||
parser.add_argument(
|
||||
'--output-folder',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Folder where `{model_name}.mar` will be created.')
|
||||
parser.add_argument(
|
||||
'--model-name',
|
||||
type=str,
|
||||
default=None,
|
||||
help='If not None, used for naming the `{model_name}.mar`'
|
||||
'file that will be created under `output_folder`.'
|
||||
'If None, `{Path(checkpoint_file).stem}` will be used.')
|
||||
parser.add_argument(
|
||||
'--model-version',
|
||||
type=str,
|
||||
default='1.0',
|
||||
help='Number used for versioning.')
|
||||
parser.add_argument(
|
||||
'-f',
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='overwrite the existing `{model_name}.mar`')
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
if package_model is None:
|
||||
raise ImportError('`torch-model-archiver` is required.'
|
||||
'Try: pip install torch-model-archiver')
|
||||
|
||||
mmseg2torchserve(args.config, args.checkpoint, args.output_folder,
|
||||
args.model_name, args.model_version, args.force)
|
||||
56
finetune/tools/torchserve/mmseg_handler.py
Normal file
56
finetune/tools/torchserve/mmseg_handler.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import base64
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import torch
|
||||
from mmengine.model.utils import revert_sync_batchnorm
|
||||
from ts.torch_handler.base_handler import BaseHandler
|
||||
|
||||
from mmseg.apis import inference_model, init_model
|
||||
|
||||
|
||||
class MMsegHandler(BaseHandler):
|
||||
|
||||
def initialize(self, context):
|
||||
properties = context.system_properties
|
||||
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.device = torch.device(self.map_location + ':' +
|
||||
str(properties.get('gpu_id')) if torch.cuda.
|
||||
is_available() else self.map_location)
|
||||
self.manifest = context.manifest
|
||||
|
||||
model_dir = properties.get('model_dir')
|
||||
serialized_file = self.manifest['model']['serializedFile']
|
||||
checkpoint = os.path.join(model_dir, serialized_file)
|
||||
self.config_file = os.path.join(model_dir, 'config.py')
|
||||
|
||||
self.model = init_model(self.config_file, checkpoint, self.device)
|
||||
self.model = revert_sync_batchnorm(self.model)
|
||||
self.initialized = True
|
||||
|
||||
def preprocess(self, data):
|
||||
images = []
|
||||
|
||||
for row in data:
|
||||
image = row.get('data') or row.get('body')
|
||||
if isinstance(image, str):
|
||||
image = base64.b64decode(image)
|
||||
image = mmcv.imfrombytes(image)
|
||||
images.append(image)
|
||||
|
||||
return images
|
||||
|
||||
def inference(self, data, *args, **kwargs):
|
||||
results = [inference_model(self.model, img) for img in data]
|
||||
return results
|
||||
|
||||
def postprocess(self, data):
|
||||
output = []
|
||||
|
||||
for image_result in data:
|
||||
_, buffer = cv2.imencode('.png', image_result[0].astype('uint8'))
|
||||
content = buffer.tobytes()
|
||||
output.append(content)
|
||||
return output
|
||||
58
finetune/tools/torchserve/test_torchserve.py
Normal file
58
finetune/tools/torchserve/test_torchserve.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from argparse import ArgumentParser
|
||||
from io import BytesIO
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mmcv
|
||||
import requests
|
||||
|
||||
from mmseg.apis import inference_model, init_model
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(
|
||||
description='Compare result of torchserve and pytorch,'
|
||||
'and visualize them.')
|
||||
parser.add_argument('img', help='Image file')
|
||||
parser.add_argument('config', help='Config file')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument('model_name', help='The model name in the server')
|
||||
parser.add_argument(
|
||||
'--inference-addr',
|
||||
default='127.0.0.1:8080',
|
||||
help='Address and port of the inference server')
|
||||
parser.add_argument(
|
||||
'--result-image',
|
||||
type=str,
|
||||
default=None,
|
||||
help='save server output in result-image')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
|
||||
with open(args.img, 'rb') as image:
|
||||
tmp_res = requests.post(url, image)
|
||||
content = tmp_res.content
|
||||
if args.result_image:
|
||||
with open(args.result_image, 'wb') as out_image:
|
||||
out_image.write(content)
|
||||
plt.imshow(mmcv.imread(args.result_image, 'grayscale'))
|
||||
plt.show()
|
||||
else:
|
||||
plt.imshow(plt.imread(BytesIO(content)))
|
||||
plt.show()
|
||||
model = init_model(args.config, args.checkpoint, args.device)
|
||||
image = mmcv.imread(args.img)
|
||||
result = inference_model(model, image)
|
||||
plt.imshow(result[0])
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user