优化验证集文件读取逻辑
This commit is contained in:
@@ -267,7 +267,7 @@ def init_model(model_name, num_classes) -> YOLO:
|
||||
return model
|
||||
|
||||
|
||||
def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]:
|
||||
def build_valset_if_available(cfg, params, args=None, val_name: str = "val2017") -> Optional[Dataset]:
|
||||
"""
|
||||
Try to build a validation Dataset.
|
||||
- If cfg['val_txt'] exists, use it.
|
||||
@@ -276,6 +276,8 @@ def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]:
|
||||
Args:
|
||||
cfg: config dict
|
||||
params: params dict for Dataset
|
||||
args: optional args object (for input_size)
|
||||
val_name: name of the validation set folder with no prefix (default: "val2017")
|
||||
Returns:
|
||||
Dataset or None
|
||||
"""
|
||||
@@ -283,18 +285,24 @@ def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]:
|
||||
val_txt = cfg.get("val_txt", "")
|
||||
if not val_txt:
|
||||
ds_root = cfg.get("dataset_path", "")
|
||||
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
|
||||
guess = os.path.join(ds_root, f"{val_name}.txt") if ds_root else ""
|
||||
val_txt = guess if os.path.exists(guess) else ""
|
||||
|
||||
val_files = _read_list_file(val_txt)
|
||||
if not val_files:
|
||||
# val_files = _read_list_file(val_txt)
|
||||
|
||||
filenames = []
|
||||
with open(val_txt, "r", encoding="utf-8") as f:
|
||||
for filename in f.readlines():
|
||||
filename = os.path.basename(filename.rstrip())
|
||||
filenames.append(f"{ds_root}/images/{val_name}/" + filename)
|
||||
if not filenames:
|
||||
import warnings
|
||||
|
||||
warnings.warn("No validation dataset found.")
|
||||
return None
|
||||
|
||||
return Dataset(
|
||||
filenames=val_files,
|
||||
filenames=filenames,
|
||||
input_size=input_size,
|
||||
params=params,
|
||||
augment=True,
|
||||
|
Reference in New Issue
Block a user