Metadata-Version: 2.4
Name: rsai
Version: 2026.5.23
Summary: Add your description here
Requires-Python: >=3.10
Description-Content-Type: text/markdown
Requires-Dist: gdal==3.4.3
Requires-Dist: numpy<2.0.0,>=1.24.4
Requires-Dist: pyproj
Requires-Dist: rich
Requires-Dist: loguru
Requires-Dist: shapely
Requires-Dist: mercantile>=1.2.1
Requires-Dist: aiohttp>=3.13.3
Provides-Extra: cv2
Requires-Dist: opencv-python; extra == "cv2"
Provides-Extra: pil
Requires-Dist: pillow; extra == "pil"
Provides-Extra: rasterio
Requires-Dist: rasterio<=1.3.10; extra == "rasterio"
Provides-Extra: all
Requires-Dist: opencv-python; extra == "all"
Requires-Dist: pillow; extra == "all"
Requires-Dist: rasterio<=1.3.10; extra == "all"

# rsai

遥感图像 I/O 工具库，支持多后端读写、OGC 地图服务访问、矢量数据处理及训练样本制作。

## 安装

```bash
pip install rsai

# 安装可选后端
pip install "rsai[cv2]"       # OpenCV
pip install "rsai[pil]"       # Pillow
pip install "rsai[rasterio]"  # Rasterio
pip install "rsai[all]"       # 全部
```

> 核心依赖：GDAL 3.4.3、NumPy、pyproj、Shapely、mercantile、aiohttp

---

## 快速开始

### 读写本地图像

```python
from rsio import open_image

# 读取（默认 GDAL 后端）
with open_image("image.tif") as img:
    array = img.read()           # 读取全图为 numpy 数组
    print(img.width, img.height) # 图像尺寸
    print(img.channel)           # 波段数
    print(img.dtype)             # 数据类型
    print(img.geo_transform)     # 地理变换参数
    print(img.geo_projection)    # 投影信息（WKT）
    print(img.geo_bounds)        # 地理范围 (left, top, right, bottom)
    print(img.pixel_size)        # 像元大小 (x, y)
    print(img.nodata)            # 无效值

# 指定后端
with open_image("image.tif", backend="rasterio") as img:
    array = img.read()

# 写入（需传入地理信息）
OpenImage.write_array(
    "output.tif",
    array,
    backend="gdal",
    geo_transform=geo_transform,
    geo_projection=projection,
)
```

### 滑窗切片

```python
with open_image("large.tif") as img:
    # 生成器，逐块读取，节省内存
    for tile_array, (x, y) in img.sliding_window_tiles(
        window_size=512,   # 也可传 (width, height)
        stride=256,        # 步长，默认等于 window_size（无重叠）
        max_tiles=100,     # 限制最大切片数
        band_indexes=[1, 2, 3],  # 指定波段
    ):
        pass

    # 预估切片总数
    total = img.count_tiles(stride=256)
```

### 均匀分块

```python
with open_image("large.tif") as img:
    # 自动计算最优行列数，使每块尽量接近正方形
    for block_img, (x, y) in img.image_blocks(num_blocks=16):
        array = block_img.read()
```

### 图像裁剪

```python
with open_image("image.tif") as img:
    # 按像素坐标裁剪
    img.crop(x_off=100, y_off=100, window_x_size=512, window_y_size=512)

    # 按地理范围裁剪（left, top, right, bottom）
    img.crop_by_geo_bounds((116.3, 40.1, 116.5, 39.9))

    # 按 GeoJSON 裁剪（需 rasterio 后端）
    with open_image("image.tif", backend="rasterio") as img:
        img.crop_by_geojson(geojson_dict_or_path)
```

### 重投影

```python
with open_image("image.tif") as img:
    # 支持 EPSG 整数、"EPSG:xxxx" 字符串、WKT 字符串、或另一个 OpenImage 对象
    img.reproject(4326)
    img.reproject("EPSG:32650")
    img.reproject(other_img)
```

### 多图像求交集

```python
from rsio import open_image
from rsio.open_image import OpenImage

img_a = open_image("a.tif")
img_b = open_image("b.tif")

# 自动对齐投影并裁剪到公共地理范围
img_a, img_b = OpenImage.intersection_image([img_a, img_b])
```

### 图像 resize

```python
from rsio.open_image import OpenImage
import numpy as np

array = ...  # (H, W, C) 或 (C, H, W)

# 直接 resize
resized = OpenImage.resize_img(array, target_size=(512, 512))

# 保持宽高比并 padding
resized = OpenImage.resize_img(
    array,
    target_size=512,
    keep_ratio=True,
    pad=0,           # padding 填充值
    resample=1,      # 0=最近邻 1=双线性 2=双三次 3=区域 4=Lanczos
)
```

### 有效值掩码

```python
from rsio.open_image import OpenImage

# 多通道不全为 0 且不全为 255 的像素为有效
mask = OpenImage.valid_mask_of_image(array)  # shape: (H, W)，bool
```

---

## OGC 地图服务

### 通用接口

```python
from rsio import open_ogc

# XYZ 切片服务
with open_ogc("https://tile.openstreetmap.org/{z}/{x}/{y}.png", service="xyz") as ogc:
    array = ogc.read(bbox=[116.3, 39.9, 116.5, 40.1], zoom=14)

# WMTS 服务
with open_ogc(url, service="wmts", layer="layer_name") as ogc:
    array = ogc.read(bbox=..., zoom=14)

# WMS 服务
with open_ogc(url, service="wms", layer="layer_name") as ogc:
    array = ogc.read(bbox=..., width=512, height=512)
```

### 预置地图服务

```python
from rsio import open_ogc_service

# 天地图矢量底图（需申请 token）
with open_ogc_service("tianditu_vector", token="your_token") as svc:
    image_array = svc.image.read(bbox=..., zoom=14)
    label_array = svc.label.read(bbox=..., zoom=14)  # 标注图层
    shp = svc.shp                                     # 矢量数据

# ESRI 影像
with open_ogc_service("esri_image") as svc:
    array = svc.image.read(bbox=..., zoom=16)

# Google 影像
with open_ogc_service("google_image") as svc:
    array = svc.image.read(bbox=..., zoom=16)

# CartoDB 矢量底图
with open_ogc_service("cartodb_vector") as svc:
    array = svc.image.read(bbox=..., zoom=14)
```

| 服务名 | 说明 | 需要 token |
|--------|------|-----------|
| `tianditu_vector` | 天地图矢量底图 + 标注 | 是 |
| `cartodb_vector` | CartoDB 矢量底图 | 否 |
| `esri_image` | ESRI 卫星影像 | 否 |
| `google_image` | Google 卫星影像 | 否 |

---

## Shapefile 操作

```python
from rsio import open_shp

# 读取属性
with open_shp("data.shp") as shp:
    print(shp.feature_count)    # 要素数量
    print(shp.extent)           # 范围 (min_x, min_y, max_x, max_y)
    print(shp.geo_bounds)       # 地理范围
    print(shp.geo_projection)   # 投影信息
    for feat in shp.features:
        geom = feat.GetGeometryRef()

# 重投影并保存
with open_shp("data.shp") as shp:
    shp.reproject(target_epsg=4326)
    shp.save("reprojected.shp")

# 合并另一个 shp
with open_shp("base.shp") as shp:
    shp.add_other_shp("extra.shp")
    shp.save("merged.shp")

# 简化几何
with open_shp("data.shp") as shp:
    shp.simplify_shp(tolerance=0.001)
    shp.save("simplified.shp")

# 转为栅格图像（以参考图像的范围和分辨率为准）
with open_shp("data.shp") as shp:
    img = shp.convert_to_open_image(ref_image="ref.tif")
    img.save("rasterized.tif")
```

---

## 训练样本制作

### 语义分割样本

```python
from rsio.build_samples import SemanticSampleBuilder

builder = SemanticSampleBuilder(
    input_images="images/",          # 图像目录或路径列表
    input_labels="labels/",          # 标签目录或路径列表（支持 tif/shp）
    output_dir="samples/",           # 输出目录
    window_size=512,                 # 切片大小，也可传 (w, h)
    stride=256,                      # 步长，默认等于 window_size
    max_samples=None,                # 最大样本数，None 表示不限
    band_indexes=[1, 2, 3],          # 指定波段，None 表示全部
    palette=None,                    # 颜色映射表（用于彩色标签）
    split_val_ratio=0.1,             # 验证集比例
    split_test_ratio=0.1,            # 测试集比例
    save_blank_sample_prob=0.05,     # 保存全背景样本的概率（0 表示不保存）
    input_mode=0,                    # 0: 路径列表  1: 目录按文件名匹配  2: 目录按地理范围匹配
    num_workers=4,                   # 并行进程数
    attribute_field="class_id",      # shp 标签的属性字段名
    mapping_field={"road": 1},       # shp 属性值到类别 id 的映射
    image_prefix=None,               # 图像文件名前缀过滤
    label_suffix=None,               # 标签文件名后缀过滤
    crop=False,                      # 是否裁剪图像到标签范围
    log_file="build.log",            # 日志文件路径
)
builder.build()

# 生成可视化标签（将类别 id 映射为颜色）
builder.batch_build_color_label(
    input_labels="labels/",
    output_dir="color_labels/",
    palette={0: (0,0,0), 1: (255,0,0)},
)

# 生成叠加可视化图（图像 + 标签半透明叠加）
builder.batch_build_visualizer_label(
    input_images="images/",
    input_labels="labels/",
    output_dir="visualize/",
)
```

输出目录结构：

```
samples/
├── train/
│   ├── images/
│   └── labels/
├── val/
│   ├── images/
│   └── labels/
└── test/
    ├── images/
    └── labels/
```

### 变化检测样本

```python
from rsio.build_samples import ChangedSampleBuilder

builder = ChangedSampleBuilder(
    input_images_a="time_a/",        # 时相 A 图像目录
    input_images_b="time_b/",        # 时相 B 图像目录
    input_labels="labels/",          # 变化标签目录
    output_dir="samples/",
    window_size=512,
    stride=256,
    split_val_ratio=0.1,
    split_test_ratio=0.1,
    num_workers=4,
)
builder.build()
```

输出目录结构：

```
samples/
├── train/
│   ├── images_a/
│   ├── images_b/
│   └── labels/
├── val/
│   ├── images_a/
│   ├── images_b/
│   └── labels/
└── test/
    ├── images_a/
    ├── images_b/
    └── labels/
```

---

## 后端说明

| 后端 | 关键字 | 适用场景 |
|------|--------|---------|
| GDAL（默认） | `gdal` | 地理栅格数据，支持完整地理信息，推荐首选 |
| Rasterio | `rasterio` | 地理栅格数据，支持 GeoJSON 裁剪 |
| PIL | `pil` | 普通图像，轻量读写，不含地理信息 |
| OpenCV | `cv2` | 图像处理，BGR 格式，resize/显示 |

后端通过 `backend=` 参数指定，所有后端均实现相同的抽象接口（`ImageBackend`），可无缝切换。

---

## 项目结构

```
rsio/
├── open_image.py        # 本地图像 I/O（OpenImage）
├── open_ogc.py          # OGC 服务（WMS/WMTS/XYZ）
├── open_ogc_services.py # 预置地图服务
├── open_shp.py          # Shapefile I/O
├── build_samples.py     # 训练样本制作
├── enums.py             # 枚举定义
├── backends/
│   ├── base_backend.py      # 抽象基类 ImageBackend
│   ├── gdal_backend.py      # GDALImage
│   ├── rasterio_backend.py  # RasterioImage
│   ├── pil_backend.py       # PILImage
│   └── cv2_backend.py       # OpencvImage
└── utils/
    ├── converters.py    # 坐标/数组/颜色/BBox 转换
    ├── decorators.py    # 通用装饰器（类型检查、依赖检查等）
    ├── exceptions.py    # 统一异常体系
    ├── geo.py           # 栅格/矢量地理操作
    └── misc.py          # 日志、安全导入、列表工具
```
