mmseg 推理单张图像并保存
from mmseg.apis import inference_segmentor, init_segmentor
import mmcv
import pandas as pd
import os
from tqdm import tqdm
def generate_pseudo_masks(config_file, checkpoint_file, dir_save_pseudo_masks):
model = init_segmentor(config_file, checkpoint_file, device='cuda:1')
PALETTE = []
for i in range(150):
PALETTE.append([i, i, i])
model.PALETTE = PALETTE
if not os.path.exists(dir_save_pseudo_masks):
os.mkdir(dir_save_pseudo_masks)
for image_name in tqdm(list_images):
img = mmcv.imread(image_name)
result = inference_segmentor(model, img)
model.show_result(img, result, out_file=os.path.join(dir_save_pseudo_masks, image_name.split('/')[-1]), opacity=1)
if __name__ == '__main__':
df = pd.read_csv(train_coarse_0.csv')
list_images = df['filename'].tolist()
config_file = 'configs/deeplabv3plus/deeplabv3plus_r101-d8_512x512_160k_ade20k.py'
checkpoint_file = 'checkpoints/ade20k/deeplabv3plus_r101-d8_512x512_160k_ade20k_20200615_123232-38ed86bb.pth'
dir_save_pseudo_masks = '/home/muyun99/data/dataset/Public-Dataset/Cityscapes/cityscapes_pseudo_mask/deeplabv3plus_r101-d8_512x512_160k_ade20k'
generate_pseudo_masks(config_file, checkpoint_file, dir_save_pseudo_masks)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
上次更新: 2021/10/03, 15:21:19
- 02
- README 美化05-20
- 03
- 常见 Tricks 代码片段05-12