Skip to content

Library API

The ReFedEz library provides a high-level API for implementing federated learning algorithms. It abstracts away the complexities of distributed communication, allowing you to focus on your machine learning models and training logic.

What the Library Does

The library enables you to write federated learning code that runs seamlessly across distributed machines. Instead of dealing with low-level networking, serialization, and coordination, you write standard ML code and the library handles:

  • Model Distribution: Automatically sending model weights between server and clients
  • Aggregation: Coordinating federated averaging and other aggregation strategies
  • Lifecycle Management: Handling training rounds, validation, and synchronization
  • Framework Integration: Supporting PyTorch and NumPy backends out of the box

Where to Use It

Use the library in your model.py or training script files. This is where you define your federated learning algorithm. The library is designed for:

  • Research Prototyping: Quickly test federated algorithms without infrastructure setup
  • Production ML Code: Write clean, framework-agnostic federated training code
  • Algorithm Development: Focus on the ML logic while the library handles distribution

Basic Usage

  1. Choose a Backend: Inherit from FederatedTorch for PyTorch models or FederatedNumpy for NumPy-based implementations.

  2. Use the @Federated Decorator: Apply this decorator to your class to specify the server and client configurations. The decorator automatically reads from your refedez.yaml and connects to the running ReFedEz deployment.

  3. Implement Required Methods:

  4. get_weights(): Return current model parameters
  5. set_weights(weights): Load new model parameters
  6. train_step(): Perform one round of local training
  7. validate(weights): Evaluate model performance

Example

from refedez.lib import Federated, Server, Client, FederatedTorch


@Federated(
    server=Server("server.localhost", save_model_path="/models/final_model.pt"),
    clients=[Client("site1"), Client("site2")],
    refedez_config="./refedez.yaml"
)
class MyFederatedModel(FederatedTorch):
    def __init__(self):
        super().__init__()
        # Your model initialization here

    def get_weights(self):
        # Return model weights
        pass

    def set_weights(self, weights):
        # Load model weights
        pass

    def train_step(self):
        # Local training logic
        pass

    def validate(self, weights):
        # Validation logic
        pass

When you run this script, the library automatically: - Connects to the ReFedEz-deployed server and clients - Coordinates federated training rounds - Handles model synchronization and aggregation


Library Definitions

FederatedNumpy

Bases: ABC

Abstract base class for federated learning models using NumPy.

This class defines the interface for models that participate in federated learning workflows using NumPy arrays for weight representation.

Subclasses must implement the abstract methods to define the model's forward pass, weight retrieval and setting, and training step logic.

Source code in refedez/lib/backends/numpy.py
class FederatedNumpy(ABC):
    """Abstract base class for federated learning models using NumPy.

    This class defines the interface for models that participate in federated learning
    workflows using NumPy arrays for weight representation.

    Subclasses must implement the abstract methods to define the model's forward pass,
    weight retrieval and setting, and training step logic.
    """

    def forward(self, x: Any):
        raise NotImplementedError()

    def get_weights(self) -> Any:
        raise NotImplementedError()

    def set_weights(self, new_weights: Any) -> None:
        raise NotImplementedError()

    def train_step(self, learning_rate=1.0) -> Any:
        raise NotImplementedError()

FederatedTorch

Bases: Module

Base class for federated learning models using PyTorch.

This class extends PyTorch's nn.Module and provides the interface for models in federated learning scenarios. It handles the initialization and defines abstract methods for weight management and training.

Subclasses should implement the forward pass and training logic specific to their model architecture.

Source code in refedez/lib/backends/torch.py
class FederatedTorch(nn.Module):
    """Base class for federated learning models using PyTorch.

    This class extends PyTorch's nn.Module and provides the interface for models
    in federated learning scenarios. It handles the initialization and defines
    abstract methods for weight management and training.

    Subclasses should implement the forward pass and training logic specific to
    their model architecture.
    """

    def __init__(self):
        super(FederatedTorch, self).__init__()

    def forward(self, x: Any):
        raise NotImplementedError()

    def get_weights(self) -> torch.Tensor:
        raise NotImplementedError()

    def set_weights(self, new_weights: torch.Tensor) -> None:
        raise NotImplementedError()

    def train_step(self, learning_rate=1.0) -> torch.Tensor:
        raise NotImplementedError()

Server dataclass

Configuration for the federated learning server.

Represents the central server in the federated learning system, responsible for coordinating the training process and aggregating model updates.

Attributes:

Name Type Description
name str

The unique name identifier for the server.

save_model_path str | None

Optional path where the trained model should be saved.

Source code in refedez/lib/config.py
@dataclass(frozen=True)
class Server:
    """Configuration for the federated learning server.

    Represents the central server in the federated learning system, responsible
    for coordinating the training process and aggregating model updates.

    Attributes:
        name: The unique name identifier for the server.
        save_model_path: Optional path where the trained model should be saved.
    """

    name: str
    save_model_path: str | None = None

Client dataclass

Configuration for a federated learning client.

Represents a client participant in the federated learning system, including its name and any environment variables required for execution.

Attributes:

Name Type Description
name str

The unique name identifier for the client.

env_vars Dict[str, str]

Dictionary of environment variables to set for the client.

Source code in refedez/lib/config.py
@dataclass(frozen=True)
class Client:
    """Configuration for a federated learning client.

    Represents a client participant in the federated learning system, including
    its name and any environment variables required for execution.

    Attributes:
        name: The unique name identifier for the client.
        env_vars: Dictionary of environment variables to set for the client.
    """

    name: str
    env_vars: Dict[str, str] = field(default_factory=dict)

Federated(server, clients, refedez_config, num_rounds=1, backend=Backend.PYTORCH, algorithm=FedAlgorithm.FED_AVG)

Runs a federated learning experiment using the specified backend and configuration.

This function orchestrates the communication and training rounds between a central server and multiple clients according to the provided refedez_config file. The backend determines the machine learning framework used (e.g., PyTorch or NumPy).

Parameters:

Name Type Description Default
server Server

The federated server responsible for aggregating model updates and coordinating the training process.

required
clients List[Client]

A list of participating clients, each responsible for local model training and update submission.

required
refedez_config str

Path to the ReFedEz configuration file (YAML or JSON) that defines the experiment setup, data paths, and hyperparameters.

required
num_rounds int

Number of federated training rounds to execute. Defaults to 1.

1
backend Backend

Backend to use for computation. Can be one of the supported frameworks in Backend (e.g., Backend.PYTORCH, Backend.NUMPY). Defaults to Backend.PYTORCH.

PYTORCH
algorithm FedAlgorithm

Backend to use for computation. Can be one of the supported frameworks in FedAlgorithm (e.g., FedAlgorithm.FED_AVG, FedAlgorithm.SCAFFOLD). Defaults to FedAlgorithm.FED_AVG.

FED_AVG

Returns:

Name Type Description
None

This function does not return a value directly, but performs training,

logging, and potentially saves model artifacts or metrics as side effects.

Raises:

Type Description
ValueError

If the configuration file path is invalid or unreadable.

RuntimeError

If the backend fails to initialize or training fails mid-process.

ConnectionError

If communication with clients fails during aggregation rounds.

Example:

Source code in refedez/lib/refedez.py
def Federated(
    server: Server,
    clients: List[Client],
    refedez_config: str,
    num_rounds=1,
    backend=Backend.PYTORCH,
    algorithm=FedAlgorithm.FED_AVG,
):
    """
     Runs a federated learning experiment using the specified backend and configuration.

    This function orchestrates the communication and training rounds between a central
    server and multiple clients according to the provided `refedez_config` file.
    The backend determines the machine learning framework used (e.g., PyTorch or NumPy).

    Args:
        server (Server): The federated server responsible for aggregating model updates
            and coordinating the training process.
        clients (List[Client]): A list of participating clients, each responsible for
            local model training and update submission.
        refedez_config (str): Path to the ReFedEz configuration file (YAML or JSON)
            that defines the experiment setup, data paths, and hyperparameters.
        num_rounds (int, optional): Number of federated training rounds to execute.
            Defaults to 1.
        backend (Backend, optional): Backend to use for computation. Can be one of the
            supported frameworks in `Backend` (e.g., `Backend.PYTORCH`, `Backend.NUMPY`).
            Defaults to `Backend.PYTORCH`.
        algorithm (FedAlgorithm, optional): Backend to use for computation. Can be one of the
            supported frameworks in `FedAlgorithm` (e.g., `FedAlgorithm.FED_AVG`, `FedAlgorithm.SCAFFOLD`).
            Defaults to `FedAlgorithm.FED_AVG`.

    Returns:
        None: This function does not return a value directly, but performs training,
        logging, and potentially saves model artifacts or metrics as side effects.

    Raises:
        ValueError: If the configuration file path is invalid or unreadable.
        RuntimeError: If the backend fails to initialize or training fails mid-process.
        ConnectionError: If communication with clients fails during aggregation rounds.

    Example:
    """

    def decorator(cls):
        # Only run the logic if this module is executed as __main__
        if cls.__module__ == "__main__":
            args = sys.argv
            should_execute = len(args) >= 2 and ENV_RUN_JOB_STAGE in args[1]

            if should_execute:
                if backend == Backend.PYTORCH:
                    execute_torch(cls())
                elif backend == Backend.NUMPY:
                    execute_numpy(cls())
                elif backend == Backend.TENSORFLOW:
                    execute_tensorflow(cls())
                else:
                    raise NotImplementedError(f"Unsupported backend {backend}")
            else:
                script_path = getfile(cls)
                job_config = JobConfiguration(
                    server=server,
                    clients=clients,
                    refedez_config=refedez_config,
                    to_execute_script_path=script_path,
                    num_rounds=num_rounds,
                    model_cls=cls,
                    type=backend,
                    algorithm=algorithm,
                )
                execute_job_config(job_config)
                sys.exit(0)
        return cls

    return decorator

backends

numpy

FederatedNumpy

Bases: ABC

Abstract base class for federated learning models using NumPy.

This class defines the interface for models that participate in federated learning workflows using NumPy arrays for weight representation.

Subclasses must implement the abstract methods to define the model's forward pass, weight retrieval and setting, and training step logic.

Source code in refedez/lib/backends/numpy.py
class FederatedNumpy(ABC):
    """Abstract base class for federated learning models using NumPy.

    This class defines the interface for models that participate in federated learning
    workflows using NumPy arrays for weight representation.

    Subclasses must implement the abstract methods to define the model's forward pass,
    weight retrieval and setting, and training step logic.
    """

    def forward(self, x: Any):
        raise NotImplementedError()

    def get_weights(self) -> Any:
        raise NotImplementedError()

    def set_weights(self, new_weights: Any) -> None:
        raise NotImplementedError()

    def train_step(self, learning_rate=1.0) -> Any:
        raise NotImplementedError()

tensorflow

FederatedTensorFlow

Bases: Model

Base class for federated learning models using TensorFlow (Keras).

This class extends tf.keras.Model and provides an interface for models in federated learning scenarios. It handles initialization and defines abstract methods for weight management and local training.

Source code in refedez/lib/backends/tensorflow.py
class FederatedTensorFlow(tf.keras.Model):
    """Base class for federated learning models using TensorFlow (Keras).

    This class extends tf.keras.Model and provides an interface for models
    in federated learning scenarios. It handles initialization
    and defines abstract methods for weight management and local training.
    """

    def __init__(self):
        super(FederatedTensorFlow, self).__init__()

    def forward(self, inputs):
        pass

    def get_weights(self):
        pass

    def set_weights(self, new_weights):
        pass

    def train_step(self, learning_rate=1.0):
        pass

torch

FederatedTorch

Bases: Module

Base class for federated learning models using PyTorch.

This class extends PyTorch's nn.Module and provides the interface for models in federated learning scenarios. It handles the initialization and defines abstract methods for weight management and training.

Subclasses should implement the forward pass and training logic specific to their model architecture.

Source code in refedez/lib/backends/torch.py
class FederatedTorch(nn.Module):
    """Base class for federated learning models using PyTorch.

    This class extends PyTorch's nn.Module and provides the interface for models
    in federated learning scenarios. It handles the initialization and defines
    abstract methods for weight management and training.

    Subclasses should implement the forward pass and training logic specific to
    their model architecture.
    """

    def __init__(self):
        super(FederatedTorch, self).__init__()

    def forward(self, x: Any):
        raise NotImplementedError()

    def get_weights(self) -> torch.Tensor:
        raise NotImplementedError()

    def set_weights(self, new_weights: torch.Tensor) -> None:
        raise NotImplementedError()

    def train_step(self, learning_rate=1.0) -> torch.Tensor:
        raise NotImplementedError()

config

Client dataclass

Configuration for a federated learning client.

Represents a client participant in the federated learning system, including its name and any environment variables required for execution.

Attributes:

Name Type Description
name str

The unique name identifier for the client.

env_vars Dict[str, str]

Dictionary of environment variables to set for the client.

Source code in refedez/lib/config.py
@dataclass(frozen=True)
class Client:
    """Configuration for a federated learning client.

    Represents a client participant in the federated learning system, including
    its name and any environment variables required for execution.

    Attributes:
        name: The unique name identifier for the client.
        env_vars: Dictionary of environment variables to set for the client.
    """

    name: str
    env_vars: Dict[str, str] = field(default_factory=dict)

Server dataclass

Configuration for the federated learning server.

Represents the central server in the federated learning system, responsible for coordinating the training process and aggregating model updates.

Attributes:

Name Type Description
name str

The unique name identifier for the server.

save_model_path str | None

Optional path where the trained model should be saved.

Source code in refedez/lib/config.py
@dataclass(frozen=True)
class Server:
    """Configuration for the federated learning server.

    Represents the central server in the federated learning system, responsible
    for coordinating the training process and aggregating model updates.

    Attributes:
        name: The unique name identifier for the server.
        save_model_path: Optional path where the trained model should be saved.
    """

    name: str
    save_model_path: str | None = None

refedez

Federated(server, clients, refedez_config, num_rounds=1, backend=Backend.PYTORCH, algorithm=FedAlgorithm.FED_AVG)

Runs a federated learning experiment using the specified backend and configuration.

This function orchestrates the communication and training rounds between a central server and multiple clients according to the provided refedez_config file. The backend determines the machine learning framework used (e.g., PyTorch or NumPy).

Parameters:

Name Type Description Default
server Server

The federated server responsible for aggregating model updates and coordinating the training process.

required
clients List[Client]

A list of participating clients, each responsible for local model training and update submission.

required
refedez_config str

Path to the ReFedEz configuration file (YAML or JSON) that defines the experiment setup, data paths, and hyperparameters.

required
num_rounds int

Number of federated training rounds to execute. Defaults to 1.

1
backend Backend

Backend to use for computation. Can be one of the supported frameworks in Backend (e.g., Backend.PYTORCH, Backend.NUMPY). Defaults to Backend.PYTORCH.

PYTORCH
algorithm FedAlgorithm

Backend to use for computation. Can be one of the supported frameworks in FedAlgorithm (e.g., FedAlgorithm.FED_AVG, FedAlgorithm.SCAFFOLD). Defaults to FedAlgorithm.FED_AVG.

FED_AVG

Returns:

Name Type Description
None

This function does not return a value directly, but performs training,

logging, and potentially saves model artifacts or metrics as side effects.

Raises:

Type Description
ValueError

If the configuration file path is invalid or unreadable.

RuntimeError

If the backend fails to initialize or training fails mid-process.

ConnectionError

If communication with clients fails during aggregation rounds.

Example:

Source code in refedez/lib/refedez.py
def Federated(
    server: Server,
    clients: List[Client],
    refedez_config: str,
    num_rounds=1,
    backend=Backend.PYTORCH,
    algorithm=FedAlgorithm.FED_AVG,
):
    """
     Runs a federated learning experiment using the specified backend and configuration.

    This function orchestrates the communication and training rounds between a central
    server and multiple clients according to the provided `refedez_config` file.
    The backend determines the machine learning framework used (e.g., PyTorch or NumPy).

    Args:
        server (Server): The federated server responsible for aggregating model updates
            and coordinating the training process.
        clients (List[Client]): A list of participating clients, each responsible for
            local model training and update submission.
        refedez_config (str): Path to the ReFedEz configuration file (YAML or JSON)
            that defines the experiment setup, data paths, and hyperparameters.
        num_rounds (int, optional): Number of federated training rounds to execute.
            Defaults to 1.
        backend (Backend, optional): Backend to use for computation. Can be one of the
            supported frameworks in `Backend` (e.g., `Backend.PYTORCH`, `Backend.NUMPY`).
            Defaults to `Backend.PYTORCH`.
        algorithm (FedAlgorithm, optional): Backend to use for computation. Can be one of the
            supported frameworks in `FedAlgorithm` (e.g., `FedAlgorithm.FED_AVG`, `FedAlgorithm.SCAFFOLD`).
            Defaults to `FedAlgorithm.FED_AVG`.

    Returns:
        None: This function does not return a value directly, but performs training,
        logging, and potentially saves model artifacts or metrics as side effects.

    Raises:
        ValueError: If the configuration file path is invalid or unreadable.
        RuntimeError: If the backend fails to initialize or training fails mid-process.
        ConnectionError: If communication with clients fails during aggregation rounds.

    Example:
    """

    def decorator(cls):
        # Only run the logic if this module is executed as __main__
        if cls.__module__ == "__main__":
            args = sys.argv
            should_execute = len(args) >= 2 and ENV_RUN_JOB_STAGE in args[1]

            if should_execute:
                if backend == Backend.PYTORCH:
                    execute_torch(cls())
                elif backend == Backend.NUMPY:
                    execute_numpy(cls())
                elif backend == Backend.TENSORFLOW:
                    execute_tensorflow(cls())
                else:
                    raise NotImplementedError(f"Unsupported backend {backend}")
            else:
                script_path = getfile(cls)
                job_config = JobConfiguration(
                    server=server,
                    clients=clients,
                    refedez_config=refedez_config,
                    to_execute_script_path=script_path,
                    num_rounds=num_rounds,
                    model_cls=cls,
                    type=backend,
                    algorithm=algorithm,
                )
                execute_job_config(job_config)
                sys.exit(0)
        return cls

    return decorator