MMSeg无法使用单类自定义数据集训练(2023/2/17更新)

摘要:将三通道图像转为一通道图像,并将类别的通道值统一为0, 1, 2,以解决MMSeg的报错与无法训练问题

2023/02/17更新:添加了一个效果不好的真·单分类训练方式,以及效果图

描述 - 用二分类实现单类分割

  • 跑自定义数据集时报错,理论上其它东西都没错,那就只能是图片问题。
  • 但我这次弄了两个数据集,上一个虽然也报这个错,不过用某些方式解决了,可行的数据集的 GT 是彩色图片,报错是黑白图片,检查发现黑白图片也是三通道,那就不该是通道问题。
  • 但查官方 issue 后,发现他们推荐单通道:https://github.com/open-mmlab/mmsegmentation/issues/1625#issuecomment-1140384065
  • 在更改为单通道后,以下报错消失,但出现了新的问题,指标/损失异常
ValueError: Input and output must have the same number of spatial dimensions, but got input with with spatial dimensions of [128, 128] and output size of torch.Size([512, 512, 3]). Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size in (o1, o2, ...,oK) format.

另一种单分类训练方式

  • 最近跑了几次实验,训练 iter 从 20k 到 80k,目标的 iou 始终只有10~20,但 acc 比较高,有80+,即便将损失权重设为 (1, 100) 也无法改善低性能的问题。
  • 我又去查了一下 issue,发现另一个单类训练的方法:
  • 这样训练后,只会训练一类,iou 能稳定 20+,但 acc 只有 40+。从指标来看比较好看,但是实际上还是推荐用二分类来训练单分类:
    • 我可视化结果之后,单类的会在目标附近框一个大框,只有一小部分是目标
    • 而二分类的,则可以更精确的标注目标,仅在目标附近稍微扩散。
    • 从结果来看,二分类的才是正确的,即便二分类的指标看起来很奇怪,看起来没有在训练一般,我猜可能是数据量太小了。
  • 单分类训练结果可视化:
    1.jpg
  • 二分类训练结果可视化:
    1.jpg

代码

  • 排序代码有点难写,不想动脑子,因此只有一个量体裁衣的代码。
  • 给定图片背景值为 (0, 0, 0),目标值为 (255, 255, 255),代码将其改为 (0)(1)
  • 以下两个代码放在同级文件夹下,运行 chennel3to1.py,输入待处理文件夹(支持递归),输出结果见 log 文件夹。
    • 我其实写了好多类似的小工具,但是就传了最初版到 GitHub 上,太懒了...
# channel3to1.py
from base_model import BaseModel
import cv2
import numpy as np

class Channels3to1(BaseModel):
    def __init__(self):
        super().__init__()
        self.change_log_path()
        pass

    def run(self):
        path = input("Input path: ")
        files_path = self.get_path_content(path, 'allfile')

        self.log(f"Path: {path}")
        for i, file_path in enumerate(files_path):
            self.log(f"{i+1}: {file_path}")
        for i, file_path in enumerate(files_path):
            img = cv2.imread(file_path)
            H, W, C = img.shape
            img = img[:, :, 0].tolist()
            for h in range(H):
                for w in range(W):
                    if img[h][w] != 0:
                        img[h][w] = [1]
                    else:
                        img[h][w] = [0]
            img = np.array(img)
            save_path = self.log_dir + "/"+ self.path2name(file_path, keep_ext=True)
            cv2.imwrite(save_path, img, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
            self.log(f"{i+1}: {file_path} converted (H, W, 3) -> (H, W, 1) to {save_path}")

if __name__ == "__main__":
    Channels3to1().run()
# base_model.py
import os
import os.path as osp
import re
import json
import time
import datetime

class BaseModel():
    """
    BaseModel, call it "utils" is OK.
    """
    def __init__(self, log_dir='', lang='en'):
        if log_dir == '':
            self.log_root = f"./log/{self.__class__.__name__}"
        else:
            self.log_root = log_dir
        self.log_dir = self.log_root
        self.timestamp = time.time()
        self.log_file = f"{self.__class__.__name__}_{self.timestamp}.log"
        # self.lang_path = "./languages"
        # self.lang_dict = {
        #     "en": "English.json",
        #     "zh": "Chinese.json"
        # }
        # self.lang_encoding = {
        #     "en": "utf-8",
        #     "zh": "gb18030"
        # }
        # self.lang = {}
        # self.parse_from_language("zh")

    def help(self): 
        """ Help function
        Print the help message
        """
        self.log(self.__doc__)

    def change_log_path(self, mode="timestamp"):
        if mode == "timestamp":
            self.log_dir = osp.join(self.log_root, str(self.timestamp))
        elif mode == "root":
            self.log_dir = self.log_root

    def init_log_file(self):
        self.log_file = f"{self.__class__.__name__}_{time.time()}.log"

    def get_path_content(self, path, mode='allfile'):
        """
        mode:
            allfile: All files in path, including files in subfolders.
            file: Files in path, only including files in this dir: path
            dir: Dirs in path, only including Dir in this dir: path
        """
        path_content = []
        index = 0
        for root, dirs, files in os.walk(path):
            index += 1
            if mode == 'allfile':
                for file in files:
                    file_path = osp.join(root, file)
                    path_content.append(file_path)
            if mode == 'file':
                for file in files:
                    file_path = osp.join(root, file)
                    path_content.append(file_path)
                break
            if mode == 'dir':
                for dir in dirs:
                    dir_path = osp.join(root, dir)
                    path_content.append(dir_path)
                break

        return path_content

    def is_file_meet(self, 
                     file_path, 
                     condition={
                         'size_max': '10M', 
                         'size_min': '10M', 
                         'ext_allow': ['pth', 'pt', 't'],
                         'ext_forbid': ['pth', 'pt', 't'],
                         'name_allow': ['epoch_99.t'],
                         'name_forbid': ['epoch_99.t']
                         }):
        meet = True
        for k, v in condition.items():
            if k == 'size_max':
                # file size should <= size_max
                max_value = self.unit_conversion(v, 'B')
                file_size = os.path.getsize(file_path)
                if not file_size <= max_value:
                    meet = False
            elif k == 'size_min':
                # file size should >= size_min
                min_value = self.unit_conversion(v, 'B')
                file_size = os.path.getsize(file_path)
                if not file_size >= min_value:
                    meet = False
            elif k == 'ext_allow':
                # file's extension name should in ext_allow[]
                _, file_name = os.path.split(file_path)
                _, ext = os.path.splitext(file_name)
                ext = ext[1:]
                if not ext in v:
                    meet = False
            elif k == 'ext_forbid':
                # file's extension name shouldn't in ext_forbid[]
                _, file_name = os.path.split(file_path)
                _, ext = os.path.splitext(file_name)
                ext = ext[1:]
                if ext in v:
                    meet = False
            elif k == 'name_allow':
                # file's name should in name_allow[]
                _, file_name = os.path.split(file_path)
                if not file_name in v:
                    meet = False
            elif k == 'name_forbid':
                # file's name shouldn't in name_forbid[]
                _, file_name = os.path.split(file_path)
                if file_name in v:
                    meet = False
        return meet

    def unit_conversion(self, size, output_unit='B'):
        # convert [GB, MB, KB, B] to [GB, MB, KB, B]
        if not isinstance(size, str):
            return size
        # to Byte
        size = size.upper()
        if 'GB' == size[-2:] or 'G' == size[-1]:
            size = size.replace("G", '')
            size = size.replace("B", '')
            size_num = float(size)
            size_num = size_num * 1024 * 1024 * 1024
        elif 'MB' == size[-2:] or 'M' == size[-1]:
            size = size.replace("M", '')
            size = size.replace("B", '')
            size_num = float(size)
            size_num = size_num * 1024 * 1024
        elif 'KB' == size[-2:] or 'K' == size[-1]:
            size = size.replace("K", '')
            size = size.replace("B", '')
            size_num = float(size)
            size_num = size_num * 1024
        elif 'B' == size[-1]:
            size = size.replace("B", '')
            size_num = float(size)
        else:
            raise

        # to output_unit
        if output_unit in ['GB', 'G']:
            size_num = size_num / 1024 / 1024 / 1024
        if output_unit in ['MB', 'M']:
            size_num = size_num / 1024 / 1024
        if output_unit in ['KB', 'K']:
            size_num = size_num / 1024
        if output_unit in ['B']:
            size_num = size_num

        # return
        return size_num

    def mkdir(self, path):
        if not osp.exists(path):
            os.makedirs(path)

    def split_content(self, content):
        if isinstance(content[0], str):
            content_split = []
            for path in content:
                content_split.append(osp.split(path))
            return content_split
        elif isinstance(content[0], list):
            contents_split = []
            for group in content:
                content_split = []
                for path in group:
                    content_split.append(osp.split(path))
                contents_split.append(content_split)
            return contents_split

    def path_to_last_dir(self, path):
        dirname = osp.dirname(path)
        last_dir = osp.basename(dirname)
        return last_dir

    def path2name(self, path, keep_ext=False):
        _, filename = osp.split(path)
        if keep_ext:
            return filename
        file, _ = osp.splitext(filename)
        return file

    def sort_list(self, list):
        # copy from: https://www.modb.pro/db/162223
        # To make 1, 10, 2, 20, 3, 4, 5 -> 1, 2, 3, 4, 5, 10, 20
        list = sorted(list, key=lambda s: [int(s) if s.isdigit() else s for s in sum(re.findall(r'(\D+)(\d+)', 'a'+s+'0'), ())])
        return list

    def file_last_subtract_1(self, path, mode='-'):
        """
        Just for myself.
        file:
            xxx.png 1
            ccc.png 2
        ---> mode='-' --->
        file:
            xxx.png 0
            ccc.png 1
        """
        with open(path, 'r') as f:
            lines = f.readlines()
        res = []
        for line in lines:
            last = -2 if line[-1] == '\n' else -1
            line1, line2 = line[:last], line[last]
            if mode == '-':
                line2 = str(int(line2) - 1)
            elif mode == '+':
                line2 = str(int(line2) + 1)
            line = line1 + line2 + "\n"
            if last == -1:
                line = line1 + line2
            res.append(line)
        with open(path, 'w') as f:
            f.write("".join(res))

    def log(self, content):
        time_now = datetime.datetime.now()
        content = f"{time_now}: {content}\n"
        self.log2file(content, self.log_file, mode='a')
        print(content, end='')

    def append2file(self, path, text):
        with open(path, 'a') as f:
            f.write(text)

    def log2file(self, content, log_path='log.txt', mode='w', show=False):
        self.mkdir(self.log_dir)
        path = osp.join(self.log_dir, log_path)
        with open(path, mode, encoding='utf8') as f:
            if isinstance(content, list):
                f.write("".join(content))
            elif isinstance(content, str):
                f.write(content)
            elif isinstance(content, dict):
                json.dump(content, f, indent=2, sort_keys=True, ensure_ascii=False)
            else:
                f.write(str(content))
        if show:
            self.log(f"Log save to: {path}")

    def list2tuple2str(self, list):
        return str(tuple(list))

    def dict_plus(self, dict, key, value=1):
        if key in dict.keys():
            dict[key] += value
        else:
            dict[key] = value

    def sort_by_label(self, path_label_list):
        """
        list:[
            "mwhls.jpg 1",                      # path and label
            "mwhls.png 0",                      # path and label
            "mwhls.gif 0"]                      # path and label
        -->
        list:[
            ["0", "1"],                         # label
            ["mwhls.png 0", "mwhls.gif 0"],     # class 0
            ["mwhls.jpg 1"]]                    # class 1
        """
        label_list = []
        for path_label in path_label_list:
            label = path_label.split()[-1]
            label_list.append(label)
        label_set = set(label_list)
        res_list = []
        res_list.append(list(label_set))
        for label in label_set:
            index_equal = []    # why index_equal = label_list == label isn't working?
            for i, lab in enumerate(label_list):
                if lab == label:
                    index_equal.append(i)
            res = [path_label_list[i] for i in index_equal] # why path_label_list[index_equal] isn't working either??
            res_list.append(res)
        return res_list

    def clear_taobao_link(self, text):
        # try:
        link = "https://item.taobao.com/item.htm?"
        try:
            id_index_1 = text.index('&id=') + 1
            id_index = id_index_1
        except:
            pass
        try:
            id_index_2 = text.index('?id=') + 1
            id_index = id_index_2
        except:
            pass
        try:
            id = text[id_index: id_index+15]
            text = link + id
        except:
            pass
        return text
        # except:
        #     return text

    def parse_from_language(self, lang='en'):
        path = osp.join(self.lang_path, self.lang_dict[lang])
        with open(path, "rb") as f:
            self.lang = json.load(f)

if __name__ == '__main__':
    # .py to .exe
    # os.system("pyinstaller -F main.py")
    # print(get_path_content("test2"))
    # file_last_subtract_1("path_label.txt")
    pass

版权声明:
作者:MWHLS
链接:https://mwhls.top/4423.html
来源:无镣之涯
文章版权归作者所有,未经允许请勿转载。

THE END
分享
二维码
打赏
< <上一篇
下一篇>>