# Implement New Models

PocketPose offers a flexible and extensible framework for integrating various pose estimation models, particularly tailored for mobile devices. At its core, the library utilizes a well-defined interface for models, making it straightforward to add new models and extend the library's capabilities.

```{mermaid}
classDiagram
    IModel <|-- TFLiteModel
    IModel <|-- ONNModel
    TFLiteModel <|-- YourCustomModel
    IModel: +process_image()
    IModel: +predict()
    IModel: +postprocess_prediction()
    class TFLiteModel{
      +predict()
      +postprocess_prediction()
    }
    class ONNModel{
      +predict()
      +postprocess_prediction()
    }
    class YourCustomModel{
      +predict()
      +postprocess_prediction()
    }
```

To integrate a new model into PocketPose, follow these steps:

1. **Inherit from IModel**: Your model class should inherit from the `IModel` interface or one of its direct subclasses. This ensures consistency and interoperability within the PocketPose framework.
   ```python
   from pocketpose.models.interfaces import IModel

   class YourCustomModel(IModel):
   ...
   ```

2. **Implement Required Methods**: Implement the abstract methods defined in the `IModel` interface. These typically include `process_image`, `predict`, and `postprocess_prediction`.
   - `process_image`: Prepare the input image for prediction.
   - `predict`: Run the model inference.
   - `postprocess_prediction`: Process the raw model output to extract meaningful information, such as keypoint coordinates.

3. **Add Custom Logic (if needed)**: Depending on your model's specific requirements, you can add additional methods or override existing ones for customized behavior.

4. **Integrate with Model Factory**: Optionally, integrate your model with the ModelFactory to enable easy instantiation.


## Example

The following example demonstrates how to integrate a TFLite model into PocketPose.

```python
from pocketpose.models.interfaces import TFLiteModel
from pocketpose.models.registry import model_registry


@model_registry.register('model_name')
class CustomModel(TFLiteModel):
    def __init__(self,
                 model_path: str = "path/to/cache/model/file.tflite",
                 model_url: str = "https://url/to/download/model/from.tflite",
                 input_size: tuple = (192, 192, 3)):
        super().__init__(model_path, model_url, keypoints_type='coco',
                         input_size=input_size, output_type='keypoints')

    def postprocess_prediction(self, prediction, original_size):
        keypoints = prediction.squeeze()  # (17, 3) as (y, x, score)
        keypoints[:, :2] *= original_size
        keypoints = [tuple([int(x), int(y), s]) for y, x, s in keypoints]
        return keypoints  # (17, 3) as (x, y, score)
```

The decorated `model_registry.register` method registers the model with the ModelFactory, allowing it to be instantiated by name. The `postprocess_prediction` method converts the output from model-specific format to a list of keypoints, which is the expected output format for PocketPose.