Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#3090] Reduce Command Responses in Redis Connection Process (on_connect) #3268

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 137 additions & 70 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,87 +356,154 @@ def on_connect(self):
)
auth_args = cred_provider.get_credentials()

# if resp version is specified and we have auth args,
# we need to send them via HELLO
if auth_args and self.protocol not in [2, "2"]:
if isinstance(self._parser, _RESP2Parser):
self.set_parser(_RESP3Parser)
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
if len(auth_args) == 1:
auth_args = ["default", auth_args[0]]
self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
response = self.read_response()
# if response.get(b"proto") != self.protocol and response.get(
# "proto"
# ) != self.protocol:
# raise ConnectionError("Invalid RESP version")
elif auth_args:
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
self.send_command("AUTH", *auth_args, check_health=False)
# try to send HELLO command (for Redis 6.0 and above)
try:
# if resp version is specified and we have auth args,
# we need to send them via HELLO
if auth_args and self.protocol not in [2, "2"]:
if isinstance(self._parser, _RESP2Parser):
self.set_parser(_RESP3Parser)
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
if len(auth_args) == 1:
auth_args = ["default", auth_args[0]]
self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
else:
self.send_command("HELLO", self.protocol)

try:
auth_response = self.read_response()
except AuthenticationWrongNumberOfArgsError:
# a username and password were specified but the Redis
# server seems to be < 6.0.0 which expects a single password
# arg. retry auth with just the password.
# https://github.com/andymccurdy/redis-py/issues/1274
self.send_command("AUTH", auth_args[-1], check_health=False)
auth_response = self.read_response()

if str_if_bytes(auth_response) != "OK":
raise AuthenticationError("Invalid Username or Password")

# if resp version is specified, switch to it
elif self.protocol not in [2, "2"]:
if isinstance(self._parser, _RESP2Parser):
self.set_parser(_RESP3Parser)
# update cluster exception classes
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
self.send_command("HELLO", self.protocol)
response = self.read_response()
if (
response.get(b"proto") != self.protocol
and response.get("proto") != self.protocol
):
self.read_response()

except Exception as e:
if str(e) == "Invalid RESP version":
raise ConnectionError("Invalid RESP version")
# fall back to AUTH command (for Redis versions less than 6.0)
else:
# avoid checking health here -- PING will fail if we try
# to check the health prior to the AUTH
if auth_args:
# check if only password is provided and RESP version < 6
if not self.username and self.password and self.protocol in [2, "2"]:
self.send_command("AUTH", self.password, check_health=False)
else:
self.send_command("AUTH", *auth_args, check_health=False)

# start a transaction block with MULTI
try:
self.send_command('MULTI')
self.read_response()

# if a client_name is given, set it
if self.client_name:
self.send_command("CLIENT", "SETNAME", self.client_name)
if str_if_bytes(self.read_response()) != "OK":
raise ConnectionError("Error setting client name")
# if a client_name is given, set it
if self.client_name:
self.send_command("CLIENT", "SETNAME", self.client_name)

try:
# set the library name and version
if self.lib_name:
self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
self.read_response()
except ResponseError:
pass

try:
if self.lib_version:
self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
self.read_response()
except ResponseError:
pass

# if a database is specified, switch to it
if self.db:
self.send_command("SELECT", self.db)
if str_if_bytes(self.read_response()) != "OK":
raise ConnectionError("Invalid Database")
# if a database is specified, switch to it
if self.db:
self.send_command("SELECT", self.db)

# if client caching is enabled, start tracking
if self.client_cache:
self.send_command("CLIENT", "TRACKING", "ON")

# execute the MULTI block
self.send_command('EXEC')
responses = self._read_exec_responses()
self._handle_responses(responses, auth_args)
except AuthenticationError as e:
if str(e) == "Invalid Username or Password":
raise AuthenticationError("Invalid Username or Password") from e
except Exception:
raise ConnectionError("Error during EXEC handling")

def _read_exec_responses(self):
# read the response for EXEC which should be a list
response = self.read_response()
if response == b'OK':
# EXEC did not execute correctly, likely due to previous error
raise ConnectionError("EXEC command did not execute correctly")
while response == b'QUEUED':
response = self.read_response()
if not isinstance(response, list):
raise ConnectionError(f"EXEC command did not return a list: {response}")
return response

# if client caching is enabled, start tracking
if self.client_cache:
self.send_command("CLIENT", "TRACKING", "ON")
self.read_response()
self._parser.set_invalidation_push_handler(self._cache_invalidation_process)
def _handle_responses(self, responses, auth_args):
if not isinstance(responses, list):
raise ConnectionError(f"EXEC command did not return a list: {responses}")

response_iter = iter(responses)

try:
# handle HELLO + AUTH
if auth_args and self.protocol not in [2, "2"]:
response = next(response_iter, None)
if isinstance(response, dict) and (
response.get(b"proto") != self.protocol and response.get("proto") != self.protocol):
raise ConnectionError("Invalid RESP version")

response = next(response_iter, None)
if isinstance(response, bytes) and str_if_bytes(response) != "OK":
raise AuthenticationError("Invalid Username or Password")
elif auth_args:
response = next(response_iter, None)
if isinstance(response, bytes) and str_if_bytes(response) != "OK":
try:
# a username and password were specified but the Redis
# server seems to be < 6.0.0 which expects a single password
# arg. retry auth with just the password.
# https://github.com/andymccurdy/redis-py/issues/1274
self.send_command("AUTH", auth_args[-1], check_health=False)
auth_response = self.read_response()
if isinstance(auth_response, bytes) and str_if_bytes(
auth_response) != "OK":
raise AuthenticationError("Invalid Username or Password")
# add the retry response to the responses list for further processing
responses = [auth_response] + list(response_iter)
response_iter = iter(responses)
except AuthenticationWrongNumberOfArgsError:
raise AuthenticationError("Invalid Username or Password")

# handle CLIENT SETNAME
if self.client_name:
response = next(response_iter, None)
if isinstance(response, bytes) and str_if_bytes(response) != "OK":
raise ConnectionError("Error setting client name")

# handle CLIENT SETINFO LIB-NAME
if self.lib_name:
response = next(response_iter, None)
if isinstance(response, bytes) and str_if_bytes(response) != "OK":
raise ConnectionError("Error setting client library name")

# handle CLIENT SETINFO LIB-VER
if self.lib_version:
response = next(response_iter, None)
if isinstance(response, bytes) and str_if_bytes(response) != "OK":
raise ConnectionError("Error setting client library version")

# handle SELECT
if self.db:
response = next(response_iter, None)
if isinstance(response, bytes) and str_if_bytes(response) != "OK":
raise ConnectionError("Invalid Database")

# handle CLIENT TRACKING ON
if self.client_cache:
response = next(response_iter, None)
if isinstance(response, bytes) and str_if_bytes(response) != "OK":
raise ConnectionError("Error enabling client tracking")
self._parser.set_invalidation_push_handler(
self._cache_invalidation_process)
except (AuthenticationError, ConnectionError):
raise
except Exception as e:
raise ConnectionError("Error during response handling") from e

def disconnect(self, *args):
"Disconnects from the Redis server"
Expand Down
Loading