MMSeg无法使用单类自定义数据集训练(2023/2/17更新)
摘要:将三通道图像转为一通道图像,并将类别的通道值统一为0, 1, 2,以解决MMSeg的报错与无法训练问题
2023/02/17更新:添加了一个效果不好的真·单分类训练方式,以及效果图
描述 - 用二分类实现单类分割
- 跑自定义数据集时报错,理论上其它东西都没错,那就只能是图片问题。
- 但我这次弄了两个数据集,上一个虽然也报这个错,不过用某些方式解决了,可行的数据集的 GT 是彩色图片,报错是黑白图片,检查发现黑白图片也是三通道,那就不该是通道问题。
- 但查官方 issue 后,发现他们推荐单通道:https://github.com/open-mmlab/mmsegmentation/issues/1625#issuecomment-1140384065
- 在更改为单通道后,以下报错消失,但出现了新的问题,指标/损失异常
ValueError: Input and output must have the same number of spatial dimensions, but got input with with spatial dimensions of [128, 128] and output size of torch.Size([512, 512, 3]). Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size in (o1, o2, ...,oK) format.
- 多次测试,将单分类分为两类,背景类与目标类,分别对应像素值
0
,1
(值域 0-255),而后解决。 - 但出现目标类miou为0,改变损失权重为(1, 10)后解决。
- 顺带一提,MMSeg 说更新了类别为 1 时的处理,但我更新到最新版后依然和老版一样。
另一种单分类训练方式
- 最近跑了几次实验,训练 iter 从 20k 到 80k,目标的 iou 始终只有10~20,但 acc 比较高,有80+,即便将损失权重设为 (1, 100) 也无法改善低性能的问题。
- 我又去查了一下 issue,发现另一个单类训练的方法:
- 把前面说的二分类设置统统改为单类,即
num_classes=1, reduce_zero_label=True
,数据集类的CLASSES
和PALETTE
对应减少至一类。 - 并将损失设为
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4
- 更具体的代码见:https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/faq.md#how-to-handle-binary-segmentation-task。
- 把前面说的二分类设置统统改为单类,即
- 这样训练后,只会训练一类,iou 能稳定 20+,但 acc 只有 40+。从指标来看比较好看,但是实际上还是推荐用二分类来训练单分类:
- 我可视化结果之后,单类的会在目标附近框一个大框,只有一小部分是目标
- 而二分类的,则可以更精确的标注目标,仅在目标附近稍微扩散。
- 从结果来看,二分类的才是正确的,即便二分类的指标看起来很奇怪,看起来没有在训练一般,我猜可能是数据量太小了。
- 单分类训练结果可视化:
- 二分类训练结果可视化:
代码
- 排序代码有点难写,不想动脑子,因此只有一个量体裁衣的代码。
- 给定图片背景值为
(0, 0, 0)
,目标值为(255, 255, 255)
,代码将其改为(0)
与(1)
- 以下两个代码放在同级文件夹下,运行 chennel3to1.py,输入待处理文件夹(支持递归),输出结果见 log 文件夹。
- 我其实写了好多类似的小工具,但是就传了最初版到 GitHub 上,太懒了...
# channel3to1.py
from base_model import BaseModel
import cv2
import numpy as np
class Channels3to1(BaseModel):
def __init__(self):
super().__init__()
self.change_log_path()
pass
def run(self):
path = input("Input path: ")
files_path = self.get_path_content(path, 'allfile')
self.log(f"Path: {path}")
for i, file_path in enumerate(files_path):
self.log(f"{i+1}: {file_path}")
for i, file_path in enumerate(files_path):
img = cv2.imread(file_path)
H, W, C = img.shape
img = img[:, :, 0].tolist()
for h in range(H):
for w in range(W):
if img[h][w] != 0:
img[h][w] = [1]
else:
img[h][w] = [0]
img = np.array(img)
save_path = self.log_dir + "/"+ self.path2name(file_path, keep_ext=True)
cv2.imwrite(save_path, img, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
self.log(f"{i+1}: {file_path} converted (H, W, 3) -> (H, W, 1) to {save_path}")
if __name__ == "__main__":
Channels3to1().run()
# base_model.py
import os
import os.path as osp
import re
import json
import time
import datetime
class BaseModel():
"""
BaseModel, call it "utils" is OK.
"""
def __init__(self, log_dir='', lang='en'):
if log_dir == '':
self.log_root = f"./log/{self.__class__.__name__}"
else:
self.log_root = log_dir
self.log_dir = self.log_root
self.timestamp = time.time()
self.log_file = f"{self.__class__.__name__}_{self.timestamp}.log"
# self.lang_path = "./languages"
# self.lang_dict = {
# "en": "English.json",
# "zh": "Chinese.json"
# }
# self.lang_encoding = {
# "en": "utf-8",
# "zh": "gb18030"
# }
# self.lang = {}
# self.parse_from_language("zh")
def help(self):
""" Help function
Print the help message
"""
self.log(self.__doc__)
def change_log_path(self, mode="timestamp"):
if mode == "timestamp":
self.log_dir = osp.join(self.log_root, str(self.timestamp))
elif mode == "root":
self.log_dir = self.log_root
def init_log_file(self):
self.log_file = f"{self.__class__.__name__}_{time.time()}.log"
def get_path_content(self, path, mode='allfile'):
"""
mode:
allfile: All files in path, including files in subfolders.
file: Files in path, only including files in this dir: path
dir: Dirs in path, only including Dir in this dir: path
"""
path_content = []
index = 0
for root, dirs, files in os.walk(path):
index += 1
if mode == 'allfile':
for file in files:
file_path = osp.join(root, file)
path_content.append(file_path)
if mode == 'file':
for file in files:
file_path = osp.join(root, file)
path_content.append(file_path)
break
if mode == 'dir':
for dir in dirs:
dir_path = osp.join(root, dir)
path_content.append(dir_path)
break
return path_content
def is_file_meet(self,
file_path,
condition={
'size_max': '10M',
'size_min': '10M',
'ext_allow': ['pth', 'pt', 't'],
'ext_forbid': ['pth', 'pt', 't'],
'name_allow': ['epoch_99.t'],
'name_forbid': ['epoch_99.t']
}):
meet = True
for k, v in condition.items():
if k == 'size_max':
# file size should <= size_max
max_value = self.unit_conversion(v, 'B')
file_size = os.path.getsize(file_path)
if not file_size <= max_value:
meet = False
elif k == 'size_min':
# file size should >= size_min
min_value = self.unit_conversion(v, 'B')
file_size = os.path.getsize(file_path)
if not file_size >= min_value:
meet = False
elif k == 'ext_allow':
# file's extension name should in ext_allow[]
_, file_name = os.path.split(file_path)
_, ext = os.path.splitext(file_name)
ext = ext[1:]
if not ext in v:
meet = False
elif k == 'ext_forbid':
# file's extension name shouldn't in ext_forbid[]
_, file_name = os.path.split(file_path)
_, ext = os.path.splitext(file_name)
ext = ext[1:]
if ext in v:
meet = False
elif k == 'name_allow':
# file's name should in name_allow[]
_, file_name = os.path.split(file_path)
if not file_name in v:
meet = False
elif k == 'name_forbid':
# file's name shouldn't in name_forbid[]
_, file_name = os.path.split(file_path)
if file_name in v:
meet = False
return meet
def unit_conversion(self, size, output_unit='B'):
# convert [GB, MB, KB, B] to [GB, MB, KB, B]
if not isinstance(size, str):
return size
# to Byte
size = size.upper()
if 'GB' == size[-2:] or 'G' == size[-1]:
size = size.replace("G", '')
size = size.replace("B", '')
size_num = float(size)
size_num = size_num * 1024 * 1024 * 1024
elif 'MB' == size[-2:] or 'M' == size[-1]:
size = size.replace("M", '')
size = size.replace("B", '')
size_num = float(size)
size_num = size_num * 1024 * 1024
elif 'KB' == size[-2:] or 'K' == size[-1]:
size = size.replace("K", '')
size = size.replace("B", '')
size_num = float(size)
size_num = size_num * 1024
elif 'B' == size[-1]:
size = size.replace("B", '')
size_num = float(size)
else:
raise
# to output_unit
if output_unit in ['GB', 'G']:
size_num = size_num / 1024 / 1024 / 1024
if output_unit in ['MB', 'M']:
size_num = size_num / 1024 / 1024
if output_unit in ['KB', 'K']:
size_num = size_num / 1024
if output_unit in ['B']:
size_num = size_num
# return
return size_num
def mkdir(self, path):
if not osp.exists(path):
os.makedirs(path)
def split_content(self, content):
if isinstance(content[0], str):
content_split = []
for path in content:
content_split.append(osp.split(path))
return content_split
elif isinstance(content[0], list):
contents_split = []
for group in content:
content_split = []
for path in group:
content_split.append(osp.split(path))
contents_split.append(content_split)
return contents_split
def path_to_last_dir(self, path):
dirname = osp.dirname(path)
last_dir = osp.basename(dirname)
return last_dir
def path2name(self, path, keep_ext=False):
_, filename = osp.split(path)
if keep_ext:
return filename
file, _ = osp.splitext(filename)
return file
def sort_list(self, list):
# copy from: https://www.modb.pro/db/162223
# To make 1, 10, 2, 20, 3, 4, 5 -> 1, 2, 3, 4, 5, 10, 20
list = sorted(list, key=lambda s: [int(s) if s.isdigit() else s for s in sum(re.findall(r'(\D+)(\d+)', 'a'+s+'0'), ())])
return list
def file_last_subtract_1(self, path, mode='-'):
"""
Just for myself.
file:
xxx.png 1
ccc.png 2
---> mode='-' --->
file:
xxx.png 0
ccc.png 1
"""
with open(path, 'r') as f:
lines = f.readlines()
res = []
for line in lines:
last = -2 if line[-1] == '\n' else -1
line1, line2 = line[:last], line[last]
if mode == '-':
line2 = str(int(line2) - 1)
elif mode == '+':
line2 = str(int(line2) + 1)
line = line1 + line2 + "\n"
if last == -1:
line = line1 + line2
res.append(line)
with open(path, 'w') as f:
f.write("".join(res))
def log(self, content):
time_now = datetime.datetime.now()
content = f"{time_now}: {content}\n"
self.log2file(content, self.log_file, mode='a')
print(content, end='')
def append2file(self, path, text):
with open(path, 'a') as f:
f.write(text)
def log2file(self, content, log_path='log.txt', mode='w', show=False):
self.mkdir(self.log_dir)
path = osp.join(self.log_dir, log_path)
with open(path, mode, encoding='utf8') as f:
if isinstance(content, list):
f.write("".join(content))
elif isinstance(content, str):
f.write(content)
elif isinstance(content, dict):
json.dump(content, f, indent=2, sort_keys=True, ensure_ascii=False)
else:
f.write(str(content))
if show:
self.log(f"Log save to: {path}")
def list2tuple2str(self, list):
return str(tuple(list))
def dict_plus(self, dict, key, value=1):
if key in dict.keys():
dict[key] += value
else:
dict[key] = value
def sort_by_label(self, path_label_list):
"""
list:[
"mwhls.jpg 1", # path and label
"mwhls.png 0", # path and label
"mwhls.gif 0"] # path and label
-->
list:[
["0", "1"], # label
["mwhls.png 0", "mwhls.gif 0"], # class 0
["mwhls.jpg 1"]] # class 1
"""
label_list = []
for path_label in path_label_list:
label = path_label.split()[-1]
label_list.append(label)
label_set = set(label_list)
res_list = []
res_list.append(list(label_set))
for label in label_set:
index_equal = [] # why index_equal = label_list == label isn't working?
for i, lab in enumerate(label_list):
if lab == label:
index_equal.append(i)
res = [path_label_list[i] for i in index_equal] # why path_label_list[index_equal] isn't working either??
res_list.append(res)
return res_list
def clear_taobao_link(self, text):
# try:
link = "https://item.taobao.com/item.htm?"
try:
id_index_1 = text.index('&id=') + 1
id_index = id_index_1
except:
pass
try:
id_index_2 = text.index('?id=') + 1
id_index = id_index_2
except:
pass
try:
id = text[id_index: id_index+15]
text = link + id
except:
pass
return text
# except:
# return text
def parse_from_language(self, lang='en'):
path = osp.join(self.lang_path, self.lang_dict[lang])
with open(path, "rb") as f:
self.lang = json.load(f)
if __name__ == '__main__':
# .py to .exe
# os.system("pyinstaller -F main.py")
# print(get_path_content("test2"))
# file_last_subtract_1("path_label.txt")
pass
共有 0 条评论