From 300ce2e93f0c2f128753eff9a5f029575f9dd93d Mon Sep 17 00:00:00 2001 From: TY1667 Date: Sun, 19 Oct 2025 21:28:04 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96Dataset=E7=B1=BB=EF=BC=8C?= =?UTF-8?q?=E7=AE=80=E5=8C=96=E5=8F=82=E6=95=B0=E5=A4=84=E7=90=86=EF=BC=8C?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=9B=BE=E5=83=8F=E5=8A=A0=E8=BD=BD=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E5=A4=84=E7=90=86=EF=BC=8C=E7=A7=BB=E9=99=A4=E5=86=97?= =?UTF-8?q?=E4=BD=99=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- utils/dataset.py | 99 +++++++++--------------------------------------- 1 file changed, 18 insertions(+), 81 deletions(-) diff --git a/utils/dataset.py b/utils/dataset.py index e3d71f6..f6eaea2 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -8,16 +8,11 @@ import torch from PIL import Image from torch.utils import data -FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "JPEG", "JPG", "PNG", "TIFF" +FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp" class Dataset(data.Dataset): - params: dict - mosaic: bool - augment: bool - input_size: int - - def __init__(self, filenames, input_size: int, params: dict, augment: bool): + def __init__(self, filenames, input_size, params, augment): self.params = params self.mosaic = augment self.augment = augment @@ -48,8 +43,6 @@ class Dataset(data.Dataset): else: # Load image image, shape = self.load_image(index) - if image is None: - raise ValueError(f"Failed to load image at index {index}: {self.filenames[index]}") h, w = image.shape[:2] # Resize @@ -57,7 +50,7 @@ class Dataset(data.Dataset): label = self.labels[index].copy() if label.size: - label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, int(pad[0]), int(pad[1])) + label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, pad[0], pad[1]) if self.augment: image, label = random_perspective(image, label, self.params) @@ -84,25 +77,24 @@ class Dataset(data.Dataset): if nl: box[:, 0] = 1 - box[:, 0] - # target_cls = torch.zeros((nl, 1)) - # target_box = torch.zeros((nl, 4)) - # if nl: - # target_cls = torch.from_numpy(cls) - # target_box = torch.from_numpy(box) - - # fix [cls, box] empty bug. e.g. [0,1] is illegal in DataLoader collate_fn cat operation + # XXX: when nl=0, torch.from_numpy(box) will error if nl: target_cls = torch.from_numpy(cls).view(-1, 1).float() # always (N,1) target_box = torch.from_numpy(box).reshape(-1, 4).float() # always (N,4) else: target_cls = torch.zeros((0, 1), dtype=torch.float32) target_box = torch.zeros((0, 4), dtype=torch.float32) + # target_cls = torch.zeros((nl, 1)) + # target_box = torch.zeros((nl, 4)) + # if nl: + # target_cls = torch.from_numpy(cls) + # target_box = torch.from_numpy(box) # Convert HWC to CHW, BGR to RGB sample = image.transpose((2, 0, 1))[::-1] sample = numpy.ascontiguousarray(sample) - # init: return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl) + # return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl) return torch.from_numpy(sample), target_cls, target_box, torch.zeros((nl, 1), dtype=torch.long) def __len__(self): @@ -111,7 +103,7 @@ class Dataset(data.Dataset): def load_image(self, i): image = cv2.imread(self.filenames[i]) if image is None: - raise ValueError(f"Image not found or unable to open: {self.filenames[i]}") + raise FileNotFoundError(f"Image Not Found {self.filenames[i]}") h, w = image.shape[:2] r = self.input_size / max(h, w) if r != 1: @@ -173,8 +165,8 @@ class Dataset(data.Dataset): x2b = min(shape[1], x2a - x1a) y2b = min(y2a - y1a, shape[0]) - pad_w = (x1a if x1a is not None else 0) - (x1b if x1b is not None else 0) - pad_h = (y1a if y1a is not None else 0) - (y1b if y1b is not None else 0) + pad_w = x1a - x1b + pad_h = y1a - y1b image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b] # Labels @@ -197,14 +189,8 @@ class Dataset(data.Dataset): def collate_fn(batch): samples, cls, box, indices = zip(*batch) - # ensure empty tensor shape is correct - cls = [c.view(-1, 1) for c in cls] - box = [b.reshape(-1, 4) for b in box] - indices = [i for i in indices] - - cls = torch.cat(cls, dim=0) if cls else torch.zeros((0, 1)) - box = torch.cat(box, dim=0) if box else torch.zeros((0, 4)) - indices = torch.cat(indices, dim=0) if indices else torch.zeros((0,), dtype=torch.long) + cls = torch.cat(cls, dim=0) + box = torch.cat(box, dim=0) new_indices = list(indices) for i in range(len(indices)): @@ -215,7 +201,7 @@ class Dataset(data.Dataset): return torch.stack(samples, dim=0), targets @staticmethod - def load_label_use_cache(filenames): + def load_label(filenames): path = f"{os.path.dirname(filenames[0])}.cache" if os.path.exists(path): return torch.load(path, weights_only=False) @@ -228,14 +214,11 @@ class Dataset(data.Dataset): image.verify() # PIL verify shape = image.size # image size assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" - assert image.format is not None and image.format.lower() in FORMATS, ( - f"invalid image format {image.format}" - ) + assert image.format.lower() in FORMATS, f"invalid image format {image.format}" # verify labels a = f"{os.sep}images{os.sep}" b = f"{os.sep}labels{os.sep}" - if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"): with open(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt") as f: label = [x.split() for x in f.read().strip().splitlines() if len(x)] @@ -260,50 +243,6 @@ class Dataset(data.Dataset): torch.save(x, path) return x - @staticmethod - def load_label(filenames): - x = {} - for filename in filenames: - try: - # verify images - with open(filename, "rb") as f: - image = Image.open(f) - image.verify() - shape = image.size - assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" - assert image.format is not None and image.format.lower() in FORMATS, ( - f"invalid image format {image.format}" - ) - - # verify labels - a = f"{os.sep}images{os.sep}" - b = f"{os.sep}labels{os.sep}" - label_path = b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt" - - if os.path.isfile(label_path): - rows = [] - with open(label_path) as f: - for line in f: - parts = line.strip().split() - if len(parts) == 5: # YOLO format - rows.append([float(x) for x in parts]) - label = numpy.array(rows, dtype=numpy.float32) if rows else numpy.zeros((0, 5), dtype=numpy.float32) - - if label.shape[0]: - assert (label >= 0).all() - assert label.shape[1] == 5 - assert (label[:, 1:] <= 1.0001).all() - _, i = numpy.unique(label, axis=0, return_index=True) - label = label[i] - else: - label = numpy.zeros((0, 5), dtype=numpy.float32) - - except (FileNotFoundError, AssertionError): - label = numpy.zeros((0, 5), dtype=numpy.float32) - - x[filename] = label - return x - def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0): # Convert nx4 boxes @@ -462,9 +401,7 @@ class Albumentations: albumentations.ToGray(p=0.01), albumentations.MedianBlur(p=0.01), ] - self.transform = albumentations.Compose( - transforms, albumentations.BboxParams(format="yolo", label_fields=["class_labels"]) - ) + self.transform = albumentations.Compose(transforms, albumentations.BboxParams("yolo", ["class_labels"])) except ImportError: # package not installed, skip pass