From d40040b00d5738ae09e42c1ba71b28f6ca5efea0 Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Wed, 9 Oct 2024 16:29:22 +0000 Subject: [PATCH] Consolidate set_cell immutability behavior --- .pre-commit-config.yaml | 2 +- great_tables/_tbl_data.py | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 212a375c8..3627b677c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ exclude: "(.*\\.svg)|(.*\\.qmd)|(.*\\.ambr)|(.*\\.csv)|(.*\\.txt)|(.*\\.json)" repos: - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 + rev: 7.1.1 hooks: - id: flake8 types: diff --git a/great_tables/_tbl_data.py b/great_tables/_tbl_data.py index 4003506da..3c5de8c8b 100644 --- a/great_tables/_tbl_data.py +++ b/great_tables/_tbl_data.py @@ -249,15 +249,20 @@ def _set_cell(data: DataFrameLike, row: int, column: str, value: Any): def _(data, row: int, column: str, value: Any) -> PdDataFrame: # TODO: This assumes column names are unique # if this is violated, get_loc will return a mask - col_indx = data.columns.get_loc(column) - data.iloc[row, col_indx] = value - return data + data_new = data.copy(deep=False) # make a shallow copy and only update the specific column. + data_new[column] = data_new[column].copy() + data_new.at[row, column] = value + return data_new @_set_cell.register(PlDataFrame) def _(data, row: int, column: str, value: Any) -> PlDataFrame: - data[row, column] = value - return data + # While using scatter is considered an antipattern, + # it is easier to read than a when.then.otherwise expression, + # and it is generally better performing. + col_series_modified = data[column].scatter(row, value) + data_new = data.with_columns(col_series_modified) + return data_new @_set_cell.register(PyArrowTable) @@ -268,8 +273,8 @@ def _(data: PyArrowTable, row: int, column: str, value: Any) -> PyArrowTable: col = data.column(column) pylist = col.to_pylist() pylist[row] = value - data = data.set_column(colindex, column, pa.array(pylist)) - return data + data_new = data.set_column(colindex, column, pa.array(pylist)) + return data_new # _get_column_dtype ----