Coverage for pattern_lens / prompts.py: 96%

24 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:15 -0700

1"implements `load_text_data` for loading prompts" 

2 

3import json 

4import random 

5from pathlib import Path 

6 

7 

8def load_text_data( 

9 fname: Path, 

10 min_chars: int | None = None, 

11 max_chars: int | None = None, 

12 shuffle: bool = False, 

13) -> list[dict]: 

14 """given `fname`, the path to a jsonl file, split prompts up into more reasonable sizes 

15 

16 # Parameters: 

17 - `fname : Path` 

18 jsonl file with prompts. Expects a list of dicts with a "text" key 

19 - `min_chars : int | None` 

20 (defaults to `None`) 

21 - `max_chars : int | None` 

22 (defaults to `None`) 

23 - `shuffle : bool` 

24 (defaults to `False`) 

25 

26 # Returns: 

27 - `list[dict]` 

28 processed list of prompts. Each prompt has a "text" key w/ a string value and some metadata. 

29 this is not guaranteed to be the same length as the input list! 

30 """ 

31 # read raw data 

32 with open(fname, "r") as f: 

33 data_raw: list[dict] = [json.loads(d) for d in f.readlines()] 

34 

35 # add fname metadata 

36 for d in data_raw: 

37 d["source_fname"] = fname.as_posix() 

38 

39 # trim too-short samples 

40 if min_chars is not None: 

41 data_raw = list( 

42 filter( 

43 lambda x: len(x["text"]) >= min_chars, 

44 data_raw, 

45 ), 

46 ) 

47 

48 # split up too-long samples 

49 if max_chars is not None: 

50 data_new: list[dict] = [] 

51 for d in data_raw: 

52 d_text: str = d["text"] 

53 while len(d_text) > max_chars: 

54 data_new.append( 

55 { 

56 **d, 

57 "text": d_text[:max_chars], 

58 }, 

59 ) 

60 d_text = d_text[max_chars:] 

61 data_new.append( 

62 { 

63 **d, 

64 "text": d_text, 

65 }, 

66 ) 

67 data_raw = data_new 

68 

69 # trim too-short samples again 

70 if min_chars is not None: 

71 data_raw = list( 

72 filter( 

73 lambda x: len(x["text"]) >= min_chars, 

74 data_raw, 

75 ), 

76 ) 

77 

78 # shuffle 

79 if shuffle: 

80 random.shuffle(data_raw) 

81 

82 return data_raw