import math import os import random import cv2 import numpy import torch from PIL import Image from torch.utils import data FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "JPEG", "JPG", "PNG", "TIFF" class Dataset(data.Dataset): params: dict mosaic: bool augment: bool input_size: int def __init__(self, filenames, input_size: int, params: dict, augment: bool): self.params = params self.mosaic = augment self.augment = augment self.input_size = input_size # Read labels labels = self.load_label(filenames) self.labels = list(labels.values()) self.filenames = list(labels.keys()) # update self.n = len(self.filenames) # number of samples self.indices = range(self.n) # Albumentations (optional, only used if package is installed) self.albumentations = Albumentations() def __getitem__(self, index): index = self.indices[index] if self.mosaic and random.random() < self.params["mosaic"]: # Load MOSAIC image, label = self.load_mosaic(index, self.params) # MixUp augmentation if random.random() < self.params["mix_up"]: index = random.choice(self.indices) mix_image1, mix_label1 = image, label mix_image2, mix_label2 = self.load_mosaic(index, self.params) image, label = mix_up(mix_image1, mix_label1, mix_image2, mix_label2) else: # Load image image, shape = self.load_image(index) if image is None: raise ValueError(f"Failed to load image at index {index}: {self.filenames[index]}") h, w = image.shape[:2] # Resize image, ratio, pad = resize(image, self.input_size, self.augment) label = self.labels[index].copy() if label.size: label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, int(pad[0]), int(pad[1])) if self.augment: image, label = random_perspective(image, label, self.params) nl = len(label) # number of labels h, w = image.shape[:2] cls = label[:, 0:1] box = label[:, 1:5] box = xy2wh(box, w, h) if self.augment: # Albumentations image, box, cls = self.albumentations(image, box, cls) nl = len(box) # update after albumentations # HSV color-space augment_hsv(image, self.params) # Flip up-down if random.random() < self.params["flip_ud"]: image = numpy.flipud(image) if nl: box[:, 1] = 1 - box[:, 1] # Flip left-right if random.random() < self.params["flip_lr"]: image = numpy.fliplr(image) if nl: box[:, 0] = 1 - box[:, 0] # target_cls = torch.zeros((nl, 1)) # target_box = torch.zeros((nl, 4)) # if nl: # target_cls = torch.from_numpy(cls) # target_box = torch.from_numpy(box) # fix [cls, box] empty bug. e.g. [0,1] is illegal in DataLoader collate_fn cat operation if nl: target_cls = torch.from_numpy(cls).view(-1, 1).float() # always (N,1) target_box = torch.from_numpy(box).reshape(-1, 4).float() # always (N,4) else: target_cls = torch.zeros((0, 1), dtype=torch.float32) target_box = torch.zeros((0, 4), dtype=torch.float32) # Convert HWC to CHW, BGR to RGB sample = image.transpose((2, 0, 1))[::-1] sample = numpy.ascontiguousarray(sample) # init: return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl) return torch.from_numpy(sample), target_cls, target_box, torch.zeros((nl, 1), dtype=torch.long) def __len__(self): return len(self.filenames) def load_image(self, i): image = cv2.imread(self.filenames[i]) if image is None: raise ValueError(f"Image not found or unable to open: {self.filenames[i]}") h, w = image.shape[:2] r = self.input_size / max(h, w) if r != 1: image = cv2.resize( image, dsize=(int(w * r), int(h * r)), interpolation=resample() if self.augment else cv2.INTER_LINEAR ) return image, (h, w) def load_mosaic(self, index, params): label4 = [] border = [-self.input_size // 2, -self.input_size // 2] image4 = numpy.full((self.input_size * 2, self.input_size * 2, 3), 0, dtype=numpy.uint8) y1a, y2a, x1a, x2a, y1b, y2b, x1b, x2b = (None, None, None, None, None, None, None, None) xc = int(random.uniform(-border[0], 2 * self.input_size + border[1])) yc = int(random.uniform(-border[0], 2 * self.input_size + border[1])) indices = [index] + random.choices(self.indices, k=3) random.shuffle(indices) for i, index in enumerate(indices): # Load image image, _ = self.load_image(index) shape = image.shape if i == 0: # top left x1a = max(xc - shape[1], 0) y1a = max(yc - shape[0], 0) x2a = xc y2a = yc x1b = shape[1] - (x2a - x1a) y1b = shape[0] - (y2a - y1a) x2b = shape[1] y2b = shape[0] if i == 1: # top right x1a = xc y1a = max(yc - shape[0], 0) x2a = min(xc + shape[1], self.input_size * 2) y2a = yc x1b = 0 y1b = shape[0] - (y2a - y1a) x2b = min(shape[1], x2a - x1a) y2b = shape[0] if i == 2: # bottom left x1a = max(xc - shape[1], 0) y1a = yc x2a = xc y2a = min(self.input_size * 2, yc + shape[0]) x1b = shape[1] - (x2a - x1a) y1b = 0 x2b = shape[1] y2b = min(y2a - y1a, shape[0]) if i == 3: # bottom right x1a = xc y1a = yc x2a = min(xc + shape[1], self.input_size * 2) y2a = min(self.input_size * 2, yc + shape[0]) x1b = 0 y1b = 0 x2b = min(shape[1], x2a - x1a) y2b = min(y2a - y1a, shape[0]) pad_w = (x1a if x1a is not None else 0) - (x1b if x1b is not None else 0) pad_h = (y1a if y1a is not None else 0) - (y1b if y1b is not None else 0) image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b] # Labels label = self.labels[index].copy() if len(label): label[:, 1:] = wh2xy(label[:, 1:], shape[1], shape[0], pad_w, pad_h) label4.append(label) # Concat/clip labels label4 = numpy.concatenate(label4, 0) for x in label4[:, 1:]: numpy.clip(x, 0, 2 * self.input_size, out=x) # Augment image4, label4 = random_perspective(image4, label4, params, border) return image4, label4 @staticmethod def collate_fn(batch): samples, cls, box, indices = zip(*batch) # ensure empty tensor shape is correct cls = [c.view(-1, 1) for c in cls] box = [b.reshape(-1, 4) for b in box] indices = [i for i in indices] cls = torch.cat(cls, dim=0) if cls else torch.zeros((0, 1)) box = torch.cat(box, dim=0) if box else torch.zeros((0, 4)) indices = torch.cat(indices, dim=0) if indices else torch.zeros((0,), dtype=torch.long) new_indices = list(indices) for i in range(len(indices)): new_indices[i] += i indices = torch.cat(new_indices, dim=0) targets = {"cls": cls, "box": box, "idx": indices} return torch.stack(samples, dim=0), targets @staticmethod def load_label_use_cache(filenames): path = f"{os.path.dirname(filenames[0])}.cache" if os.path.exists(path): return torch.load(path, weights_only=False) x = {} for filename in filenames: try: # verify images with open(filename, "rb") as f: image = Image.open(f) image.verify() # PIL verify shape = image.size # image size assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" assert image.format is not None and image.format.lower() in FORMATS, ( f"invalid image format {image.format}" ) # verify labels a = f"{os.sep}images{os.sep}" b = f"{os.sep}labels{os.sep}" if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"): with open(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt") as f: label = [x.split() for x in f.read().strip().splitlines() if len(x)] label = numpy.array(label, dtype=numpy.float32) nl = len(label) if nl: assert (label >= 0).all() assert label.shape[1] == 5 assert (label[:, 1:] <= 1).all() _, i = numpy.unique(label, axis=0, return_index=True) if len(i) < nl: # duplicate row check label = label[i] # remove duplicates else: label = numpy.zeros((0, 5), dtype=numpy.float32) else: label = numpy.zeros((0, 5), dtype=numpy.float32) except FileNotFoundError: label = numpy.zeros((0, 5), dtype=numpy.float32) except AssertionError: continue x[filename] = label torch.save(x, path) return x @staticmethod def load_label(filenames): x = {} for filename in filenames: try: # verify images with open(filename, "rb") as f: image = Image.open(f) image.verify() shape = image.size assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" assert image.format is not None and image.format.lower() in FORMATS, ( f"invalid image format {image.format}" ) # verify labels a = f"{os.sep}images{os.sep}" b = f"{os.sep}labels{os.sep}" label_path = b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt" if os.path.isfile(label_path): rows = [] with open(label_path) as f: for line in f: parts = line.strip().split() if len(parts) == 5: # YOLO format rows.append([float(x) for x in parts]) label = numpy.array(rows, dtype=numpy.float32) if rows else numpy.zeros((0, 5), dtype=numpy.float32) if label.shape[0]: assert (label >= 0).all() assert label.shape[1] == 5 assert (label[:, 1:] <= 1.0001).all() _, i = numpy.unique(label, axis=0, return_index=True) label = label[i] else: label = numpy.zeros((0, 5), dtype=numpy.float32) except (FileNotFoundError, AssertionError): label = numpy.zeros((0, 5), dtype=numpy.float32) x[filename] = label return x def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0): # Convert nx4 boxes # from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right y = numpy.copy(x) y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + pad_w # top left x y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + pad_h # top left y y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + pad_w # bottom right x y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + pad_h # bottom right y return y def xy2wh(x, w, h): # warning: inplace clip x[:, [0, 2]] = x[:, [0, 2]].clip(0, w - 1e-3) # x1, x2 x[:, [1, 3]] = x[:, [1, 3]].clip(0, h - 1e-3) # y1, y2 # Convert nx4 boxes # from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right y = numpy.copy(x) y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center y[:, 2] = (x[:, 2] - x[:, 0]) / w # width y[:, 3] = (x[:, 3] - x[:, 1]) / h # height return y def resample(): choices = (cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4) return random.choice(seq=choices) def augment_hsv(image, params): # HSV color-space augmentation h = params["hsv_h"] s = params["hsv_s"] v = params["hsv_v"] r = numpy.random.uniform(-1, 1, 3) * [h, s, v] + 1 h, s, v = cv2.split(cv2.cvtColor(image, cv2.COLOR_BGR2HSV)) x = numpy.arange(0, 256, dtype=r.dtype) lut_h = ((x * r[0]) % 180).astype("uint8") lut_s = numpy.clip(x * r[1], 0, 255).astype("uint8") lut_v = numpy.clip(x * r[2], 0, 255).astype("uint8") hsv = cv2.merge((cv2.LUT(h, lut_h), cv2.LUT(s, lut_s), cv2.LUT(v, lut_v))) cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR, dst=image) # no return needed def resize(image, input_size, augment): # Resize and pad image while meeting stride-multiple constraints shape = image.shape[:2] # current shape [height, width] # Scale ratio (new / old) r = min(input_size / shape[0], input_size / shape[1]) if not augment: # only scale down, do not scale up (for better val mAP) r = min(r, 1.0) # Compute padding pad = int(round(shape[1] * r)), int(round(shape[0] * r)) w = (input_size - pad[0]) / 2 h = (input_size - pad[1]) / 2 if shape[::-1] != pad: # resize image = cv2.resize(image, dsize=pad, interpolation=resample() if augment else cv2.INTER_LINEAR) top, bottom = int(round(h - 0.1)), int(round(h + 0.1)) left, right = int(round(w - 0.1)), int(round(w + 0.1)) image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT) # add border return image, (r, r), (w, h) def candidates(box1, box2): # box1(4,n), box2(4,n) w1, h1 = box1[2] - box1[0], box1[3] - box1[1] w2, h2 = box2[2] - box2[0], box2[3] - box2[1] aspect_ratio = numpy.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio return (w2 > 2) & (h2 > 2) & (w2 * h2 / (w1 * h1 + 1e-16) > 0.1) & (aspect_ratio < 100) def random_perspective(image, label, params, border=(0, 0)): h = image.shape[0] + border[0] * 2 w = image.shape[1] + border[1] * 2 # Center center = numpy.eye(3) center[0, 2] = -image.shape[1] / 2 # x translation (pixels) center[1, 2] = -image.shape[0] / 2 # y translation (pixels) # Perspective perspective = numpy.eye(3) # Rotation and Scale rotate = numpy.eye(3) a = random.uniform(-params["degrees"], params["degrees"]) s = random.uniform(1 - params["scale"], 1 + params["scale"]) rotate[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) # Shear shear = numpy.eye(3) shear[0, 1] = math.tan(random.uniform(-params["shear"], params["shear"]) * math.pi / 180) shear[1, 0] = math.tan(random.uniform(-params["shear"], params["shear"]) * math.pi / 180) # Translation translate = numpy.eye(3) translate[0, 2] = random.uniform(0.5 - params["translate"], 0.5 + params["translate"]) * w translate[1, 2] = random.uniform(0.5 - params["translate"], 0.5 + params["translate"]) * h # Combined rotation matrix, order of operations (right to left) is IMPORTANT matrix = translate @ shear @ rotate @ perspective @ center if (border[0] != 0) or (border[1] != 0) or (matrix != numpy.eye(3)).any(): # image changed image = cv2.warpAffine(image, matrix[:2], dsize=(w, h), borderValue=(0, 0, 0)) # Transform label coordinates n = len(label) if n: xy = numpy.ones((n * 4, 3)) xy[:, :2] = label[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 xy = xy @ matrix.T # transform xy = xy[:, :2].reshape(n, 8) # perspective rescale or affine # create new boxes x = xy[:, [0, 2, 4, 6]] y = xy[:, [1, 3, 5, 7]] box = numpy.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T # clip box[:, [0, 2]] = box[:, [0, 2]].clip(0, w) box[:, [1, 3]] = box[:, [1, 3]].clip(0, h) # filter candidates indices = candidates(box1=label[:, 1:5].T * s, box2=box.T) label = label[indices] label[:, 1:5] = box[indices] return image, label def mix_up(image1, label1, image2, label2): # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf alpha = numpy.random.beta(a=32.0, b=32.0) # mix-up ratio, alpha=beta=32.0 image = (image1 * alpha + image2 * (1 - alpha)).astype(numpy.uint8) label = numpy.concatenate((label1, label2), 0) return image, label class Albumentations: def __init__(self): self.transform = None try: import albumentations transforms = [ albumentations.Blur(p=0.01), albumentations.CLAHE(p=0.01), albumentations.ToGray(p=0.01), albumentations.MedianBlur(p=0.01), ] self.transform = albumentations.Compose( transforms, albumentations.BboxParams(format="yolo", label_fields=["class_labels"]) ) except ImportError: # package not installed, skip pass def __call__(self, image, box, cls): if self.transform: x = self.transform(image=image, bboxes=box, class_labels=cls) image = x["image"] box = numpy.array(x["bboxes"]) cls = numpy.array(x["class_labels"]) return image, box, cls