Parameterizing SLQ
This guide walks through how to parameterize the SLQ method implemented in primate
on the C++ side to approximate some spectral quantity of interest.
SLQ as a function template
Below is the full signature of the SLQ function template:
// Stochastic Lanczos quadrature method
template< std::floating_point F, LinearOperator Matrix, ThreadSafeRBG RBG >
void slq (
const Matrix& A, // Any *LinearOperator*
const function< F(int,F*,F*) >& f, // Generic function
const function< bool(int) >& stop, // Early-stop function
const int nv, // Num. of sample vectors
const Distribution dist, // Sample vector distribution
& rng, // Random bit generator
RBGconst int lanczos_degree, // Krylov subspace degree
const F lanczos_rtol, // Lanczos residual tolerance
const int orth, // Add. vectors to orthogonalize
const int ncv, // Num. of Lanczos vectors
const int num_threads, // # threads to allocate
const int seed // Seed for RNG
)
Many of the runtime arguments are documented in the lanczos or sl_trace docs; the compile-time (template) parameters are:
- The floating point type (e.g.
float
,double
,long double
) - The operator type (e.g.
Eigen::MatrixXf
,torch::Tensor
,LinOp
) - The multi-threaded random number generator (e.g.
ThreadedRNG64
)
Note any type combination satisfying these concepts (e.g. std::floating_point
, LinearOperator
) generates a function specialized of said types at compile-time—this is known as template instantiation.
Generality via function passing
Given a valid set of parameters, the main body of the SLQ looks something like this:
bool stop_flag = false;
#pragma omp parallel shared(stop_flag)
{
// < allocations for Q, alpha, beta, etc. >
int tid = omp_get_thread_num(); // thread-id
#pragma omp for
for (i = 0; i < nv; ++i){
if (stop_flag){ continue; }
< F >(...); // populates q
generate_isotropic< F >(...); // populates alpha + beta
lanczos_recurrence< F >(...); // populates nodes + weights
lanczos_quadrature(i, q, Q, nodes, weights); // Run user-supplied function
f#pragma omp critical
{
= stop(i); // Checks for early-stopping
stop_flag }
} // end for
} // end parallel
There are two functions that can be used for generalizing SLQ for different purposes.
The first generic function f
can read, save, or modify the information available from the iteration index i
, the isotropic vector q
, the Lanczos vectors Q
, and/or the quadrature information nodes
, weights
. Note this function is run in the parallel section.
The second is a boolean-valued function stop
which can be used to stop the iteration early, for example if convergence has been achieved according to some rule. Since this is run in the critical section, it is called sequentially.
Using SLQ to estimate \mathrm{tr}(f(A))
The SLQ method is often used to estimate the trace of an arbitrary matrix function:
\mathrm{tr}(f(A)), \quad \text{ where } f(A) = U f(\Lambda) U^T
It’s has been shown1 that the information obtained by the Lanczos method is sufficient to obtained a Gaussian quadrature approximation of the empirical spectral measure of A. By sampling zero-mean vectors satisfying \mathbb{E}[v v^T] = I, one can obtain estimates of the trace above: \operatorname{tr}(f(A)) \approx \frac{n}{\mathrm{n}_{\mathrm{v}}} \sum_{l=1}^{\mathrm{n}_{\mathrm{v}}}\left(\sum_{k=0}^m\left(\tau_k^{(l)}\right)^2 f\left(\theta_k^{(l)}\right)\right)
It turns out averaging these trace estimates yields unbiased, Girard-Hutchinson estimator of the trace. To see why this estimator is unbiased, note that: \mathtt{tr}(A) = \mathbb{E}[v^T A v] \approx \frac{1}{n_v}\sum\limits_{i=1}^{n_v} v_i^\top A v_i
Thus, all we need to do to estimate the trace of a matrix function is multiply and sum the quadrature nodes and weights output by SLQ.
sl_trace
method
To see how these formulas are actually implemented with the generic SLQ implementation, here’s an abbreviated form of the sl_trace
function implemented by primate
:
template< std::floating_point F, LinearOperator Matrix, ThreadSafeRBG RBG >
void sl_trace(
const Matrix& A, const std::function< F(F) > sf, RBG& rbg,
const int nv, const int dist, const int engine_id, const int seed,
const int deg, const float lanczos_rtol, const int orth, const int ncv,
const F atol, const F rtol
* estimates
F){
using VectorF = Eigen::Array< F, Dynamic, 1>;
// Parameterize the trace function (runs in parallel)
auto trace_f = [&](int i, F* q, F* Q, F* nodes, F* weights){
< VectorF > nodes_v(nodes, deg, 1); // no-op
Map< VectorF > weights_v(weights, deg, 1); // no-op
Map.unaryExpr(sf);
nodes_v[i] = (nodes_v * weights_v).sum();
estimates};
// Convergence checking like scipy.integrate.quadrature
int n = 0;
= 0.0, mu_pre = 0.0;
F mu_est const auto early_stop = [&](int i) -> bool {
++n; // Number of estimates
= (1.0 / F(n)) * (estimates[i] + (n - 1) * mu_pre);
mu_est bool atol_check = abs(mu_est - mu_pre) <= atol;
bool rtol_check = abs(mu_est - mu_pre) / mu_est <= rtol;
= mu_est;
mu_pre return atol_check || rtol_check;
};
// Execute the stochastic Lanczos quadrature with the trace function
< float >(A, trace_f, early_stop, ...);
slq}
As before, two functions are used to parameterize the slq
method.
The first (trace_f
) applies an arbitrary spectral function sf
to the Rayleigh-Ritz values obtained by the Lanczos tridiagonalization of A
(or equivalently, the nodes of the Gaussian quadrature). These are the \theta’s in the pseudocode above. When multiplied by the weights of the quadrature, the corresponding sum forms an estimate of the trace of the matrix function.
The second function early_stop
is used to check for convergence of the estimator. First, it uses the trace estimate x_n to update the sample mean \mu_n via the formula:
\mu_n = n^{-1} [x_n + (n - 1)\mu_{n-1}]
Then, much in the same way the quadrature function from scipy.integrate
approximates a definite integral, it checks for convergence using the absolute and relative tolerances supplied by the user. Returning true
signals convergence, stopping the iteration early.
References
Footnotes
Ubaru, S., Chen, J., & Saad, Y. (2017). Fast estimation of tr(f(A)) via stochastic Lanczos quadrature. SIAM Journal on Matrix Analysis and Applications, 38(4), 1075-1099.↩︎