Coverage for dj/models/database.py: 100%

26 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-17 20:05 -0700

1""" 

2Models for databases. 

3""" 

4 

5from datetime import datetime, timezone 

6from functools import partial 

7from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict 

8from uuid import UUID, uuid4 

9 

10from sqlalchemy import DateTime, String 

11from sqlalchemy.sql.schema import Column as SqlaColumn 

12from sqlalchemy_utils import UUIDType 

13from sqlmodel import JSON, Field, Relationship 

14 

15from dj.models.base import BaseSQLModel 

16from dj.typing import UTCDatetime 

17 

18if TYPE_CHECKING: 

19 from dj.models.catalog import Catalog 

20 from dj.models.table import Table 

21 

22 

23# Schema of a database in the YAML file. 

24DatabaseYAML = TypedDict( 

25 "DatabaseYAML", 

26 {"description": str, "URI": str, "read-only": bool, "async_": bool, "cost": float}, 

27 total=False, 

28) 

29 

30 

31class Database(BaseSQLModel, table=True): # type: ignore 

32 """ 

33 A database. 

34 

35 A simple example: 

36 

37 name: druid 

38 description: An Apache Druid database 

39 URI: druid://localhost:8082/druid/v2/sql/ 

40 read-only: true 

41 async_: false 

42 cost: 1.0 

43 

44 """ 

45 

46 id: Optional[int] = Field(default=None, primary_key=True) 

47 uuid: UUID = Field(default_factory=uuid4, sa_column=SqlaColumn(UUIDType())) 

48 

49 name: str = Field(sa_column=SqlaColumn("name", String, unique=True)) 

50 description: str = "" 

51 URI: str 

52 extra_params: Dict = Field(default={}, sa_column=SqlaColumn(JSON)) 

53 read_only: bool = True 

54 async_: bool = Field(default=False, sa_column_kwargs={"name": "async"}) 

55 cost: float = 1.0 

56 

57 created_at: UTCDatetime = Field( 

58 sa_column=SqlaColumn(DateTime(timezone=True)), 

59 default_factory=partial(datetime.now, timezone.utc), 

60 ) 

61 updated_at: UTCDatetime = Field( 

62 sa_column=SqlaColumn(DateTime(timezone=True)), 

63 default_factory=partial(datetime.now, timezone.utc), 

64 ) 

65 

66 tables: List["Table"] = Relationship( 

67 back_populates="database", 

68 sa_relationship_kwargs={"cascade": "all, delete"}, 

69 ) 

70 

71 def __hash__(self) -> int: 

72 return hash(self.id)