Initial commit
This commit is contained in:
331
main.py
Executable file
331
main.py
Executable file
@@ -0,0 +1,331 @@
|
||||
import copy
|
||||
import csv
|
||||
import os
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import yaml
|
||||
from torch.utils import data
|
||||
from torch.amp.autocast_mode import autocast
|
||||
|
||||
from nets import nn
|
||||
from utils import util
|
||||
from utils.dataset import Dataset
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
data_dir = "/home/image1325/ssd1/dataset/coco"
|
||||
|
||||
|
||||
def train(args, params):
|
||||
# Model
|
||||
model = nn.yolo_v11_n(len(params["names"]))
|
||||
model.cuda()
|
||||
|
||||
# Optimizer
|
||||
accumulate = max(round(64 / (args.batch_size * args.world_size)), 1)
|
||||
params["weight_decay"] *= args.batch_size * args.world_size * accumulate / 64
|
||||
|
||||
optimizer = torch.optim.SGD(
|
||||
util.set_params(model, params["weight_decay"]), params["min_lr"], params["momentum"], nesterov=True
|
||||
)
|
||||
|
||||
# EMA
|
||||
ema = util.EMA(model) if args.local_rank == 0 else None
|
||||
|
||||
filenames = []
|
||||
with open(f"{data_dir}/train2017.txt") as f:
|
||||
for filename in f.readlines():
|
||||
filename = os.path.basename(filename.rstrip())
|
||||
filenames.append(f"{data_dir}/images/train2017/" + filename)
|
||||
|
||||
sampler = None
|
||||
dataset = Dataset(filenames, args.input_size, params, augment=True)
|
||||
|
||||
if args.distributed:
|
||||
sampler = data.DistributedSampler(dataset)
|
||||
|
||||
loader = data.DataLoader(
|
||||
dataset,
|
||||
args.batch_size,
|
||||
sampler is None,
|
||||
sampler,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
collate_fn=Dataset.collate_fn,
|
||||
)
|
||||
|
||||
# Scheduler
|
||||
num_steps = len(loader)
|
||||
scheduler = util.LinearLR(args, params, num_steps)
|
||||
|
||||
if args.distributed:
|
||||
# DDP mode
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
module=model, device_ids=[args.local_rank], output_device=args.local_rank
|
||||
)
|
||||
|
||||
best = 0
|
||||
amp_scale = torch.amp.grad_scaler.GradScaler()
|
||||
criterion = util.ComputeLoss(model, params)
|
||||
|
||||
with open("weights/step.csv", "w") as log:
|
||||
if args.local_rank == 0:
|
||||
logger = csv.DictWriter(
|
||||
log, fieldnames=["epoch", "box", "cls", "dfl", "Recall", "Precision", "mAP@50", "mAP"]
|
||||
)
|
||||
logger.writeheader()
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
model.train()
|
||||
if args.distributed and sampler:
|
||||
sampler.set_epoch(epoch)
|
||||
if args.epochs - epoch == 10:
|
||||
ds = cast(Dataset, loader.dataset)
|
||||
ds.mosaic = False
|
||||
|
||||
p_bar = enumerate(loader)
|
||||
|
||||
if args.local_rank == 0:
|
||||
print(("\n" + "%10s" * 5) % ("epoch", "memory", "box", "cls", "dfl"))
|
||||
p_bar = tqdm.tqdm(p_bar, total=num_steps, ascii=" >-")
|
||||
|
||||
optimizer.zero_grad()
|
||||
avg_box_loss = util.AverageMeter()
|
||||
avg_cls_loss = util.AverageMeter()
|
||||
avg_dfl_loss = util.AverageMeter()
|
||||
for i, (samples, targets) in p_bar:
|
||||
step = i + num_steps * epoch
|
||||
scheduler.step(step, optimizer)
|
||||
|
||||
samples = samples.cuda().float() / 255
|
||||
|
||||
# Forward
|
||||
with autocast("cuda"):
|
||||
outputs = model(samples) # forward
|
||||
loss_box, loss_cls, loss_dfl = criterion(outputs, targets)
|
||||
|
||||
avg_box_loss.update(loss_box.item(), samples.size(0))
|
||||
avg_cls_loss.update(loss_cls.item(), samples.size(0))
|
||||
avg_dfl_loss.update(loss_dfl.item(), samples.size(0))
|
||||
|
||||
loss_box *= args.batch_size # loss scaled by batch_size
|
||||
loss_cls *= args.batch_size # loss scaled by batch_size
|
||||
loss_dfl *= args.batch_size # loss scaled by batch_size
|
||||
loss_box *= args.world_size # gradient averaged between devices in DDP mode
|
||||
loss_cls *= args.world_size # gradient averaged between devices in DDP mode
|
||||
loss_dfl *= args.world_size # gradient averaged between devices in DDP mode
|
||||
|
||||
# Backward
|
||||
amp_scale.scale(loss_box + loss_cls + loss_dfl).backward()
|
||||
|
||||
# Optimize
|
||||
if step % accumulate == 0:
|
||||
# amp_scale.unscale_(optimizer) # unscale gradients
|
||||
# util.clip_gradients(model) # clip gradients
|
||||
amp_scale.step(optimizer) # optimizer.step
|
||||
amp_scale.update()
|
||||
optimizer.zero_grad()
|
||||
if ema:
|
||||
ema.update(model)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Log
|
||||
if args.local_rank == 0:
|
||||
memory = f"{torch.cuda.memory_reserved() / 1e9:.4g}G" # (GB)
|
||||
s = ("%10s" * 2 + "%10.3g" * 3) % (
|
||||
f"{epoch + 1}/{args.epochs}",
|
||||
memory,
|
||||
avg_box_loss.avg,
|
||||
avg_cls_loss.avg,
|
||||
avg_dfl_loss.avg,
|
||||
)
|
||||
p_bar = cast(tqdm.tqdm, p_bar)
|
||||
p_bar.set_description(s)
|
||||
|
||||
if args.local_rank == 0:
|
||||
# mAP
|
||||
last = test(args, params, ema.ema if ema else None)
|
||||
|
||||
logger.writerow(
|
||||
{
|
||||
"epoch": str(epoch + 1).zfill(3),
|
||||
"box": str(f"{avg_box_loss.avg:.3f}"),
|
||||
"cls": str(f"{avg_cls_loss.avg:.3f}"),
|
||||
"dfl": str(f"{avg_dfl_loss.avg:.3f}"),
|
||||
"mAP": str(f"{last[0]:.3f}"),
|
||||
"mAP@50": str(f"{last[1]:.3f}"),
|
||||
"Recall": str(f"{last[2]:.3f}"),
|
||||
"Precision": str(f"{last[3]:.3f}"),
|
||||
}
|
||||
)
|
||||
log.flush()
|
||||
|
||||
# Update best mAP
|
||||
if last[0] > best:
|
||||
best = last[0]
|
||||
|
||||
# Save model
|
||||
save = {"epoch": epoch + 1, "model": copy.deepcopy(ema.ema if ema else None)}
|
||||
|
||||
# Save last, best and delete
|
||||
torch.save(save, f="./weights/last.pt")
|
||||
if best == last[0]:
|
||||
torch.save(save, f="./weights/best.pt")
|
||||
del save
|
||||
|
||||
if args.local_rank == 0:
|
||||
util.strip_optimizer("./weights/best.pt") # strip optimizers
|
||||
util.strip_optimizer("./weights/last.pt") # strip optimizers
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test(args, params, model=None):
|
||||
filenames = []
|
||||
with open(f"{data_dir}/val2017.txt") as f:
|
||||
for filename in f.readlines():
|
||||
filename = os.path.basename(filename.rstrip())
|
||||
filenames.append(f"{data_dir}/images/val2017/" + filename)
|
||||
|
||||
dataset = Dataset(filenames, args.input_size, params, augment=False)
|
||||
loader = data.DataLoader(
|
||||
dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True, collate_fn=Dataset.collate_fn
|
||||
)
|
||||
|
||||
plot = False
|
||||
if not model:
|
||||
plot = True
|
||||
model = torch.load(f="./weights/best.pt", map_location="cuda", weights_only=False)
|
||||
model = model["model"].float().fuse()
|
||||
|
||||
model.half()
|
||||
model.eval()
|
||||
|
||||
# Configure
|
||||
iou_v = torch.linspace(start=0.5, end=0.95, steps=10).cuda() # iou vector for mAP@0.5:0.95
|
||||
n_iou = iou_v.numel()
|
||||
|
||||
m_pre = 0
|
||||
m_rec = 0
|
||||
map50 = 0
|
||||
mean_ap = 0
|
||||
metrics = []
|
||||
p_bar = tqdm.tqdm(loader, desc=("%10s" * 5) % ("", "precision", "recall", "mAP50", "mAP"), ascii=" >-")
|
||||
for samples, targets in p_bar:
|
||||
samples = samples.cuda()
|
||||
samples = samples.half() # uint8 to fp16/32
|
||||
samples = samples / 255.0 # 0 - 255 to 0.0 - 1.0
|
||||
_, _, h, w = samples.shape # batch-size, channels, height, width
|
||||
scale = torch.tensor((w, h, w, h)).cuda()
|
||||
# Inference
|
||||
outputs = model(samples)
|
||||
# NMS
|
||||
outputs = util.non_max_suppression(outputs)
|
||||
# Metrics
|
||||
for i, output in enumerate(outputs):
|
||||
# Ensure idx is a 1D boolean mask (squeeze any trailing dimension) to match cls/box shapes
|
||||
idx = targets["idx"]
|
||||
if idx.dim() > 1:
|
||||
idx = idx.squeeze(-1)
|
||||
idx = idx == i
|
||||
|
||||
# XXX: initially, the code was like below, which caused shape mismatch when idx has extra dimension
|
||||
# idx = targets["idx"] == i
|
||||
cls = targets["cls"][idx]
|
||||
box = targets["box"][idx]
|
||||
|
||||
cls = cls.cuda()
|
||||
box = box.cuda()
|
||||
|
||||
metric = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda()
|
||||
|
||||
if output.shape[0] == 0:
|
||||
if cls.shape[0]:
|
||||
metrics.append((metric, *torch.zeros((2, 0)).cuda(), cls.squeeze(-1)))
|
||||
continue
|
||||
# Evaluate
|
||||
if cls.shape[0]:
|
||||
target = torch.cat(tensors=(cls, util.wh2xy(box) * scale), dim=1)
|
||||
metric = util.compute_metric(output[:, :6], target, iou_v)
|
||||
# Append
|
||||
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
||||
|
||||
# Compute metrics
|
||||
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
|
||||
if len(metrics) and metrics[0].any():
|
||||
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=plot, names=params["names"])
|
||||
# Print results
|
||||
print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
|
||||
# Return results
|
||||
model.float() # for training
|
||||
return mean_ap, map50, m_rec, m_pre
|
||||
|
||||
|
||||
def profile(args, params):
|
||||
import thop
|
||||
|
||||
shape = (1, 3, args.input_size, args.input_size)
|
||||
model = nn.yolo_v11_n(len(params["names"])).fuse()
|
||||
|
||||
model.eval()
|
||||
model(torch.zeros(shape))
|
||||
|
||||
x = torch.empty(shape)
|
||||
flops, num_params = thop.profile(model, inputs=[x], verbose=False)
|
||||
flops, num_params = thop.clever_format(nums=[2 * flops, num_params], format="%.3f")
|
||||
|
||||
if args.local_rank == 0:
|
||||
print(f"Number of parameters: {num_params}")
|
||||
print(f"Number of FLOPs: {flops}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--input-size", default=640, type=int)
|
||||
parser.add_argument("--batch-size", default=32, type=int)
|
||||
parser.add_argument("--local-rank", default=0, type=int)
|
||||
parser.add_argument("--local_rank", default=0, type=int)
|
||||
parser.add_argument("--epochs", default=600, type=int)
|
||||
parser.add_argument("--train", action="store_true")
|
||||
parser.add_argument("--test", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.local_rank = int(os.getenv("LOCAL_RANK", 0))
|
||||
args.world_size = int(os.getenv("WORLD_SIZE", 1))
|
||||
args.distributed = int(os.getenv("WORLD_SIZE", 1)) > 1
|
||||
|
||||
if args.distributed:
|
||||
torch.cuda.set_device(device=args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||
|
||||
if args.local_rank == 0:
|
||||
if not os.path.exists("weights"):
|
||||
os.makedirs("weights")
|
||||
|
||||
with open("utils/args.yaml", errors="ignore") as f:
|
||||
params = yaml.safe_load(f)
|
||||
|
||||
util.setup_seed()
|
||||
util.setup_multi_processes()
|
||||
|
||||
profile(args, params)
|
||||
|
||||
if args.train:
|
||||
train(args, params)
|
||||
if args.test:
|
||||
test(args, params)
|
||||
|
||||
# Clean
|
||||
if args.distributed:
|
||||
torch.distributed.destroy_process_group()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user