Coverage for pattern_lens\prompts.py: 96%
24 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
1"implements `load_text_data` for loading prompts"
3import json
4import random
5from pathlib import Path
8def load_text_data(
9 fname: Path,
10 min_chars: int | None = None,
11 max_chars: int | None = None,
12 shuffle: bool = False,
13) -> list[dict]:
14 """given `fname`, the path to a jsonl file, split prompts up into more reasonable sizes
16 # Parameters:
17 - `fname : Path`
18 jsonl file with prompts. Expects a list of dicts with a "text" key
19 - `min_chars : int | None`
20 (defaults to `None`)
21 - `max_chars : int | None`
22 (defaults to `None`)
23 - `shuffle : bool`
24 (defaults to `False`)
26 # Returns:
27 - `list[dict]`
28 new, processed list of prompts. Each prompt has a "text" key with a string value, and some metadata. this is not guaranteed to be the same length as the input list!
29 """
30 # read raw data
31 with open(fname, "r") as f:
32 data_raw: list[dict] = [json.loads(d) for d in f.readlines()]
34 # add fname metadata
35 for d in data_raw:
36 d["source_fname"] = fname.as_posix()
38 # trim too-short samples
39 if min_chars is not None:
40 data_raw = list(
41 filter(
42 lambda x: len(x["text"]) >= min_chars,
43 data_raw,
44 )
45 )
47 # split up too-long samples
48 if max_chars is not None:
49 data_new: list[dict] = []
50 for d in data_raw:
51 d_text: str = d["text"]
52 while len(d_text) > max_chars:
53 data_new.append(
54 {
55 **d,
56 "text": d_text[:max_chars],
57 }
58 )
59 d_text = d_text[max_chars:]
60 data_new.append(
61 {
62 **d,
63 "text": d_text,
64 }
65 )
66 data_raw = data_new
68 # trim too-short samples again
69 if min_chars is not None:
70 data_raw = list(
71 filter(
72 lambda x: len(x["text"]) >= min_chars,
73 data_raw,
74 )
75 )
77 # shuffle
78 if shuffle:
79 random.shuffle(data_raw)
81 return data_raw