Metadata-Version: 2.4
Name: tf-kan-latest
Version: 1.1.0
Summary: A Keras-native implementation of Kolmogorov-Arnold Networks (KANs) for TensorFlow.
Home-page: https://github.com/sathyasubrahmanya/tf-kan
Author: Sathyasubrahmanya v S
Author-email: sathyapel0005@gmail.com
Keywords: tensorflow,keras,kan,kolmogorov-arnold,neural-networks,machine-learning
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: tensorflow>=2.16.0
Requires-Dist: numpy>=1.23.0
Dynamic: author
Dynamic: author-email
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: home-page
Dynamic: keywords
Dynamic: license-file
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary



# TF-KAN: Kolmogorov-Arnold Networks for TensorFlow

[](https://www.google.com/search?q=https://badge.fury.io/py/tf-kan)
[](https://www.google.com/search?q=https://travis-ci.org/your-username/tf-kan)
[](https://opensource.org/licenses/MIT)

A Keras-native, high-performance implementation of **Kolmogorov-Arnold Networks (KANs)** for **TensorFlow 2.19+**.

This library provides easy-to-use Keras layers that replace standard linear transformations with learnable B-spline activation functions, allowing for more expressive and interpretable models.

-----

## Key Features

  * **🧠 Learnable Activations**: Goes beyond fixed activation functions like ReLU or SiLU by learning complex, data-driven activations on each weight.
  * **🧩 Seamless Keras Integration**: Use `DenseKAN` and `Conv*DKAN` layers as direct, drop-in replacements for standard Keras layers.
  * **⚡ High Performance**: Core mathematical operations are compiled into static graphs with `@tf.function` for maximum speed.
  * **🔄 Adaptive Grids**: Dynamically update spline resolutions based on data, allowing the model to allocate its parameters more effectively.
  * **💾 Modern Serialization**: Save and load models containing KAN layers with `model.save()` and `tf.keras.models.load_model()`—no `custom_objects` needed.

-----

## Installation

```bash
pip install tf-kan
```

-----

## Core Concepts

In a traditional neural network, a connection is a single weight (`w`). In a KAN, each connection is a learnable 1D function (a **B-spline**), like a smart dimmer switch that can apply a complex curve to the input signal.

You control these functions with two hyperparameters:

  * **`grid_size`**: The resolution of the function. A larger size allows for more complex, "wiggly" functions.
  * **`spline_order`**: The smoothness of the function. An order of 3 (cubic) is recommended for smooth curves.

-----

## Examples

Here are several examples demonstrating how to use `tfkan` for different tasks.

### 1\. Basic Regression

This example builds a simple model to learn a 1D function, showcasing the `DenseKAN` layer.

```python
import tensorflow as tf
import numpy as np
from tfkan.layers import DenseKAN

# 1. Generate some synthetic data for y = sin(pi*x)
x_train = np.linspace(-1, 1, 100)[:, np.newaxis]
y_train = np.sin(np.pi * x_train)

# 2. Build the KAN model
# A small model is enough to learn this simple function
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(1,)),
    DenseKAN(units=16, grid_size=8, spline_order=3, name='kan_layer_1'),
    DenseKAN(units=1, name='kan_output')
])

# 3. Compile and train
model.compile(optimizer='adam', loss='mean_squared_error')
print("--- Training a simple regressor ---")
model.fit(x_train, y_train, epochs=50, verbose=0)

# 4. Test the model
print("--- Prediction ---")
test_input = tf.constant([[0.5]]) # sin(pi * 0.5) = 1.0
prediction = model.predict(test_input)
print(f"Model prediction for input 0.5: {prediction[0][0]:.4f}")
model.summary()
```

### 2\. Image Classification (Hybrid CNN)

Mix standard Keras layers with `Conv2DKAN` and `DenseKAN` to build a powerful hybrid classifier.

```python
import tensorflow as tf
from tfkan.layers import Conv2DKAN, DenseKAN

# 1. Load a dataset (using dummy data here)
(x_train, y_train), _ = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0

# 2. Build the hybrid model
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(32, 32, 3)),

    # Standard Conv block
    tf.keras.layers.Conv2D(32, kernel_size=3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(),

    # KAN Conv block with specific KAN arguments
    Conv2DKAN(
        filters=64,
        kernel_size=3,
        padding='same',
        name='kan_conv',
        kan_kwargs={'grid_size': 5, 'spline_order': 3}
    ),
    tf.keras.layers.GlobalAveragePooling2D(),

    # KAN Dense layers for final classification
    DenseKAN(units=128, grid_size=8, name='kan_dense'),
    tf.keras.layers.Dense(units=10, name='output_logits') # Standard output layer
])

# 3. Compile and train
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
print("\n--- Training a hybrid CNN for image classification ---")
# model.fit(x_train, y_train, epochs=1, batch_size=64) # Uncomment to train
model.summary()
```

### 3\. Advanced Usage: Adaptive Grid Updates

KANs can dynamically update their internal grids to better fit the data distribution. This is useful for refining a pre-trained model.

```python
import tensorflow as tf
import numpy as np
from tfkan.layers import DenseKAN

# 1. Create a model and a sample data batch
model = tf.keras.Sequential([DenseKAN(16, grid_size=5, name='my_kan_layer', input_shape=(32,))])
sample_data = np.random.randn(100, 32).astype('float32')

# 2. Get the KAN layer from the model
kan_layer = model.get_layer('my_kan_layer')
print(f"Initial grid size: {kan_layer.grid_size}")

# 3. Update the grid based on the sample data
# This re-calculates knot locations to better cover the data's features
print("Updating grid from samples...")
kan_layer.update_grid_from_samples(sample_data)
print("Grid updated successfully.")

# 4. You can also extend the grid to a higher resolution
print("Extending grid to a larger size...")
try:
    kan_layer.extend_grid_from_samples(sample_data, extend_grid_size=10)
    print(f"Grid extended successfully. New grid size: {kan_layer.grid_size}")
except Exception as e:
    print(f"Error during extension: {e}")

```

### 4\. Time Series Forecasting

Use `Conv1DKAN` to find complex temporal patterns in sequential data.

```python
import tensorflow as tf
from tfkan.layers import Conv1DKAN, DenseKAN

# 1. Define model parameters for a time series task
lookback_window = 20  # Number of past time steps to use as input
num_features = 5      # Number of features at each time step
num_classes = 3       # Number of output classes

# 2. Build a model for sequence classification
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(lookback_window, num_features)),
    
    # 1D KAN convolution to extract temporal features
    Conv1DKAN(
        filters=32,
        kernel_size=3,
        kan_kwargs={'grid_size': 8}
    ),
    tf.keras.layers.GlobalAveragePooling1D(),
    
    # Dense KAN layers for classification
    DenseKAN(64),
    tf.keras.layers.Dense(num_classes)
])

# 3. Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy')
print("\n--- Time Series Model ---")
model.summary()
```

-----

## Contributing

Contributions are welcome\! Please feel free to submit a pull request or open an issue.

## License

This project is licensed under the MIT License.
