-
Notifications
You must be signed in to change notification settings - Fork 366
Description
有代码,以下是maskrcnn在这个基代码的实现:
class MaskRCNNHead(nn.Module):
def init(self, n_class, roi_size, spatial_scale, num_convs=4, conv_dim=256, mask_out_dim=28):
super(MaskRCNNHead, self).init()
self.n_class = n_class
self.roi_size = roi_size
self.spatial_scale = spatial_scale
# 定义用于预测分割掩模的卷积层
self.conv_layers = nn.Sequential(*[
nn.Conv2d(256, conv_dim, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
] * num_convs)
# 最后一层用于预测每个类别的掩模
self.mask_pred = nn.Conv2d(conv_dim, n_class, kernel_size=1)
# 初始化权重
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, mean=0, std=0.01)
nn.init.constant_(m.bias, 0)
def forward(self, features, rois, roi_indices, img_size):
# 对特征图中的每个ROI进行裁剪
roi_features = []
for i, roi in enumerate(rois):
roi_feature = roi_pooling(features, roi, roi_indices[i], self.roi_size, self.spatial_scale)
roi_features.append(roi_feature)
# 将裁剪后的特征堆叠起来
roi_features = torch.stack(roi_features, dim=0)
# 通过卷积层预测掩模
x = self.conv_layers(roi_features)
mask_logits = self.mask_pred(x)
# 将输出调整到原始图像大小
masks = []
for i, roi in enumerate(rois):
mask = mask_logits[i]
mask = torch.sigmoid(mask) # 将logits转换为概率
mask = F.interpolate(mask.unsqueeze(0), size=img_size, mode='bilinear', align_corners=False).squeeze(0)
masks.append(mask)
# 返回预测的掩模
return masks
def roi_pooling(features, roi, roi_index, roi_size, spatial_scale):
# 假设roi_pooling是一个已经定义好的函数,用于从特征图中裁剪出对应ROI的部分
# 并进行池化操作以得到固定大小的输出
# 这里我们简化处理,直接使用ROIAlign
roi_align = torchvision.ops.RoIAlign((roi_size, roi_size), spatial_scale, sampling_ratio=2)
pooled_roi = roi_align(features, torch.tensor([roi], device=features.device))
return pooled_roi
还有就是,再建立一个文件maskrcnn.py:
import torch.nn as nn
from nets.classifier import Resnet50RoIHead, VGG16RoIHead, MaskRCNNHead
from nets.resnet50 import resnet50
from nets.rpn import RegionProposalNetwork
from nets.vgg16 import decom_vgg16
class MaskRCNN():
def init(self, num_classes,
mode = "training",
feat_stride = 16,
anchor_scales = [8, 16, 32],
ratios = [0.5, 1, 2],
backbone = 'vgg',
pretrained = False):
super(MaskRCNN, self).init()
self.feat_stride = feat_stride
#---------------------------------#
# 一共存在两个主干
# vgg和resnet50
#---------------------------------#
if backbone == 'vgg':
self.extractor, classifier = decom_vgg16(pretrained)
#---------------------------------#
# 构建建议框网络
#---------------------------------#
self.rpn = RegionProposalNetwork(
512, 512,
ratios = ratios,
anchor_scales = anchor_scales,
feat_stride = self.feat_stride,
mode = mode
)
#---------------------------------#
# 构建分类器网络
#---------------------------------#
self.head = VGG16RoIHead(
n_class = num_classes + 1,
roi_size = 7,
spatial_scale = 1,
classifier = classifier
)
self.mask_head = MaskRCNNHead(
n_class=num_classes + 1,
roi_size=14,
spatial_scale=1
)
elif backbone == 'resnet50':
self.extractor, classifier = resnet50(pretrained)
#---------------------------------#
# 构建classifier网络
#---------------------------------#
self.rpn = RegionProposalNetwork(
1024, 512,
ratios = ratios,
anchor_scales = anchor_scales,
feat_stride = self.feat_stride,
mode = mode
)
#---------------------------------#
# 构建classifier网络
#---------------------------------#
self.head = Resnet50RoIHead(
n_class = num_classes + 1,
roi_size = 14,
spatial_scale = 1,
classifier = classifier
)
self.mask_head = MaskRCNNHead(
n_class=num_classes + 1,
roi_size=14,
spatial_scale=1
)
def forward(self, x, scale=1., mode="forward"):
if mode == "forward":
#---------------------------------#
# 计算输入图片的大小
#---------------------------------#
img_size = x.shape[2:]
#---------------------------------#
# 利用主干网络提取特征
#---------------------------------#
base_feature = self.extractor.forward(x)
#---------------------------------#
# 获得建议框
#---------------------------------#
_, _, rois, roi_indices, _ = self.rpn.forward(base_feature, img_size, scale)
#---------------------------------------#
# 获得classifier的分类结果和回归结果
#---------------------------------------#
roi_cls_locs, roi_scores = self.head.forward(base_feature, rois, roi_indices, img_size)
return roi_cls_locs, roi_scores, rois, roi_indices
elif mode == "extractor":
#---------------------------------#
# 利用主干网络提取特征
#---------------------------------#
base_feature = self.extractor.forward(x)
return base_feature
elif mode == "rpn":
base_feature, img_size = x
#---------------------------------#
# 获得建议框
#---------------------------------#
rpn_locs, rpn_scores, rois, roi_indices, anchor = self.rpn.forward(base_feature, img_size, scale)
return rpn_locs, rpn_scores, rois, roi_indices, anchor
elif mode == "head":
base_feature, rois, roi_indices, img_size = x
#---------------------------------------#
# 获得classifier的分类结果和回归结果
#---------------------------------------#
roi_cls_locs, roi_scores = self.head.forward(base_feature, rois, roi_indices, img_size)
return roi_cls_locs, roi_scores
elif mode == "mask_head":
base_feature, rois, roi_indices, img_size = x
roi_masks = self.mask_head.forward(base_feature, rois, roi_indices, img_size)
return roi_masks
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
即可完成maskrcnn的训练功能!