Coverage for pydantic_ai_jupyter / display.py: 96%
75 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-26 11:36 -0800
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-26 11:36 -0800
1"""Main display runner for pydantic-ai agents in Jupyter notebooks."""
3from __future__ import annotations
5import json
6from typing import TYPE_CHECKING, Any
8from IPython.display import display
9from pydantic_ai import AgentRunResult
10from pydantic_ai.messages import (
11 FunctionToolCallEvent,
12 FunctionToolResultEvent,
13 PartDeltaEvent,
14 PartStartEvent,
15 TextPart,
16 TextPartDelta,
17 ThinkingPart,
18 ThinkingPartDelta,
19 ToolCallPart,
20 ToolCallPartDelta,
21)
22from pydantic_ai.run import AgentRunResultEvent
24from .markdown import Markdown
25from .views import (
26 DebugEventView,
27 ErrorView,
28 StreamingToolCallView,
29 ThinkingView,
30 ToolResultView,
31)
33if TYPE_CHECKING:
34 from pydantic_ai import Agent
37async def run_with_display(
38 agent: Agent[Any, Any],
39 user_prompt: str | None = None,
40 *,
41 debug: bool = False,
42 **kwargs: Any,
43) -> AgentRunResult[str] | None:
44 """Run an agent with live Jupyter display of tool calls and streaming text.
46 All arguments except `debug` are passed directly to `agent.run_stream_events()`.
48 Args:
49 agent: The pydantic-ai Agent to run
50 user_prompt: The user's prompt/question
51 debug: If True, display unhandled lifecycle events
52 **kwargs: Passed to agent.run_stream_events() - includes deps, message_history,
53 model_settings, usage_limits, toolsets, etc.
55 Returns:
56 The agent result, or None if an exception occurred
58 Example:
59 ```python
60 result = await run_with_display(agent, "What's the weather?")
62 # Multi-turn conversation
63 result = await run_with_display(
64 agent, "What about London?",
65 message_history=result.all_messages(),
66 )
68 # With dependencies and settings
69 result = await run_with_display(
70 agent, "Analyze this",
71 deps=my_deps,
72 model_settings={"temperature": 0.5},
73 )
74 ```
75 """
76 current_markdown: Markdown | None = None
77 current_thinking: ThinkingView | None = None
78 # Track streaming tool calls by part index
79 streaming_tool_calls: dict[int, StreamingToolCallView] = {}
81 def get_or_create_markdown() -> Markdown:
82 nonlocal current_markdown
83 if current_markdown is None:
84 current_markdown = Markdown(content="")
85 current_markdown.display()
86 return current_markdown
88 def finish_markdown() -> None:
89 nonlocal current_markdown
90 current_markdown = None
92 def get_or_create_thinking() -> ThinkingView:
93 nonlocal current_thinking
94 if current_thinking is None:
95 current_thinking = ThinkingView(content="")
96 current_thinking.display()
97 return current_thinking
99 def finish_thinking() -> None:
100 nonlocal current_thinking
101 current_thinking = None
103 def finish_streaming_tool_calls() -> None:
104 nonlocal streaming_tool_calls
105 streaming_tool_calls = {}
107 try:
108 async for event in agent.run_stream_events(user_prompt, **kwargs):
109 # Handle streaming tool call parts (args streaming in)
110 if isinstance(event, PartStartEvent) and isinstance(
111 event.part, ToolCallPart
112 ):
113 finish_markdown()
114 finish_thinking()
115 # Start a new streaming tool call view
116 view = StreamingToolCallView(
117 tool_name=event.part.tool_name,
118 args=event.part.args
119 if isinstance(event.part.args, str)
120 else json.dumps(event.part.args),
121 tool_call_id=event.part.tool_call_id,
122 )
123 view.display()
124 view.update()
125 streaming_tool_calls[event.index] = view
127 elif isinstance(event, PartDeltaEvent) and isinstance(
128 event.delta, ToolCallPartDelta
129 ):
130 if event.index in streaming_tool_calls:
131 view = streaming_tool_calls[event.index]
132 if event.delta.args_delta:
133 if isinstance(event.delta.args_delta, str):
134 view.append_args(event.delta.args_delta)
135 # TODO(rgbkrk): Determine if this dict is fully parsed JSON from the args at this point in the Pydantic AI Event Cycle
136 else: # dict[str, Any]
137 view.append_args(json.dumps(event.delta.args_delta))
138 if event.delta.tool_name_delta:
139 view.append_tool_name(event.delta.tool_name_delta)
141 elif isinstance(event, FunctionToolCallEvent):
142 pass
144 elif isinstance(event, FunctionToolResultEvent):
145 display(ToolResultView.from_part(event.result))
147 elif isinstance(event, PartStartEvent) and isinstance(
148 event.part, ThinkingPart
149 ):
150 if event.part.content:
151 get_or_create_thinking().append(event.part.content)
153 elif isinstance(event, PartDeltaEvent) and isinstance(
154 event.delta, ThinkingPartDelta
155 ):
156 if event.delta.content_delta:
157 get_or_create_thinking().append(event.delta.content_delta)
159 elif isinstance(event, PartStartEvent) and isinstance(event.part, TextPart):
160 finish_thinking()
161 finish_streaming_tool_calls()
162 if event.part.content:
163 get_or_create_markdown().append(event.part.content)
165 elif isinstance(event, PartDeltaEvent) and isinstance(
166 event.delta, TextPartDelta
167 ):
168 if event.delta.content_delta:
169 get_or_create_markdown().append(event.delta.content_delta)
171 elif isinstance(event, AgentRunResultEvent):
172 return event.result
173 else:
174 if debug:
175 display(DebugEventView.from_event(event))
177 except Exception as e:
178 finish_markdown()
179 finish_thinking()
180 display(ErrorView.from_exception(e))
181 raise
183 # Never got a result
184 return None