优化Dataset类,简化参数处理,修复图像加载错误处理,移除冗余代码

This commit is contained in:
TY1667
2025-10-19 21:28:04 +08:00
parent 40de29591b
commit 300ce2e93f

View File

@@ -8,16 +8,11 @@ import torch
from PIL import Image
from torch.utils import data
FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "JPEG", "JPG", "PNG", "TIFF"
FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp"
class Dataset(data.Dataset):
params: dict
mosaic: bool
augment: bool
input_size: int
def __init__(self, filenames, input_size: int, params: dict, augment: bool):
def __init__(self, filenames, input_size, params, augment):
self.params = params
self.mosaic = augment
self.augment = augment
@@ -48,8 +43,6 @@ class Dataset(data.Dataset):
else:
# Load image
image, shape = self.load_image(index)
if image is None:
raise ValueError(f"Failed to load image at index {index}: {self.filenames[index]}")
h, w = image.shape[:2]
# Resize
@@ -57,7 +50,7 @@ class Dataset(data.Dataset):
label = self.labels[index].copy()
if label.size:
label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, int(pad[0]), int(pad[1]))
label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, pad[0], pad[1])
if self.augment:
image, label = random_perspective(image, label, self.params)
@@ -84,25 +77,24 @@ class Dataset(data.Dataset):
if nl:
box[:, 0] = 1 - box[:, 0]
# target_cls = torch.zeros((nl, 1))
# target_box = torch.zeros((nl, 4))
# if nl:
# target_cls = torch.from_numpy(cls)
# target_box = torch.from_numpy(box)
# fix [cls, box] empty bug. e.g. [0,1] is illegal in DataLoader collate_fn cat operation
# XXX: when nl=0, torch.from_numpy(box) will error
if nl:
target_cls = torch.from_numpy(cls).view(-1, 1).float() # always (N,1)
target_box = torch.from_numpy(box).reshape(-1, 4).float() # always (N,4)
else:
target_cls = torch.zeros((0, 1), dtype=torch.float32)
target_box = torch.zeros((0, 4), dtype=torch.float32)
# target_cls = torch.zeros((nl, 1))
# target_box = torch.zeros((nl, 4))
# if nl:
# target_cls = torch.from_numpy(cls)
# target_box = torch.from_numpy(box)
# Convert HWC to CHW, BGR to RGB
sample = image.transpose((2, 0, 1))[::-1]
sample = numpy.ascontiguousarray(sample)
# init: return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl)
# return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl)
return torch.from_numpy(sample), target_cls, target_box, torch.zeros((nl, 1), dtype=torch.long)
def __len__(self):
@@ -111,7 +103,7 @@ class Dataset(data.Dataset):
def load_image(self, i):
image = cv2.imread(self.filenames[i])
if image is None:
raise ValueError(f"Image not found or unable to open: {self.filenames[i]}")
raise FileNotFoundError(f"Image Not Found {self.filenames[i]}")
h, w = image.shape[:2]
r = self.input_size / max(h, w)
if r != 1:
@@ -173,8 +165,8 @@ class Dataset(data.Dataset):
x2b = min(shape[1], x2a - x1a)
y2b = min(y2a - y1a, shape[0])
pad_w = (x1a if x1a is not None else 0) - (x1b if x1b is not None else 0)
pad_h = (y1a if y1a is not None else 0) - (y1b if y1b is not None else 0)
pad_w = x1a - x1b
pad_h = y1a - y1b
image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]
# Labels
@@ -197,14 +189,8 @@ class Dataset(data.Dataset):
def collate_fn(batch):
samples, cls, box, indices = zip(*batch)
# ensure empty tensor shape is correct
cls = [c.view(-1, 1) for c in cls]
box = [b.reshape(-1, 4) for b in box]
indices = [i for i in indices]
cls = torch.cat(cls, dim=0) if cls else torch.zeros((0, 1))
box = torch.cat(box, dim=0) if box else torch.zeros((0, 4))
indices = torch.cat(indices, dim=0) if indices else torch.zeros((0,), dtype=torch.long)
cls = torch.cat(cls, dim=0)
box = torch.cat(box, dim=0)
new_indices = list(indices)
for i in range(len(indices)):
@@ -215,7 +201,7 @@ class Dataset(data.Dataset):
return torch.stack(samples, dim=0), targets
@staticmethod
def load_label_use_cache(filenames):
def load_label(filenames):
path = f"{os.path.dirname(filenames[0])}.cache"
if os.path.exists(path):
return torch.load(path, weights_only=False)
@@ -228,14 +214,11 @@ class Dataset(data.Dataset):
image.verify() # PIL verify
shape = image.size # image size
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert image.format is not None and image.format.lower() in FORMATS, (
f"invalid image format {image.format}"
)
assert image.format.lower() in FORMATS, f"invalid image format {image.format}"
# verify labels
a = f"{os.sep}images{os.sep}"
b = f"{os.sep}labels{os.sep}"
if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"):
with open(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt") as f:
label = [x.split() for x in f.read().strip().splitlines() if len(x)]
@@ -260,50 +243,6 @@ class Dataset(data.Dataset):
torch.save(x, path)
return x
@staticmethod
def load_label(filenames):
x = {}
for filename in filenames:
try:
# verify images
with open(filename, "rb") as f:
image = Image.open(f)
image.verify()
shape = image.size
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert image.format is not None and image.format.lower() in FORMATS, (
f"invalid image format {image.format}"
)
# verify labels
a = f"{os.sep}images{os.sep}"
b = f"{os.sep}labels{os.sep}"
label_path = b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"
if os.path.isfile(label_path):
rows = []
with open(label_path) as f:
for line in f:
parts = line.strip().split()
if len(parts) == 5: # YOLO format
rows.append([float(x) for x in parts])
label = numpy.array(rows, dtype=numpy.float32) if rows else numpy.zeros((0, 5), dtype=numpy.float32)
if label.shape[0]:
assert (label >= 0).all()
assert label.shape[1] == 5
assert (label[:, 1:] <= 1.0001).all()
_, i = numpy.unique(label, axis=0, return_index=True)
label = label[i]
else:
label = numpy.zeros((0, 5), dtype=numpy.float32)
except (FileNotFoundError, AssertionError):
label = numpy.zeros((0, 5), dtype=numpy.float32)
x[filename] = label
return x
def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0):
# Convert nx4 boxes
@@ -462,9 +401,7 @@ class Albumentations:
albumentations.ToGray(p=0.01),
albumentations.MedianBlur(p=0.01),
]
self.transform = albumentations.Compose(
transforms, albumentations.BboxParams(format="yolo", label_fields=["class_labels"])
)
self.transform = albumentations.Compose(transforms, albumentations.BboxParams("yolo", ["class_labels"]))
except ImportError: # package not installed, skip
pass