Skip to content

Commit

Permalink
[SPARK-48938][PYTHON] Improve error messages when registering Python …
Browse files Browse the repository at this point in the history
…UDTFs

### What changes were proposed in this pull request?

This PR improves the error messages when registering Python UDTFs.
Before this PR:
```python
class TestUDTF:
   ...

spark.udtf.register("test_udtf", TestUDTF)
```
This fails with
```
AttributeError: type object "TestUDTF" has no attribute "evalType"
```
After this PR:
```python
spark.udtf.register("test_udtf", TestUDTF)
```
Now we have a nicer error:
```
[CANNOT_REGISTER_UDTF] Cannot register the UDTF 'test_udtf': expected a 'UserDefinedTableFunction'. Please make sure the UDTF is correctly defined as a class, and then either wrap it in the `udtf()` function or annotate it with `udtf(...)`.`
```

### Why are the changes needed?

To improve usability.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing and new unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#47408 from allisonwang-db/spark-48938-udtf-register-err-msg.

Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
allisonwang-db authored and HyukjinKwon committed Jul 20, 2024
1 parent 1632568 commit 5785098
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 2 deletions.
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@
"Metadata can only be provided for a single column."
]
},
"CANNOT_REGISTER_UDTF": {
"message": [
"Cannot register the UDTF '<name>': expected a 'UserDefinedTableFunction'. Please make sure the UDTF is correctly defined as a class, and then either wrap it in the `udtf()` function or annotate it with `@udtf(...)`."
]
},
"CANNOT_SET_TOGETHER": {
"message": [
"<arg_list> should not be set together."
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/sql/connect/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,14 @@ def register(
name: str,
f: "UserDefinedTableFunction",
) -> "UserDefinedTableFunction":
if not isinstance(f, UserDefinedTableFunction):
raise PySparkTypeError(
error_class="CANNOT_REGISTER_UDTF",
message_parameters={
"name": name,
},
)

if f.evalType not in [PythonEvalType.SQL_TABLE_UDF, PythonEvalType.SQL_ARROW_TABLE_UDF]:
raise PySparkTypeError(
error_class="INVALID_UDTF_EVAL_TYPE",
Expand Down
17 changes: 15 additions & 2 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,10 +935,23 @@ def upper(s: str):

self.check_error(
exception=e.exception,
error_class="INVALID_UDTF_EVAL_TYPE",
error_class="CANNOT_REGISTER_UDTF",
message_parameters={
"name": "test_udf",
"eval_type": "SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF",
},
)

class TestUDTF:
...

with self.assertRaises(PySparkTypeError) as e:
self.spark.udtf.register("test_udtf", TestUDTF)

self.check_error(
exception=e.exception,
error_class="CANNOT_REGISTER_UDTF",
message_parameters={
"name": "test_udtf",
},
)

Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/sql/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,14 @@ def register(
>>> spark.sql("SELECT * FROM VALUES (0, 1), (1, 2) t(x, y), LATERAL plus_one(x)").collect()
[Row(x=0, y=1, c1=0, c2=1), Row(x=1, y=2, c1=1, c2=2)]
"""
if not isinstance(f, UserDefinedTableFunction):
raise PySparkTypeError(
error_class="CANNOT_REGISTER_UDTF",
message_parameters={
"name": name,
},
)

if f.evalType not in [PythonEvalType.SQL_TABLE_UDF, PythonEvalType.SQL_ARROW_TABLE_UDF]:
raise PySparkTypeError(
error_class="INVALID_UDTF_EVAL_TYPE",
Expand Down

0 comments on commit 5785098

Please sign in to comment.