distFedPAQ-simulation / distFedPAQ / utils / argument_parser.py
argument_parser.py
Raw
from beartype.typing import Tuple, List, Callable

from argparse import ArgumentParser
import numpy as np
from numpy.typing import NDArray

from distFedPAQ.functional.tools import parallel_trainer, sequential_trainer

__all__ = ["arg_parser"]

__description = """
Test the distFedPAQ module. 
In order to play with it, user needs to run python in interactive mode by running the command
$ python3 -i main.py -n ...
"""


def arg_parser() -> Tuple[
    int,
    int,
    int,
    int,
    int,
    float,
    float,
    bool,
    Callable[..., List[float]],
    NDArray,
    int,
]:

    parser = ArgumentParser(description=__description)

    # > add arguments to parser
    parser.add_argument(
        "--nodes",
        "-n",
        type=int,
        required=True,
        help="number of nodes for the training",
    )
    parser.add_argument(
        "--local-update",
        "-loc",
        type=int,
        default=10,
        help="number of iterations per node for each local update. Default to 10.",
    )
    parser.add_argument(
        "--external-update",
        "-ext",
        type=int,
        default=100,
        help="number of time that some nodes (see -n-ext-ave) will update externally by averaging. Default to 100.",
    )
    parser.add_argument(
        "--nodes-external-averaging",
        "-n-ext-ave",
        type=int,
        default=2,
        help="number nodes will update externally by averaging. Default to 2.",
    )
    parser.add_argument(
        "--batch-size",
        "-bs",
        type=int,
        default=1,
        help="size of a batch from local data of a node in each iteration of its local update. Default to 1.",
    )
    parser.add_argument(
        "--learning-rate",
        "-lr",
        type=float,
        default=1e-3,
        help="learning rate for nodes' local update. Default to 1e-3.",
    )
    parser.add_argument(
        "--momentum",
        "-mom",
        type=float,
        default=0,
        help="momentum for a local update. Default to 0.",
    )
    parser.add_argument(
        "--with-bias",
        "-bias",
        type=str,
        default="y",
        help="Flag for using 'bias' for the model [y/n]. Default to y.",
    )
    parser.add_argument(
        "--trainer",
        "-tr",
        type=str,
        default="seq",
        help="trainer for the nodes [[seq/sequential]/[par/parallel]]. Default to seq",
    )
    parser.add_argument(
        "--repeat-single",
        "-rs",
        type=int,
        default=1,
        help="number of time that the 'single node' will do loops of local updates before evaluation. 1",
    )
    parser.add_argument(
        "--path-to-probability-P",
        "-P-path",
        type=str,
        help="path to the probability file which encode the graph. If not set, probability will be uniform on all nodes.",
    )

    # > argument fetcher from parser
    args = vars(parser.parse_args())

    n = args["nodes"]
    n_loc_update = args["local_update"]
    n_ext_update = args["external_update"]
    nodes_external_averaging = args["nodes_external_averaging"]
    batch_size = args["batch_size"]
    lr = args["learning_rate"]
    mom = args["momentum"]
    add_ones = args["with_bias"] == "y"
    tr = args["trainer"]
    rs = args["repeat_single"]
    path_proba = args["path_to_probability_P"]

    # > checkers
    assert n > 0, f"n ({n}) must be a positive integer"
    assert n_loc_update > 0, f"n_loc_update ({n_loc_update}) must be a positive integer"
    assert batch_size > 0, f"batch size ({batch_size}) must be a positive integer"
    assert lr > 0, f"learning rate ({lr}) must be a positive float"
    assert 0 <= mom and mom < 1, f"momentum ({mom}) must be in [0,1)"
    assert (
        0 < nodes_external_averaging and nodes_external_averaging <= n
    ), f"n-ext-ave ({nodes_external_averaging}) must be positive and cannot exceed n ({n})"
    assert tr in [
        "seq",
        "sequential",
        "par",
        "parallel",
    ], f"trainer ({tr}) must be one of 'seq, sequential, par, parallel'."
    assert (
        rs > 0
    ), f"number of time  ({rs}) that the single node will be repeated must be a positive integer"
    # > end checkers

    trainer = {True: parallel_trainer, False: sequential_trainer}.get("par" in tr)

    if path_proba is not None:
        P = np.genfromtxt(path_proba, dtype=float, delimiter=",", filling_values=0)
        assert (
            len(P.shape) == 2
        ), f"Probability must be a 2 dim array, {len(P.shape)} dim were given."
        assert (
            P.shape[0] == P.shape[1] and P.shape[0] == n
        ), f"Probability must be a {n}x{n} square matrix, {P.shape[0]}x{P.shape[1]} were given."

    else:
        P = 1 / (n - 1) * (np.ones((n, n)) - np.eye(n))
        # For n = 5
        # P = array([[0.  , 0.25, 0.25, 0.25, 0.25],
        #    [0.25, 0.  , 0.25, 0.25, 0.25],
        #    [0.25, 0.25, 0.  , 0.25, 0.25],
        #    [0.25, 0.25, 0.25, 0.  , 0.25],
        #    [0.25, 0.25, 0.25, 0.25, 0.  ]])

    return (
        n,
        n_loc_update,
        n_ext_update,
        nodes_external_averaging,
        batch_size,
        lr,
        mom,
        add_ones,
        trainer,
        P,
        rs,
    )