Skip to content

Instantly share code, notes, and snippets.

@charbonnierg
Created May 16, 2025 12:36
Show Gist options
  • Select an option

  • Save charbonnierg/b62a5d0fc405f5bbf93d7623d041a0e0 to your computer and use it in GitHub Desktop.

Select an option

Save charbonnierg/b62a5d0fc405f5bbf93d7623d041a0e0 to your computer and use it in GitHub Desktop.
"""Alembic command line wrapper."""
from collections.abc import Sequence
import importlib
import os
from typing import Any
from alembic.config import CommandLine
from alembic.config import Config
class CommandLineWrapper(CommandLine):
"""Wrapper for the Alembic command line interface."""
def main(self, argv: Sequence[str] | None = None) -> None:
"""Execute the command line with the provided arguments."""
options = self.parser.parse_args(argv)
if not hasattr(options, "cmd"):
# see http://bugs.python.org/issue9253, argparse
# behavior changed incompatibly in py3.3
self.parser.error("too few arguments")
# Special case when config is a resource import string
if detect_resource_string(options.config):
factory = options.config
options.config = None
cfg = Config(
file_=None,
ini_section=options.name,
cmd_opts=options,
)
try:
func = load_object_py(factory)
except ImportFromStringError as exc:
message = "Invalid config factory. {exc}"
self.parser.error(message.format(exc=exc))
func(cfg)
# Default case
else:
cfg = Config(
file_=options.config,
ini_section=options.name,
cmd_opts=options,
)
self.run_cmd(cfg, options)
def main(
argv: Sequence[str] | None = None,
prog: str | None = None,
) -> None:
"""Console runner function for Alembic wrapper."""
CommandLineWrapper(prog=prog).main(argv=argv)
def detect_resource_string(import_str: str) -> bool:
"""Names that are non absolute paths and contain a colon
are detected as resource strings."""
return not os.path.isabs(import_str) and ":" in import_str
# Taken from uvicorn
# https://github.com/encode/uvicorn/blob/master/uvicorn/importer.py
class ImportFromStringError:
"""Error raised when an import string is invalid."""
def load_object_py(import_str: str) -> Any: # noqa: ANN401
"""Import an object from a string."""
module_str, _, attrs_str = import_str.partition(":")
if not module_str or not attrs_str:
message = 'Import string "{import_str}" must be '
'in format "<module>:<attribute>".'
raise ImportFromStringError(message.format(import_str=import_str))
try:
module = importlib.import_module(module_str)
except ModuleNotFoundError as exc:
if exc.name != module_str:
raise exc from None
message = 'Could not import module "{module_str}".'
raise ImportFromStringError(
message.format(module_str=module_str)
) from exc
instance = module
try:
for attr_str in attrs_str.split("."):
instance = getattr(instance, attr_str)
except AttributeError as exc:
message = 'Attribute "{attrs_str}" not found in module "{module_str}".'
raise ImportFromStringError(
message.format(attrs_str=attrs_str, module_str=module_str)
) from exc
return instance
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment