如何加载数据集
# 如何加载数据集
# Cityscapes 数据集
# 类别数量及其所对应的 trainId 以及 color
在 官方脚本 (opens new window) 中可以看到标签的情况,截图如上。分为 8 个大类,其中 human
和 vehicle
两大类是拥有实例标签的。此外我们只需关注 ignoreInEval
字段为 False 的类别即可,共 19 类,所以我们在语义分割中对于每个像素都有 20 类输出,还有一个背景类别。并且实际参与训练的类别 id 是 trainId
字段的内容。
# 标签可视化
对于每张图像,官方数据集里都对应了四张标注,示例如上。
aachen_000000_000019_gtFine_color.png:彩色标注可视化
aachen_000000_000019_gtFine_instanceIds.png:标出
human
以及vehicle
两个大类的实例 idaachen_000000_000019_gtFine_labelIds.png:按照 labelId 来做可视化
aachen_000000_000019_gtFine_labelTrainIds.png:按照 TrainId 来做可视化
官方脚本:https://github.com/mcordts/cityscapesScripts
# 如何构建 Cityscapes 数据集的 Dataloader
# 使用 PyTorch 的官方接口
import torchvision
import numpy as np
root = r'./Cityscapes/'
# target_type = ['instance', 'semantic', 'color', 'polygon']
target_type = 'instance'
cityscapes_dataset = torchvision.datasets.Cityscapes(root, split='train', mode='fine', target_type=target_type, transform=None, target_transform=None)
2
3
4
5
6
7
8
9
10
我们可以写如上的代码使用 PyTorch 官方提供的 Dataset 接口直接使用,但这样我觉得不够灵活,所以我决定还是自己实现 Dataloader。
# 手动实现
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import mmcv
import glob
class cityscapesDataset(Dataset):
def __init__(self, root, split='train', img_size=(1024, 512), img_norm=True, gt_type='gtCoarse'):
super(cityscapesDataset, self).__init__()
self.name_dataset = 'Cityscapes'
self.num_classes = 19
self.colormap = {
0: (0, 0, 0), # unlabeled
1: (0, 0, 0), # ego vehicle
2: (0, 0, 0), # rectification border
3: (0, 0, 0), # out of roi
4: (0, 0, 0), # static
5: (0, 0, 0), # dynamic
6: (0, 0, 0), # ground
7: (128, 64, 128), # road
8: (244, 35, 232), # sidewalk
9: (0, 0, 0), # parking
10: (0, 0, 0), # rail track
11: (70, 70, 70), # building
12: (102, 102, 156), # wall
13: (190, 153, 153), # fence
14: (0, 0, 0), # guard rail
15: (0, 0, 0), # bridge
16: (0, 0, 0), # tunnel
17: (153, 153, 153), # pole
18: (0, 0, 0), # polegroup
19: (250, 170, 30), # traffic light
20: (220, 220, 0), # traffic sign
21: (107, 142, 35), # vegetation
22: (152, 251, 152), # terrain
23: (0, 130, 180), # sky
24: (220, 20, 60), # person
25: (255, 0, 0), # rider
26: (0, 0, 142), # car
27: (0, 0, 70), # truck
28: (0, 60, 100), # bus
29: (0, 0, 0), # caravan
30: (0, 0, 0), # trailer
31: (0, 80, 100), # train
32: (0, 0, 230), # motorcycle
33: (119, 11, 32), # bicycle
-1: (0, 0, 0) # license plate
# 5: (111, 74, 0), # dynamic
# 6: (81, 0, 81), # ground
# 9: (250, 170, 160), # parking
# 10: (230, 150, 140), # rail track
# 14: (180, 165, 180), # guard rail
# 15: (150, 100, 100), # bridge
# 16: (150, 120, 90), # tunnel
# 18: (153, 153, 153), # polegroup
# 29: (0, 0, 90), # caravan
# 30: (0, 0, 110), # trailer
}
self.labels = {
0: 'unlabeled',
1: 'ego vehicle',
2: 'rectification border',
3: 'out of roi',
4: 'static',
5: 'dynamic',
6: 'ground',
7: 'road',
8: 'sidewalk',
9: 'parking',
10: 'rail track',
11: 'building',
12: 'wall',
13: 'fence',
14: 'guard rail',
15: 'bridge',
16: 'tunnel',
17: 'pole',
18: 'polegroup',
19: 'traffic light',
20: 'traffic sign',
21: 'vegetation',
22: 'terrain',
23: 'sky',
24: 'person',
25: 'rider',
26: 'car',
27: 'truck',
28: 'bus',
29: 'caravan',
30: 'trailer',
31: 'train',
32: 'motorcycle',
33: 'bicycle',
-1: 'license plate'
}
self.trainId = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
self.ignoreId = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
self.classmap = dict(zip(self.trainId, range(self.num_classes)))
self.split = split
self.img_size = img_size
self.img_norm = img_norm
self.ignore_index = 250
self.gt_type = gt_type
self.root = root
self.images_base = os.path.join(self.root, "leftImg8bit", self.split)
self.annotations_base = os.path.join(self.root, self.gt_type, self.split)
self.files = {}
self.files[self.split] = glob.glob(os.path.join(self.images_base, '*/*.png'))
print('debug')
def __getitem__(self, index):
img_path = self.files[self.split][index].rstrip()
label_path = os.path.join(
self.annotations_base,
img_path.split(os.sep)[-2],
os.path.basename(img_path)[:-15] + f"{self.gt_type}_labelIds.png"
)
img = mmcv.imread(img_path, channel_order='rgb')
label = mmcv.imread(label_path, flag='grayscale')
label = self.encode_segmap(label)
# transform & augmentations
img, label = self.transform(img, label)
return img, label
def __len__(self):
return len(self.files[self.split])
def decode_segmap(self, mask):
r = mask.copy()
g = mask.copy()
b = mask.copy()
for cls in range(self.num_classes):
color_cls = self.trainId[cls]
r[mask == cls] = self.colormap[color_cls][0]
g[mask == cls] = self.colormap[color_cls][1]
b[mask == cls] = self.colormap[color_cls][2]
rgb = np.zeros((mask.shape[0], mask.shape[1], 3))
rgb[:, :, 0] = b / 255.0
rgb[:, :, 1] = g / 255.0
rgb[:, :, 2] = r / 255.0
return rgb
def encode_segmap(self, mask):
for ignorecls in self.ignoreId:
mask[mask == ignorecls] = self.ignore_index
for traincls in self.trainId:
mask[mask == traincls] = self.classmap[traincls]
return mask
def transform(self, img, label):
img = mmcv.imresize(img, (self.img_size[0], self.img_size[1]))
img = img.astype(np.float64)
if self.img_norm:
img = img.astype(float) / 255.0
classes = np.unique(label)
label = label.astype(float)
label = mmcv.imresize(label, (self.img_size[0], self.img_size[1]), interpolation='nearest')
label = label.astype(int)
if not np.all(classes == np.unique(label)):
print("WARN: resizing labels yielded fewer classes")
if not np.all(np.unique(label[label != self.ignore_index]) < self.num_classes):
print("after det", classes, np.unique(label))
raise ValueError("Segmentation map contained invalid class values")
img = torch.from_numpy(img).float()
label = torch.from_numpy(label).long()
return img, label
# dataloader test
if __name__ == '__main__':
train_dataset = cityscapesDataset(root='/home/muyun99/data/dataset/cityscapes', split='train', img_size=(1024, 512), img_norm=True, gt_type='gtFine')
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=8, pin_memory=True)
print(len(train_dataset))
for i, batch in enumerate(train_dataloader):
imgs, masks = batch
if i == 0:
print("单个img的size: ", imgs.shape)
print("单个mask的size: ", masks.shape)
img = imgs[0].numpy()
mask = masks[0].numpy()
mask = train_dataset.decode_segmap(mask)
mmcv.imshow(img, wait_time=1000)
mmcv.imshow(mask, wait_time=1000)
# break
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# 参考资料
https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/loader/cityscapes_loader.py
https://github.com/Tramac/awesome-semantic-segmentation-pytorch/blob/master/core/data/dataloader/cityscapes.py
https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py
1、修改 ignore_id
2、将 ignore_id 的类别修改为 255
3、计算 loss 的时候除去255的类别
- 02
- README 美化05-20
- 03
- 常见 Tricks 代码片段05-12