kiln_ai.adapters.repair.repair_task
1import json 2from typing import Type 3 4from pydantic import BaseModel, Field 5 6from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_registry 7from kiln_ai.datamodel import Priority, Project, Task, TaskRequirement, TaskRun 8 9 10# TODO add evaluator rating 11class RepairTaskInput(BaseModel): 12 original_prompt: str 13 original_input: str 14 original_output: str 15 evaluator_feedback: str = Field( 16 min_length=1, 17 description="Feedback from an evaluator on how to repair the task run.", 18 ) 19 20 21class RepairTaskRun(Task, parent_of={}): 22 def __init__(self, original_task: Task): 23 # Keep the typechecker happy 24 tmp_project = Project(name="Repair") 25 super().__init__( 26 name="Repair", 27 parent=tmp_project, 28 description="Repair a task run, given feedback from an evaluator about how the response can be improved.", 29 instruction="You are an assistant which helps improve output from another assistant (original assistant). You'll be provided a task that the original assistant executed (prompt), \ 30the input it was given, and the output it generated. An evaluator has determined that the output it generated did not satisfy the task and should be improved. The evaluator will provide \ 31feedback describing what should be improved. Your job is to understand the evaluator's feedback and improve the response.", 32 requirements=[ 33 TaskRequirement( 34 name="Follow Eval Feedback", 35 instruction="The evaluator's feedback is the most important thing to consider. If it conflicts with the original task instruction or prompt, prioritize the evaluator's feedback.", 36 priority=Priority.p0, 37 ) 38 ], 39 input_json_schema=json.dumps(RepairTaskInput.model_json_schema()), 40 output_json_schema=original_task.output_json_schema, 41 ) 42 43 @classmethod 44 def _original_prompt(cls, run: TaskRun, task: Task) -> str: 45 prompt_builder_class: Type[BasePromptBuilder] | None = None 46 prompt_builder_name = run.output.source.properties.get( 47 "prompt_builder_name", None 48 ) 49 if prompt_builder_name is not None and isinstance(prompt_builder_name, str): 50 prompt_builder_class = prompt_builder_registry.get( 51 prompt_builder_name, None 52 ) 53 if prompt_builder_class is None: 54 raise ValueError(f"No prompt builder found for name: {prompt_builder_name}") 55 prompt_builder = prompt_builder_class(task=task) 56 if not isinstance(prompt_builder, BasePromptBuilder): 57 raise ValueError( 58 f"Prompt builder {prompt_builder_name} is not a valid prompt builder" 59 ) 60 return prompt_builder.build_prompt() 61 62 @classmethod 63 def build_repair_task_input( 64 cls, original_task: Task, task_run: TaskRun, evaluator_feedback: str 65 ) -> RepairTaskInput: 66 original_prompt = cls._original_prompt(task_run, original_task) 67 return RepairTaskInput( 68 original_prompt=original_prompt, 69 original_input=task_run.input, 70 original_output=task_run.output.output, 71 evaluator_feedback=evaluator_feedback, 72 )
12class RepairTaskInput(BaseModel): 13 original_prompt: str 14 original_input: str 15 original_output: str 16 evaluator_feedback: str = Field( 17 min_length=1, 18 description="Feedback from an evaluator on how to repair the task run.", 19 )
Usage docs: https://docs.pydantic.dev/2.8/concepts/models/
A base class for creating Pydantic models.
Attributes: __class_vars__: The names of classvars defined on the model. __private_attributes__: Metadata about the private attributes of the model. __signature__: The signature for instantiating the model.
__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
__args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a `RootModel`.
__pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_extra__: An instance attribute with the values of extra fields from validation when
`model_config['extra'] == 'allow'`.
__pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
__pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
22class RepairTaskRun(Task, parent_of={}): 23 def __init__(self, original_task: Task): 24 # Keep the typechecker happy 25 tmp_project = Project(name="Repair") 26 super().__init__( 27 name="Repair", 28 parent=tmp_project, 29 description="Repair a task run, given feedback from an evaluator about how the response can be improved.", 30 instruction="You are an assistant which helps improve output from another assistant (original assistant). You'll be provided a task that the original assistant executed (prompt), \ 31the input it was given, and the output it generated. An evaluator has determined that the output it generated did not satisfy the task and should be improved. The evaluator will provide \ 32feedback describing what should be improved. Your job is to understand the evaluator's feedback and improve the response.", 33 requirements=[ 34 TaskRequirement( 35 name="Follow Eval Feedback", 36 instruction="The evaluator's feedback is the most important thing to consider. If it conflicts with the original task instruction or prompt, prioritize the evaluator's feedback.", 37 priority=Priority.p0, 38 ) 39 ], 40 input_json_schema=json.dumps(RepairTaskInput.model_json_schema()), 41 output_json_schema=original_task.output_json_schema, 42 ) 43 44 @classmethod 45 def _original_prompt(cls, run: TaskRun, task: Task) -> str: 46 prompt_builder_class: Type[BasePromptBuilder] | None = None 47 prompt_builder_name = run.output.source.properties.get( 48 "prompt_builder_name", None 49 ) 50 if prompt_builder_name is not None and isinstance(prompt_builder_name, str): 51 prompt_builder_class = prompt_builder_registry.get( 52 prompt_builder_name, None 53 ) 54 if prompt_builder_class is None: 55 raise ValueError(f"No prompt builder found for name: {prompt_builder_name}") 56 prompt_builder = prompt_builder_class(task=task) 57 if not isinstance(prompt_builder, BasePromptBuilder): 58 raise ValueError( 59 f"Prompt builder {prompt_builder_name} is not a valid prompt builder" 60 ) 61 return prompt_builder.build_prompt() 62 63 @classmethod 64 def build_repair_task_input( 65 cls, original_task: Task, task_run: TaskRun, evaluator_feedback: str 66 ) -> RepairTaskInput: 67 original_prompt = cls._original_prompt(task_run, original_task) 68 return RepairTaskInput( 69 original_prompt=original_prompt, 70 original_input=task_run.input, 71 original_output=task_run.output.output, 72 evaluator_feedback=evaluator_feedback, 73 )
Usage docs: https://docs.pydantic.dev/2.8/concepts/models/
A base class for creating Pydantic models.
Attributes: __class_vars__: The names of classvars defined on the model. __private_attributes__: Metadata about the private attributes of the model. __signature__: The signature for instantiating the model.
__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
__args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a `RootModel`.
__pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
__pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
__pydantic_extra__: An instance attribute with the values of extra fields from validation when
`model_config['extra'] == 'allow'`.
__pydantic_fields_set__: An instance attribute with the names of fields explicitly set.
__pydantic_private__: Instance attribute with the values of private attributes set on the model instance.
23 def __init__(self, original_task: Task): 24 # Keep the typechecker happy 25 tmp_project = Project(name="Repair") 26 super().__init__( 27 name="Repair", 28 parent=tmp_project, 29 description="Repair a task run, given feedback from an evaluator about how the response can be improved.", 30 instruction="You are an assistant which helps improve output from another assistant (original assistant). You'll be provided a task that the original assistant executed (prompt), \ 31the input it was given, and the output it generated. An evaluator has determined that the output it generated did not satisfy the task and should be improved. The evaluator will provide \ 32feedback describing what should be improved. Your job is to understand the evaluator's feedback and improve the response.", 33 requirements=[ 34 TaskRequirement( 35 name="Follow Eval Feedback", 36 instruction="The evaluator's feedback is the most important thing to consider. If it conflicts with the original task instruction or prompt, prioritize the evaluator's feedback.", 37 priority=Priority.p0, 38 ) 39 ], 40 input_json_schema=json.dumps(RepairTaskInput.model_json_schema()), 41 output_json_schema=original_task.output_json_schema, 42 )
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError
][pydantic_core.ValidationError] if the input data cannot be
validated to form a valid model.
self
is explicitly positional-only to allow self
as a field name.
63 @classmethod 64 def build_repair_task_input( 65 cls, original_task: Task, task_run: TaskRun, evaluator_feedback: str 66 ) -> RepairTaskInput: 67 original_prompt = cls._original_prompt(task_run, original_task) 68 return RepairTaskInput( 69 original_prompt=original_prompt, 70 original_input=task_run.input, 71 original_output=task_run.output.output, 72 evaluator_feedback=evaluator_feedback, 73 )
105 def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: 106 """We need to both initialize private attributes and call the user-defined model_post_init 107 method. 108 """ 109 init_private_attributes(self, context) 110 original_model_post_init(self, context)
We need to both initialize private attributes and call the user-defined model_post_init method.