Muyun99's wiki Muyun99's wiki
首页
学术搬砖
学习笔记
生活杂谈
wiki搬运
资源收藏
关于
  • 分类
  • 标签
  • 归档
GitHub (opens new window)

Muyun99

努力成为一个善良的人
首页
学术搬砖
学习笔记
生活杂谈
wiki搬运
资源收藏
关于
  • 分类
  • 标签
  • 归档
GitHub (opens new window)
  • 代码实践-目标检测

  • 代码实践-图像分割

  • 代码实践-自监督学习

  • 竞赛笔记-视觉竞赛

  • 框架解析-mmlab系列

    • MMClassifiction 框架学习导言
    • mmcls 是如何能够通过config 就搭建好一个模型的?
    • 为自己的 inicls 框架加上 fp16 训练
    • 为自己的 inicls 框架集成 Horovod
    • 为自己的 inicls 框架集成 DALI
    • mmsegmentation框架解析(上)
    • mmsegmentation框架解析(中)
      • mmsegmentation 框架解析
        • Tutorial 3:自定义数据 Pipeline
        • Tutorial 4:自定义模型
    • mmsegmentation框架解析(下)
    • mmcv 使用
    • mmcv使用(中)--Config
    • 什么是 Register
    • 什么是 ABCMeta
    • mmseg数据集
    • mmseg 推理单张图像并保存
    • 计算loss和计算metric
  • 讲座记录-有意思的文章集合

  • 体会感悟-产品沉思录观后有感

  • 体会感悟-摄影

  • 系列笔记-

  • 系列笔记-乐理和五线谱

  • 系列笔记-爬虫实践

  • 系列笔记-Django学习笔记

  • 系列笔记-Git 使用笔记

  • 系列笔记-网站搭建

  • 系列笔记-图卷积网络

  • 课程笔记-MIT-NULL

  • 系列笔记-OpenCV-Python

  • 系列笔记-使用 Beancount 记账

  • 系列笔记-Python设计模式

  • 系列笔记-MLOps

  • 系列笔记-Apollo自动驾驶

  • 系列笔记-PaddlePaddle

  • 系列笔记-视频操作

  • Vue+Django前后端分离开发

  • 深度学习及机器学习理论知识学习笔记

  • PyTorch Tricks

  • 学习笔记
  • 框架解析-mmlab系列
Muyun99
2021-04-13

mmsegmentation框架解析(中)

# mmsegmentation 框架解析

# Tutorial 3:自定义数据 Pipeline

因为语义分割的数据集可能不是相同的尺寸,所以在 MMCV 引入了一个新的 DataContainer 数据类型,来帮助组织不同尺寸的数据。

dataset 和 data preparation pipeline 是解耦的,dataset 定义的是如何处理图像和标注,data pipeline 定义了所有准备数据的步骤。一个 pipeline 包含了一系列的操作,每个操作接收一个 dict 作为输入,并输出一个 dict 用作下一步的transform。下面是一个 Cityscapes 数据集的dataset pipeline

# dataset settings
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 1024)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(2048, 1024),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

# 1、data pipeline 的组成部分

# 1.1 数据加载
  • LoadImageFromFile
  • LoadAnnotations
# 1.2 预处理
  • Resize
  • RandomFlip
  • Pad
  • RandomCrop
  • Normalize
  • SegRescale
  • PhotoMetricDistortion
# 1.3 格式化
  • ToTensor

  • ImageToTensor

  • Transpose

  • ToDataContainer

  • DefaultFormatBundle

  • Collect

# 1.4 测试时增强(TTA)
  • MultiScaleFlipAug

# 2、如何自定义 pipeline

首先实现自己的 MyTransform 类

from mmseg.datasets import PIPELINES

@PIPELINES.register_module()
class MyTransform:

    def __call__(self, results):
        results['dummy'] = True
        return results
1
2
3
4
5
6
7
8

然后在 _init_.py 中 import 该类

from .my_pipeline import MyTransform
1

在 config 中使用

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 1024)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='MyTransform'),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

# Tutorial 4:自定义模型

# 1、自定义 optimizer

先定义自己的优化器 MyOptimizer,并注册在 _init_.py 中 import 该优化器

from mmcv.runner import OPTIMIZERS
from torch.optim import Optimizer


@OPTIMIZERS.register_module
class MyOptimizer(Optimizer):

    def __init__(self, a, b, c)
1
2
3
4
5
6
7
8
from .my_optimizer import MyOptimizer
1

2、在 config 中使用自定义的优化器

optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
1
optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)
1
optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)
1

# 2、自定义 optimizer constructor

from mmcv.utils import build_from_cfg

from mmcv.runner import OPTIMIZER_BUILDERS
from .cocktail_optimizer import CocktailOptimizer


@OPTIMIZER_BUILDERS.register_module
class CocktailOptimizerConstructor(object):

    def __init__(self, optimizer_cfg, paramwise_cfg=None):

    def __call__(self, model):

        return my_optimizer
1
2
3
4
5
6
7
8
9
10
11
12
13
14

# 3、开发一个新组件

# 3.1 添加新的 backbone
import torch.nn as nn

from ..registry import BACKBONES

@BACKBONES.register_module
class MobileNet(nn.Module):
    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tuple
        pass

    def init_weights(self, pretrained=None):
        pass
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from .mobilenet import MobileNet
1
model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...
1
2
3
4
5
6
7
# 3.2 添加新的 head
@HEADS.register_module()
class PSPHead(BaseDecodeHead):

    def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
        super(PSPHead, self).__init__(**kwargs)

    def init_weights(self):

    def forward(self, inputs):
1
2
3
4
5
6
7
8
9
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    pretrained='pretrain_model/resnet50_v1c_trick-2cccc1ad.pth',
    backbone=dict(
        type='ResNetV1c',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        dilations=(1, 1, 2, 4),
        strides=(1, 2, 1, 1),
        norm_cfg=norm_cfg,
        norm_eval=False,
        style='pytorch',
        contract_dilation=True),
    decode_head=dict(
        type='PSPHead',
        in_channels=2048,
        in_index=3,
        channels=512,
        pool_scales=(1, 2, 3, 6),
        dropout_ratio=0.1,
        num_classes=19,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 3.3 添加新的 loss
import torch
import torch.nn as nn

from ..builder import LOSSES
from .utils import weighted_loss

@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@LOSSES.register_module
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from .my_loss import MyLoss, my_loss
1
loss_decode=dict(type='MyLoss', loss_weight=1.0))
1
上次更新: 2021/08/02, 21:04:52
mmsegmentation框架解析(上)
mmsegmentation框架解析(下)

← mmsegmentation框架解析(上) mmsegmentation框架解析(下)→

最近更新
01
Structured Knowledge Distillation for Semantic Segmentation
06-03
02
README 美化
05-20
03
常见 Tricks 代码片段
05-12
更多文章>
Theme by Vdoing | Copyright © 2021-2023 Muyun99 | MIT License
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式
×