SLQ Trace guide

This guide walks through the SLQ method implemented in primate, which can be specialized for different purposes.

SLQ as a function template

Below is the full signature of the SLQ function template:

// Stochastic Lanczos quadrature method
template< std::floating_point F, LinearOperator Matrix, ThreadSafeRBG RBG >
void slq (
  const Matrix& A,                    // Any *LinearOperator*
  const function< F(int,F*,F*) >& f,  // Generic function
  const function< bool(int) >& stop,  // Early-stop function
  const int nv,                       // Num. of sample vectors
  const Distribution dist,            // Sample vector distribution
  RBG& rng,                           // Random bit generator
  const int lanczos_degree,           // Krylov subspace degree
  const F lanczos_rtol,               // Lanczos residual tolerance
  const int orth,                     // Add. vectors to orthogonalize
  const int ncv,                      // Num. of Lanczos vectors
  const int num_threads,              // # threads to allocate 
  const int seed                      // Seed for RNG 
)

Many of the runtime arguments are documented in the lanczos or sl_trace docs; the compile-time (template) parameters are:

  • The floating point type (e.g. float, double, long double)
  • The operator type (e.g. Eigen::MatrixXf, torch::Tensor, LinOp)
  • The multi-threaded random number generator (e.g. ThreadedRNG64)

Note any type combination satisfying these concepts (e.g. std::floating_point, LinearOperator) generates a function specialized of said types at compile-time—this is known as template instantiation.

Generality via function passing

Given a valid set of parameters, the main body of the SLQ looks something like this:

  bool stop_flag = false;
  #pragma omp parallel shared(stop_flag)
  {
    // < allocations for Q, alpha, beta, etc. > 
    int tid = omp_get_thread_num(); // thread-id 
    
    #pragma omp for
    for (i = 0; i < nv; ++i){
      if (stop_flag){ continue; }
      generate_isotropic< F >(...); // populates q
      lanczos_recurrence< F >(...); // populates alpha + beta
      lanczos_quadrature< F >(...); // populates nodes + weights
      f(i, q, Q, nodes, weights);   // Run user-supplied function 
      #pragma omp critical
      {
        stop_flag = stop(i);        // Checks for early-stopping
      }
    } // end for
  } // end parallel 

There are two functions that can be used for generalizing SLQ for different purposes.

The first generic function f can read, save, or modify the information available from the iteration index i, the isotropic vector q, the Lanczos vectors Q, and/or the quadrature information nodes, weights. Note this function is run in the parallel section.

The second is a boolean-valued function stop which can be used to stop the iteration early, for example if convergence has been achieved according to some rule. Since this is run in the critical section, it is called sequentially.

Using SLQ to estimate \mathrm{tr}(f(A))

The SLQ method is often used to estimate the trace of an arbitrary matrix function:

\mathrm{tr}(f(A)), \quad \text{ where } f(A) = U f(\Lambda) U^T

It’s has been shown1 that the information obtained by the Lanczos method is sufficient to obtained a Gaussian quadrature approximation of the empirical spectral measure of A. By sampling zero-mean vectors satisfying \mathbb{E}[v v^T] = I, one can obtain estimates of the trace above: \operatorname{tr}(f(A)) \approx \frac{n}{\mathrm{n}_{\mathrm{v}}} \sum_{l=1}^{\mathrm{n}_{\mathrm{v}}}\left(\sum_{k=0}^m\left(\tau_k^{(l)}\right)^2 f\left(\theta_k^{(l)}\right)\right)

It turns out averaging these trace estimates yields unbiased, Girard-Hutchinson estimator of the trace. To see why this estimator is unbiased, note that: \mathtt{tr}(A) = \mathbb{E}[v^T A v] \approx \frac{1}{n_v}\sum\limits_{i=1}^{n_v} v_i^\top A v_i

Thus, all we need to do to estimate the trace of a matrix function is multiply and sum the quadrature nodes and weights output by SLQ.

sl_trace method

To see how these formulas are actually implemented with the generic SLQ implementation, here’s an abbreviated form of the sl_trace function implemented by primate:

template< std::floating_point F, LinearOperator Matrix, ThreadSafeRBG RBG >
void sl_trace(
  const Matrix& A, const std::function< F(F) > sf, RBG& rbg, 
  const int nv, const int dist, const int engine_id, const int seed,
  const int deg, const float lanczos_rtol, const int orth, const int ncv,
  const F atol, const F rtol
  F* estimates
){  
  using VectorF = Eigen::Array< F, Dynamic, 1>;

  // Parameterize the trace function (runs in parallel)
  auto trace_f = [&](int i, F* q, F* Q, F* nodes, F* weights){
    Map< VectorF > nodes_v(nodes, deg, 1);     // no-op
    Map< VectorF > weights_v(weights, deg, 1); // no-op
    nodes_v.unaryExpr(sf);
    estimates[i] = (nodes_v * weights_v).sum();
  };
  
  // Convergence checking like scipy.integrate.quadrature
  int n = 0;
  F mu_est = 0.0, mu_pre = 0.0;
  const auto early_stop = [&](int i) -> bool {
    ++n; // Number of estimates
    mu_est = (1.0 / F(n)) * (estimates[i] + (n - 1) * mu_pre); 
    bool atol_check = abs(mu_est - mu_pre) <= atol;
    bool rtol_check = abs(mu_est - mu_pre) / mu_est <= rtol; 
    mu_pre = mu_est; 
    return atol_check || rtol_check;
  };

  // Execute the stochastic Lanczos quadrature with the trace function 
  slq< float >(A, trace_f, early_stop, ...);
}

As before, two functions are used to parameterize the slq method.

The first (trace_f) applies an arbitrary spectral function sf to the Rayleigh-Ritz values obtained by the Lanczos tridiagonalization of A(or equivalently, the nodes of the Gaussian quadrature). These are the \theta’s in the pseudocode above. When multiplied by the weights of the quadrature, the corresponding sum forms an estimate of the trace of the matrix function.

The second function early_stop is used to check for convergence of the estimator. First, it uses the trace estimate x_n to update the sample mean \mu_n via the formula:

\mu_n = n^{-1} [x_n + (n - 1)\mu_{n-1}]

Then, much in the same way the quadrature function from scipy.integrate approximates a definite integral, it checks for convergence using the absolute and relative tolerances supplied by the user. Returning true signals convergence, stopping the iteration early.

References