优化FedYoloClient和FedYoloServer类

This commit is contained in:
TY1667
2025-10-19 21:27:19 +08:00
parent 101ffa51eb
commit 40de29591b
2 changed files with 270 additions and 232 deletions

View File

@@ -3,11 +3,11 @@ import torch
from torch import nn from torch import nn
from torch.utils import data from torch.utils import data
from torch.amp.autocast_mode import autocast from torch.amp.autocast_mode import autocast
from tqdm import tqdm
from utils.fed_util import init_model from utils.fed_util import init_model
from utils import util from utils import util
from utils.dataset import Dataset from utils.dataset import Dataset
from typing import cast from typing import cast
from tqdm import tqdm
class FedYoloClient(object): class FedYoloClient(object):
@@ -82,52 +82,48 @@ class FedYoloClient(object):
# load the global model parameters # load the global model parameters
self.model.load_state_dict(Global_model_state_dict, strict=True) self.model.load_state_dict(Global_model_state_dict, strict=True)
def train(self, args): def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
""" """
Train the local model Train the local model.
Args: Returns: (state_dict, n_data, avg_loss_per_image)
:param args: Command line arguments
- local_rank: Local rank for distributed training
- world_size: World size for distributed training
- distributed: Whether to use distributed training
- input_size: Input size for the model
Returns:
:return: Local updated model, number of local data points, training loss
""" """
# ---- Dist init (if any) ----
if args.distributed: if args.distributed:
torch.cuda.set_device(device=args.local_rank) torch.cuda.set_device(device=args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://") torch.distributed.init_process_group(backend="nccl", init_method="env://")
# print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}")
# self._device = torch.device("cuda", local_rank)
if args.local_rank == 0:
pass
# if not os.path.exists("weights"):
# os.makedirs("weights")
util.setup_seed() util.setup_seed()
util.setup_multi_processes() util.setup_multi_processes()
# model # device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu")
# init model have been done in __init__() # self.model.to(device)
self.model.to(self._device) self.model.cuda()
# show model architecture
# print(self.model)
# Optimizer # ---- Optimizer / WD scaling & LR warmup/schedule ----
accumulate = max(round(64 / (self._batch_size * args.world_size)), 1) # accumulate = effective grad-accumulation steps to emulate global batch 64
self._weight_decay = self._batch_size * args.world_size * accumulate / 64 world_size = getattr(args, "world_size", 1)
accumulate = max(round(64 / (self._batch_size * max(world_size, 1))), 1)
# scale weight_decay like YOLO recipes
scaled_wd = self._weight_decay * self._batch_size * max(world_size, 1) * accumulate / 64
optimizer = torch.optim.SGD( optimizer = torch.optim.SGD(
util.set_params(self.model, self._weight_decay), util.set_params(self.model, scaled_wd),
lr=self._min_lr, lr=self._min_lr,
momentum=self._momentum, momentum=self._momentum,
nesterov=True, nesterov=True,
) )
# EMA # ---- EMA (track the underlying module if DDP) ----
# track_model = self.model.module if is_ddp else self.model
ema = util.EMA(self.model) if args.local_rank == 0 else None ema = util.EMA(self.model) if args.local_rank == 0 else None
data_set = Dataset( print(type(self.train_dataset))
# ---- Data ----
dataset = Dataset(
filenames=self.train_dataset, filenames=self.train_dataset,
input_size=args.input_size, input_size=args.input_size,
params=self.params, params=self.params,
@@ -136,26 +132,28 @@ class FedYoloClient(object):
if args.distributed: if args.distributed:
train_sampler = data.DistributedSampler( train_sampler = data.DistributedSampler(
data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True dataset, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
) )
else: else:
train_sampler = None train_sampler = None
loader = data.DataLoader( loader = data.DataLoader(
data_set, dataset,
batch_size=self._batch_size, batch_size=self._batch_size,
shuffle=train_sampler is None, shuffle=(train_sampler is None),
sampler=train_sampler, sampler=train_sampler,
num_workers=self.num_workers, num_workers=self.num_workers,
pin_memory=True, pin_memory=True,
collate_fn=Dataset.collate_fn, collate_fn=Dataset.collate_fn,
drop_last=False,
) )
# Scheduler
num_steps = max(1, len(loader)) num_steps = max(1, len(loader))
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps) scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
# DDP mode
if args.distributed: # ---- SyncBN + DDP (if any) ----
is_ddp = bool(args.distributed)
if is_ddp:
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model) self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
self.model = nn.parallel.DistributedDataParallel( self.model = nn.parallel.DistributedDataParallel(
module=self.model, module=self.model,
@@ -164,102 +162,133 @@ class FedYoloClient(object):
find_unused_parameters=False, find_unused_parameters=False,
) )
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True) # ---- AMP + loss ----
scaler = torch.amp.grad_scaler.GradScaler(enabled=True)
# criterion = util.ComputeLoss(
# self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model,
# self.params,
# )
criterion = util.ComputeLoss(self.model, self.params) criterion = util.ComputeLoss(self.model, self.params)
# log # ---- Training ----
# if args.local_rank == 0:
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
# print("\n" + header)
# p_bar = tqdm(total=args.epochs * num_steps, ncols=120)
# p_bar.set_description(f"{self.name:>10}")
for epoch in range(args.epochs): for epoch in range(args.epochs):
# (self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model).train()
self.model.train() self.model.train()
# when distributed, set epoch for shuffling if is_ddp and train_sampler is not None:
if args.distributed and train_sampler is not None:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
if args.epochs - epoch == 10: # disable mosaic in the last 10 epochs (if dataset supports it)
# disable mosaic augmentation in the last 10 epochs if args.epochs - epoch == 10 and hasattr(loader.dataset, "mosaic"):
ds = cast(Dataset, loader.dataset) ds = cast(Dataset, loader.dataset)
ds.mosaic = False ds.mosaic = False
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
avg_box_loss = util.AverageMeter() loss_box_meter = util.AverageMeter()
avg_cls_loss = util.AverageMeter() loss_cls_meter = util.AverageMeter()
avg_dfl_loss = util.AverageMeter() loss_dfl_meter = util.AverageMeter()
# # --- header (once per epoch, YOLO-style) --- for i, (images, targets) in enumerate(loader):
# if args.local_rank == 0: print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}")
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl") step = i + epoch * num_steps
# print("\n" + header)
# p_bar = enumerate(loader) # scheduler per-step (your util.LinearLR expects step)
# if args.local_rank == 0: scheduler.step(step=step, optimizer=optimizer)
# p_bar = tqdm(p_bar, total=num_steps, ncols=120)
for i, (samples, targets) in enumerate(loader): # images = images.to(device, non_blocking=True).float() / 255.0
global_step = i + num_steps * epoch images = images.cuda().float() / 255.0
scheduler.step(step=global_step, optimizer=optimizer) bs = images.size(0)
# total_imgs_seen += bs
samples = samples.cuda(non_blocking=True).float() / 255.0 # targets: keep as your ComputeLoss expects (often CPU lists/tensors).
# Move to GPU here only if your loss requires it.
# Forward with autocast(device_type="cuda", enabled=True):
with autocast("cuda", enabled=True): outputs = self.model(images) # DDP wraps forward
outputs = self.model(samples)
box_loss, cls_loss, dfl_loss = criterion(outputs, targets) box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
# meters (use the *unscaled* values) # total_loss = box_loss + cls_loss + dfl_loss
bs = samples.size(0) # Gradient accumulation: normalize by 'accumulate' so LR stays effective
avg_box_loss.update(box_loss.item(), bs) # total_loss = total_loss / accumulate
avg_cls_loss.update(cls_loss.item(), bs)
avg_dfl_loss.update(dfl_loss.item(), bs)
# scale losses by batch/world if your loss is averaged internally per-sample/device # IMPORTANT: assume criterion returns **average per image** in the batch.
# box_loss = box_loss * self._batch_size * args.world_size # Keep logging on the true (unscaled) values:
# cls_loss = cls_loss * self._batch_size * args.world_size loss_box_meter.update(box_loss.item(), bs)
# dfl_loss = dfl_loss * self._batch_size * args.world_size loss_cls_meter.update(cls_loss.item(), bs)
loss_dfl_meter.update(dfl_loss.item(), bs)
box_loss *= self._batch_size
cls_loss *= self._batch_size
dfl_loss *= self._batch_size
box_loss *= args.world_size
cls_loss *= args.world_size
dfl_loss *= args.world_size
total_loss = box_loss + cls_loss + dfl_loss total_loss = box_loss + cls_loss + dfl_loss
# Backward scaler.scale(total_loss).backward()
amp_scale.scale(total_loss).backward()
# Optimize # optimize
if (i + 1) % accumulate == 0: if step % accumulate == 0:
amp_scale.unscale_(optimizer) # unscale gradients # scaler.unscale_(optimizer)
util.clip_gradients(model=self.model, max_norm=10.0) # clip gradients # util.clip_gradients(self.model)
amp_scale.step(optimizer) scaler.step(optimizer)
amp_scale.update() scaler.update()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
if ema:
ema.update(self.model)
# torch.cuda.synchronize() # # Step when we have 'accumulate' micro-batches, or at the end
# if ((i + 1) % accumulate == 0) or (i + 1 == len(loader)):
# tqdm update # scaler.unscale_(optimizer)
# if args.local_rank == 0: # util.clip_gradients(
# mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G" # model=(
# desc = ("%10s" * 2 + "%10.4g" * 3) % ( # self.model.module
# self.name, # if isinstance(self.model, nn.parallel.DistributedDataParallel)
# mem, # else self.model
# avg_box_loss.avg, # ),
# avg_cls_loss.avg, # max_norm=10.0,
# avg_dfl_loss.avg,
# ) # )
# cast(tqdm, p_bar).set_description(desc) # scaler.step(optimizer)
# p_bar.update(1) # scaler.update()
# optimizer.zero_grad(set_to_none=True)
# p_bar.close() if ema:
# Update EMA from the underlying module
ema.update(
self.model.module
if isinstance(self.model, nn.parallel.DistributedDataParallel)
else self.model
)
# print loss to test
print(
f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}"
)
torch.cuda.synchronize()
# clean # ---- Final average loss (per image) over the whole epoch span ----
if args.distributed: avg_loss_per_image = loss_box_meter.avg + loss_cls_meter.avg + loss_dfl_meter.avg
# ---- Cleanup DDP ----
if is_ddp:
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return ( # ---- Choose which weights to return ----
self.model.state_dict() if not ema else ema.ema.state_dict(), # - If EMA exists, return EMA weights (common YOLO eval practice)
self.n_data, # - Be careful with DDP: grab state_dict from the underlying module / EMA model
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg}, if ema:
) # print("Using EMA weights")
return (ema.ema.state_dict(), self.n_data, avg_loss_per_image)
else:
# Safely get the underlying module if wrapped by DDP; getattr returns the module or the original object.
model_obj = getattr(self.model, "module", self.model)
# If it's a proper nn.Module, call state_dict(); if it's already a state dict, use it;
# otherwise try to call state_dict() and finally fall back to wrapping the object.
if isinstance(model_obj, torch.nn.Module):
model_to_return = model_obj.state_dict()
elif isinstance(model_obj, dict):
model_to_return = model_obj
else:
try:
model_to_return = model_obj.state_dict()
except Exception:
# fallback: if model_obj is a tensor or unexpected object, wrap it in a dict
model_to_return = {"state": model_obj}
return model_to_return, self.n_data, avg_loss_per_image

View File

@@ -4,6 +4,7 @@ from torch.utils.data import DataLoader
from utils.fed_util import init_model from utils.fed_util import init_model
from utils.dataset import Dataset from utils.dataset import Dataset
from utils import util from utils import util
from nets import YOLO
class FedYoloServer(object): class FedYoloServer(object):
@@ -21,7 +22,7 @@ class FedYoloServer(object):
self.client_n_data = {} self.client_n_data = {}
self.selected_clients = [] self.selected_clients = []
self._batch_size = params.get("val_batch_size", 4) self._batch_size = params.get("val_batch_size", 200)
self.client_list = client_list self.client_list = client_list
self.valset = None self.valset = None
@@ -40,7 +41,7 @@ class FedYoloServer(object):
self.model = init_model(model_name, self._num_classes) self.model = init_model(model_name, self._num_classes)
self.params = params self.params = params
def load_valset(self, valset): def load_valset(self, valset: Dataset):
"""Server loads the validation dataset.""" """Server loads the validation dataset."""
self.valset = valset self.valset = valset
@@ -48,78 +49,6 @@ class FedYoloServer(object):
"""Return global model weights.""" """Return global model weights."""
return self.model.state_dict() return self.model.state_dict()
@torch.no_grad()
def test(self, args) -> dict:
"""
Test the global model on the server's validation set.
Returns:
dict with keys: mAP, mAP50, precision, recall
"""
if self.valset is None:
return {}
loader = DataLoader(
self.valset,
batch_size=self._batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
collate_fn=Dataset.collate_fn,
)
dev = self._device
# move to device for eval; keep in float32 for stability
self.model.eval().to(dev).float()
iou_v = torch.linspace(0.5, 0.95, 10, device=dev)
n_iou = iou_v.numel()
metrics = []
for samples, targets in loader:
samples = samples.to(dev, non_blocking=True).float() / 255.0
_, _, h, w = samples.shape
scale = torch.tensor((w, h, w, h), device=dev)
outputs = self.model(samples)
outputs = util.non_max_suppression(outputs)
for i, output in enumerate(outputs):
idx = targets["idx"] == i
cls = targets["cls"][idx].to(dev)
box = targets["box"][idx].to(dev)
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=dev)
if output.shape[0] == 0:
if cls.shape[0]:
metrics.append((metric, *torch.zeros((2, 0), device=dev), cls.squeeze(-1)))
continue
if cls.shape[0]:
if cls.dim() == 1:
cls = cls.unsqueeze(1)
box_xy = util.wh2xy(box)
if not isinstance(box_xy, torch.Tensor):
box_xy = torch.tensor(box_xy, device=dev)
target = torch.cat((cls, box_xy * scale), dim=1)
metric = util.compute_metric(output[:, :6], target, iou_v)
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
if not metrics:
# move back to CPU before returning
self.model.to("cpu").float()
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
if len(metrics) and metrics[0].any():
_, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False)
else:
prec, rec, map50, mean_ap = 0, 0, 0, 0
# return model to CPU so next agg() stays device-consistent
self.model.to("cpu").float()
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
def select_clients(self, connection_ratio=1.0): def select_clients(self, connection_ratio=1.0):
""" """
Randomly select a fraction of clients. Randomly select a fraction of clients.
@@ -130,80 +59,69 @@ class FedYoloServer(object):
self.n_data = 0 self.n_data = 0
for client_id in self.client_list: for client_id in self.client_list:
# Random selection based on connection ratio # Random selection based on connection ratio
if np.random.rand() <= connection_ratio: s = np.random.binomial(np.ones(1).astype(int), connection_ratio)
if s[0] == 1:
self.selected_clients.append(client_id) self.selected_clients.append(client_id)
self.n_data += self.client_n_data.get(client_id, 0) self.n_data += self.client_n_data[client_id]
@torch.no_grad()
def agg(self): def agg(self):
"""Aggregate client updates (FedAvg) on CPU/FP32, preserving non-float buffers.""" """
Server aggregates the local updates from selected clients using FedAvg.
:return: model_state: aggregated model weights
:return: avg_loss: weighted average training loss across selected clients
:return: n_data: total number of data points across selected clients
"""
if len(self.selected_clients) == 0 or self.n_data == 0: if len(self.selected_clients) == 0 or self.n_data == 0:
return self.model.state_dict(), {}, 0 import warnings
# Ensure global model is on CPU for safe load later warnings.warn("No clients selected or no data available for aggregation.")
self.model.to("cpu") return self.model.state_dict(), 0, 0
global_state = self.model.state_dict() # may hold CPU or CUDA refs; were on CPU now
avg_loss = {} # Initialize a model for aggregation
total_n = float(self.n_data) model = init_model(model_name=self.model_name, num_classes=self._num_classes)
model_state = model.state_dict()
# Prepare accumulators on CPU. For floating tensors, use float32 zeros. avg_loss = 0
# For non-floating tensors (e.g., BN num_batches_tracked int64), well copy from the first client.
new_state = {}
first_client = None
for cid in self.selected_clients:
if cid in self.client_state:
first_client = cid
break
assert first_client is not None, "No client states available to aggregate." # Aggregate the local updated models from selected clients
for i, name in enumerate(self.selected_clients):
for k, v in global_state.items(): if name not in self.client_state:
if v.is_floating_point():
new_state[k] = torch.zeros_like(v.detach().cpu(), dtype=torch.float32)
else:
# For non-float buffers, just copy from the first client (or keep global)
new_state[k] = self.client_state[first_client][k].clone()
# Accumulate floating tensors with weights; keep non-floats as assigned above
for cid in self.selected_clients:
if cid not in self.client_state:
continue continue
weight = self.client_n_data[cid] / total_n for key in self.client_state[name]:
cst = self.client_state[cid] if i == 0:
for k in new_state.keys(): # First client, initialize the model_state
if new_state[k].is_floating_point(): model_state[key] = self.client_state[name][key] * (self.client_n_data[name] / self.n_data)
# cst[k] is CPU; ensure float32 for accumulation else:
new_state[k].add_(cst[k].to(torch.float32), alpha=weight) # math equation: w = sum(n_k / n * w_k)
model_state[key] = model_state[key] + self.client_state[name][key] * (
# weighted average losses self.client_n_data[name] / self.n_data
for lk, lv in self.client_loss[cid].items(): )
avg_loss[lk] = avg_loss.get(lk, 0.0) + float(lv) * weight avg_loss = avg_loss + self.client_loss[name] * (self.client_n_data[name] / self.n_data)
# Load aggregated state back into the global model (model is on CPU)
with torch.no_grad():
self.model.load_state_dict(new_state, strict=True)
self.model.load_state_dict(model_state, strict=True)
self.round += 1 self.round += 1
# Return CPU state_dict (good for broadcasting to clients)
return {k: v.clone() for k, v in self.model.state_dict().items()}, avg_loss, int(self.n_data)
def rec(self, name, state_dict, n_data, loss_dict): n_data = self.n_data
return model_state, avg_loss, n_data
def rec(self, name, state_dict, n_data, loss):
""" """
Receive local update from a client. Receive local update from a client.
- Store all floating tensors as CPU float32 - Store all floating tensors as CPU float32
- Store non-floating tensors (e.g., BN counters) as CPU in original dtype - Store non-floating tensors (e.g., BN counters) as CPU in original dtype
""" """
self.n_data += n_data self.n_data += n_data
safe_state = {}
with torch.no_grad(): self.client_state[name] = {}
for k, v in state_dict.items(): self.client_n_data[name] = {}
t = v.detach().cpu() self.client_loss[name] = {}
if t.is_floating_point():
t = t.to(torch.float32) self.client_state[name].update(state_dict)
safe_state[k] = t
self.client_state[name] = safe_state
self.client_n_data[name] = int(n_data) self.client_n_data[name] = int(n_data)
self.client_loss[name] = {k: float(v) for k, v in loss_dict.items()} self.client_loss[name] = loss
def flush(self): def flush(self):
"""Clear stored client updates.""" """Clear stored client updates."""
@@ -211,3 +129,94 @@ class FedYoloServer(object):
self.client_state.clear() self.client_state.clear()
self.client_n_data.clear() self.client_n_data.clear()
self.client_loss.clear() self.client_loss.clear()
def test(self):
"""Evaluate the global model on the server's validation dataset."""
if self.valset is None:
import warnings
warnings.warn("No validation dataset available for testing.")
return {}
return test(self.valset, self.params, self.model)
@torch.no_grad()
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
"""
Evaluate the model on the validation dataset.
Args:
valset: validation dataset
params: dict of parameters (must include 'names')
model: YOLO model to evaluate
batch_size: batch size for evaluation
Returns:
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
"""
loader = DataLoader(
dataset=valset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
collate_fn=Dataset.collate_fn,
)
model.cuda()
model.half()
model.eval()
# Configure
iou_v = torch.linspace(start=0.5, end=0.95, steps=10).cuda() # iou vector for mAP@0.5:0.95
n_iou = iou_v.numel()
m_pre = 0
m_rec = 0
map50 = 0
mean_ap = 0
metrics = []
for samples, targets in loader:
samples = samples.cuda()
samples = samples.half() # uint8 to fp16/32
samples = samples / 255.0 # 0 - 255 to 0.0 - 1.0
_, _, h, w = samples.shape # batch-size, channels, height, width
scale = torch.tensor((w, h, w, h)).cuda()
# Inference
outputs = model(samples)
# NMS
outputs = util.non_max_suppression(outputs)
# Metrics
for i, output in enumerate(outputs):
idx = targets["idx"]
if idx.dim() > 1:
idx = idx.squeeze(-1)
idx = idx == i
# idx = targets["idx"] == i
cls = targets["cls"][idx]
box = targets["box"][idx]
cls = cls.cuda()
box = box.cuda()
metric = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda()
if output.shape[0] == 0:
if cls.shape[0]:
metrics.append((metric, *torch.zeros((2, 0)).cuda(), cls.squeeze(-1)))
continue
# Evaluate
if cls.shape[0]:
target = torch.cat(tensors=(cls, util.wh2xy(box) * scale), dim=1)
metric = util.compute_metric(output[:, :6], target, iou_v)
# Append
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
# Compute metrics
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
if len(metrics) and metrics[0].any():
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"])
# Print results
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
# Return results
model.float() # for training
return mean_ap, map50, m_rec, m_pre