diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index dd70e814b1ea8..4061d024a83cd 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -134,6 +134,11 @@ "Metadata can only be provided for a single column." ] }, + "CANNOT_REGISTER_UDTF": { + "message": [ + "Cannot register the UDTF '': expected a 'UserDefinedTableFunction'. Please make sure the UDTF is correctly defined as a class, and then either wrap it in the `udtf()` function or annotate it with `@udtf(...)`." + ] + }, "CANNOT_SET_TOGETHER": { "message": [ " should not be set together." diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index 739289d72a3b1..1a55f0aa08bf3 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -192,6 +192,14 @@ def register( name: str, f: "UserDefinedTableFunction", ) -> "UserDefinedTableFunction": + if not isinstance(f, UserDefinedTableFunction): + raise PySparkTypeError( + error_class="CANNOT_REGISTER_UDTF", + message_parameters={ + "name": name, + }, + ) + if f.evalType not in [PythonEvalType.SQL_TABLE_UDF, PythonEvalType.SQL_ARROW_TABLE_UDF]: raise PySparkTypeError( error_class="INVALID_UDTF_EVAL_TYPE", diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index d69eea8a80c43..74a2a40a46314 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -935,10 +935,23 @@ def upper(s: str): self.check_error( exception=e.exception, - error_class="INVALID_UDTF_EVAL_TYPE", + error_class="CANNOT_REGISTER_UDTF", message_parameters={ "name": "test_udf", - "eval_type": "SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF", + }, + ) + + class TestUDTF: + ... + + with self.assertRaises(PySparkTypeError) as e: + self.spark.udtf.register("test_udtf", TestUDTF) + + self.check_error( + exception=e.exception, + error_class="CANNOT_REGISTER_UDTF", + message_parameters={ + "name": "test_udtf", }, ) diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index f560880202230..83ef1d488d960 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -461,6 +461,14 @@ def register( >>> spark.sql("SELECT * FROM VALUES (0, 1), (1, 2) t(x, y), LATERAL plus_one(x)").collect() [Row(x=0, y=1, c1=0, c2=1), Row(x=1, y=2, c1=1, c2=2)] """ + if not isinstance(f, UserDefinedTableFunction): + raise PySparkTypeError( + error_class="CANNOT_REGISTER_UDTF", + message_parameters={ + "name": name, + }, + ) + if f.evalType not in [PythonEvalType.SQL_TABLE_UDF, PythonEvalType.SQL_ARROW_TABLE_UDF]: raise PySparkTypeError( error_class="INVALID_UDTF_EVAL_TYPE",