# 1. DSL Versioning
version: 1.0

# 2. Global flags
saving_to_hf: ${saving_to_hf}      # boolean: true = MaxText→HF, false = HF→MaxText
scan_layers: ${scan_layers}        # boolean: whether to unroll per-layer hooks

# 3. Model Configurations
#    - For text‐only models, use config directly.
#    - For multimodal, include both text_config & vision_config.
config:
  text_config?:
    num_hidden_layers: ${config.text_config.num_hidden_layers}
    hidden_size:        ${config.text_config.hidden_size}
    head_dim:           ${config.text_config.head_dim}
  vision_config?:
    num_hidden_layers: ${config.vision_config.num_hidden_layers}

# 4. Reusable Operations
#    Each op can inspect saving_to_hf to choose its behavior.
rules:
  pad_and_scale_embedding:
    # Operation: Scale embeddings, then pad/truncate the vocabulary size.
    params:
      normalizer: sqrt(config.text_config.hidden_size)
    logic:
      to_hf:
        ops:
          # For HF, we divide by the sqrt of hidden_size
          - scale: 1.0 / normalizer
          # Pad or truncate the first dimension (vocab size)
          - pad: [target_shape[0], null]
      from_hf:
        ops:
          # From HF, we multiply back
          - scale: normalizer
          - pad: [target_shape[0], null]

  scale_rmsnorm:
    # Operation: Adjust the value by 1, then reshape.
    logic:
      to_hf:
        ops:
          - adjust: -1.0
          - reshape: target_shape
      from_hf:
        ops:
          - adjust: 1.0
          - reshape: target_shape

  reshape_kernel:
    # To HF: reshape, then transpose. From HF: transpose, then reshape.
    logic:
      to_hf:
        ops:
          - reshape: reverse(target_shape)
          - transpose: [1, 0]
      from_hf:
        ops:
          - transpose: [1, 0]
          - reshape: target_shape

  patch:
    # Operation: A single, multi-axis transpose.
    logic:
      to_hf:
        ops:
          - transpose: [3, 2, 0, 1]
      from_hf:
        ops:
          - transpose: [2, 3, 1, 0]

  bias:
    # Operation: Flatten in one direction, reshape in the other.
    logic:
      to_hf:
        ops:
          - flatten
      from_hf:
        ops:
          - reshape: target_shape

  pos_embed:
    # Operation: Squeeze or expand a dimension.
    logic:
      to_hf:
        ops:
          - squeeze: 0 # Squeeze the first axis
      from_hf:
        ops:
          - expand_dims: 0 # Add an axis at the first position
