Metadata-Version: 2.4
Name: sanet
Version: 1.0.11
Summary: Spiking Attention Network model
Project-URL: Homepage, https://www.qtvo.dev/
Project-URL: License, https://opensource.org/licenses/MIT
Author-email: Quoc Thinh Vo <contact@qtvo.dev>
License: MIT
License-File: LICENSE
Keywords: attention,neural-network,pytorch,spiking
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Requires-Python: <3.13,>=3.9
Requires-Dist: black>=25.11.0
Requires-Dist: conformer==0.3.2
Requires-Dist: numpy==2.0.2
Requires-Dist: pytorch-lightning==2.4.0
Requires-Dist: scikit-learn==1.5.2
Requires-Dist: snntorch==0.9.4
Requires-Dist: torch==2.4.1
Description-Content-Type: text/markdown

# sanet

Spiking Attention Network (sanet) model package.
[![PyPI version](https://img.shields.io/pypi/v/sanet.svg?logo=pypi&logoColor=FFE873)](https://pypi.org/project/sanet/)
[![Supported Python versions](https://img.shields.io/pypi/pyversions/sanet.svg?logo=python&logoColor=FFE873)](https://pypi.org/project/sanet/)
[![PyPI downloads](https://img.shields.io/pypi/dm/sanet.svg)](https://pypistats.org/packages/sanet)
[![Licence](https://img.shields.io/github/license/qtvo93/spiking-nw-ssl.svg)](LICENSE)
[![Code style: Black](https://img.shields.io/badge/code%20style-Black-000000.svg)](https://github.com/psf/black)

Paper: [Spiking Attention Network: A Hybrid Neuromorphic Approach to Underwater Acoustic Localization and Zero-shot Adaptation](https://ieeexplore.ieee.org/document/11464621)

## Install

```bash
pip3 install sanet
```

Or

```bash
uv add sanet
```

## Usage

### Minimal

```python
import torch
import sanet

model = sanet.SA_NET()

batch_size = 2
time_steps = 1500
x = torch.randn(batch_size, time_steps, 21)

with torch.no_grad():
	y = model(x)
	# expected torch.Size([2, 1])

print(y.shape)
```

### All Parameters

```python
import torch
import sanet

model = sanet.SA_NET(
	input_channels=21,
	output_channels=1,
	middle_channels=11,
	seed=42,
	spike_slope=25,
	lif1_beta=0.9956,
	lif2_beta=0.9821,
	lif3_beta=0.930,
	conformer_dim=512,
	conformer_depth=2,
	conformer_dim_head=64,
	conformer_heads=8,
	conformer_ff_mult=4,
	conformer_conv_expansion_factor=2,
	conformer_conv_kernel_size=24,
	conformer_attn_dropout=0.1,
	conformer_ff_dropout=0.1,
	conformer_conv_dropout=0.1,
	dropout_p1=0.0,
	dropout_p2=0.1,
	dropout_p3=0.1,
	dropout_p4=0.1,
)

batch_size = 2
time_steps = 1500
x = torch.randn(batch_size, time_steps, 21)

with torch.no_grad():
    y = model(x)
    # expected torch.Size([2, 1])

print(y.shape)
```

## Model Notes

- Input tensor shape: [batch, time, channels]
- The forward pass applies per-channel standardization before the backbone.
- The network uses ResNet-style 1D blocks, spiking neurons, and Conformer layers.
- Model initialization sets a deterministic seed (Python, NumPy, and PyTorch) and enables deterministic CUDA behavior.

## Architecture Diagram

```mermaid
flowchart TB
	A["Input (batch, time, channels)"] --> B["Per-channel Standardization"]

	subgraph Backbone
		C["ResNet1 + MaxPool + Dropout (p1)"] --> D["LIF1 Spiking"]
		D --> E["ResNet2 + MaxPool + Dropout (p2)"]
		E --> F["LIF2 Spiking"]
		F --> G["ResNet3 + MaxPool + Dropout (p3)"]
		G --> H["LIF3 Spiking"]
		H --> I["ResNet4 + MaxPool + Dropout (p4)"]
	end

	B --> C
	I --> J["Conformer Stack"]
	J --> K["Multilayer Perceptron"]
	K --> L["Output"]
```

## Parameters

- `input_channels` (default: 21): Number of input channels.
- `output_channels` (default: 1): Number of output channels.
- `middle_channels` (default: 11): Number of intermediate channels before the final projection.
- `seed` (default: 42): Random seed for reproducibility.
- `spike_slope` (default: 25): Slope for the surrogate spike gradient.
- `lif1_beta`, `lif2_beta`, `lif3_beta` (default: 0.9956, 0.9821, 0.930): Decay rates for spiking neurons.
- `conformer_dim` (default: 512): Conformer model dimension.
- `conformer_depth` (default: 2): Number of Conformer blocks.
- `conformer_dim_head` (default: 64): Attention head dimension.
- `conformer_heads` (default: 8): Number of attention heads.
- `conformer_ff_mult` (default: 4): Feedforward expansion multiplier.
- `conformer_conv_expansion_factor` (default: 2): Conformer conv expansion factor.
- `conformer_conv_kernel_size` (default: 24): Conformer conv kernel size.
- `conformer_attn_dropout` (default: 0.1): Attention dropout in Conformer.
- `conformer_ff_dropout` (default: 0.1): Feedforward dropout in Conformer.
- `conformer_conv_dropout` (default: 0.1): Convolution dropout in Conformer.
- `dropout_p1`, `dropout_p2`, `dropout_p3`, `dropout_p4` (default: 0.0, 0.1, 0.1, 0.1): Dropout probabilities for the 1D dropout layers.

## Citation

If you find this package useful, please consider citing our paper:

```bibtex
@inproceedings{vo2026sa-net,
	title={Spiking Attention Network: A Hybrid Neuromorphic Approach to Underwater Acoustic Localization and Zero-shot Adaptation},
	author={Vo, Quoc Thinh and Han, David K},
	booktitle={2026 51st IEEE International Conference on Acoustics, Speech, and Signal Processing (ICASSP)},
	pages={1--5},
	year={2026},
	organization={IEEE}
}
```
