diff --git a/src/swsssdk/port_util.py b/src/swsssdk/port_util.py index a21f5288a..27de486c2 100644 --- a/src/swsssdk/port_util.py +++ b/src/swsssdk/port_util.py @@ -164,13 +164,14 @@ def get_rif_port_map(db): return rif_port_oid_map -def get_vlan_interface_oid_map(db): +def get_vlan_interface_oid_map(db, blocking=True): """ Get Vlan Interface names and sai oids """ db.connect('COUNTERS_DB') - rif_name_map = db.get_all('COUNTERS_DB', 'COUNTERS_RIF_NAME_MAP', blocking=True) - rif_type_name_map = db.get_all('COUNTERS_DB', 'COUNTERS_RIF_TYPE_MAP', blocking=True) + + rif_name_map = db.get_all('COUNTERS_DB', 'COUNTERS_RIF_NAME_MAP', blocking=blocking) + rif_type_name_map = db.get_all('COUNTERS_DB', 'COUNTERS_RIF_TYPE_MAP', blocking=blocking) if not rif_name_map or not rif_type_name_map: return {} diff --git a/test/test_port_util.py b/test/test_port_util.py new file mode 100644 index 000000000..2281f43d9 --- /dev/null +++ b/test/test_port_util.py @@ -0,0 +1,19 @@ +import os +import sys + +if sys.version_info.major == 3: + from unittest import mock +else: + import mock + +modules_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, os.path.join(modules_path, 'src')) + +class TestPortUtil: + def test_get_vlan_interface_oid_map(self): + db = mock.MagicMock() + db.get_all = mock.MagicMock() + db.get_all.return_value = {} + + from swsssdk.port_util import get_vlan_interface_oid_map + assert not get_vlan_interface_oid_map(db, True)