Skip to content

Commit 9a5c984

Browse files
committed
🚸 fix(alembic): intuitive default args for revision
- use main version location if neither `version_path` nor `branch_label` provided - smartly select `head` if not provided
1 parent 07f6eb4 commit 9a5c984

File tree

1 file changed

+36
-21
lines changed

1 file changed

+36
-21
lines changed

nonebot_plugin_orm/migrate.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pprint import pformat
99
from argparse import Namespace
1010
from operator import attrgetter
11+
from itertools import filterfalse
1112
from typing import Any, TextIO, cast
1213
from tempfile import TemporaryDirectory
1314
from configparser import DuplicateSectionError
@@ -180,11 +181,12 @@ def move_script(self, script: Script) -> Path:
180181
return script_path
181182

182183
plugin_name = script_path.parent.name
183-
if version_location := self._plugin_version_locations.get(plugin_name):
184-
pass
185-
elif version_location := self._plugin_version_locations.get(""):
186-
plugin_name = ""
187-
else:
184+
version_location = self._plugin_version_locations.get(plugin_name)
185+
186+
if not version_location:
187+
version_location = self._plugin_version_locations.get("")
188+
189+
if not version_location:
188190
self.print_stdout(
189191
f'无法找到 {plugin_name or "<default>"} 对应的版本目录, 忽略 "{script.path}"',
190192
fg="yellow",
@@ -406,7 +408,7 @@ def revision(
406408
config: `AlembicConfig` 对象
407409
message: 迁移的描述
408410
sql: 是否以 SQL 的形式输出迁移脚本
409-
head: 迁移的基准版本, 提供了 branch_label 时默认为 'base', 否则默认为 'head'
411+
head: 迁移的基准版本, 如果提供了 branch_label 默认为 `branch_label@head`, 否则为主分支的头
410412
splice: 是否将迁移作为一个新的分支的头; 当 `head` 不是一个分支的头时, 此项必须为 `True`
411413
branch_label: 迁移的分支标签
412414
version_path: 存放迁移脚本的目录
@@ -416,24 +418,12 @@ def revision(
416418
"""
417419
from . import _plugins
418420

419-
if head is None:
420-
head = "base" if branch_label else "head"
421-
422-
if not version_path and branch_label and (plugin := _plugins.get(branch_label)):
423-
version_path = str(
424-
config._temp_dir.joinpath(
425-
*map(
426-
attrgetter("name"),
427-
reversed(list(get_parent_plugins(plugin))),
428-
)
429-
)
430-
)
431-
elif version_path:
421+
if version_path:
432422
version_path = Path(version_path).resolve()
433423
version_locations = config.get_main_option("version_locations", "")
434424
pathsep = _SPLIT_ON_PATH[config.get_main_option("version_path_separator")]
435425

436-
if version_path in (
426+
if version_path not in (
437427
Path(path).resolve() for path in version_locations.split(pathsep)
438428
):
439429
config.set_main_option(
@@ -442,9 +432,34 @@ def revision(
442432
logger.warning(
443433
f'临时将目录 "{version_path}" 添加到版本目录中, 请稍后将其添加到 ALEMBIC_VERSION_LOCATIONS 中'
444434
)
435+
elif branch_label and (plugin := _plugins.get(branch_label)):
436+
version_path = config._temp_dir.joinpath(
437+
*map(
438+
attrgetter("name"),
439+
reversed(list(get_parent_plugins(plugin))),
440+
)
441+
)
442+
else:
443+
version_path = config._temp_dir
445444

446445
script = ScriptDirectory.from_config(config)
447446

447+
if not head:
448+
if branch_label:
449+
head = f"{branch_label}@head"
450+
elif len(heads := script.get_heads()) <= 1:
451+
head = "head"
452+
else:
453+
try:
454+
head = next(
455+
filterfalse(
456+
attrgetter("branch_labels"),
457+
script.get_revisions(heads),
458+
)
459+
).revision
460+
except StopIteration:
461+
head = "base"
462+
448463
revision_context = RevisionContext(
449464
config,
450465
script,
@@ -455,7 +470,7 @@ def revision(
455470
head=head,
456471
splice=splice,
457472
branch_label=branch_label,
458-
version_path=version_path,
473+
version_path=str(version_path),
459474
rev_id=rev_id,
460475
depends_on=depends_on,
461476
),

0 commit comments

Comments
 (0)