Coverage for pattern_lens/prompts.py: 96%
24 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
1"implements `load_text_data` for loading prompts"
3import json
4import random
5from pathlib import Path
8def load_text_data(
9 fname: Path,
10 min_chars: int | None = None,
11 max_chars: int | None = None,
12 shuffle: bool = False,
13) -> list[dict]:
14 """given `fname`, the path to a jsonl file, split prompts up into more reasonable sizes
16 # Parameters:
17 - `fname : Path`
18 jsonl file with prompts. Expects a list of dicts with a "text" key
19 - `min_chars : int | None`
20 (defaults to `None`)
21 - `max_chars : int | None`
22 (defaults to `None`)
23 - `shuffle : bool`
24 (defaults to `False`)
26 # Returns:
27 - `list[dict]`
28 processed list of prompts. Each prompt has a "text" key w/ a string value and some metadata.
29 this is not guaranteed to be the same length as the input list!
30 """
31 # read raw data
32 with open(fname, "r") as f:
33 data_raw: list[dict] = [json.loads(d) for d in f.readlines()]
35 # add fname metadata
36 for d in data_raw:
37 d["source_fname"] = fname.as_posix()
39 # trim too-short samples
40 if min_chars is not None:
41 data_raw = list(
42 filter(
43 lambda x: len(x["text"]) >= min_chars,
44 data_raw,
45 ),
46 )
48 # split up too-long samples
49 if max_chars is not None:
50 data_new: list[dict] = []
51 for d in data_raw:
52 d_text: str = d["text"]
53 while len(d_text) > max_chars:
54 data_new.append(
55 {
56 **d,
57 "text": d_text[:max_chars],
58 },
59 )
60 d_text = d_text[max_chars:]
61 data_new.append(
62 {
63 **d,
64 "text": d_text,
65 },
66 )
67 data_raw = data_new
69 # trim too-short samples again
70 if min_chars is not None:
71 data_raw = list(
72 filter(
73 lambda x: len(x["text"]) >= min_chars,
74 data_raw,
75 ),
76 )
78 # shuffle
79 if shuffle:
80 random.shuffle(data_raw)
82 return data_raw