Files
fed-yolo/utils/fed_util.py

255 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
import random
from collections import defaultdict
from typing import Dict, List, Optional, Set, Any
from nets import nn
def _image_to_label_path(img_path: str) -> str:
"""
Convert an image path like ".../images/train2017/xxx.jpg"
to the corresponding label path ".../labels/train2017/xxx.txt".
Works for POSIX/Windows separators.
"""
# swap "/images/" (or "\images\") to "/labels/"
label_path = re.sub(r"([/\\])images([/\\])", r"\1labels\2", img_path)
# swap extension to .txt
root, _ = os.path.splitext(label_path)
return root + ".txt"
def _parse_yolo_label_file(label_path: str) -> Set[int]:
"""
Return a set of class_ids found in a YOLO .txt label file.
Empty file -> empty set. Missing file -> empty set.
Robust to blank lines / trailing spaces.
Args:
label_path: path to the label file
Returns:
set of class IDs (integers) found in the file
"""
class_ids: Set[int] = set()
if not os.path.exists(label_path):
return class_ids
try:
with open(label_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
# YOLO format: cls cx cy w h
parts = line.split()
if not parts:
continue
try:
cls = int(parts[0])
except ValueError:
# handle weird case like '23.0'
try:
cls = int(float(parts[0]))
except ValueError:
# skip malformed line
continue
class_ids.add(cls)
except Exception:
# If the file can't be read for some reason, treat as no labels
return set()
return class_ids
def divide_trainset(
trainset_path: str,
num_local_class: int,
num_client: int,
min_data: int,
max_data: int,
mode: str = "overlap", # "overlap" or "disjoint"
seed: Optional[int] = None,
) -> Dict[str, Any]:
"""
Build a federated split from a YOLO dataset list file.
Args:
trainset_path: path to a .txt file containing one image path per line
e.g. /COCO/images/train2017/1111.jpg
num_local_class: how many distinct classes to sample for each client
num_client: number of clients
min_data: minimum number of images per client
max_data: maximum number of images per client
mode: "overlap" -> images may be shared across clients
"disjoint" -> each image is used by at most one client
seed: optional random seed for reproducibility
Returns:
trainset_divided = {
"users": ["c_00001", ...],
"user_data": {
"c_00001": {"filename": [img_path, ...]},
...
},
"num_samples": [len(list_for_user1), len(list_for_user2), ...]
}
Example:
dataset = divide_trainset(
trainset_path="/COCO/train2017.txt",
num_local_class=3,
num_client=5,
min_data=10,
max_data=20,
mode="disjoint", # or "overlap"
seed=42
)
print(dataset["users"]) # ['c_00001', ..., 'c_00005']
print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
print(dataset["user_data"]["c_00001"]["filename"][:3])
"""
if seed is not None:
random.seed(seed)
# ---- Basic validations (defensive programming) ----
if num_client <= 0:
raise ValueError("num_client must be > 0")
if num_local_class <= 0:
raise ValueError("num_local_class must be > 0")
if min_data < 0 or max_data < 0:
raise ValueError("min_data/max_data must be >= 0")
if max_data < min_data:
raise ValueError("max_data must be >= min_data")
if mode not in {"overlap", "disjoint"}:
raise ValueError('mode must be "overlap" or "disjoint"')
# ---- 1) Read image list ----
with open(trainset_path, "r", encoding="utf-8") as f:
all_images_raw = [ln.strip() for ln in f if ln.strip()]
# Normalize and deduplicate image paths (safe)
all_images: List[str] = []
seen = set()
for p in all_images_raw:
# keep exact string (dont join with cwd), just normalize slashes
norm = os.path.normpath(p)
if norm not in seen:
seen.add(norm)
all_images.append(norm)
# ---- 2) Build mappings from labels ----
class_to_images: Dict[int, Set[str]] = defaultdict(set)
image_to_classes: Dict[str, Set[int]] = {}
missing_label_files = 0
empty_label_files = 0
parsed_images = 0
for img in all_images:
lbl = _image_to_label_path(img)
if not os.path.exists(lbl):
# Missing labels: skip image (no class info)
missing_label_files += 1
continue
classes = _parse_yolo_label_file(lbl)
if not classes:
# No objects in this image -> skip (no class bucket)
empty_label_files += 1
continue
image_to_classes[img] = classes
for c in classes:
class_to_images[c].add(img)
parsed_images += 1
if not class_to_images:
# No usable images found
return {
"users": [f"c_{i + 1:05d}" for i in range(num_client)],
"user_data": {f"c_{i + 1:05d}": {"filename": []} for i in range(num_client)},
"num_samples": [0 for _ in range(num_client)],
}
all_classes: List[int] = sorted(class_to_images.keys())
# Available pool for disjoint mode (only images with labels)
available_images: Set[str] = set(image_to_classes.keys())
# ---- 3) Allocate to clients ----
result = {"users": [], "user_data": {}, "num_samples": []}
for cid in range(num_client):
user_id = f"c_{cid + 1:05d}"
result["users"].append(user_id)
# Pick the classes for this client (sample without replacement from global class set)
k = min(num_local_class, len(all_classes))
chosen_classes = random.sample(all_classes, k) if k > 0 else []
# Decide how many samples for this client
need = min_data if min_data == max_data else random.randint(min_data, max_data)
# Build the candidate pool for this client
if mode == "overlap":
pool_set: Set[str] = set()
for c in chosen_classes:
pool_set.update(class_to_images[c])
else: # "disjoint": restrict to currently available images
pool_set = set()
for c in chosen_classes:
# intersect with available images
pool_set.update(class_to_images[c] & available_images)
# Deduplicate and sample
pool_list = list(pool_set)
if len(pool_list) <= need:
chosen_imgs = pool_list[:] # take all (can be fewer than need)
else:
chosen_imgs = random.sample(pool_list, need)
# Record for the user
result["user_data"][user_id] = {"filename": chosen_imgs}
result["num_samples"].append(len(chosen_imgs))
# If disjoint, remove selected images from availability everywhere
if mode == "disjoint" and chosen_imgs:
for img in chosen_imgs:
if img in available_images:
available_images.remove(img)
# remove from every class bucket this image belongs to
for c in image_to_classes.get(img, []):
if img in class_to_images[c]:
class_to_images[c].remove(img)
# Optional: prune empty classes from all_classes to speed up later loops
# (keep list stable; just skip empties naturally)
# (Optional) You can print some quick diagnostics if helpful:
# print(f"[INFO] Parsed images with labels: {parsed_images}")
# print(f"[INFO] Missing label files: {missing_label_files}")
# print(f"[INFO] Empty label files: {empty_label_files}")
return result
def init_model(model_name, num_classes):
"""
Initialize the model for a specific learning task
Args:
:param model_name: Name of the model
:param num_classes: Number of classes
"""
model = None
if model_name == "yolo_v11_n":
model = nn.yolo_v11_n(num_classes=num_classes)
elif model_name == "yolo_v11_s":
model = nn.yolo_v11_s(num_classes=num_classes)
elif model_name == "yolo_v11_m":
model = nn.yolo_v11_m(num_classes=num_classes)
elif model_name == "yolo_v11_l":
model = nn.yolo_v11_l(num_classes=num_classes)
elif model_name == "yolo_v11_x":
model = nn.yolo_v11_x(num_classes=num_classes)
else:
raise ValueError("Model {} is not supported.".format(model_name))
return model