Skip to content

Commit

Permalink
Merge pull request #49 from k2bd/feat/migration-report
Browse files Browse the repository at this point in the history
Some improvements to CLI
  • Loading branch information
k2bd authored Oct 22, 2024
2 parents c254042 + 685b920 commit b1840f2
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 52 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ flux new "Initial tables"
There are two forms that migrations can take in ``flux`` - Python files and sql files.
Both forms define up and optional down migrations.

In each case, you can append ``--pre`` or ``--post`` to create pre-apply and post-apply migrations.
These will run before/after any batch of migrations are run (by default, they're also run before/after rollbacks, but this can be disabled)

### Migrations as Python files

By default ``flux`` creates Python migration files when you run ``flux new "My new migration"``.
Expand Down Expand Up @@ -118,7 +121,7 @@ def apply():
### Migrations as sql files

It may be that you prefer just writing sql files for your migrations, and you just want ``flux`` for its flexibility or testing functionality.
That's cool too, just run ``flux new --kind sql "My new migration"``.
That's cool too, just run ``flux new --sql "My new migration"``.

Up-migration files are just files ending with ``.sql``. They can have down-migration counterparts ending with ``.undo.sql``.

Expand Down
126 changes: 90 additions & 36 deletions src/flux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import datetime as dt
import os
from dataclasses import dataclass
from enum import Enum
from typing import Optional

import typer
Expand All @@ -14,7 +13,12 @@

from flux.backend.get_backends import get_backend
from flux.config import FluxConfig
from flux.constants import FLUX_CONFIG_FILE, FLUX_DEFAULT_MIGRATION_DIRECTORY
from flux.constants import (
FLUX_CONFIG_FILE,
FLUX_DEFAULT_MIGRATION_DIRECTORY,
POST_APPLY_DIRECTORY,
PRE_APPLY_DIRECTORY,
)
from flux.exceptions import BackendNotInstalledError
from flux.runner import FluxRunner

Expand Down Expand Up @@ -105,18 +109,31 @@ def init(
print(f"Created {FLUX_CONFIG_FILE}")


class MigrationKind(str, Enum):
sql = "sql"
python = "python"


async def _new(ctx: typer.Context, name: str, kind: MigrationKind):
async def _new(
ctx: typer.Context,
name: str,
sql: bool,
pre: bool,
post: bool,
):
config: FluxConfig | None = ctx.obj.config
if config is None:
print("Please run `flux init` to create a configuration file")
raise typer.Exit(code=1)

os.makedirs(config.migration_directory, exist_ok=True)
if pre and post:
print("Cannot create migration with both --pre and --post")
raise typer.Exit(code=1)

repeatable = pre or post

target_dir = config.migration_directory
if pre:
target_dir = os.path.join(target_dir, PRE_APPLY_DIRECTORY)
if post:
target_dir = os.path.join(target_dir, POST_APPLY_DIRECTORY)

os.makedirs(target_dir, exist_ok=True)

date_part = dt.date.today().strftime("%Y%m%d")
migration_index = 1
Expand All @@ -126,7 +143,7 @@ async def _new(ctx: typer.Context, name: str, kind: MigrationKind):
def migration_filename_prefix() -> str:
return f"{date_part}_{migration_index:>03}"

migration_filenames = os.listdir(config.migration_directory)
migration_filenames = os.listdir(target_dir)
while any(
[
filename.startswith(migration_filename_prefix())
Expand All @@ -137,20 +154,17 @@ def migration_filename_prefix() -> str:

migration_basename = f"{migration_filename_prefix()}_{name_part}"

if kind == MigrationKind.sql:
with open(
os.path.join(config.migration_directory, f"{migration_basename}.sql"), "w"
) as f:
if sql:
with open(os.path.join(target_dir, f"{migration_basename}.sql"), "w") as f:
f.write("")
with open(
os.path.join(config.migration_directory, f"{migration_basename}.undo.sql"),
"w",
) as f:
f.write("")
elif kind == MigrationKind.python:
with open(
os.path.join(config.migration_directory, f"{migration_basename}.py"), "w"
) as f:
if not repeatable:
with open(
os.path.join(target_dir, f"{migration_basename}.undo.sql"),
"w",
) as f:
f.write("")
else:
with open(os.path.join(target_dir, f"{migration_basename}.py"), "w") as f:
f.write(
f'''"""
{name}
Expand All @@ -160,27 +174,46 @@ def migration_filename_prefix() -> str:
def apply():
return """ """
'''
)
if not repeatable:
f.write(
'''
def undo():
return """ """
'''
)
else:
print(f"Invalid migration type {kind}")
raise typer.Exit(code=1)
)


@app.command()
def new(
ctx: typer.Context,
name: Annotated[str, typer.Argument(help="Migration name and default comment")],
kind: MigrationKind = MigrationKind.python,
sql: Annotated[bool, typer.Option("--sql")] = False,
pre: Annotated[bool, typer.Option("--pre")] = False,
post: Annotated[bool, typer.Option("--post")] = False,
):
async_run(_new(ctx=ctx, name=name, kind=kind))
async_run(_new(ctx=ctx, name=name, sql=sql, pre=pre, post=post))


async def _print_apply_report(runner: FluxRunner, n: int | None):
def _print_status_report(runner: FluxRunner):
table = Table(title="Status")
table.add_column("ID")
table.add_column("Status")

for migration in runner.list_applied_migrations():
table.add_row(migration.id, APPLIED_STATUS)

for migration in runner.list_unapplied_migrations():
status = NOT_APPLIED_STATUS
table.add_row(migration.id, status)

console = Console()
console.print(table)


def _print_apply_report(runner: FluxRunner, n: int | None):
table = Table(title="Apply Migrations")
table.add_column("ID")
table.add_column("Status")
Expand All @@ -202,7 +235,7 @@ async def _print_apply_report(runner: FluxRunner, n: int | None):
console.print(table)


async def _print_rollback_report(runner: FluxRunner, n: int | None):
def _print_rollback_report(runner: FluxRunner, n: int | None):
table = Table(title="Rollback Migrations")
table.add_column("ID")
table.add_column("Status")
Expand All @@ -224,6 +257,19 @@ async def _print_rollback_report(runner: FluxRunner, n: int | None):
console.print(table)


async def _status(connection_uri: str):
async with FluxRunner.from_file(
path=FLUX_CONFIG_FILE,
connection_uri=connection_uri,
) as runner:
_print_status_report(runner=runner)


@app.command()
def status(connection_uri: str):
async_run(_status(connection_uri=connection_uri))


async def _apply(
ctx: typer.Context,
connection_uri: str,
Expand All @@ -238,7 +284,7 @@ async def _apply(
path=FLUX_CONFIG_FILE,
connection_uri=connection_uri,
) as runner:
await _print_apply_report(runner=runner, n=n)
_print_apply_report(runner=runner, n=n)
if not auto_approve:
if not Confirm.ask("Apply these migrations?"):
raise typer.Exit(1)
Expand Down Expand Up @@ -269,6 +315,7 @@ async def _rollback(
connection_uri: str,
n: int | None,
auto_approve: bool = False,
repeatable: bool | None = None,
):
config: FluxConfig | None = ctx.obj.config
if config is None:
Expand All @@ -278,12 +325,12 @@ async def _rollback(
path=FLUX_CONFIG_FILE,
connection_uri=connection_uri,
) as runner:
await _print_rollback_report(runner=runner, n=n)
_print_rollback_report(runner=runner, n=n)
if not auto_approve:
if not Confirm.ask("Undo these migrations?"):
raise typer.Exit(1)

await runner.rollback_migrations(n=n)
await runner.rollback_migrations(n=n, apply_repeatable=repeatable)


@app.command()
Expand All @@ -299,7 +346,14 @@ def rollback(
),
] = None,
auto_approve: bool = False,
repeatable: bool | None = None,
):
async_run(
_rollback(ctx, connection_uri=connection_uri, n=n, auto_approve=auto_approve)
_rollback(
ctx,
connection_uri=connection_uri,
n=n,
auto_approve=auto_approve,
repeatable=repeatable,
)
)
3 changes: 2 additions & 1 deletion src/flux/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def from_file(cls, path: str):
)

apply_repeatable_on_down = general_config.get(
FLUX_APPLY_REPEATABLE_ON_DOWN_KEY, FLUX_DEFAULT_APPLY_REPEATABLE_ON_DOWN
FLUX_APPLY_REPEATABLE_ON_DOWN_KEY,
FLUX_DEFAULT_APPLY_REPEATABLE_ON_DOWN,
)

log_level = general_config.get(FLUX_LOG_LEVEL_KEY, FLUX_DEFAULT_LOG_LEVEL)
Expand Down
2 changes: 1 addition & 1 deletion src/flux/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@

FLUX_DEFAULT_MIGRATION_DIRECTORY = "migrations"
FLUX_DEFAULT_LOG_LEVEL = "INFO"
FLUX_DEFAULT_APPLY_REPEATABLE_ON_DOWN = False
FLUX_DEFAULT_APPLY_REPEATABLE_ON_DOWN = True
24 changes: 19 additions & 5 deletions src/flux/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,24 @@ def migrations_to_rollback(self, n: int | None = None) -> list[Migration]:
)
return migrations_to_rollback[::-1]

async def rollback_migrations(self, n: int | None = None):
async def rollback_migrations(
self,
n: int | None = None,
apply_repeatable: bool | None = None,
):
"""
Rollback applied migrations from the database, applying any undo
migrations if they exist.
"""
await self.validate_applied_migrations()

if self.config.apply_repeatable_on_down:
should_apply_repeatable = (
apply_repeatable
if apply_repeatable is not None
else self.config.apply_repeatable_on_down
)

if should_apply_repeatable:
await self._apply_pre_apply_migrations()

migrations_to_rollback = self.migrations_to_rollback(n=n)
Expand All @@ -198,13 +208,17 @@ async def rollback_migrations(self, n: int | None = None):
f"Failed to rollback migration {migration.id if migration else ''}"
) from e
finally:
if self.config.apply_repeatable_on_down:
if should_apply_repeatable:
async with self.backend.transaction():
await self._apply_post_apply_migrations()

self.applied_migrations = await self.backend.get_applied_migrations()

async def rollback_migration(self, migration_id: str):
async def rollback_migration(
self,
migration_id: str,
apply_repeatable: bool | None = None,
):
"""
Rollback all migrations up to and including the given migration ID
"""
Expand All @@ -222,4 +236,4 @@ async def rollback_migration(self, migration_id: str):

n = len(applied_migrations) - target_migration_index

await self.rollback_migrations(n=n)
await self.rollback_migrations(n=n, apply_repeatable=apply_repeatable)
2 changes: 1 addition & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def example_config(
backend: str,
migration_directory: str,
log_level: str = "DEBUG",
apply_repeatable_on_down: bool = False,
apply_repeatable_on_down: bool = True,
backend_config: dict | None = None,
):
return FluxConfig(
Expand Down
Loading

0 comments on commit b1840f2

Please sign in to comment.