优化Dataset类,简化参数处理,修复图像加载错误处理,移除冗余代码
This commit is contained in:
@@ -8,16 +8,11 @@ import torch
|
||||
from PIL import Image
|
||||
from torch.utils import data
|
||||
|
||||
FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "JPEG", "JPG", "PNG", "TIFF"
|
||||
FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp"
|
||||
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
params: dict
|
||||
mosaic: bool
|
||||
augment: bool
|
||||
input_size: int
|
||||
|
||||
def __init__(self, filenames, input_size: int, params: dict, augment: bool):
|
||||
def __init__(self, filenames, input_size, params, augment):
|
||||
self.params = params
|
||||
self.mosaic = augment
|
||||
self.augment = augment
|
||||
@@ -48,8 +43,6 @@ class Dataset(data.Dataset):
|
||||
else:
|
||||
# Load image
|
||||
image, shape = self.load_image(index)
|
||||
if image is None:
|
||||
raise ValueError(f"Failed to load image at index {index}: {self.filenames[index]}")
|
||||
h, w = image.shape[:2]
|
||||
|
||||
# Resize
|
||||
@@ -57,7 +50,7 @@ class Dataset(data.Dataset):
|
||||
|
||||
label = self.labels[index].copy()
|
||||
if label.size:
|
||||
label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, int(pad[0]), int(pad[1]))
|
||||
label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, pad[0], pad[1])
|
||||
if self.augment:
|
||||
image, label = random_perspective(image, label, self.params)
|
||||
|
||||
@@ -84,25 +77,24 @@ class Dataset(data.Dataset):
|
||||
if nl:
|
||||
box[:, 0] = 1 - box[:, 0]
|
||||
|
||||
# target_cls = torch.zeros((nl, 1))
|
||||
# target_box = torch.zeros((nl, 4))
|
||||
# if nl:
|
||||
# target_cls = torch.from_numpy(cls)
|
||||
# target_box = torch.from_numpy(box)
|
||||
|
||||
# fix [cls, box] empty bug. e.g. [0,1] is illegal in DataLoader collate_fn cat operation
|
||||
# XXX: when nl=0, torch.from_numpy(box) will error
|
||||
if nl:
|
||||
target_cls = torch.from_numpy(cls).view(-1, 1).float() # always (N,1)
|
||||
target_box = torch.from_numpy(box).reshape(-1, 4).float() # always (N,4)
|
||||
else:
|
||||
target_cls = torch.zeros((0, 1), dtype=torch.float32)
|
||||
target_box = torch.zeros((0, 4), dtype=torch.float32)
|
||||
# target_cls = torch.zeros((nl, 1))
|
||||
# target_box = torch.zeros((nl, 4))
|
||||
# if nl:
|
||||
# target_cls = torch.from_numpy(cls)
|
||||
# target_box = torch.from_numpy(box)
|
||||
|
||||
# Convert HWC to CHW, BGR to RGB
|
||||
sample = image.transpose((2, 0, 1))[::-1]
|
||||
sample = numpy.ascontiguousarray(sample)
|
||||
|
||||
# init: return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl)
|
||||
# return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl)
|
||||
return torch.from_numpy(sample), target_cls, target_box, torch.zeros((nl, 1), dtype=torch.long)
|
||||
|
||||
def __len__(self):
|
||||
@@ -111,7 +103,7 @@ class Dataset(data.Dataset):
|
||||
def load_image(self, i):
|
||||
image = cv2.imread(self.filenames[i])
|
||||
if image is None:
|
||||
raise ValueError(f"Image not found or unable to open: {self.filenames[i]}")
|
||||
raise FileNotFoundError(f"Image Not Found {self.filenames[i]}")
|
||||
h, w = image.shape[:2]
|
||||
r = self.input_size / max(h, w)
|
||||
if r != 1:
|
||||
@@ -173,8 +165,8 @@ class Dataset(data.Dataset):
|
||||
x2b = min(shape[1], x2a - x1a)
|
||||
y2b = min(y2a - y1a, shape[0])
|
||||
|
||||
pad_w = (x1a if x1a is not None else 0) - (x1b if x1b is not None else 0)
|
||||
pad_h = (y1a if y1a is not None else 0) - (y1b if y1b is not None else 0)
|
||||
pad_w = x1a - x1b
|
||||
pad_h = y1a - y1b
|
||||
image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]
|
||||
|
||||
# Labels
|
||||
@@ -197,14 +189,8 @@ class Dataset(data.Dataset):
|
||||
def collate_fn(batch):
|
||||
samples, cls, box, indices = zip(*batch)
|
||||
|
||||
# ensure empty tensor shape is correct
|
||||
cls = [c.view(-1, 1) for c in cls]
|
||||
box = [b.reshape(-1, 4) for b in box]
|
||||
indices = [i for i in indices]
|
||||
|
||||
cls = torch.cat(cls, dim=0) if cls else torch.zeros((0, 1))
|
||||
box = torch.cat(box, dim=0) if box else torch.zeros((0, 4))
|
||||
indices = torch.cat(indices, dim=0) if indices else torch.zeros((0,), dtype=torch.long)
|
||||
cls = torch.cat(cls, dim=0)
|
||||
box = torch.cat(box, dim=0)
|
||||
|
||||
new_indices = list(indices)
|
||||
for i in range(len(indices)):
|
||||
@@ -215,7 +201,7 @@ class Dataset(data.Dataset):
|
||||
return torch.stack(samples, dim=0), targets
|
||||
|
||||
@staticmethod
|
||||
def load_label_use_cache(filenames):
|
||||
def load_label(filenames):
|
||||
path = f"{os.path.dirname(filenames[0])}.cache"
|
||||
if os.path.exists(path):
|
||||
return torch.load(path, weights_only=False)
|
||||
@@ -228,14 +214,11 @@ class Dataset(data.Dataset):
|
||||
image.verify() # PIL verify
|
||||
shape = image.size # image size
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert image.format is not None and image.format.lower() in FORMATS, (
|
||||
f"invalid image format {image.format}"
|
||||
)
|
||||
assert image.format.lower() in FORMATS, f"invalid image format {image.format}"
|
||||
|
||||
# verify labels
|
||||
a = f"{os.sep}images{os.sep}"
|
||||
b = f"{os.sep}labels{os.sep}"
|
||||
|
||||
if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"):
|
||||
with open(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt") as f:
|
||||
label = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
||||
@@ -260,50 +243,6 @@ class Dataset(data.Dataset):
|
||||
torch.save(x, path)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def load_label(filenames):
|
||||
x = {}
|
||||
for filename in filenames:
|
||||
try:
|
||||
# verify images
|
||||
with open(filename, "rb") as f:
|
||||
image = Image.open(f)
|
||||
image.verify()
|
||||
shape = image.size
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert image.format is not None and image.format.lower() in FORMATS, (
|
||||
f"invalid image format {image.format}"
|
||||
)
|
||||
|
||||
# verify labels
|
||||
a = f"{os.sep}images{os.sep}"
|
||||
b = f"{os.sep}labels{os.sep}"
|
||||
label_path = b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"
|
||||
|
||||
if os.path.isfile(label_path):
|
||||
rows = []
|
||||
with open(label_path) as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) == 5: # YOLO format
|
||||
rows.append([float(x) for x in parts])
|
||||
label = numpy.array(rows, dtype=numpy.float32) if rows else numpy.zeros((0, 5), dtype=numpy.float32)
|
||||
|
||||
if label.shape[0]:
|
||||
assert (label >= 0).all()
|
||||
assert label.shape[1] == 5
|
||||
assert (label[:, 1:] <= 1.0001).all()
|
||||
_, i = numpy.unique(label, axis=0, return_index=True)
|
||||
label = label[i]
|
||||
else:
|
||||
label = numpy.zeros((0, 5), dtype=numpy.float32)
|
||||
|
||||
except (FileNotFoundError, AssertionError):
|
||||
label = numpy.zeros((0, 5), dtype=numpy.float32)
|
||||
|
||||
x[filename] = label
|
||||
return x
|
||||
|
||||
|
||||
def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0):
|
||||
# Convert nx4 boxes
|
||||
@@ -462,9 +401,7 @@ class Albumentations:
|
||||
albumentations.ToGray(p=0.01),
|
||||
albumentations.MedianBlur(p=0.01),
|
||||
]
|
||||
self.transform = albumentations.Compose(
|
||||
transforms, albumentations.BboxParams(format="yolo", label_fields=["class_labels"])
|
||||
)
|
||||
self.transform = albumentations.Compose(transforms, albumentations.BboxParams("yolo", ["class_labels"]))
|
||||
|
||||
except ImportError: # package not installed, skip
|
||||
pass
|
||||
|
Reference in New Issue
Block a user