Skip to content

Commit

Permalink
chore(sqlalchemy): Remove erroneous SQLAlchemy ORM session.merge oper…
Browse files Browse the repository at this point in the history
…ations (#24776)
  • Loading branch information
john-bodley authored Nov 21, 2023
1 parent e7797b6 commit dd58b31
Show file tree
Hide file tree
Showing 49 changed files with 34 additions and 82 deletions.
2 changes: 1 addition & 1 deletion superset/examples/bart_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "BART lines"
tbl.database = database
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
2 changes: 1 addition & 1 deletion superset/examples/country_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name, schema=schema)
db.session.add(obj)
obj.main_dttm_col = "dttm"
obj.database = database
obj.filter_select_enabled = True
if not any(col.metric_name == "avg__2004" for col in obj.metrics):
col = str(column("2004").compile(db.engine))
obj.metrics.append(SqlMetric(metric_name="avg__2004", expression=f"AVG({col})"))
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj
Expand Down
4 changes: 2 additions & 2 deletions superset/examples/css_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def load_css_templates() -> None:
obj = db.session.query(CssTemplate).filter_by(template_name="Flat").first()
if not obj:
obj = CssTemplate(template_name="Flat")
db.session.add(obj)
css = textwrap.dedent(
"""\
.navbar {
Expand All @@ -51,12 +52,12 @@ def load_css_templates() -> None:
"""
)
obj.css = css
db.session.merge(obj)
db.session.commit()

obj = db.session.query(CssTemplate).filter_by(template_name="Courier Black").first()
if not obj:
obj = CssTemplate(template_name="Courier Black")
db.session.add(obj)
css = textwrap.dedent(
"""\
h2 {
Expand Down Expand Up @@ -96,5 +97,4 @@ def load_css_templates() -> None:
"""
)
obj.css = css
db.session.merge(obj)
db.session.commit()
2 changes: 1 addition & 1 deletion superset/examples/deck.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements

if not dash:
dash = Dashboard()
db.session.add(dash)
dash.published = True
js = POSITION_JSON
pos = json.loads(js)
Expand All @@ -540,5 +541,4 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
dash.dashboard_title = title
dash.slug = slug
dash.slices = slices
db.session.merge(dash)
db.session.commit()
2 changes: 1 addition & 1 deletion superset/examples/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def load_energy(
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "Energy consumption"
tbl.database = database
tbl.filter_select_enabled = True
Expand All @@ -76,7 +77,6 @@ def load_energy(
SqlMetric(metric_name="sum__value", expression=f"SUM({col})")
)

db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()

Expand Down
2 changes: 1 addition & 1 deletion superset/examples/flights.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "Random set of flights in the US"
tbl.database = database
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
print("Done loading table!")
2 changes: 1 addition & 1 deletion superset/examples/long_lat.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name, schema=schema)
db.session.add(obj)
obj.main_dttm_col = "datetime"
obj.database = database
obj.filter_select_enabled = True
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj
Expand Down
2 changes: 1 addition & 1 deletion superset/examples/misc_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def load_misc_dashboard() -> None:

if not dash:
dash = Dashboard()
db.session.add(dash)
js = textwrap.dedent(
"""\
{
Expand Down Expand Up @@ -215,5 +216,4 @@ def load_misc_dashboard() -> None:
dash.position_json = json.dumps(pos, indent=4)
dash.slug = DASH_SLUG
dash.slices = slices
db.session.merge(dash)
db.session.commit()
2 changes: 1 addition & 1 deletion superset/examples/multiformat_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name, schema=schema)
db.session.add(obj)
obj.main_dttm_col = "ds"
obj.database = database
obj.filter_select_enabled = True
Expand All @@ -100,7 +101,6 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
col.python_date_format = dttm_and_expr[0]
col.database_expression = dttm_and_expr[1]
col.is_dttm = True
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj
Expand Down
2 changes: 1 addition & 1 deletion superset/examples/paris.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) ->
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "Map of Paris"
tbl.database = database
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
2 changes: 1 addition & 1 deletion superset/examples/random_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def load_random_time_series_data(
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name, schema=schema)
db.session.add(obj)
obj.main_dttm_col = "ds"
obj.database = database
obj.filter_select_enabled = True
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj
Expand Down
2 changes: 1 addition & 1 deletion superset/examples/sf_population_polygons.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def load_sf_population_polygons(
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "Population density of San Francisco"
tbl.database = database
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
3 changes: 1 addition & 2 deletions superset/examples/tabbed_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def load_tabbed_dashboard(_: bool = False) -> None:

if not dash:
dash = Dashboard()
db.session.add(dash)

js = textwrap.dedent(
"""
Expand Down Expand Up @@ -556,6 +557,4 @@ def load_tabbed_dashboard(_: bool = False) -> None:
dash.slices = slices
dash.dashboard_title = "Tabbed Dashboard"
dash.slug = slug

db.session.merge(dash)
db.session.commit()
4 changes: 2 additions & 2 deletions superset/examples/world_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = utils.readfile(
os.path.join(get_examples_folder(), "countries.md")
)
Expand All @@ -110,7 +111,6 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
SqlMetric(metric_name=metric, expression=f"{aggr_func}({col})")
)

db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()

Expand All @@ -126,6 +126,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s

if not dash:
dash = Dashboard()
db.session.add(dash)
dash.published = True
pos = dashboard_positions
slices = update_slice_ids(pos)
Expand All @@ -134,7 +135,6 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
dash.position_json = json.dumps(pos, indent=4)
dash.slug = slug
dash.slices = slices
db.session.merge(dash)
db.session.commit()


Expand Down
1 change: 0 additions & 1 deletion superset/key_value/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def update(self) -> Optional[Key]:
entry.expires_on = self.expires_on
entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id()
db.session.merge(entry)
db.session.commit()
return Key(id=entry.id, uuid=entry.uuid)

Expand Down
1 change: 0 additions & 1 deletion superset/key_value/commands/upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def upsert(self) -> Key:
entry.expires_on = self.expires_on
entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id()
db.session.merge(entry)
db.session.commit()
return Key(entry.id, entry.uuid)

Expand Down
12 changes: 4 additions & 8 deletions superset/migrations/shared/migrate_viz/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _migrate_temporal_filter(self, rv_data: dict[str, Any]) -> None:
]

@classmethod
def upgrade_slice(cls, slc: Slice) -> Slice:
def upgrade_slice(cls, slc: Slice) -> None:
clz = cls(slc.params)
form_data_bak = copy.deepcopy(clz.data)

Expand All @@ -141,10 +141,9 @@ def upgrade_slice(cls, slc: Slice) -> Slice:
if "form_data" in (query_context := try_load_json(slc.query_context)):
query_context["form_data"] = clz.data
slc.query_context = json.dumps(query_context)
return slc

@classmethod
def downgrade_slice(cls, slc: Slice) -> Slice:
def downgrade_slice(cls, slc: Slice) -> None:
form_data = try_load_json(slc.params)
if "viz_type" in (form_data_bak := form_data.get(FORM_DATA_BAK_FIELD_NAME, {})):
slc.params = json.dumps(form_data_bak)
Expand All @@ -153,7 +152,6 @@ def downgrade_slice(cls, slc: Slice) -> Slice:
if "form_data" in query_context:
query_context["form_data"] = form_data_bak
slc.query_context = json.dumps(query_context)
return slc

@classmethod
def upgrade(cls, session: Session) -> None:
Expand All @@ -162,8 +160,7 @@ def upgrade(cls, session: Session) -> None:
slices,
lambda current, total: print(f"Upgraded {current}/{total} charts"),
):
new_viz = cls.upgrade_slice(slc)
session.merge(new_viz)
cls.upgrade_slice(slc)

@classmethod
def downgrade(cls, session: Session) -> None:
Expand All @@ -177,5 +174,4 @@ def downgrade(cls, session: Session) -> None:
slices,
lambda current, total: print(f"Downgraded {current}/{total} charts"),
):
new_viz = cls.downgrade_slice(slc)
session.merge(new_viz)
cls.downgrade_slice(slc)
1 change: 0 additions & 1 deletion superset/migrations/shared/security_converge.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ def migrate_roles(
if new_pvm not in role.permissions:
logger.info(f"Add {new_pvm} to {role}")
role.permissions.append(new_pvm)
session.merge(role)

# Delete old permissions
_delete_old_permissions(session, pvm_map)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def upgrade():
for slc in session.query(Slice).all():
if slc.datasource:
slc.perm = slc.datasource.perm
session.merge(slc)
session.commit()
db.session.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def upgrade():
slc.datasource_id = slc.druid_datasource_id
if slc.table_id:
slc.datasource_id = slc.table_id
session.merge(slc)
session.commit()
session.close()

Expand All @@ -69,7 +68,6 @@ def downgrade():
slc.druid_datasource_id = slc.datasource_id
if slc.datasource_type == "table":
slc.table_id = slc.datasource_id
session.merge(slc)
session.commit()
session.close()
op.drop_column("slices", "datasource_id")
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def upgrade():
try:
d = json.loads(slc.params or "{}")
slc.params = json.dumps(d, indent=2, sort_keys=True)
session.merge(slc)
session.commit()
print(f"Upgraded ({i}/{slice_len}): {slc.slice_name}")
except Exception as ex:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def upgrade():
"/".join(split[:-1]) + "/?form_data=" + parse.quote_plus(json.dumps(d))
)
url.url = newurl
session.merge(url)
session.commit()
print(f"Updating url ({i}/{urls_len})")
session.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def upgrade():
del params["latitude"]
del params["longitude"]
slc.params = json.dumps(params)
session.merge(slc)
session.commit()
session.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def upgrade():
)
params["annotation_layers"] = new_layers
slc.params = json.dumps(params)
session.merge(slc)
session.commit()
session.close()

Expand All @@ -86,6 +85,5 @@ def downgrade():
if layers:
params["annotation_layers"] = [layer["value"] for layer in layers]
slc.params = json.dumps(params)
session.merge(slc)
session.commit()
session.close()
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def upgrade():
pos["v"] = 1

dashboard.position_json = json.dumps(positions, indent=2)
session.merge(dashboard)
session.commit()

session.close()
Expand All @@ -85,6 +84,5 @@ def downgrade():
pos["v"] = 0

dashboard.position_json = json.dumps(positions, indent=2)
session.merge(dashboard)
session.commit()
pass
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def upgrade():
params["metrics"] = [params.get("metric")]
del params["metric"]
slc.params = json.dumps(params, indent=2, sort_keys=True)
session.merge(slc)
session.commit()
print(f"Upgraded ({i}/{slice_len}): {slc.slice_name}")
except Exception as ex:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,6 @@ def upgrade():

sorted_by_key = collections.OrderedDict(sorted(v2_layout.items()))
dashboard.position_json = json.dumps(sorted_by_key, indent=2)
session.merge(dashboard)
session.commit()
else:
print(f"Skip converted dash_id: {dashboard.id}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def upgrade():
dashboard.id, len(original_text), len(text)
)
)
session.merge(dashboard)
session.commit()


Expand Down
Loading

0 comments on commit dd58b31

Please sign in to comment.