Compare commits
2 Commits
fdb70869f9
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e999853c94 | ||
| f658c6aca5 |
@@ -14,6 +14,7 @@ from fed_algo_cs.server_base import FedYoloServer
|
||||
from utils.args import args_parser # args parser
|
||||
from utils.fed_util import divide_trainset # divide_trainset
|
||||
from utils import util
|
||||
from utils import fed_util
|
||||
from utils.fed_util import prepare_result_dir
|
||||
|
||||
|
||||
@@ -189,7 +190,9 @@ def fed_run():
|
||||
|
||||
# Save final global model weights
|
||||
# FIXME: save model not adaptive YOLOv11-pt specific
|
||||
save_model = {"config": cfg, "model": copy.deepcopy(global_state if global_state else None)}
|
||||
global_model = fed_util.init_model(model_name, num_classes=len(cfg["names"]))
|
||||
global_model.load_state_dict(global_state)
|
||||
save_model = {"config": cfg, "model": copy.deepcopy(global_model if global_model else None)}
|
||||
torch.save(save_model, f"{weights_root}/last.pt")
|
||||
if best == mAP:
|
||||
torch.save(save_model, f"{weights_root}/best.pt")
|
||||
|
||||
@@ -101,6 +101,7 @@ class Dataset(data.Dataset):
|
||||
return len(self.filenames)
|
||||
|
||||
def load_image(self, i):
|
||||
# FIXME: for png maybe have something different with jpg
|
||||
image = cv2.imread(self.filenames[i])
|
||||
if image is None:
|
||||
raise FileNotFoundError(f"Image Not Found {self.filenames[i]}")
|
||||
@@ -165,6 +166,13 @@ class Dataset(data.Dataset):
|
||||
x2b = min(shape[1], x2a - x1a)
|
||||
y2b = min(y2a - y1a, shape[0])
|
||||
|
||||
if (
|
||||
isinstance(x1a, type(None))
|
||||
or isinstance(y1a, type(None))
|
||||
or isinstance(x1b, type(None))
|
||||
or isinstance(y1b, type(None))
|
||||
):
|
||||
raise ValueError("Mosaic calculation error")
|
||||
pad_w = x1a - x1b
|
||||
pad_h = y1a - y1b
|
||||
image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]
|
||||
@@ -205,7 +213,7 @@ class Dataset(data.Dataset):
|
||||
path = f"{os.path.dirname(filenames[0])}.cache"
|
||||
if os.path.exists(path):
|
||||
# XXX: temporarily disable cache
|
||||
os.remove(path)
|
||||
# os.remove(path)
|
||||
pass
|
||||
# return torch.load(path, weights_only=False)
|
||||
x = {}
|
||||
@@ -217,7 +225,10 @@ class Dataset(data.Dataset):
|
||||
image.verify() # PIL verify
|
||||
shape = image.size # image size
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert image.format.lower() in FORMATS, f"invalid image format {image.format}"
|
||||
if image.format:
|
||||
assert image.format.lower() in FORMATS, f"invalid image format {image.format}"
|
||||
else:
|
||||
assert any(filename.lower().endswith(f".{x}") for x in FORMATS), "unknown image format"
|
||||
|
||||
# verify labels
|
||||
a = f"{os.sep}images{os.sep}"
|
||||
|
||||
Reference in New Issue
Block a user