From d8a98475036a4fba28b3d3eb508b3d1f3f5072aa Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 2 Feb 2024 12:56:02 -0500 Subject: [PATCH] fix: column values with NaN (#26946) --- superset/models/helpers.py | 7 ++++++- tests/integration_tests/conftest.py | 21 ++++++++++--------- .../integration_tests/datasource/api_tests.py | 10 +++++++++ tests/integration_tests/datasource_tests.py | 9 +++++++- 4 files changed, 35 insertions(+), 12 deletions(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index fa2c9b8102136..9322e8c46d993 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1340,7 +1340,10 @@ def get_time_filter( # pylint: disable=too-many-arguments return and_(*l) def values_for_column( - self, column_name: str, limit: int = 10000, denormalize_column: bool = False + self, + column_name: str, + limit: int = 10000, + denormalize_column: bool = False, ) -> list[Any]: # denormalize column name before querying for values # unless disabled in the dataset configuration @@ -1378,6 +1381,8 @@ def values_for_column( sql = self.mutate_query_from_config(sql) df = pd.read_sql_query(sql=sql, con=engine) + # replace NaN with None to ensure it can be serialized to JSON + df = df.replace({np.nan: None}) return df["column_values"].to_list() def get_timestamp_expression( diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 3e6aa963072b1..b90416587c2e8 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -296,25 +296,25 @@ def virtual_dataset(): dataset = SqlaTable( table_name="virtual_dataset", sql=( - "SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5 " + "SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5, 1 as col6 " "UNION ALL " - "SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00' " + "SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00', NULL " "UNION ALL " - "SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00' " + "SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00', 3 " "UNION ALL " - "SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00' " + "SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00', 4 " "UNION ALL " - "SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00' " + "SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00', 5 " "UNION ALL " - "SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00' " + "SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00', 6 " "UNION ALL " - "SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00' " + "SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00', 7 " "UNION ALL " - "SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00' " + "SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00', 8 " "UNION ALL " - "SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00' " + "SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00', 9 " "UNION ALL " - "SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00' " + "SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00', 10" ), database=get_example_database(), ) @@ -324,6 +324,7 @@ def virtual_dataset(): TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset) # Different database dialect datetime type is not consistent, so temporarily use varchar TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="col6", type="INTEGER", table=dataset) SqlMetric(metric_name="count", expression="count(*)", table=dataset) db.session.add(dataset) diff --git a/tests/integration_tests/datasource/api_tests.py b/tests/integration_tests/datasource/api_tests.py index 6f37186963ce2..554875e58d953 100644 --- a/tests/integration_tests/datasource/api_tests.py +++ b/tests/integration_tests/datasource/api_tests.py @@ -72,6 +72,16 @@ def test_get_column_values_nulls(self): response = json.loads(rv.data.decode("utf-8")) self.assertEqual(response["result"], [None]) + @pytest.mark.usefixtures("app_context", "virtual_dataset") + def test_get_column_values_integers_with_nulls(self): + self.login(username="admin") + table = self.get_virtual_dataset() + rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col6/values/") + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + for val in [1, None, 3, 4, 5, 6, 7, 8, 9, 10]: + assert val in response["result"] + @pytest.mark.usefixtures("app_context", "virtual_dataset") def test_get_column_values_invalid_datasource_type(self): self.login(username="admin") diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index dce33ea2ccaea..4e05b63002fd0 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -602,7 +602,14 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset): }, ) assert rv.status_code == 200 - assert rv.json["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"] + assert rv.json["result"]["colnames"] == [ + "col1", + "col2", + "col3", + "col4", + "col5", + "col6", + ] assert rv.json["result"]["rowcount"] == 1 # empty results