Coverage for pattern_lens\prompts.py: 96%

24 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-16 20:39 -0700

1"implements `load_text_data` for loading prompts" 

2 

3import json 

4import random 

5from pathlib import Path 

6 

7 

8def load_text_data( 

9 fname: Path, 

10 min_chars: int | None = None, 

11 max_chars: int | None = None, 

12 shuffle: bool = False, 

13) -> list[dict]: 

14 """given `fname`, the path to a jsonl file, split prompts up into more reasonable sizes 

15 

16 # Parameters: 

17 - `fname : Path` 

18 jsonl file with prompts. Expects a list of dicts with a "text" key 

19 - `min_chars : int | None` 

20 (defaults to `None`) 

21 - `max_chars : int | None` 

22 (defaults to `None`) 

23 - `shuffle : bool` 

24 (defaults to `False`) 

25 

26 # Returns: 

27 - `list[dict]` 

28 new, processed list of prompts. Each prompt has a "text" key with a string value, and some metadata. this is not guaranteed to be the same length as the input list! 

29 """ 

30 # read raw data 

31 with open(fname, "r") as f: 

32 data_raw: list[dict] = [json.loads(d) for d in f.readlines()] 

33 

34 # add fname metadata 

35 for d in data_raw: 

36 d["source_fname"] = fname.as_posix() 

37 

38 # trim too-short samples 

39 if min_chars is not None: 

40 data_raw = list( 

41 filter( 

42 lambda x: len(x["text"]) >= min_chars, 

43 data_raw, 

44 ) 

45 ) 

46 

47 # split up too-long samples 

48 if max_chars is not None: 

49 data_new: list[dict] = [] 

50 for d in data_raw: 

51 d_text: str = d["text"] 

52 while len(d_text) > max_chars: 

53 data_new.append( 

54 { 

55 **d, 

56 "text": d_text[:max_chars], 

57 } 

58 ) 

59 d_text = d_text[max_chars:] 

60 data_new.append( 

61 { 

62 **d, 

63 "text": d_text, 

64 } 

65 ) 

66 data_raw = data_new 

67 

68 # trim too-short samples again 

69 if min_chars is not None: 

70 data_raw = list( 

71 filter( 

72 lambda x: len(x["text"]) >= min_chars, 

73 data_raw, 

74 ) 

75 ) 

76 

77 # shuffle 

78 if shuffle: 

79 random.shuffle(data_raw) 

80 

81 return data_raw