Skip to content

Commit

Permalink
[fix] Improve logic for local deployment of PythonScript (#168)
Browse files Browse the repository at this point in the history
* add spark_python_task support in Workflow class and update example workflows

* fix: tests and expected bundle

* feat: add adjustment of python file in case of local for spark_python

* fix: adjust logic for workspace

* feat: add common_task_parameters support

* fix: missed slash

* fix: improve tests

* feat: update poetry lock

* Feat/workflow parameters (#1)

* Add workflows parameters

* feat: add JobsParameters import in sample_workflows

* fix: adjust logic for local deployment in PythonScript

* fix: update version number to 0.11.0a0 in pyproject.toml
  • Loading branch information
mikita-sakalouski authored Oct 18, 2024
1 parent 922953d commit f133723
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
27 changes: 23 additions & 4 deletions brickflow/codegen/databricks_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
Targets,
Workspace,
)
from brickflow.cli.projects import MultiProjectManager, get_brickflow_root
from brickflow.codegen import (
CodegenInterface,
DatabricksDefaultClusterTagKeys,
Expand Down Expand Up @@ -461,12 +462,30 @@ def adjust_file_path(self, file_path: str) -> str:
]
).replace("//", "/")

# Finds the start position of the project name in the given file path and calculates the cut position.
# - `file_path.find(self.project.name)`: Finds the start index of the project name in the file path.
# - `+ len(self.project.name) + 1`: Moves the start position to the character after the project name.
multi_project_manager = MultiProjectManager(
config_file_name=str(get_brickflow_root())
)
bf_project = multi_project_manager.get_project(self.project.name)

start_index_of_project_root = file_path.find(
bf_project.path_from_repo_root_to_project_root
)

if start_index_of_project_root < 0:
raise ValueError(
f"Error while adjusting file path. "
f"Project root not found in the file path: {file_path}."
)

# Finds the start position of the path_from_repo_root_to_project_root in the given file path
# and calculates the cut position.
# - `file_path.find: Finds the start index of the project root in the file path.
# - `+ len + 1`: Moves the start position to the character after the project root.
# - Adjusts the file path by appending the local bundle path to the cut file path.
cut_file_path = file_path[
file_path.find(self.project.name) + len(self.project.name) + 1 :
start_index_of_project_root
+ len(bf_project.path_from_repo_root_to_project_root)
+ 1 :
]
file_path = (
bundle_files_local_path + file_path
Expand Down
11 changes: 11 additions & 0 deletions tests/codegen/test_databricks_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class TestBundleCodegen(TestCase):
BrickflowEnvVars.BRICKFLOW_PROJECT_TAGS.value: "tag1 = value1, tag2 =value2 ", # spaces will be trimmed
},
)
@patch("brickflow.codegen.databricks_bundle.MultiProjectManager")
@patch("brickflow.engine.task.get_job_id", return_value=12345678901234.0)
@patch("subprocess.check_output")
@patch("brickflow.context.ctx.get_parameter")
Expand All @@ -96,12 +97,16 @@ def test_generate_bundle_local(
dbutils: Mock,
sub_proc_mock: Mock,
get_job_id_mock: Mock,
multi_project_manager_mock: Mock,
):
dbutils.return_value = None
sub_proc_mock.return_value = b""
bf_version_mock.return_value = "1.0.0"
workspace_client = get_workspace_client_mock()
get_job_id_mock.return_value = 12345678901234.0
multi_project_manager_mock.return_value.get_project.return_value = MagicMock(
path_from_repo_root_to_project_root="test-project"
)
# get caller part breaks here
with Project(
"test-project",
Expand Down Expand Up @@ -138,6 +143,7 @@ def test_generate_bundle_local(
BrickflowEnvVars.BRICKFLOW_WORKFLOW_SUFFIX.value: "_suffix",
},
)
@patch("brickflow.codegen.databricks_bundle.MultiProjectManager")
@patch("brickflow.engine.task.get_job_id", return_value=12345678901234.0)
@patch("subprocess.check_output")
@patch("brickflow.context.ctx.get_parameter")
Expand All @@ -146,18 +152,23 @@ def test_generate_bundle_local(
"brickflow.context.ctx.get_current_timestamp",
MagicMock(return_value=1704067200000),
)
# @patch()
def test_generate_bundle_local_prefix_suffix(
self,
bf_version_mock: Mock,
dbutils: Mock,
sub_proc_mock: Mock,
get_job_id_mock: Mock,
multi_project_manager_mock: Mock,
):
dbutils.return_value = None
sub_proc_mock.return_value = b""
bf_version_mock.return_value = "1.0.0"
workspace_client = get_workspace_client_mock()
get_job_id_mock.return_value = 12345678901234.0
multi_project_manager_mock.return_value.get_project.return_value = MagicMock(
path_from_repo_root_to_project_root="test-project"
)
# get caller part breaks here
with Project(
"test-project",
Expand Down

0 comments on commit f133723

Please sign in to comment.