Coverage for /Users/antonigmitruk/golf/src/golf/telemetry/instrumentation.py: 0%
779 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-16 18:46 +0200
« prev ^ index » next coverage.py v7.6.12, created at 2025-08-16 18:46 +0200
1"""Component-level OpenTelemetry instrumentation for Golf-built servers."""
3import asyncio
4import functools
5import os
6import sys
7import time
8import json
9from collections.abc import Callable
10from contextlib import asynccontextmanager
11from typing import Any, TypeVar
12from collections.abc import AsyncGenerator
13from collections import OrderedDict
15from opentelemetry import baggage, trace
17# Import endpoints with fallback for dev mode
18try:
19 # In built wheels, this exists (generated from _endpoints.py.in)
20 from golf import _endpoints # type: ignore
21except ImportError:
22 # In editable/dev installs, fall back to env-based values
23 from golf import _endpoints_fallback as _endpoints # type: ignore
24from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
25from opentelemetry.sdk.resources import Resource
26from opentelemetry.sdk.trace import TracerProvider
27from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
28from opentelemetry.trace import Status, StatusCode
30from starlette.middleware.base import BaseHTTPMiddleware
32T = TypeVar("T")
34# Global tracer instance
35_tracer: trace.Tracer | None = None
36_provider: TracerProvider | None = None
37_detailed_tracing_enabled: bool = False
40def _safe_serialize(data: Any, max_length: int = 1000) -> str | None:
41 """Safely serialize data to string with length limit."""
42 try:
43 if isinstance(data, str):
44 serialized = data
45 else:
46 serialized = json.dumps(data, default=str, ensure_ascii=False)
48 if len(serialized) > max_length:
49 return serialized[:max_length] + "..." + f" (truncated from {len(serialized)} chars)"
50 return serialized
51 except (TypeError, ValueError):
52 # Fallback for non-serializable objects
53 try:
54 return str(data)[:max_length] + "..." if len(str(data)) > max_length else str(data)
55 except Exception:
56 return None
59def set_detailed_tracing(enabled: bool) -> None:
60 """Enable or disable detailed tracing with input/output capture."""
61 global _detailed_tracing_enabled
62 _detailed_tracing_enabled = enabled
65def init_telemetry(service_name: str = "golf-mcp-server") -> TracerProvider | None:
66 """Initialize OpenTelemetry with environment-based configuration.
68 Returns None if required environment variables are not set.
69 """
70 global _provider
72 # Check for Golf platform integration first
73 golf_api_key = os.environ.get("GOLF_API_KEY")
74 if golf_api_key:
75 # Auto-configure for Golf platform - always use OTLP when Golf API
76 # key is present
77 os.environ["OTEL_TRACES_EXPORTER"] = "otlp_http"
79 # Only set endpoint if not already configured (allow user override)
80 if not os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
81 os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = _endpoints.OTEL_ENDPOINT
83 # Set Golf platform headers (append to existing if present)
84 existing_headers = os.environ.get("OTEL_EXPORTER_OTLP_HEADERS", "")
85 golf_header = f"X-Golf-Key={golf_api_key}"
87 if existing_headers:
88 # Check if Golf key is already in headers
89 if "X-Golf-Key=" not in existing_headers:
90 os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"{existing_headers},{golf_header}"
91 else:
92 os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = golf_header
94 # Check for required environment variables based on exporter type
95 exporter_type = os.environ.get("OTEL_TRACES_EXPORTER", "console").lower()
97 # For OTLP HTTP exporter, check if endpoint is configured
98 if exporter_type == "otlp_http":
99 endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
100 if not endpoint:
101 print(
102 "[WARNING] OpenTelemetry tracing is disabled: "
103 "OTEL_EXPORTER_OTLP_ENDPOINT is not set for OTLP HTTP exporter"
104 )
105 return None
107 # Create resource with service information
108 resource_attributes = {
109 "service.name": os.environ.get("OTEL_SERVICE_NAME", service_name),
110 "service.version": os.environ.get("SERVICE_VERSION", "1.0.0"),
111 "service.instance.id": os.environ.get("SERVICE_INSTANCE_ID", "default"),
112 }
114 # Add Golf-specific attributes if available
115 if golf_api_key:
116 golf_server_id = os.environ.get("GOLF_SERVER_ID")
117 if golf_server_id:
118 resource_attributes["golf.server.id"] = golf_server_id
119 resource_attributes["golf.platform.enabled"] = "true"
121 resource = Resource.create(resource_attributes)
123 # Create provider
124 provider = TracerProvider(resource=resource)
126 # Configure exporter based on type
127 try:
128 if exporter_type == "otlp_http":
129 endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4318/v1/traces")
130 headers = os.environ.get("OTEL_EXPORTER_OTLP_HEADERS", "")
132 # Parse headers if provided
133 header_dict = {}
134 if headers:
135 for header in headers.split(","):
136 if "=" in header:
137 key, value = header.split("=", 1)
138 header_dict[key.strip()] = value.strip()
140 exporter = OTLPSpanExporter(endpoint=endpoint, headers=header_dict if header_dict else None)
142 else:
143 # Default to console exporter
144 exporter = ConsoleSpanExporter(out=sys.stderr)
145 except Exception:
146 import traceback
148 traceback.print_exc()
149 raise
151 # Add batch processor for better performance
152 try:
153 processor = BatchSpanProcessor(
154 exporter,
155 max_queue_size=2048,
156 schedule_delay_millis=1000, # Export every 1 second instead of
157 # default 5 seconds
158 max_export_batch_size=512,
159 export_timeout_millis=5000,
160 )
161 provider.add_span_processor(processor)
162 except Exception:
163 import traceback
165 traceback.print_exc()
166 raise
168 # Set as global provider
169 try:
170 # Check if a provider is already set to avoid the warning
171 existing_provider = trace.get_tracer_provider()
172 if existing_provider is None or str(type(existing_provider).__name__) == "ProxyTracerProvider":
173 # Only set if no provider exists or it's the default proxy provider
174 trace.set_tracer_provider(provider)
175 _provider = provider
176 except Exception:
177 import traceback
179 traceback.print_exc()
180 raise
182 return provider
185def get_tracer() -> trace.Tracer:
186 """Get or create the global tracer instance."""
187 global _tracer, _provider
189 # If no provider is set, telemetry is disabled - return no-op tracer
190 if _provider is None:
191 return trace.get_tracer("golf.mcp.components.noop", "1.0.0")
193 if _tracer is None:
194 _tracer = trace.get_tracer("golf.mcp.components", "1.0.0")
195 return _tracer
198def instrument_tool(func: Callable[..., T], tool_name: str) -> Callable[..., T]:
199 """Instrument a tool function with OpenTelemetry tracing."""
200 global _provider
202 # If telemetry is disabled, return the original function
203 if _provider is None:
204 return func
206 tracer = get_tracer()
208 @functools.wraps(func)
209 async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
210 # Record metrics timing
211 import time
213 start_time = time.time()
215 # Create a more descriptive span name
216 span_name = f"mcp.tool.{tool_name}.execute"
218 # start_as_current_span automatically uses the current context and manages it
219 with tracer.start_as_current_span(span_name) as span:
220 # Add essential attributes only
221 span.set_attribute("mcp.component.type", "tool")
222 span.set_attribute("mcp.tool.name", tool_name)
223 span.set_attribute(
224 "mcp.tool.module",
225 func.__module__ if hasattr(func, "__module__") else "unknown",
226 )
228 # Add minimal execution context
229 if args or kwargs:
230 span.set_attribute("mcp.execution.has_params", True)
232 # Capture inputs if detailed tracing is enabled
233 if _detailed_tracing_enabled and (args or kwargs):
234 input_data = {"args": args, "kwargs": kwargs} if args or kwargs else None
235 if input_data:
236 input_str = _safe_serialize(input_data)
237 if input_str:
238 span.set_attribute("mcp.tool.input", input_str)
240 # Extract Context parameter if present
241 ctx = kwargs.get("ctx")
242 if ctx:
243 # Only extract known MCP context attributes
244 ctx_attrs = [
245 "request_id",
246 "session_id",
247 "client_id",
248 "user_id",
249 "tenant_id",
250 ]
251 for attr in ctx_attrs:
252 if hasattr(ctx, attr):
253 value = getattr(ctx, attr)
254 if value is not None:
255 span.set_attribute(f"mcp.context.{attr}", str(value))
257 # Also check baggage for session ID
258 session_id_from_baggage = baggage.get_baggage("mcp.session.id")
259 if session_id_from_baggage:
260 span.set_attribute("mcp.session.id", session_id_from_baggage)
262 # Add event for tool execution start
263 span.add_event("tool.execution.started", {"tool.name": tool_name})
265 try:
266 result = await func(*args, **kwargs)
267 span.set_status(Status(StatusCode.OK))
269 # Add event for successful completion
270 span.add_event("tool.execution.completed", {"tool.name": tool_name})
272 # Record metrics for successful execution
273 try:
274 from golf.metrics import get_metrics_collector
276 metrics_collector = get_metrics_collector()
277 metrics_collector.increment_tool_execution(tool_name, "success")
278 metrics_collector.record_tool_duration(tool_name, time.time() - start_time)
279 except ImportError:
280 # Metrics not available, continue without metrics
281 pass
283 # Capture result metadata
284 if result is not None:
285 span.set_attribute("mcp.tool.result.type", type(result).__name__)
287 if isinstance(result, list | dict) and hasattr(result, "__len__"):
288 span.set_attribute("mcp.tool.result.size", len(result))
289 elif isinstance(result, str):
290 span.set_attribute("mcp.tool.result.length", len(result))
292 # Capture full output if detailed tracing is enabled
293 if _detailed_tracing_enabled:
294 output_str = _safe_serialize(result)
295 if output_str:
296 span.set_attribute("mcp.tool.output", output_str)
298 return result
299 except Exception as e:
300 span.record_exception(e)
301 span.set_status(Status(StatusCode.ERROR, str(e)))
303 # Add event for error
304 span.add_event(
305 "tool.execution.error",
306 {
307 "tool.name": tool_name,
308 "error.type": type(e).__name__,
309 "error.message": str(e),
310 },
311 )
313 # Record metrics for failed execution
314 try:
315 from golf.metrics import get_metrics_collector
317 metrics_collector = get_metrics_collector()
318 metrics_collector.increment_tool_execution(tool_name, "error")
319 metrics_collector.increment_error("tool", type(e).__name__)
320 except ImportError:
321 # Metrics not available, continue without metrics
322 pass
324 raise
326 @functools.wraps(func)
327 def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
328 # Record metrics timing
329 import time
331 start_time = time.time()
333 # Create a more descriptive span name
334 span_name = f"mcp.tool.{tool_name}.execute"
336 # start_as_current_span automatically uses the current context and manages it
337 with tracer.start_as_current_span(span_name) as span:
338 # Add essential attributes only
339 span.set_attribute("mcp.component.type", "tool")
340 span.set_attribute("mcp.tool.name", tool_name)
341 span.set_attribute(
342 "mcp.tool.module",
343 func.__module__ if hasattr(func, "__module__") else "unknown",
344 )
346 # Add execution context
347 span.set_attribute("mcp.execution.args_count", len(args))
348 span.set_attribute("mcp.execution.kwargs_count", len(kwargs))
350 # Extract Context parameter if present
351 ctx = kwargs.get("ctx")
352 if ctx:
353 # Only extract known MCP context attributes
354 ctx_attrs = [
355 "request_id",
356 "session_id",
357 "client_id",
358 "user_id",
359 "tenant_id",
360 ]
361 for attr in ctx_attrs:
362 if hasattr(ctx, attr):
363 value = getattr(ctx, attr)
364 if value is not None:
365 span.set_attribute(f"mcp.context.{attr}", str(value))
367 # Also check baggage for session ID
368 session_id_from_baggage = baggage.get_baggage("mcp.session.id")
369 if session_id_from_baggage:
370 span.set_attribute("mcp.session.id", session_id_from_baggage)
372 # Add event for tool execution start
373 span.add_event("tool.execution.started", {"tool.name": tool_name})
375 try:
376 result = func(*args, **kwargs)
377 span.set_status(Status(StatusCode.OK))
379 # Add event for successful completion
380 span.add_event("tool.execution.completed", {"tool.name": tool_name})
382 # Record metrics for successful execution
383 try:
384 from golf.metrics import get_metrics_collector
386 metrics_collector = get_metrics_collector()
387 metrics_collector.increment_tool_execution(tool_name, "success")
388 metrics_collector.record_tool_duration(tool_name, time.time() - start_time)
389 except ImportError:
390 # Metrics not available, continue without metrics
391 pass
393 # Capture result metadata
394 if result is not None:
395 span.set_attribute("mcp.tool.result.type", type(result).__name__)
397 if isinstance(result, list | dict) and hasattr(result, "__len__"):
398 span.set_attribute("mcp.tool.result.size", len(result))
399 elif isinstance(result, str):
400 span.set_attribute("mcp.tool.result.length", len(result))
402 # Capture full output if detailed tracing is enabled
403 if _detailed_tracing_enabled:
404 output_str = _safe_serialize(result)
405 if output_str:
406 span.set_attribute("mcp.tool.output", output_str)
408 return result
409 except Exception as e:
410 span.record_exception(e)
411 span.set_status(Status(StatusCode.ERROR, str(e)))
413 # Add event for error
414 span.add_event(
415 "tool.execution.error",
416 {
417 "tool.name": tool_name,
418 "error.type": type(e).__name__,
419 "error.message": str(e),
420 },
421 )
423 # Record metrics for failed execution
424 try:
425 from golf.metrics import get_metrics_collector
427 metrics_collector = get_metrics_collector()
428 metrics_collector.increment_tool_execution(tool_name, "error")
429 metrics_collector.increment_error("tool", type(e).__name__)
430 except ImportError:
431 # Metrics not available, continue without metrics
432 pass
434 raise
436 # Return appropriate wrapper based on function type
437 if asyncio.iscoroutinefunction(func):
438 return async_wrapper
439 else:
440 return sync_wrapper
443def instrument_resource(func: Callable[..., T], resource_uri: str) -> Callable[..., T]:
444 """Instrument a resource function with OpenTelemetry tracing."""
445 global _provider
447 # If telemetry is disabled, return the original function
448 if _provider is None:
449 return func
451 tracer = get_tracer()
453 # Determine if this is a template based on URI pattern
454 is_template = "{" in resource_uri
456 @functools.wraps(func)
457 async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
458 # Create a more descriptive span name
459 span_name = f"mcp.resource.{'template' if is_template else 'static'}.read"
460 with tracer.start_as_current_span(span_name) as span:
461 # Add essential attributes only
462 span.set_attribute("mcp.component.type", "resource")
463 span.set_attribute("mcp.resource.uri", resource_uri)
464 span.set_attribute("mcp.resource.is_template", is_template)
465 span.set_attribute(
466 "mcp.resource.module",
467 func.__module__ if hasattr(func, "__module__") else "unknown",
468 )
470 # Extract Context parameter if present
471 ctx = kwargs.get("ctx")
472 if ctx:
473 # Only extract known MCP context attributes
474 ctx_attrs = [
475 "request_id",
476 "session_id",
477 "client_id",
478 "user_id",
479 "tenant_id",
480 ]
481 for attr in ctx_attrs:
482 if hasattr(ctx, attr):
483 value = getattr(ctx, attr)
484 if value is not None:
485 span.set_attribute(f"mcp.context.{attr}", str(value))
487 # Also check baggage for session ID
488 session_id_from_baggage = baggage.get_baggage("mcp.session.id")
489 if session_id_from_baggage:
490 span.set_attribute("mcp.session.id", session_id_from_baggage)
492 # Add event for resource read start
493 span.add_event("resource.read.started", {"resource.uri": resource_uri})
495 try:
496 result = await func(*args, **kwargs)
497 span.set_status(Status(StatusCode.OK))
499 # Add event for successful read
500 span.add_event("resource.read.completed", {"resource.uri": resource_uri})
502 # Add result metadata
503 if hasattr(result, "__len__"):
504 span.set_attribute("mcp.resource.result.size", len(result))
506 # Determine content type if possible
507 if isinstance(result, str):
508 span.set_attribute("mcp.resource.result.type", "text")
509 span.set_attribute("mcp.resource.result.length", len(result))
510 elif isinstance(result, bytes):
511 span.set_attribute("mcp.resource.result.type", "binary")
512 span.set_attribute("mcp.resource.result.size_bytes", len(result))
513 elif isinstance(result, dict):
514 span.set_attribute("mcp.resource.result.type", "object")
515 span.set_attribute("mcp.resource.result.keys_count", len(result))
516 elif isinstance(result, list):
517 span.set_attribute("mcp.resource.result.type", "array")
518 span.set_attribute("mcp.resource.result.items_count", len(result))
520 return result
521 except Exception as e:
522 span.record_exception(e)
523 span.set_status(Status(StatusCode.ERROR, str(e)))
525 # Add event for error
526 span.add_event(
527 "resource.read.error",
528 {
529 "resource.uri": resource_uri,
530 "error.type": type(e).__name__,
531 "error.message": str(e),
532 },
533 )
534 raise
536 @functools.wraps(func)
537 def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
538 # Create a more descriptive span name
539 span_name = f"mcp.resource.{'template' if is_template else 'static'}.read"
540 with tracer.start_as_current_span(span_name) as span:
541 # Add essential attributes only
542 span.set_attribute("mcp.component.type", "resource")
543 span.set_attribute("mcp.resource.uri", resource_uri)
544 span.set_attribute("mcp.resource.is_template", is_template)
545 span.set_attribute(
546 "mcp.resource.module",
547 func.__module__ if hasattr(func, "__module__") else "unknown",
548 )
550 # Extract Context parameter if present
551 ctx = kwargs.get("ctx")
552 if ctx:
553 # Only extract known MCP context attributes
554 ctx_attrs = [
555 "request_id",
556 "session_id",
557 "client_id",
558 "user_id",
559 "tenant_id",
560 ]
561 for attr in ctx_attrs:
562 if hasattr(ctx, attr):
563 value = getattr(ctx, attr)
564 if value is not None:
565 span.set_attribute(f"mcp.context.{attr}", str(value))
567 # Also check baggage for session ID
568 session_id_from_baggage = baggage.get_baggage("mcp.session.id")
569 if session_id_from_baggage:
570 span.set_attribute("mcp.session.id", session_id_from_baggage)
572 # Add event for resource read start
573 span.add_event("resource.read.started", {"resource.uri": resource_uri})
575 try:
576 result = func(*args, **kwargs)
577 span.set_status(Status(StatusCode.OK))
579 # Add event for successful read
580 span.add_event("resource.read.completed", {"resource.uri": resource_uri})
582 # Add result metadata
583 if hasattr(result, "__len__"):
584 span.set_attribute("mcp.resource.result.size", len(result))
586 # Determine content type if possible
587 if isinstance(result, str):
588 span.set_attribute("mcp.resource.result.type", "text")
589 span.set_attribute("mcp.resource.result.length", len(result))
590 elif isinstance(result, bytes):
591 span.set_attribute("mcp.resource.result.type", "binary")
592 span.set_attribute("mcp.resource.result.size_bytes", len(result))
593 elif isinstance(result, dict):
594 span.set_attribute("mcp.resource.result.type", "object")
595 span.set_attribute("mcp.resource.result.keys_count", len(result))
596 elif isinstance(result, list):
597 span.set_attribute("mcp.resource.result.type", "array")
598 span.set_attribute("mcp.resource.result.items_count", len(result))
600 return result
601 except Exception as e:
602 span.record_exception(e)
603 span.set_status(Status(StatusCode.ERROR, str(e)))
605 # Add event for error
606 span.add_event(
607 "resource.read.error",
608 {
609 "resource.uri": resource_uri,
610 "error.type": type(e).__name__,
611 "error.message": str(e),
612 },
613 )
614 raise
616 if asyncio.iscoroutinefunction(func):
617 return async_wrapper
618 else:
619 return sync_wrapper
622def instrument_elicitation(func: Callable[..., T], elicitation_type: str = "elicit") -> Callable[..., T]:
623 """Instrument an elicitation function with OpenTelemetry tracing."""
624 global _provider
626 # If telemetry is disabled, return the original function
627 if _provider is None:
628 return func
630 tracer = get_tracer()
632 @functools.wraps(func)
633 async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
634 # If telemetry is disabled at runtime, call original function
635 global _provider
636 if _provider is None:
637 return await func(*args, **kwargs)
639 # Record metrics timing
640 start_time = time.time()
642 # Create a more descriptive span name
643 span_name = f"mcp.elicitation.{elicitation_type}.request"
644 with tracer.start_as_current_span(span_name) as span:
645 # Add essential attributes
646 span.set_attribute("mcp.component.type", "elicitation")
647 span.set_attribute("mcp.elicitation.type", elicitation_type)
649 # Capture elicitation parameters if detailed tracing is enabled
650 if _detailed_tracing_enabled:
651 # Extract message from first argument (common pattern)
652 if args:
653 message = args[0] if isinstance(args[0], str) else None
654 if message:
655 span.set_attribute("mcp.elicitation.message", _safe_serialize(message, 500))
657 # Extract response_type from kwargs/args
658 response_type = kwargs.get("response_type") or (args[1] if len(args) > 1 else None)
659 if response_type is not None:
660 if isinstance(response_type, list):
661 span.set_attribute("mcp.elicitation.response_type", "choice")
662 span.set_attribute("mcp.elicitation.choices", str(response_type))
663 elif hasattr(response_type, "__name__"):
664 span.set_attribute("mcp.elicitation.response_type", response_type.__name__)
665 else:
666 span.set_attribute("mcp.elicitation.response_type", str(type(response_type).__name__))
668 # Extract Context parameter if present
669 ctx = kwargs.get("ctx")
670 if ctx:
671 ctx_attrs = ["request_id", "session_id", "client_id", "user_id", "tenant_id"]
672 for attr in ctx_attrs:
673 if hasattr(ctx, attr):
674 value = getattr(ctx, attr)
675 if value is not None:
676 span.set_attribute(f"mcp.context.{attr}", str(value))
678 # Add event for elicitation start
679 span.add_event("elicitation.request.started")
681 try:
682 result = await func(*args, **kwargs)
683 span.set_status(Status(StatusCode.OK))
685 # Add event for successful completion
686 span.add_event("elicitation.request.completed")
688 # Capture result metadata
689 if result is not None and _detailed_tracing_enabled:
690 if isinstance(result, str):
691 span.set_attribute("mcp.elicitation.result.content", _safe_serialize(result, 500))
692 elif isinstance(result, (list, dict)) and hasattr(result, "__len__"):
693 span.set_attribute("mcp.elicitation.result.size", len(result))
694 span.set_attribute("mcp.elicitation.result.content", _safe_serialize(result, 1000))
696 # Record metrics for successful elicitation
697 try:
698 from golf.metrics import get_metrics_collector
700 metrics_collector = get_metrics_collector()
701 metrics_collector.increment_elicitation(elicitation_type, "success")
702 metrics_collector.record_elicitation_duration(elicitation_type, time.time() - start_time)
703 except ImportError:
704 pass
706 return result
707 except Exception as e:
708 span.record_exception(e)
709 span.set_status(Status(StatusCode.ERROR, str(e)))
711 # Add event for error
712 span.add_event(
713 "elicitation.request.error",
714 {
715 "error.type": type(e).__name__,
716 "error.message": str(e),
717 },
718 )
720 # Record metrics for failed elicitation
721 try:
722 from golf.metrics import get_metrics_collector
724 metrics_collector = get_metrics_collector()
725 metrics_collector.increment_elicitation(elicitation_type, "error")
726 metrics_collector.increment_error("elicitation", type(e).__name__)
727 except ImportError:
728 pass
730 raise
732 @functools.wraps(func)
733 def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
734 # If telemetry is disabled at runtime, call original function
735 global _provider
736 if _provider is None:
737 return func(*args, **kwargs)
739 # Record metrics timing
740 start_time = time.time()
742 # Create a more descriptive span name
743 span_name = f"mcp.elicitation.{elicitation_type}.request"
744 with tracer.start_as_current_span(span_name) as span:
745 # Add essential attributes
746 span.set_attribute("mcp.component.type", "elicitation")
747 span.set_attribute("mcp.elicitation.type", elicitation_type)
749 # Capture elicitation parameters if detailed tracing is enabled
750 if _detailed_tracing_enabled:
751 if args:
752 message = args[0] if isinstance(args[0], str) else None
753 if message:
754 span.set_attribute("mcp.elicitation.message", _safe_serialize(message, 500))
756 # Add event for elicitation start
757 span.add_event("elicitation.request.started")
759 try:
760 result = func(*args, **kwargs)
761 span.set_status(Status(StatusCode.OK))
763 # Add event for successful completion
764 span.add_event("elicitation.request.completed")
766 # Record metrics for successful elicitation
767 try:
768 from golf.metrics import get_metrics_collector
770 metrics_collector = get_metrics_collector()
771 metrics_collector.increment_elicitation(elicitation_type, "success")
772 metrics_collector.record_elicitation_duration(elicitation_type, time.time() - start_time)
773 except ImportError:
774 pass
776 return result
777 except Exception as e:
778 span.record_exception(e)
779 span.set_status(Status(StatusCode.ERROR, str(e)))
781 # Add event for error
782 span.add_event(
783 "elicitation.request.error",
784 {
785 "error.type": type(e).__name__,
786 "error.message": str(e),
787 },
788 )
790 # Record metrics for failed elicitation
791 try:
792 from golf.metrics import get_metrics_collector
794 metrics_collector = get_metrics_collector()
795 metrics_collector.increment_elicitation(elicitation_type, "error")
796 metrics_collector.increment_error("elicitation", type(e).__name__)
797 except ImportError:
798 pass
800 raise
802 if asyncio.iscoroutinefunction(func):
803 return async_wrapper
804 else:
805 return sync_wrapper
808def instrument_sampling(func: Callable[..., T], sampling_type: str = "sample") -> Callable[..., T]:
809 """Instrument a sampling function with OpenTelemetry tracing."""
810 global _provider
812 # If telemetry is disabled, return the original function
813 if _provider is None:
814 return func
816 tracer = get_tracer()
818 @functools.wraps(func)
819 async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
820 # If telemetry is disabled at runtime, call original function
821 global _provider
822 if _provider is None:
823 return await func(*args, **kwargs)
825 # Record metrics timing
826 start_time = time.time()
828 # Create a more descriptive span name
829 span_name = f"mcp.sampling.{sampling_type}.request"
830 with tracer.start_as_current_span(span_name) as span:
831 # Add essential attributes
832 span.set_attribute("mcp.component.type", "sampling")
833 span.set_attribute("mcp.sampling.type", sampling_type)
835 # Capture sampling parameters
836 messages = kwargs.get("messages") or (args[0] if args else None)
837 if messages and _detailed_tracing_enabled:
838 if isinstance(messages, str):
839 span.set_attribute("mcp.sampling.messages.content", _safe_serialize(messages, 1000))
840 elif isinstance(messages, list):
841 span.set_attribute("mcp.sampling.messages.type", "list")
842 span.set_attribute("mcp.sampling.messages.count", len(messages))
843 span.set_attribute("mcp.sampling.messages.content", _safe_serialize(messages, 1000))
845 # Capture other sampling parameters
846 system_prompt = kwargs.get("system_prompt")
847 if system_prompt and _detailed_tracing_enabled:
848 span.set_attribute("mcp.sampling.system_prompt.length", len(str(system_prompt)))
849 span.set_attribute("mcp.sampling.system_prompt.content", _safe_serialize(system_prompt, 500))
851 temperature = kwargs.get("temperature")
852 if temperature is not None:
853 span.set_attribute("mcp.sampling.temperature", temperature)
855 max_tokens = kwargs.get("max_tokens")
856 if max_tokens is not None:
857 span.set_attribute("mcp.sampling.max_tokens", max_tokens)
859 model_preferences = kwargs.get("model_preferences")
860 if model_preferences:
861 if isinstance(model_preferences, str):
862 span.set_attribute("mcp.sampling.model_preferences", model_preferences)
863 elif isinstance(model_preferences, list):
864 span.set_attribute("mcp.sampling.model_preferences", ",".join(model_preferences))
866 # Extract Context parameter if present
867 ctx = kwargs.get("ctx")
868 if ctx:
869 ctx_attrs = ["request_id", "session_id", "client_id", "user_id", "tenant_id"]
870 for attr in ctx_attrs:
871 if hasattr(ctx, attr):
872 value = getattr(ctx, attr)
873 if value is not None:
874 span.set_attribute(f"mcp.context.{attr}", str(value))
876 # Add event for sampling start
877 span.add_event("sampling.request.started")
879 try:
880 result = await func(*args, **kwargs)
881 span.set_status(Status(StatusCode.OK))
883 # Add event for successful completion
884 span.add_event("sampling.request.completed")
886 # Capture result metadata
887 if result is not None and _detailed_tracing_enabled and isinstance(result, str):
888 span.set_attribute("mcp.sampling.result.content", _safe_serialize(result, 1000))
890 # Record metrics for successful sampling
891 try:
892 from golf.metrics import get_metrics_collector
894 metrics_collector = get_metrics_collector()
895 metrics_collector.increment_sampling(sampling_type, "success")
896 metrics_collector.record_sampling_duration(sampling_type, time.time() - start_time)
897 if isinstance(result, str):
898 metrics_collector.record_sampling_tokens(sampling_type, len(result.split()))
899 except ImportError:
900 pass
902 return result
903 except Exception as e:
904 span.record_exception(e)
905 span.set_status(Status(StatusCode.ERROR, str(e)))
907 # Add event for error
908 span.add_event(
909 "sampling.request.error",
910 {
911 "error.type": type(e).__name__,
912 "error.message": str(e),
913 },
914 )
916 # Record metrics for failed sampling
917 try:
918 from golf.metrics import get_metrics_collector
920 metrics_collector = get_metrics_collector()
921 metrics_collector.increment_sampling(sampling_type, "error")
922 metrics_collector.increment_error("sampling", type(e).__name__)
923 except ImportError:
924 pass
926 raise
928 @functools.wraps(func)
929 def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
930 # If telemetry is disabled at runtime, call original function
931 global _provider
932 if _provider is None:
933 return func(*args, **kwargs)
935 # Record metrics timing
936 start_time = time.time()
938 # Create a more descriptive span name
939 span_name = f"mcp.sampling.{sampling_type}.request"
940 with tracer.start_as_current_span(span_name) as span:
941 # Add essential attributes
942 span.set_attribute("mcp.component.type", "sampling")
943 span.set_attribute("mcp.sampling.type", sampling_type)
945 # Add event for sampling start
946 span.add_event("sampling.request.started")
948 try:
949 result = func(*args, **kwargs)
950 span.set_status(Status(StatusCode.OK))
952 # Add event for successful completion
953 span.add_event("sampling.request.completed")
955 # Record metrics for successful sampling
956 try:
957 from golf.metrics import get_metrics_collector
959 metrics_collector = get_metrics_collector()
960 metrics_collector.increment_sampling(sampling_type, "success")
961 metrics_collector.record_sampling_duration(sampling_type, time.time() - start_time)
962 except ImportError:
963 pass
965 return result
966 except Exception as e:
967 span.record_exception(e)
968 span.set_status(Status(StatusCode.ERROR, str(e)))
970 # Add event for error
971 span.add_event(
972 "sampling.request.error",
973 {
974 "error.type": type(e).__name__,
975 "error.message": str(e),
976 },
977 )
978 raise
980 if asyncio.iscoroutinefunction(func):
981 return async_wrapper
982 else:
983 return sync_wrapper
986def instrument_prompt(func: Callable[..., T], prompt_name: str) -> Callable[..., T]:
987 """Instrument a prompt function with OpenTelemetry tracing."""
988 global _provider
990 # If telemetry is disabled, return the original function
991 if _provider is None:
992 return func
994 tracer = get_tracer()
996 @functools.wraps(func)
997 async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
998 # Create a more descriptive span name
999 span_name = f"mcp.prompt.{prompt_name}.generate"
1000 with tracer.start_as_current_span(span_name) as span:
1001 # Add essential attributes only
1002 span.set_attribute("mcp.component.type", "prompt")
1003 span.set_attribute("mcp.prompt.name", prompt_name)
1004 span.set_attribute(
1005 "mcp.prompt.module",
1006 func.__module__ if hasattr(func, "__module__") else "unknown",
1007 )
1009 # Extract Context parameter if present
1010 ctx = kwargs.get("ctx")
1011 if ctx:
1012 # Only extract known MCP context attributes
1013 ctx_attrs = [
1014 "request_id",
1015 "session_id",
1016 "client_id",
1017 "user_id",
1018 "tenant_id",
1019 ]
1020 for attr in ctx_attrs:
1021 if hasattr(ctx, attr):
1022 value = getattr(ctx, attr)
1023 if value is not None:
1024 span.set_attribute(f"mcp.context.{attr}", str(value))
1026 # Also check baggage for session ID
1027 session_id_from_baggage = baggage.get_baggage("mcp.session.id")
1028 if session_id_from_baggage:
1029 span.set_attribute("mcp.session.id", session_id_from_baggage)
1031 # Add event for prompt generation start
1032 span.add_event("prompt.generation.started", {"prompt.name": prompt_name})
1034 try:
1035 result = await func(*args, **kwargs)
1036 span.set_status(Status(StatusCode.OK))
1038 # Add event for successful generation
1039 span.add_event("prompt.generation.completed", {"prompt.name": prompt_name})
1041 # Add message count and type information
1042 if isinstance(result, list):
1043 span.set_attribute("mcp.prompt.result.message_count", len(result))
1044 span.set_attribute("mcp.prompt.result.type", "message_list")
1046 # Analyze message types if they have role attributes
1047 roles = []
1048 for msg in result:
1049 if hasattr(msg, "role"):
1050 roles.append(msg.role)
1051 elif isinstance(msg, dict) and "role" in msg:
1052 roles.append(msg["role"])
1054 if roles:
1055 unique_roles = list(set(roles))
1056 span.set_attribute("mcp.prompt.result.roles", ",".join(unique_roles))
1057 span.set_attribute(
1058 "mcp.prompt.result.role_counts",
1059 str({role: roles.count(role) for role in unique_roles}),
1060 )
1061 elif isinstance(result, str):
1062 span.set_attribute("mcp.prompt.result.type", "string")
1063 span.set_attribute("mcp.prompt.result.length", len(result))
1064 else:
1065 span.set_attribute("mcp.prompt.result.type", type(result).__name__)
1067 return result
1068 except Exception as e:
1069 span.record_exception(e)
1070 span.set_status(Status(StatusCode.ERROR, str(e)))
1072 # Add event for error
1073 span.add_event(
1074 "prompt.generation.error",
1075 {
1076 "prompt.name": prompt_name,
1077 "error.type": type(e).__name__,
1078 "error.message": str(e),
1079 },
1080 )
1081 raise
1083 @functools.wraps(func)
1084 def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
1085 # Create a more descriptive span name
1086 span_name = f"mcp.prompt.{prompt_name}.generate"
1087 with tracer.start_as_current_span(span_name) as span:
1088 # Add essential attributes only
1089 span.set_attribute("mcp.component.type", "prompt")
1090 span.set_attribute("mcp.prompt.name", prompt_name)
1091 span.set_attribute(
1092 "mcp.prompt.module",
1093 func.__module__ if hasattr(func, "__module__") else "unknown",
1094 )
1096 # Extract Context parameter if present
1097 ctx = kwargs.get("ctx")
1098 if ctx:
1099 # Only extract known MCP context attributes
1100 ctx_attrs = [
1101 "request_id",
1102 "session_id",
1103 "client_id",
1104 "user_id",
1105 "tenant_id",
1106 ]
1107 for attr in ctx_attrs:
1108 if hasattr(ctx, attr):
1109 value = getattr(ctx, attr)
1110 if value is not None:
1111 span.set_attribute(f"mcp.context.{attr}", str(value))
1113 # Also check baggage for session ID
1114 session_id_from_baggage = baggage.get_baggage("mcp.session.id")
1115 if session_id_from_baggage:
1116 span.set_attribute("mcp.session.id", session_id_from_baggage)
1118 # Add event for prompt generation start
1119 span.add_event("prompt.generation.started", {"prompt.name": prompt_name})
1121 try:
1122 result = func(*args, **kwargs)
1123 span.set_status(Status(StatusCode.OK))
1125 # Add event for successful generation
1126 span.add_event("prompt.generation.completed", {"prompt.name": prompt_name})
1128 # Add message count and type information
1129 if isinstance(result, list):
1130 span.set_attribute("mcp.prompt.result.message_count", len(result))
1131 span.set_attribute("mcp.prompt.result.type", "message_list")
1133 # Analyze message types if they have role attributes
1134 roles = []
1135 for msg in result:
1136 if hasattr(msg, "role"):
1137 roles.append(msg.role)
1138 elif isinstance(msg, dict) and "role" in msg:
1139 roles.append(msg["role"])
1141 if roles:
1142 unique_roles = list(set(roles))
1143 span.set_attribute("mcp.prompt.result.roles", ",".join(unique_roles))
1144 span.set_attribute(
1145 "mcp.prompt.result.role_counts",
1146 str({role: roles.count(role) for role in unique_roles}),
1147 )
1148 elif isinstance(result, str):
1149 span.set_attribute("mcp.prompt.result.type", "string")
1150 span.set_attribute("mcp.prompt.result.length", len(result))
1151 else:
1152 span.set_attribute("mcp.prompt.result.type", type(result).__name__)
1154 return result
1155 except Exception as e:
1156 span.record_exception(e)
1157 span.set_status(Status(StatusCode.ERROR, str(e)))
1159 # Add event for error
1160 span.add_event(
1161 "prompt.generation.error",
1162 {
1163 "prompt.name": prompt_name,
1164 "error.type": type(e).__name__,
1165 "error.message": str(e),
1166 },
1167 )
1168 raise
1170 if asyncio.iscoroutinefunction(func):
1171 return async_wrapper
1172 else:
1173 return sync_wrapper
1176# Add the BoundedSessionTracker class before SessionTracingMiddleware
1177class BoundedSessionTracker:
1178 """Memory-safe session tracker with automatic expiration."""
1180 def __init__(self, max_sessions: int = 1000, session_ttl: int = 3600) -> None:
1181 self.max_sessions = max_sessions
1182 self.session_ttl = session_ttl
1183 self.sessions: OrderedDict[str, float] = OrderedDict()
1184 self.last_cleanup = time.time()
1186 def track_session(self, session_id: str) -> bool:
1187 """Track a session, returns True if it's new."""
1188 current_time = time.time()
1190 # Periodic cleanup (every 5 minutes)
1191 if current_time - self.last_cleanup > 300:
1192 self._cleanup_expired(current_time)
1193 self.last_cleanup = current_time
1195 # Check if session exists and is still valid
1196 if session_id in self.sessions:
1197 # Move to end (mark as recently used)
1198 self.sessions.move_to_end(session_id)
1199 return False
1201 # New session
1202 self.sessions[session_id] = current_time
1204 # Enforce max size
1205 while len(self.sessions) > self.max_sessions:
1206 self.sessions.popitem(last=False) # Remove oldest
1208 return True
1210 def _cleanup_expired(self, current_time: float) -> None:
1211 """Remove expired sessions."""
1212 expired = [sid for sid, timestamp in self.sessions.items() if current_time - timestamp > self.session_ttl]
1213 for sid in expired:
1214 del self.sessions[sid]
1216 def get_active_session_count(self) -> int:
1217 return len(self.sessions)
1220class SessionTracingMiddleware(BaseHTTPMiddleware):
1221 def __init__(self, app: Any) -> None:
1222 super().__init__(app)
1223 # Use memory-safe session tracker instead of unbounded collections
1224 self.session_tracker = BoundedSessionTracker(max_sessions=1000, session_ttl=3600)
1226 async def dispatch(self, request: Any, call_next: Callable[..., Any]) -> Any:
1227 # Record HTTP request timing
1228 import time
1230 start_time = time.time()
1232 # Extract session ID from query params or headers
1233 session_id = request.query_params.get("session_id")
1234 if not session_id:
1235 # Check headers as fallback
1236 session_id = request.headers.get("x-session-id")
1238 # Track session metrics using memory-safe tracker
1239 if session_id:
1240 is_new_session = self.session_tracker.track_session(session_id)
1242 if is_new_session:
1243 try:
1244 from golf.metrics import get_metrics_collector
1246 metrics_collector = get_metrics_collector()
1247 metrics_collector.increment_session()
1248 except ImportError:
1249 pass
1250 else:
1251 # Record session duration for existing sessions
1252 try:
1253 from golf.metrics import get_metrics_collector
1255 metrics_collector = get_metrics_collector()
1256 # Use a default duration since we don't track exact start
1257 # times anymore
1258 # This is less precise but memory-safe
1259 metrics_collector.record_session_duration(300.0) # 5 min default
1260 except ImportError:
1261 pass
1263 # Create a descriptive span name based on the request
1264 method = request.method
1265 path = request.url.path
1267 # Determine the operation type from the path
1268 operation_type = "unknown"
1269 if "/mcp" in path:
1270 operation_type = "mcp.request"
1271 elif "/sse" in path:
1272 operation_type = "sse.stream"
1273 elif "/auth" in path:
1274 operation_type = "auth"
1276 span_name = f"{operation_type}.{method.lower()}"
1278 tracer = get_tracer()
1279 with tracer.start_as_current_span(span_name) as span:
1280 # Add essential HTTP attributes
1281 span.set_attribute("http.method", method)
1282 span.set_attribute("http.target", path)
1283 span.set_attribute("http.host", request.url.hostname or "unknown")
1285 # Add session tracking
1286 if session_id:
1287 span.set_attribute("mcp.session.id", session_id)
1288 span.set_attribute(
1289 "mcp.session.active_count",
1290 self.session_tracker.get_active_session_count(),
1291 )
1292 # Add to baggage for propagation
1293 ctx = baggage.set_baggage("mcp.session.id", session_id)
1294 from opentelemetry import context
1296 token = context.attach(ctx)
1297 else:
1298 token = None
1300 # Add request size if available
1301 content_length = request.headers.get("content-length")
1302 if content_length:
1303 span.set_attribute("http.request.size", int(content_length))
1305 # Add event for request start
1306 span.add_event("http.request.started", {"method": method, "path": path})
1308 try:
1309 response = await call_next(request)
1311 # Add response attributes
1312 span.set_attribute("http.status_code", response.status_code)
1314 # Set span status based on HTTP status
1315 if response.status_code >= 400:
1316 span.set_status(Status(StatusCode.ERROR, f"HTTP {response.status_code}"))
1317 else:
1318 span.set_status(Status(StatusCode.OK))
1320 # Add event for request completion
1321 span.add_event(
1322 "http.request.completed",
1323 {
1324 "method": method,
1325 "path": path,
1326 "status_code": response.status_code,
1327 },
1328 )
1330 # Record HTTP request metrics
1331 try:
1332 from golf.metrics import get_metrics_collector
1334 metrics_collector = get_metrics_collector()
1336 # Clean up path for metrics (remove query params, normalize)
1337 clean_path = path.split("?")[0] # Remove query parameters
1338 if clean_path.startswith("/"):
1339 clean_path = clean_path[1:] or "root" # Remove leading slash, handle root
1341 metrics_collector.increment_http_request(method, response.status_code, clean_path)
1342 metrics_collector.record_http_duration(method, clean_path, time.time() - start_time)
1343 except ImportError:
1344 # Metrics not available, continue without metrics
1345 pass
1347 return response
1348 except Exception as e:
1349 span.record_exception(e)
1350 span.set_status(Status(StatusCode.ERROR, str(e)))
1352 # Add event for error
1353 span.add_event(
1354 "http.request.error",
1355 {
1356 "method": method,
1357 "path": path,
1358 "error.type": type(e).__name__,
1359 "error.message": str(e),
1360 },
1361 )
1363 # Record HTTP error metrics
1364 try:
1365 from golf.metrics import get_metrics_collector
1367 metrics_collector = get_metrics_collector()
1369 # Clean up path for metrics
1370 clean_path = path.split("?")[0]
1371 if clean_path.startswith("/"):
1372 clean_path = clean_path[1:] or "root"
1374 metrics_collector.increment_http_request(method, 500, clean_path) # Assume 500 for exceptions
1375 metrics_collector.increment_error("http", type(e).__name__)
1376 except ImportError:
1377 pass
1379 raise
1380 finally:
1381 if token:
1382 context.detach(token)
1385@asynccontextmanager
1386async def telemetry_lifespan(mcp_instance: Any) -> AsyncGenerator[None, None]:
1387 """Simplified lifespan for telemetry initialization and cleanup."""
1388 global _provider
1390 # Initialize telemetry with the server name
1391 provider = init_telemetry(service_name=mcp_instance.name)
1393 # If provider is None, telemetry is disabled
1394 if provider is None:
1395 # Just yield without any telemetry setup
1396 yield
1397 return
1399 # Try to add session tracking middleware if possible
1400 try:
1401 # Try to add middleware to FastMCP app if it has Starlette app
1402 if hasattr(mcp_instance, "app") or hasattr(mcp_instance, "_app"):
1403 app = getattr(mcp_instance, "app", getattr(mcp_instance, "_app", None))
1404 if app and hasattr(app, "add_middleware"):
1405 app.add_middleware(SessionTracingMiddleware)
1407 # Also try to instrument FastMCP's internal handlers
1408 if hasattr(mcp_instance, "_tool_manager") and hasattr(mcp_instance._tool_manager, "tools"):
1409 # The tools should already be instrumented when they were registered
1410 pass
1412 # Try to patch FastMCP's request handling to ensure context propagation
1413 if hasattr(mcp_instance, "handle_request"):
1414 original_handle_request = mcp_instance.handle_request
1416 async def traced_handle_request(*args: Any, **kwargs: Any) -> Any:
1417 tracer = get_tracer()
1418 with tracer.start_as_current_span("mcp.handle_request") as span:
1419 span.set_attribute("mcp.request.handler", "handle_request")
1420 return await original_handle_request(*args, **kwargs)
1422 mcp_instance.handle_request = traced_handle_request
1424 except Exception:
1425 # Silently continue if middleware setup fails
1426 import traceback
1428 traceback.print_exc()
1430 try:
1431 # Yield control back to FastMCP
1432 yield
1433 finally:
1434 # Cleanup - shutdown the provider
1435 if _provider and hasattr(_provider, "shutdown"):
1436 _provider.force_flush()
1437 _provider.shutdown()
1438 _provider = None