amachine.am_fast

  1from __future__ import annotations
  2
  3import sys
  4from pathlib import Path
  5import json
  6
  7import importlib.util
  8import os
  9
 10import pyarrow as pa
 11import pyarrow.parquet as pq
 12import pprint
 13import numpy as np
 14import numpy.typing as npt
 15
 16import networkx as nx
 17
 18from automata.fa.dfa import DFA
 19
 20
 21_build_dir = Path(__file__).resolve().parent / "build"
 22
 23if str(_build_dir) not in sys.path:
 24    sys.path.insert(0, str(_build_dir))
 25
 26try:
 27    from . import _am_fast
 28    from ._am_fast import *
 29except ImportError as e:
 30    print(f"\n--- DEBUG START ---")
 31    print(f"Error: {e}")
 32    print(f"Current Directory: {os.getcwd()}")
 33    print(f"File Location: {__file__}")
 34    # This helps see if the .so file is actually in the folder
 35    print(f"Contents of this folder: {os.listdir(os.path.dirname(__file__))}")
 36    print(f"--- DEBUG END ---\n")
 37    raise e
 38
 39# from ._am_fast import generate_cpp, block_entropy_convergence_cpp, strongly_connected_components_cpp, minify_dfa_cpp
 40from .json_utils import save_json
 41
 42def block_entropy_convergence(
 43    h_mu: float,
 44    n_states: int,
 45    n_symbols : int,
 46    convergence_tol: float,
 47    precision: float,
 48    eps: float,
 49    branches: list[tuple[float, list[float]]],
 50    trans: list[list[tuple[int, float, int]]],
 51    max_branches: int = 30_000_000
 52) -> any :
 53    return block_entropy_convergence_cpp(
 54        h_mu            = float( h_mu ),
 55        n_states        = n_states,
 56        n_symbols       = n_symbols,
 57        convergence_tol = convergence_tol,
 58        precision       = float(precision),
 59        eps             = eps,
 60        branches        = branches,
 61        trans           = trans,
 62        max_branches    = max_branches
 63    )
 64
 65def strongly_connected_components( T ) :
 66    return strongly_connected_components_cpp( T )
 67
 68def generate_data(
 69    n_gen        : int,
 70    start_state  : int,
 71    transitions  : list[list[tuple[int, float, int]]],
 72    alphabet     : list[str],
 73    include_states: bool = False,
 74    random_seed : int = 42 ) -> dict[str, npt.NDArray]:
 75
 76    n_states  = len(transitions)
 77    n_symbols = len(alphabet)
 78
 79    if n_states > 65536:
 80        raise ValueError(f"Max states supported is 65536, got {n_states}")
 81    if n_symbols > 65536:
 82        raise ValueError(f"Max alphabet size supported is 65536, got {n_symbols}")
 83
 84    symbol_indices, state_indices = generate_cpp(
 85        n_gen             = n_gen,
 86        start_state_index = start_state,
 87        transitions       = transitions,
 88        include_states    = include_states,
 89        random_seed       = random_seed
 90    )
 91
 92    res: dict[str, npt.NDArray] = {
 93        "symbol_index": np.from_dlpack(symbol_indices)
 94    }
 95    if include_states:
 96        res["state_index"] = np.from_dlpack(state_indices)
 97
 98    return res
 99
100def save_data(
101    data             : dict[str,any],
102    file_prefix      : str,
103    alphabet         : list[str],
104    n_states         : int,
105    start_state      : int,
106    random_seed      : int,
107    machine_metadata : dict[str,any] | None = None,
108) -> None:
109
110    def _uint_type(n: int) -> pa.DataType:
111        return pa.uint8() if n <= 256 else pa.uint16()
112
113    sym_type   = _uint_type(len(alphabet))
114    state_type = _uint_type(n_states)
115
116    columns: dict[str, pa.Array] = {
117        "symbol_index": pa.array(data["symbol_index"]).cast(sym_type)
118    }
119    
120    if "state_index" in data:
121        columns["state_index"] = pa.array(data["state_index"]).cast(state_type)
122
123    iso_shifts = data.get("isomorphic_shifts", {})
124    for shift, shifted in iso_shifts.items():
125        columns[f"symbol_index_isoshift_{shift}"] = pa.array(shifted["symbol_index"]).cast(sym_type)
126        columns[f"state_index_isoshift_{shift}"]  = pa.array(shifted["state_index"]).cast(state_type)
127
128    parquet_meta: dict[str,any] = {
129        "alphabet"          : alphabet,
130        "machine_metadata"  : machine_metadata or {},
131        "start_state"       : start_state,
132        "random_seed"       : random_seed,
133        "isomorphic_shifts" : sorted(iso_shifts.keys()),
134    }
135
136    table = pa.table(columns)
137    table = table.replace_schema_metadata({
138        **(table.schema.metadata or {}),
139        "am_metadata": json.dumps(parquet_meta),
140    })
141
142    pq.write_table(table, f"{file_prefix}.parquet")
143    save_json(parquet_meta, f"{file_prefix}.json")
144
145
146def minify_cpp(dfa: DFA, retain_names: bool = True) -> DFA:
147    state_list = list(dfa.states)
148    state_idx = {s: i for i, s in enumerate(state_list)}
149    symbol_list = sorted(dfa.input_symbols)
150    symbol_idx = {sym: i for i, sym in enumerate(symbol_list)}
151
152    adj = [[] for _ in range(len(state_list))]
153    for state, paths in dfa.transitions.items():
154        u = state_idx[state]
155        for sym, nxt in paths.items():
156            v = state_idx.get(nxt, -1)
157            if v != -1:
158                adj[u].append((symbol_idx[sym], v))
159
160    is_final = [s in dfa.final_states for s in state_list]
161    init_idx = state_idx[dfa.initial_state]
162
163    # result names in res now match the C++ struct fields exactly
164    res = minify_dfa_cpp(len(state_list), adj, init_idx, is_final)
165
166    if res.is_empty_language:
167        return dfa.__class__.empty_language(dfa.input_symbols)
168
169    if retain_names:
170        class_members = {}
171        for old_idx, nc in enumerate(res.eq_class):
172            if nc >= 0:
173                class_members.setdefault(nc, []).append(state_list[old_idx])
174        class_map = {nc: frozenset(m) for nc, m in class_members.items()}
175    else:
176        class_map = {nc: nc for nc in range(res.n_classes)}
177
178    new_states = set(class_map.values())
179    new_initial = class_map[res.new_initial]
180    new_final = {class_map[c] for c, f in enumerate(res.class_is_final) if f}
181
182    new_trans = {}
183    for nc in range(res.n_classes):
184        new_trans[class_map[nc]] = {
185            symbol_list[trans[0]]: class_map[trans[1]]
186            for trans in res.class_trans[nc]
187        }
188
189    return dfa.__class__(
190        states=new_states,
191        input_symbols=dfa.input_symbols,
192        transitions=new_trans,
193        initial_state=new_initial,
194        final_states=new_final,
195        allow_partial=any(len(t) < len(symbol_list) for t in new_trans.values())
196    )
def block_entropy_convergence( h_mu: float, n_states: int, n_symbols: int, convergence_tol: float, precision: float, eps: float, branches: list[tuple[float, list[float]]], trans: list[list[tuple[int, float, int]]], max_branches: int = 30000000) -> <built-in function any>:
43def block_entropy_convergence(
44    h_mu: float,
45    n_states: int,
46    n_symbols : int,
47    convergence_tol: float,
48    precision: float,
49    eps: float,
50    branches: list[tuple[float, list[float]]],
51    trans: list[list[tuple[int, float, int]]],
52    max_branches: int = 30_000_000
53) -> any :
54    return block_entropy_convergence_cpp(
55        h_mu            = float( h_mu ),
56        n_states        = n_states,
57        n_symbols       = n_symbols,
58        convergence_tol = convergence_tol,
59        precision       = float(precision),
60        eps             = eps,
61        branches        = branches,
62        trans           = trans,
63        max_branches    = max_branches
64    )
def strongly_connected_components(T):
66def strongly_connected_components( T ) :
67    return strongly_connected_components_cpp( T )
def generate_data( n_gen: int, start_state: int, transitions: list[list[tuple[int, float, int]]], alphabet: list[str], include_states: bool = False, random_seed: int = 42) -> dict[str, numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[~_ScalarT]]]:
69def generate_data(
70    n_gen        : int,
71    start_state  : int,
72    transitions  : list[list[tuple[int, float, int]]],
73    alphabet     : list[str],
74    include_states: bool = False,
75    random_seed : int = 42 ) -> dict[str, npt.NDArray]:
76
77    n_states  = len(transitions)
78    n_symbols = len(alphabet)
79
80    if n_states > 65536:
81        raise ValueError(f"Max states supported is 65536, got {n_states}")
82    if n_symbols > 65536:
83        raise ValueError(f"Max alphabet size supported is 65536, got {n_symbols}")
84
85    symbol_indices, state_indices = generate_cpp(
86        n_gen             = n_gen,
87        start_state_index = start_state,
88        transitions       = transitions,
89        include_states    = include_states,
90        random_seed       = random_seed
91    )
92
93    res: dict[str, npt.NDArray] = {
94        "symbol_index": np.from_dlpack(symbol_indices)
95    }
96    if include_states:
97        res["state_index"] = np.from_dlpack(state_indices)
98
99    return res
def save_data( data: dict[str, any], file_prefix: str, alphabet: list[str], n_states: int, start_state: int, random_seed: int, machine_metadata: dict[str, any] | None = None) -> None:
101def save_data(
102    data             : dict[str,any],
103    file_prefix      : str,
104    alphabet         : list[str],
105    n_states         : int,
106    start_state      : int,
107    random_seed      : int,
108    machine_metadata : dict[str,any] | None = None,
109) -> None:
110
111    def _uint_type(n: int) -> pa.DataType:
112        return pa.uint8() if n <= 256 else pa.uint16()
113
114    sym_type   = _uint_type(len(alphabet))
115    state_type = _uint_type(n_states)
116
117    columns: dict[str, pa.Array] = {
118        "symbol_index": pa.array(data["symbol_index"]).cast(sym_type)
119    }
120    
121    if "state_index" in data:
122        columns["state_index"] = pa.array(data["state_index"]).cast(state_type)
123
124    iso_shifts = data.get("isomorphic_shifts", {})
125    for shift, shifted in iso_shifts.items():
126        columns[f"symbol_index_isoshift_{shift}"] = pa.array(shifted["symbol_index"]).cast(sym_type)
127        columns[f"state_index_isoshift_{shift}"]  = pa.array(shifted["state_index"]).cast(state_type)
128
129    parquet_meta: dict[str,any] = {
130        "alphabet"          : alphabet,
131        "machine_metadata"  : machine_metadata or {},
132        "start_state"       : start_state,
133        "random_seed"       : random_seed,
134        "isomorphic_shifts" : sorted(iso_shifts.keys()),
135    }
136
137    table = pa.table(columns)
138    table = table.replace_schema_metadata({
139        **(table.schema.metadata or {}),
140        "am_metadata": json.dumps(parquet_meta),
141    })
142
143    pq.write_table(table, f"{file_prefix}.parquet")
144    save_json(parquet_meta, f"{file_prefix}.json")
def minify_cpp( dfa: automata.fa.dfa.DFA, retain_names: bool = True) -> automata.fa.dfa.DFA:
147def minify_cpp(dfa: DFA, retain_names: bool = True) -> DFA:
148    state_list = list(dfa.states)
149    state_idx = {s: i for i, s in enumerate(state_list)}
150    symbol_list = sorted(dfa.input_symbols)
151    symbol_idx = {sym: i for i, sym in enumerate(symbol_list)}
152
153    adj = [[] for _ in range(len(state_list))]
154    for state, paths in dfa.transitions.items():
155        u = state_idx[state]
156        for sym, nxt in paths.items():
157            v = state_idx.get(nxt, -1)
158            if v != -1:
159                adj[u].append((symbol_idx[sym], v))
160
161    is_final = [s in dfa.final_states for s in state_list]
162    init_idx = state_idx[dfa.initial_state]
163
164    # result names in res now match the C++ struct fields exactly
165    res = minify_dfa_cpp(len(state_list), adj, init_idx, is_final)
166
167    if res.is_empty_language:
168        return dfa.__class__.empty_language(dfa.input_symbols)
169
170    if retain_names:
171        class_members = {}
172        for old_idx, nc in enumerate(res.eq_class):
173            if nc >= 0:
174                class_members.setdefault(nc, []).append(state_list[old_idx])
175        class_map = {nc: frozenset(m) for nc, m in class_members.items()}
176    else:
177        class_map = {nc: nc for nc in range(res.n_classes)}
178
179    new_states = set(class_map.values())
180    new_initial = class_map[res.new_initial]
181    new_final = {class_map[c] for c, f in enumerate(res.class_is_final) if f}
182
183    new_trans = {}
184    for nc in range(res.n_classes):
185        new_trans[class_map[nc]] = {
186            symbol_list[trans[0]]: class_map[trans[1]]
187            for trans in res.class_trans[nc]
188        }
189
190    return dfa.__class__(
191        states=new_states,
192        input_symbols=dfa.input_symbols,
193        transitions=new_trans,
194        initial_state=new_initial,
195        final_states=new_final,
196        allow_partial=any(len(t) < len(symbol_list) for t in new_trans.values())
197    )