Coverage for dj/api/tags.py: 100%

57 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1""" 

2Tag related APIs. 

3""" 

4 

5from typing import List, Optional 

6 

7from fastapi import APIRouter, Depends 

8from sqlalchemy.orm import joinedload 

9from sqlmodel import Session, select 

10 

11from dj.errors import DJException 

12from dj.models.node import NodeType 

13from dj.models.tag import CreateTag, Tag, TagOutput, UpdateTag 

14from dj.utils import get_session 

15 

16router = APIRouter() 

17 

18 

19def get_tag_by_name( 

20 session: Session, 

21 name: str, 

22 raise_if_not_exists: bool = False, 

23 for_update: bool = False, 

24): 

25 """ 

26 Retrieves a tag by its name. 

27 """ 

28 statement = select(Tag).where(Tag.name == name) 

29 if for_update: 

30 statement = statement.with_for_update().execution_options( 

31 populate_existing=True, 

32 ) 

33 tag = session.exec(statement).one_or_none() 

34 if not tag and raise_if_not_exists: 

35 raise DJException( 

36 message=(f"A tag with name `{name}` does not exist."), 

37 http_status_code=404, 

38 ) 

39 return tag 

40 

41 

42@router.get("/tags/", response_model=List[TagOutput]) 

43def list_tags( 

44 tag_type: Optional[str] = None, *, session: Session = Depends(get_session) 

45) -> List[TagOutput]: 

46 """ 

47 List all available tags. 

48 """ 

49 statement = select(Tag) 

50 if tag_type: 

51 statement = statement.where(Tag.tag_type == tag_type) 

52 return session.exec(statement).all() 

53 

54 

55@router.get("/tags/{name}/", response_model=Tag) 

56def get_a_tag(name: str, *, session: Session = Depends(get_session)) -> Tag: 

57 """ 

58 Return a tag by name. 

59 """ 

60 tag = get_tag_by_name(session, name, raise_if_not_exists=True) 

61 return tag 

62 

63 

64@router.post("/tags/", response_model=Tag, status_code=201) 

65def create_a_tag( 

66 data: CreateTag, 

67 session: Session = Depends(get_session), 

68) -> Tag: 

69 """ 

70 Create a tag. 

71 """ 

72 tag = get_tag_by_name(session, data.name, raise_if_not_exists=False) 

73 if tag: 

74 raise DJException( 

75 message=f"A tag with name `{data.name}` already exists!", 

76 http_status_code=500, 

77 ) 

78 tag = Tag.from_orm(data) 

79 session.add(tag) 

80 session.commit() 

81 session.refresh(tag) 

82 return tag 

83 

84 

85@router.patch("/tags/{name}/", response_model=Tag) 

86def update_a_tag( 

87 name: str, 

88 data: UpdateTag, 

89 session: Session = Depends(get_session), 

90) -> Tag: 

91 """ 

92 Update a tag. 

93 """ 

94 tag = get_tag_by_name(session, name, raise_if_not_exists=True, for_update=True) 

95 

96 if data.description: 

97 tag.description = data.description 

98 if data.tag_metadata: 

99 tag.tag_metadata = data.tag_metadata 

100 session.add(tag) 

101 session.commit() 

102 session.refresh(tag) 

103 return tag 

104 

105 

106@router.get("/tags/{name}/nodes/", response_model=List[str]) 

107def list_nodes_for_a_tag( 

108 name: str, 

109 node_type: Optional[NodeType] = None, 

110 *, 

111 session: Session = Depends(get_session), 

112) -> List[str]: 

113 """ 

114 Find nodes tagged with the tag, filterable by node type. 

115 """ 

116 statement = select(Tag).where(Tag.name == name).options(joinedload(Tag.nodes)) 

117 tag = session.exec(statement).unique().one_or_none() 

118 if not tag: 

119 raise DJException( 

120 message=f"A tag with name `{name}` does not exist.", 

121 http_status_code=404, 

122 ) 

123 if not node_type: 

124 return sorted([node.name for node in tag.nodes]) 

125 return sorted([node.name for node in tag.nodes if node.type == node_type])