distFedPAQ-simulation / distFedPAQ / nodes / nodes.py
nodes.py
Raw
from __future__ import annotations

from distFedPAQ.datasets import Data

import numpy as np
from numpy.typing import NDArray

from beartype import beartype
from beartype.typing import Union, Tuple, List, Optional, Callable

__all__ = ["Node"]


class Node:
    @beartype
    def __init__(
        self,
        data: NDArray,
        weight_size: Union[Tuple, int],
        grad_f: Callable,
        batch_size: int = 1,
        momentum: float = 0,
    ) -> None:
        """
        Trainer node with local `sgd` updater.

        Parameters
        ----------
        data : NDArray
            local dataset for the node
        weight_size : Union[Tuple, int]
            size of the weight of the objective function `f`.
        grad_f : Callable
            gradient of the objective function `f`. (will be use for sgd local update)
        batch_size : int, optional
            size of the batch of local dataset's sampler, by default 1
        lr : float, optional
            learning rate for sgd local update, by default 1e-2
        use_tqdm : bool, optional
            boolean flag for viewing local update, by default False
        """
        self.local_data = Data(content=data, batch_size=batch_size)

        self.__batch_size = batch_size

        self.grad_f = grad_f  # gradient of the objective function
        self.__w = 5 * np.random.randn(
            *weight_size
        )  # local weight for the objective function
        self.__grad = np.zeros_like(self.__w)
        self.__momentum = momentum
        self.external_update = self.__external_update

    @property
    def data_batch_size(self):
        """
        Retrieve the value of `batch_size` of the sampler of the local dataset.

        Returns
        -------
        int
            the value of `batch_size` of the local dataset
        """
        return self.__batch_size

    @beartype
    @data_batch_size.setter
    def data_batch_size(self, value: int):
        """
        Update the `batch_size` for the sampler of the local dataset.

        Parameters
        ----------
        value : int
            new value for `batch_size`
        """
        self.__batch_size = value
        self.local_data.batch_size = value

    @property
    def w(self):
        return self.__w

    @beartype
    @w.setter
    def w(self, new_w: NDArray[np.float64]):
        assert (
            new_w.shape == self.__w.shape
        ), f"Incorrect shape for new weight. {self.__w.shape} expected but {new_w.shape} were given."
        self.__w = new_w

    @beartype
    def local_update(self, n_iter: int, lr: float = 1e-4):
        """
        Update the local parameter `w` with `sgd` on the local dataset.

        Parameters
        ----------
        n_iter : int
            Number of iterations for the update.
        """
        for _ in range(n_iter):
            x = self.local_data.sample()

            self.__grad = self.__momentum * self.__grad + (
                1 - self.__momentum
            ) * self.grad_f(x, self.__w)

            self.__w -= lr * self.__grad

        # import time
        # time.sleep(0.5)

    @beartype
    def __external_update(
        self,
        *others,
        weight: Optional[float] = None,
        others_weights: Optional[List[float]] = None,
    ):
        """
        Update the parameter for the node and other nodes by (weighted) averaging.

        It is a copy of the `static method` called `external_update`.
        Parameters
        ----------
        weight : Optional[float], optional
            the weight for the current node, by default None
        others_weights : Optional[List[float]], optional
            weight of all the other nodes, by default None
        """
        if weight is not None and others_weights is not None:
            ave_weights = [weight, *others_weights]
        else:
            ave_weights = None

        Node.external_update(self, *others, ave_weights=ave_weights)

    @beartype
    @staticmethod
    def external_update(*nodes: Node, ave_weights: Optional[List[float]] = None):
        """
        Update the parameter `w` of each `node` of `nodes` with their (weighted) average.

        Parameters
        ----------
        ave_weights : Optional[List[float]], optional
            Weight to be used during averaging, by default None.
        """
        if ave_weights is not None:
            assert len(nodes) == len(
                ave_weights
            ), "Length of ave_weights must be equal to the number of nodes."
            assert all(
                [weight > 0 for weight in ave_weights]
            ), "Weights must be a positive number."
        else:
            ave_weights = [1] * len(nodes)

        total_weight = sum(ave_weights)

        w = np.zeros_like(nodes[0].w)
        for i, node in enumerate(nodes):
            w += ave_weights[i] * node.w

        w /= total_weight

        for node in nodes:
            node.w = w