8
8
from pprint import pformat
9
9
from argparse import Namespace
10
10
from operator import attrgetter
11
+ from itertools import filterfalse
11
12
from typing import Any , TextIO , cast
12
13
from tempfile import TemporaryDirectory
13
14
from configparser import DuplicateSectionError
@@ -180,11 +181,12 @@ def move_script(self, script: Script) -> Path:
180
181
return script_path
181
182
182
183
plugin_name = script_path .parent .name
183
- if version_location := self ._plugin_version_locations .get (plugin_name ):
184
- pass
185
- elif version_location := self ._plugin_version_locations .get ("" ):
186
- plugin_name = ""
187
- else :
184
+ version_location = self ._plugin_version_locations .get (plugin_name )
185
+
186
+ if not version_location :
187
+ version_location = self ._plugin_version_locations .get ("" )
188
+
189
+ if not version_location :
188
190
self .print_stdout (
189
191
f'无法找到 { plugin_name or "<default>" } 对应的版本目录, 忽略 "{ script .path } "' ,
190
192
fg = "yellow" ,
@@ -406,7 +408,7 @@ def revision(
406
408
config: `AlembicConfig` 对象
407
409
message: 迁移的描述
408
410
sql: 是否以 SQL 的形式输出迁移脚本
409
- head: 迁移的基准版本, 提供了 branch_label 时默认为 'base', 否则默认为 'head'
411
+ head: 迁移的基准版本, 如果提供了 branch_label 默认为 `branch_label@head`, 否则为主分支的头
410
412
splice: 是否将迁移作为一个新的分支的头; 当 `head` 不是一个分支的头时, 此项必须为 `True`
411
413
branch_label: 迁移的分支标签
412
414
version_path: 存放迁移脚本的目录
@@ -416,24 +418,12 @@ def revision(
416
418
"""
417
419
from . import _plugins
418
420
419
- if head is None :
420
- head = "base" if branch_label else "head"
421
-
422
- if not version_path and branch_label and (plugin := _plugins .get (branch_label )):
423
- version_path = str (
424
- config ._temp_dir .joinpath (
425
- * map (
426
- attrgetter ("name" ),
427
- reversed (list (get_parent_plugins (plugin ))),
428
- )
429
- )
430
- )
431
- elif version_path :
421
+ if version_path :
432
422
version_path = Path (version_path ).resolve ()
433
423
version_locations = config .get_main_option ("version_locations" , "" )
434
424
pathsep = _SPLIT_ON_PATH [config .get_main_option ("version_path_separator" )]
435
425
436
- if version_path in (
426
+ if version_path not in (
437
427
Path (path ).resolve () for path in version_locations .split (pathsep )
438
428
):
439
429
config .set_main_option (
@@ -442,9 +432,34 @@ def revision(
442
432
logger .warning (
443
433
f'临时将目录 "{ version_path } " 添加到版本目录中, 请稍后将其添加到 ALEMBIC_VERSION_LOCATIONS 中'
444
434
)
435
+ elif branch_label and (plugin := _plugins .get (branch_label )):
436
+ version_path = config ._temp_dir .joinpath (
437
+ * map (
438
+ attrgetter ("name" ),
439
+ reversed (list (get_parent_plugins (plugin ))),
440
+ )
441
+ )
442
+ else :
443
+ version_path = config ._temp_dir
445
444
446
445
script = ScriptDirectory .from_config (config )
447
446
447
+ if not head :
448
+ if branch_label :
449
+ head = f"{ branch_label } @head"
450
+ elif len (heads := script .get_heads ()) <= 1 :
451
+ head = "head"
452
+ else :
453
+ try :
454
+ head = next (
455
+ filterfalse (
456
+ attrgetter ("branch_labels" ),
457
+ script .get_revisions (heads ),
458
+ )
459
+ ).revision
460
+ except StopIteration :
461
+ head = "base"
462
+
448
463
revision_context = RevisionContext (
449
464
config ,
450
465
script ,
@@ -455,7 +470,7 @@ def revision(
455
470
head = head ,
456
471
splice = splice ,
457
472
branch_label = branch_label ,
458
- version_path = version_path ,
473
+ version_path = str ( version_path ) ,
459
474
rev_id = rev_id ,
460
475
depends_on = depends_on ,
461
476
),
0 commit comments