Metadata-Version: 2.4
Name: branchkey
Version: 2.9.2
Summary: Client application to interface with the BranchKey system
Home-page: https://branchkey.com
Author: BranchKey
Author-email: info@branchkey.com
Project-URL: Homepage, https://branchkey.com
Project-URL: Repository, https://gitlab.com/branchkey/client_application
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
Classifier: Operating System :: OS Independent
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: requests==2.32.3
Requires-Dist: numpy==1.26.4
Requires-Dist: pika==1.3.2
Requires-Dist: pysocks==1.7.1
Requires-Dist: websockets>=12.0
Requires-Dist: aiohttp>=3.9.0
Dynamic: author
Dynamic: author-email
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: home-page
Dynamic: license-file
Dynamic: project-url
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary

# BranchKey Python Client

![BK_logo](https://branchkey.com/branding/bk-logo-medium.png)

[![PyPI version](https://badge.fury.io/py/branchkey.svg)](https://badge.fury.io/py/branchkey)
[![Python](https://img.shields.io/pypi/pyversions/branchkey.svg)](https://pypi.org/project/branchkey/)
[![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0)

Official Python client for the BranchKey federated learning platform. This library provides a simple interface to upload model weights, download aggregated results, and track training runs.

## Installation

```bash
pip install branchkey
```

**Requirements:** Python 3.9 or higher

## Quick Start

### 1. Get Credentials

Create a leaf entity through the BranchKey platform to obtain credentials via the `/v2/entities` API endpoint.

### 2. Initialise Client

```python
from branchkey import (
    Client,
    Credentials,
    APIConfig,
    RabbitMQConfig,
    WebSocketConfig,
    RunConfig,
    RetryConfig,
)

# Create credentials
credentials = Credentials(
    id="your-leaf-uuid",
    name="my-client",
    session_token="your-session-token-uuid",
    owner_id="your-user-uuid",
    tree_id="your-tree-uuid",
    branch_id="your-branch-uuid",
)

# Initialise client with default settings
client = Client(credentials)

# Or with custom configuration
client = Client(
    credentials=credentials,
    api_config=APIConfig(
        host="https://app.branchkey.com",
        ssl=True,
    ),
    rabbitmq_config=RabbitMQConfig(
        port=5671,
        ssl=True,
    ),
    run_config=RunConfig(
        wait_for_run=False,
        check_interval_s=30,
    ),
)
```

### 3. Upload Model Weights

```python
import numpy as np

# Prepare model weights
weighting = 1000  # Weight for aggregation (typically number of samples)
parameters = [layer1_weights, layer2_weights, ...]

# Save and upload
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
print(f"Uploaded: {file_id}")
```

### 4. Download Aggregated Results

```python
# Wait for aggregation notification
aggregation_id = client.queue.get(block=True)  # Blocks until aggregation ready
client.file_download(aggregation_id)
print(f"Downloaded to: ./aggregated_output/{aggregation_id}.npz")

# Or check without blocking
if not client.queue.empty():
    aggregation_id = client.queue.get(block=False)
    client.file_download(aggregation_id)
```

## Configuration

All configuration uses immutable dataclasses for type safety and clarity.

### Credentials

```python
from branchkey import Credentials

credentials = Credentials(
    id="leaf-uuid",
    name="my-leaf",
    session_token="token-uuid",
    owner_id="user-uuid",
    tree_id="tree-uuid",
    branch_id="branch-uuid",
)

# Or from a dictionary
credentials = Credentials.from_dict(creds_dict)
```

### API Configuration

```python
from branchkey import APIConfig

api_config = APIConfig(
    host="https://app.branchkey.com",  # API endpoint (default)
    ssl=True,                           # Verify SSL certificates (default)
    proxies=None,                       # Optional proxy dict
)
```

### Transport: AMQP (RabbitMQ) vs WebSocket

The client supports two transport mechanisms for receiving aggregation notifications:

#### AMQP/RabbitMQ (Default)

```python
from branchkey import Client, Credentials, RabbitMQConfig

client = Client(
    credentials=credentials,
    rabbitmq_config=RabbitMQConfig(
        host=None,                      # Auto-derived from API host
        port=5671,                      # TLS port (default)
        ssl=True,                       # Use TLS (default)
        max_reconnect_attempts=0,       # 0 = infinite retry (default)
        reconnect_backoff_factor=2.0,   # Exponential backoff multiplier
        reconnect_max_delay=60,         # Max delay in seconds
    ),
    use_websocket=False,  # Default
)

# Receive aggregations via queue
aggregation_id = client.queue.get(block=True)
```

#### WebSocket

```python
from branchkey import Client, Credentials, WebSocketConfig

client = Client(
    credentials=credentials,
    websocket_config=WebSocketConfig(
        max_reconnect_attempts=0,       # 0 = infinite retry (default)
        reconnect_backoff_factor=2.0,   # Exponential backoff multiplier
        reconnect_max_delay=60,         # Max delay in seconds
    ),
    use_websocket=True,  # Enable WebSocket transport
)

# Receive aggregations via polling
aggregation_id = client.get_latest_aggregation_id()
if aggregation_id:
    client.file_download(aggregation_id)
```

### Run Configuration

```python
from branchkey import RunConfig

run_config = RunConfig(
    wait_for_run=False,     # Wait if run is paused before uploading
    check_interval_s=30,    # Run status check interval in seconds
)
```

### HTTP Retry Configuration

The client automatically retries failed HTTP requests with exponential backoff:

```python
from branchkey import RetryConfig

retry_config = RetryConfig(
    max_retries=3,                                   # Maximum retry attempts
    backoff_factor=1.0,                              # Backoff multiplier (seconds)
    total_timeout=30,                                # Request timeout in seconds
    status_forcelist=(408, 429, 500, 502, 503, 504), # HTTP codes to retry
    allowed_methods=("GET", "POST", "PUT"),          # Methods that support retry
)

client = Client(credentials, retry_config=retry_config)
```

**Retry Behaviour:**

- **Retries on:** 408, 429, 5xx errors, connection timeouts
- **Does NOT retry:** Other 4xx client errors (400, 401, 403, 404)
- **Backoff delays:** Exponential (1s, 2s, 4s, ...)

**Configuration Examples:**

```python
# Production: More retries, longer timeout
production_retry = RetryConfig(max_retries=5, backoff_factor=2.0, total_timeout=60)

# Development: Faster failure
dev_retry = RetryConfig(max_retries=1, backoff_factor=0.5, total_timeout=10)
```

### Complete Configuration Example

```python
from branchkey import (
    Client,
    Credentials,
    APIConfig,
    RabbitMQConfig,
    WebSocketConfig,
    RunConfig,
    RetryConfig,
)

client = Client(
    credentials=Credentials(
        id="leaf-uuid",
        name="my-leaf",
        session_token="token",
        tree_id="tree-uuid",
        branch_id="branch-uuid",
        owner_id="user-uuid",
    ),
    api_config=APIConfig(
        host="https://app.branchkey.com",
        ssl=True,
    ),
    rabbitmq_config=RabbitMQConfig(
        port=5671,
        ssl=True,
        max_reconnect_attempts=10,
    ),
    websocket_config=WebSocketConfig(
        max_reconnect_attempts=10,
    ),
    run_config=RunConfig(
        wait_for_run=True,
        check_interval_s=15,
    ),
    retry_config=RetryConfig(
        max_retries=5,
        backoff_factor=2.0,
    ),
    use_websocket=False,  # False for AMQP, True for WebSocket
)
```

## Model Weight Format

Model weights are stored in compressed NPZ format.

### Structure

```python
# Format: (weighting, [list_of_parameter_arrays])
weighting = 1000  # Weight for aggregation (see below)
parameters = [layer1, layer2, ...]  # List of numpy arrays
```

### Weighting Options

The `weighting` parameter controls how much influence this update has during aggregation:

**1. By Sample Count (Most Common)**

```python
weighting = len(train_dataset)  # e.g., 1000 samples
# Client with 1000 samples has 2x influence of client with 500 samples
```

**2. Equal Weighting**

```python
weighting = 1  # All clients have equal influence
```

**3. Quality-Based Weighting**

```python
validation_accuracy = 0.85
weighting = len(train_dataset) * validation_accuracy  # Weight by quality
```

### PyTorch Example

```python
import numpy as np

# Using client helper
weighting = len(train_dataset)
parameters = []
for name, param in model.named_parameters():
    parameters.append(param.data.cpu().detach().numpy())

file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
```

```python
# Using convert_pytorch_numpy
weighting, parameters = client.convert_pytorch_numpy(
    model.named_parameters(),
    weighting=len(train_dataset)
)
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
```

### TensorFlow/Keras Example

```python
import numpy as np

weighting = len(train_dataset)
parameters = [layer.numpy() for layer in model.trainable_weights]

file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
```

### Loading Aggregated Weights

```python
import numpy as np

# Load aggregated weights from NPZ file
npz_data = np.load("aggregated_output/aggregation_id.npz")

# Note: Aggregated results only contain layers (no weighting)
layer_keys = sorted([k for k in npz_data.files if k.startswith('layer_')])
parameters = [npz_data[k] for k in layer_keys]

# Apply to PyTorch model
import torch
for i, param in enumerate(model.parameters()):
    param.data = torch.from_numpy(parameters[i])
```

## Performance Metrics

Submit training or testing metrics:

```python
import json

metrics = {"accuracy": 0.95, "loss": 0.12}
client.send_performance_metrics(
    aggregation_id="aggregation-uuid",
    data=json.dumps(metrics),
    mode="test"  # "test", "train", or "non-federated"
)
```

## Client Properties

```python
client.run_status        # Current run status: "start", "stop", or "pause"
client.run_number        # Current run iteration
client.leaf_id           # Your leaf UUID
client.branch_id         # Parent branch UUID
client.tree_id           # Tree UUID
client.is_initialized    # Initialisation status
client.use_websocket     # True if using WebSocket transport
```

## Branch Configuration

Fetch branch configuration including model-specific settings:

```python
config = client.get_branch_config()
model_config = config.get("model_config", {})
sklearn_params = model_config.get("sklearn_params", {})
```

## Advanced Features

### Proxy Support

```python
from branchkey import Client, Credentials, APIConfig

proxies = {
    'http': 'http://user:password@proxy.example.com:8080',
    'https': 'http://user:password@proxy.example.com:8080',
}

client = Client(
    credentials=credentials,
    api_config=APIConfig(proxies=proxies),
)
```

### Context Manager

Use the client as a context manager for automatic cleanup:

```python
from branchkey import Client, Credentials

with Client(credentials) as client:
    # Upload model
    file_path = client.save_weights("model", 1000, parameters)
    file_id = client.file_upload(file_path)

    # Download aggregation
    if not client.queue.empty():
        aggregation_id = client.queue.get(block=False)
        client.file_download(aggregation_id)
# Connections automatically closed
```

### Error Handling

```python
try:
    file_id = client.file_upload(file_path)
except Exception as e:
    print(f"Upload failed: {e}")
    # Logs include:
    # - HTTP status codes
    # - Response content preview
    # - Retry attempt information
```

## Public API

```python
from branchkey import (
    # Main client
    Client,

    # Configuration (frozen dataclasses)
    Credentials,
    APIConfig,
    RabbitMQConfig,
    WebSocketConfig,
    RunConfig,
    RetryConfig,

    # Utilities
    get_metadata,
    AGGREGATED_OUTPUT_DIR,
)
```

## Support

- **Website**: [https://branchkey.com](https://branchkey.com)
- **Repository**: [https://gitlab.com/branchkey/client_application](https://gitlab.com/branchkey/client_application)
- **Email**: info@branchkey.com

---

**BranchKey** - Federated Learning Platform
