Coverage for src/pqlattice/lattice/_bkz.py: 98%

50 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-01-12 21:36 +0100

1import logging 

2 

3from .._utils import as_integer, as_rational 

4from ..typing import Matrix, SquareMatrix, Vector, validate_aliases 

5from ._gso import gso 

6from ._lll import lll 

7from ._svp import schnorr_euchner_svp 

8 

9logger = logging.getLogger("BKZ") 

10 

11 

12@validate_aliases 

13def update_block(lattice_basis: SquareMatrix, new_vector: Vector, start_index: int, block_size: int) -> SquareMatrix: 

14 B = as_integer(lattice_basis) 

15 local_basis = as_integer([new_vector] + B[start_index : start_index + block_size].tolist()) 

16 reduced_local = lll(local_basis) 

17 

18 zero_vec = reduced_local[0] 

19 if any(s != 0 for s in zero_vec): 

20 raise ValueError("block update failed") 

21 

22 cleaned_basis = as_integer(reduced_local[1:]) 

23 

24 for i in range(block_size): 

25 B[start_index + i] = cleaned_basis[i] 

26 

27 return B 

28 

29 

30@validate_aliases 

31def bkz(lattice_basis: SquareMatrix, block_size: int = 10) -> SquareMatrix: 

32 """_summary_ 

33 

34 Parameters 

35 ---------- 

36 lattice_basis : SquareMatrix 

37 _description_ 

38 block_size : int, optional 

39 _description_, by default 10 

40 

41 Returns 

42 ------- 

43 SquareMatrix 

44 _description_ 

45 """ 

46 n, m = lattice_basis.shape 

47 

48 B = lll(lattice_basis) 

49 

50 is_changed = True 

51 # iteration = 0 

52 # logger.info("starting bkz loop") 

53 while is_changed: 

54 is_changed = False 

55 # iteration += 1 

56 # logger.info(f"iter {iteration}") 

57 for k in range(n - 1): 

58 h = min(block_size, n - k) 

59 local_basis: Matrix = B[k : k + h] 

60 # logger.info(f"block {local_basis}") 

61 B_star, U = gso(local_basis) 

62 B_norms2 = as_rational([sum(x * x for x in v) for v in B_star]) 

63 mu = U.T 

64 coeffs = schnorr_euchner_svp(mu, B_norms2) 

65 

66 is_trivial = True 

67 if abs(coeffs[0]) == 1: 

68 is_trivial = all(c == 0 for c in coeffs[1:]) 

69 elif coeffs[0] == 0: 

70 is_trivial = False 

71 else: 

72 is_trivial = False 

73 

74 if is_trivial: 

75 continue 

76 

77 # logger.info(f"shortest vector in block norm: {norm(coeffs @ local_basis)}") 

78 new_vector = as_integer([0] * m) 

79 for i in range(h): 

80 if coeffs[i] == 0: 

81 continue 

82 for d in range(m): 

83 new_vector[d] += coeffs[i] * local_basis[i][d] 

84 

85 B = update_block(B, new_vector, k, h) 

86 is_changed = True 

87 return B