Metadata-Version: 2.4
Name: mlx-genkit
Version: 0.4.4
Summary: MLX generation parity utils with HF/Torch-compatible sampling, processors, and steering hooks
Author: Your Team
Project-URL: Homepage, https://github.com/strangeloopcanon/mlx-genkit
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Programming Language :: Python :: 3
Requires-Python: >=3.9
Description-Content-Type: text/markdown
Requires-Dist: mlx>=0.17.0
Requires-Dist: mlx-lm>=0.17.0
Requires-Dist: transformers>=4.40.0; python_version >= "3.9"

# mlx-genkit

Small, reusable MLX generation and training toolkit that brings HF/Torch `generate()` feature parity and clean persona steering to Apple Silicon. It reuses mlx-lm primitives (caches, projections, speculative) and fills the missing parity pieces.

Features
- HF-style `GenerationConfig` with processors/warpers: repetition penalty, no-repeat-ngrams, frequency/presence, bad-words, min_new_tokens, `typical_p`, `epsilon_cutoff`.
- Constraints: `force_words_ids` (strict start + continuation), `suppress_tokens`, `begin_suppress_tokens`, multiple `eos_token_ids`, forced BOS/EOS and per-position `forced_decoder_ids`.
- Modes: sampling (fast path via mlx-lm), beam (`num_beams`, `length_penalty`, `early_stopping`), speculative (mlx-lm), sliding KV (`max_kv_size`).
- Hooks: ResidualInjectionHook (sampling) and LogitBiasHook (sampling/beam); SoftPromptHook for training.
- Structured output enforcement: JSON-schema adherence, retry policies, semantic checks, and grammar stubs.
- Streaming with incremental validation: token callbacks, schema-aware early exits, and detailed result metadata.
- Batch helpers, JSONL adherence logging, and an eval harness for prompt suites.
- Training (MLX): `loss_forward`, `xent_loss` (label smoothing), mixed-precision compute (bf16) with fp32 master weights.
- Training utilities: `sequence_logprob`, `token_kl` for scoring and policy KL.
- Model helpers: `ema_update`, `build_action_mask`, `stable_softmax`; best-effort `clone_reference`.

Install
- From PyPI (recommended):
```
pip install mlx-genkit
```
- Dependencies (if not already installed):
```
pip install mlx mlx-lm transformers
```
- From source (editable):
```
pip install -e .
```

Models from Hugging Face
- If the repo provides MLX weights (e.g., in `mlx-community`), you can load directly: `load('mlx-community/<model>')`.
- For standard HF (PyTorch) repos, convert once using mlx-lm:
  - Python: `from mlx_genkit.interop import convert_hf_to_mlx; convert_hf_to_mlx('Qwen/Qwen3-0.6B', quantize=False, local_out='mlx_qwen3_0_6b')`
  - CLI: `mlx_lm.convert --hf-path Qwen/Qwen3-0.6B --mlx-path mlx_qwen3_0_6b`
  - Then load with `load('mlx_qwen3_0_6b')`.

Auto-convert loader
- You can pass either an HF repo id or a local MLX path to `auto_load`, which will convert once and cache under `./mlx_cache/<sanitized_repo_id>`:
```
from mlx_genkit.loader import auto_load
model, tokenizer, local_path = auto_load('Qwen/Qwen3-0.6B')
print('Loaded from', local_path)  # e.g., ./mlx_cache/Qwen_Qwen3-0.6B
```

Basic usage
```
from mlx_genkit import GenerationConfig, generate
from mlx_lm import load

model, tokenizer = load('mlx_qwen3_0_6b')
cfg = GenerationConfig(max_tokens=64, temperature=0.7, top_p=0.95, seed=17)
out = generate(model, tokenizer, 'Hello MLX parity', cfg)
print(out['text'])
```

Chat prompts (auto chat template)
```
# If you pass a list of HF-style messages, mlx-genkit will automatically
# apply the tokenizer's chat template when available.
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Summarize MLX in 3 bullets."},
]
cfg = GenerationConfig(max_tokens=64, temperature=0.7)
out = generate(model, tokenizer, messages, cfg)
print(out['text'])
```

Auto-apply chat template for plain prompts
```
# You can also provide a plain string and have mlx-genkit wrap it
# using the model's chat template when available. This is enabled
# automatically if the tokenizer defines a chat_template; you can
# force or disable it via GenerationConfig, or force with assume_user_chat.
cfg = GenerationConfig(max_tokens=64, temperature=0.7, auto_chat_template=True, system_prompt="You are helpful")
# Equivalent explicit flag:
# cfg = GenerationConfig(max_tokens=64, temperature=0.7, assume_user_chat=True, system_prompt="You are helpful")
out = generate(model, tokenizer, "Summarize MLX in 3 bullets.", cfg)
print(out['text'])
```

Beam and constraints
```
cfg = GenerationConfig(max_tokens=64, temperature=0.0, num_beams=4, early_stopping=True, length_penalty=0.2,
                       force_words_ids=[tokenizer.encode(' cat')], min_new_tokens=8,
                       bad_words_ids=[[tokenizer.eos_token_id]], suppress_tokens=[tokenizer.eos_token_id])
out = generate(model, tokenizer, 'The', cfg)
```

Speculative decoding
```
cfg = GenerationConfig(max_tokens=64, temperature=0.7, top_p=0.95,
                       use_speculative=True, draft_model_id='mlx_qwen3_0_6b', num_draft_tokens=3)
out = generate(model, tokenizer, 'Speculative test', cfg)
```

Structured JSON output
```
from mlx_genkit import GenerationConfig, JsonAdherence, generate

schema = {
    "type": "object",
    "properties": {
        "summary": {"type": "string"},
        "confidence": {"type": "number"}
    },
    "required": ["summary", "confidence"],
}

cfg = GenerationConfig(max_tokens=220, temperature=0.0)
result = generate(
    model,
    tokenizer,
    "Summarise the change log in JSON",
    cfg,
    json_schema=schema,
    adherence=JsonAdherence(retries=2, strict_only_json=True),
)
print(result.json)
```
- Structured adherence automatically strips common ```json fenced output and keeps
  retrying with escalating prompts when `JsonAdherence(retries=...)` is configured.
  Use `strict_only_json=True` to reject any non-JSON commentary.

Structured DSL
```
from mlx_genkit import StructuredSpec, generate_structured

spec = StructuredSpec(
    schema=schema,
    fields=["summary", "confidence"],
    examples=[{"input": "Bug fix", "output": {"summary": "...", "confidence": 0.9}}],
)
res = generate_structured(model, tokenizer, task="Summarise the diff", spec=spec)
print(res.json)
```

Streaming & incremental validation
```
from mlx_genkit import (
    GenerationConfig,
    JsonAdherence,
    StreamCallbacks,
    generate_stream,
)

def stream_json(model, tokenizer, prompt, schema):
    emitted = []

    def on_token(token, idx):
        emitted.append(token)
        piece = tokenizer.decode([token]) if isinstance(token, int) else str(token)
        if piece:
            print(piece, end="", flush=True)

    def on_invalid_path(info):
        print("\n[invalid path detected]", info["message"])

    callbacks = StreamCallbacks(on_token=on_token, on_invalid_path=on_invalid_path)
    result = generate_stream(
        model,
        tokenizer,
        prompt,
        GenerationConfig(max_tokens=128, temperature=0.0),
        json_schema=schema,
        adherence=JsonAdherence(strict_only_json=True, retries=1),
        on_token=callbacks.on_token,
        on_invalid_path=callbacks.on_invalid_path,
        stop_on_invalid=False,
    )
    print("\nAttempts:", result.attempts, " tokens:", len(emitted))
    return result
```
`StreamCallbacks` let you observe every token while the incremental monitor enforces the schema. Leave `stop_on_invalid=True` for hard stops, or set it to `False` to keep streaming after the warning while JSON retries repair the output.

Persona steering
```
import mlx.core as mx
from mlx_genkit import LogitBiasHook
H = model.args.hidden_size
model['_persona_v'] = mx.random.normal((H,)) * (1.0/(H**0.5))
cfg = GenerationConfig(max_tokens=64, temperature=0.7)
out = generate(model, tokenizer, 'Summarize MLX', cfg, hooks=[LogitBiasHook(param_key='_persona_v', alpha=1.2)])
```

Training (MLX)
```
from mlx_genkit import TrainingConfig, train_step, SoftPromptHook
from mlx.optimizers import AdamW
pad_id = getattr(tokenizer, 'pad_token_id', -100) or -100
opt = AdamW(learning_rate=2e-4)
batch = {'tokens': ...}  # mx.array [B, T]
cfg = TrainingConfig(dtype='bf16', loss_scale=1024.0)
loss = train_step(model, batch, opt, cfg, hooks=[SoftPromptHook(n_virtual=10, param_key='_soft_prompt')], pad_id=pad_id)
```

Utilities
```
from mlx_genkit import sequence_logprob, token_kl, ema_update, build_action_mask

# Per-sample mean log-prob on supervised positions (labels == -100 are ignored)
lp = sequence_logprob(model, batch_tokens, labels)  # [B]

# KL(pi || pref) averaged over supervised positions
kl = token_kl(model, ref_model, batch_tokens, labels)  # [B]

# EMA update of a target model from a source model
ema_update(target_model, model, decay=0.999)

# Supervised mask after prompt
mask = build_action_mask(prompt_lens=[12, 20], seq_len=T)  # [B, T] bool
```

Parity testing
- Torch vs MLX: `python -m mlx_genkit.tests.parity_hf --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b --prompt 'hello'`
- Suite (8 prompts): `python -m mlx_genkit.tests.parity_suite --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b`

CLI wrapper
```
mlxgk-generate \
  --model Qwen/Qwen3-0.6B \
  --prompt "Hello MLX" \
  --max-tokens 64 --temp 0.7 --top-p 0.95 \
  --num-beams 1 --no-repeat-ngram-size 2
```

Structured CLI flags
- `--json-schema schema.json` enable schema validation (optionally combine with `--retries 2`).
- `--strict-only-json` forbid commentary; `--stream` prints incremental output while validating.
- `--validator jsonschema --validator mypkg.validators:custom` layer pluggable validators.
- `--semantic-checks checks.json` applies `must_contain`, `enum_in`, and `regex_on_field` predicates.
- `--log-jsonl adherence.jsonl --log-raw-on-fail` capture adherence diagnostics per run.
- `--backend` selects the decoding backend (`mlx`, `transformers`, or `vllm`) and validates grammar support.

CLI chat and stop strings
- Chat: `--messages-json '[{"role":"user","content":"hi"}]'` (auto-applies template)
- Auto chat for plain prompts: add `--auto-chat` (or disable with `--no-auto-chat`); optional `--system "You are helpful"`
- Force treating plain prompts as user messages: `--assume-user-chat` (equivalent to `--auto-chat`)
- Stop strings: use `--stop` or the alias `--stop-strings` (comma-separated)

Defaults
- The CLI will, by default, auto-apply chat templates when the loaded tokenizer exposes a chat template (has `apply_chat_template` and a non-empty `chat_template`). Use `--no-auto-chat` to turn this off.

Prefetch and convert (download)
```
# Download an HF repo and convert to MLX format without loading into memory.
# Prints the local path, e.g., ./mlx_cache/Qwen_Qwen2-7B-Instruct
mlxgk-download --model Qwen/Qwen2-7B-Instruct

# Options
#  --cache-dir DIR           Cache location (default: ./mlx_cache)
#  --quantize                Quantize during conversion
#  --trust-remote-code       Allow custom code from the repo
#  --force                   Reconvert and overwrite existing cache
```

Adherence eval harness
```
mlxgk-eval --suite adherence_suite.yaml --markdown report.md --json report.json
```
Example suite (`adherence_suite.yaml`):
```
name: adherence_smoke
model: Qwen/Qwen3-0.6B
cases:
  - name: summary
    prompt: "Return a JSON object with summary and confidence"
    json_schema:
      type: object
      properties:
        summary: {type: string}
        confidence: {type: number}
      required: [summary, confidence]
    retries: 2
    strict_only_json: true
```

Performance bench
```
python -m mlx_genkit.tests.perf_bench --hf-model Qwen/Qwen3-0.6B --mlx-model ./mlx_qwen3_0_6b --prompt "Hello performance" --max-tokens 64
```

Releases
- Bump version across files (defaults to patch):
  - `make bump-version` (use `PART=minor` or `PART=major` to override)
- Create and push a git tag (vX.Y.Z):
  - `make git-release`
  - This tags and pushes the repo; PyPI packaging can be added later.

Notes
- Parity targets control‑surface equivalence: constraints, stops, finish reasons, determinism; token streams may differ across frameworks/devices.
- Sampling fast path reuses mlx-lm’s decoding loop and caches for best performance on Apple Silicon.

Known limitations
- Residual injection uses Python-level patching; highly optimized/compiled paths may bypass it. Use `forward_with_hidden(..., strict=True)` when you need deterministic capture/injection semantics.
- Some MLX model classes may not accept `input_embeddings` (used for soft prompts in training). In those cases, the library now falls back gracefully to standard token-only forward.
- Beam search applies processors on raw logits and then normalizes (HF behavior). Earlier parity reports in this repo may reflect the previous implementation on normalized logprobs.

Tips
- When running examples directly from the repo, make sure you’re using the local sources: `pip install -e .` or run with `PYTHONPATH=.`.
- Parity/perf harnesses will download HF models; ensure network access and sufficient disk space.
