Metadata-Version: 2.1
Name: torchcat
Version: 0.0.8
Summary: TorchCat 是用于封装 PyTorch 模型的工具
Home-page: https://gitee.com/kkkaiyu/torchcat
Author: KaiYu
Author-email: 2971934557@qq.com
License: GPL-3.0
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch
Requires-Dist: torchvision
Requires-Dist: torchsummary
Requires-Dist: numpy

# TorchCat 🐱

# 简介

TorchCat 是用于封装 PyTorch 模型的工具

提供以下功能：

- 简化训练过程
- 简化测试过程
- 记录训练日志

# 用法

导入 TorchCat

```python
import torchcat
```

## 封装模型

使用 `Cat` 封装你的模型。如果不进行训练，也可以忽略 `loss_fn`、`torchcat` 参数

```python
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)

net = torchcat.Cat(model=net,
                   loss_fn=nn.CrossEntropyLoss(),
                   optimizer=torch.optim.Adam(net.parameters(), lr=0.0003))
```

| 参数      | 说明     |
| --------- | -------- |
| model     | 你的模型 |
| loss_fn   | 损失函数 |
| optimizer | 优化器   |

## 查看架构

在封装模型后，使用 **net.summary()**，可以查看模型的架构。**input_size** 参数需填写模型的输入形状，如：`net.summary(1, 28, 28)`

## 训练模型

使用 `net.train()`，可以开始模型的训练。训练结束后会返回训练日志

```python
log = net.train(train_set=train_set, epochs=5, valid_set=test_set)
```

`log` 记录了训练时的日志，包括以下内容

* 训练集损失（**log['train** **loss']**）
* 训练集准确率（**log['train** **acc']**）
* 验证集损失（**log['valid** **loss']**）
* 验证集准确率（**log['validacc']**）

| 参数      | 说明                |
| --------- | ------------------- |
| train_set | 训练集              |
| epochs    | 训练轮次            |
| valid_set | 验证集（默认 None） |

## 验证模型

使用 `net.valid(valid_set, show=True, train=False)`，能够验证模型在给定验证集上的性能，包括损失值、准确率。验证后模型将切换为推理模式

| 参数      | 说明                                           |
| --------- | ---------------------------------------------- |
| valid_set | 验证集                                         |
| show      | 是否输出验证集损失值、准确率（默认 True）      |
| train     | 验证后，模型是否且切换为训练模式（默认 False） |

# 其他

## 切换计算设备

TorchCat 提供了方法 `to_cpu()`、`to_cuda()` 用于切换计算设备（CPU 或 GPU🚀）

## 检查模型当前模式

使用 `training` 方法，模型当前是否处于训练模式。返回 `True` 表示处于训练模式，`False` 表示处于推理模式

## 模型推理

* 方法名：`__call__`
* 功能描述：执行模型的前向推理过程
* 参数：`x` - 输入数据
* 返回值：模型对输入数据 `x` 的推理结果
