用自己的数据集训练Mask R-CNN

分类: 深度学习 图像分割

准备工作

原训练代码:https://github.com/multimodallearning/pytorch-mask-rcnn

标注工具labelmehttps://github.com/wkentaro/labelme

数据生成

  • 安装好labelme后,标注好数据生成json文件
  • labelme_json_to_dataset命令生成数据集文件(一个xxx_json文件夹,xxx是json文件的名字)
  • 生成的文件夹中有img.png(原图),info.yaml(图中包含的类的名字),label.png(标签mask),label_viz.png(标签mask 可视化文件)四个文件,需要用的只有info.yaml以及label.png

(labelme)的安装和使用请参考其repo里的readme.md

代码修改

  • 在全局定义一个iter_num = 0
  • 修改coco.py里的CoCoDataset类,这里重新写一个类,代替原来的CoCoDataset类:

class DrugDataset(utils.Dataset):

  • 为这个类添加针对labelme标注工具的一些方法:
#得到该图中有多少个实例(物体)
def get_obj_index(self, image):
        n = np.max(image)
        return n
#解析labelme中得到的yaml文件,从而得到mask每一层对应的实例标签
def from_yaml_get_class(self,image_id):
        info=self.image_info[image_id]
        with open(info['yaml_path']) as f:
            temp=yaml.load(f.read())
            labels=temp['label_names']
            del labels[0]
        return labels
#重新写draw_mask
def draw_mask(self, num_obj, mask, image):
        info = self.image_info[image_id]
        for index in range(num_obj):
            for i in range(info['width']):
                for j in range(info['height']):
                    at_pixel = image.getpixel((i, j))
                    if at_pixel == index + 1:
                        mask[j, i, index] =1
        return mask
#重新写load_shapes,里面包含自己的自己的类别
#并在self.image_info信息中添加了path、mask_path 、yaml_path
def load_shapes(self, count, height, width, img_folder, mask_folder, imglist,dataset_root_path):
        """Generate the requested number of synthetic images.
        count: number of images to generate.
        height, width: the size of the generated images.
        """
        # Add classes
        self.add_class('shapes', 1, 'your_class_1')
        self.add_class('shapes', 2, 'your_class_2')
        for i in range(count):
            filestr = imglist[i].split(".")[0]
            filestr = filestr.split("_")[1]
            mask_path = 'your mask path'
            yaml_path = 'your yaml path'
            self.add_image("shapes", image_id=i, path=img_folder + "/" + imglist[i],
                           width=width, height=height, mask_path=mask_path,yaml_path=yaml_path)
#重写load_mask
def load_mask(self, image_id):
    """
    Generate instance masks for shapes of the given image ID.
    """
    global iter_num
    info = self.image_info[image_id]
    count = 1  # number of object
    img = Image.open(info['mask_path'])
    num_obj = self.get_obj_index(img)
    mask = np.zeros([info['height'], info['width'], num_obj], dtype=np.uint8)
    mask = self.draw_mask(num_obj, mask, img)
    occlusion = np.logical_not(mask[:, :, -1]).astype(np.uint8)
    for i in range(count - 2, -1, -1):
        mask[:, :, i] = mask[:, :, i] * occlusion
        occlusion = np.logical_and(occlusion, np.logical_not(mask[:, :, i]))
    labels=[]
    labels=self.from_yaml_get_class(image_id)
    labels_form=[]
    for i in range(len(labels)):
        if labels[i] == 'your_class1':
            labels_form.append('your_class1')
        elif labels[i] == 'your_class2':
            labels_form.append('your_class2')
    class_ids = np.array([self.class_names.index(s) for s in labels_form])
    return mask, class_ids.astype(np.int32)
  • 全局信息设置
#基础设置
dataset_root_path = 'your dataset root path'
img_folder = dataset_root_path + 'your image folder'
mask_folder = dataset_root_path+ 'your mask folder'
imglist = listdir(img_folder)
count = len(imglist)
width = 640
height = 400
  • 训练集准备
#train与val数据集准备
dataset_train = DrugDataset()
dataset_train.load_shapes(count, width, height, img_folder, mask_folder, imglist,dataset_root_path)
dataset_train.prepare()

dataset_val = DrugDataset()
dataset_val.load_shapes(count, width, height, img_folder, mask_folder, imglist,dataset_root_path)
dataset_val.prepare()

参考资料

  1. http://blog.csdn.net/l297969586/article/details/79140840
  2. http://blog.csdn.net/xiongchao99/article/details/79106588

上一篇: Git命令记录(持续更新ing)
下一篇: Ubuntu14.04下安装python3.5