优化文档字符串,明确参数说明
This commit is contained in:
@@ -64,7 +64,7 @@ class FedYoloClient(object):
|
||||
"""
|
||||
Load the local training dataset
|
||||
Args:
|
||||
:param train_dataset: Training dataset
|
||||
train_dataset: Training dataset
|
||||
"""
|
||||
self.train_dataset = train_dataset
|
||||
self.n_data = len(self.train_dataset)
|
||||
@@ -72,8 +72,9 @@ class FedYoloClient(object):
|
||||
def update(self, Global_model_state_dict):
|
||||
"""
|
||||
Update the local model with the global model parameters
|
||||
|
||||
Args:
|
||||
:param Global_model_state_dict: State dictionary of the global model
|
||||
Global_model_state_dict: State dictionary of the global model
|
||||
"""
|
||||
|
||||
if not hasattr(self, "model") or self.model is None:
|
||||
@@ -85,7 +86,15 @@ class FedYoloClient(object):
|
||||
def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
|
||||
"""
|
||||
Train the local model.
|
||||
Returns: (state_dict, n_data, avg_loss_per_image)
|
||||
|
||||
Args:
|
||||
args: training arguments including
|
||||
|
||||
Returns:
|
||||
(state_dict, n_data, avg_loss_per_image): A tuple including:
|
||||
- state_dict: State dictionary of the trained local model
|
||||
- n_data: Number of training data samples
|
||||
- avg_loss_per_image: Average training loss per image over all epochs
|
||||
"""
|
||||
|
||||
# ---- Dist init (if any) ----
|
||||
|
||||
Reference in New Issue
Block a user