Coverage for dj/api/metrics.py: 100%

43 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1""" 

2Metric related APIs. 

3""" 

4 

5from http import HTTPStatus 

6from typing import List 

7 

8from fastapi import APIRouter, Depends, HTTPException, Query 

9from sqlalchemy.exc import NoResultFound 

10from sqlmodel import Session, select 

11 

12from dj.api.helpers import get_node_by_name 

13from dj.errors import DJError, DJException, ErrorCode 

14from dj.models.metric import Metric 

15from dj.models.node import Node, NodeType 

16from dj.sql.dag import get_dimensions 

17from dj.utils import get_session 

18 

19router = APIRouter() 

20 

21 

22def get_metric(session: Session, name: str) -> Node: 

23 """ 

24 Return a metric node given a node name. 

25 """ 

26 node = get_node_by_name(session, name) 

27 if node.type != NodeType.METRIC: 

28 raise HTTPException( 

29 status_code=HTTPStatus.BAD_REQUEST, 

30 detail=f"Not a metric node: `{name}`", 

31 ) 

32 return node 

33 

34 

35@router.get("/metrics/", response_model=List[Metric]) 

36def list_metrics(*, session: Session = Depends(get_session)) -> List[Metric]: 

37 """ 

38 List all available metrics. 

39 """ 

40 return [ 

41 Metric.parse_node(node) 

42 for node in ( 

43 session.exec( 

44 select(Node).where(Node.type == NodeType.METRIC), 

45 ) 

46 ) 

47 ] 

48 

49 

50@router.get("/metrics/{name}/", response_model=Metric) 

51def get_a_metric(name: str, *, session: Session = Depends(get_session)) -> Metric: 

52 """ 

53 Return a metric by name. 

54 """ 

55 node = get_metric(session, name) 

56 return Metric.parse_node(node) 

57 

58 

59@router.get("/metrics/common/dimensions/", response_model=List[str]) 

60async def get_common_dimensions( 

61 metric: List[str] = Query( 

62 title="List of metrics to find common dimensions for", 

63 default=[], 

64 ), 

65 session: Session = Depends(get_session), 

66) -> List[str]: 

67 """ 

68 Return common dimensions for a set of metrics. 

69 """ 

70 metric_nodes = [] 

71 errors = [] 

72 for node_name in metric: 

73 statement = select(Node).where(Node.name == node_name) 

74 try: 

75 node = session.exec(statement).one() 

76 if node.type != NodeType.METRIC: 

77 errors.append( 

78 DJError( 

79 message=f"Not a metric node: {node_name}", 

80 code=ErrorCode.NODE_TYPE_ERROR, 

81 ), 

82 ) 

83 metric_nodes.append(node) 

84 except NoResultFound: 

85 errors.append( 

86 DJError( 

87 message=f"Metric node not found: {node_name}", 

88 code=ErrorCode.UNKNOWN_NODE, 

89 ), 

90 ) 

91 

92 if errors: 

93 raise DJException(errors=errors) 

94 

95 common = set(get_dimensions(metric_nodes[0])) 

96 for node in set(metric_nodes[1:]): 

97 common.intersection_update(get_dimensions(node)) 

98 

99 return list(common)