Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
EpsilonPrime committed Sep 20, 2024
1 parent 879ae60 commit b4927d2
Showing 1 changed file with 29 additions and 26 deletions.
55 changes: 29 additions & 26 deletions src/gateway/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def convert_pyarrow_schema_to_spark(schema: pa.Schema) -> types_pb2.DataType:


def create_dataframe_view(
session_id: str, view: commands_pb2.CreateDataFrameViewCommand, backend
session_id: str, view: commands_pb2.CreateDataFrameViewCommand, backend
) -> None:
"""Register the temporary dataframe."""
read_data_source_relation = view.input.read.data_source
Expand Down Expand Up @@ -235,7 +235,7 @@ def _ReinitializeExecution(self) -> None:
return None

def ExecutePlan(
self, request: pb2.ExecutePlanRequest, context: grpc.RpcContext
self, request: pb2.ExecutePlanRequest, context: grpc.RpcContext
) -> Generator[pb2.ExecutePlanResponse, None, None]:
"""Execute the given plan and return the results."""
self._statistics.execute_requests += 1
Expand Down Expand Up @@ -294,9 +294,9 @@ def ExecutePlan(
_LOGGER.debug(" results are: %s", results)

if (
not self._options.implement_show_string
and request.plan.WhichOneof("op_type") == "root"
and request.plan.root.WhichOneof("rel_type") == "show_string"
not self._options.implement_show_string
and request.plan.WhichOneof("op_type") == "root"
and request.plan.root.WhichOneof("rel_type") == "show_string"
):
yield pb2.ExecutePlanResponse(
session_id=request.session_id,
Expand Down Expand Up @@ -383,16 +383,19 @@ def Config(self, request, context):
need_reset = False
match pair.value:
case "arrow":
if self._backend is not None and self._options.backend.backend != BackendEngine.ARROW:
if (self._backend is not None and
self._options.backend.backend != BackendEngine.ARROW):
need_reset = True
self._options = arrow()
case "duckdb":
if self._backend is not None and self._options.backend.backend != BackendEngine.DUCKDB:
if (self._backend is not None and
self._options.backend.backend != BackendEngine.DUCKDB):
need_reset = True
self._options = duck_db()
case "datafusion":
need_reset = False
if self._backend is not None and self._options.backend.backend != BackendEngine.DATAFUSION:
if (self._backend is not None and
self._options.backend.backend != BackendEngine.DATAFUSION):
need_reset = True
self._options = datafusion()
case _:
Expand All @@ -410,7 +413,7 @@ def Config(self, request, context):
elif key == "spark-substrait-gateway.plan_count":
response.pairs.add(key=key, value=str(len(self._statistics.plans)))
elif key.startswith("spark-substrait-gateway.plan."):
index = int(key[len("spark-substrait-gateway.plan.") :])
index = int(key[len("spark-substrait-gateway.plan."):])
if 0 <= index - 1 < len(self._statistics.plans):
response.pairs.add(key=key, value=self._statistics.plans[index - 1])
elif key == "spark.sql.session.timeZone":
Expand Down Expand Up @@ -488,7 +491,7 @@ def Interrupt(self, request, context):
return pb2.InterruptResponse()

def ReattachExecute(
self, request: pb2.ReattachExecuteRequest, context: grpc.RpcContext
self, request: pb2.ReattachExecuteRequest, context: grpc.RpcContext
) -> Generator[pb2.ExecutePlanResponse, None, None]:
"""Reattach the execution of the given plan."""
self._statistics.reattach_requests += 1
Expand All @@ -505,13 +508,13 @@ def ReleaseExecute(self, request, context):


def serve(
port: int,
wait: bool,
tls: list[str] | None = None,
enable_auth: bool = False,
jwt_audience: str | None = None,
secret_key: str | None = None,
log_level: str = "INFO",
port: int,
wait: bool,
tls: list[str] | None = None,
enable_auth: bool = False,
jwt_audience: str | None = None,
secret_key: str | None = None,
log_level: str = "INFO",
) -> grpc.Server:
"""Start the Spark Substrait Gateway server."""
logging.basicConfig(level=getattr(logging, log_level), encoding="utf-8")
Expand Down Expand Up @@ -589,8 +592,8 @@ def serve(
required=False,
metavar=("CERTFILE", "KEYFILE"),
help="Enable transport-level security (TLS/SSL). Provide a "
"Certificate file path, and a Key file path - separated by a space. "
"Example: tls/server.crt tls/server.key",
"Certificate file path, and a Key file path - separated by a space. "
"Example: tls/server.crt tls/server.key",
)
@click.option(
"--enable-auth/--no-enable-auth",
Expand Down Expand Up @@ -621,13 +624,13 @@ def serve(
help="The logging level to use for the server.",
)
def click_serve(
port: int,
wait: bool,
tls: list[str],
enable_auth: bool,
jwt_audience: str,
secret_key: str,
log_level: str,
port: int,
wait: bool,
tls: list[str],
enable_auth: bool,
jwt_audience: str,
secret_key: str,
log_level: str,
) -> grpc.Server:
"""Provide a click interface for starting the Spark Substrait Gateway server."""
return serve(**locals())
Expand Down

0 comments on commit b4927d2

Please sign in to comment.