Metadata-Version: 2.4
Name: branchkey
Version: 2.9.0
Summary: Client application to interface with the BranchKey system
Home-page: https://branchkey.com
Author: BranchKey
Author-email: info@branchkey.com
Project-URL: Homepage, https://branchkey.com
Project-URL: Repository, https://gitlab.com/branchkey/client_application
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
Classifier: Operating System :: OS Independent
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: requests==2.32.3
Requires-Dist: numpy==1.26.4
Requires-Dist: pika==1.3.2
Requires-Dist: pysocks==1.7.1
Requires-Dist: websockets>=12.0
Requires-Dist: aiohttp>=3.9.0
Dynamic: author
Dynamic: author-email
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: home-page
Dynamic: license-file
Dynamic: project-url
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary

# BranchKey Python Client

![BK_logo](https://branchkey.com/branding/bk-logo-medium.png)

[![PyPI version](https://badge.fury.io/py/branchkey.svg)](https://badge.fury.io/py/branchkey)
[![Python](https://img.shields.io/pypi/pyversions/branchkey.svg)](https://pypi.org/project/branchkey/)
[![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0)

Official Python client for the BranchKey federated learning platform. This library provides a simple interface to upload model weights, download aggregated results, and track training runs.

## Installation

```bash
pip install branchkey
```

**Requirements:** Python 3.9 or higher

## Quick Start

### 1. Get Credentials

Create a leaf entity through the BranchKey platform to obtain credentials via the `/v2/entities` API endpoint:

```python
credentials = {
    "id": "your-leaf-uuid",
    "name": "my-client",
    "session_token": "your-session-token-uuid",
    "owner_id": "your-user-uuid",
    "tree_id": "your-tree-uuid",
    "branch_id": "your-branch-uuid"
}
```

### 2. Initialize Client

```python
from branchkey.client import Client

# Connect to BranchKey
client = Client(credentials, host="https://app.branchkey.com")
```

### 3. Upload Model Weights

```python
import numpy as np

# Prepare model weights
weighting = 1000  # Weight for aggregation (typically number of samples)
parameters = [layer1_weights, layer2_weights, ...]

# Save and upload
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
print(f"Uploaded: {file_id}")
```

### 4. Download Aggregated Results

```python
# Wait for aggregation notification from RabbitMQ
# The queue is populated by the background RabbitMQ consumer
aggregation_id = client.queue.get(block=True)  # Blocks until aggregation ready
client.file_download(aggregation_id)
print(f"Downloaded to: ./aggregated_files/{aggregation_id}.npz")

# Or check without blocking
if not client.queue.empty():
    aggregation_id = client.queue.get(block=False)
    client.file_download(aggregation_id)
```

## Configuration Options

### Basic Configuration

```python
client = Client(
    credentials,
    host="https://app.branchkey.com",         # API endpoint
    rbmq_host=None,                           # RabbitMQ host (auto-derived from host)
    rbmq_port=5671,                           # RabbitMQ port (5671 for TLS)
    rbmq_ssl=True,                            # Use TLS for RabbitMQ
    rbmq_max_reconnect_attempts=10,           # Max RabbitMQ reconnection attempts
    rbmq_reconnect_backoff_factor=2.0,        # Exponential backoff multiplier
    rbmq_reconnect_max_delay=60,              # Max reconnection delay (seconds)
    ssl=True,                                 # Verify SSL certificates
    wait_for_run=False,                       # Wait if run is paused
    run_check_interval_s=30,                  # Run status check interval
    proxies=None,                             # HTTP/HTTPS proxy dict
    retry_config=None                         # Custom retry configuration (optional)
)
```

### Retry Configuration

The client automatically retries failed HTTP requests with exponential backoff. Default settings:

```python
from branchkey.retry_config import RetryConfig

# Default configuration (applied automatically)
retry_config = RetryConfig(
    max_retries=3,              # Maximum retry attempts
    backoff_factor=1.0,         # Exponential backoff multiplier (seconds)
    total_timeout=30,           # Request timeout in seconds
    status_forcelist=(408, 429, 500, 502, 503, 504)  # HTTP codes to retry
)

client = Client(credentials, retry_config=retry_config)
```

**Retry Behaviour:**

- **Retries on:**
  - 408 Request Timeout - Client-side timeouts
  - 429 Too Many Requests - Rate limiting
  - 5xx Server Errors - 500, 502, 503, 504
  - Connection timeouts and network failures
- **Does NOT retry:** Other 4xx client errors (400, 401, 403, 404, etc.)
- **Backoff delays:** 1s, 2s, 4s (exponential with `backoff_factor`)

**Custom Configuration Examples:**

```python
# Production: More retries, longer timeout
production_retry = RetryConfig(max_retries=5, backoff_factor=2.0, total_timeout=60)
client = Client(credentials, retry_config=production_retry)

# Development: Faster failure
dev_retry = RetryConfig(max_retries=1, backoff_factor=0.5, total_timeout=10)
client = Client(credentials, retry_config=dev_retry)
```

### RabbitMQ Reconnection

The RabbitMQ consumer automatically reconnects with exponential backoff if the connection is lost.

**Default Settings:**

- Max reconnection attempts: 10
- Backoff factor: 2.0x (delays: 2s, 4s, 8s, 16s, 32s, 60s...)
- Max delay: 60 seconds

**Custom Configuration:**

```python
# Pass reconnection settings as Client parameters
client = Client(
    credentials,
    rbmq_max_reconnect_attempts=20,     # More attempts
    rbmq_reconnect_backoff_factor=3.0,  # Faster backoff increase
    rbmq_reconnect_max_delay=120        # 2 minute max delay (seconds)
)
```

## Model Weight Format

Model weights are stored in compressed NPZ format.

### Structure

```python
# Format: (weighting, [list_of_parameter_arrays])
weighting = 1000  # Weight for aggregation (see below)
parameters = [layer1, layer2, ...]  # List of numpy arrays
```

### Weighting Options

The `weighting` parameter controls how much influence this update has during aggregation:

**1. By Sample Count (Most Common)**

```python
weighting = len(train_dataset)  # e.g., 1000 samples
# Client with 1000 samples has 2x influence of client with 500 samples
```

**2. Equal Weighting**

```python
weighting = 1  # All clients have equal influence
```

**3. Quality-Based Weighting**

```python
validation_accuracy = 0.85
weighting = len(train_dataset) * validation_accuracy  # Weight by quality
```

**4. Manual Weighting**

```python
weighting = 5.0  # Trusted client gets higher weight
```

### PyTorch Example

```python
import numpy as np

# Method 1: Using client helper (recommended)
weighting = len(train_dataset)
parameters = []
for name, param in model.named_parameters():
    parameters.append(param.data.cpu().detach().numpy())

file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
```

```python
# Method 2: Using convert_pytorch_numpy
weighting, parameters = client.convert_pytorch_numpy(
    model.named_parameters(),
    weighting=len(train_dataset)
)
file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
```

### TensorFlow/Keras Example

```python
import numpy as np

weighting = len(train_dataset)
parameters = [layer.numpy() for layer in model.trainable_weights]

file_path = client.save_weights("model_weights", weighting, parameters)
file_id = client.file_upload(file_path)
```

### Manual NPZ Creation

```python
import numpy as np

# Save manually (without using client helper)
arrays_dict = {'weighting': np.array([weighting], dtype=np.float64)}
for i, arr in enumerate(parameters):
    arrays_dict[f'layer_{i}'] = arr

np.savez_compressed("model_weights.npz", **arrays_dict)  # Must include .npz
```

### Loading Aggregated Weights

```python
# Load aggregated weights from NPZ file
npz_data = np.load("aggregated_files/aggregation_id.npz")

# Note: Aggregated results only contain layers (no weighting)
layer_keys = sorted([k for k in npz_data.files if k.startswith('layer_')])
parameters = [npz_data[k] for k in layer_keys]

# Apply to your model
for i, param in enumerate(model.parameters()):
    param.data = torch.from_numpy(parameters[i])
```

### Example NPZ File Contents

**Client Upload Format (with weighting):**

```python
>>> npz_data.files
['weighting', 'layer_0', 'layer_1', 'layer_2', 'layer_3', ...]

>>> npz_data['weighting']
array([1530.])  # Weight for aggregation

>>> npz_data['layer_0'].shape, npz_data['layer_0'].dtype
((32, 1, 5, 5), dtype('float32'))

>>> npz_data['layer_0'][:1, :2, :2, :]
array([[[[-0.18576819, -0.03041792,  0.19532707, -0.11234483, -0.01512307],
         [ 0.19993757, -0.06492048,  0.08324468, -0.19899307, -0.0412709 ]]]],
       dtype=float32)
```

**Aggregated Result Format (layers only):**

```python
>>> npz_data.files
['layer_0', 'layer_1', 'layer_2', ...]  # No weighting in aggregated results
```

## Performance Metrics

Submit training or testing metrics:

```python
import json

metrics = {"accuracy": 0.95, "loss": 0.12}
client.send_performance_metrics(
    aggregation_id="aggregation-uuid",
    data=json.dumps(metrics),
    mode="test"  # "test", "train", or "non-federated"
)
```

## Client Properties

```python
client.run_status        # Current run status: "start", "stop", or "pause"
client.run_number        # Current run iteration
client.leaf_id           # Your leaf UUID
client.branch_id         # Parent branch UUID
client.is_initialized    # Initialization status
```

## Branch Configuration

Fetch branch configuration including model-specific settings:

```python
config = client.get_branch_config()
model_config = config.get("model_config", {})
sklearn_params = model_config.get("sklearn_params", {})
```

## Advanced Features

### Proxy Support

For networks requiring proxy access:

```python
proxies = {
    'http': 'http://user:password@proxy.example.com:8080',
    'https': 'http://user:password@proxy.example.com:8080',
}
client = Client(credentials, host="https://app.branchkey.com", proxies=proxies)
```

### Context Manager

Use the client as a context manager for automatic cleanup:

```python
with Client(credentials) as client:
    # Upload model
    file_path = client.save_weights("model", 1000, parameters)
    file_id = client.file_upload(file_path)

    # Download aggregation
    if not client.queue.empty():
        aggregation_id = client.queue.get(block=False)
        client.file_download(aggregation_id)
# Connections automatically closed
```

### Error Handling

The client provides detailed error messages for debugging:

```python
try:
    file_id = client.file_upload(file_path)
except Exception as e:
    print(f"Upload failed: {e}")
    # Logs include:
    # - HTTP status codes
    # - Response content preview
    # - Retry attempt information
```

## Support

- **Website**: [https://branchkey.com](https://branchkey.com)
- **Repository**: [https://gitlab.com/branchkey/client_application](https://gitlab.com/branchkey/client_application)
- **Email**: info@branchkey.com

---

**BranchKey** - Federated Learning Platform
