Coverage for pydantic_ai_jupyter / display.py: 96%

75 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-01-26 11:36 -0800

1"""Main display runner for pydantic-ai agents in Jupyter notebooks.""" 

2 

3from __future__ import annotations 

4 

5import json 

6from typing import TYPE_CHECKING, Any 

7 

8from IPython.display import display 

9from pydantic_ai import AgentRunResult 

10from pydantic_ai.messages import ( 

11 FunctionToolCallEvent, 

12 FunctionToolResultEvent, 

13 PartDeltaEvent, 

14 PartStartEvent, 

15 TextPart, 

16 TextPartDelta, 

17 ThinkingPart, 

18 ThinkingPartDelta, 

19 ToolCallPart, 

20 ToolCallPartDelta, 

21) 

22from pydantic_ai.run import AgentRunResultEvent 

23 

24from .markdown import Markdown 

25from .views import ( 

26 DebugEventView, 

27 ErrorView, 

28 StreamingToolCallView, 

29 ThinkingView, 

30 ToolResultView, 

31) 

32 

33if TYPE_CHECKING: 

34 from pydantic_ai import Agent 

35 

36 

37async def run_with_display( 

38 agent: Agent[Any, Any], 

39 user_prompt: str | None = None, 

40 *, 

41 debug: bool = False, 

42 **kwargs: Any, 

43) -> AgentRunResult[str] | None: 

44 """Run an agent with live Jupyter display of tool calls and streaming text. 

45 

46 All arguments except `debug` are passed directly to `agent.run_stream_events()`. 

47 

48 Args: 

49 agent: The pydantic-ai Agent to run 

50 user_prompt: The user's prompt/question 

51 debug: If True, display unhandled lifecycle events 

52 **kwargs: Passed to agent.run_stream_events() - includes deps, message_history, 

53 model_settings, usage_limits, toolsets, etc. 

54 

55 Returns: 

56 The agent result, or None if an exception occurred 

57 

58 Example: 

59 ```python 

60 result = await run_with_display(agent, "What's the weather?") 

61 

62 # Multi-turn conversation 

63 result = await run_with_display( 

64 agent, "What about London?", 

65 message_history=result.all_messages(), 

66 ) 

67 

68 # With dependencies and settings 

69 result = await run_with_display( 

70 agent, "Analyze this", 

71 deps=my_deps, 

72 model_settings={"temperature": 0.5}, 

73 ) 

74 ``` 

75 """ 

76 current_markdown: Markdown | None = None 

77 current_thinking: ThinkingView | None = None 

78 # Track streaming tool calls by part index 

79 streaming_tool_calls: dict[int, StreamingToolCallView] = {} 

80 

81 def get_or_create_markdown() -> Markdown: 

82 nonlocal current_markdown 

83 if current_markdown is None: 

84 current_markdown = Markdown(content="") 

85 current_markdown.display() 

86 return current_markdown 

87 

88 def finish_markdown() -> None: 

89 nonlocal current_markdown 

90 current_markdown = None 

91 

92 def get_or_create_thinking() -> ThinkingView: 

93 nonlocal current_thinking 

94 if current_thinking is None: 

95 current_thinking = ThinkingView(content="") 

96 current_thinking.display() 

97 return current_thinking 

98 

99 def finish_thinking() -> None: 

100 nonlocal current_thinking 

101 current_thinking = None 

102 

103 def finish_streaming_tool_calls() -> None: 

104 nonlocal streaming_tool_calls 

105 streaming_tool_calls = {} 

106 

107 try: 

108 async for event in agent.run_stream_events(user_prompt, **kwargs): 

109 # Handle streaming tool call parts (args streaming in) 

110 if isinstance(event, PartStartEvent) and isinstance( 

111 event.part, ToolCallPart 

112 ): 

113 finish_markdown() 

114 finish_thinking() 

115 # Start a new streaming tool call view 

116 view = StreamingToolCallView( 

117 tool_name=event.part.tool_name, 

118 args=event.part.args 

119 if isinstance(event.part.args, str) 

120 else json.dumps(event.part.args), 

121 tool_call_id=event.part.tool_call_id, 

122 ) 

123 view.display() 

124 view.update() 

125 streaming_tool_calls[event.index] = view 

126 

127 elif isinstance(event, PartDeltaEvent) and isinstance( 

128 event.delta, ToolCallPartDelta 

129 ): 

130 if event.index in streaming_tool_calls: 

131 view = streaming_tool_calls[event.index] 

132 if event.delta.args_delta: 

133 if isinstance(event.delta.args_delta, str): 

134 view.append_args(event.delta.args_delta) 

135 # TODO(rgbkrk): Determine if this dict is fully parsed JSON from the args at this point in the Pydantic AI Event Cycle 

136 else: # dict[str, Any] 

137 view.append_args(json.dumps(event.delta.args_delta)) 

138 if event.delta.tool_name_delta: 

139 view.append_tool_name(event.delta.tool_name_delta) 

140 

141 elif isinstance(event, FunctionToolCallEvent): 

142 pass 

143 

144 elif isinstance(event, FunctionToolResultEvent): 

145 display(ToolResultView.from_part(event.result)) 

146 

147 elif isinstance(event, PartStartEvent) and isinstance( 

148 event.part, ThinkingPart 

149 ): 

150 if event.part.content: 

151 get_or_create_thinking().append(event.part.content) 

152 

153 elif isinstance(event, PartDeltaEvent) and isinstance( 

154 event.delta, ThinkingPartDelta 

155 ): 

156 if event.delta.content_delta: 

157 get_or_create_thinking().append(event.delta.content_delta) 

158 

159 elif isinstance(event, PartStartEvent) and isinstance(event.part, TextPart): 

160 finish_thinking() 

161 finish_streaming_tool_calls() 

162 if event.part.content: 

163 get_or_create_markdown().append(event.part.content) 

164 

165 elif isinstance(event, PartDeltaEvent) and isinstance( 

166 event.delta, TextPartDelta 

167 ): 

168 if event.delta.content_delta: 

169 get_or_create_markdown().append(event.delta.content_delta) 

170 

171 elif isinstance(event, AgentRunResultEvent): 

172 return event.result 

173 else: 

174 if debug: 

175 display(DebugEventView.from_event(event)) 

176 

177 except Exception as e: 

178 finish_markdown() 

179 finish_thinking() 

180 display(ErrorView.from_exception(e)) 

181 raise 

182 

183 # Never got a result 

184 return None