Generated by Cython 3.2.4
Yellow lines hint at Python interaction.
Click on a line that starts with a "+" to see the C code that Cython generated for it.
Raw output: alphazero_cy.c
+001: from libc.math cimport sqrt
__pyx_t_2 = __Pyx_PyDict_NewPresized(0); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 1, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); if (PyDict_SetItem(__pyx_mstate_global->__pyx_d, __pyx_mstate_global->__pyx_n_u_test, __pyx_t_2) < (0)) __PYX_ERR(0, 1, __pyx_L1_error) __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
002: from lean_reinforcement.agent.mcts.mcts_cy.base_mcts_cy cimport Node, BaseMCTS
+003: import math
__pyx_t_1 = __Pyx_Import(__pyx_mstate_global->__pyx_n_u_math, 0, 0, NULL, 0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 3, __pyx_L1_error) __pyx_t_2 = __pyx_t_1; __Pyx_GOTREF(__pyx_t_2); if (PyDict_SetItem(__pyx_mstate_global->__pyx_d, __pyx_mstate_global->__pyx_n_u_math, __pyx_t_2) < (0)) __PYX_ERR(0, 3, __pyx_L1_error) __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+004: import torch
__pyx_t_1 = __Pyx_Import(__pyx_mstate_global->__pyx_n_u_torch, 0, 0, NULL, 0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 4, __pyx_L1_error) __pyx_t_2 = __pyx_t_1; __Pyx_GOTREF(__pyx_t_2); if (PyDict_SetItem(__pyx_mstate_global->__pyx_d, __pyx_mstate_global->__pyx_n_u_torch, __pyx_t_2) < (0)) __PYX_ERR(0, 4, __pyx_L1_error) __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+005: from lean_dojo import TacticState, ProofFinished, LeanError, ProofGivenUp
{
PyObject* const __pyx_imported_names[] = {__pyx_mstate_global->__pyx_n_u_TacticState,__pyx_mstate_global->__pyx_n_u_ProofFinished,__pyx_mstate_global->__pyx_n_u_LeanError,__pyx_mstate_global->__pyx_n_u_ProofGivenUp};
__pyx_t_1 = __Pyx_Import(__pyx_mstate_global->__pyx_n_u_lean_dojo, __pyx_imported_names, 4, NULL, 0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 5, __pyx_L1_error)
}
__pyx_t_2 = __pyx_t_1;
__Pyx_GOTREF(__pyx_t_2);
{
PyObject* const __pyx_imported_names[] = {__pyx_mstate_global->__pyx_n_u_TacticState,__pyx_mstate_global->__pyx_n_u_ProofFinished,__pyx_mstate_global->__pyx_n_u_LeanError,__pyx_mstate_global->__pyx_n_u_ProofGivenUp};
for (__pyx_t_3=0; __pyx_t_3 < 4; __pyx_t_3++) {
__pyx_t_4 = __Pyx_ImportFrom(__pyx_t_2, __pyx_imported_names[__pyx_t_3]); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 5, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_4);
if (PyDict_SetItem(__pyx_mstate_global->__pyx_d, __pyx_imported_names[__pyx_t_3], __pyx_t_4) < (0)) __PYX_ERR(0, 5, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
}
}
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
006:
+007: cdef class MCTS_AlphaZero(BaseMCTS):
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero {
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_BaseMCTS __pyx_base;
PyObject *value_head;
};
/* … */
struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero {
struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_BaseMCTS __pyx_base;
float (*_puct_score)(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *, struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *, int __pyx_skip_dispatch);
};
static struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_vtabptr_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero;
+008: cdef public object value_head
/* Python wrapper */
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10value_head_1__get__(PyObject *__pyx_v_self); /*proto*/
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10value_head_1__get__(PyObject *__pyx_v_self) {
CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
PyObject *__pyx_r = 0;
__Pyx_RefNannyDeclarations
__Pyx_RefNannySetupContext("__get__ (wrapper)", 0);
__pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs);
__pyx_r = __pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10value_head___get__(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self));
/* function exit code */
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
static PyObject *__pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10value_head___get__(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self) {
PyObject *__pyx_r = NULL;
__Pyx_XDECREF(__pyx_r);
__Pyx_INCREF(__pyx_v_self->value_head);
__pyx_r = __pyx_v_self->value_head;
goto __pyx_L0;
/* function exit code */
__pyx_L0:;
__Pyx_XGIVEREF(__pyx_r);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* Python wrapper */
static int __pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10value_head_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value); /*proto*/
static int __pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10value_head_3__set__(PyObject *__pyx_v_self, PyObject *__pyx_v_value) {
CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
int __pyx_r;
__Pyx_RefNannyDeclarations
__Pyx_RefNannySetupContext("__set__ (wrapper)", 0);
__pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs);
__pyx_r = __pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10value_head_2__set__(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self), ((PyObject *)__pyx_v_value));
/* function exit code */
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
static int __pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10value_head_2__set__(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, PyObject *__pyx_v_value) {
int __pyx_r;
__Pyx_INCREF(__pyx_v_value);
__Pyx_GIVEREF(__pyx_v_value);
__Pyx_GOTREF(__pyx_v_self->value_head);
__Pyx_DECREF(__pyx_v_self->value_head);
__pyx_v_self->value_head = __pyx_v_value;
/* function exit code */
__pyx_r = 0;
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* Python wrapper */
static int __pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10value_head_5__del__(PyObject *__pyx_v_self); /*proto*/
static int __pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10value_head_5__del__(PyObject *__pyx_v_self) {
CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
int __pyx_r;
__Pyx_RefNannyDeclarations
__Pyx_RefNannySetupContext("__del__ (wrapper)", 0);
__pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs);
__pyx_r = __pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10value_head_4__del__(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self));
/* function exit code */
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
static int __pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10value_head_4__del__(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self) {
int __pyx_r;
__Pyx_INCREF(Py_None);
__Pyx_GIVEREF(Py_None);
__Pyx_GOTREF(__pyx_v_self->value_head);
__Pyx_DECREF(__pyx_v_self->value_head);
__pyx_v_self->value_head = Py_None;
/* function exit code */
__pyx_r = 0;
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
009:
+010: def __init__(
/* Python wrapper */
static int __pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
static int __pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
PyObject *__pyx_v_value_head = 0;
PyObject *__pyx_v_env = 0;
PyObject *__pyx_v_transformer = 0;
float __pyx_v_exploration_weight;
int __pyx_v_max_tree_nodes;
int __pyx_v_batch_size;
int __pyx_v_num_tactics_to_expand;
int __pyx_v_max_rollout_depth;
CYTHON_UNUSED double __pyx_v_max_time;
CYTHON_UNUSED PyObject *__pyx_v_kwargs = 0;
CYTHON_UNUSED Py_ssize_t __pyx_nargs;
CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
int __pyx_r;
__Pyx_RefNannyDeclarations
__Pyx_RefNannySetupContext("__init__ (wrapper)", 0);
#if CYTHON_ASSUME_SAFE_SIZE
__pyx_nargs = PyTuple_GET_SIZE(__pyx_args);
#else
__pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return -1;
#endif
__pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs);
{
PyObject ** const __pyx_pyargnames[] = {&__pyx_mstate_global->__pyx_n_u_value_head,&__pyx_mstate_global->__pyx_n_u_env,&__pyx_mstate_global->__pyx_n_u_transformer,&__pyx_mstate_global->__pyx_n_u_exploration_weight,&__pyx_mstate_global->__pyx_n_u_max_tree_nodes,&__pyx_mstate_global->__pyx_n_u_batch_size,&__pyx_mstate_global->__pyx_n_u_num_tactics_to_expand,&__pyx_mstate_global->__pyx_n_u_max_rollout_depth,&__pyx_mstate_global->__pyx_n_u_max_time,0};
PyObject* values[9] = {0,0,0,0,0,0,0,0,0};
const Py_ssize_t __pyx_kwds_len = (__pyx_kwds) ? __Pyx_NumKwargs_VARARGS(__pyx_kwds) : 0;
if (unlikely(__pyx_kwds_len) < 0) __PYX_ERR(0, 10, __pyx_L3_error)
if (__pyx_kwds_len > 0) {
switch (__pyx_nargs) {
case 9:
values[8] = __Pyx_ArgRef_VARARGS(__pyx_args, 8);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[8])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 8:
values[7] = __Pyx_ArgRef_VARARGS(__pyx_args, 7);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[7])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 7:
values[6] = __Pyx_ArgRef_VARARGS(__pyx_args, 6);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[6])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 6:
values[5] = __Pyx_ArgRef_VARARGS(__pyx_args, 5);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[5])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 5:
values[4] = __Pyx_ArgRef_VARARGS(__pyx_args, 4);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[4])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 4:
values[3] = __Pyx_ArgRef_VARARGS(__pyx_args, 3);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[3])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 3:
values[2] = __Pyx_ArgRef_VARARGS(__pyx_args, 2);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[2])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 2:
values[1] = __Pyx_ArgRef_VARARGS(__pyx_args, 1);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[1])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 1:
values[0] = __Pyx_ArgRef_VARARGS(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 0: break;
default: goto __pyx_L5_argtuple_error;
}
const Py_ssize_t kwd_pos_args = __pyx_nargs;
if (__Pyx_ParseKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, __pyx_v_kwargs, values, kwd_pos_args, __pyx_kwds_len, "__init__", 1) < (0)) __PYX_ERR(0, 10, __pyx_L3_error)
for (Py_ssize_t i = __pyx_nargs; i < 3; i++) {
if (unlikely(!values[i])) { __Pyx_RaiseArgtupleInvalid("__init__", 0, 3, 9, i); __PYX_ERR(0, 10, __pyx_L3_error) }
}
} else {
switch (__pyx_nargs) {
case 9:
values[8] = __Pyx_ArgRef_VARARGS(__pyx_args, 8);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[8])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 8:
values[7] = __Pyx_ArgRef_VARARGS(__pyx_args, 7);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[7])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 7:
values[6] = __Pyx_ArgRef_VARARGS(__pyx_args, 6);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[6])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 6:
values[5] = __Pyx_ArgRef_VARARGS(__pyx_args, 5);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[5])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 5:
values[4] = __Pyx_ArgRef_VARARGS(__pyx_args, 4);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[4])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 4:
values[3] = __Pyx_ArgRef_VARARGS(__pyx_args, 3);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[3])) __PYX_ERR(0, 10, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 3:
values[2] = __Pyx_ArgRef_VARARGS(__pyx_args, 2);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[2])) __PYX_ERR(0, 10, __pyx_L3_error)
values[1] = __Pyx_ArgRef_VARARGS(__pyx_args, 1);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[1])) __PYX_ERR(0, 10, __pyx_L3_error)
values[0] = __Pyx_ArgRef_VARARGS(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 10, __pyx_L3_error)
break;
default: goto __pyx_L5_argtuple_error;
}
}
__pyx_v_value_head = values[0];
__pyx_v_env = values[1];
__pyx_v_transformer = values[2];
if (values[3]) {
__pyx_v_exploration_weight = __Pyx_PyFloat_AsFloat(values[3]); if (unlikely((__pyx_v_exploration_weight == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 15, __pyx_L3_error)
} else {
__pyx_v_exploration_weight = ((float)1.41421356);
}
if (values[4]) {
__pyx_v_max_tree_nodes = __Pyx_PyLong_As_int(values[4]); if (unlikely((__pyx_v_max_tree_nodes == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 16, __pyx_L3_error)
} else {
__pyx_v_max_tree_nodes = ((int)0x2710);
}
if (values[5]) {
__pyx_v_batch_size = __Pyx_PyLong_As_int(values[5]); if (unlikely((__pyx_v_batch_size == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 17, __pyx_L3_error)
} else {
__pyx_v_batch_size = ((int)8);
}
if (values[6]) {
__pyx_v_num_tactics_to_expand = __Pyx_PyLong_As_int(values[6]); if (unlikely((__pyx_v_num_tactics_to_expand == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 18, __pyx_L3_error)
} else {
__pyx_v_num_tactics_to_expand = ((int)8);
}
if (values[7]) {
__pyx_v_max_rollout_depth = __Pyx_PyLong_As_int(values[7]); if (unlikely((__pyx_v_max_rollout_depth == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 19, __pyx_L3_error)
} else {
__pyx_v_max_rollout_depth = ((int)30);
}
if (values[8]) {
__pyx_v_max_time = __Pyx_PyFloat_AsDouble(values[8]); if (unlikely((__pyx_v_max_time == (double)-1) && PyErr_Occurred())) __PYX_ERR(0, 20, __pyx_L3_error)
} else {
__pyx_v_max_time = ((double)300.0);
}
}
goto __pyx_L6_skip;
__pyx_L5_argtuple_error:;
__Pyx_RaiseArgtupleInvalid("__init__", 0, 3, 9, __pyx_nargs); __PYX_ERR(0, 10, __pyx_L3_error)
__pyx_L6_skip:;
goto __pyx_L4_argument_unpacking_done;
__pyx_L3_error:;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__Pyx_XDECREF(__pyx_v_kwargs); __pyx_v_kwargs = 0;
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename);
__Pyx_RefNannyFinishContext();
return -1;
__pyx_L4_argument_unpacking_done:;
__pyx_r = __pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero___init__(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self), __pyx_v_value_head, __pyx_v_env, __pyx_v_transformer, __pyx_v_exploration_weight, __pyx_v_max_tree_nodes, __pyx_v_batch_size, __pyx_v_num_tactics_to_expand, __pyx_v_max_rollout_depth, __pyx_v_max_time, __pyx_v_kwargs);
int __pyx_lineno = 0;
const char *__pyx_filename = NULL;
int __pyx_clineno = 0;
/* function exit code */
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__Pyx_XDECREF(__pyx_v_kwargs);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
static int __pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero___init__(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, PyObject *__pyx_v_value_head, PyObject *__pyx_v_env, PyObject *__pyx_v_transformer, float __pyx_v_exploration_weight, int __pyx_v_max_tree_nodes, int __pyx_v_batch_size, int __pyx_v_num_tactics_to_expand, int __pyx_v_max_rollout_depth, CYTHON_UNUSED double __pyx_v_max_time, CYTHON_UNUSED PyObject *__pyx_v_kwargs) {
int __pyx_r;
/* … */
/* function exit code */
__pyx_r = 0;
goto __pyx_L0;
__pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_1);
__Pyx_XDECREF(__pyx_t_2);
__Pyx_XDECREF(__pyx_t_3);
__Pyx_XDECREF(__pyx_t_4);
__Pyx_XDECREF(__pyx_t_6);
__Pyx_XDECREF(__pyx_t_7);
__Pyx_XDECREF(__pyx_t_8);
__Pyx_XDECREF(__pyx_t_9);
__Pyx_XDECREF(__pyx_t_10);
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = -1;
__pyx_L0:;
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
011: self,
012: value_head,
013: env,
014: transformer,
015: float exploration_weight=1.41421356,
016: int max_tree_nodes=10000,
017: int batch_size=8,
018: int num_tactics_to_expand=8,
019: int max_rollout_depth=30,
020: max_time: float = 300.0,
021: **kwargs,
022: ):
+023: super().__init__(
__pyx_t_4 = NULL;
__pyx_t_5 = 1;
{
PyObject *__pyx_callargs[3] = {__pyx_t_4, ((PyObject *)__pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero), ((PyObject *)__pyx_v_self)};
__pyx_t_3 = __Pyx_PyObject_FastCall((PyObject*)__pyx_builtin_super, __pyx_callargs+__pyx_t_5, (3-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0;
if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 23, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_3);
}
__pyx_t_2 = __pyx_t_3;
__Pyx_INCREF(__pyx_t_2);
024: env=env,
025: transformer=transformer,
+026: exploration_weight=exploration_weight,
__pyx_t_4 = PyFloat_FromDouble(__pyx_v_exploration_weight); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 26, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_4);
+027: max_tree_nodes=max_tree_nodes,
__pyx_t_6 = __Pyx_PyLong_From_int(__pyx_v_max_tree_nodes); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 27, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_6);
+028: batch_size=batch_size,
__pyx_t_7 = __Pyx_PyLong_From_int(__pyx_v_batch_size); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 28, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_7);
+029: num_tactics_to_expand=num_tactics_to_expand,
__pyx_t_8 = __Pyx_PyLong_From_int(__pyx_v_num_tactics_to_expand); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 29, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_8);
+030: max_rollout_depth=max_rollout_depth,
__pyx_t_9 = __Pyx_PyLong_From_int(__pyx_v_max_rollout_depth); if (unlikely(!__pyx_t_9)) __PYX_ERR(0, 30, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_9); __pyx_t_5 = 0; { PyObject *__pyx_callargs[2 + ((CYTHON_VECTORCALL) ? 7 : 0)] = {__pyx_t_2, NULL}; __pyx_t_10 = __Pyx_MakeVectorcallBuilderKwds(7); if (unlikely(!__pyx_t_10)) __PYX_ERR(0, 23, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_10); if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_env, __pyx_v_env, __pyx_t_10, __pyx_callargs+1, 0) < (0)) __PYX_ERR(0, 23, __pyx_L1_error) if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_transformer, __pyx_v_transformer, __pyx_t_10, __pyx_callargs+1, 1) < (0)) __PYX_ERR(0, 23, __pyx_L1_error) if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_exploration_weight, __pyx_t_4, __pyx_t_10, __pyx_callargs+1, 2) < (0)) __PYX_ERR(0, 23, __pyx_L1_error) if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_max_tree_nodes, __pyx_t_6, __pyx_t_10, __pyx_callargs+1, 3) < (0)) __PYX_ERR(0, 23, __pyx_L1_error) if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_batch_size, __pyx_t_7, __pyx_t_10, __pyx_callargs+1, 4) < (0)) __PYX_ERR(0, 23, __pyx_L1_error) if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_num_tactics_to_expand, __pyx_t_8, __pyx_t_10, __pyx_callargs+1, 5) < (0)) __PYX_ERR(0, 23, __pyx_L1_error) if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_max_rollout_depth, __pyx_t_9, __pyx_t_10, __pyx_callargs+1, 6) < (0)) __PYX_ERR(0, 23, __pyx_L1_error) __pyx_t_1 = __Pyx_Object_VectorcallMethod_CallFromBuilder((PyObject*)__pyx_mstate_global->__pyx_n_u_init, __pyx_callargs+__pyx_t_5, (1-__pyx_t_5) | (1*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET), __pyx_t_10); __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; __Pyx_DECREF(__pyx_t_9); __pyx_t_9 = 0; __Pyx_DECREF(__pyx_t_10); __pyx_t_10 = 0; __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); } __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
031: )
+032: self.value_head = value_head
__Pyx_INCREF(__pyx_v_value_head); __Pyx_GIVEREF(__pyx_v_value_head); __Pyx_GOTREF(__pyx_v_self->value_head); __Pyx_DECREF(__pyx_v_self->value_head); __pyx_v_self->value_head = __pyx_v_value_head;
033:
+034: cpdef float _puct_score(self, Node node):
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_3_puct_score(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
); /*proto*/
static float __pyx_f_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero__puct_score(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node, int __pyx_skip_dispatch) {
float __pyx_v_q_value;
float __pyx_v_exploration;
int __pyx_v_v_loss;
int __pyx_v_visit_count;
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_parent = 0;
float __pyx_r;
/* Check if called by wrapper */
if (unlikely(__pyx_skip_dispatch)) ;
/* Check if overridden in Python */
else if (
#if !CYTHON_USE_TYPE_SLOTS
unlikely(Py_TYPE(((PyObject *)__pyx_v_self)) != __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero &&
__Pyx_PyType_HasFeature(Py_TYPE(((PyObject *)__pyx_v_self)), Py_TPFLAGS_HAVE_GC))
#else
unlikely(Py_TYPE(((PyObject *)__pyx_v_self))->tp_dictoffset != 0 || __Pyx_PyType_HasFeature(Py_TYPE(((PyObject *)__pyx_v_self)), (Py_TPFLAGS_IS_ABSTRACT | Py_TPFLAGS_HEAPTYPE)))
#endif
) {
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
static PY_UINT64_T __pyx_tp_dict_version = __PYX_DICT_VERSION_INIT, __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
if (unlikely(!__Pyx_object_dict_version_matches(((PyObject *)__pyx_v_self), __pyx_tp_dict_version, __pyx_obj_dict_version))) {
PY_UINT64_T __pyx_typedict_guard = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
#endif
__pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_mstate_global->__pyx_n_u_puct_score); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 34, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
if (!__Pyx_IsSameCFunction(__pyx_t_1, (void(*)(void)) __pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_3_puct_score)) {
__pyx_t_3 = NULL;
__Pyx_INCREF(__pyx_t_1);
__pyx_t_4 = __pyx_t_1;
__pyx_t_5 = 1;
#if CYTHON_UNPACK_METHODS
if (unlikely(PyMethod_Check(__pyx_t_4))) {
__pyx_t_3 = PyMethod_GET_SELF(__pyx_t_4);
assert(__pyx_t_3);
PyObject* __pyx__function = PyMethod_GET_FUNCTION(__pyx_t_4);
__Pyx_INCREF(__pyx_t_3);
__Pyx_INCREF(__pyx__function);
__Pyx_DECREF_SET(__pyx_t_4, __pyx__function);
__pyx_t_5 = 0;
}
#endif
{
PyObject *__pyx_callargs[2] = {__pyx_t_3, ((PyObject *)__pyx_v_node)};
__pyx_t_2 = __Pyx_PyObject_FastCall((PyObject*)__pyx_t_4, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0;
__Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 34, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
}
__pyx_t_6 = __Pyx_PyFloat_AsFloat(__pyx_t_2); if (unlikely((__pyx_t_6 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 34, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
__pyx_r = __pyx_t_6;
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
goto __pyx_L0;
}
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
__pyx_tp_dict_version = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
__pyx_obj_dict_version = __Pyx_get_object_dict_version(((PyObject *)__pyx_v_self));
if (unlikely(__pyx_typedict_guard != __pyx_tp_dict_version)) {
__pyx_tp_dict_version = __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
}
#endif
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
}
#endif
}
/* … */
/* function exit code */
__pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_1);
__Pyx_XDECREF(__pyx_t_2);
__Pyx_XDECREF(__pyx_t_3);
__Pyx_XDECREF(__pyx_t_4);
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._puct_score", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = 0;
__pyx_L0:;
__Pyx_XDECREF((PyObject *)__pyx_v_parent);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* Python wrapper */
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_3_puct_score(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
); /*proto*/
static PyMethodDef __pyx_mdef_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_3_puct_score = {"_puct_score", (PyCFunction)(void(*)(void))(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_3_puct_score, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0};
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_3_puct_score(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
) {
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node = 0;
#if !CYTHON_METH_FASTCALL
CYTHON_UNUSED Py_ssize_t __pyx_nargs;
#endif
CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
PyObject *__pyx_r = 0;
__Pyx_RefNannyDeclarations
__Pyx_RefNannySetupContext("_puct_score (wrapper)", 0);
#if !CYTHON_METH_FASTCALL
#if CYTHON_ASSUME_SAFE_SIZE
__pyx_nargs = PyTuple_GET_SIZE(__pyx_args);
#else
__pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL;
#endif
#endif
__pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs);
{
PyObject ** const __pyx_pyargnames[] = {&__pyx_mstate_global->__pyx_n_u_node,0};
PyObject* values[1] = {0};
const Py_ssize_t __pyx_kwds_len = (__pyx_kwds) ? __Pyx_NumKwargs_FASTCALL(__pyx_kwds) : 0;
if (unlikely(__pyx_kwds_len) < 0) __PYX_ERR(0, 34, __pyx_L3_error)
if (__pyx_kwds_len > 0) {
switch (__pyx_nargs) {
case 1:
values[0] = __Pyx_ArgRef_FASTCALL(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 34, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 0: break;
default: goto __pyx_L5_argtuple_error;
}
const Py_ssize_t kwd_pos_args = __pyx_nargs;
if (__Pyx_ParseKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values, kwd_pos_args, __pyx_kwds_len, "_puct_score", 0) < (0)) __PYX_ERR(0, 34, __pyx_L3_error)
for (Py_ssize_t i = __pyx_nargs; i < 1; i++) {
if (unlikely(!values[i])) { __Pyx_RaiseArgtupleInvalid("_puct_score", 1, 1, 1, i); __PYX_ERR(0, 34, __pyx_L3_error) }
}
} else if (unlikely(__pyx_nargs != 1)) {
goto __pyx_L5_argtuple_error;
} else {
values[0] = __Pyx_ArgRef_FASTCALL(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 34, __pyx_L3_error)
}
__pyx_v_node = ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)values[0]);
}
goto __pyx_L6_skip;
__pyx_L5_argtuple_error:;
__Pyx_RaiseArgtupleInvalid("_puct_score", 1, 1, 1, __pyx_nargs); __PYX_ERR(0, 34, __pyx_L3_error)
__pyx_L6_skip:;
goto __pyx_L4_argument_unpacking_done;
__pyx_L3_error:;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._puct_score", __pyx_clineno, __pyx_lineno, __pyx_filename);
__Pyx_RefNannyFinishContext();
return NULL;
__pyx_L4_argument_unpacking_done:;
if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_node), __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node, 1, "node", 0))) __PYX_ERR(0, 34, __pyx_L1_error)
__pyx_r = __pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_2_puct_score(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self), __pyx_v_node);
int __pyx_lineno = 0;
const char *__pyx_filename = NULL;
int __pyx_clineno = 0;
/* function exit code */
goto __pyx_L0;
__pyx_L1_error:;
__pyx_r = NULL;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
goto __pyx_L7_cleaned_up;
__pyx_L0:;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__pyx_L7_cleaned_up:;
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
static PyObject *__pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_2_puct_score(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node) {
PyObject *__pyx_r = NULL;
__Pyx_XDECREF(__pyx_r);
__pyx_t_1 = __pyx_f_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero__puct_score(__pyx_v_self, __pyx_v_node, 1); if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 34, __pyx_L1_error)
__pyx_t_2 = PyFloat_FromDouble(__pyx_t_1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 34, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
__pyx_r = __pyx_t_2;
__pyx_t_2 = 0;
goto __pyx_L0;
/* function exit code */
__pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_2);
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._puct_score", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = NULL;
__pyx_L0:;
__Pyx_XGIVEREF(__pyx_r);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* … */
__pyx_t_2 = __Pyx_CyFunction_New(&__pyx_mdef_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_3_puct_score, __Pyx_CYFUNCTION_CCLASS, __pyx_mstate_global->__pyx_n_u_MCTS_AlphaZero__puct_score, NULL, __pyx_mstate_global->__pyx_n_u_lean_reinforcement_agent_mcts_mc, __pyx_mstate_global->__pyx_d, ((PyObject *)__pyx_mstate_global->__pyx_codeobj_tab[0])); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 34, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030E0000
PyUnstable_Object_EnableDeferredRefcount(__pyx_t_2);
#endif
if (__Pyx_SetItemOnTypeDict(__pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero, __pyx_mstate_global->__pyx_n_u_puct_score, __pyx_t_2) < (0)) __PYX_ERR(0, 34, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
035: cdef float q_value
036: cdef float exploration
037: cdef int v_loss
038: cdef int visit_count
+039: cdef Node parent = node.parent
__pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_node), __pyx_mstate_global->__pyx_n_u_parent); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 39, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node))))) __PYX_ERR(0, 39, __pyx_L1_error) __pyx_v_parent = ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_1); __pyx_t_1 = 0;
040:
+041: if parent is None:
__pyx_t_7 = (((PyObject *)__pyx_v_parent) == Py_None);
if (__pyx_t_7) {
/* … */
}
+042: return 0.0
__pyx_r = 0.0;
goto __pyx_L0;
043:
+044: v_loss = self._get_virtual_loss(node)
__pyx_t_8 = ((struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self->__pyx_base.__pyx_vtab)->__pyx_base._get_virtual_loss(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_BaseMCTS *)__pyx_v_self), __pyx_v_node, 0); if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 44, __pyx_L1_error)
__pyx_v_v_loss = __pyx_t_8;
+045: visit_count = node.visit_count + v_loss
__pyx_v_visit_count = (__pyx_v_node->visit_count + __pyx_v_v_loss);
046:
+047: if visit_count == 0:
__pyx_t_7 = (__pyx_v_visit_count == 0);
if (__pyx_t_7) {
/* … */
goto __pyx_L4;
}
+048: q_value = 0.0
__pyx_v_q_value = 0.0;
049: else:
+050: q_value = node.max_value - (v_loss / <float>visit_count)
/*else*/ {
__pyx_v_q_value = (__pyx_v_node->max_value - (((float)__pyx_v_v_loss) / ((float)__pyx_v_visit_count)));
}
__pyx_L4:;
051:
052: exploration = (
053: self.exploration_weight
054: * node.prior_p
+055: * (sqrt(parent.visit_count) / (1 + visit_count))
__pyx_v_exploration = ((__pyx_v_self->__pyx_base.exploration_weight * __pyx_v_node->prior_p) * (sqrt(__pyx_v_parent->visit_count) / ((double)(1 + __pyx_v_visit_count))));
056: )
057:
+058: return q_value + exploration
__pyx_r = (__pyx_v_q_value + __pyx_v_exploration); goto __pyx_L0;
059:
+060: cpdef Node _get_best_child(self, Node node):
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_5_get_best_child(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
); /*proto*/
static struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_f_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero__get_best_child(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node, int __pyx_skip_dispatch) {
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_child = 0;
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_best_child = 0;
float __pyx_v_max_score;
float __pyx_v_score;
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_r = NULL;
/* Check if called by wrapper */
if (unlikely(__pyx_skip_dispatch)) ;
/* Check if overridden in Python */
else if (
#if !CYTHON_USE_TYPE_SLOTS
unlikely(Py_TYPE(((PyObject *)__pyx_v_self)) != __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero &&
__Pyx_PyType_HasFeature(Py_TYPE(((PyObject *)__pyx_v_self)), Py_TPFLAGS_HAVE_GC))
#else
unlikely(Py_TYPE(((PyObject *)__pyx_v_self))->tp_dictoffset != 0 || __Pyx_PyType_HasFeature(Py_TYPE(((PyObject *)__pyx_v_self)), (Py_TPFLAGS_IS_ABSTRACT | Py_TPFLAGS_HEAPTYPE)))
#endif
) {
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
static PY_UINT64_T __pyx_tp_dict_version = __PYX_DICT_VERSION_INIT, __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
if (unlikely(!__Pyx_object_dict_version_matches(((PyObject *)__pyx_v_self), __pyx_tp_dict_version, __pyx_obj_dict_version))) {
PY_UINT64_T __pyx_typedict_guard = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
#endif
__pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_mstate_global->__pyx_n_u_get_best_child); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 60, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
if (!__Pyx_IsSameCFunction(__pyx_t_1, (void(*)(void)) __pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_5_get_best_child)) {
__Pyx_XDECREF((PyObject *)__pyx_r);
__pyx_t_3 = NULL;
__Pyx_INCREF(__pyx_t_1);
__pyx_t_4 = __pyx_t_1;
__pyx_t_5 = 1;
#if CYTHON_UNPACK_METHODS
if (unlikely(PyMethod_Check(__pyx_t_4))) {
__pyx_t_3 = PyMethod_GET_SELF(__pyx_t_4);
assert(__pyx_t_3);
PyObject* __pyx__function = PyMethod_GET_FUNCTION(__pyx_t_4);
__Pyx_INCREF(__pyx_t_3);
__Pyx_INCREF(__pyx__function);
__Pyx_DECREF_SET(__pyx_t_4, __pyx__function);
__pyx_t_5 = 0;
}
#endif
{
PyObject *__pyx_callargs[2] = {__pyx_t_3, ((PyObject *)__pyx_v_node)};
__pyx_t_2 = __Pyx_PyObject_FastCall((PyObject*)__pyx_t_4, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0;
__Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 60, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
}
if (!(likely(((__pyx_t_2) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_2, __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node))))) __PYX_ERR(0, 60, __pyx_L1_error)
__pyx_r = ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_2);
__pyx_t_2 = 0;
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
goto __pyx_L0;
}
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
__pyx_tp_dict_version = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
__pyx_obj_dict_version = __Pyx_get_object_dict_version(((PyObject *)__pyx_v_self));
if (unlikely(__pyx_typedict_guard != __pyx_tp_dict_version)) {
__pyx_tp_dict_version = __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
}
#endif
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
}
#endif
}
/* … */
/* function exit code */
__pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_1);
__Pyx_XDECREF(__pyx_t_2);
__Pyx_XDECREF(__pyx_t_3);
__Pyx_XDECREF(__pyx_t_4);
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._get_best_child", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = 0;
__pyx_L0:;
__Pyx_XDECREF((PyObject *)__pyx_v_child);
__Pyx_XDECREF((PyObject *)__pyx_v_best_child);
__Pyx_XGIVEREF((PyObject *)__pyx_r);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* Python wrapper */
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_5_get_best_child(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
); /*proto*/
static PyMethodDef __pyx_mdef_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_5_get_best_child = {"_get_best_child", (PyCFunction)(void(*)(void))(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_5_get_best_child, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0};
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_5_get_best_child(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
) {
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node = 0;
#if !CYTHON_METH_FASTCALL
CYTHON_UNUSED Py_ssize_t __pyx_nargs;
#endif
CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
PyObject *__pyx_r = 0;
__Pyx_RefNannyDeclarations
__Pyx_RefNannySetupContext("_get_best_child (wrapper)", 0);
#if !CYTHON_METH_FASTCALL
#if CYTHON_ASSUME_SAFE_SIZE
__pyx_nargs = PyTuple_GET_SIZE(__pyx_args);
#else
__pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL;
#endif
#endif
__pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs);
{
PyObject ** const __pyx_pyargnames[] = {&__pyx_mstate_global->__pyx_n_u_node,0};
PyObject* values[1] = {0};
const Py_ssize_t __pyx_kwds_len = (__pyx_kwds) ? __Pyx_NumKwargs_FASTCALL(__pyx_kwds) : 0;
if (unlikely(__pyx_kwds_len) < 0) __PYX_ERR(0, 60, __pyx_L3_error)
if (__pyx_kwds_len > 0) {
switch (__pyx_nargs) {
case 1:
values[0] = __Pyx_ArgRef_FASTCALL(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 60, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 0: break;
default: goto __pyx_L5_argtuple_error;
}
const Py_ssize_t kwd_pos_args = __pyx_nargs;
if (__Pyx_ParseKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values, kwd_pos_args, __pyx_kwds_len, "_get_best_child", 0) < (0)) __PYX_ERR(0, 60, __pyx_L3_error)
for (Py_ssize_t i = __pyx_nargs; i < 1; i++) {
if (unlikely(!values[i])) { __Pyx_RaiseArgtupleInvalid("_get_best_child", 1, 1, 1, i); __PYX_ERR(0, 60, __pyx_L3_error) }
}
} else if (unlikely(__pyx_nargs != 1)) {
goto __pyx_L5_argtuple_error;
} else {
values[0] = __Pyx_ArgRef_FASTCALL(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 60, __pyx_L3_error)
}
__pyx_v_node = ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)values[0]);
}
goto __pyx_L6_skip;
__pyx_L5_argtuple_error:;
__Pyx_RaiseArgtupleInvalid("_get_best_child", 1, 1, 1, __pyx_nargs); __PYX_ERR(0, 60, __pyx_L3_error)
__pyx_L6_skip:;
goto __pyx_L4_argument_unpacking_done;
__pyx_L3_error:;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._get_best_child", __pyx_clineno, __pyx_lineno, __pyx_filename);
__Pyx_RefNannyFinishContext();
return NULL;
__pyx_L4_argument_unpacking_done:;
if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_node), __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node, 1, "node", 0))) __PYX_ERR(0, 60, __pyx_L1_error)
__pyx_r = __pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_4_get_best_child(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self), __pyx_v_node);
int __pyx_lineno = 0;
const char *__pyx_filename = NULL;
int __pyx_clineno = 0;
/* function exit code */
goto __pyx_L0;
__pyx_L1_error:;
__pyx_r = NULL;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
goto __pyx_L7_cleaned_up;
__pyx_L0:;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__pyx_L7_cleaned_up:;
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
static PyObject *__pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_4_get_best_child(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node) {
PyObject *__pyx_r = NULL;
__Pyx_XDECREF(__pyx_r);
__pyx_t_1 = ((PyObject *)__pyx_f_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero__get_best_child(__pyx_v_self, __pyx_v_node, 1)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 60, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
__pyx_r = __pyx_t_1;
__pyx_t_1 = 0;
goto __pyx_L0;
/* function exit code */
__pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_1);
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._get_best_child", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = NULL;
__pyx_L0:;
__Pyx_XGIVEREF(__pyx_r);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* … */
__pyx_t_2 = __Pyx_CyFunction_New(&__pyx_mdef_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_5_get_best_child, __Pyx_CYFUNCTION_CCLASS, __pyx_mstate_global->__pyx_n_u_MCTS_AlphaZero__get_best_child, NULL, __pyx_mstate_global->__pyx_n_u_lean_reinforcement_agent_mcts_mc, __pyx_mstate_global->__pyx_d, ((PyObject *)__pyx_mstate_global->__pyx_codeobj_tab[1])); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 60, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030E0000
PyUnstable_Object_EnableDeferredRefcount(__pyx_t_2);
#endif
if (__Pyx_SetItemOnTypeDict(__pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero, __pyx_mstate_global->__pyx_n_u_get_best_child, __pyx_t_2) < (0)) __PYX_ERR(0, 60, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
061: cdef Node child
+062: cdef Node best_child = None
__Pyx_INCREF(Py_None);
__pyx_v_best_child = ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)Py_None);
+063: cdef float max_score = -1e9
__pyx_v_max_score = -1e9;
064: cdef float score
065:
+066: if not node.children:
if (__pyx_v_node->children == Py_None) __pyx_t_6 = 0;
else
{
Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_v_node->children);
if (unlikely(((!CYTHON_ASSUME_SAFE_SIZE) && __pyx_temp < 0))) __PYX_ERR(0, 66, __pyx_L1_error)
__pyx_t_6 = (__pyx_temp != 0);
}
__pyx_t_7 = (!__pyx_t_6);
if (unlikely(__pyx_t_7)) {
/* … */
}
+067: raise ValueError("Node has no children")
__pyx_t_2 = NULL;
__pyx_t_5 = 1;
{
PyObject *__pyx_callargs[2] = {__pyx_t_2, __pyx_mstate_global->__pyx_kp_u_Node_has_no_children};
__pyx_t_1 = __Pyx_PyObject_FastCall((PyObject*)(((PyTypeObject*)PyExc_ValueError)), __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0;
if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 67, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
}
__Pyx_Raise(__pyx_t_1, 0, 0, 0);
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
__PYX_ERR(0, 67, __pyx_L1_error)
068:
+069: for child in node.children:
if (unlikely(__pyx_v_node->children == Py_None)) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is not iterable");
__PYX_ERR(0, 69, __pyx_L1_error)
}
__pyx_t_1 = __pyx_v_node->children; __Pyx_INCREF(__pyx_t_1);
__pyx_t_8 = 0;
for (;;) {
{
Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_1);
#if !CYTHON_ASSUME_SAFE_SIZE
if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 69, __pyx_L1_error)
#endif
if (__pyx_t_8 >= __pyx_temp) break;
}
__pyx_t_2 = __Pyx_PyList_GetItemRefFast(__pyx_t_1, __pyx_t_8, __Pyx_ReferenceSharing_OwnStrongReference);
++__pyx_t_8;
if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 69, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
if (!(likely(((__pyx_t_2) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_2, __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node))))) __PYX_ERR(0, 69, __pyx_L1_error)
__Pyx_XDECREF_SET(__pyx_v_child, ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_2));
__pyx_t_2 = 0;
/* … */
}
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+070: score = self._puct_score(child)
__pyx_t_9 = ((struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self->__pyx_base.__pyx_vtab)->_puct_score(__pyx_v_self, __pyx_v_child, 0); if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 70, __pyx_L1_error)
__pyx_v_score = __pyx_t_9;
+071: if best_child is None or score > max_score:
__pyx_t_6 = (((PyObject *)__pyx_v_best_child) == Py_None);
if (!__pyx_t_6) {
} else {
__pyx_t_7 = __pyx_t_6;
goto __pyx_L7_bool_binop_done;
}
__pyx_t_6 = (__pyx_v_score > __pyx_v_max_score);
__pyx_t_7 = __pyx_t_6;
__pyx_L7_bool_binop_done:;
if (__pyx_t_7) {
/* … */
}
+072: max_score = score
__pyx_v_max_score = __pyx_v_score;
+073: best_child = child
__Pyx_INCREF((PyObject *)__pyx_v_child); __Pyx_DECREF_SET(__pyx_v_best_child, __pyx_v_child);
074:
+075: return best_child
__Pyx_XDECREF((PyObject *)__pyx_r); __Pyx_INCREF((PyObject *)__pyx_v_best_child); __pyx_r = __pyx_v_best_child; goto __pyx_L0;
076:
+077: cpdef Node _expand(self, Node node):
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_7_expand(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
); /*proto*/
static struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_f_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero__expand(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node, int __pyx_skip_dispatch) {
PyObject *__pyx_v_state_key = 0;
PyObject *__pyx_v_next_state = 0;
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_child = 0;
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_existing_node = 0;
PyObject *__pyx_v_state_str = NULL;
PyObject *__pyx_v_tactics_with_probs = NULL;
PyObject *__pyx_v_tactic = NULL;
PyObject *__pyx_v_prob = NULL;
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_r = NULL;
/* Check if called by wrapper */
if (unlikely(__pyx_skip_dispatch)) ;
/* Check if overridden in Python */
else if (
#if !CYTHON_USE_TYPE_SLOTS
unlikely(Py_TYPE(((PyObject *)__pyx_v_self)) != __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero &&
__Pyx_PyType_HasFeature(Py_TYPE(((PyObject *)__pyx_v_self)), Py_TPFLAGS_HAVE_GC))
#else
unlikely(Py_TYPE(((PyObject *)__pyx_v_self))->tp_dictoffset != 0 || __Pyx_PyType_HasFeature(Py_TYPE(((PyObject *)__pyx_v_self)), (Py_TPFLAGS_IS_ABSTRACT | Py_TPFLAGS_HEAPTYPE)))
#endif
) {
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
static PY_UINT64_T __pyx_tp_dict_version = __PYX_DICT_VERSION_INIT, __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
if (unlikely(!__Pyx_object_dict_version_matches(((PyObject *)__pyx_v_self), __pyx_tp_dict_version, __pyx_obj_dict_version))) {
PY_UINT64_T __pyx_typedict_guard = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
#endif
__pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_mstate_global->__pyx_n_u_expand); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 77, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
if (!__Pyx_IsSameCFunction(__pyx_t_1, (void(*)(void)) __pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_7_expand)) {
__Pyx_XDECREF((PyObject *)__pyx_r);
__pyx_t_3 = NULL;
__Pyx_INCREF(__pyx_t_1);
__pyx_t_4 = __pyx_t_1;
__pyx_t_5 = 1;
#if CYTHON_UNPACK_METHODS
if (unlikely(PyMethod_Check(__pyx_t_4))) {
__pyx_t_3 = PyMethod_GET_SELF(__pyx_t_4);
assert(__pyx_t_3);
PyObject* __pyx__function = PyMethod_GET_FUNCTION(__pyx_t_4);
__Pyx_INCREF(__pyx_t_3);
__Pyx_INCREF(__pyx__function);
__Pyx_DECREF_SET(__pyx_t_4, __pyx__function);
__pyx_t_5 = 0;
}
#endif
{
PyObject *__pyx_callargs[2] = {__pyx_t_3, ((PyObject *)__pyx_v_node)};
__pyx_t_2 = __Pyx_PyObject_FastCall((PyObject*)__pyx_t_4, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0;
__Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 77, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
}
if (!(likely(((__pyx_t_2) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_2, __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node))))) __PYX_ERR(0, 77, __pyx_L1_error)
__pyx_r = ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_2);
__pyx_t_2 = 0;
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
goto __pyx_L0;
}
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
__pyx_tp_dict_version = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
__pyx_obj_dict_version = __Pyx_get_object_dict_version(((PyObject *)__pyx_v_self));
if (unlikely(__pyx_typedict_guard != __pyx_tp_dict_version)) {
__pyx_tp_dict_version = __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
}
#endif
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
}
#endif
}
/* … */
/* function exit code */
__pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_1);
__Pyx_XDECREF(__pyx_t_2);
__Pyx_XDECREF(__pyx_t_3);
__Pyx_XDECREF(__pyx_t_4);
__Pyx_XDECREF(__pyx_t_10);
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._expand", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = 0;
__pyx_L0:;
__Pyx_XDECREF(__pyx_v_state_key);
__Pyx_XDECREF(__pyx_v_next_state);
__Pyx_XDECREF((PyObject *)__pyx_v_child);
__Pyx_XDECREF((PyObject *)__pyx_v_existing_node);
__Pyx_XDECREF(__pyx_v_state_str);
__Pyx_XDECREF(__pyx_v_tactics_with_probs);
__Pyx_XDECREF(__pyx_v_tactic);
__Pyx_XDECREF(__pyx_v_prob);
__Pyx_XGIVEREF((PyObject *)__pyx_r);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* Python wrapper */
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_7_expand(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
); /*proto*/
static PyMethodDef __pyx_mdef_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_7_expand = {"_expand", (PyCFunction)(void(*)(void))(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_7_expand, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0};
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_7_expand(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
) {
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node = 0;
#if !CYTHON_METH_FASTCALL
CYTHON_UNUSED Py_ssize_t __pyx_nargs;
#endif
CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
PyObject *__pyx_r = 0;
__Pyx_RefNannyDeclarations
__Pyx_RefNannySetupContext("_expand (wrapper)", 0);
#if !CYTHON_METH_FASTCALL
#if CYTHON_ASSUME_SAFE_SIZE
__pyx_nargs = PyTuple_GET_SIZE(__pyx_args);
#else
__pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL;
#endif
#endif
__pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs);
{
PyObject ** const __pyx_pyargnames[] = {&__pyx_mstate_global->__pyx_n_u_node,0};
PyObject* values[1] = {0};
const Py_ssize_t __pyx_kwds_len = (__pyx_kwds) ? __Pyx_NumKwargs_FASTCALL(__pyx_kwds) : 0;
if (unlikely(__pyx_kwds_len) < 0) __PYX_ERR(0, 77, __pyx_L3_error)
if (__pyx_kwds_len > 0) {
switch (__pyx_nargs) {
case 1:
values[0] = __Pyx_ArgRef_FASTCALL(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 77, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 0: break;
default: goto __pyx_L5_argtuple_error;
}
const Py_ssize_t kwd_pos_args = __pyx_nargs;
if (__Pyx_ParseKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values, kwd_pos_args, __pyx_kwds_len, "_expand", 0) < (0)) __PYX_ERR(0, 77, __pyx_L3_error)
for (Py_ssize_t i = __pyx_nargs; i < 1; i++) {
if (unlikely(!values[i])) { __Pyx_RaiseArgtupleInvalid("_expand", 1, 1, 1, i); __PYX_ERR(0, 77, __pyx_L3_error) }
}
} else if (unlikely(__pyx_nargs != 1)) {
goto __pyx_L5_argtuple_error;
} else {
values[0] = __Pyx_ArgRef_FASTCALL(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 77, __pyx_L3_error)
}
__pyx_v_node = ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)values[0]);
}
goto __pyx_L6_skip;
__pyx_L5_argtuple_error:;
__Pyx_RaiseArgtupleInvalid("_expand", 1, 1, 1, __pyx_nargs); __PYX_ERR(0, 77, __pyx_L3_error)
__pyx_L6_skip:;
goto __pyx_L4_argument_unpacking_done;
__pyx_L3_error:;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._expand", __pyx_clineno, __pyx_lineno, __pyx_filename);
__Pyx_RefNannyFinishContext();
return NULL;
__pyx_L4_argument_unpacking_done:;
if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_node), __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node, 1, "node", 0))) __PYX_ERR(0, 77, __pyx_L1_error)
__pyx_r = __pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_6_expand(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self), __pyx_v_node);
int __pyx_lineno = 0;
const char *__pyx_filename = NULL;
int __pyx_clineno = 0;
/* function exit code */
goto __pyx_L0;
__pyx_L1_error:;
__pyx_r = NULL;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
goto __pyx_L7_cleaned_up;
__pyx_L0:;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__pyx_L7_cleaned_up:;
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
static PyObject *__pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_6_expand(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node) {
PyObject *__pyx_r = NULL;
__Pyx_XDECREF(__pyx_r);
__pyx_t_1 = ((PyObject *)__pyx_f_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero__expand(__pyx_v_self, __pyx_v_node, 1)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 77, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
__pyx_r = __pyx_t_1;
__pyx_t_1 = 0;
goto __pyx_L0;
/* function exit code */
__pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_1);
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._expand", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = NULL;
__pyx_L0:;
__Pyx_XGIVEREF(__pyx_r);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* … */
__pyx_t_2 = __Pyx_CyFunction_New(&__pyx_mdef_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_7_expand, __Pyx_CYFUNCTION_CCLASS, __pyx_mstate_global->__pyx_n_u_MCTS_AlphaZero__expand, NULL, __pyx_mstate_global->__pyx_n_u_lean_reinforcement_agent_mcts_mc, __pyx_mstate_global->__pyx_d, ((PyObject *)__pyx_mstate_global->__pyx_codeobj_tab[2])); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 77, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030E0000
PyUnstable_Object_EnableDeferredRefcount(__pyx_t_2);
#endif
if (__Pyx_SetItemOnTypeDict(__pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero, __pyx_mstate_global->__pyx_n_u_expand, __pyx_t_2) < (0)) __PYX_ERR(0, 77, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
078: cdef object state_key
079: cdef object next_state
080: cdef Node child
081: cdef Node existing_node
082:
+083: if not isinstance(node.state, TacticState):
__pyx_t_1 = __pyx_v_node->state; __Pyx_INCREF(__pyx_t_1); __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_mstate_global->__pyx_n_u_TacticState); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 83, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); __pyx_t_6 = PyObject_IsInstance(__pyx_t_1, __pyx_t_2); if (unlikely(__pyx_t_6 == ((int)-1))) __PYX_ERR(0, 83, __pyx_L1_error) __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; __pyx_t_7 = (!__pyx_t_6); if (unlikely(__pyx_t_7)) { /* … */ }
+084: raise TypeError("Cannot expand a node without a TacticState.")
__pyx_t_1 = NULL;
__pyx_t_5 = 1;
{
PyObject *__pyx_callargs[2] = {__pyx_t_1, __pyx_mstate_global->__pyx_kp_u_Cannot_expand_a_node_without_a_T};
__pyx_t_2 = __Pyx_PyObject_FastCall((PyObject*)(((PyTypeObject*)PyExc_TypeError)), __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0;
if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 84, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
}
__Pyx_Raise(__pyx_t_2, 0, 0, 0);
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
__PYX_ERR(0, 84, __pyx_L1_error)
085:
+086: state_str = node.state.pp
__pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_node->state, __pyx_mstate_global->__pyx_n_u_pp); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 86, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); __pyx_v_state_str = __pyx_t_2; __pyx_t_2 = 0;
087:
+088: if node.encoder_features is None:
__pyx_t_7 = (__pyx_v_node->encoder_features == Py_None);
if (__pyx_t_7) {
/* … */
}
+089: node.encoder_features = self.value_head.encode_states([state_str])
__pyx_t_1 = __pyx_v_self->value_head;
__Pyx_INCREF(__pyx_t_1);
__pyx_t_4 = PyList_New(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 89, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_4);
__Pyx_INCREF(__pyx_v_state_str);
__Pyx_GIVEREF(__pyx_v_state_str);
if (__Pyx_PyList_SET_ITEM(__pyx_t_4, 0, __pyx_v_state_str) != (0)) __PYX_ERR(0, 89, __pyx_L1_error);
__pyx_t_5 = 0;
{
PyObject *__pyx_callargs[2] = {__pyx_t_1, __pyx_t_4};
__pyx_t_2 = __Pyx_PyObject_FastCallMethod((PyObject*)__pyx_mstate_global->__pyx_n_u_encode_states, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (1*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0;
__Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 89, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
}
__Pyx_GIVEREF(__pyx_t_2);
__Pyx_GOTREF(__pyx_v_node->encoder_features);
__Pyx_DECREF(__pyx_v_node->encoder_features);
__pyx_v_node->encoder_features = __pyx_t_2;
__pyx_t_2 = 0;
090:
+091: tactics_with_probs = self.transformer.generate_tactics_with_probs(
__pyx_t_4 = __pyx_v_self->__pyx_base.transformer;
__Pyx_INCREF(__pyx_t_4);
+092: state_str, n=self.num_tactics_to_expand
__pyx_t_1 = __Pyx_PyLong_From_int(__pyx_v_self->__pyx_base.num_tactics_to_expand); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 92, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_t_5 = 0; { PyObject *__pyx_callargs[2 + ((CYTHON_VECTORCALL) ? 1 : 0)] = {__pyx_t_4, __pyx_v_state_str}; __pyx_t_3 = __Pyx_MakeVectorcallBuilderKwds(1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 91, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_3); if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_n, __pyx_t_1, __pyx_t_3, __pyx_callargs+2, 0) < (0)) __PYX_ERR(0, 91, __pyx_L1_error) __pyx_t_2 = __Pyx_Object_VectorcallMethod_CallFromBuilder((PyObject*)__pyx_mstate_global->__pyx_n_u_generate_tactics_with_probs, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (1*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET), __pyx_t_3); __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 91, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); } __pyx_v_tactics_with_probs = __pyx_t_2; __pyx_t_2 = 0;
093: )
094:
+095: for tactic, prob in tactics_with_probs:
if (likely(PyList_CheckExact(__pyx_v_tactics_with_probs)) || PyTuple_CheckExact(__pyx_v_tactics_with_probs)) { __pyx_t_2 = __pyx_v_tactics_with_probs; __Pyx_INCREF(__pyx_t_2); __pyx_t_8 = 0; __pyx_t_9 = NULL; } else { __pyx_t_8 = -1; __pyx_t_2 = PyObject_GetIter(__pyx_v_tactics_with_probs); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 95, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); __pyx_t_9 = (CYTHON_COMPILING_IN_LIMITED_API) ? PyIter_Next : __Pyx_PyObject_GetIterNextFunc(__pyx_t_2); if (unlikely(!__pyx_t_9)) __PYX_ERR(0, 95, __pyx_L1_error) } for (;;) { if (likely(!__pyx_t_9)) { if (likely(PyList_CheckExact(__pyx_t_2))) { { Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_2); #if !CYTHON_ASSUME_SAFE_SIZE if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 95, __pyx_L1_error) #endif if (__pyx_t_8 >= __pyx_temp) break; } __pyx_t_3 = __Pyx_PyList_GetItemRefFast(__pyx_t_2, __pyx_t_8, __Pyx_ReferenceSharing_OwnStrongReference); ++__pyx_t_8; } else { { Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_2); #if !CYTHON_ASSUME_SAFE_SIZE if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 95, __pyx_L1_error) #endif if (__pyx_t_8 >= __pyx_temp) break; } #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS __pyx_t_3 = __Pyx_NewRef(PyTuple_GET_ITEM(__pyx_t_2, __pyx_t_8)); #else __pyx_t_3 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_8); #endif ++__pyx_t_8; } if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 95, __pyx_L1_error) } else { __pyx_t_3 = __pyx_t_9(__pyx_t_2); if (unlikely(!__pyx_t_3)) { PyObject* exc_type = PyErr_Occurred(); if (exc_type) { if (unlikely(!__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) __PYX_ERR(0, 95, __pyx_L1_error) PyErr_Clear(); } break; } } __Pyx_GOTREF(__pyx_t_3); if ((likely(PyTuple_CheckExact(__pyx_t_3))) || (PyList_CheckExact(__pyx_t_3))) { PyObject* sequence = __pyx_t_3; Py_ssize_t size = __Pyx_PySequence_SIZE(sequence); if (unlikely(size != 2)) { if (size > 2) __Pyx_RaiseTooManyValuesError(2); else if (size >= 0) __Pyx_RaiseNeedMoreValuesError(size); __PYX_ERR(0, 95, __pyx_L1_error) } #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS if (likely(PyTuple_CheckExact(sequence))) { __pyx_t_1 = PyTuple_GET_ITEM(sequence, 0); __Pyx_INCREF(__pyx_t_1); __pyx_t_4 = PyTuple_GET_ITEM(sequence, 1); __Pyx_INCREF(__pyx_t_4); } else { __pyx_t_1 = __Pyx_PyList_GetItemRefFast(sequence, 0, __Pyx_ReferenceSharing_SharedReference); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 95, __pyx_L1_error) __Pyx_XGOTREF(__pyx_t_1); __pyx_t_4 = __Pyx_PyList_GetItemRefFast(sequence, 1, __Pyx_ReferenceSharing_SharedReference); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 95, __pyx_L1_error) __Pyx_XGOTREF(__pyx_t_4); } #else __pyx_t_1 = __Pyx_PySequence_ITEM(sequence, 0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 95, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_t_4 = __Pyx_PySequence_ITEM(sequence, 1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 95, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_4); #endif __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; } else { Py_ssize_t index = -1; __pyx_t_10 = PyObject_GetIter(__pyx_t_3); if (unlikely(!__pyx_t_10)) __PYX_ERR(0, 95, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_10); __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; __pyx_t_11 = (CYTHON_COMPILING_IN_LIMITED_API) ? PyIter_Next : __Pyx_PyObject_GetIterNextFunc(__pyx_t_10); index = 0; __pyx_t_1 = __pyx_t_11(__pyx_t_10); if (unlikely(!__pyx_t_1)) goto __pyx_L7_unpacking_failed; __Pyx_GOTREF(__pyx_t_1); index = 1; __pyx_t_4 = __pyx_t_11(__pyx_t_10); if (unlikely(!__pyx_t_4)) goto __pyx_L7_unpacking_failed; __Pyx_GOTREF(__pyx_t_4); if (__Pyx_IternextUnpackEndCheck(__pyx_t_11(__pyx_t_10), 2) < (0)) __PYX_ERR(0, 95, __pyx_L1_error) __pyx_t_11 = NULL; __Pyx_DECREF(__pyx_t_10); __pyx_t_10 = 0; goto __pyx_L8_unpacking_done; __pyx_L7_unpacking_failed:; __Pyx_DECREF(__pyx_t_10); __pyx_t_10 = 0; __pyx_t_11 = NULL; if (__Pyx_IterFinish() == 0) __Pyx_RaiseNeedMoreValuesError(index); __PYX_ERR(0, 95, __pyx_L1_error) __pyx_L8_unpacking_done:; } __Pyx_XDECREF_SET(__pyx_v_tactic, __pyx_t_1); __pyx_t_1 = 0; __Pyx_XDECREF_SET(__pyx_v_prob, __pyx_t_4); __pyx_t_4 = 0; /* … */ __pyx_L5_continue:; } __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+096: next_state = self.env.run_tactic_stateless(node.state, tactic)
__pyx_t_4 = __pyx_v_self->__pyx_base.env;
__Pyx_INCREF(__pyx_t_4);
__pyx_t_5 = 0;
{
PyObject *__pyx_callargs[3] = {__pyx_t_4, __pyx_v_node->state, __pyx_v_tactic};
__pyx_t_3 = __Pyx_PyObject_FastCallMethod((PyObject*)__pyx_mstate_global->__pyx_n_u_run_tactic_stateless, __pyx_callargs+__pyx_t_5, (3-__pyx_t_5) | (1*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0;
if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 96, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_3);
}
__Pyx_XDECREF_SET(__pyx_v_next_state, __pyx_t_3);
__pyx_t_3 = 0;
097:
098: # Check for duplicate states
+099: state_key = self._get_state_key(next_state)
__pyx_t_3 = ((struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self->__pyx_base.__pyx_vtab)->__pyx_base._get_state_key(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_BaseMCTS *)__pyx_v_self), __pyx_v_next_state, 0); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 99, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_3); __Pyx_XDECREF_SET(__pyx_v_state_key, __pyx_t_3); __pyx_t_3 = 0;
+100: if state_key is not None and state_key in self.seen_states:
__pyx_t_6 = (__pyx_v_state_key != Py_None);
if (__pyx_t_6) {
} else {
__pyx_t_7 = __pyx_t_6;
goto __pyx_L10_bool_binop_done;
}
if (unlikely(__pyx_v_self->__pyx_base.seen_states == Py_None)) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is not iterable");
__PYX_ERR(0, 100, __pyx_L1_error)
}
__pyx_t_6 = (__Pyx_PyDict_ContainsTF(__pyx_v_state_key, __pyx_v_self->__pyx_base.seen_states, Py_EQ)); if (unlikely((__pyx_t_6 < 0))) __PYX_ERR(0, 100, __pyx_L1_error)
__pyx_t_7 = __pyx_t_6;
__pyx_L10_bool_binop_done:;
if (__pyx_t_7) {
/* … */
}
101: # Reuse existing node - add as child with additional parent edge
+102: existing_node = self.seen_states[state_key]
if (unlikely(__pyx_v_self->__pyx_base.seen_states == Py_None)) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable");
__PYX_ERR(0, 102, __pyx_L1_error)
}
__pyx_t_3 = __Pyx_PyDict_GetItem(__pyx_v_self->__pyx_base.seen_states, __pyx_v_state_key); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 102, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_3);
if (!(likely(((__pyx_t_3) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_3, __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node))))) __PYX_ERR(0, 102, __pyx_L1_error)
__Pyx_XDECREF_SET(__pyx_v_existing_node, ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_3));
__pyx_t_3 = 0;
+103: existing_node.add_parent(node, tactic)
__pyx_t_12.__pyx_n = 1;
__pyx_t_12.action = __pyx_v_tactic;
((struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_v_existing_node->__pyx_vtab)->add_parent(__pyx_v_existing_node, __pyx_v_node, 0, &__pyx_t_12); if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 103, __pyx_L1_error)
+104: if existing_node not in node.children:
__pyx_t_7 = (__Pyx_PySequence_ContainsTF(((PyObject *)__pyx_v_existing_node), __pyx_v_node->children, Py_NE)); if (unlikely((__pyx_t_7 < 0))) __PYX_ERR(0, 104, __pyx_L1_error) if (__pyx_t_7) { /* … */ }
+105: node.children.append(existing_node)
if (unlikely(__pyx_v_node->children == Py_None)) {
PyErr_Format(PyExc_AttributeError, "'NoneType' object has no attribute '%.30s'", "append");
__PYX_ERR(0, 105, __pyx_L1_error)
}
__pyx_t_13 = __Pyx_PyList_Append(__pyx_v_node->children, ((PyObject *)__pyx_v_existing_node)); if (unlikely(__pyx_t_13 == ((int)-1))) __PYX_ERR(0, 105, __pyx_L1_error)
+106: continue
goto __pyx_L5_continue;
107:
+108: child = Node(next_state, parent=node, action=tactic)
__pyx_t_4 = NULL;
__pyx_t_5 = 1;
{
PyObject *__pyx_callargs[2 + ((CYTHON_VECTORCALL) ? 2 : 0)] = {__pyx_t_4, __pyx_v_next_state};
__pyx_t_1 = __Pyx_MakeVectorcallBuilderKwds(2); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 108, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_parent, ((PyObject *)__pyx_v_node), __pyx_t_1, __pyx_callargs+2, 0) < (0)) __PYX_ERR(0, 108, __pyx_L1_error)
if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_action, __pyx_v_tactic, __pyx_t_1, __pyx_callargs+2, 1) < (0)) __PYX_ERR(0, 108, __pyx_L1_error)
__pyx_t_3 = __Pyx_Object_Vectorcall_CallFromBuilder((PyObject*)__pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET), __pyx_t_1);
__Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0;
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 108, __pyx_L1_error)
__Pyx_GOTREF((PyObject *)__pyx_t_3);
}
__Pyx_XDECREF_SET(__pyx_v_child, ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_3));
__pyx_t_3 = 0;
+109: child.prior_p = prob
__pyx_t_14 = __Pyx_PyFloat_AsFloat(__pyx_v_prob); if (unlikely((__pyx_t_14 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 109, __pyx_L1_error) __pyx_v_child->prior_p = __pyx_t_14;
+110: node.children.append(child)
if (unlikely(__pyx_v_node->children == Py_None)) {
PyErr_Format(PyExc_AttributeError, "'NoneType' object has no attribute '%.30s'", "append");
__PYX_ERR(0, 110, __pyx_L1_error)
}
__pyx_t_13 = __Pyx_PyList_Append(__pyx_v_node->children, ((PyObject *)__pyx_v_child)); if (unlikely(__pyx_t_13 == ((int)-1))) __PYX_ERR(0, 110, __pyx_L1_error)
+111: self.node_count += 1
__pyx_v_self->__pyx_base.node_count = (__pyx_v_self->__pyx_base.node_count + 1);
112:
113: # Register new state in seen_states and encode features
+114: if isinstance(next_state, TacticState):
__Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_mstate_global->__pyx_n_u_TacticState); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 114, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_3); __pyx_t_7 = PyObject_IsInstance(__pyx_v_next_state, __pyx_t_3); if (unlikely(__pyx_t_7 == ((int)-1))) __PYX_ERR(0, 114, __pyx_L1_error) __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; if (__pyx_t_7) { /* … */ }
+115: self.seen_states[state_key] = child
if (unlikely(__pyx_v_self->__pyx_base.seen_states == Py_None)) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable");
__PYX_ERR(0, 115, __pyx_L1_error)
}
if (unlikely((PyDict_SetItem(__pyx_v_self->__pyx_base.seen_states, __pyx_v_state_key, ((PyObject *)__pyx_v_child)) < 0))) __PYX_ERR(0, 115, __pyx_L1_error)
+116: child.encoder_features = self.value_head.encode_states([next_state.pp])
__pyx_t_1 = __pyx_v_self->value_head;
__Pyx_INCREF(__pyx_t_1);
__pyx_t_4 = __Pyx_PyObject_GetAttrStr(__pyx_v_next_state, __pyx_mstate_global->__pyx_n_u_pp); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 116, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_4);
__pyx_t_10 = PyList_New(1); if (unlikely(!__pyx_t_10)) __PYX_ERR(0, 116, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_10);
__Pyx_GIVEREF(__pyx_t_4);
if (__Pyx_PyList_SET_ITEM(__pyx_t_10, 0, __pyx_t_4) != (0)) __PYX_ERR(0, 116, __pyx_L1_error);
__pyx_t_4 = 0;
__pyx_t_5 = 0;
{
PyObject *__pyx_callargs[2] = {__pyx_t_1, __pyx_t_10};
__pyx_t_3 = __Pyx_PyObject_FastCallMethod((PyObject*)__pyx_mstate_global->__pyx_n_u_encode_states, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (1*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0;
__Pyx_DECREF(__pyx_t_10); __pyx_t_10 = 0;
if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 116, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_3);
}
__Pyx_GIVEREF(__pyx_t_3);
__Pyx_GOTREF(__pyx_v_child->encoder_features);
__Pyx_DECREF(__pyx_v_child->encoder_features);
__pyx_v_child->encoder_features = __pyx_t_3;
__pyx_t_3 = 0;
117:
+118: node.untried_actions = []
__pyx_t_2 = PyList_New(0); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 118, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); __Pyx_GIVEREF(__pyx_t_2); __Pyx_GOTREF(__pyx_v_node->untried_actions); __Pyx_DECREF(__pyx_v_node->untried_actions); __pyx_v_node->untried_actions = ((PyObject*)__pyx_t_2); __pyx_t_2 = 0;
+119: return node
__Pyx_XDECREF((PyObject *)__pyx_r); __Pyx_INCREF((PyObject *)__pyx_v_node); __pyx_r = __pyx_v_node; goto __pyx_L0;
120:
+121: cpdef list _expand_batch(self, list nodes):
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_9_expand_batch(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
); /*proto*/
static PyObject *__pyx_f_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero__expand_batch(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, PyObject *__pyx_v_nodes, int __pyx_skip_dispatch) {
PyObject *__pyx_v_states = 0;
PyObject *__pyx_v_nodes_to_generate = 0;
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node = 0;
PyObject *__pyx_v_batch_tactics_with_probs = 0;
PyObject *__pyx_v_tasks = 0;
PyObject *__pyx_v_results = 0;
PyObject *__pyx_v_new_children_nodes = 0;
int __pyx_v_i;
PyObject *__pyx_v_tactic = 0;
float __pyx_v_prob;
PyObject *__pyx_v_next_state = 0;
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_child = 0;
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_existing_node = 0;
PyObject *__pyx_v_state_key = 0;
PyObject *__pyx_v_tactics_probs = NULL;
PyObject *__pyx_v_e = NULL;
PyObject *__pyx_r = NULL;
/* Check if called by wrapper */
if (unlikely(__pyx_skip_dispatch)) ;
/* Check if overridden in Python */
else if (
#if !CYTHON_USE_TYPE_SLOTS
unlikely(Py_TYPE(((PyObject *)__pyx_v_self)) != __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero &&
__Pyx_PyType_HasFeature(Py_TYPE(((PyObject *)__pyx_v_self)), Py_TPFLAGS_HAVE_GC))
#else
unlikely(Py_TYPE(((PyObject *)__pyx_v_self))->tp_dictoffset != 0 || __Pyx_PyType_HasFeature(Py_TYPE(((PyObject *)__pyx_v_self)), (Py_TPFLAGS_IS_ABSTRACT | Py_TPFLAGS_HEAPTYPE)))
#endif
) {
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
static PY_UINT64_T __pyx_tp_dict_version = __PYX_DICT_VERSION_INIT, __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
if (unlikely(!__Pyx_object_dict_version_matches(((PyObject *)__pyx_v_self), __pyx_tp_dict_version, __pyx_obj_dict_version))) {
PY_UINT64_T __pyx_typedict_guard = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
#endif
__pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_mstate_global->__pyx_n_u_expand_batch); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 121, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
if (!__Pyx_IsSameCFunction(__pyx_t_1, (void(*)(void)) __pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_9_expand_batch)) {
__Pyx_XDECREF(__pyx_r);
__pyx_t_3 = NULL;
__Pyx_INCREF(__pyx_t_1);
__pyx_t_4 = __pyx_t_1;
__pyx_t_5 = 1;
#if CYTHON_UNPACK_METHODS
if (unlikely(PyMethod_Check(__pyx_t_4))) {
__pyx_t_3 = PyMethod_GET_SELF(__pyx_t_4);
assert(__pyx_t_3);
PyObject* __pyx__function = PyMethod_GET_FUNCTION(__pyx_t_4);
__Pyx_INCREF(__pyx_t_3);
__Pyx_INCREF(__pyx__function);
__Pyx_DECREF_SET(__pyx_t_4, __pyx__function);
__pyx_t_5 = 0;
}
#endif
{
PyObject *__pyx_callargs[2] = {__pyx_t_3, __pyx_v_nodes};
__pyx_t_2 = __Pyx_PyObject_FastCall((PyObject*)__pyx_t_4, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0;
__Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 121, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
}
if (!(likely(PyList_CheckExact(__pyx_t_2))||((__pyx_t_2) == Py_None) || __Pyx_RaiseUnexpectedTypeError("list", __pyx_t_2))) __PYX_ERR(0, 121, __pyx_L1_error)
__pyx_r = ((PyObject*)__pyx_t_2);
__pyx_t_2 = 0;
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
goto __pyx_L0;
}
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
__pyx_tp_dict_version = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
__pyx_obj_dict_version = __Pyx_get_object_dict_version(((PyObject *)__pyx_v_self));
if (unlikely(__pyx_typedict_guard != __pyx_tp_dict_version)) {
__pyx_tp_dict_version = __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
}
#endif
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
}
#endif
}
/* … */
/* function exit code */
__pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_1);
__Pyx_XDECREF(__pyx_t_2);
__Pyx_XDECREF(__pyx_t_3);
__Pyx_XDECREF(__pyx_t_4);
__Pyx_XDECREF(__pyx_t_14);
__Pyx_XDECREF(__pyx_t_17);
__Pyx_XDECREF(__pyx_t_21);
__Pyx_XDECREF(__pyx_t_22);
__Pyx_XDECREF(__pyx_t_23);
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._expand_batch", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = 0;
__pyx_L0:;
__Pyx_XDECREF(__pyx_v_states);
__Pyx_XDECREF(__pyx_v_nodes_to_generate);
__Pyx_XDECREF((PyObject *)__pyx_v_node);
__Pyx_XDECREF(__pyx_v_batch_tactics_with_probs);
__Pyx_XDECREF(__pyx_v_tasks);
__Pyx_XDECREF(__pyx_v_results);
__Pyx_XDECREF(__pyx_v_new_children_nodes);
__Pyx_XDECREF(__pyx_v_tactic);
__Pyx_XDECREF(__pyx_v_next_state);
__Pyx_XDECREF((PyObject *)__pyx_v_child);
__Pyx_XDECREF((PyObject *)__pyx_v_existing_node);
__Pyx_XDECREF(__pyx_v_state_key);
__Pyx_XDECREF(__pyx_v_tactics_probs);
__Pyx_XDECREF(__pyx_v_e);
__Pyx_XGIVEREF(__pyx_r);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* Python wrapper */
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_9_expand_batch(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
); /*proto*/
static PyMethodDef __pyx_mdef_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_9_expand_batch = {"_expand_batch", (PyCFunction)(void(*)(void))(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_9_expand_batch, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0};
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_9_expand_batch(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
) {
PyObject *__pyx_v_nodes = 0;
#if !CYTHON_METH_FASTCALL
CYTHON_UNUSED Py_ssize_t __pyx_nargs;
#endif
CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
PyObject *__pyx_r = 0;
__Pyx_RefNannyDeclarations
__Pyx_RefNannySetupContext("_expand_batch (wrapper)", 0);
#if !CYTHON_METH_FASTCALL
#if CYTHON_ASSUME_SAFE_SIZE
__pyx_nargs = PyTuple_GET_SIZE(__pyx_args);
#else
__pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL;
#endif
#endif
__pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs);
{
PyObject ** const __pyx_pyargnames[] = {&__pyx_mstate_global->__pyx_n_u_nodes,0};
PyObject* values[1] = {0};
const Py_ssize_t __pyx_kwds_len = (__pyx_kwds) ? __Pyx_NumKwargs_FASTCALL(__pyx_kwds) : 0;
if (unlikely(__pyx_kwds_len) < 0) __PYX_ERR(0, 121, __pyx_L3_error)
if (__pyx_kwds_len > 0) {
switch (__pyx_nargs) {
case 1:
values[0] = __Pyx_ArgRef_FASTCALL(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 121, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 0: break;
default: goto __pyx_L5_argtuple_error;
}
const Py_ssize_t kwd_pos_args = __pyx_nargs;
if (__Pyx_ParseKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values, kwd_pos_args, __pyx_kwds_len, "_expand_batch", 0) < (0)) __PYX_ERR(0, 121, __pyx_L3_error)
for (Py_ssize_t i = __pyx_nargs; i < 1; i++) {
if (unlikely(!values[i])) { __Pyx_RaiseArgtupleInvalid("_expand_batch", 1, 1, 1, i); __PYX_ERR(0, 121, __pyx_L3_error) }
}
} else if (unlikely(__pyx_nargs != 1)) {
goto __pyx_L5_argtuple_error;
} else {
values[0] = __Pyx_ArgRef_FASTCALL(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 121, __pyx_L3_error)
}
__pyx_v_nodes = ((PyObject*)values[0]);
}
goto __pyx_L6_skip;
__pyx_L5_argtuple_error:;
__Pyx_RaiseArgtupleInvalid("_expand_batch", 1, 1, 1, __pyx_nargs); __PYX_ERR(0, 121, __pyx_L3_error)
__pyx_L6_skip:;
goto __pyx_L4_argument_unpacking_done;
__pyx_L3_error:;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._expand_batch", __pyx_clineno, __pyx_lineno, __pyx_filename);
__Pyx_RefNannyFinishContext();
return NULL;
__pyx_L4_argument_unpacking_done:;
if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_nodes), (&PyList_Type), 1, "nodes", 1))) __PYX_ERR(0, 121, __pyx_L1_error)
__pyx_r = __pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_8_expand_batch(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self), __pyx_v_nodes);
int __pyx_lineno = 0;
const char *__pyx_filename = NULL;
int __pyx_clineno = 0;
/* function exit code */
goto __pyx_L0;
__pyx_L1_error:;
__pyx_r = NULL;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
goto __pyx_L7_cleaned_up;
__pyx_L0:;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__pyx_L7_cleaned_up:;
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
static PyObject *__pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_8_expand_batch(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, PyObject *__pyx_v_nodes) {
PyObject *__pyx_r = NULL;
__Pyx_XDECREF(__pyx_r);
__pyx_t_1 = __pyx_f_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero__expand_batch(__pyx_v_self, __pyx_v_nodes, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 121, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
__pyx_r = __pyx_t_1;
__pyx_t_1 = 0;
goto __pyx_L0;
/* function exit code */
__pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_1);
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._expand_batch", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = NULL;
__pyx_L0:;
__Pyx_XGIVEREF(__pyx_r);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* … */
__pyx_t_2 = __Pyx_CyFunction_New(&__pyx_mdef_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_9_expand_batch, __Pyx_CYFUNCTION_CCLASS, __pyx_mstate_global->__pyx_n_u_MCTS_AlphaZero__expand_batch, NULL, __pyx_mstate_global->__pyx_n_u_lean_reinforcement_agent_mcts_mc, __pyx_mstate_global->__pyx_d, ((PyObject *)__pyx_mstate_global->__pyx_codeobj_tab[3])); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 121, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030E0000
PyUnstable_Object_EnableDeferredRefcount(__pyx_t_2);
#endif
if (__Pyx_SetItemOnTypeDict(__pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero, __pyx_mstate_global->__pyx_n_u_expand_batch, __pyx_t_2) < (0)) __PYX_ERR(0, 121, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+122: cdef list states = []
__pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 122, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_v_states = ((PyObject*)__pyx_t_1); __pyx_t_1 = 0;
+123: cdef list nodes_to_generate = []
__pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 123, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_v_nodes_to_generate = ((PyObject*)__pyx_t_1); __pyx_t_1 = 0;
124: cdef Node node
125: cdef list batch_tactics_with_probs
+126: cdef list tasks = []
__pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 126, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_v_tasks = ((PyObject*)__pyx_t_1); __pyx_t_1 = 0;
+127: cdef list results = []
__pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 127, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_v_results = ((PyObject*)__pyx_t_1); __pyx_t_1 = 0;
+128: cdef list new_children_nodes = []
__pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 128, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_v_new_children_nodes = ((PyObject*)__pyx_t_1); __pyx_t_1 = 0;
129: cdef int i
130: cdef object tactic
131: cdef float prob
132: cdef object next_state
133: cdef Node child
134: cdef Node existing_node
135: cdef object state_key
136:
+137: for node in nodes:
if (unlikely(__pyx_v_nodes == Py_None)) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is not iterable");
__PYX_ERR(0, 137, __pyx_L1_error)
}
__pyx_t_1 = __pyx_v_nodes; __Pyx_INCREF(__pyx_t_1);
__pyx_t_6 = 0;
for (;;) {
{
Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_1);
#if !CYTHON_ASSUME_SAFE_SIZE
if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 137, __pyx_L1_error)
#endif
if (__pyx_t_6 >= __pyx_temp) break;
}
__pyx_t_2 = __Pyx_PyList_GetItemRefFast(__pyx_t_1, __pyx_t_6, __Pyx_ReferenceSharing_OwnStrongReference);
++__pyx_t_6;
if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 137, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
if (!(likely(((__pyx_t_2) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_2, __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node))))) __PYX_ERR(0, 137, __pyx_L1_error)
__Pyx_XDECREF_SET(__pyx_v_node, ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_2));
__pyx_t_2 = 0;
/* … */
}
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+138: if isinstance(node.state, TacticState):
__pyx_t_2 = __pyx_v_node->state;
__Pyx_INCREF(__pyx_t_2);
__Pyx_GetModuleGlobalName(__pyx_t_4, __pyx_mstate_global->__pyx_n_u_TacticState); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 138, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_4);
__pyx_t_7 = PyObject_IsInstance(__pyx_t_2, __pyx_t_4); if (unlikely(__pyx_t_7 == ((int)-1))) __PYX_ERR(0, 138, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
__Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
if (__pyx_t_7) {
/* … */
}
+139: states.append(node.state.pp)
__pyx_t_4 = __Pyx_PyObject_GetAttrStr(__pyx_v_node->state, __pyx_mstate_global->__pyx_n_u_pp); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 139, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_4); __pyx_t_8 = __Pyx_PyList_Append(__pyx_v_states, __pyx_t_4); if (unlikely(__pyx_t_8 == ((int)-1))) __PYX_ERR(0, 139, __pyx_L1_error) __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
+140: nodes_to_generate.append(node)
__pyx_t_8 = __Pyx_PyList_Append(__pyx_v_nodes_to_generate, ((PyObject *)__pyx_v_node)); if (unlikely(__pyx_t_8 == ((int)-1))) __PYX_ERR(0, 140, __pyx_L1_error)
141:
+142: if not states:
{
Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_v_states);
if (unlikely(((!CYTHON_ASSUME_SAFE_SIZE) && __pyx_temp < 0))) __PYX_ERR(0, 142, __pyx_L1_error)
__pyx_t_7 = (__pyx_temp != 0);
}
__pyx_t_9 = (!__pyx_t_7);
if (__pyx_t_9) {
/* … */
}
+143: return nodes
__Pyx_XDECREF(__pyx_r); __Pyx_INCREF(__pyx_v_nodes); __pyx_r = __pyx_v_nodes; goto __pyx_L0;
144:
145: # Check timeout before expensive model call
+146: if self._is_timeout():
__pyx_t_9 = ((struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self->__pyx_base.__pyx_vtab)->__pyx_base._is_timeout(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_BaseMCTS *)__pyx_v_self), 0); if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 146, __pyx_L1_error)
if (__pyx_t_9) {
/* … */
}
+147: return nodes
__Pyx_XDECREF(__pyx_r); __Pyx_INCREF(__pyx_v_nodes); __pyx_r = __pyx_v_nodes; goto __pyx_L0;
148:
+149: batch_tactics_with_probs = self.transformer.generate_tactics_with_probs_batch(
__pyx_t_4 = __pyx_v_self->__pyx_base.transformer; __Pyx_INCREF(__pyx_t_4); /* … */ if (!(likely(PyList_CheckExact(__pyx_t_1))||((__pyx_t_1) == Py_None) || __Pyx_RaiseUnexpectedTypeError("list", __pyx_t_1))) __PYX_ERR(0, 149, __pyx_L1_error) __pyx_v_batch_tactics_with_probs = ((PyObject*)__pyx_t_1); __pyx_t_1 = 0;
+150: states, n=self.num_tactics_to_expand
__pyx_t_2 = __Pyx_PyLong_From_int(__pyx_v_self->__pyx_base.num_tactics_to_expand); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 150, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); __pyx_t_5 = 0; { PyObject *__pyx_callargs[2 + ((CYTHON_VECTORCALL) ? 1 : 0)] = {__pyx_t_4, __pyx_v_states}; __pyx_t_3 = __Pyx_MakeVectorcallBuilderKwds(1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 149, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_3); if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_n, __pyx_t_2, __pyx_t_3, __pyx_callargs+2, 0) < (0)) __PYX_ERR(0, 149, __pyx_L1_error) __pyx_t_1 = __Pyx_Object_VectorcallMethod_CallFromBuilder((PyObject*)__pyx_mstate_global->__pyx_n_u_generate_tactics_with_probs_batc, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (1*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET), __pyx_t_3); __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 149, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); }
151: )
152:
153: # Check timeout after model call
+154: if self._is_timeout():
__pyx_t_9 = ((struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self->__pyx_base.__pyx_vtab)->__pyx_base._is_timeout(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_BaseMCTS *)__pyx_v_self), 0); if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 154, __pyx_L1_error)
if (__pyx_t_9) {
/* … */
}
+155: return nodes
__Pyx_XDECREF(__pyx_r); __Pyx_INCREF(__pyx_v_nodes); __pyx_r = __pyx_v_nodes; goto __pyx_L0;
156:
+157: for i in range(len(batch_tactics_with_probs)):
if (unlikely(__pyx_v_batch_tactics_with_probs == Py_None)) {
PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()");
__PYX_ERR(0, 157, __pyx_L1_error)
}
__pyx_t_6 = __Pyx_PyList_GET_SIZE(__pyx_v_batch_tactics_with_probs); if (unlikely(__pyx_t_6 == ((Py_ssize_t)-1))) __PYX_ERR(0, 157, __pyx_L1_error)
__pyx_t_10 = __pyx_t_6;
for (__pyx_t_11 = 0; __pyx_t_11 < __pyx_t_10; __pyx_t_11+=1) {
__pyx_v_i = __pyx_t_11;
+158: tactics_probs = batch_tactics_with_probs[i]
if (unlikely(__pyx_v_batch_tactics_with_probs == Py_None)) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable");
__PYX_ERR(0, 158, __pyx_L1_error)
}
__pyx_t_1 = __Pyx_PyList_GET_ITEM(__pyx_v_batch_tactics_with_probs, __pyx_v_i);
__Pyx_INCREF(__pyx_t_1);
__Pyx_XDECREF_SET(__pyx_v_tactics_probs, __pyx_t_1);
__pyx_t_1 = 0;
+159: node = nodes_to_generate[i]
__pyx_t_1 = __Pyx_PyList_GET_ITEM(__pyx_v_nodes_to_generate, __pyx_v_i); __Pyx_INCREF(__pyx_t_1); if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node))))) __PYX_ERR(0, 159, __pyx_L1_error) __Pyx_XDECREF_SET(__pyx_v_node, ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_1)); __pyx_t_1 = 0;
+160: for tactic, prob in tactics_probs:
if (likely(PyList_CheckExact(__pyx_v_tactics_probs)) || PyTuple_CheckExact(__pyx_v_tactics_probs)) { __pyx_t_1 = __pyx_v_tactics_probs; __Pyx_INCREF(__pyx_t_1); __pyx_t_12 = 0; __pyx_t_13 = NULL; } else { __pyx_t_12 = -1; __pyx_t_1 = PyObject_GetIter(__pyx_v_tactics_probs); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 160, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_t_13 = (CYTHON_COMPILING_IN_LIMITED_API) ? PyIter_Next : __Pyx_PyObject_GetIterNextFunc(__pyx_t_1); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 160, __pyx_L1_error) } for (;;) { if (likely(!__pyx_t_13)) { if (likely(PyList_CheckExact(__pyx_t_1))) { { Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_1); #if !CYTHON_ASSUME_SAFE_SIZE if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 160, __pyx_L1_error) #endif if (__pyx_t_12 >= __pyx_temp) break; } __pyx_t_3 = __Pyx_PyList_GetItemRefFast(__pyx_t_1, __pyx_t_12, __Pyx_ReferenceSharing_OwnStrongReference); ++__pyx_t_12; } else { { Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_1); #if !CYTHON_ASSUME_SAFE_SIZE if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 160, __pyx_L1_error) #endif if (__pyx_t_12 >= __pyx_temp) break; } #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS __pyx_t_3 = __Pyx_NewRef(PyTuple_GET_ITEM(__pyx_t_1, __pyx_t_12)); #else __pyx_t_3 = __Pyx_PySequence_ITEM(__pyx_t_1, __pyx_t_12); #endif ++__pyx_t_12; } if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 160, __pyx_L1_error) } else { __pyx_t_3 = __pyx_t_13(__pyx_t_1); if (unlikely(!__pyx_t_3)) { PyObject* exc_type = PyErr_Occurred(); if (exc_type) { if (unlikely(!__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) __PYX_ERR(0, 160, __pyx_L1_error) PyErr_Clear(); } break; } } __Pyx_GOTREF(__pyx_t_3); if ((likely(PyTuple_CheckExact(__pyx_t_3))) || (PyList_CheckExact(__pyx_t_3))) { PyObject* sequence = __pyx_t_3; Py_ssize_t size = __Pyx_PySequence_SIZE(sequence); if (unlikely(size != 2)) { if (size > 2) __Pyx_RaiseTooManyValuesError(2); else if (size >= 0) __Pyx_RaiseNeedMoreValuesError(size); __PYX_ERR(0, 160, __pyx_L1_error) } #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS if (likely(PyTuple_CheckExact(sequence))) { __pyx_t_2 = PyTuple_GET_ITEM(sequence, 0); __Pyx_INCREF(__pyx_t_2); __pyx_t_4 = PyTuple_GET_ITEM(sequence, 1); __Pyx_INCREF(__pyx_t_4); } else { __pyx_t_2 = __Pyx_PyList_GetItemRefFast(sequence, 0, __Pyx_ReferenceSharing_SharedReference); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 160, __pyx_L1_error) __Pyx_XGOTREF(__pyx_t_2); __pyx_t_4 = __Pyx_PyList_GetItemRefFast(sequence, 1, __Pyx_ReferenceSharing_SharedReference); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 160, __pyx_L1_error) __Pyx_XGOTREF(__pyx_t_4); } #else __pyx_t_2 = __Pyx_PySequence_ITEM(sequence, 0); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 160, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); __pyx_t_4 = __Pyx_PySequence_ITEM(sequence, 1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 160, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_4); #endif __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; } else { Py_ssize_t index = -1; __pyx_t_14 = PyObject_GetIter(__pyx_t_3); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 160, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_14); __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; __pyx_t_15 = (CYTHON_COMPILING_IN_LIMITED_API) ? PyIter_Next : __Pyx_PyObject_GetIterNextFunc(__pyx_t_14); index = 0; __pyx_t_2 = __pyx_t_15(__pyx_t_14); if (unlikely(!__pyx_t_2)) goto __pyx_L14_unpacking_failed; __Pyx_GOTREF(__pyx_t_2); index = 1; __pyx_t_4 = __pyx_t_15(__pyx_t_14); if (unlikely(!__pyx_t_4)) goto __pyx_L14_unpacking_failed; __Pyx_GOTREF(__pyx_t_4); if (__Pyx_IternextUnpackEndCheck(__pyx_t_15(__pyx_t_14), 2) < (0)) __PYX_ERR(0, 160, __pyx_L1_error) __pyx_t_15 = NULL; __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; goto __pyx_L15_unpacking_done; __pyx_L14_unpacking_failed:; __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; __pyx_t_15 = NULL; if (__Pyx_IterFinish() == 0) __Pyx_RaiseNeedMoreValuesError(index); __PYX_ERR(0, 160, __pyx_L1_error) __pyx_L15_unpacking_done:; } __pyx_t_16 = __Pyx_PyFloat_AsFloat(__pyx_t_4); if (unlikely((__pyx_t_16 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 160, __pyx_L1_error) __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; __Pyx_XDECREF_SET(__pyx_v_tactic, __pyx_t_2); __pyx_t_2 = 0; __pyx_v_prob = __pyx_t_16; /* … */ } __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; }
+161: tasks.append((node, tactic, prob))
__pyx_t_3 = PyFloat_FromDouble(__pyx_v_prob); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 161, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_3); __pyx_t_4 = PyTuple_New(3); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 161, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_4); __Pyx_INCREF((PyObject *)__pyx_v_node); __Pyx_GIVEREF((PyObject *)__pyx_v_node); if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, ((PyObject *)__pyx_v_node)) != (0)) __PYX_ERR(0, 161, __pyx_L1_error); __Pyx_INCREF(__pyx_v_tactic); __Pyx_GIVEREF(__pyx_v_tactic); if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 1, __pyx_v_tactic) != (0)) __PYX_ERR(0, 161, __pyx_L1_error); __Pyx_GIVEREF(__pyx_t_3); if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 2, __pyx_t_3) != (0)) __PYX_ERR(0, 161, __pyx_L1_error); __pyx_t_3 = 0; __pyx_t_8 = __Pyx_PyList_Append(__pyx_v_tasks, __pyx_t_4); if (unlikely(__pyx_t_8 == ((int)-1))) __PYX_ERR(0, 161, __pyx_L1_error) __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
162:
+163: for node, tactic, prob in tasks:
__pyx_t_1 = __pyx_v_tasks; __Pyx_INCREF(__pyx_t_1); __pyx_t_6 = 0; for (;;) { { Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_1); #if !CYTHON_ASSUME_SAFE_SIZE if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 163, __pyx_L1_error) #endif if (__pyx_t_6 >= __pyx_temp) break; } __pyx_t_4 = __Pyx_PyList_GetItemRefFast(__pyx_t_1, __pyx_t_6, __Pyx_ReferenceSharing_OwnStrongReference); ++__pyx_t_6; if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 163, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_4); if ((likely(PyTuple_CheckExact(__pyx_t_4))) || (PyList_CheckExact(__pyx_t_4))) { PyObject* sequence = __pyx_t_4; Py_ssize_t size = __Pyx_PySequence_SIZE(sequence); if (unlikely(size != 3)) { if (size > 3) __Pyx_RaiseTooManyValuesError(3); else if (size >= 0) __Pyx_RaiseNeedMoreValuesError(size); __PYX_ERR(0, 163, __pyx_L1_error) } #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS if (likely(PyTuple_CheckExact(sequence))) { __pyx_t_3 = PyTuple_GET_ITEM(sequence, 0); __Pyx_INCREF(__pyx_t_3); __pyx_t_2 = PyTuple_GET_ITEM(sequence, 1); __Pyx_INCREF(__pyx_t_2); __pyx_t_14 = PyTuple_GET_ITEM(sequence, 2); __Pyx_INCREF(__pyx_t_14); } else { __pyx_t_3 = __Pyx_PyList_GetItemRefFast(sequence, 0, __Pyx_ReferenceSharing_SharedReference); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 163, __pyx_L1_error) __Pyx_XGOTREF(__pyx_t_3); __pyx_t_2 = __Pyx_PyList_GetItemRefFast(sequence, 1, __Pyx_ReferenceSharing_SharedReference); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 163, __pyx_L1_error) __Pyx_XGOTREF(__pyx_t_2); __pyx_t_14 = __Pyx_PyList_GetItemRefFast(sequence, 2, __Pyx_ReferenceSharing_SharedReference); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 163, __pyx_L1_error) __Pyx_XGOTREF(__pyx_t_14); } #else __pyx_t_3 = __Pyx_PySequence_ITEM(sequence, 0); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 163, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_3); __pyx_t_2 = __Pyx_PySequence_ITEM(sequence, 1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 163, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); __pyx_t_14 = __Pyx_PySequence_ITEM(sequence, 2); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 163, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_14); #endif __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; } else { Py_ssize_t index = -1; __pyx_t_17 = PyObject_GetIter(__pyx_t_4); if (unlikely(!__pyx_t_17)) __PYX_ERR(0, 163, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_17); __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; __pyx_t_15 = (CYTHON_COMPILING_IN_LIMITED_API) ? PyIter_Next : __Pyx_PyObject_GetIterNextFunc(__pyx_t_17); index = 0; __pyx_t_3 = __pyx_t_15(__pyx_t_17); if (unlikely(!__pyx_t_3)) goto __pyx_L19_unpacking_failed; __Pyx_GOTREF(__pyx_t_3); index = 1; __pyx_t_2 = __pyx_t_15(__pyx_t_17); if (unlikely(!__pyx_t_2)) goto __pyx_L19_unpacking_failed; __Pyx_GOTREF(__pyx_t_2); index = 2; __pyx_t_14 = __pyx_t_15(__pyx_t_17); if (unlikely(!__pyx_t_14)) goto __pyx_L19_unpacking_failed; __Pyx_GOTREF(__pyx_t_14); if (__Pyx_IternextUnpackEndCheck(__pyx_t_15(__pyx_t_17), 3) < (0)) __PYX_ERR(0, 163, __pyx_L1_error) __pyx_t_15 = NULL; __Pyx_DECREF(__pyx_t_17); __pyx_t_17 = 0; goto __pyx_L20_unpacking_done; __pyx_L19_unpacking_failed:; __Pyx_DECREF(__pyx_t_17); __pyx_t_17 = 0; __pyx_t_15 = NULL; if (__Pyx_IterFinish() == 0) __Pyx_RaiseNeedMoreValuesError(index); __PYX_ERR(0, 163, __pyx_L1_error) __pyx_L20_unpacking_done:; } if (!(likely(((__pyx_t_3) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_3, __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node))))) __PYX_ERR(0, 163, __pyx_L1_error) __pyx_t_16 = __Pyx_PyFloat_AsFloat(__pyx_t_14); if (unlikely((__pyx_t_16 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 163, __pyx_L1_error) __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; __Pyx_XDECREF_SET(__pyx_v_node, ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_3)); __pyx_t_3 = 0; __Pyx_XDECREF_SET(__pyx_v_tactic, __pyx_t_2); __pyx_t_2 = 0; __pyx_v_prob = __pyx_t_16; /* … */ } __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; goto __pyx_L41_for_end; __pyx_L18_break:; __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; goto __pyx_L41_for_end; __pyx_L41_for_end:;
164: # Check timeout before each Lean call
+165: if self._is_timeout():
__pyx_t_9 = ((struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self->__pyx_base.__pyx_vtab)->__pyx_base._is_timeout(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_BaseMCTS *)__pyx_v_self), 0); if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 165, __pyx_L1_error)
if (__pyx_t_9) {
/* … */
}
+166: break
goto __pyx_L18_break;
+167: try:
{
/*try:*/ {
/* … */
}
__Pyx_XDECREF(__pyx_t_18); __pyx_t_18 = 0;
__Pyx_XDECREF(__pyx_t_19); __pyx_t_19 = 0;
__Pyx_XDECREF(__pyx_t_20); __pyx_t_20 = 0;
goto __pyx_L29_try_end;
__pyx_L22_error:;
__Pyx_XDECREF(__pyx_t_14); __pyx_t_14 = 0;
__Pyx_XDECREF(__pyx_t_17); __pyx_t_17 = 0;
__Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0;
__Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0;
__Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0;
/* … */
__pyx_L24_except_error:;
__Pyx_XGIVEREF(__pyx_t_18);
__Pyx_XGIVEREF(__pyx_t_19);
__Pyx_XGIVEREF(__pyx_t_20);
__Pyx_ExceptionReset(__pyx_t_18, __pyx_t_19, __pyx_t_20);
goto __pyx_L1_error;
__pyx_L23_exception_handled:;
__Pyx_XGIVEREF(__pyx_t_18);
__Pyx_XGIVEREF(__pyx_t_19);
__Pyx_XGIVEREF(__pyx_t_20);
__Pyx_ExceptionReset(__pyx_t_18, __pyx_t_19, __pyx_t_20);
__pyx_L29_try_end:;
}
+168: next_state = self.env.run_tactic_stateless(node.state, tactic)
__pyx_t_14 = __pyx_v_self->__pyx_base.env;
__Pyx_INCREF(__pyx_t_14);
__pyx_t_5 = 0;
{
PyObject *__pyx_callargs[3] = {__pyx_t_14, __pyx_v_node->state, __pyx_v_tactic};
__pyx_t_4 = __Pyx_PyObject_FastCallMethod((PyObject*)__pyx_mstate_global->__pyx_n_u_run_tactic_stateless, __pyx_callargs+__pyx_t_5, (3-__pyx_t_5) | (1*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_14); __pyx_t_14 = 0;
if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 168, __pyx_L22_error)
__Pyx_GOTREF(__pyx_t_4);
}
__Pyx_XDECREF_SET(__pyx_v_next_state, __pyx_t_4);
__pyx_t_4 = 0;
+169: except Exception as e:
__pyx_t_11 = __Pyx_PyErr_ExceptionMatches(((PyObject *)(((PyTypeObject*)PyExc_Exception)))); if (__pyx_t_11) { __Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._expand_batch", __pyx_clineno, __pyx_lineno, __pyx_filename); if (__Pyx_GetException(&__pyx_t_4, &__pyx_t_14, &__pyx_t_2) < 0) __PYX_ERR(0, 169, __pyx_L24_except_error) __Pyx_XGOTREF(__pyx_t_4); __Pyx_XGOTREF(__pyx_t_14); __Pyx_XGOTREF(__pyx_t_2); __Pyx_INCREF(__pyx_t_14); __pyx_v_e = __pyx_t_14; /*try:*/ { /* … */ /*finally:*/ { /*normal exit:*/{ __Pyx_DECREF(__pyx_v_e); __pyx_v_e = 0; goto __pyx_L36; } __pyx_L35_error:; /*exception exit:*/{ __Pyx_PyThreadState_declare __Pyx_PyThreadState_assign __pyx_t_26 = 0; __pyx_t_27 = 0; __pyx_t_28 = 0; __pyx_t_29 = 0; __pyx_t_30 = 0; __pyx_t_31 = 0; __Pyx_XDECREF(__pyx_t_17); __pyx_t_17 = 0; __Pyx_XDECREF(__pyx_t_21); __pyx_t_21 = 0; __Pyx_XDECREF(__pyx_t_22); __pyx_t_22 = 0; __Pyx_XDECREF(__pyx_t_23); __pyx_t_23 = 0; __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; __Pyx_ExceptionSwap(&__pyx_t_29, &__pyx_t_30, &__pyx_t_31); if ( unlikely(__Pyx_GetException(&__pyx_t_26, &__pyx_t_27, &__pyx_t_28) < 0)) __Pyx_ErrFetch(&__pyx_t_26, &__pyx_t_27, &__pyx_t_28); __Pyx_XGOTREF(__pyx_t_26); __Pyx_XGOTREF(__pyx_t_27); __Pyx_XGOTREF(__pyx_t_28); __Pyx_XGOTREF(__pyx_t_29); __Pyx_XGOTREF(__pyx_t_30); __Pyx_XGOTREF(__pyx_t_31); __pyx_t_11 = __pyx_lineno; __pyx_t_24 = __pyx_clineno; __pyx_t_25 = __pyx_filename; { __Pyx_DECREF(__pyx_v_e); __pyx_v_e = 0; } __Pyx_XGIVEREF(__pyx_t_29); __Pyx_XGIVEREF(__pyx_t_30); __Pyx_XGIVEREF(__pyx_t_31); __Pyx_ExceptionReset(__pyx_t_29, __pyx_t_30, __pyx_t_31); __Pyx_XGIVEREF(__pyx_t_26); __Pyx_XGIVEREF(__pyx_t_27); __Pyx_XGIVEREF(__pyx_t_28); __Pyx_ErrRestore(__pyx_t_26, __pyx_t_27, __pyx_t_28); __pyx_t_26 = 0; __pyx_t_27 = 0; __pyx_t_28 = 0; __pyx_t_29 = 0; __pyx_t_30 = 0; __pyx_t_31 = 0; __pyx_lineno = __pyx_t_11; __pyx_clineno = __pyx_t_24; __pyx_filename = __pyx_t_25; goto __pyx_L24_except_error; } __pyx_L36:; } __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; __Pyx_XDECREF(__pyx_t_14); __pyx_t_14 = 0; __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; goto __pyx_L23_exception_handled; } goto __pyx_L24_except_error;
+170: next_state = LeanError(error=str(e))
__pyx_t_17 = NULL;
__Pyx_GetModuleGlobalName(__pyx_t_21, __pyx_mstate_global->__pyx_n_u_LeanError); if (unlikely(!__pyx_t_21)) __PYX_ERR(0, 170, __pyx_L35_error)
__Pyx_GOTREF(__pyx_t_21);
__pyx_t_22 = __Pyx_PyObject_Unicode(__pyx_v_e); if (unlikely(!__pyx_t_22)) __PYX_ERR(0, 170, __pyx_L35_error)
__Pyx_GOTREF(__pyx_t_22);
__pyx_t_5 = 1;
#if CYTHON_UNPACK_METHODS
if (unlikely(PyMethod_Check(__pyx_t_21))) {
__pyx_t_17 = PyMethod_GET_SELF(__pyx_t_21);
assert(__pyx_t_17);
PyObject* __pyx__function = PyMethod_GET_FUNCTION(__pyx_t_21);
__Pyx_INCREF(__pyx_t_17);
__Pyx_INCREF(__pyx__function);
__Pyx_DECREF_SET(__pyx_t_21, __pyx__function);
__pyx_t_5 = 0;
}
#endif
{
PyObject *__pyx_callargs[2 + ((CYTHON_VECTORCALL) ? 1 : 0)] = {__pyx_t_17, NULL};
__pyx_t_23 = __Pyx_MakeVectorcallBuilderKwds(1); if (unlikely(!__pyx_t_23)) __PYX_ERR(0, 170, __pyx_L35_error)
__Pyx_GOTREF(__pyx_t_23);
if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_error, __pyx_t_22, __pyx_t_23, __pyx_callargs+1, 0) < (0)) __PYX_ERR(0, 170, __pyx_L35_error)
__pyx_t_3 = __Pyx_Object_Vectorcall_CallFromBuilder((PyObject*)__pyx_t_21, __pyx_callargs+__pyx_t_5, (1-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET), __pyx_t_23);
__Pyx_XDECREF(__pyx_t_17); __pyx_t_17 = 0;
__Pyx_DECREF(__pyx_t_22); __pyx_t_22 = 0;
__Pyx_DECREF(__pyx_t_23); __pyx_t_23 = 0;
__Pyx_DECREF(__pyx_t_21); __pyx_t_21 = 0;
if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 170, __pyx_L35_error)
__Pyx_GOTREF(__pyx_t_3);
}
__Pyx_XDECREF_SET(__pyx_v_next_state, __pyx_t_3);
__pyx_t_3 = 0;
}
+171: results.append((node, tactic, prob, next_state))
__pyx_t_2 = PyFloat_FromDouble(__pyx_v_prob); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 171, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); __pyx_t_14 = PyTuple_New(4); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 171, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_14); __Pyx_INCREF((PyObject *)__pyx_v_node); __Pyx_GIVEREF((PyObject *)__pyx_v_node); if (__Pyx_PyTuple_SET_ITEM(__pyx_t_14, 0, ((PyObject *)__pyx_v_node)) != (0)) __PYX_ERR(0, 171, __pyx_L1_error); __Pyx_INCREF(__pyx_v_tactic); __Pyx_GIVEREF(__pyx_v_tactic); if (__Pyx_PyTuple_SET_ITEM(__pyx_t_14, 1, __pyx_v_tactic) != (0)) __PYX_ERR(0, 171, __pyx_L1_error); __Pyx_GIVEREF(__pyx_t_2); if (__Pyx_PyTuple_SET_ITEM(__pyx_t_14, 2, __pyx_t_2) != (0)) __PYX_ERR(0, 171, __pyx_L1_error); __Pyx_INCREF(__pyx_v_next_state); __Pyx_GIVEREF(__pyx_v_next_state); if (__Pyx_PyTuple_SET_ITEM(__pyx_t_14, 3, __pyx_v_next_state) != (0)) __PYX_ERR(0, 171, __pyx_L1_error); __pyx_t_2 = 0; __pyx_t_8 = __Pyx_PyList_Append(__pyx_v_results, __pyx_t_14); if (unlikely(__pyx_t_8 == ((int)-1))) __PYX_ERR(0, 171, __pyx_L1_error) __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0;
172:
173: # Create children (reusing existing nodes for duplicates - DAG structure)
+174: for node, tactic, prob, next_state in results:
__pyx_t_1 = __pyx_v_results; __Pyx_INCREF(__pyx_t_1); __pyx_t_6 = 0; for (;;) { { Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_1); #if !CYTHON_ASSUME_SAFE_SIZE if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 174, __pyx_L1_error) #endif if (__pyx_t_6 >= __pyx_temp) break; } __pyx_t_14 = __Pyx_PyList_GetItemRefFast(__pyx_t_1, __pyx_t_6, __Pyx_ReferenceSharing_OwnStrongReference); ++__pyx_t_6; if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 174, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_14); if ((likely(PyTuple_CheckExact(__pyx_t_14))) || (PyList_CheckExact(__pyx_t_14))) { PyObject* sequence = __pyx_t_14; Py_ssize_t size = __Pyx_PySequence_SIZE(sequence); if (unlikely(size != 4)) { if (size > 4) __Pyx_RaiseTooManyValuesError(4); else if (size >= 0) __Pyx_RaiseNeedMoreValuesError(size); __PYX_ERR(0, 174, __pyx_L1_error) } #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS if (likely(PyTuple_CheckExact(sequence))) { __pyx_t_2 = PyTuple_GET_ITEM(sequence, 0); __Pyx_INCREF(__pyx_t_2); __pyx_t_4 = PyTuple_GET_ITEM(sequence, 1); __Pyx_INCREF(__pyx_t_4); __pyx_t_3 = PyTuple_GET_ITEM(sequence, 2); __Pyx_INCREF(__pyx_t_3); __pyx_t_21 = PyTuple_GET_ITEM(sequence, 3); __Pyx_INCREF(__pyx_t_21); } else { __pyx_t_2 = __Pyx_PyList_GetItemRefFast(sequence, 0, __Pyx_ReferenceSharing_SharedReference); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 174, __pyx_L1_error) __Pyx_XGOTREF(__pyx_t_2); __pyx_t_4 = __Pyx_PyList_GetItemRefFast(sequence, 1, __Pyx_ReferenceSharing_SharedReference); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 174, __pyx_L1_error) __Pyx_XGOTREF(__pyx_t_4); __pyx_t_3 = __Pyx_PyList_GetItemRefFast(sequence, 2, __Pyx_ReferenceSharing_SharedReference); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 174, __pyx_L1_error) __Pyx_XGOTREF(__pyx_t_3); __pyx_t_21 = __Pyx_PyList_GetItemRefFast(sequence, 3, __Pyx_ReferenceSharing_SharedReference); if (unlikely(!__pyx_t_21)) __PYX_ERR(0, 174, __pyx_L1_error) __Pyx_XGOTREF(__pyx_t_21); } #else { Py_ssize_t i; PyObject** temps[4] = {&__pyx_t_2,&__pyx_t_4,&__pyx_t_3,&__pyx_t_21}; for (i=0; i < 4; i++) { PyObject* item = __Pyx_PySequence_ITEM(sequence, i); if (unlikely(!item)) __PYX_ERR(0, 174, __pyx_L1_error) __Pyx_GOTREF(item); *(temps[i]) = item; } } #endif __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; } else { Py_ssize_t index = -1; PyObject** temps[4] = {&__pyx_t_2,&__pyx_t_4,&__pyx_t_3,&__pyx_t_21}; __pyx_t_23 = PyObject_GetIter(__pyx_t_14); if (unlikely(!__pyx_t_23)) __PYX_ERR(0, 174, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_23); __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; __pyx_t_15 = (CYTHON_COMPILING_IN_LIMITED_API) ? PyIter_Next : __Pyx_PyObject_GetIterNextFunc(__pyx_t_23); for (index=0; index < 4; index++) { PyObject* item = __pyx_t_15(__pyx_t_23); if (unlikely(!item)) goto __pyx_L44_unpacking_failed; __Pyx_GOTREF(item); *(temps[index]) = item; } if (__Pyx_IternextUnpackEndCheck(__pyx_t_15(__pyx_t_23), 4) < (0)) __PYX_ERR(0, 174, __pyx_L1_error) __pyx_t_15 = NULL; __Pyx_DECREF(__pyx_t_23); __pyx_t_23 = 0; goto __pyx_L45_unpacking_done; __pyx_L44_unpacking_failed:; __Pyx_DECREF(__pyx_t_23); __pyx_t_23 = 0; __pyx_t_15 = NULL; if (__Pyx_IterFinish() == 0) __Pyx_RaiseNeedMoreValuesError(index); __PYX_ERR(0, 174, __pyx_L1_error) __pyx_L45_unpacking_done:; } if (!(likely(((__pyx_t_2) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_2, __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node))))) __PYX_ERR(0, 174, __pyx_L1_error) __pyx_t_16 = __Pyx_PyFloat_AsFloat(__pyx_t_3); if (unlikely((__pyx_t_16 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 174, __pyx_L1_error) __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; __Pyx_XDECREF_SET(__pyx_v_node, ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_2)); __pyx_t_2 = 0; __Pyx_XDECREF_SET(__pyx_v_tactic, __pyx_t_4); __pyx_t_4 = 0; __pyx_v_prob = __pyx_t_16; __Pyx_XDECREF_SET(__pyx_v_next_state, __pyx_t_21); __pyx_t_21 = 0; /* … */ __pyx_L42_continue:; } __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
175: # Check for duplicate states
+176: state_key = self._get_state_key(next_state)
__pyx_t_14 = ((struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self->__pyx_base.__pyx_vtab)->__pyx_base._get_state_key(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_BaseMCTS *)__pyx_v_self), __pyx_v_next_state, 0); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 176, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_14); __Pyx_XDECREF_SET(__pyx_v_state_key, __pyx_t_14); __pyx_t_14 = 0;
+177: if state_key is not None and state_key in self.seen_states:
__pyx_t_7 = (__pyx_v_state_key != Py_None);
if (__pyx_t_7) {
} else {
__pyx_t_9 = __pyx_t_7;
goto __pyx_L47_bool_binop_done;
}
if (unlikely(__pyx_v_self->__pyx_base.seen_states == Py_None)) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is not iterable");
__PYX_ERR(0, 177, __pyx_L1_error)
}
__pyx_t_7 = (__Pyx_PyDict_ContainsTF(__pyx_v_state_key, __pyx_v_self->__pyx_base.seen_states, Py_EQ)); if (unlikely((__pyx_t_7 < 0))) __PYX_ERR(0, 177, __pyx_L1_error)
__pyx_t_9 = __pyx_t_7;
__pyx_L47_bool_binop_done:;
if (__pyx_t_9) {
/* … */
}
178: # Reuse existing node - add as child with additional parent edge
+179: existing_node = self.seen_states[state_key]
if (unlikely(__pyx_v_self->__pyx_base.seen_states == Py_None)) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable");
__PYX_ERR(0, 179, __pyx_L1_error)
}
__pyx_t_14 = __Pyx_PyDict_GetItem(__pyx_v_self->__pyx_base.seen_states, __pyx_v_state_key); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 179, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_14);
if (!(likely(((__pyx_t_14) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_14, __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node))))) __PYX_ERR(0, 179, __pyx_L1_error)
__Pyx_XDECREF_SET(__pyx_v_existing_node, ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_14));
__pyx_t_14 = 0;
+180: existing_node.add_parent(node, tactic)
__pyx_t_32.__pyx_n = 1;
__pyx_t_32.action = __pyx_v_tactic;
((struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_v_existing_node->__pyx_vtab)->add_parent(__pyx_v_existing_node, __pyx_v_node, 0, &__pyx_t_32); if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 180, __pyx_L1_error)
+181: if existing_node not in node.children:
__pyx_t_9 = (__Pyx_PySequence_ContainsTF(((PyObject *)__pyx_v_existing_node), __pyx_v_node->children, Py_NE)); if (unlikely((__pyx_t_9 < 0))) __PYX_ERR(0, 181, __pyx_L1_error) if (__pyx_t_9) { /* … */ }
+182: node.children.append(existing_node)
if (unlikely(__pyx_v_node->children == Py_None)) {
PyErr_Format(PyExc_AttributeError, "'NoneType' object has no attribute '%.30s'", "append");
__PYX_ERR(0, 182, __pyx_L1_error)
}
__pyx_t_8 = __Pyx_PyList_Append(__pyx_v_node->children, ((PyObject *)__pyx_v_existing_node)); if (unlikely(__pyx_t_8 == ((int)-1))) __PYX_ERR(0, 182, __pyx_L1_error)
+183: continue
goto __pyx_L42_continue;
184:
+185: child = Node(next_state, parent=node, action=tactic)
__pyx_t_21 = NULL;
__pyx_t_5 = 1;
{
PyObject *__pyx_callargs[2 + ((CYTHON_VECTORCALL) ? 2 : 0)] = {__pyx_t_21, __pyx_v_next_state};
__pyx_t_3 = __Pyx_MakeVectorcallBuilderKwds(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 185, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_3);
if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_parent, ((PyObject *)__pyx_v_node), __pyx_t_3, __pyx_callargs+2, 0) < (0)) __PYX_ERR(0, 185, __pyx_L1_error)
if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_action, __pyx_v_tactic, __pyx_t_3, __pyx_callargs+2, 1) < (0)) __PYX_ERR(0, 185, __pyx_L1_error)
__pyx_t_14 = __Pyx_Object_Vectorcall_CallFromBuilder((PyObject*)__pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET), __pyx_t_3);
__Pyx_XDECREF(__pyx_t_21); __pyx_t_21 = 0;
__Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 185, __pyx_L1_error)
__Pyx_GOTREF((PyObject *)__pyx_t_14);
}
__Pyx_XDECREF_SET(__pyx_v_child, ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_14));
__pyx_t_14 = 0;
+186: child.prior_p = prob
__pyx_v_child->prior_p = __pyx_v_prob;
+187: node.children.append(child)
if (unlikely(__pyx_v_node->children == Py_None)) {
PyErr_Format(PyExc_AttributeError, "'NoneType' object has no attribute '%.30s'", "append");
__PYX_ERR(0, 187, __pyx_L1_error)
}
__pyx_t_8 = __Pyx_PyList_Append(__pyx_v_node->children, ((PyObject *)__pyx_v_child)); if (unlikely(__pyx_t_8 == ((int)-1))) __PYX_ERR(0, 187, __pyx_L1_error)
+188: self.node_count += 1
__pyx_v_self->__pyx_base.node_count = (__pyx_v_self->__pyx_base.node_count + 1);
189:
190: # Register new state in seen_states
+191: if isinstance(next_state, TacticState):
__Pyx_GetModuleGlobalName(__pyx_t_14, __pyx_mstate_global->__pyx_n_u_TacticState); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 191, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_14); __pyx_t_9 = PyObject_IsInstance(__pyx_v_next_state, __pyx_t_14); if (unlikely(__pyx_t_9 == ((int)-1))) __PYX_ERR(0, 191, __pyx_L1_error) __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; if (__pyx_t_9) { /* … */ }
+192: self.seen_states[state_key] = child
if (unlikely(__pyx_v_self->__pyx_base.seen_states == Py_None)) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable");
__PYX_ERR(0, 192, __pyx_L1_error)
}
if (unlikely((PyDict_SetItem(__pyx_v_self->__pyx_base.seen_states, __pyx_v_state_key, ((PyObject *)__pyx_v_child)) < 0))) __PYX_ERR(0, 192, __pyx_L1_error)
+193: new_children_nodes.append(child)
__pyx_t_8 = __Pyx_PyList_Append(__pyx_v_new_children_nodes, ((PyObject *)__pyx_v_child)); if (unlikely(__pyx_t_8 == ((int)-1))) __PYX_ERR(0, 193, __pyx_L1_error)
194:
+195: for node in nodes_to_generate:
__pyx_t_1 = __pyx_v_nodes_to_generate; __Pyx_INCREF(__pyx_t_1); __pyx_t_6 = 0; for (;;) { { Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_1); #if !CYTHON_ASSUME_SAFE_SIZE if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 195, __pyx_L1_error) #endif if (__pyx_t_6 >= __pyx_temp) break; } __pyx_t_14 = __Pyx_PyList_GetItemRefFast(__pyx_t_1, __pyx_t_6, __Pyx_ReferenceSharing_OwnStrongReference); ++__pyx_t_6; if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 195, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_14); if (!(likely(((__pyx_t_14) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_14, __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node))))) __PYX_ERR(0, 195, __pyx_L1_error) __Pyx_XDECREF_SET(__pyx_v_node, ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_14)); __pyx_t_14 = 0; /* … */ } __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+196: node.untried_actions = []
__pyx_t_14 = PyList_New(0); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 196, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_14); __Pyx_GIVEREF(__pyx_t_14); __Pyx_GOTREF(__pyx_v_node->untried_actions); __Pyx_DECREF(__pyx_v_node->untried_actions); __pyx_v_node->untried_actions = ((PyObject*)__pyx_t_14); __pyx_t_14 = 0;
197:
+198: return nodes
__Pyx_XDECREF(__pyx_r); __Pyx_INCREF(__pyx_v_nodes); __pyx_r = __pyx_v_nodes; goto __pyx_L0;
199:
+200: cpdef float _simulate(self, Node node):
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_11_simulate(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
); /*proto*/
static float __pyx_f_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero__simulate(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node, int __pyx_skip_dispatch) {
float __pyx_v_value;
PyObject *__pyx_v_state_str = NULL;
float __pyx_r;
/* Check if called by wrapper */
if (unlikely(__pyx_skip_dispatch)) ;
/* Check if overridden in Python */
else if (
#if !CYTHON_USE_TYPE_SLOTS
unlikely(Py_TYPE(((PyObject *)__pyx_v_self)) != __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero &&
__Pyx_PyType_HasFeature(Py_TYPE(((PyObject *)__pyx_v_self)), Py_TPFLAGS_HAVE_GC))
#else
unlikely(Py_TYPE(((PyObject *)__pyx_v_self))->tp_dictoffset != 0 || __Pyx_PyType_HasFeature(Py_TYPE(((PyObject *)__pyx_v_self)), (Py_TPFLAGS_IS_ABSTRACT | Py_TPFLAGS_HEAPTYPE)))
#endif
) {
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
static PY_UINT64_T __pyx_tp_dict_version = __PYX_DICT_VERSION_INIT, __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
if (unlikely(!__Pyx_object_dict_version_matches(((PyObject *)__pyx_v_self), __pyx_tp_dict_version, __pyx_obj_dict_version))) {
PY_UINT64_T __pyx_typedict_guard = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
#endif
__pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_mstate_global->__pyx_n_u_simulate); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 200, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
if (!__Pyx_IsSameCFunction(__pyx_t_1, (void(*)(void)) __pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_11_simulate)) {
__pyx_t_3 = NULL;
__Pyx_INCREF(__pyx_t_1);
__pyx_t_4 = __pyx_t_1;
__pyx_t_5 = 1;
#if CYTHON_UNPACK_METHODS
if (unlikely(PyMethod_Check(__pyx_t_4))) {
__pyx_t_3 = PyMethod_GET_SELF(__pyx_t_4);
assert(__pyx_t_3);
PyObject* __pyx__function = PyMethod_GET_FUNCTION(__pyx_t_4);
__Pyx_INCREF(__pyx_t_3);
__Pyx_INCREF(__pyx__function);
__Pyx_DECREF_SET(__pyx_t_4, __pyx__function);
__pyx_t_5 = 0;
}
#endif
{
PyObject *__pyx_callargs[2] = {__pyx_t_3, ((PyObject *)__pyx_v_node)};
__pyx_t_2 = __Pyx_PyObject_FastCall((PyObject*)__pyx_t_4, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0;
__Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 200, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
}
__pyx_t_6 = __Pyx_PyFloat_AsFloat(__pyx_t_2); if (unlikely((__pyx_t_6 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 200, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
__pyx_r = __pyx_t_6;
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
goto __pyx_L0;
}
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
__pyx_tp_dict_version = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
__pyx_obj_dict_version = __Pyx_get_object_dict_version(((PyObject *)__pyx_v_self));
if (unlikely(__pyx_typedict_guard != __pyx_tp_dict_version)) {
__pyx_tp_dict_version = __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
}
#endif
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
}
#endif
}
/* … */
/* function exit code */
__pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_1);
__Pyx_XDECREF(__pyx_t_2);
__Pyx_XDECREF(__pyx_t_3);
__Pyx_XDECREF(__pyx_t_4);
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._simulate", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = 0;
__pyx_L0:;
__Pyx_XDECREF(__pyx_v_state_str);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* Python wrapper */
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_11_simulate(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
); /*proto*/
static PyMethodDef __pyx_mdef_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_11_simulate = {"_simulate", (PyCFunction)(void(*)(void))(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_11_simulate, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0};
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_11_simulate(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
) {
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node = 0;
#if !CYTHON_METH_FASTCALL
CYTHON_UNUSED Py_ssize_t __pyx_nargs;
#endif
CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
PyObject *__pyx_r = 0;
__Pyx_RefNannyDeclarations
__Pyx_RefNannySetupContext("_simulate (wrapper)", 0);
#if !CYTHON_METH_FASTCALL
#if CYTHON_ASSUME_SAFE_SIZE
__pyx_nargs = PyTuple_GET_SIZE(__pyx_args);
#else
__pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL;
#endif
#endif
__pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs);
{
PyObject ** const __pyx_pyargnames[] = {&__pyx_mstate_global->__pyx_n_u_node,0};
PyObject* values[1] = {0};
const Py_ssize_t __pyx_kwds_len = (__pyx_kwds) ? __Pyx_NumKwargs_FASTCALL(__pyx_kwds) : 0;
if (unlikely(__pyx_kwds_len) < 0) __PYX_ERR(0, 200, __pyx_L3_error)
if (__pyx_kwds_len > 0) {
switch (__pyx_nargs) {
case 1:
values[0] = __Pyx_ArgRef_FASTCALL(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 200, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 0: break;
default: goto __pyx_L5_argtuple_error;
}
const Py_ssize_t kwd_pos_args = __pyx_nargs;
if (__Pyx_ParseKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values, kwd_pos_args, __pyx_kwds_len, "_simulate", 0) < (0)) __PYX_ERR(0, 200, __pyx_L3_error)
for (Py_ssize_t i = __pyx_nargs; i < 1; i++) {
if (unlikely(!values[i])) { __Pyx_RaiseArgtupleInvalid("_simulate", 1, 1, 1, i); __PYX_ERR(0, 200, __pyx_L3_error) }
}
} else if (unlikely(__pyx_nargs != 1)) {
goto __pyx_L5_argtuple_error;
} else {
values[0] = __Pyx_ArgRef_FASTCALL(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 200, __pyx_L3_error)
}
__pyx_v_node = ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)values[0]);
}
goto __pyx_L6_skip;
__pyx_L5_argtuple_error:;
__Pyx_RaiseArgtupleInvalid("_simulate", 1, 1, 1, __pyx_nargs); __PYX_ERR(0, 200, __pyx_L3_error)
__pyx_L6_skip:;
goto __pyx_L4_argument_unpacking_done;
__pyx_L3_error:;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._simulate", __pyx_clineno, __pyx_lineno, __pyx_filename);
__Pyx_RefNannyFinishContext();
return NULL;
__pyx_L4_argument_unpacking_done:;
if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_node), __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node, 1, "node", 0))) __PYX_ERR(0, 200, __pyx_L1_error)
__pyx_r = __pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10_simulate(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self), __pyx_v_node);
int __pyx_lineno = 0;
const char *__pyx_filename = NULL;
int __pyx_clineno = 0;
/* function exit code */
goto __pyx_L0;
__pyx_L1_error:;
__pyx_r = NULL;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
goto __pyx_L7_cleaned_up;
__pyx_L0:;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__pyx_L7_cleaned_up:;
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
static PyObject *__pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_10_simulate(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node) {
PyObject *__pyx_r = NULL;
__Pyx_XDECREF(__pyx_r);
__pyx_t_1 = __pyx_f_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero__simulate(__pyx_v_self, __pyx_v_node, 1); if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 200, __pyx_L1_error)
__pyx_t_2 = PyFloat_FromDouble(__pyx_t_1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 200, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
__pyx_r = __pyx_t_2;
__pyx_t_2 = 0;
goto __pyx_L0;
/* function exit code */
__pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_2);
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._simulate", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = NULL;
__pyx_L0:;
__Pyx_XGIVEREF(__pyx_r);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* … */
__pyx_t_2 = __Pyx_CyFunction_New(&__pyx_mdef_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_11_simulate, __Pyx_CYFUNCTION_CCLASS, __pyx_mstate_global->__pyx_n_u_MCTS_AlphaZero__simulate, NULL, __pyx_mstate_global->__pyx_n_u_lean_reinforcement_agent_mcts_mc, __pyx_mstate_global->__pyx_d, ((PyObject *)__pyx_mstate_global->__pyx_codeobj_tab[4])); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 200, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030E0000
PyUnstable_Object_EnableDeferredRefcount(__pyx_t_2);
#endif
if (__Pyx_SetItemOnTypeDict(__pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero, __pyx_mstate_global->__pyx_n_u_simulate, __pyx_t_2) < (0)) __PYX_ERR(0, 200, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
201: cdef float value
202:
203: # Check timeout before evaluation
+204: if self._is_timeout():
__pyx_t_7 = ((struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self->__pyx_base.__pyx_vtab)->__pyx_base._is_timeout(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_BaseMCTS *)__pyx_v_self), 0); if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 204, __pyx_L1_error)
if (__pyx_t_7) {
/* … */
}
+205: return 0.0 # Neutral reward on timeout
__pyx_r = 0.0;
goto __pyx_L0;
206:
+207: if node.is_terminal:
if (__pyx_v_node->is_terminal) {
/* … */
}
+208: if isinstance(node.state, ProofFinished):
__pyx_t_1 = __pyx_v_node->state;
__Pyx_INCREF(__pyx_t_1);
__Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_mstate_global->__pyx_n_u_ProofFinished); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 208, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
__pyx_t_7 = PyObject_IsInstance(__pyx_t_1, __pyx_t_2); if (unlikely(__pyx_t_7 == ((int)-1))) __PYX_ERR(0, 208, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
if (__pyx_t_7) {
/* … */
}
+209: return 1.0
__pyx_r = 1.0;
goto __pyx_L0;
+210: if isinstance(node.state, (LeanError, ProofGivenUp)):
__Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_mstate_global->__pyx_n_u_LeanError); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 210, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_mstate_global->__pyx_n_u_ProofGivenUp); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 210, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_t_4 = __pyx_v_node->state; __Pyx_INCREF(__pyx_t_4); __pyx_t_8 = PyObject_IsInstance(__pyx_t_4, __pyx_t_2); __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; if (!__pyx_t_8) { } else { __pyx_t_7 = __pyx_t_8; goto __pyx_L7_bool_binop_done; } __pyx_t_4 = __pyx_v_node->state; __Pyx_INCREF(__pyx_t_4); __pyx_t_8 = PyObject_IsInstance(__pyx_t_4, __pyx_t_1); __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; __pyx_t_7 = __pyx_t_8; __pyx_L7_bool_binop_done:; __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; if (__pyx_t_7) { /* … */ }
+211: return -1.0
__pyx_r = -1.0;
goto __pyx_L0;
212:
+213: if not isinstance(node.state, TacticState):
__pyx_t_2 = __pyx_v_node->state; __Pyx_INCREF(__pyx_t_2); __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_mstate_global->__pyx_n_u_TacticState); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 213, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_t_7 = PyObject_IsInstance(__pyx_t_2, __pyx_t_1); if (unlikely(__pyx_t_7 == ((int)-1))) __PYX_ERR(0, 213, __pyx_L1_error) __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; __pyx_t_8 = (!__pyx_t_7); if (__pyx_t_8) { /* … */ }
+214: return -1.0
__pyx_r = -1.0;
goto __pyx_L0;
215:
+216: if node.encoder_features is not None:
__pyx_t_8 = (__pyx_v_node->encoder_features != Py_None);
if (__pyx_t_8) {
/* … */
goto __pyx_L10;
}
+217: value = self.value_head.predict_from_features(node.encoder_features)
__pyx_t_2 = __pyx_v_self->value_head;
__Pyx_INCREF(__pyx_t_2);
__pyx_t_5 = 0;
{
PyObject *__pyx_callargs[2] = {__pyx_t_2, __pyx_v_node->encoder_features};
__pyx_t_1 = __Pyx_PyObject_FastCallMethod((PyObject*)__pyx_mstate_global->__pyx_n_u_predict_from_features, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (1*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0;
if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 217, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
}
__pyx_t_6 = __Pyx_PyFloat_AsFloat(__pyx_t_1); if (unlikely((__pyx_t_6 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 217, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
__pyx_v_value = __pyx_t_6;
218: else:
+219: state_str = node.state.pp
/*else*/ {
__pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_v_node->state, __pyx_mstate_global->__pyx_n_u_pp); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 219, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
__pyx_v_state_str = __pyx_t_1;
__pyx_t_1 = 0;
+220: value = self.value_head.predict(state_str)
__pyx_t_2 = __pyx_v_self->value_head;
__Pyx_INCREF(__pyx_t_2);
__pyx_t_5 = 0;
{
PyObject *__pyx_callargs[2] = {__pyx_t_2, __pyx_v_state_str};
__pyx_t_1 = __Pyx_PyObject_FastCallMethod((PyObject*)__pyx_mstate_global->__pyx_n_u_predict, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (1*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0;
if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 220, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
}
__pyx_t_6 = __Pyx_PyFloat_AsFloat(__pyx_t_1); if (unlikely((__pyx_t_6 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 220, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
__pyx_v_value = __pyx_t_6;
}
__pyx_L10:;
221:
+222: return value
__pyx_r = __pyx_v_value; goto __pyx_L0;
223:
+224: cpdef list _simulate_batch(self, list nodes):
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_13_simulate_batch(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
); /*proto*/
static PyObject *__pyx_f_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero__simulate_batch(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, PyObject *__pyx_v_nodes, int __pyx_skip_dispatch) {
PyObject *__pyx_v_results = 0;
PyObject *__pyx_v_features_list = 0;
PyObject *__pyx_v_indices_with_features = 0;
PyObject *__pyx_v_states_to_encode = 0;
PyObject *__pyx_v_indices_to_encode = 0;
int __pyx_v_i;
struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *__pyx_v_node = 0;
PyObject *__pyx_v_batch_features = 0;
PyObject *__pyx_v_values = 0;
PyObject *__pyx_v_idx = NULL;
PyObject *__pyx_v_val = NULL;
PyObject *__pyx_r = NULL;
/* Check if called by wrapper */
if (unlikely(__pyx_skip_dispatch)) ;
/* Check if overridden in Python */
else if (
#if !CYTHON_USE_TYPE_SLOTS
unlikely(Py_TYPE(((PyObject *)__pyx_v_self)) != __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero &&
__Pyx_PyType_HasFeature(Py_TYPE(((PyObject *)__pyx_v_self)), Py_TPFLAGS_HAVE_GC))
#else
unlikely(Py_TYPE(((PyObject *)__pyx_v_self))->tp_dictoffset != 0 || __Pyx_PyType_HasFeature(Py_TYPE(((PyObject *)__pyx_v_self)), (Py_TPFLAGS_IS_ABSTRACT | Py_TPFLAGS_HEAPTYPE)))
#endif
) {
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
static PY_UINT64_T __pyx_tp_dict_version = __PYX_DICT_VERSION_INIT, __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
if (unlikely(!__Pyx_object_dict_version_matches(((PyObject *)__pyx_v_self), __pyx_tp_dict_version, __pyx_obj_dict_version))) {
PY_UINT64_T __pyx_typedict_guard = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
#endif
__pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_mstate_global->__pyx_n_u_simulate_batch); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 224, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
if (!__Pyx_IsSameCFunction(__pyx_t_1, (void(*)(void)) __pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_13_simulate_batch)) {
__Pyx_XDECREF(__pyx_r);
__pyx_t_3 = NULL;
__Pyx_INCREF(__pyx_t_1);
__pyx_t_4 = __pyx_t_1;
__pyx_t_5 = 1;
#if CYTHON_UNPACK_METHODS
if (unlikely(PyMethod_Check(__pyx_t_4))) {
__pyx_t_3 = PyMethod_GET_SELF(__pyx_t_4);
assert(__pyx_t_3);
PyObject* __pyx__function = PyMethod_GET_FUNCTION(__pyx_t_4);
__Pyx_INCREF(__pyx_t_3);
__Pyx_INCREF(__pyx__function);
__Pyx_DECREF_SET(__pyx_t_4, __pyx__function);
__pyx_t_5 = 0;
}
#endif
{
PyObject *__pyx_callargs[2] = {__pyx_t_3, __pyx_v_nodes};
__pyx_t_2 = __Pyx_PyObject_FastCall((PyObject*)__pyx_t_4, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0;
__Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 224, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
}
if (!(likely(PyList_CheckExact(__pyx_t_2))||((__pyx_t_2) == Py_None) || __Pyx_RaiseUnexpectedTypeError("list", __pyx_t_2))) __PYX_ERR(0, 224, __pyx_L1_error)
__pyx_r = ((PyObject*)__pyx_t_2);
__pyx_t_2 = 0;
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
goto __pyx_L0;
}
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
__pyx_tp_dict_version = __Pyx_get_tp_dict_version(((PyObject *)__pyx_v_self));
__pyx_obj_dict_version = __Pyx_get_object_dict_version(((PyObject *)__pyx_v_self));
if (unlikely(__pyx_typedict_guard != __pyx_tp_dict_version)) {
__pyx_tp_dict_version = __pyx_obj_dict_version = __PYX_DICT_VERSION_INIT;
}
#endif
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_PYTYPE_LOOKUP && CYTHON_USE_TYPE_SLOTS
}
#endif
}
/* … */
/* function exit code */
__pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_1);
__Pyx_XDECREF(__pyx_t_2);
__Pyx_XDECREF(__pyx_t_3);
__Pyx_XDECREF(__pyx_t_4);
__Pyx_XDECREF(__pyx_t_13);
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._simulate_batch", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = 0;
__pyx_L0:;
__Pyx_XDECREF(__pyx_v_results);
__Pyx_XDECREF(__pyx_v_features_list);
__Pyx_XDECREF(__pyx_v_indices_with_features);
__Pyx_XDECREF(__pyx_v_states_to_encode);
__Pyx_XDECREF(__pyx_v_indices_to_encode);
__Pyx_XDECREF((PyObject *)__pyx_v_node);
__Pyx_XDECREF(__pyx_v_batch_features);
__Pyx_XDECREF(__pyx_v_values);
__Pyx_XDECREF(__pyx_v_idx);
__Pyx_XDECREF(__pyx_v_val);
__Pyx_XGIVEREF(__pyx_r);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* Python wrapper */
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_13_simulate_batch(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
); /*proto*/
static PyMethodDef __pyx_mdef_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_13_simulate_batch = {"_simulate_batch", (PyCFunction)(void(*)(void))(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_13_simulate_batch, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0};
static PyObject *__pyx_pw_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_13_simulate_batch(PyObject *__pyx_v_self,
#if CYTHON_METH_FASTCALL
PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds
#else
PyObject *__pyx_args, PyObject *__pyx_kwds
#endif
) {
PyObject *__pyx_v_nodes = 0;
#if !CYTHON_METH_FASTCALL
CYTHON_UNUSED Py_ssize_t __pyx_nargs;
#endif
CYTHON_UNUSED PyObject *const *__pyx_kwvalues;
PyObject *__pyx_r = 0;
__Pyx_RefNannyDeclarations
__Pyx_RefNannySetupContext("_simulate_batch (wrapper)", 0);
#if !CYTHON_METH_FASTCALL
#if CYTHON_ASSUME_SAFE_SIZE
__pyx_nargs = PyTuple_GET_SIZE(__pyx_args);
#else
__pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL;
#endif
#endif
__pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs);
{
PyObject ** const __pyx_pyargnames[] = {&__pyx_mstate_global->__pyx_n_u_nodes,0};
PyObject* values[1] = {0};
const Py_ssize_t __pyx_kwds_len = (__pyx_kwds) ? __Pyx_NumKwargs_FASTCALL(__pyx_kwds) : 0;
if (unlikely(__pyx_kwds_len) < 0) __PYX_ERR(0, 224, __pyx_L3_error)
if (__pyx_kwds_len > 0) {
switch (__pyx_nargs) {
case 1:
values[0] = __Pyx_ArgRef_FASTCALL(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 224, __pyx_L3_error)
CYTHON_FALLTHROUGH;
case 0: break;
default: goto __pyx_L5_argtuple_error;
}
const Py_ssize_t kwd_pos_args = __pyx_nargs;
if (__Pyx_ParseKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values, kwd_pos_args, __pyx_kwds_len, "_simulate_batch", 0) < (0)) __PYX_ERR(0, 224, __pyx_L3_error)
for (Py_ssize_t i = __pyx_nargs; i < 1; i++) {
if (unlikely(!values[i])) { __Pyx_RaiseArgtupleInvalid("_simulate_batch", 1, 1, 1, i); __PYX_ERR(0, 224, __pyx_L3_error) }
}
} else if (unlikely(__pyx_nargs != 1)) {
goto __pyx_L5_argtuple_error;
} else {
values[0] = __Pyx_ArgRef_FASTCALL(__pyx_args, 0);
if (!CYTHON_ASSUME_SAFE_MACROS && unlikely(!values[0])) __PYX_ERR(0, 224, __pyx_L3_error)
}
__pyx_v_nodes = ((PyObject*)values[0]);
}
goto __pyx_L6_skip;
__pyx_L5_argtuple_error:;
__Pyx_RaiseArgtupleInvalid("_simulate_batch", 1, 1, 1, __pyx_nargs); __PYX_ERR(0, 224, __pyx_L3_error)
__pyx_L6_skip:;
goto __pyx_L4_argument_unpacking_done;
__pyx_L3_error:;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._simulate_batch", __pyx_clineno, __pyx_lineno, __pyx_filename);
__Pyx_RefNannyFinishContext();
return NULL;
__pyx_L4_argument_unpacking_done:;
if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_nodes), (&PyList_Type), 1, "nodes", 1))) __PYX_ERR(0, 224, __pyx_L1_error)
__pyx_r = __pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_12_simulate_batch(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self), __pyx_v_nodes);
int __pyx_lineno = 0;
const char *__pyx_filename = NULL;
int __pyx_clineno = 0;
/* function exit code */
goto __pyx_L0;
__pyx_L1_error:;
__pyx_r = NULL;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
goto __pyx_L7_cleaned_up;
__pyx_L0:;
for (Py_ssize_t __pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) {
Py_XDECREF(values[__pyx_temp]);
}
__pyx_L7_cleaned_up:;
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
static PyObject *__pyx_pf_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_12_simulate_batch(struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *__pyx_v_self, PyObject *__pyx_v_nodes) {
PyObject *__pyx_r = NULL;
__Pyx_XDECREF(__pyx_r);
__pyx_t_1 = __pyx_f_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero__simulate_batch(__pyx_v_self, __pyx_v_nodes, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 224, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
__pyx_r = __pyx_t_1;
__pyx_t_1 = 0;
goto __pyx_L0;
/* function exit code */
__pyx_L1_error:;
__Pyx_XDECREF(__pyx_t_1);
__Pyx_AddTraceback("lean_reinforcement.agent.mcts.mcts_cy.alphazero_cy.MCTS_AlphaZero._simulate_batch", __pyx_clineno, __pyx_lineno, __pyx_filename);
__pyx_r = NULL;
__pyx_L0:;
__Pyx_XGIVEREF(__pyx_r);
__Pyx_RefNannyFinishContext();
return __pyx_r;
}
/* … */
__pyx_t_2 = __Pyx_CyFunction_New(&__pyx_mdef_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_14MCTS_AlphaZero_13_simulate_batch, __Pyx_CYFUNCTION_CCLASS, __pyx_mstate_global->__pyx_n_u_MCTS_AlphaZero__simulate_batch, NULL, __pyx_mstate_global->__pyx_n_u_lean_reinforcement_agent_mcts_mc, __pyx_mstate_global->__pyx_d, ((PyObject *)__pyx_mstate_global->__pyx_codeobj_tab[5])); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 224, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030E0000
PyUnstable_Object_EnableDeferredRefcount(__pyx_t_2);
#endif
if (__Pyx_SetItemOnTypeDict(__pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero, __pyx_mstate_global->__pyx_n_u_simulate_batch, __pyx_t_2) < (0)) __PYX_ERR(0, 224, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
+225: cdef list results = [0.0] * len(nodes)
if (unlikely(__pyx_v_nodes == Py_None)) {
PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()");
__PYX_ERR(0, 225, __pyx_L1_error)
}
__pyx_t_6 = __Pyx_PyList_GET_SIZE(__pyx_v_nodes); if (unlikely(__pyx_t_6 == ((Py_ssize_t)-1))) __PYX_ERR(0, 225, __pyx_L1_error)
__pyx_t_1 = PyList_New(1 * ((__pyx_t_6<0) ? 0:__pyx_t_6)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 225, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
{ Py_ssize_t __pyx_temp;
for (__pyx_temp=0; __pyx_temp < __pyx_t_6; __pyx_temp++) {
__Pyx_INCREF(__pyx_mstate_global->__pyx_float_0_0);
__Pyx_GIVEREF(__pyx_mstate_global->__pyx_float_0_0);
if (__Pyx_PyList_SET_ITEM(__pyx_t_1, __pyx_temp, __pyx_mstate_global->__pyx_float_0_0) != (0)) __PYX_ERR(0, 225, __pyx_L1_error);
}
}
__pyx_v_results = ((PyObject*)__pyx_t_1);
__pyx_t_1 = 0;
+226: cdef list features_list = []
__pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 226, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_v_features_list = ((PyObject*)__pyx_t_1); __pyx_t_1 = 0;
+227: cdef list indices_with_features = []
__pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 227, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_v_indices_with_features = ((PyObject*)__pyx_t_1); __pyx_t_1 = 0;
+228: cdef list states_to_encode = []
__pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 228, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_v_states_to_encode = ((PyObject*)__pyx_t_1); __pyx_t_1 = 0;
+229: cdef list indices_to_encode = []
__pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 229, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_v_indices_to_encode = ((PyObject*)__pyx_t_1); __pyx_t_1 = 0;
230: cdef int i
231: cdef Node node
232: cdef object batch_features
233: cdef list values
234:
235: # Check timeout before batch evaluation
+236: if self._is_timeout():
__pyx_t_7 = ((struct __pyx_vtabstruct_18lean_reinforcement_5agent_4mcts_7mcts_cy_12alphazero_cy_MCTS_AlphaZero *)__pyx_v_self->__pyx_base.__pyx_vtab)->__pyx_base._is_timeout(((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_BaseMCTS *)__pyx_v_self), 0); if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 236, __pyx_L1_error)
if (__pyx_t_7) {
/* … */
}
+237: return results # Return neutral rewards on timeout
__Pyx_XDECREF(__pyx_r); __Pyx_INCREF(__pyx_v_results); __pyx_r = __pyx_v_results; goto __pyx_L0;
238:
+239: for i in range(len(nodes)):
if (unlikely(__pyx_v_nodes == Py_None)) {
PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()");
__PYX_ERR(0, 239, __pyx_L1_error)
}
__pyx_t_6 = __Pyx_PyList_GET_SIZE(__pyx_v_nodes); if (unlikely(__pyx_t_6 == ((Py_ssize_t)-1))) __PYX_ERR(0, 239, __pyx_L1_error)
__pyx_t_8 = __pyx_t_6;
for (__pyx_t_9 = 0; __pyx_t_9 < __pyx_t_8; __pyx_t_9+=1) {
__pyx_v_i = __pyx_t_9;
+240: node = nodes[i]
if (unlikely(__pyx_v_nodes == Py_None)) {
PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable");
__PYX_ERR(0, 240, __pyx_L1_error)
}
__pyx_t_1 = __Pyx_PyList_GET_ITEM(__pyx_v_nodes, __pyx_v_i);
__Pyx_INCREF(__pyx_t_1);
if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_mstate_global->__pyx_ptype_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node))))) __PYX_ERR(0, 240, __pyx_L1_error)
__Pyx_XDECREF_SET(__pyx_v_node, ((struct __pyx_obj_18lean_reinforcement_5agent_4mcts_7mcts_cy_12base_mcts_cy_Node *)__pyx_t_1));
__pyx_t_1 = 0;
+241: if node.is_terminal:
if (__pyx_v_node->is_terminal) {
/* … */
}
+242: if isinstance(node.state, ProofFinished):
__pyx_t_1 = __pyx_v_node->state;
__Pyx_INCREF(__pyx_t_1);
__Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_mstate_global->__pyx_n_u_ProofFinished); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 242, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
__pyx_t_7 = PyObject_IsInstance(__pyx_t_1, __pyx_t_2); if (unlikely(__pyx_t_7 == ((int)-1))) __PYX_ERR(0, 242, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
if (__pyx_t_7) {
/* … */
goto __pyx_L7;
}
+243: results[i] = 1.0
if (unlikely((__Pyx_SetItemInt(__pyx_v_results, __pyx_v_i, __pyx_mstate_global->__pyx_float_1_0, int, 1, __Pyx_PyLong_From_int, 1, 0, 0, 1, __Pyx_ReferenceSharing_OwnStrongReference) < 0))) __PYX_ERR(0, 243, __pyx_L1_error)
+244: elif isinstance(node.state, (LeanError, ProofGivenUp)):
__Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_mstate_global->__pyx_n_u_LeanError); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 244, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_2); __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_mstate_global->__pyx_n_u_ProofGivenUp); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 244, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_t_4 = __pyx_v_node->state; __Pyx_INCREF(__pyx_t_4); __pyx_t_10 = PyObject_IsInstance(__pyx_t_4, __pyx_t_2); __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; if (!__pyx_t_10) { } else { __pyx_t_7 = __pyx_t_10; goto __pyx_L8_bool_binop_done; } __pyx_t_4 = __pyx_v_node->state; __Pyx_INCREF(__pyx_t_4); __pyx_t_10 = PyObject_IsInstance(__pyx_t_4, __pyx_t_1); __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; __pyx_t_7 = __pyx_t_10; __pyx_L8_bool_binop_done:; __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; if (__pyx_t_7) { /* … */ } __pyx_L7:;
+245: results[i] = -1.0
if (unlikely((__Pyx_SetItemInt(__pyx_v_results, __pyx_v_i, __pyx_mstate_global->__pyx_float_neg_1_0, int, 1, __Pyx_PyLong_From_int, 1, 0, 0, 1, __Pyx_ReferenceSharing_OwnStrongReference) < 0))) __PYX_ERR(0, 245, __pyx_L1_error)
+246: continue
goto __pyx_L4_continue;
247:
+248: if not isinstance(node.state, TacticState):
__pyx_t_2 = __pyx_v_node->state;
__Pyx_INCREF(__pyx_t_2);
__Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_mstate_global->__pyx_n_u_TacticState); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 248, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
__pyx_t_7 = PyObject_IsInstance(__pyx_t_2, __pyx_t_1); if (unlikely(__pyx_t_7 == ((int)-1))) __PYX_ERR(0, 248, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
__pyx_t_10 = (!__pyx_t_7);
if (__pyx_t_10) {
/* … */
}
+249: results[i] = -1.0
if (unlikely((__Pyx_SetItemInt(__pyx_v_results, __pyx_v_i, __pyx_mstate_global->__pyx_float_neg_1_0, int, 1, __Pyx_PyLong_From_int, 1, 0, 0, 1, __Pyx_ReferenceSharing_OwnStrongReference) < 0))) __PYX_ERR(0, 249, __pyx_L1_error)
+250: continue
goto __pyx_L4_continue;
251:
+252: if node.encoder_features is not None:
__pyx_t_10 = (__pyx_v_node->encoder_features != Py_None);
if (__pyx_t_10) {
/* … */
goto __pyx_L11;
}
+253: features_list.append(node.encoder_features)
__pyx_t_1 = __pyx_v_node->encoder_features;
__Pyx_INCREF(__pyx_t_1);
__pyx_t_11 = __Pyx_PyList_Append(__pyx_v_features_list, __pyx_t_1); if (unlikely(__pyx_t_11 == ((int)-1))) __PYX_ERR(0, 253, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+254: indices_with_features.append(i)
__pyx_t_1 = __Pyx_PyLong_From_int(__pyx_v_i); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 254, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_t_11 = __Pyx_PyList_Append(__pyx_v_indices_with_features, __pyx_t_1); if (unlikely(__pyx_t_11 == ((int)-1))) __PYX_ERR(0, 254, __pyx_L1_error) __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
255: else:
+256: states_to_encode.append(node.state.pp)
/*else*/ {
__pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_v_node->state, __pyx_mstate_global->__pyx_n_u_pp); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 256, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
__pyx_t_11 = __Pyx_PyList_Append(__pyx_v_states_to_encode, __pyx_t_1); if (unlikely(__pyx_t_11 == ((int)-1))) __PYX_ERR(0, 256, __pyx_L1_error)
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+257: indices_to_encode.append(i)
__pyx_t_1 = __Pyx_PyLong_From_int(__pyx_v_i); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 257, __pyx_L1_error) __Pyx_GOTREF(__pyx_t_1); __pyx_t_11 = __Pyx_PyList_Append(__pyx_v_indices_to_encode, __pyx_t_1); if (unlikely(__pyx_t_11 == ((int)-1))) __PYX_ERR(0, 257, __pyx_L1_error) __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; } __pyx_L11:; __pyx_L4_continue:; }
258:
+259: if features_list:
{
Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_v_features_list);
if (unlikely(((!CYTHON_ASSUME_SAFE_SIZE) && __pyx_temp < 0))) __PYX_ERR(0, 259, __pyx_L1_error)
__pyx_t_10 = (__pyx_temp != 0);
}
if (__pyx_t_10) {
/* … */
}
+260: batch_features = torch.cat(features_list, dim=0)
__pyx_t_2 = NULL;
__Pyx_GetModuleGlobalName(__pyx_t_4, __pyx_mstate_global->__pyx_n_u_torch); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 260, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_4);
__pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_mstate_global->__pyx_n_u_cat); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 260, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_3);
__Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
__pyx_t_5 = 1;
#if CYTHON_UNPACK_METHODS
if (unlikely(PyMethod_Check(__pyx_t_3))) {
__pyx_t_2 = PyMethod_GET_SELF(__pyx_t_3);
assert(__pyx_t_2);
PyObject* __pyx__function = PyMethod_GET_FUNCTION(__pyx_t_3);
__Pyx_INCREF(__pyx_t_2);
__Pyx_INCREF(__pyx__function);
__Pyx_DECREF_SET(__pyx_t_3, __pyx__function);
__pyx_t_5 = 0;
}
#endif
{
PyObject *__pyx_callargs[2 + ((CYTHON_VECTORCALL) ? 1 : 0)] = {__pyx_t_2, __pyx_v_features_list};
__pyx_t_4 = __Pyx_MakeVectorcallBuilderKwds(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 260, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_4);
if (__Pyx_VectorcallBuilder_AddArg(__pyx_mstate_global->__pyx_n_u_dim, __pyx_mstate_global->__pyx_int_0, __pyx_t_4, __pyx_callargs+2, 0) < (0)) __PYX_ERR(0, 260, __pyx_L1_error)
__pyx_t_1 = __Pyx_Object_Vectorcall_CallFromBuilder((PyObject*)__pyx_t_3, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET), __pyx_t_4);
__Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0;
__Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
__Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 260, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
}
__pyx_v_batch_features = __pyx_t_1;
__pyx_t_1 = 0;
+261: values = self.value_head.predict_from_features_batch(batch_features)
__pyx_t_3 = __pyx_v_self->value_head;
__Pyx_INCREF(__pyx_t_3);
__pyx_t_5 = 0;
{
PyObject *__pyx_callargs[2] = {__pyx_t_3, __pyx_v_batch_features};
__pyx_t_1 = __Pyx_PyObject_FastCallMethod((PyObject*)__pyx_mstate_global->__pyx_n_u_predict_from_features_batch, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (1*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0;
if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 261, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
}
if (!(likely(PyList_CheckExact(__pyx_t_1))||((__pyx_t_1) == Py_None) || __Pyx_RaiseUnexpectedTypeError("list", __pyx_t_1))) __PYX_ERR(0, 261, __pyx_L1_error)
__pyx_v_values = ((PyObject*)__pyx_t_1);
__pyx_t_1 = 0;
+262: for idx, val in zip(indices_with_features, values):
__pyx_t_3 = NULL;
__pyx_t_5 = 1;
{
PyObject *__pyx_callargs[3] = {__pyx_t_3, __pyx_v_indices_with_features, __pyx_v_values};
__pyx_t_1 = __Pyx_PyObject_FastCall((PyObject*)__pyx_builtin_zip, __pyx_callargs+__pyx_t_5, (3-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0;
if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 262, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
}
if (likely(PyList_CheckExact(__pyx_t_1)) || PyTuple_CheckExact(__pyx_t_1)) {
__pyx_t_3 = __pyx_t_1; __Pyx_INCREF(__pyx_t_3);
__pyx_t_6 = 0;
__pyx_t_12 = NULL;
} else {
__pyx_t_6 = -1; __pyx_t_3 = PyObject_GetIter(__pyx_t_1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 262, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_3);
__pyx_t_12 = (CYTHON_COMPILING_IN_LIMITED_API) ? PyIter_Next : __Pyx_PyObject_GetIterNextFunc(__pyx_t_3); if (unlikely(!__pyx_t_12)) __PYX_ERR(0, 262, __pyx_L1_error)
}
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
for (;;) {
if (likely(!__pyx_t_12)) {
if (likely(PyList_CheckExact(__pyx_t_3))) {
{
Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_3);
#if !CYTHON_ASSUME_SAFE_SIZE
if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 262, __pyx_L1_error)
#endif
if (__pyx_t_6 >= __pyx_temp) break;
}
__pyx_t_1 = __Pyx_PyList_GetItemRefFast(__pyx_t_3, __pyx_t_6, __Pyx_ReferenceSharing_OwnStrongReference);
++__pyx_t_6;
} else {
{
Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_3);
#if !CYTHON_ASSUME_SAFE_SIZE
if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 262, __pyx_L1_error)
#endif
if (__pyx_t_6 >= __pyx_temp) break;
}
#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
__pyx_t_1 = __Pyx_NewRef(PyTuple_GET_ITEM(__pyx_t_3, __pyx_t_6));
#else
__pyx_t_1 = __Pyx_PySequence_ITEM(__pyx_t_3, __pyx_t_6);
#endif
++__pyx_t_6;
}
if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 262, __pyx_L1_error)
} else {
__pyx_t_1 = __pyx_t_12(__pyx_t_3);
if (unlikely(!__pyx_t_1)) {
PyObject* exc_type = PyErr_Occurred();
if (exc_type) {
if (unlikely(!__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) __PYX_ERR(0, 262, __pyx_L1_error)
PyErr_Clear();
}
break;
}
}
__Pyx_GOTREF(__pyx_t_1);
if ((likely(PyTuple_CheckExact(__pyx_t_1))) || (PyList_CheckExact(__pyx_t_1))) {
PyObject* sequence = __pyx_t_1;
Py_ssize_t size = __Pyx_PySequence_SIZE(sequence);
if (unlikely(size != 2)) {
if (size > 2) __Pyx_RaiseTooManyValuesError(2);
else if (size >= 0) __Pyx_RaiseNeedMoreValuesError(size);
__PYX_ERR(0, 262, __pyx_L1_error)
}
#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
if (likely(PyTuple_CheckExact(sequence))) {
__pyx_t_4 = PyTuple_GET_ITEM(sequence, 0);
__Pyx_INCREF(__pyx_t_4);
__pyx_t_2 = PyTuple_GET_ITEM(sequence, 1);
__Pyx_INCREF(__pyx_t_2);
} else {
__pyx_t_4 = __Pyx_PyList_GetItemRefFast(sequence, 0, __Pyx_ReferenceSharing_SharedReference);
if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 262, __pyx_L1_error)
__Pyx_XGOTREF(__pyx_t_4);
__pyx_t_2 = __Pyx_PyList_GetItemRefFast(sequence, 1, __Pyx_ReferenceSharing_SharedReference);
if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 262, __pyx_L1_error)
__Pyx_XGOTREF(__pyx_t_2);
}
#else
__pyx_t_4 = __Pyx_PySequence_ITEM(sequence, 0); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 262, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_4);
__pyx_t_2 = __Pyx_PySequence_ITEM(sequence, 1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 262, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
#endif
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
} else {
Py_ssize_t index = -1;
__pyx_t_13 = PyObject_GetIter(__pyx_t_1); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 262, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_13);
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
__pyx_t_14 = (CYTHON_COMPILING_IN_LIMITED_API) ? PyIter_Next : __Pyx_PyObject_GetIterNextFunc(__pyx_t_13);
index = 0; __pyx_t_4 = __pyx_t_14(__pyx_t_13); if (unlikely(!__pyx_t_4)) goto __pyx_L15_unpacking_failed;
__Pyx_GOTREF(__pyx_t_4);
index = 1; __pyx_t_2 = __pyx_t_14(__pyx_t_13); if (unlikely(!__pyx_t_2)) goto __pyx_L15_unpacking_failed;
__Pyx_GOTREF(__pyx_t_2);
if (__Pyx_IternextUnpackEndCheck(__pyx_t_14(__pyx_t_13), 2) < (0)) __PYX_ERR(0, 262, __pyx_L1_error)
__pyx_t_14 = NULL;
__Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0;
goto __pyx_L16_unpacking_done;
__pyx_L15_unpacking_failed:;
__Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0;
__pyx_t_14 = NULL;
if (__Pyx_IterFinish() == 0) __Pyx_RaiseNeedMoreValuesError(index);
__PYX_ERR(0, 262, __pyx_L1_error)
__pyx_L16_unpacking_done:;
}
__Pyx_XDECREF_SET(__pyx_v_idx, __pyx_t_4);
__pyx_t_4 = 0;
__Pyx_XDECREF_SET(__pyx_v_val, __pyx_t_2);
__pyx_t_2 = 0;
/* … */
}
__Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
+263: results[idx] = val
if (unlikely((PyObject_SetItem(__pyx_v_results, __pyx_v_idx, __pyx_v_val) < 0))) __PYX_ERR(0, 263, __pyx_L1_error)
264:
+265: if states_to_encode:
{
Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_v_states_to_encode);
if (unlikely(((!CYTHON_ASSUME_SAFE_SIZE) && __pyx_temp < 0))) __PYX_ERR(0, 265, __pyx_L1_error)
__pyx_t_10 = (__pyx_temp != 0);
}
if (__pyx_t_10) {
/* … */
}
+266: values = self.value_head.predict_batch(states_to_encode)
__pyx_t_1 = __pyx_v_self->value_head;
__Pyx_INCREF(__pyx_t_1);
__pyx_t_5 = 0;
{
PyObject *__pyx_callargs[2] = {__pyx_t_1, __pyx_v_states_to_encode};
__pyx_t_3 = __Pyx_PyObject_FastCallMethod((PyObject*)__pyx_mstate_global->__pyx_n_u_predict_batch, __pyx_callargs+__pyx_t_5, (2-__pyx_t_5) | (1*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0;
if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 266, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_3);
}
if (!(likely(PyList_CheckExact(__pyx_t_3))||((__pyx_t_3) == Py_None) || __Pyx_RaiseUnexpectedTypeError("list", __pyx_t_3))) __PYX_ERR(0, 266, __pyx_L1_error)
__Pyx_XDECREF_SET(__pyx_v_values, ((PyObject*)__pyx_t_3));
__pyx_t_3 = 0;
+267: for idx, val in zip(indices_to_encode, values):
__pyx_t_1 = NULL;
__pyx_t_5 = 1;
{
PyObject *__pyx_callargs[3] = {__pyx_t_1, __pyx_v_indices_to_encode, __pyx_v_values};
__pyx_t_3 = __Pyx_PyObject_FastCall((PyObject*)__pyx_builtin_zip, __pyx_callargs+__pyx_t_5, (3-__pyx_t_5) | (__pyx_t_5*__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET));
__Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0;
if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 267, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_3);
}
if (likely(PyList_CheckExact(__pyx_t_3)) || PyTuple_CheckExact(__pyx_t_3)) {
__pyx_t_1 = __pyx_t_3; __Pyx_INCREF(__pyx_t_1);
__pyx_t_6 = 0;
__pyx_t_12 = NULL;
} else {
__pyx_t_6 = -1; __pyx_t_1 = PyObject_GetIter(__pyx_t_3); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 267, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_1);
__pyx_t_12 = (CYTHON_COMPILING_IN_LIMITED_API) ? PyIter_Next : __Pyx_PyObject_GetIterNextFunc(__pyx_t_1); if (unlikely(!__pyx_t_12)) __PYX_ERR(0, 267, __pyx_L1_error)
}
__Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
for (;;) {
if (likely(!__pyx_t_12)) {
if (likely(PyList_CheckExact(__pyx_t_1))) {
{
Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_1);
#if !CYTHON_ASSUME_SAFE_SIZE
if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 267, __pyx_L1_error)
#endif
if (__pyx_t_6 >= __pyx_temp) break;
}
__pyx_t_3 = __Pyx_PyList_GetItemRefFast(__pyx_t_1, __pyx_t_6, __Pyx_ReferenceSharing_OwnStrongReference);
++__pyx_t_6;
} else {
{
Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_1);
#if !CYTHON_ASSUME_SAFE_SIZE
if (unlikely((__pyx_temp < 0))) __PYX_ERR(0, 267, __pyx_L1_error)
#endif
if (__pyx_t_6 >= __pyx_temp) break;
}
#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
__pyx_t_3 = __Pyx_NewRef(PyTuple_GET_ITEM(__pyx_t_1, __pyx_t_6));
#else
__pyx_t_3 = __Pyx_PySequence_ITEM(__pyx_t_1, __pyx_t_6);
#endif
++__pyx_t_6;
}
if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 267, __pyx_L1_error)
} else {
__pyx_t_3 = __pyx_t_12(__pyx_t_1);
if (unlikely(!__pyx_t_3)) {
PyObject* exc_type = PyErr_Occurred();
if (exc_type) {
if (unlikely(!__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) __PYX_ERR(0, 267, __pyx_L1_error)
PyErr_Clear();
}
break;
}
}
__Pyx_GOTREF(__pyx_t_3);
if ((likely(PyTuple_CheckExact(__pyx_t_3))) || (PyList_CheckExact(__pyx_t_3))) {
PyObject* sequence = __pyx_t_3;
Py_ssize_t size = __Pyx_PySequence_SIZE(sequence);
if (unlikely(size != 2)) {
if (size > 2) __Pyx_RaiseTooManyValuesError(2);
else if (size >= 0) __Pyx_RaiseNeedMoreValuesError(size);
__PYX_ERR(0, 267, __pyx_L1_error)
}
#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
if (likely(PyTuple_CheckExact(sequence))) {
__pyx_t_2 = PyTuple_GET_ITEM(sequence, 0);
__Pyx_INCREF(__pyx_t_2);
__pyx_t_4 = PyTuple_GET_ITEM(sequence, 1);
__Pyx_INCREF(__pyx_t_4);
} else {
__pyx_t_2 = __Pyx_PyList_GetItemRefFast(sequence, 0, __Pyx_ReferenceSharing_SharedReference);
if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 267, __pyx_L1_error)
__Pyx_XGOTREF(__pyx_t_2);
__pyx_t_4 = __Pyx_PyList_GetItemRefFast(sequence, 1, __Pyx_ReferenceSharing_SharedReference);
if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 267, __pyx_L1_error)
__Pyx_XGOTREF(__pyx_t_4);
}
#else
__pyx_t_2 = __Pyx_PySequence_ITEM(sequence, 0); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 267, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_2);
__pyx_t_4 = __Pyx_PySequence_ITEM(sequence, 1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 267, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_4);
#endif
__Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
} else {
Py_ssize_t index = -1;
__pyx_t_13 = PyObject_GetIter(__pyx_t_3); if (unlikely(!__pyx_t_13)) __PYX_ERR(0, 267, __pyx_L1_error)
__Pyx_GOTREF(__pyx_t_13);
__Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0;
__pyx_t_14 = (CYTHON_COMPILING_IN_LIMITED_API) ? PyIter_Next : __Pyx_PyObject_GetIterNextFunc(__pyx_t_13);
index = 0; __pyx_t_2 = __pyx_t_14(__pyx_t_13); if (unlikely(!__pyx_t_2)) goto __pyx_L21_unpacking_failed;
__Pyx_GOTREF(__pyx_t_2);
index = 1; __pyx_t_4 = __pyx_t_14(__pyx_t_13); if (unlikely(!__pyx_t_4)) goto __pyx_L21_unpacking_failed;
__Pyx_GOTREF(__pyx_t_4);
if (__Pyx_IternextUnpackEndCheck(__pyx_t_14(__pyx_t_13), 2) < (0)) __PYX_ERR(0, 267, __pyx_L1_error)
__pyx_t_14 = NULL;
__Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0;
goto __pyx_L22_unpacking_done;
__pyx_L21_unpacking_failed:;
__Pyx_DECREF(__pyx_t_13); __pyx_t_13 = 0;
__pyx_t_14 = NULL;
if (__Pyx_IterFinish() == 0) __Pyx_RaiseNeedMoreValuesError(index);
__PYX_ERR(0, 267, __pyx_L1_error)
__pyx_L22_unpacking_done:;
}
__Pyx_XDECREF_SET(__pyx_v_idx, __pyx_t_2);
__pyx_t_2 = 0;
__Pyx_XDECREF_SET(__pyx_v_val, __pyx_t_4);
__pyx_t_4 = 0;
/* … */
}
__Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
+268: results[idx] = val
if (unlikely((PyObject_SetItem(__pyx_v_results, __pyx_v_idx, __pyx_v_val) < 0))) __PYX_ERR(0, 268, __pyx_L1_error)
269:
+270: return results
__Pyx_XDECREF(__pyx_r); __Pyx_INCREF(__pyx_v_results); __pyx_r = __pyx_v_results; goto __pyx_L0;