Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ufunc input does not respect over expressions #15900

Closed
2 tasks done
hixan opened this issue Apr 26, 2024 · 3 comments
Closed
2 tasks done

Ufunc input does not respect over expressions #15900

hixan opened this issue Apr 26, 2024 · 3 comments
Assignees
Labels
A-interop Area: interoperability with other libraries bug Something isn't working python Related to Python Polars

Comments

@hixan
Copy link

hixan commented Apr 26, 2024

Checks

  • I have checked that this issue has not already been reported.
  • I have confirmed this bug exists on the latest version of Polars.

Reproducible example

import numba
import numpy as np
import polars as pl

df = pl.DataFrame(
    [
        pl.Series("category", ["4", "5", "4"]),
        pl.Series("value", [-4.0, -2.0, -4.0], dtype=pl.Float64),
    ]
)


@numba.guvectorize([(numba.float64[:], numba.float64[:])], "(n)->(n)")
def my_func(mis, res):
    print('called myfunc with', mis)
    v = np.max(mis)
    for i in range(len(res)):
        res[i] = v


transform = my_func(pl.col("value")).over("category").alias("my_func_output")
filter = pl.col("category") == "4"

print(df.with_columns(transform).filter(filter))
print(df.filter(filter).with_columns(transform))

Log output

No response

Issue description

As in the above, when calling a guvectorised ufunc expression (here declared with numba) and applying a .over() condition polars produces unexpected results by passing in all inputs to the ufunc as if it had not been constrained by the over.

This is likely related to #14507

Expected behavior

I expect the output to be equivalent in both cases, and for the keep_between function to be called twice in the first instance, but only once in the second instance. Instead, I see that the function is called as if .over was not included in transform:

called myfunc with [-4. -2. -4.]
shape: (2, 3)
┌──────────┬───────┬────────────────┐
│ category ┆ value ┆ my_func_output │
│ ---      ┆ ---   ┆ ---            │
│ str      ┆ f64   ┆ f64            │
╞══════════╪═══════╪════════════════╡
│ 4        ┆ -4.0  ┆ -2.0           │
│ 4        ┆ -4.0  ┆ -2.0           │
└──────────┴───────┴────────────────┘
called myfunc with [-4. -4.]
shape: (2, 3)
┌──────────┬───────┬────────────────┐
│ category ┆ value ┆ my_func_output │
│ ---      ┆ ---   ┆ ---            │
│ str      ┆ f64   ┆ f64            │
╞══════════╪═══════╪════════════════╡
│ 4        ┆ -4.0  ┆ -4.0           │
│ 4        ┆ -4.0  ┆ -4.0           │
└──────────┴───────┴────────────────┘

Installed versions

--------Version info---------
Polars:               0.20.22
Index type:           UInt32
Platform:             Linux-5.14.0-284.30.1.el9_2.x86_64-x86_64-with-glibc2.28
Python:               3.11.8 (main, Feb 16 2024, 19:42:16) [GCC 8.5.0 20210514 (Red Hat 8.5.0-20)]

----Optional dependencies----
adbc_driver_manager:  <not installed>
cloudpickle:          3.0.0
connectorx:           <not installed>
deltalake:            <not installed>
fastexcel:            <not installed>
fsspec:               2024.3.1
gevent:               <not installed>
hvplot:               <not installed>
matplotlib:           3.8.4
nest_asyncio:         1.6.0
numpy:                1.26.4
openpyxl:             3.1.1
pandas:               2.2.2
pyarrow:              11.0.0
pydantic:             1.10.15
pyiceberg:            <not installed>
pyxlsb:               <not installed>
sqlalchemy:           1.4.52
xlsx2csv:             <not installed>
xlsxwriter:           <not installed>
@hixan hixan added bug Something isn't working needs triage Awaiting prioritization by a maintainer python Related to Python Polars labels Apr 26, 2024
@deanm0000
Copy link
Collaborator

deanm0000 commented Apr 26, 2024

I don't think we can make .over work on our own.

The way ufuncs work, in the general case, is that when any of their inputs are a class that has an __array_ufunc__ method then it will run that method. Polars makes ufuncs work as a direct input by having such a method defined on expressions. When the ufunc runs the __array_ufunc__ method the input is just pl.col('value').

What should happen is that either numpy/python should raise because my_func doesn't have an .over method OR it should pass along the method with the rest of the arguments to the __array_ufunc__ method. If numpy implements the latter case then this could work as you're trying it.

What you'll have to do is:

df.select(pl.col('value').map_batches(my_func).over('category'))

I think this PR will fix this.

It turns out that when we do my_func(pl.col("value")) it resolves to col('value').map_batches(my_func) which means that when you do my_func(pl.col("value")).over('category') it resolves to col('value').map_batches(my_func).over('category')

@deanm0000 deanm0000 added A-interop Area: interoperability with other libraries and removed needs triage Awaiting prioritization by a maintainer labels Apr 26, 2024
@deanm0000 deanm0000 self-assigned this Apr 26, 2024
@MarcoGorelli
Copy link
Collaborator

I think this #15916 will fix this.

thanks for looking into this - that PR's merged, if it fixes the issue, should a test be added and the issue closed?

@deanm0000
Copy link
Collaborator

The test would need a numba dependency. I'm under the impression (perhaps mistakenly) that having numba required isn't wanted in the test suite. Even then it's probably only helpful to test that case sporadically (to catch when/if numba or numpy change) rather than for every PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
A-interop Area: interoperability with other libraries bug Something isn't working python Related to Python Polars
Projects
None yet
Development

No branches or pull requests

3 participants