您好,欢迎访问代理记账网站
  • 价格透明
  • 信息保密
  • 进度掌控
  • 售后无忧

MMSegmentation-Docs-Tutorial 4: Customize Models

MMSegmentation-Docs-Tutorial 4: Customize Models

源文档 https://mmsegmentation.readthedocs.io/en/latest/tutorials/customize_models.html



1 Customize optimizer

假设你想要添加一个名为MyOptimizer的优化器,该优化器有参数abc。首先需要在文件中实现这个优化器,即mmseg/core/optimizer/my_optimizer.py

from mmcv.runner import OPTIMIZERS
from torch.optim import Optimizer

@OPTIMIZERS.register_module
class MyOptimizer(Optimizer):

    def __init__(self, a, b, c)

然后将这个模块添加到mmseg/core/optimizer/__init__.py中,这样注册表就会找到新模块并添加它,

from .my_optimizer import MyOptimizer

然后你可以在配置文件的optimizer字段中使用MyOptimizer。在配置中,优化器是由如下字段optimizer定义的:

optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)

要使用自己的优化器,可以将字段更改为:

optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)

我们已经支持使用PyTorch实现的所有优化器,唯一的修改是更改配置文件的优化器optimizer字段。例如,如果您想使用ADAM,尽管性能会下降很多,但修改可以如下所示:

optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)

可以根据PyTorch的API文档直接设置参数

2 Customize optimizer constructor

一些模型可能有一些特定参数的优化设置,例如,BatchNoarm层的weight decay。用户可以通过自定义优化器构造函数来进行这些细粒度参数调优。

from mmcv.utils import build_from_cfg

from mmcv.runner import OPTIMIZER_BUILDERS
from .cocktail_optimizer import CocktailOptimizer


@OPTIMIZER_BUILDERS.register_module
class CocktailOptimizerConstructor(object):

    def __init__(self, optimizer_cfg, paramwise_cfg=None):

    def __call__(self, model):

        return my_optimizer

3 Develop new components

MMSegmentation主要有两种类型的组件。

  • backbone:通常是堆叠的卷积网络来提取特征地图,如ResNet, HRNet。
  • head:用于语义分割地特征图解码的组件
3.1 Add new backbones

这里通过一个MobileNet的例子来展示如何开发新的组件。

  • 创建一个新文件mmseg/models/backbones/mobilenet.py
import torch.nn as nn

from ..registry import BACKBONES

@BACKBONES.register_module
class MobileNet(nn.Module):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tuple
        pass

    def init_weights(self, pretrained=None):
        pass
  • mmseg/models/backbones/__init__.py中import模块
from .mobilenet import MobileNet

  • 在配置文件中使用
model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...
3.2 Add new heads

在MMSegmentation中,为所有的分割头提供了一个基本的BaseDecodeHead。所有新实现的解码头都应该从它派生出来。下面将使用PSPNet的示例演示如何开发一个新的头。
首先,在mmseg/models/decode_heads/psp_head.py中添加一个新的解码头。PSPNet实现了一个用于分割解码的解码头。为了实现一个解码头,我们基本上需要实现新模块的三个功能,如下所示:

@HEADS.register_module()
class PSPHead(BaseDecodeHead):

    def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
        super(PSPHead, self).__init__(**kwargs)

    def init_weights(self):

    def forward(self, inputs):

接下来,需要在mmseg/models/decode_heads/__init__.py中添加模块,这样对应的注册表就可以找到并加载它们。
PSPNet的配置文件如下所示:

norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    pretrained='pretrain_model/resnet50_v1c_trick-2cccc1ad.pth',
    backbone=dict(
        type='ResNetV1c',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        dilations=(1, 1, 2, 4),
        strides=(1, 2, 1, 1),
        norm_cfg=norm_cfg,
        norm_eval=False,
        style='pytorch',
        contract_dilation=True),
    decode_head=dict(
        type='PSPHead',
        in_channels=2048,
        in_index=3,
        channels=512,
        pool_scales=(1, 2, 3, 6),
        dropout_ratio=0.1,
        num_classes=19,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
3.3 Add new loss

假设想添加一个新的损失作为MyLoss用于分割解码。要添加一个新的损失函数,需要在mmseg/models/losses/my_loss.py中实现它。weighted_loss可以对每个元素的损失进行加权。

import torch
import torch.nn as nn

from ..builder import LOSSES
from .utils import weighted_loss

@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@LOSSES.register_module
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss

然后需要将其添加到mmseg/models/losses/__init__.py

from .my_loss import MyLoss, my_loss

若要使用,需修改loss_xxx字段。然后需要修改head中的loss_decode字段。Loss_weight可以用来平衡多个损失

loss_decode=dict(type='MyLoss', loss_weight=1.0))


分享:

低价透明

统一报价,无隐形消费

金牌服务

一对一专属顾问7*24小时金牌服务

信息保密

个人信息安全有保障

售后无忧

服务出问题客服经理全程跟进