Coverage for src/twofas/core.py: 100%

83 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-22 17:34 +0100

1import json 

2import sys 

3import typing 

4from collections import defaultdict 

5from pathlib import Path 

6from typing import Optional 

7 

8from ._security import decrypt, keyring_manager 

9from ._types import TwoFactorAuthDetails, into_class 

10from .utils import flatten, fuzzy_match 

11 

12T_TwoFactorAuthDetails = typing.TypeVar("T_TwoFactorAuthDetails", bound=TwoFactorAuthDetails) 

13 

14 

15class TwoFactorStorage(typing.Generic[T_TwoFactorAuthDetails]): 

16 _multidict: defaultdict[str, list[T_TwoFactorAuthDetails]] 

17 count: int 

18 

19 def __init__(self, _klass: typing.Type[T_TwoFactorAuthDetails] = None) -> None: 

20 # _klass is purely for annotation atm 

21 

22 self._multidict = defaultdict(list) # one name can map to multiple keys 

23 self.count = 0 

24 

25 def __len__(self) -> int: 

26 return self.count 

27 

28 def __bool__(self) -> bool: 

29 return self.count > 0 

30 

31 def add(self, entries: list[T_TwoFactorAuthDetails]) -> None: 

32 for entry in entries: 

33 name = (entry.name or "").lower() 

34 self._multidict[name].append(entry) 

35 

36 self.count += len(entries) 

37 

38 def __getitem__(self, item: str) -> "list[T_TwoFactorAuthDetails]": 

39 # class[property] syntax 

40 return self._multidict[item.lower()] 

41 

42 def keys(self) -> list[str]: 

43 return list(self._multidict.keys()) 

44 

45 def items(self) -> typing.Generator[tuple[str, list[T_TwoFactorAuthDetails]], None, None]: 

46 yield from self._multidict.items() 

47 

48 def _fuzzy_find(self, find: typing.Optional[str], fuzz_threshold: int) -> list[T_TwoFactorAuthDetails]: 

49 if not find: 

50 # don't loop 

51 return list(self) 

52 

53 all_items = self._multidict.items() 

54 

55 find = find.lower() 

56 # if nothing found exactly, try again but fuzzy (could be slower) 

57 # search in key: 

58 fuzzy = [ 

59 # search in key 

60 v 

61 for k, v in all_items 

62 if fuzzy_match(k.lower(), find) > fuzz_threshold 

63 ] 

64 if fuzzy and (flat := flatten(fuzzy)): 

65 return flat 

66 

67 # search in value: 

68 # str is short, repr is json 

69 return [ 

70 # search in value instead 

71 v 

72 for v in list(self) 

73 if fuzzy_match(repr(v).lower(), find) > fuzz_threshold 

74 ] 

75 

76 def generate(self) -> list[tuple[str, str]]: 

77 return [(_.name, _.generate()) for _ in self] 

78 

79 def find( 

80 self, target: Optional[str] = None, fuzz_threshold: int = 75 

81 ) -> "TwoFactorStorage[T_TwoFactorAuthDetails]": 

82 target = (target or "").lower() 

83 # first try exact match: 

84 if items := self._multidict.get(target): 

85 return new_auth_storage(items) 

86 # else: fuzzy match: 

87 return new_auth_storage(self._fuzzy_find(target, fuzz_threshold)) 

88 

89 def all(self) -> list[T_TwoFactorAuthDetails]: 

90 return list(self) 

91 

92 def __iter__(self) -> typing.Generator[T_TwoFactorAuthDetails, None, None]: 

93 for entries in self._multidict.values(): 

94 yield from entries 

95 

96 def __repr__(self) -> str: 

97 return f"<TwoFactorStorage with {len(self._multidict)} keys and {self.count} entries>" 

98 

99 

100def new_auth_storage(initial_items: list[T_TwoFactorAuthDetails] = None) -> TwoFactorStorage[T_TwoFactorAuthDetails]: 

101 storage: TwoFactorStorage[T_TwoFactorAuthDetails] = TwoFactorStorage() 

102 

103 if initial_items: 

104 storage.add(initial_items) 

105 

106 return storage 

107 

108 

109def load_services(filename: str, _max_retries: int = 0) -> TwoFactorStorage[TwoFactorAuthDetails]: 

110 filepath = Path(filename).expanduser() 

111 with filepath.open() as f: 

112 data_raw = f.read() 

113 data = json.loads(data_raw) 

114 

115 storage: TwoFactorStorage[TwoFactorAuthDetails] = new_auth_storage() 

116 

117 if decrypted := data["services"]: 

118 services = into_class(decrypted, TwoFactorAuthDetails) 

119 storage.add(services) 

120 return storage 

121 

122 encrypted = data["servicesEncrypted"] 

123 

124 retries = 0 

125 while True: 

126 password = keyring_manager.retrieve_credentials(filename) or keyring_manager.save_credentials(filename) 

127 try: 

128 entries = decrypt(encrypted, password) 

129 storage.add(entries) 

130 return storage 

131 except PermissionError as e: 

132 retries += 1 # only really useful for pytest 

133 print(e, file=sys.stderr) 

134 keyring_manager.delete_credentials(filename) 

135 

136 if _max_retries and retries > _max_retries: 

137 raise e