1
1
#!/usr/bin/env python
2
2
# -*- coding: utf-8 -*-
3
3
4
+ from typing import TypeVar , List , Dict
4
5
from app .core .config import settings
5
- from sqlalchemy import create_engine
6
+ from sqlalchemy import create_engine , text
6
7
from sqlalchemy .orm import sessionmaker
7
8
from app .models .base import Model
8
9
from app .models import system
9
10
from pathlib import Path
10
11
import orjson
11
12
12
13
14
+ ModelType = TypeVar ("ModelType" , bound = Model )
15
+
16
+
13
17
class InitializeData :
14
18
"""
15
19
初始化数据
16
20
"""
17
21
18
22
SCRIPT_DIR : Path = Path .joinpath (settings .BASE_DIR , 'scripts' , 'initialize' )
19
23
20
- def __init__ (self ):
24
+ def __init__ (self ) -> None :
21
25
self .engine = create_engine (self .__get_db_url (), echo = True , future = True )
22
26
self .DBSession = sessionmaker (bind = self .engine )
23
-
24
- def __get_db_url (self ):
27
+ self .prepare_init_models = [
28
+ system .DeptModel ,
29
+ system .UserModel ,
30
+ system .MenuModel ,
31
+ system .PositionModel ,
32
+ system .RoleModel ,
33
+ system .OperationLogModel ,
34
+ system .RoleDeptsModel ,
35
+ system .RoleMenusModel ,
36
+ system .UserPositionsModel ,
37
+ system .UserRolesModel
38
+ ]
39
+
40
+ def __get_db_url (self ) -> str :
25
41
scheme = settings .SQL_DB_URL .scheme .split ('+' )[0 ]
26
42
new_db_url = settings .SQL_DB_URL .unicode_string ().replace (settings .SQL_DB_URL .scheme , scheme )
27
43
return new_db_url
28
44
29
- def __init_model (self ):
45
+ def __init_model (self ) -> None :
30
46
print ("开始初始化数据库..." )
31
47
Model .metadata .create_all (
32
48
self .engine ,
33
- tables = [
34
- system .DeptModel .__table__ ,
35
- system .UserModel .__table__ ,
36
- system .MenuModel .__table__ ,
37
- system .PositionModel .__table__ ,
38
- system .RoleModel .__table__ ,
39
- system .OperationLogModel .__table__ ,
40
- system .RoleDeptsModel .__table__ ,
41
- system .RoleMenusModel .__table__ ,
42
- system .UserPositionsModel .__table__ ,
43
- system .UserRolesModel .__table__
44
- ]
49
+ tables = [modal .__table__ for modal in self .prepare_init_models ]
45
50
)
46
51
print ("数据库初始化完成!" )
47
52
48
- def __init_data (self ):
53
+ def __init_data (self ) -> None :
49
54
print ("开始初始化数据..." )
50
- self .__init_dept ()
51
- self .__init_user ()
52
- self .__init_menu ()
53
- self .__init_position ()
54
- self .__init_role ()
55
- self .__init_role_depts ()
56
- self .__init_role_menus ()
57
- self .__init_user_positions ()
58
- self .__init_user_roles ()
55
+
56
+ for model in self .prepare_init_models :
57
+ max_rows = self .__generate_data (model )
58
+ self .__update_sequence (model , max_rows )
59
+
59
60
print ("数据初始化完成!" )
60
61
61
- def __generate_data (self , table_name : str , model : Model ) :
62
+ def __generate_data (self , model : ModelType ) -> int :
62
63
session = self .DBSession ()
63
64
65
+ table_name = model .__tablename__
66
+
64
67
data = self .__get_data (table_name )
65
68
objs = [model (** item ) for item in data ]
66
69
session .add_all (objs )
@@ -69,40 +72,41 @@ def __generate_data(self, table_name: str, model: Model):
69
72
session .close ()
70
73
print (f"{ table_name } 表数据已生成!" )
71
74
72
- def __get_data (self , filename : str ):
73
- json_path = Path .joinpath (self .SCRIPT_DIR , 'data' , f'{ filename } .json' )
74
- with open (json_path , 'r' , encoding = 'utf-8' ) as f :
75
- data = orjson .loads (f .read ())
76
- return data
75
+ return len (objs )
77
76
78
- def __init_dept (self ):
79
- self .__generate_data ("system_dept" , system .DeptModel )
77
+ def __get_data (self , filename : str ) -> List [Dict ]:
78
+ try :
79
+ json_path = Path .joinpath (self .SCRIPT_DIR , 'data' , f'{ filename } .json' )
80
+ with open (json_path , 'r' , encoding = 'utf-8' ) as f :
81
+ data = orjson .loads (f .read ())
82
+ return data
80
83
81
- def __init_menu ( self ) :
82
- self . __generate_data ( "system_menu" , system . MenuModel )
84
+ except FileNotFoundError :
85
+ return []
83
86
84
- def __init_position (self ) :
85
- self . __generate_data ( "system_position" , system . PositionModel )
87
+ def __update_sequence (self , model : ModelType , max_rows : int ) -> None :
88
+ table_name = model . __tablename__
86
89
87
- def __init_role (self ):
88
- self .__generate_data ("system_role" , system .RoleModel )
90
+ # 检查模型中是否有自增字段
91
+ sequence_name = None
92
+ for col in model .__table__ .columns :
93
+ if col .autoincrement is True :
94
+ sequence_name = f"{ table_name } _{ col .name } _seq"
95
+ break
89
96
90
- def __init_role_depts (self ):
91
- self .__generate_data ("system_role_depts" , system .RoleDeptsModel )
92
-
93
- def __init_role_menus (self ):
94
- self .__generate_data ("system_role_menus" , system .RoleMenusModel )
95
-
96
- def __init_user (self ):
97
- self .__generate_data ("system_user" , system .UserModel )
97
+ if not sequence_name :
98
+ print (f"{ table_name } 表无需设置自增序列值" )
99
+ return
98
100
99
- def __init_user_positions (self ):
100
- self .__generate_data ("system_user_positions" , system .UserPositionsModel )
101
+ session = self .DBSession ()
101
102
102
- def __init_user_roles (self ):
103
- self .__generate_data ("system_user_roles" , system .UserRolesModel )
103
+ # 更新序列最大值
104
+ new_value = max_rows + 1
105
+ session .execute (text (f"ALTER SEQUENCE { sequence_name } RESTART WITH { new_value } " ))
106
+ session .commit ()
107
+ print (f"{ table_name } 表的自增序列值已更新!" )
104
108
105
- def run (self ):
109
+ def run (self ) -> None :
106
110
"""
107
111
执行初始化
108
112
"""
0 commit comments