From d97e311130a2ed285c131e8b75c8987651df5662 Mon Sep 17 00:00:00 2001 From: Takuya ASADA Date: Sat, 30 Sep 2023 05:20:44 +0900 Subject: [PATCH] Optimize cloud detection Currently get_cloud_instance() is very slow on GCE / Azure, since it fails to detect AWS metadata server, and retrying multiple times in curl(). To speed up cloud detection, we will introduce following: - Use DMI parameters for the detection when it's available - Use async/await for metadata server access Fixes #481 --- lib/scylla_cloud.py | 151 +++++++++++++++++++++++++++-------- tests/test_aws_instance.py | 59 ++++++++++---- tests/test_azure_instance.py | 56 +++++++++---- tests/test_gcp_instance.py | 47 ++++++++--- 4 files changed, 238 insertions(+), 75 deletions(-) diff --git a/lib/scylla_cloud.py b/lib/scylla_cloud.py index 4faa4baa..a7f8638c 100644 --- a/lib/scylla_cloud.py +++ b/lib/scylla_cloud.py @@ -18,6 +18,7 @@ import distro import base64 import datetime +import asyncio from subprocess import run, CalledProcessError from abc import ABCMeta, abstractmethod @@ -58,23 +59,43 @@ def scylla_excepthook(etype, value, tb): sys.excepthook = scylla_excepthook +def _curl_one(url, headers=None, method=None, byte=False, timeout=3): + req = urllib.request.Request(url, headers=headers or {}, method=method) + with urllib.request.urlopen(req, timeout=timeout) as res: + if byte: + return res.read() + else: + return res.read().decode('utf-8') + # @param headers dict of k:v def curl(url, headers=None, method=None, byte=False, timeout=3, max_retries=5, retry_interval=5): retries = 0 while True: try: - req = urllib.request.Request(url, headers=headers or {}, method=method) - with urllib.request.urlopen(req, timeout=timeout) as res: - if byte: - return res.read() - else: - return res.read().decode('utf-8') + return _curl_one(url, headers, method, byte, timeout) except (urllib.error.URLError, socket.timeout): time.sleep(retry_interval) retries += 1 if retries >= max_retries: raise +async def aiocurl(url, headers=None, method=None, byte=False, timeout=3, max_retries=5, retry_interval=5): + retries = 0 + while True: + try: + return _curl_one(url, headers, method, byte, timeout) + except (urllib.error.URLError, socket.timeout): + await asyncio.sleep(retry_interval) + retries += 1 + if retries >= max_retries: + raise + +def read_one_line(filename): + try: + with open(filename) as f: + return f.read().strip() + except FileNotFoundError: + return '' class cloud_instance(metaclass=ABCMeta): @abstractmethod @@ -126,6 +147,20 @@ def nvme_disk_count(self): def endpoint_snitch(self): pass + @classmethod + @abstractmethod + def identify_dmi(cls): + pass + + @classmethod + @abstractmethod + async def identify_metadata(cls): + pass + + @classmethod + async def identify(cls): + return cls.identify_dmi() or await cls.identify_metadata() + class gcp_instance(cloud_instance): @@ -150,13 +185,20 @@ def endpoint_snitch(self): return self.ENDPOINT_SNITCH - @staticmethod - def is_gce_instance(): + @classmethod + def identify_dmi(cls): + product_name = read_one_line('/sys/class/dmi/id/product_name') + if product_name == "Google Compute Engine": + return cls + return None + + @classmethod + async def identify_metadata(cls): """Check if it's GCE instance via DNS lookup to metadata server.""" try: addrlist = socket.getaddrinfo('metadata.google.internal', 80) except socket.gaierror: - return False + return None for res in addrlist: af, socktype, proto, canonname, sa = res if af == socket.AF_INET: @@ -164,11 +206,13 @@ def is_gce_instance(): if addr == "169.254.169.254": # Make sure it is not on GKE try: - gcp_instance().__instance_metadata("machine-type") + await aiocurl(cls.META_DATA_BASE_URL + "machine-type?recursive=false", + headers={"Metadata-Flavor": "Google"}) except urllib.error.HTTPError: - return False - return True - return False + return None + return cls + return None + def __instance_metadata(self, path, recursive=False): return curl(self.META_DATA_BASE_URL + path + "?recursive=%s" % str(recursive).lower(), @@ -423,17 +467,28 @@ def endpoint_snitch(self): return self.ENDPOINT_SNITCH @classmethod - def is_azure_instance(cls): + def identify_dmi(cls): + # On Azure, we cannot discriminate between Azure and baremetal Hyper-V + # from DMI. + # But only Azure has waagent, so we can use it for Azure detection. + sys_vendor = read_one_line('/sys/class/dmi/id/sys_vendor') + if sys_vendor == "Microsoft Corporation" and os.path.exists('/etc/waagent.conf'): + return cls + return None + + @classmethod + async def identify_metadata(cls): """Check if it's Azure instance via query to metadata server.""" try: - curl(cls.META_DATA_BASE_URL + cls.API_VERSION + "&format=text", headers = { "Metadata": "True" }, max_retries=2, retry_interval=1) - return True + await aiocurl(cls.META_DATA_BASE_URL + cls.API_VERSION + "&format=text", headers = { "Metadata": "True" }) + return cls except (urllib.error.URLError, urllib.error.HTTPError): - return False + return None # as per https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=windows#supported-api-versions API_VERSION = "?api-version=2021-01-01" + def __instance_metadata(self, path): """query Azure metadata server""" return curl(self.META_DATA_BASE_URL + path + self.API_VERSION + "&format=text", headers = { "Metadata": "True" }) @@ -715,13 +770,26 @@ def endpoint_snitch(self): @classmethod - def is_aws_instance(cls): + def identify_dmi(cls): + product_version = read_one_line('/sys/class/dmi/id/product_version') + bios_vendor = read_one_line('/sys/class/dmi/id/bios_vendor') + # On Xen instance, product_version is like "4.11.amazon" + if product_version.endswith('.amazon'): + return cls + # On Nitro instance / Baremetal instance, bios_vendor is "Amazon EC2" + if bios_vendor == 'Amazon EC2': + return cls + return None + + @classmethod + async def identify_metadata(cls): """Check if it's AWS instance via query to metadata server.""" try: - curl(cls.META_DATA_BASE_URL + "api/token", headers={"X-aws-ec2-metadata-token-ttl-seconds": cls.METADATA_TOKEN_TTL}, method="PUT") - return True + res = await aiocurl(cls.META_DATA_BASE_URL + "api/token", headers={"X-aws-ec2-metadata-token-ttl-seconds": cls.METADATA_TOKEN_TTL}, method="PUT") + print(f'aws_instance: {res}') + return cls except (urllib.error.URLError, urllib.error.HTTPError, socket.timeout): - return False + return None @property def instancetype(self): @@ -823,25 +891,44 @@ def user_data(self): return '' +async def identify_cloud_async(): + tasks = [ + asyncio.create_task(aws_instance.identify()), + asyncio.create_task(gcp_instance.identify()), + asyncio.create_task(azure_instance.identify()) + ] + result = None + for task in asyncio.as_completed(tasks): + result = await task + if result: + for other_task in tasks: + other_task.cancel() + break + return result + +_identify_cloud_result = None +def identify_cloud(): + global _identify_cloud_result + if _identify_cloud_result: + return _identify_cloud_result + time_start = datetime.datetime.now() + result = asyncio.run(identify_cloud_async()) + time_end = datetime.datetime.now() + _identify_cloud_result = result + return result def is_ec2(): - return aws_instance.is_aws_instance() + return identify_cloud() == aws_instance def is_gce(): - return gcp_instance.is_gce_instance() + return identify_cloud() == gcp_instance def is_azure(): - return azure_instance.is_azure_instance() + return identify_cloud() == azure_instance def get_cloud_instance(): - if is_ec2(): - return aws_instance() - elif is_gce(): - return gcp_instance() - elif is_azure(): - return azure_instance() - else: - raise Exception("Unknown cloud provider! Only AWS/GCP/Azure supported.") + cls = identify_cloud() + return cls() CONCOLORS = {'green': '\033[1;32m', 'red': '\033[1;31m', 'yellow': '\033[1;33m', 'nocolor': '\033[0m'} diff --git a/tests/test_aws_instance.py b/tests/test_aws_instance.py index bc553ddc..acf6c556 100644 --- a/tests/test_aws_instance.py +++ b/tests/test_aws_instance.py @@ -3,7 +3,7 @@ import httpretty import unittest.mock from pathlib import Path -from unittest import TestCase +from unittest import TestCase, IsolatedAsyncioTestCase from subprocess import CalledProcessError from collections import namedtuple @@ -23,6 +23,8 @@ def _mock_multi_open(files, filename, *args, **kwargs): def mock_multi_open_i3en_2xlarge(filename, *args, **kwargs): files = { + '/sys/class/dmi/id/product_version': '', + '/sys/class/dmi/id/bios_vendor': 'Amazon EC2', '/sys/class/net/eth0/address': '00:00:5e:00:53:00\n', '/sys/class/nvme/nvme0/model': 'Amazon Elastic Block Store\n', '/sys/class/nvme/nvme1/model': 'Amazon EC2 NVMe Instance Storage\n', @@ -33,6 +35,8 @@ def mock_multi_open_i3en_2xlarge(filename, *args, **kwargs): def mock_multi_open_i3en_2xlarge_with_ebs(filename, *args, **kwargs): files = { + '/sys/class/dmi/id/product_version': '', + '/sys/class/dmi/id/bios_vendor': 'Amazon EC2', '/sys/class/net/eth0/address': '00:00:5e:00:53:00\n', '/sys/class/nvme/nvme0/model': 'Amazon Elastic Block Store\n', '/sys/class/nvme/nvme1/model': 'Amazon Elastic Block Store\n', @@ -46,6 +50,8 @@ def mock_multi_open_i3en_2xlarge_with_ebs(filename, *args, **kwargs): def mock_multi_open_i3_2xlarge(filename, *args, **kwargs): files = { + '/sys/class/dmi/id/product_version': '4.11.amazon', + '/sys/class/dmi/id/bios_vendor': 'Xen', '/sys/class/net/eth0/address': '00:00:5e:00:53:00\n', '/sys/class/nvme/nvme0/model': 'Amazon EC2 NVMe Instance Storage\n', '/sys/class/nvme/nvme1/model': 'Amazon EC2 NVMe Instance Storage\n', @@ -86,14 +92,7 @@ def mock_multi_run_i3_2xlarge(*popenargs, mock_listdevdir_i3_2xlarge = ['md0', 'root', 'nvme0n1', 'nvme1n1', 'xvda1', 'xvda', 'nvme1', 'nvme0', 'zero', 'null'] -class TestAwsInstance(TestCase): - def setUp(self): - httpretty.enable(verbose=True, allow_net_connect=False) - - def tearDown(self): - httpretty.disable() - httpretty.reset() - +class AwsMetadata: def httpretty_aws_metadata(self, instance_type='i3en.2xlarge', with_ebs=False, with_userdata=False): if not with_userdata: httpretty.register_uri( @@ -205,21 +204,37 @@ def httpretty_aws_metadata(self, instance_type='i3en.2xlarge', with_ebs=False, w ) - def test_is_aws_instance(self): +class TestAsyncAwsInstance(IsolatedAsyncioTestCase, AwsMetadata): + def setUp(self): + httpretty.enable(verbose=True, allow_net_connect=False) + + def tearDown(self): + httpretty.disable() + httpretty.reset() + + async def test_identify_metadata(self): self.httpretty_aws_metadata() - assert aws_instance.is_aws_instance() + assert await aws_instance.identify_metadata() - def test_is_not_aws_instance(self): + async def test_not_identify_metadata(self): httpretty.disable() - real_curl = lib.scylla_cloud.curl + real_curl = lib.scylla_cloud.aiocurl - def mocked_curl(*args, **kwargs): + async def mocked_curl(*args, **kwargs): kwargs['timeout'] = 0.1 kwargs['retry_interval'] = 0.001 - return real_curl(*args, **kwargs) + return await real_curl(*args, **kwargs) + + with unittest.mock.patch('lib.scylla_cloud.aiocurl', new=mocked_curl): + assert not await aws_instance.identify_metadata() + +class TestAwsInstance(TestCase, AwsMetadata): + def setUp(self): + httpretty.enable(verbose=True, allow_net_connect=False) - with unittest.mock.patch('lib.scylla_cloud.curl', new=mocked_curl): - assert not aws_instance.is_aws_instance() + def tearDown(self): + httpretty.disable() + httpretty.reset() def test_endpoint_snitch(self): self.httpretty_aws_metadata() @@ -306,6 +321,11 @@ def test_no_user_data(self): assert not ins.user_data + def test_identify_dmi_i3en_2xlarge(self): + self.httpretty_aws_metadata() + with unittest.mock.patch('builtins.open', unittest.mock.MagicMock(side_effect=mock_multi_open_i3en_2xlarge)): + assert aws_instance.identify_dmi() + def test_non_root_nvmes_i3en_2xlarge(self): self.httpretty_aws_metadata() with unittest.mock.patch('psutil.disk_partitions', return_value=mock_disk_partitions),\ @@ -447,6 +467,11 @@ def test_get_remote_disks_i3en_2xlarge_with_ebs(self): assert ins.get_remote_disks() == ['nvme2n1', 'nvme1n1'] + def test_identify_dmi_i3_2xlarge(self): + self.httpretty_aws_metadata() + with unittest.mock.patch('builtins.open', unittest.mock.MagicMock(side_effect=mock_multi_open_i3_2xlarge)): + assert aws_instance.identify_dmi() + def test_non_root_nvmes_i3_2xlarge(self): self.httpretty_aws_metadata() with unittest.mock.patch('psutil.disk_partitions', return_value=mock_disk_partitions),\ diff --git a/tests/test_azure_instance.py b/tests/test_azure_instance.py index 464974df..560889c3 100644 --- a/tests/test_azure_instance.py +++ b/tests/test_azure_instance.py @@ -4,7 +4,7 @@ import unittest.mock import base64 import re -from unittest import TestCase +from unittest import TestCase, IsolatedAsyncioTestCase from collections import namedtuple from pathlib import Path @@ -42,14 +42,20 @@ mock_glob_glob_dev_standard_l16s_v2_2persistent_noswap = ['/dev/sdc', '/dev/sdb', '/dev/sda15', '/dev/sda14', '/dev/sda1', '/dev/sda'] mock_glob_glob_dev_standard_l32s_v2 = mock_glob_glob_dev_standard_l16s_v2 -class TestAzureInstance(TestCase): - def setUp(self): - httpretty.enable(verbose=True, allow_net_connect=False) +def _mock_multi_open(files, filename, *args, **kwargs): + if filename in files: + return unittest.mock.mock_open(read_data=files[filename]).return_value + else: + raise FileNotFoundError(f'Unable to open {filename}') - def tearDown(self): - httpretty.disable() - httpretty.reset() +def mock_multi_open_l(filename, *args, **kwargs): + files = { + '/sys/class/dmi/id/sys_vendor': 'Microsoft Corporation', + '/etc/waagent.conf': '' + } + return _mock_multi_open(files, filename, *args, **kwargs) +class AzureMetadata: def httpretty_azure_metadata(self, instance_type='Standard_L16s_v2', with_userdata=False): httpretty.register_uri( httpretty.GET, @@ -93,24 +99,44 @@ def httpretty_no_azure_metadata(self): status=404 ) - def test_is_azure_instance(self): +class TestAsyncAzureInstance(IsolatedAsyncioTestCase, AzureMetadata): + def setUp(self): + httpretty.enable(verbose=True, allow_net_connect=False) + + def tearDown(self): + httpretty.disable() + httpretty.reset() + async def test_identify_metadata(self): self.httpretty_azure_metadata() - assert azure_instance.is_azure_instance() + assert await azure_instance.identify_metadata() # XXX: Seems like Github Actions is running in Azure, we cannot disable # httpretty here (it suceeded to connect metadata server even we disabled # httpretty) - def test_is_not_azure_instance(self): + async def test_not_identify_metadata(self): self.httpretty_no_azure_metadata() - real_curl = lib.scylla_cloud.curl + real_curl = lib.scylla_cloud.aiocurl - def mocked_curl(*args, **kwargs): + async def mocked_curl(*args, **kwargs): kwargs['timeout'] = 0.001 kwargs['retry_interval'] = 0.0001 - return real_curl(*args, **kwargs) + return await real_curl(*args, **kwargs) + + with unittest.mock.patch('lib.scylla_cloud.aiocurl', new=mocked_curl): + assert not await azure_instance.identify_metadata() + +class TestAzureInstance(TestCase, AzureMetadata): + def setUp(self): + httpretty.enable(verbose=True, allow_net_connect=False) + + def tearDown(self): + httpretty.disable() + httpretty.reset() - with unittest.mock.patch('lib.scylla_cloud.curl', new=mocked_curl): - assert not azure_instance.is_azure_instance() + def test_identify_dmi(self): + with unittest.mock.patch('builtins.open', unittest.mock.MagicMock(side_effect=mock_multi_open_l)),\ + unittest.mock.patch('os.path.exists', unittest.mock.MagicMock(side_effect=mock_multi_open_l)): + assert azure_instance.identify_dmi() def test_endpoint_snitch(self): self.httpretty_azure_metadata() diff --git a/tests/test_gcp_instance.py b/tests/test_gcp_instance.py index cb9619b2..02177239 100644 --- a/tests/test_gcp_instance.py +++ b/tests/test_gcp_instance.py @@ -3,7 +3,7 @@ import httpretty import unittest.mock import json -from unittest import TestCase +from unittest import TestCase, IsolatedAsyncioTestCase from collections import namedtuple from socket import AddressFamily, SocketKind from pathlib import Path @@ -35,14 +35,20 @@ mock_glob_glob_dev_n2_highcpu_8_4ssd = mock_glob_glob_dev_n2_standard_8 mock_glob_glob_dev_n2_standard_8_4ssd_2persistent = ['/dev/sdc', '/dev/sdb', '/dev/sda15', '/dev/sda14', '/dev/sda1', '/dev/sda'] -class TestGcpInstance(TestCase): - def setUp(self): - httpretty.enable(verbose=True, allow_net_connect=False) +def _mock_multi_open(files, filename, *args, **kwargs): + if filename in files: + return unittest.mock.mock_open(read_data=files[filename]).return_value + else: + raise FileNotFoundError(f'Unable to open {filename}') - def tearDown(self): - httpretty.disable() - httpretty.reset() +def mock_multi_open_n2(filename, *args, **kwargs): + files = { + '/sys/class/dmi/id/product_name': 'Google Compute Engine' + } + return _mock_multi_open(files, filename, *args, **kwargs) + +class GcpMetadata: def httpretty_gcp_metadata(self, instance_type='n2-standard-8', project_number='431729375847', instance_name='testcase_1', num_local_disks=4, num_remote_disks=0, with_userdata=False): httpretty.register_uri( httpretty.GET, @@ -82,14 +88,33 @@ def httpretty_gcp_metadata(self, instance_type='n2-standard-8', project_number=' '{"scylla_yaml": {"cluster_name": "test-cluster"}}' ) +class TestAsyncGcpInstance(IsolatedAsyncioTestCase, GcpMetadata): + def setUp(self): + httpretty.enable(verbose=True, allow_net_connect=False) + + def tearDown(self): + httpretty.disable() + httpretty.reset() - def test_is_gce_instance(self): + async def test_identify_metadata(self): self.httpretty_gcp_metadata() with unittest.mock.patch('socket.getaddrinfo', return_value=[(AddressFamily.AF_INET, SocketKind.SOCK_STREAM, 6, '', ('169.254.169.254', 80))]): - assert gcp_instance.is_gce_instance() + assert await gcp_instance.identify_metadata() + + async def test_not_identify_metadata(self): + assert not await gcp_instance.identify_metadata() + +class TestGcpInstance(TestCase, GcpMetadata): + def setUp(self): + httpretty.enable(verbose=True, allow_net_connect=False) + + def tearDown(self): + httpretty.disable() + httpretty.reset() - def test_is_not_gce_instance(self): - assert not gcp_instance.is_gce_instance() + def test_identify_dmi(self): + with unittest.mock.patch('builtins.open', unittest.mock.MagicMock(side_effect=mock_multi_open_n2)): + assert gcp_instance.identify_dmi() def test_endpoint_snitch(self): self.httpretty_gcp_metadata()