SLQ Trace guide

One of the base offerings of primate is an extensible implementation of the stochastic Lanczos method (SLQ). Despite its simplicity (only a few lines of pseudocode!), there are actually many methods called the ‘SLQ’ method, each typically performing a related but otherwise distinct functionality. For example, sl_quad computes Gaussian quadrature nodes / weights while sl_trace computes an unbiased estimator for the trace of a matrix function.

SLQ parameters

Below is the full signature of the SLQ function template:

// Stochastic Lanczos quadrature method
template< std::floating_point F, LinearOperator Matrix, ThreadSafeRBG RBG >
void slq (
  const Matrix& A,                    // Any *LinearOperator*
  const function< F(int,F*,F*) >& f,  // Matrix function 
  const function< bool(int) >& stop,  // Early-stop function
  const int nv,                       // Num. of sample vectors
  const Distribution dist,            // Sample vector distribution
  RBG& rng,                           // Random bit generator
  const int lanczos_degree,           // Krylov subspace degree
  const F lanczos_rtol,               // Lanczos residual tolerance
  const int orth,                     // Add. vectors to orthogonalize
  const int ncv,                      // Num. of Lanczos vectors
  const int num_threads,              // # threads to allocate 
  const int seed                      // Seed for RNG 
)

Many of the runtime arguments are documented in the lanczos or sl_trace docs; the compile-time (template) parameters are:

  • The floating point type (e.g. float, double, long double)
  • The operator type (e.g. Eigen::MatrixXf, torch::Tensor, LinOp)
  • The multi-threaded random number generator (e.g. ThreadedRNG64)

Any combination of types satisfying the concepts above (std::floating_point, LinearOperator, etc.) generates a function specialized for that combination at compile-time, a process known as template instantiation.

SLQ Implementation

The pseudo-code for the abstract stochastic Lanczos quadrature procedure is given below.

\begin{algorithm} \caption{Stochastic Lanczos Quadrature} \begin{algorithmic} \Input Symmetric operator ($A \in \mathbb{R}^{n \times n}$), function ($f: \mathbb{R} \to \mathbb{R}$) \Require Number of queries ($n_v$), Degree of quadrature ($k$) \Function{SLQ}{$A$, $n_v$, $k$} \State $\Gamma \gets 0$ \For{$j = 1, 2, \dots, n_v$} \State $v_i \sim \mathcal{D}$ where $\mathcal{D}$ satisfies $\mathbb{E}(v v^\top) = I$ \State $T^{(j)}(\alpha, \beta)$ $\gets$ $\mathrm{Lanczos}(A,v_j,k+1)$ \State $[\Theta, Y] \gets \mathrm{eigh\_tridiag}(T^{(j)}(\alpha, \beta))$ \State < Do something with the node/weight pairs $(\theta_i, \tau_i^2)$ >, where $\tau_i = \langle e_1, y_i \rangle$ \EndFor \Return $n \cdot \Gamma$ \EndFunction \end{algorithmic} \end{algorithm}

Given a valid set of parameters, the main body of the SLQ

  bool stop_flag = false; // convergence checking
  #pragma omp parallel shared(stop_flag)
  {
    int tid = omp_get_thread_num(); // thread-id 

    // < allocations for Q, alpha, beta, etc. > 

    #pragma omp for schedule(dynamic, chunk_size)
    for (i = 0; i < nv; ++i){
      if (stop_flag){ continue; }

      // Generate isotropic vector (populates q)
      generate_isotropic< F >(...);

      // Lanczos iteration (populates alpha + beta)
      lanczos_recurrence< F >(...); 

      // Gaussian quad. (populates nodes + weights)
      lanczos_quadrature< F >(...);

      // Run the user-supplied function 
      f(i, q, Q, nodes, weights);

      // Checks early-stopping condition
      #pragma omp critical
      {
        stop_flag = stop(i);
      }
    } // end for
  } // end parallel 

Example: Log determinant

For explanatory purposes, the following code outline how to call the trace estimator to compute the log determinant using a custom user-implemented operator LinOp:

#include <cmath>                              // std::log
#include <_linear_operator/linear_operator.h> // LinearOperator
#include <_lanczos/lanczos.h>                 // sl_trace
#include "LinOp.h"                            // custom class

void slq_log_det(LinOp A, ...){ 
  static_assert(LinearOperator< LinOp >);  // Constraint check
  const auto matrix_func = std::log;       // any invocable
  auto rbg = ThreadedRNG64();              // default RNG
  auto estimates = vector< float >(n, 0);  // output estimates
  sl_trace< float >(                       // specific precision
    A, matrix_func, rbg,                   // main arguments
    ...,                                   // other inputs 
    estimates.data()                       // output 
  );
}