Coverage for src / documint_mcp / jobs.py: 0%

150 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 22:30 -0400

1"""Queue-backed job dispatch for Documint long-running workflows.""" 

2 

3from __future__ import annotations 

4 

5from collections.abc import Callable 

6from typing import Any 

7 

8from .config import settings 

9from .models import DriftJobRequest, PullRequestCreateRequest, QueuedJob 

10from .repository import get_service 

11 

12_redis_settings_factory: Any = None 

13 

14try: 

15 from arq import create_pool 

16 from arq.connections import RedisSettings 

17 

18 _redis_settings_factory = RedisSettings 

19except ModuleNotFoundError: # pragma: no cover - dependency is installed in runtime. 

20 create_pool = None # type: ignore[assignment] 

21 

22_pool: Any | None = None 

23 

24 

25def _redis_settings() -> Any: 

26 if _redis_settings_factory is None: 

27 raise RuntimeError("arq is not installed") 

28 return _redis_settings_factory.from_dsn(settings.redis_url) 

29 

30 

31async def get_arq_pool() -> Any: 

32 global _pool 

33 if _pool is None: 

34 if create_pool is None: 

35 raise RuntimeError("arq is not installed") 

36 _pool = await create_pool(_redis_settings()) 

37 return _pool 

38 

39 

40async def close_arq_pool() -> None: 

41 global _pool 

42 if _pool is not None: 

43 await _pool.close() 

44 _pool = None 

45 

46 

47def _drift_result_payload(result: Any) -> tuple[str, str | None, str, dict[str, object]]: 

48 return ( 

49 "verification_run", 

50 str(result.id), 

51 f"{result.findings_count} findings returned.", 

52 {"run_id": result.id, "findings_count": result.findings_count}, 

53 ) 

54 

55 

56def _patch_result_payload(result: Any) -> tuple[str, str | None, str, dict[str, object]]: 

57 return ( 

58 "doc_patch", 

59 str(result.id), 

60 str(result.summary), 

61 {"patch_id": result.id, "status": result.status}, 

62 ) 

63 

64 

65def _publish_result_payload( 

66 result: Any, 

67) -> tuple[str, str | None, str, dict[str, object]]: 

68 return ( 

69 "publish_deployment", 

70 str(result.id), 

71 f"Published {result.docs_count} pages.", 

72 {"deployment_id": result.id, "docs_count": result.docs_count, "site_url": result.site_url}, 

73 ) 

74 

75 

76def _pull_request_result_payload( 

77 result: Any, 

78) -> tuple[str, str | None, str, dict[str, object]]: 

79 return ( 

80 "pull_request", 

81 str(result.id), 

82 str(result.title), 

83 {"pull_request_id": result.id, "state": result.state, "patch_id": result.patch_id}, 

84 ) 

85 

86 

87def _installation_result_payload( 

88 result: Any, 

89) -> tuple[str, str | None, str, dict[str, object]]: 

90 count = len(result) if isinstance(result, list) else 0 

91 return ( 

92 "installation_sync", 

93 None, 

94 f"{count} repositories synced.", 

95 {"repository_count": count}, 

96 ) 

97 

98 

99async def _enqueue_or_fail(function_name: str, payload: dict[str, object]) -> None: 

100 pool = await get_arq_pool() 

101 job = await pool.enqueue_job(function_name, payload) 

102 if job is None: 

103 get_service().mark_job_failed( 

104 str(payload["job_id"]), 

105 error_summary=f"Failed to enqueue {function_name}", 

106 ) 

107 raise RuntimeError(f"Failed to enqueue {function_name}") 

108 

109 

110def _complete_job( 

111 job_id: str, 

112 result: Any, 

113 payload_builder: Callable[[Any], tuple[str, str | None, str, dict[str, object]]], 

114) -> QueuedJob: 

115 resource_type, resource_id, result_summary, result_json = payload_builder(result) 

116 return get_service().mark_job_completed( 

117 job_id, 

118 resource_type=resource_type, 

119 resource_id=resource_id, 

120 result_summary=result_summary, 

121 result_json=result_json, 

122 ) 

123 

124 

125def _fail_job(job_id: str, exc: Exception) -> QueuedJob: 

126 return get_service().mark_job_failed(job_id, error_summary=str(exc)) 

127 

128 

129async def _run_inline_job( 

130 *, 

131 job_id: str, 

132 runner: Callable[[], Any], 

133 payload_builder: Callable[[Any], tuple[str, str | None, str, dict[str, object]]], 

134) -> QueuedJob: 

135 get_service().mark_job_running(job_id) 

136 try: 

137 result = runner() 

138 except Exception as exc: 

139 _fail_job(job_id, exc) 

140 raise 

141 return _complete_job(job_id, result, payload_builder) 

142 

143 

144async def dispatch_drift( 

145 request: DriftJobRequest, 

146 *, 

147 user_id: str | None = None, 

148) -> QueuedJob: 

149 payload = request.model_dump(mode="json") 

150 job = get_service().create_job( 

151 job_kind="drift", 

152 project_id=request.project_id, 

153 payload_json=payload, 

154 user_id=user_id, 

155 ) 

156 if settings.job_execution_mode == "inline": 

157 return await _run_inline_job( 

158 job_id=job.job_id, 

159 runner=lambda: get_service().run_drift(request, user_id=user_id), 

160 payload_builder=_drift_result_payload, 

161 ) 

162 await _enqueue_or_fail("run_drift_job", {**payload, "job_id": job.job_id, "user_id": user_id}) 

163 return get_service().get_job(job.job_id, user_id=user_id) 

164 

165 

166async def dispatch_patch( 

167 *, 

168 project_id: str, 

169 finding_id: str | None = None, 

170 artifact_id: str | None = None, 

171 policy: str = "on_demand", 

172 user_id: str | None = None, 

173) -> QueuedJob: 

174 payload = { 

175 "project_id": project_id, 

176 "finding_id": finding_id, 

177 "artifact_id": artifact_id, 

178 "policy": policy, 

179 } 

180 job = get_service().create_job( 

181 job_kind="patch", 

182 project_id=project_id, 

183 payload_json=payload, 

184 user_id=user_id, 

185 ) 

186 if settings.job_execution_mode == "inline": 

187 return await _run_inline_job( 

188 job_id=job.job_id, 

189 runner=lambda: get_service().generate_doc_patch( 

190 project_id=project_id, 

191 finding_id=finding_id, 

192 artifact_id=artifact_id, 

193 policy=policy, 

194 user_id=user_id, 

195 ), 

196 payload_builder=_patch_result_payload, 

197 ) 

198 await _enqueue_or_fail("generate_patch_job", {**payload, "job_id": job.job_id, "user_id": user_id}) 

199 return get_service().get_job(job.job_id, user_id=user_id) 

200 

201 

202async def dispatch_publish(project_id: str, *, user_id: str | None = None) -> QueuedJob: 

203 payload = {"project_id": project_id} 

204 job = get_service().create_job( 

205 job_kind="publish", 

206 project_id=project_id, 

207 payload_json=payload, 

208 user_id=user_id, 

209 ) 

210 if settings.job_execution_mode == "inline": 

211 return await _run_inline_job( 

212 job_id=job.job_id, 

213 runner=lambda: get_service().publish_preview(project_id, user_id=user_id), 

214 payload_builder=_publish_result_payload, 

215 ) 

216 await _enqueue_or_fail("publish_job", {**payload, "job_id": job.job_id, "user_id": user_id}) 

217 return get_service().get_job(job.job_id, user_id=user_id) 

218 

219 

220async def dispatch_pull_request( 

221 *, 

222 project_id: str, 

223 patch_id: str, 

224 title: str | None, 

225 user_id: str | None = None, 

226) -> QueuedJob: 

227 payload = {"project_id": project_id, "patch_id": patch_id, "title": title} 

228 job = get_service().create_job( 

229 job_kind="pull_request", 

230 project_id=project_id, 

231 payload_json=payload, 

232 user_id=user_id, 

233 ) 

234 if settings.job_execution_mode == "inline": 

235 return await _run_inline_job( 

236 job_id=job.job_id, 

237 runner=lambda: get_service().open_pull_request( 

238 project_id, 

239 patch_id, 

240 PullRequestCreateRequest(title=title), 

241 user_id=user_id, 

242 ), 

243 payload_builder=_pull_request_result_payload, 

244 ) 

245 await _enqueue_or_fail( 

246 "open_pull_request_job", 

247 {**payload, "job_id": job.job_id, "user_id": user_id}, 

248 ) 

249 return get_service().get_job(job.job_id, user_id=user_id) 

250 

251 

252async def dispatch_installation_sync( 

253 installation_id: str, 

254 *, 

255 user_id: str | None = None, 

256) -> QueuedJob: 

257 workspace_id = get_service().get_installation_workspace_id(installation_id) 

258 payload = {"installation_id": installation_id} 

259 job = get_service().create_job( 

260 job_kind="installation_sync", 

261 workspace_id=workspace_id, 

262 payload_json=payload, 

263 user_id=user_id, 

264 ) 

265 if settings.job_execution_mode == "inline": 

266 return await _run_inline_job( 

267 job_id=job.job_id, 

268 runner=lambda: get_service().sync_installation( 

269 installation_id, user_id=user_id 

270 ), 

271 payload_builder=_installation_result_payload, 

272 ) 

273 await _enqueue_or_fail( 

274 "sync_installation_job", 

275 {**payload, "job_id": job.job_id, "user_id": user_id}, 

276 ) 

277 return get_service().get_job(job.job_id, user_id=user_id) 

278 

279 

280async def run_drift_job(ctx: dict[str, Any], payload: dict[str, Any]) -> dict[str, Any]: 

281 del ctx 

282 job_id = str(payload["job_id"]) 

283 user_id = payload.get("user_id") 

284 get_service().mark_job_running(job_id) 

285 try: 

286 result = get_service().run_drift( 

287 DriftJobRequest( 

288 project_id=str(payload["project_id"]), 

289 signal_type=payload.get("signal_type", "manual"), 

290 changed_files=payload.get("changed_files"), 

291 ), 

292 user_id=str(user_id) if isinstance(user_id, str) else None, 

293 ) 

294 except Exception as exc: 

295 _fail_job(job_id, exc) 

296 raise 

297 return _complete_job(job_id, result, _drift_result_payload).model_dump(mode="json") 

298 

299 

300async def generate_patch_job( 

301 ctx: dict[str, Any], 

302 payload: dict[str, Any], 

303) -> dict[str, Any]: 

304 del ctx 

305 job_id = str(payload["job_id"]) 

306 user_id = payload.get("user_id") 

307 get_service().mark_job_running(job_id) 

308 try: 

309 result = get_service().generate_doc_patch( 

310 project_id=payload.get("project_id"), 

311 finding_id=payload.get("finding_id"), 

312 artifact_id=payload.get("artifact_id"), 

313 policy=str(payload.get("policy") or "on_demand"), 

314 user_id=str(user_id) if isinstance(user_id, str) else None, 

315 ) 

316 except Exception as exc: 

317 _fail_job(job_id, exc) 

318 raise 

319 return _complete_job(job_id, result, _patch_result_payload).model_dump(mode="json") 

320 

321 

322async def publish_job(ctx: dict[str, Any], payload: dict[str, Any]) -> dict[str, Any]: 

323 del ctx 

324 job_id = str(payload["job_id"]) 

325 user_id = payload.get("user_id") 

326 get_service().mark_job_running(job_id) 

327 try: 

328 result = get_service().publish_preview( 

329 str(payload["project_id"]), 

330 user_id=str(user_id) if isinstance(user_id, str) else None, 

331 ) 

332 except Exception as exc: 

333 _fail_job(job_id, exc) 

334 raise 

335 return _complete_job(job_id, result, _publish_result_payload).model_dump(mode="json") 

336 

337 

338async def open_pull_request_job( 

339 ctx: dict[str, Any], 

340 payload: dict[str, Any], 

341) -> dict[str, Any]: 

342 del ctx 

343 job_id = str(payload["job_id"]) 

344 user_id = payload.get("user_id") 

345 get_service().mark_job_running(job_id) 

346 try: 

347 result = get_service().open_pull_request( 

348 str(payload["project_id"]), 

349 str(payload["patch_id"]), 

350 PullRequestCreateRequest(title=payload.get("title")), 

351 user_id=str(user_id) if isinstance(user_id, str) else None, 

352 ) 

353 except Exception as exc: 

354 _fail_job(job_id, exc) 

355 raise 

356 return _complete_job(job_id, result, _pull_request_result_payload).model_dump( 

357 mode="json" 

358 ) 

359 

360 

361async def sync_installation_job( 

362 ctx: dict[str, Any], 

363 payload: dict[str, Any], 

364) -> dict[str, Any]: 

365 del ctx 

366 job_id = str(payload["job_id"]) 

367 user_id = payload.get("user_id") 

368 get_service().mark_job_running(job_id) 

369 try: 

370 result = get_service().sync_installation( 

371 str(payload["installation_id"]), 

372 user_id=str(user_id) if isinstance(user_id, str) else None, 

373 ) 

374 except Exception as exc: 

375 _fail_job(job_id, exc) 

376 raise 

377 return _complete_job(job_id, result, _installation_result_payload).model_dump( 

378 mode="json" 

379 ) 

380 

381 

382class WorkerSettings: 

383 functions = [ 

384 run_drift_job, 

385 generate_patch_job, 

386 publish_job, 

387 open_pull_request_job, 

388 sync_installation_job, 

389 ] 

390 redis_settings = _redis_settings() if _redis_settings_factory is not None else None