pybind11 Integration

If you’re using pybind11, you can easily incorporate your own custom linear operator / matrix function pair using primates binding headers.

Native pybind11 types

Suppose you have a custom class LinOp understood by pybind11, which looks something like:

class LinOp {
  int nrow, ncol;
  
  LinOp(int nr, int nc) : nrow(nr), ncol(nc) {}
  
  void matvec(const float* input, float* output) const {
    ... // implementation details 
  }

  void shape() const { return std::make_pair(nrow, ncol); }
}

Since pybind11 understands how to pass a pointer to this type natively, creating an extension module that calls primate’s SLQ trace estimator with LinOp can be done by just calling the _trace_wrapper function:

#include <binders/pb11_trace_bind.h>  // _trace_wrapper binding  
#include "LinOp.h"                    // custom LinOp class  

PYBIND11_MODULE(_custom_trace, m) {   
  // m is the actual py::module (exported as _custom_trace)
  m.doc() = "custom trace estimator module";
  _trace_wrapper< false, float, LinOp >(m); 
}

The final extension module _custom_trace will have a method trace_slq as an attribute that can be called from Python to initiate the SLQ method with the corresponding LineOp type.

Wrapping types

If you’re trying to create bindings for a class that isn’t known to pybind11, or it doesn’t natively support the matvec and shape constraints, you can optionally provide a wrapper function to the last template parameter of _trace_wrapper:

struct WrappedLinOp {
  LinOp op;

  void WrappedLinOp(LinOp& _op) : op(_op){ }

  void matvec(const float* input, float* output){
    ... // implementation details, e.g. op.dot(input, output)
  }

  void shape() { ... }
}

void linop_wrapper(LinOp* op){ // or py::object, for non-native types
  return WrappedLinOp(*op); 
}

PYBIND11_MODULE(_custom_trace, m) {   
  // m is the actual py::module (exported as _custom_trace)
  m.doc() = "custom trace estimator module";
  _trace_wrapper< false, float, LinOp, linop_wrapper >(m); 
}

As a side effect, this also enables full access to matrix types that implement matrix-vector multiplication but don’t have the matching method names / signatures available needed to match the underlying LinearOperaor concept.

Full Example using Eigen

Here’s a real example of what simple code might look like that wraps a Eigen SparseMatrix for use with primate. Eigen supports matrix-vector multiplication out of the box with the overloaded operator*() and is understood natively by pybind11, thus it suffices to define wrapper class that respects the LinearOperator concept:

template< std::floating_point F >
struct SparseEigenLinearOperator {
  using value_type = F;
  using float_vec = Eigen::Matrix< F, Eigen::Dynamic, 1 >;

  const Eigen::SparseMatrix< F > A;  
  SparseEigenLinearOperator(const Eigen::SparseMatrix< F >& _mat) : A(_mat){}

  void matvec(const F* inp, F* out) const noexcept {
    auto input = Eigen::Map< const float_vec >(inp, A.cols(), 1); 
    auto output = Eigen::Map< float_vec >(out, A.rows(), 1);
    output = A * input; 
  }

  auto shape() const noexcept -> std::pair< size_t, size_t > {
    return std::make_pair((size_t) A.rows(), (size_t) A.cols());
  }
};

Then, simply write a quick wrapper function that converts a Eigen::SparseMatrix< F >* type to a SparseEigenLinearOperator< F > types and incorporate with _trace_wrapper

template< std::floating_point F >
auto eigen_sparse_wrapper(const Eigen::SparseMatrix< F >* A){
  return SparseEigenLinearOperator< F >(*A);
}

PYBIND11_MODULE(_custom_trace, m) {
  m.doc() = "custom trace estimator module";
  _trace_wrapper< false, float, Eigen::SparseMatrix< float > >(
    m, eigen_sparse_wrapper< float >
  ); 
}

That’s it. The module _custom_trace will now have an exported trace_slq method that you can call from Python. For a list of arguments, it’s recommended to see the slq declaration in trace.py.

Example: Log determinant

For explanatory purposes, the following code outline how to call the trace estimator to compute the log determinant using a custom user-implemented operator LinOp:

#include <cmath>                              // std::log
#include <_linear_operator/linear_operator.h> // LinearOperator
#include <_lanczos/lanczos.h>                 // sl_trace
#include "LinOp.h"                            // custom class

void slq_log_det(LinOp A, ...){ 
  static_assert(LinearOperator< LinOp >);  // Constraint check
  const auto matrix_func = std::log;       // any invocable
  auto rbg = ThreadedRNG64();              // default RNG
  auto estimates = vector< float >(n, 0);  // output estimates
  sl_trace< float >(                       // specific precision
    A, matrix_func, rbg,                   // main arguments
    ...,                                   // other inputs 
    estimates.data()                       // output 
  );
}