Compare commits

...

2 Commits

Author SHA1 Message Date
TY1667
e999853c94 fix pylance type warning 2025-11-04 13:18:43 +08:00
f658c6aca5 Fix Error: AttributeError: 'collections.OrderedDict' object has no attribute 'half' 2025-11-03 21:12:35 +08:00
2 changed files with 17 additions and 3 deletions

View File

@@ -14,6 +14,7 @@ from fed_algo_cs.server_base import FedYoloServer
from utils.args import args_parser # args parser
from utils.fed_util import divide_trainset # divide_trainset
from utils import util
from utils import fed_util
from utils.fed_util import prepare_result_dir
@@ -189,7 +190,9 @@ def fed_run():
# Save final global model weights
# FIXME: save model not adaptive YOLOv11-pt specific
save_model = {"config": cfg, "model": copy.deepcopy(global_state if global_state else None)}
global_model = fed_util.init_model(model_name, num_classes=len(cfg["names"]))
global_model.load_state_dict(global_state)
save_model = {"config": cfg, "model": copy.deepcopy(global_model if global_model else None)}
torch.save(save_model, f"{weights_root}/last.pt")
if best == mAP:
torch.save(save_model, f"{weights_root}/best.pt")

View File

@@ -101,6 +101,7 @@ class Dataset(data.Dataset):
return len(self.filenames)
def load_image(self, i):
# FIXME: for png maybe have something different with jpg
image = cv2.imread(self.filenames[i])
if image is None:
raise FileNotFoundError(f"Image Not Found {self.filenames[i]}")
@@ -165,6 +166,13 @@ class Dataset(data.Dataset):
x2b = min(shape[1], x2a - x1a)
y2b = min(y2a - y1a, shape[0])
if (
isinstance(x1a, type(None))
or isinstance(y1a, type(None))
or isinstance(x1b, type(None))
or isinstance(y1b, type(None))
):
raise ValueError("Mosaic calculation error")
pad_w = x1a - x1b
pad_h = y1a - y1b
image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]
@@ -205,7 +213,7 @@ class Dataset(data.Dataset):
path = f"{os.path.dirname(filenames[0])}.cache"
if os.path.exists(path):
# XXX: temporarily disable cache
os.remove(path)
# os.remove(path)
pass
# return torch.load(path, weights_only=False)
x = {}
@@ -217,7 +225,10 @@ class Dataset(data.Dataset):
image.verify() # PIL verify
shape = image.size # image size
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert image.format.lower() in FORMATS, f"invalid image format {image.format}"
if image.format:
assert image.format.lower() in FORMATS, f"invalid image format {image.format}"
else:
assert any(filename.lower().endswith(f".{x}") for x in FORMATS), "unknown image format"
# verify labels
a = f"{os.sep}images{os.sep}"