Skip to content

Commit

Permalink
Optimize cloud detection
Browse files Browse the repository at this point in the history
Currently get_cloud_instance() is very slow on GCE / Azure, since it fails to
detect AWS metadata server, and retrying multiple times in curl().

To speed up cloud detection, we will introduce following:
 - Use DMI parameters for the detection when it's available
 - Use async/await for metadata server access

Fixes #481
  • Loading branch information
syuu1228 committed Dec 18, 2023
1 parent d980383 commit f99bef1
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 75 deletions.
151 changes: 119 additions & 32 deletions lib/scylla_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import distro
import base64
import datetime
import asyncio
from subprocess import run, CalledProcessError
from abc import ABCMeta, abstractmethod

Expand Down Expand Up @@ -58,23 +59,43 @@ def scylla_excepthook(etype, value, tb):
sys.excepthook = scylla_excepthook


def _curl_one(url, headers=None, method=None, byte=False, timeout=3):
req = urllib.request.Request(url, headers=headers or {}, method=method)
with urllib.request.urlopen(req, timeout=timeout) as res:
if byte:
return res.read()
else:
return res.read().decode('utf-8')

# @param headers dict of k:v
def curl(url, headers=None, method=None, byte=False, timeout=3, max_retries=5, retry_interval=5):
retries = 0
while True:
try:
req = urllib.request.Request(url, headers=headers or {}, method=method)
with urllib.request.urlopen(req, timeout=timeout) as res:
if byte:
return res.read()
else:
return res.read().decode('utf-8')
return _curl_one(url, headers, method, byte, timeout)
except (urllib.error.URLError, socket.timeout):
time.sleep(retry_interval)
retries += 1
if retries >= max_retries:
raise

async def aiocurl(url, headers=None, method=None, byte=False, timeout=3, max_retries=5, retry_interval=5):
retries = 0
while True:
try:
return _curl_one(url, headers, method, byte, timeout)
except (urllib.error.URLError, socket.timeout):
await asyncio.sleep(retry_interval)
retries += 1
if retries >= max_retries:
raise

def read_one_line(filename):
try:
with open(filename) as f:
return f.read().strip()
except FileNotFoundError:
return ''

class cloud_instance(metaclass=ABCMeta):
@abstractmethod
Expand Down Expand Up @@ -126,6 +147,20 @@ def nvme_disk_count(self):
def endpoint_snitch(self):
pass

@classmethod
@abstractmethod
def identify_dmi(cls):
pass

@classmethod
@abstractmethod
async def identify_metadata(cls):
pass

@classmethod
async def identify(cls):
return identify_dmi() or await cls.identify_metadata()



class gcp_instance(cloud_instance):
Expand All @@ -150,25 +185,34 @@ def endpoint_snitch(self):
return self.ENDPOINT_SNITCH


@staticmethod
def is_gce_instance():
@classmethod
def identify_dmi(cls):
product_name = read_one_line('/sys/class/dmi/id/product_name')
if product_name == "Google Compute Engine":
return cls
return None

@classmethod
async def identify_metadata(cls):
"""Check if it's GCE instance via DNS lookup to metadata server."""
try:
addrlist = socket.getaddrinfo('metadata.google.internal', 80)
except socket.gaierror:
return False
return None
for res in addrlist:
af, socktype, proto, canonname, sa = res
if af == socket.AF_INET:
addr, port = sa
if addr == "169.254.169.254":
# Make sure it is not on GKE
try:
gcp_instance().__instance_metadata("machine-type")
await aiocurl(cls.META_DATA_BASE_URL + "machine-type?recursive=false",
headers={"Metadata-Flavor": "Google"})
except urllib.error.HTTPError:
return False
return True
return False
return None
return cls
return None


def __instance_metadata(self, path, recursive=False):
return curl(self.META_DATA_BASE_URL + path + "?recursive=%s" % str(recursive).lower(),
Expand Down Expand Up @@ -423,17 +467,28 @@ def endpoint_snitch(self):
return self.ENDPOINT_SNITCH

@classmethod
def is_azure_instance(cls):
def identify_dmi(cls):
# On Azure, we cannot discriminate between Azure and baremetal Hyper-V
# from DMI.
# But only Azure has waagent, so we can use it for Azure detection.
sys_vendor = read_one_line('/sys/class/dmi/id/sys_vendor')
if sys_vendor == "Microsoft Corporation" and os.path.exists('/etc/waagent.conf'):
return cls
return None

@classmethod
async def identify_metadata(cls):
"""Check if it's Azure instance via query to metadata server."""
try:
curl(cls.META_DATA_BASE_URL + cls.API_VERSION + "&format=text", headers = { "Metadata": "True" }, max_retries=2, retry_interval=1)
return True
await aiocurl(cls.META_DATA_BASE_URL + cls.API_VERSION + "&format=text", headers = { "Metadata": "True" })
return cls
except (urllib.error.URLError, urllib.error.HTTPError):
return False
return None

# as per https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=windows#supported-api-versions
API_VERSION = "?api-version=2021-01-01"


def __instance_metadata(self, path):
"""query Azure metadata server"""
return curl(self.META_DATA_BASE_URL + path + self.API_VERSION + "&format=text", headers = { "Metadata": "True" })
Expand Down Expand Up @@ -715,13 +770,26 @@ def endpoint_snitch(self):


@classmethod
def is_aws_instance(cls):
def identify_dmi(cls):
product_version = read_one_line('/sys/class/dmi/id/product_version')
bios_vendor = read_one_line('/sys/class/dmi/id/bios_vendor')
# On Xen instance, product_version is like "4.11.amazon"
if product_version.endswith('.amazon'):
return cls
# On Nitro instance / Baremetal instance, bios_vendor is "Amazon EC2"
if bios_vendor == 'Amazon EC2':
return cls
return None

@classmethod
async def identify_metadata(cls):
"""Check if it's AWS instance via query to metadata server."""
try:
curl(cls.META_DATA_BASE_URL + "api/token", headers={"X-aws-ec2-metadata-token-ttl-seconds": cls.METADATA_TOKEN_TTL}, method="PUT")
return True
res = await aiocurl(cls.META_DATA_BASE_URL + "api/token", headers={"X-aws-ec2-metadata-token-ttl-seconds": cls.METADATA_TOKEN_TTL}, method="PUT")
print(f'aws_instance: {res}')
return cls
except (urllib.error.URLError, urllib.error.HTTPError, socket.timeout):
return False
return None

@property
def instancetype(self):
Expand Down Expand Up @@ -823,25 +891,44 @@ def user_data(self):
return ''


async def identify_cloud_async():
tasks = [
asyncio.create_task(aws_instance.identify()),
asyncio.create_task(gcp_instance.identify()),
asyncio.create_task(azure_instance.identify())
]
result = None
for task in asyncio.as_completed(tasks):
result = await task
if result:
for other_task in tasks:
other_task.cancel()
break
return result

_identify_cloud_result = None
def identify_cloud():
global _identify_cloud_result
if _identify_cloud_result:
return _identify_cloud_result
time_start = datetime.datetime.now()
result = asyncio.run(identify_cloud_async())
time_end = datetime.datetime.now()
_identify_cloud_result = result
return result

def is_ec2():
return aws_instance.is_aws_instance()
return identify_cloud() == aws_instance

def is_gce():
return gcp_instance.is_gce_instance()
return identify_cloud() == gcp_instance

def is_azure():
return azure_instance.is_azure_instance()
return identify_cloud() == azure_instance

def get_cloud_instance():
if is_ec2():
return aws_instance()
elif is_gce():
return gcp_instance()
elif is_azure():
return azure_instance()
else:
raise Exception("Unknown cloud provider! Only AWS/GCP/Azure supported.")
cls = identify_cloud()
return cls()


CONCOLORS = {'green': '\033[1;32m', 'red': '\033[1;31m', 'yellow': '\033[1;33m', 'nocolor': '\033[0m'}
Expand Down
59 changes: 42 additions & 17 deletions tests/test_aws_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import httpretty
import unittest.mock
from pathlib import Path
from unittest import TestCase
from unittest import TestCase, IsolatedAsyncioTestCase
from subprocess import CalledProcessError
from collections import namedtuple

Expand All @@ -23,6 +23,8 @@ def _mock_multi_open(files, filename, *args, **kwargs):

def mock_multi_open_i3en_2xlarge(filename, *args, **kwargs):
files = {
'/sys/class/dmi/id/product_version': '',
'/sys/class/dmi/id/bios_vendor': 'Amazon EC2',
'/sys/class/net/eth0/address': '00:00:5e:00:53:00\n',
'/sys/class/nvme/nvme0/model': 'Amazon Elastic Block Store\n',
'/sys/class/nvme/nvme1/model': 'Amazon EC2 NVMe Instance Storage\n',
Expand All @@ -33,6 +35,8 @@ def mock_multi_open_i3en_2xlarge(filename, *args, **kwargs):

def mock_multi_open_i3en_2xlarge_with_ebs(filename, *args, **kwargs):
files = {
'/sys/class/dmi/id/product_version': '',
'/sys/class/dmi/id/bios_vendor': 'Amazon EC2',
'/sys/class/net/eth0/address': '00:00:5e:00:53:00\n',
'/sys/class/nvme/nvme0/model': 'Amazon Elastic Block Store\n',
'/sys/class/nvme/nvme1/model': 'Amazon Elastic Block Store\n',
Expand All @@ -46,6 +50,8 @@ def mock_multi_open_i3en_2xlarge_with_ebs(filename, *args, **kwargs):

def mock_multi_open_i3_2xlarge(filename, *args, **kwargs):
files = {
'/sys/class/dmi/id/product_version': '4.11.amazon',
'/sys/class/dmi/id/bios_vendor': 'Xen',
'/sys/class/net/eth0/address': '00:00:5e:00:53:00\n',
'/sys/class/nvme/nvme0/model': 'Amazon EC2 NVMe Instance Storage\n',
'/sys/class/nvme/nvme1/model': 'Amazon EC2 NVMe Instance Storage\n',
Expand Down Expand Up @@ -86,14 +92,7 @@ def mock_multi_run_i3_2xlarge(*popenargs,
mock_listdevdir_i3_2xlarge = ['md0', 'root', 'nvme0n1', 'nvme1n1', 'xvda1', 'xvda', 'nvme1', 'nvme0', 'zero', 'null']


class TestAwsInstance(TestCase):
def setUp(self):
httpretty.enable(verbose=True, allow_net_connect=False)

def tearDown(self):
httpretty.disable()
httpretty.reset()

class AwsMetadata:
def httpretty_aws_metadata(self, instance_type='i3en.2xlarge', with_ebs=False, with_userdata=False):
if not with_userdata:
httpretty.register_uri(
Expand Down Expand Up @@ -205,21 +204,37 @@ def httpretty_aws_metadata(self, instance_type='i3en.2xlarge', with_ebs=False, w
)


def test_is_aws_instance(self):
class TestAsyncAwsInstance(IsolatedAsyncioTestCase, AwsMetadata):
def setUp(self):
httpretty.enable(verbose=True, allow_net_connect=False)

def tearDown(self):
httpretty.disable()
httpretty.reset()

async def test_identify_metadata(self):
self.httpretty_aws_metadata()
assert aws_instance.is_aws_instance()
assert await aws_instance.identify_metadata()

def test_is_not_aws_instance(self):
async def test_not_identify_metadata(self):
httpretty.disable()
real_curl = lib.scylla_cloud.curl
real_curl = lib.scylla_cloud.aiocurl

def mocked_curl(*args, **kwargs):
async def mocked_curl(*args, **kwargs):
kwargs['timeout'] = 0.1
kwargs['retry_interval'] = 0.001
return real_curl(*args, **kwargs)
return await real_curl(*args, **kwargs)

with unittest.mock.patch('lib.scylla_cloud.aiocurl', new=mocked_curl):
assert not await aws_instance.identify_metadata()

class TestAwsInstance(TestCase, AwsMetadata):
def setUp(self):
httpretty.enable(verbose=True, allow_net_connect=False)

with unittest.mock.patch('lib.scylla_cloud.curl', new=mocked_curl):
assert not aws_instance.is_aws_instance()
def tearDown(self):
httpretty.disable()
httpretty.reset()

def test_endpoint_snitch(self):
self.httpretty_aws_metadata()
Expand Down Expand Up @@ -306,6 +321,11 @@ def test_no_user_data(self):
assert not ins.user_data


def test_identify_dmi_i3en_2xlarge(self):
self.httpretty_aws_metadata()
with unittest.mock.patch('builtins.open', unittest.mock.MagicMock(side_effect=mock_multi_open_i3en_2xlarge)):
assert aws_instance.identify_dmi()

def test_non_root_nvmes_i3en_2xlarge(self):
self.httpretty_aws_metadata()
with unittest.mock.patch('psutil.disk_partitions', return_value=mock_disk_partitions),\
Expand Down Expand Up @@ -447,6 +467,11 @@ def test_get_remote_disks_i3en_2xlarge_with_ebs(self):
assert ins.get_remote_disks() == ['nvme2n1', 'nvme1n1']


def test_identify_dmi_i3_2xlarge(self):
self.httpretty_aws_metadata()
with unittest.mock.patch('builtins.open', unittest.mock.MagicMock(side_effect=mock_multi_open_i3_2xlarge)):
assert aws_instance.identify_dmi()

def test_non_root_nvmes_i3_2xlarge(self):
self.httpretty_aws_metadata()
with unittest.mock.patch('psutil.disk_partitions', return_value=mock_disk_partitions),\
Expand Down
Loading

0 comments on commit f99bef1

Please sign in to comment.