pybind11 Integration
If you’re using pybind11, you can easily incorporate your own custom linear operator / matrix function pair using primate
s binding headers.
Native pybind11 types
Suppose you have a custom class LinOp
understood by pybind11, which looks something like:
class LinOp {
int nrow, ncol;
(int nr, int nc) : nrow(nr), ncol(nc) {}
LinOp
void matvec(const float* input, float* output) const {
... // implementation details
}
void shape() const { return std::make_pair(nrow, ncol); }
}
Since pybind11 understands how to pass a pointer to this type natively, creating an extension module that calls primate
’s SLQ trace estimator with LinOp
can be done by just calling the _trace_wrapper
function:
#include <binders/pb11_trace_bind.h> // _trace_wrapper binding
#include "LinOp.h" // custom LinOp class
(_custom_trace, m) {
PYBIND11_MODULE// m is the actual py::module (exported as _custom_trace)
.doc() = "custom trace estimator module";
m< false, float, LinOp >(m);
_trace_wrapper}
The final extension module _custom_trace
will have a method trace_slq
as an attribute that can be called from Python to initiate the SLQ method with the corresponding LineOp
type.
Wrapping types
If you’re trying to create bindings for a class that isn’t known to pybind11, or it doesn’t natively support the matvec
and shape
constraints, you can optionally provide a wrapper function to the last template parameter of _trace_wrapper
:
struct WrappedLinOp {
;
LinOp op
void WrappedLinOp(LinOp& _op) : op(_op){ }
void matvec(const float* input, float* output){
... // implementation details, e.g. op.dot(input, output)
}
void shape() { ... }
}
void linop_wrapper(LinOp* op){ // or py::object, for non-native types
return WrappedLinOp(*op);
}
(_custom_trace, m) {
PYBIND11_MODULE// m is the actual py::module (exported as _custom_trace)
.doc() = "custom trace estimator module";
m< false, float, LinOp, linop_wrapper >(m);
_trace_wrapper}
As a side effect, this also enables full access to matrix types that implement matrix-vector multiplication but don’t have the matching method names / signatures available needed to match the underlying LinearOperaor
concept.
Full Example using Eigen
Here’s a real example of what simple code might look like that wraps a Eigen SparseMatrix for use with primate
. Eigen
supports matrix-vector multiplication out of the box with the overloaded operator*()
and is understood natively by pybind11, thus it suffices to define wrapper class that respects the LinearOperator
concept:
template< std::floating_point F >
struct SparseEigenLinearOperator {
using value_type = F;
using float_vec = Eigen::Matrix< F, Eigen::Dynamic, 1 >;
const Eigen::SparseMatrix< F > A;
(const Eigen::SparseMatrix< F >& _mat) : A(_mat){}
SparseEigenLinearOperator
void matvec(const F* inp, F* out) const noexcept {
auto input = Eigen::Map< const float_vec >(inp, A.cols(), 1);
auto output = Eigen::Map< float_vec >(out, A.rows(), 1);
= A * input;
output }
auto shape() const noexcept -> std::pair< size_t, size_t > {
return std::make_pair((size_t) A.rows(), (size_t) A.cols());
}
};
Then, simply write a quick wrapper function that converts a Eigen::SparseMatrix< F >*
type to a SparseEigenLinearOperator< F >
types and incorporate with _trace_wrapper
template< std::floating_point F >
auto eigen_sparse_wrapper(const Eigen::SparseMatrix< F >* A){
return SparseEigenLinearOperator< F >(*A);
}
(_custom_trace, m) {
PYBIND11_MODULE.doc() = "custom trace estimator module";
m< false, float, Eigen::SparseMatrix< float > >(
_trace_wrapper, eigen_sparse_wrapper< float >
m);
}
That’s it. The module _custom_trace
will now have an exported trace_slq
method that you can call from Python. For a list of arguments, it’s recommended to see the slq
declaration in trace.py.
Example: Log determinant
For explanatory purposes, the following code outline how to call the trace estimator to compute the log determinant using a custom user-implemented operator LinOp
:
#include <cmath> // std::log
#include <_linear_operator/linear_operator.h> // LinearOperator
#include <_lanczos/lanczos.h> // sl_trace
#include "LinOp.h" // custom class
void slq_log_det(LinOp A, ...){
static_assert(LinearOperator< LinOp >); // Constraint check
const auto matrix_func = std::log; // any invocable
auto rbg = ThreadedRNG64(); // default RNG
auto estimates = vector< float >(n, 0); // output estimates
< float >( // specific precision
sl_trace, matrix_func, rbg, // main arguments
A..., // other inputs
.data() // output
estimates);
}