Metadata-Version: 2.4
Name: mlcast
Version: 0.1.0
Summary: Machine learning weather nowcasting library
Project-URL: Documentation, https://github.com/mlcast-community/mlcast
Project-URL: Homepage, https://github.com/mlcast-community/mlcast
Project-URL: Issues, https://github.com/mlcast-community/mlcast/issues
Project-URL: Repository, https://github.com/mlcast-community/mlcast
Author-email: Gabriele Franch <franch@fbk.eu>, Leif Denby <lcd@dmi.dk>, Oscar Dewasmes <oscar.dewasmes@meteo.fr>
Maintainer-email: MLCast Community <franch@fbk.eu>
License: Apache-2.0
License-File: LICENSE
License-File: LICENSE-APACHE
License-File: LICENSE-BSD
Keywords: machine-learning,meteorology,nowcasting,precipitation,pytorch,radar,weather
Classifier: Development Status :: 2 - Pre-Alpha
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Atmospheric Science
Requires-Python: >=3.11
Requires-Dist: absl-py>=2.4
Requires-Dist: beartype>=0.18
Requires-Dist: cf-xarray>=0.10
Requires-Dist: etils>=1.13
Requires-Dist: fiddle>=0.3
Requires-Dist: fire>=0.7
Requires-Dist: importlib-resources>=5
Requires-Dist: ipykernel>=6.29.5
Requires-Dist: jaxtyping>=0.2
Requires-Dist: loguru>=0.7.3
Requires-Dist: mlflow[mlflow]>=3.11.1
Requires-Dist: numpy>=2.2.6
Requires-Dist: nvidia-ml-py[mlflow]>=13.595.45
Requires-Dist: pandas>=2
Requires-Dist: psutil[mlflow]>=7.2.2
Requires-Dist: pytorch-lightning>=2
Requires-Dist: pyyaml>=6
Requires-Dist: rich>=13
Requires-Dist: s3fs>=2026.3
Requires-Dist: tensorboard>=2.20
Requires-Dist: torch>=2.7.1
Requires-Dist: torchvision>=0.20
Requires-Dist: xarray>=2025.6.1
Requires-Dist: zarr>=3
Provides-Extra: dev
Requires-Dist: aiohttp>=3.9.3; extra == 'dev'
Requires-Dist: fsspec>=2024.2; extra == 'dev'
Requires-Dist: mypy>=1; extra == 'dev'
Requires-Dist: pre-commit>=3.5; extra == 'dev'
Requires-Dist: pytest-cov>=4.1; extra == 'dev'
Requires-Dist: pytest>=8; extra == 'dev'
Requires-Dist: ruff>=0.8; extra == 'dev'
Requires-Dist: s3fs>=2024.2; extra == 'dev'
Provides-Extra: gpu-cu128
Requires-Dist: torch>=2.7.1; extra == 'gpu-cu128'
Requires-Dist: torchvision>=0.20; extra == 'gpu-cu128'
Provides-Extra: gpu-cu130
Requires-Dist: torch>=2.7.1; extra == 'gpu-cu130'
Requires-Dist: torchvision>=0.20; extra == 'gpu-cu130'
Description-Content-Type: text/markdown

# mlcast

<!-- SPDX-License-Identifier: Apache-2.0 OR BSD-3-Clause -->

> ⚠️ This package is under active development. The API and functionality are subject to change until the v1.0.0 release.

The MLCast Community is a collaborative effort bringing together meteorological services, research institutions, and academia across Europe to develop a unified Python package for AI-based nowcasting. This is an initiative of the E-AI WG6 (Nowcasting) of EUMETNET.

This repo contains the `mlcast` package for machine learning-based weather nowcasting.

## Installation

As `mlcast` is in rapid development — the recommended path is to clone locally,
rather than installing a pinned release from PyPI.

### Local development: clone and install locally with uv

[Fork the repository](https://github.com/mlcast-community/mlcast/fork) on GitHub first, then clone your fork. This lets you track
upstream changes while keeping your own modifications on a separate branch:

```bash
git clone https://github.com/<your-github-username>/mlcast
cd mlcast

# Install uv if not already installed
curl -LsSf https://astral.sh/uv/install.sh | sh

# And install depencies, depending on whether you have a GPU available.
# CPU
uv sync

# GPU — CUDA 12.8
uv sync --extra gpu-cu128

# GPU — CUDA 13.0
uv sync --extra gpu-cu130
```

Next you can jump to [using mlcast](#usage) or, if you intend to modify the code, setup the development toolchain as described below:

```bash

# Install dev dependencies
uv sync --extra dev

# Install the pre-commit git hook (runs checks automatically on every commit)
uv run pre-commit install
```

### PyPI release

Tagged releases are published to PyPI and can be installed with pip:

```bash
pip install mlcast
```

For active development or access to unreleased changes, clone the repository and install locally with `uv` as described above.

## Usage

mlcast exposes two interfaces for training: a **command-line interface (CLI)**
for interactive and scripted use, and a **Python API** for programmatic control.
Both are built on [Fiddle](https://github.com/google/fiddle) — a configuration
library that lets you build a full experiment graph, override any parameter, and
reproduce runs exactly from a saved YAML file.

### Configuration model

Training in mlcast is currently built around a single base configuration
function, [`training_experiment`](src/mlcast/config/base.py), which defines the
default ConvGRU ensemble nowcasting setup: dataset, data module, network,
Lightning module, and trainer. Rather than writing a new config from scratch,
the intended workflow is to start from this base and apply targeted
modifications:

- **`set:` overrides** — change a single scalar parameter (e.g. batch size,
  learning rate, number of epochs)
- **fiddlers** — apply a named mutator function that keeps multiple related
  parameters in sync (e.g. switching the dataset class, toggling masking,
  changing the logger)
- **direct graph edits** (Python API only) — replace a sub-object entirely,
  for example swapping in a different network architecture

Any combination of these can be layered on top of the base config, and the
fully resolved config is always saved to YAML alongside the training logs so
runs can be reproduced exactly.

The diagram below shows the full default config graph as built by
[`training_experiment`](src/mlcast/config/base.py):

![training_experiment config graph](docs/config_diagram.svg)

### Command-line interface

Install the package and run:

```bash
mlcast train
```

This trains with the built-in [`training_experiment`](src/mlcast/config/base.py) defaults. All parameters
are controlled via `--config` flags:

| Prefix | Purpose | Example |
|--------|---------|---------|
| *(none)* | Use the built-in default config | `mlcast train` |
| `set:` | Override a single parameter | `--config set:data.batch_size=32` |
| `fiddler:` | Apply a semantic mutator (multi-param change) | `--config fiddler:use_random_sampler` |
| `config:` | Switch to a different `@auto_config` function | `--config=config:my_experiment` |
| `path/to/config.yaml` | Load a previously saved config | `--config saved.yaml` |

Multiple `--config` flags are applied in order and can be combined freely.

**Examples:**

```bash
# Override dataset path and batch size
mlcast train \
    --config set:data.dataset_factory.zarr_path=/data/radar.zarr \
    --config set:data.batch_size=32

# Switch to random sampler and log to MLflow
mlcast train \
    --config fiddler:use_random_sampler \
    --config fiddler:use_mlflow_logger

# Resume from a saved config with an epoch override
mlcast train \
    --config logs/mlcast/version_0/config.yaml \
    --config set:trainer.max_epochs=50

# Inspect the fully resolved config without starting training
mlcast train --config fiddler:use_random_sampler --print_config_and_exit
```

Run `mlcast train --help` for a full list of examples and available fiddlers.

### Python API

The Python API gives you full programmatic control over the config graph before
anything is instantiated.

**Run the default experiment with tweaks:**

```python
import fiddle as fdl
from mlcast.config import training_experiment, train_from_config
from mlcast.config.fiddlers import use_random_sampler

cfg = training_experiment.as_buildable()  # returns a fdl.Config graph — see src/mlcast/config/base.py

# Apply a fiddler to switch the dataset sampler
use_random_sampler(cfg)

# Override individual parameters directly on the config graph
cfg.data.batch_size = 32
cfg.trainer.max_epochs = 50

# Validates cross-parameter contracts, builds all objects, persists config
# YAML to the active logger, then calls trainer.fit() + trainer.test()
train_from_config(cfg)
```

**Custom network architecture:**

You can swap in any architecture by replacing `cfg.pl_module.network` with a
`fdl.Config` node.  The network must implement the nowcasting forward
interface — see [Custom network interface](#custom-network-interface) below.

As an example, here is how to wrap an
[mfai](https://github.com/meteofrance/mfai) `HalfUNet` (a plain single-step
U-Net) to satisfy the interface.  The wrapper channel-stacks the past frames
and runs the U-Net autoregressively for each requested forecast step:

> **Note** — `input_steps` equals `dataset_factory.input_steps` (6 by
> default) and is directly readable from the config graph before building.

```python
import einops
import fiddle as fdl
import torch
import torch.nn as nn
from jaxtyping import Float
from mfai.torch.models import HalfUNet
from mlcast.config import training_experiment, train_from_config
from mlcast.config.fiddlers import use_random_sampler

# Minimal adapter: channel-stack past frames → HalfUNet → one step at a time.
# NowcastLightningModule calls network(x, steps=N, ensemble_size=M), so any
# custom network must accept those keyword arguments.
class HalfUNetNowcaster(nn.Module):
    def __init__(self, input_steps: int = 6, num_vars: int = 1):
        super().__init__()
        self.input_steps = input_steps
        self.num_vars = num_vars
        self.unet = HalfUNet(
            input_shape=(256, 256),
            in_channels=input_steps * num_vars,
            out_channels=num_vars,
            settings=fdl.Config(HalfUNet.settings_kls),
        )

    @property
    def input_channels(self) -> int:
        # Externally, the HalfUNetNowcaster respects the required input shape structure
        # (batch, input_steps, num_vars, H, W), even though the internal U-Net is channel-stacked.
        # Adding this property allows the config consistency checks to verify that
        # the dataset and model agree on the expected number of input channels.
        return self.num_vars

    def forward(
        self,
        x: Float[torch.Tensor, "batch input_steps in_channels H W"],
        steps: int,
        ensemble_size: int = 1,
    ) -> Float[torch.Tensor, "batch steps out_channels H W"]:
        # channel-stack all input frames: (b, t, c, h, w) -> (b, t*c, h, w)
        x_flat = einops.rearrange(x, "b t c h w -> b (t c) h w")
        preds = []
        for _ in range(steps):
            y = self.unet(x_flat)   # [B, num_vars, H, W]
            preds.append(y.unsqueeze(1))
            # slide window: drop the oldest timestep (first num_vars channels),
            # append the latest prediction as the newest timestep
            x_flat = torch.cat([x_flat[:, self.num_vars:], y], dim=1)
        return torch.cat(preds, dim=1)

cfg = training_experiment.as_buildable()
use_random_sampler(cfg)

cfg.pl_module.network = fdl.Config(
    HalfUNetNowcaster,
    input_steps=cfg.data.dataset_factory.input_steps,
    num_vars=len(cfg.data.dataset_factory.standard_names),
)

train_from_config(cfg)
```

For lower-level control you can call the steps of `train_from_config` individually:

```python
import fiddle as fdl
from mlcast.config.consistency_checks import validate_config

validate_config(cfg)          # raises ValueError on any contract violation
experiment = fdl.build(cfg)   # instantiates all objects
experiment.run()              # trainer.fit() + trainer.test()
```

### Available fiddlers

| Fiddler | Arguments | What it does |
|---------|-----------|--------------|
| `use_mlflow_logger` | *(none)* | Replaces the default `TensorBoardLogger` with `MLFlowLogger` and appends `LogSystemInfoCallback`; respects the `MLFLOW_TRACKING_URI` environment variable |
| `set_variables` | `standard_names` | Sets the list of input variables on the dataset and updates `network.input_channels` to match |
| `toggle_masking` | `enabled` | Toggles masked-loss mode by setting both `dataset_factory.return_mask` and `pl_module.masked_loss` to the same value |
| `use_anon_s3_dataset` | `zarr_path`, `endpoint_url` | Points the dataset at an anonymous S3 object store; sets `zarr_path` and the required `storage_options` together |
| `use_random_sampler` | *(none)* | Switches the dataset factory to the on-the-fly random sampler (useful during development when no precomputed CSV is available) |

## Project Structure

```
mlcast/
├── src/mlcast/
│   ├── __main__.py                      # CLI entry point (mlcast train)
│   ├── nowcasting_module.py             # Generic Lightning module for nowcasting
│   ├── losses.py                        # CRPS, AFCRPS, MSE loss functions
│   ├── callbacks.py                     # Training callbacks
│   ├── visualization.py                 # TensorBoard image logging helpers
│   ├── config/
│   │   ├── base.py                      # Default training_experiment @auto_config
│   │   ├── fiddlers.py                  # Semantic config mutators
│   │   ├── consistency_checks.py        # Cross-parameter validation
│   │   ├── loader.py                    # YAML config loader
│   │   └── orchestrator.py             # train_from_config, config persistence
│   ├── data/
│   │   ├── source_data_datamodule.py    # Lightning DataModule
│   │   ├── source_data_datasets.py      # Zarr-backed PyTorch datasets
│   │   └── normalization.py             # Normalisation registry
│   └── models/
│       └── convgru.py                   # ConvGRU encoder-decoder
├── tests/
├── pyproject.toml
└── README.md
```

## Implemented architectures

### ConvGruModel

`ConvGruModel` (in `src/mlcast/models/convgru.py`) is an **encoder-decoder**
architecture.  It is **not autoregressive at forecast time**: rather than
generating each forecast frame from the previous predicted frame, the decoder
performs a temporal roll-out entirely in **latent space** — the ConvGRU at
each spatial scale unrolls over `forecast_steps` steps driven by noise or
zeros, with its hidden state initialised from the encoder.  Forecast frames
are only materialised at the end, by upsampling the final decoder hidden
states back to the original spatial resolution.

**Encoding** — a stack of `EncoderBlock` layers unrolls a ConvGRU
sequentially over the `input_steps` real observed frames.  Each block halves
the spatial resolution via `PixelUnshuffle(2)`.  The last hidden state of
each block is retained.

**Decoding** — a stack of `DecoderBlock` layers performs a latent-space
roll-out at each spatial scale.  Each decoder block's ConvGRU is initialised
with the final hidden state from the corresponding encoder block, then unrolls
over `forecast_steps` steps with noise or zeros as input — so the forecast
sequence emerges from the evolution of hidden states across multiple spatial
scales, never from feeding predictions back as inputs.  Spatial resolution is
doubled at each block via `PixelShuffle(2)`.

**Ensemble** — when `ensemble_size > 1` the decoder is run `ensemble_size`
times, each time with freshly sampled Gaussian noise.  The results are
concatenated along the channel dimension.

**Deterministic variant** ([diagram source](https://docs.google.com/presentation/d/1U2Y9vZADXTsgQBNiWYAgOwYeMPVu7TOk/edit?slide=id.p6#slide=id.p6)):

![ConvGruModel deterministic architecture](docs/architectures/convgru-deterministic.png)

**Stochastic / ensemble variant** ([diagram source](https://docs.google.com/presentation/d/1U2Y9vZADXTsgQBNiWYAgOwYeMPVu7TOk/edit?slide=id.p7#slide=id.p7)):

![ConvGruModel stochastic architecture](docs/architectures/convgru-stochastic.png)


### Custom network interface

Any network architecture can be used by replacing `cfg.pl_module.network`
with a `fdl.Config` node pointing at your class.  The only requirement is
that `forward` accepts the following signature:

```python
# from jaxtyping import Float
# import torch

def forward(
    self,
    x: Float[torch.Tensor, "batch input_steps in_channels H W"],
    steps: int,          # number of forecast steps to produce
    ensemble_size: int,  # number of stochastic ensemble members
) -> Float[torch.Tensor, "batch steps out_channels H W"]:
    ...
```

If your network uses a different parameter name for the input channel count
than `input_channels` (the default assumed by `ConvGruModel` and the
`set_variables` fiddler), set it explicitly on the config node.

## Contributing

Please feel free to raise issues or PRs if you have any suggestions or questions.

## Links to presentations for discussion about the API

- [2025/02/04 first design discussions](https://docs.google.com/presentation/d/1oWmnyxOfUMWgeQi0XyX4fX9YDMX1vl6h/edit?usp=drive_link&rtpof=true&sd=true)

## License

This project is dual-licensed under either:

* Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
* BSD 3-Clause License ([LICENSE-BSD](LICENSE-BSD) or https://opensource.org/licenses/BSD-3-Clause)

at your option.

See [LICENSE](LICENSE) for more details.
