Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable querying cumulative metrics with agg_time_dimension #999

Merged
merged 13 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240125-220047.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Enable querying cumulative metrics with their agg_time_dimension.
time: 2024-01-25T22:00:47.648696-08:00
custom:
Author: courtneyholcomb
Issue: "1000"
29 changes: 22 additions & 7 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,12 +1297,26 @@ def _build_aggregated_measure_from_measure_source_node(
f"Recipe not found for measure spec: {measure_spec} and linkable specs: {required_linkable_specs}"
)

# If a cumulative metric is queried with metric_time, join over time range.
queried_agg_time_dimension_specs = list(queried_linkable_specs.metric_time_specs)
if not queried_agg_time_dimension_specs:
valid_agg_time_dimensions = self._semantic_model_lookup.get_agg_time_dimension_specs_for_measure(
measure_spec.reference
)
queried_agg_time_dimension_specs = list(
set(queried_linkable_specs.time_dimension_specs).intersection(set(valid_agg_time_dimensions))
)

# If a cumulative metric is queried with agg_time_dimension, join over time range.
# Otherwise, the measure will be aggregated over all time.
time_range_node: Optional[JoinOverTimeRangeNode] = None
if cumulative and queried_linkable_specs.contains_metric_time:
if cumulative and queried_agg_time_dimension_specs:
# Use the time dimension spec with the smallest granularity.
agg_time_dimension_spec_for_join = sorted(
queried_agg_time_dimension_specs, key=lambda spec: spec.time_granularity.to_int()
)[0]
time_range_node = JoinOverTimeRangeNode(
parent_node=measure_recipe.source_node,
time_dimension_spec_for_join=agg_time_dimension_spec_for_join,
window=cumulative_window,
grain_to_date=cumulative_grain_to_date,
time_range_constraint=time_range_constraint
Expand Down Expand Up @@ -1356,12 +1370,13 @@ def _build_aggregated_measure_from_measure_source_node(
else:
unaggregated_measure_node = filtered_measure_source_node

# If time constraint was previously adjusted for cumulative window or grain, apply original time constraint
# here. Can skip if metric is being aggregated over all time.
cumulative_metric_constrained_node: Optional[ConstrainTimeRangeNode] = None
if (
cumulative_metric_adjusted_time_constraint is not None
and time_range_constraint is not None
and queried_linkable_specs.contains_metric_time
):
if cumulative_metric_adjusted_time_constraint is not None and time_range_constraint is not None:
assert (
queried_linkable_specs.contains_metric_time
), "Using time constraints currently requires querying with metric_time."
cumulative_metric_constrained_node = ConstrainTimeRangeNode(
unaggregated_measure_node, time_range_constraint
)
Expand Down
4 changes: 4 additions & 0 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ class JoinOverTimeRangeNode(BaseOutput):
def __init__(
self,
parent_node: BaseOutput,
time_dimension_spec_for_join: TimeDimensionSpec,
window: Optional[MetricTimeWindow],
grain_to_date: Optional[TimeGranularity],
node_id: Optional[NodeId] = None,
Expand All @@ -390,6 +391,7 @@ def __init__(
(eg month to day)
node_id: Override the node ID with this value
time_range_constraint: time range to aggregate over
time_dimension_spec_for_join: time dimension spec to use when joining to time spine
"""
if window and grain_to_date:
raise RuntimeError(
Expand All @@ -400,6 +402,7 @@ def __init__(
self._grain_to_date = grain_to_date
self._window = window
self.time_range_constraint = time_range_constraint
self.time_dimension_spec_for_join = time_dimension_spec_for_join

# Doing a list comprehension throws a type error, so doing it this way.
parent_nodes: List[DataflowPlanNode] = [self._parent_node]
Expand Down Expand Up @@ -447,6 +450,7 @@ def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> JoinOverTi
window=self.window,
grain_to_date=self.grain_to_date,
time_range_constraint=self.time_range_constraint,
time_dimension_spec_for_join=self.time_dimension_spec_for_join,
)


Expand Down
51 changes: 25 additions & 26 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def _next_unique_table_alias(self) -> str:

def _make_time_spine_data_set(
self,
metric_time_dimension_instance: TimeDimensionInstance,
metric_time_dimension_column_name: str,
agg_time_dimension_instance: TimeDimensionInstance,
agg_time_dimension_column_name: str,
time_spine_source: TimeSpineSource,
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> SqlDataSet:
Expand All @@ -187,21 +187,21 @@ def _make_time_spine_data_set(
"""
time_spine_instance = (
TimeDimensionInstance(
defined_from=metric_time_dimension_instance.defined_from,
defined_from=agg_time_dimension_instance.defined_from,
associated_columns=(
ColumnAssociation(
column_name=metric_time_dimension_column_name,
column_name=agg_time_dimension_column_name,
single_column_correlation_key=SingleColumnCorrelationKey(),
),
),
spec=metric_time_dimension_instance.spec,
spec=agg_time_dimension_instance.spec,
),
)
time_spine_instance_set = InstanceSet(time_dimension_instances=time_spine_instance)
time_spine_table_alias = self._next_unique_table_alias()

# If the requested granularity is the same as the granularity of the spine, do a direct select.
if metric_time_dimension_instance.spec.time_granularity == time_spine_source.time_column_granularity:
if agg_time_dimension_instance.spec.time_granularity == time_spine_source.time_column_granularity:
return SqlDataSet(
instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode(
Expand All @@ -214,7 +214,7 @@ def _make_time_spine_data_set(
column_name=time_spine_source.time_column_name,
),
),
column_alias=metric_time_dimension_column_name,
column_alias=agg_time_dimension_column_name,
),
),
from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table),
Expand All @@ -236,15 +236,15 @@ def _make_time_spine_data_set(
select_columns = (
SqlSelectColumn(
expr=SqlDateTruncExpression(
time_granularity=metric_time_dimension_instance.spec.time_granularity,
time_granularity=agg_time_dimension_instance.spec.time_granularity,
arg=SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=time_spine_table_alias,
column_name=time_spine_source.time_column_name,
),
),
),
column_alias=metric_time_dimension_column_name,
column_alias=agg_time_dimension_column_name,
),
)
return SqlDataSet(
Expand Down Expand Up @@ -281,40 +281,39 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
input_data_set = node.parent_node.accept(self)
input_data_set_alias = self._next_unique_table_alias()

metric_time_dimension_spec: Optional[TimeDimensionSpec] = None
metric_time_dimension_instance: Optional[TimeDimensionInstance] = None
for instance in input_data_set.metric_time_dimension_instances:
if len(instance.spec.entity_links) == 0:
metric_time_dimension_instance = instance
metric_time_dimension_spec = instance.spec
agg_time_dimension_instance: Optional[TimeDimensionInstance] = None
for instance in input_data_set.instance_set.time_dimension_instances:
if instance.spec == node.time_dimension_spec_for_join:
agg_time_dimension_instance = instance
break
assert (
agg_time_dimension_instance
), "Specified metric time spec not found in parent data set. This should have been caught by validations."

assert metric_time_dimension_spec
time_spine_data_set_alias = self._next_unique_table_alias()

metric_time_dimension_column_name = self.column_association_resolver.resolve_spec(
metric_time_dimension_spec
agg_time_dimension_column_name = self.column_association_resolver.resolve_spec(
agg_time_dimension_instance.spec
).column_name

# Assemble time_spine dataset with metric_time_dimension to join.
# Granularity of time_spine column should match granularity of metric_time column from parent dataset.
assert metric_time_dimension_instance
time_spine_data_set = self._make_time_spine_data_set(
metric_time_dimension_instance=metric_time_dimension_instance,
metric_time_dimension_column_name=metric_time_dimension_column_name,
agg_time_dimension_instance=agg_time_dimension_instance,
agg_time_dimension_column_name=agg_time_dimension_column_name,
time_spine_source=self._time_spine_source,
time_range_constraint=node.time_range_constraint,
)
table_alias_to_instance_set[time_spine_data_set_alias] = time_spine_data_set.instance_set

# Figure out which columns correspond to the time dimension that we want to join on.
input_data_set_metric_time_column_association = input_data_set.column_association_for_time_dimension(
metric_time_dimension_spec
agg_time_dimension_instance.spec
)
input_data_set_metric_time_col = input_data_set_metric_time_column_association.column_name

time_spine_data_set_column_associations = time_spine_data_set.column_association_for_time_dimension(
metric_time_dimension_spec
agg_time_dimension_instance.spec
)
time_spine_data_set_time_dimension_col = time_spine_data_set_column_associations.column_name

Expand Down Expand Up @@ -342,7 +341,7 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
[
time_dimension_instance
for time_dimension_instance in input_data_set.instance_set.time_dimension_instances
if time_dimension_instance.spec != metric_time_dimension_spec
if time_dimension_instance != agg_time_dimension_instance
]
),
)
Expand Down Expand Up @@ -1256,8 +1255,8 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
).column_name
time_spine_alias = self._next_unique_table_alias()
time_spine_dataset = self._make_time_spine_data_set(
metric_time_dimension_instance=metric_time_dimension_instance,
metric_time_dimension_column_name=metric_time_dimension_column_name,
agg_time_dimension_instance=metric_time_dimension_instance,
agg_time_dimension_column_name=metric_time_dimension_column_name,
time_spine_source=self._time_spine_source,
time_range_constraint=node.time_range_constraint,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class CumulativeMetricRequiresMetricTimeIssue(MetricFlowQueryResolutionIssue):
def ui_description(self, associated_input: MetricFlowQueryResolverInput) -> str:
return (
f"The query includes a cumulative metric {repr(self.metric_reference.element_name)} that does not "
f"accumulate over all-time, but the group-by items do not include {repr(METRIC_TIME_ELEMENT_NAME)}"
f"accumulate over all-time, but the group-by items do not include {repr(METRIC_TIME_ELEMENT_NAME)} "
"or the metric's agg_time_dimension."
)

@override
Expand Down
53 changes: 26 additions & 27 deletions metricflow/query/validation_rules/metric_time_requirements.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

from typing import List, Sequence
from typing import Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.protocols import WhereFilterIntersection
from dbt_semantic_interfaces.references import MetricReference
from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.references import MetricReference, TimeDimensionReference
from dbt_semantic_interfaces.type_enums import MetricType
from typing_extensions import override

from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
Expand All @@ -34,29 +33,11 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule):
def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D
super().__init__(manifest_lookup=manifest_lookup)

metric_time_specs: List[TimeDimensionSpec] = []

for time_granularity in TimeGranularity:
metric_time_specs.append(
TimeDimensionSpec(
element_name=METRIC_TIME_ELEMENT_NAME,
entity_links=(),
time_granularity=time_granularity,
date_part=None,
)
self._metric_time_specs = tuple(
TimeDimensionSpec.generate_possible_specs_for_time_dimension(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.

time_dimension_reference=TimeDimensionReference(element_name=METRIC_TIME_ELEMENT_NAME), entity_links=()
)
for date_part in DatePart:
for time_granularity in date_part.compatible_granularities:
metric_time_specs.append(
TimeDimensionSpec(
element_name=METRIC_TIME_ELEMENT_NAME,
entity_links=(),
time_granularity=time_granularity,
date_part=date_part,
)
)

self._metric_time_specs = tuple(metric_time_specs)
)

def _group_by_items_include_metric_time(self, query_resolver_input: ResolverInputForQuery) -> bool:
for group_by_item_input in query_resolver_input.group_by_item_inputs:
Expand All @@ -65,6 +46,18 @@ def _group_by_items_include_metric_time(self, query_resolver_input: ResolverInpu

return False

def _group_by_items_include_agg_time_dimension(
self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference
) -> bool:
valid_agg_time_dimension_specs = self._manifest_lookup.metric_lookup.get_valid_agg_time_dimensions_for_metric(
metric_reference
)
for group_by_item_input in query_resolver_input.group_by_item_inputs:
if group_by_item_input.spec_pattern.matches_any(valid_agg_time_dimension_specs):
return True

return False

@override
def validate_metric_in_resolution_dag(
self,
Expand All @@ -74,14 +67,20 @@ def validate_metric_in_resolution_dag(
) -> MetricFlowQueryResolutionIssueSet:
metric = self._get_metric(metric_reference)
query_includes_metric_time = self._group_by_items_include_metric_time(resolver_input_for_query)
query_includes_metric_time_or_agg_time_dimension = (
query_includes_metric_time
or self._group_by_items_include_agg_time_dimension(
query_resolver_input=resolver_input_for_query, metric_reference=metric_reference
)
)

if metric.type is MetricType.SIMPLE or metric.type is MetricType.CONVERSION:
return MetricFlowQueryResolutionIssueSet.empty_instance()
elif metric.type is MetricType.CUMULATIVE:
if (
metric.type_params is not None
and (metric.type_params.window is not None or metric.type_params.grain_to_date is not None)
and not query_includes_metric_time
and not query_includes_metric_time_or_agg_time_dimension
):
return MetricFlowQueryResolutionIssueSet.from_issue(
CumulativeMetricRequiresMetricTimeIssue.from_parameters(
Expand Down
3 changes: 2 additions & 1 deletion metricflow/test/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_saved_query( # noqa: D
resp = cli_runner.run(
query, args=["--saved-query", "p0_booking", "--order", "metric_time__day,listing__capacity_latest"]
)
print(resp.output)

assert resp.exit_code == 0

Expand Down Expand Up @@ -291,5 +292,5 @@ def test_saved_query_with_cumulative_metric( # noqa: D
snapshot_str=resp.output,
sql_engine=sql_client.sql_engine_type,
)

print(resp.output)
assert resp.exit_code == 0
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,31 @@ integration_test:
GROUP BY
subq_3.metric_time__month
ORDER BY subq_3.metric_time__month
---
integration_test:
name: cumulative_metric_with_agg_time_dimension
description: Query a cumulative metric with its agg_time_dimension and a time constraint.
model: SIMPLE_MODEL
metrics: ["trailing_2_months_revenue"]
group_bys: ["revenue_instance__ds__day"]
order_bys: ["revenue_instance__ds__day"]
where_filter: '{{ render_time_constraint("revenue_instance__ds__day", "2020-03-05", "2021-01-04") }}'
check_query: |
SELECT
SUM(b.txn_revenue) as trailing_2_months_revenue
, a.ds AS revenue_instance__ds__day
FROM (
SELECT ds
FROM {{ mf_time_spine_source }}
WHERE {{ render_time_constraint("ds", "2020-01-05", "2021-01-04") }}
) a
INNER JOIN (
SELECT
revenue as txn_revenue
, created_at AS ds
FROM {{ source_schema }}.fct_revenue
) b
ON b.ds <= a.ds AND b.ds > {{ render_date_sub("a", "ds", 2, TimeGranularity.MONTH) }}
WHERE {{ render_time_constraint("a.ds", "2020-03-05", "2021-01-04") }}
GROUP BY a.ds
ORDER BY a.ds
Loading
Loading