Skip to content

Commit

Permalink
Simplify code in visit_join_over_time_range_node
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Jan 26, 2024
1 parent 90743da commit a0865b7
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,24 +281,23 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
input_data_set = node.parent_node.accept(self)
input_data_set_alias = self._next_unique_table_alias()

metric_time_dimension_spec: Optional[TimeDimensionSpec] = None
metric_time_dimension_instance: Optional[TimeDimensionInstance] = None
for instance in input_data_set.instance_set.time_dimension_instances:
if instance.spec == node.time_dimension_spec_for_join:
metric_time_dimension_instance = instance
metric_time_dimension_spec = instance.spec
break
assert (
metric_time_dimension_instance
), "Specified metric time spec not found in parent data set. This should have been caught by validations."

assert metric_time_dimension_spec
time_spine_data_set_alias = self._next_unique_table_alias()

metric_time_dimension_column_name = self.column_association_resolver.resolve_spec(
metric_time_dimension_spec
metric_time_dimension_instance.spec
).column_name

# Assemble time_spine dataset with metric_time_dimension to join.
# Granularity of time_spine column should match granularity of metric_time column from parent dataset.
assert metric_time_dimension_instance
time_spine_data_set = self._make_time_spine_data_set(
metric_time_dimension_instance=metric_time_dimension_instance,
metric_time_dimension_column_name=metric_time_dimension_column_name,
Expand All @@ -309,12 +308,12 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat

# Figure out which columns correspond to the time dimension that we want to join on.
input_data_set_metric_time_column_association = input_data_set.column_association_for_time_dimension(
metric_time_dimension_spec
metric_time_dimension_instance.spec
)
input_data_set_metric_time_col = input_data_set_metric_time_column_association.column_name

time_spine_data_set_column_associations = time_spine_data_set.column_association_for_time_dimension(
metric_time_dimension_spec
metric_time_dimension_instance.spec
)
time_spine_data_set_time_dimension_col = time_spine_data_set_column_associations.column_name

Expand Down Expand Up @@ -342,7 +341,7 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
[
time_dimension_instance
for time_dimension_instance in input_data_set.instance_set.time_dimension_instances
if time_dimension_instance.spec != metric_time_dimension_spec
if time_dimension_instance != metric_time_dimension_instance
]
),
)
Expand Down

0 comments on commit a0865b7

Please sign in to comment.