Skip to content

Commit

Permalink
Merge pull request #91 from qxygxt/main
Browse files Browse the repository at this point in the history
Optimize the problem of CI test and the structure of OSPP MDIL-SS
  • Loading branch information
jaypume authored Feb 1, 2024
2 parents d8fa17a + 89636a3 commit 7ea4f4a
Show file tree
Hide file tree
Showing 16 changed files with 242 additions and 185 deletions.
6 changes: 3 additions & 3 deletions core/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ class SystemMetricType(Enum):
"""
System metric type of ianvs.
"""
# pylint: disable=C0103
SAMPLES_TRANSFER_RATIO = "samples_transfer_ratio"
FWT = "FWT"
BWT = "BWT"
Task_Avg_Acc = "Task_Avg_Acc"
Matrix = "Matrix"
TASK_AVG_ACC = "task_avg_acc"
MATRIX = "MATRIX"


class TestObjectType(Enum):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def __init__(self, workspace, **kwargs):
self.cloud_task_index = '/tmp/cloud_task/index.pkl'
self.edge_task_index = '/tmp/edge_task/index.pkl'
self.system_metric_info = {SystemMetricType.SAMPLES_TRANSFER_RATIO.value: [],
SystemMetricType.Matrix.value : {},
SystemMetricType.Task_Avg_Acc.value: {}}
SystemMetricType.MATRIX.value : {},
SystemMetricType.TASK_AVG_ACC.value: {}}

def run(self):
# pylint:disable=duplicate-code
Expand Down Expand Up @@ -147,7 +147,7 @@ def run(self):
LOGGER.info(f"{entry} scores: {scores}")
task_avg_score['accuracy'] += scores['accuracy']
task_avg_score['accuracy'] = task_avg_score['accuracy']/i
self.system_metric_info[SystemMetricType.Task_Avg_Acc.value] = task_avg_score
self.system_metric_info[SystemMetricType.TASK_AVG_ACC.value] = task_avg_score
LOGGER.info(task_avg_score)
job = self.build_paradigm_job(ParadigmType.LIFELONG_LEARNING.value)
inference_dataset = self.dataset.load_data(self.dataset.test_url, "eval",
Expand All @@ -160,7 +160,7 @@ def run(self):
for key in my_dict.keys():
matrix = my_dict[key]
#BWT, FWT = self.compute(key, matrix)
self.system_metric_info[SystemMetricType.Matrix.value][key] = matrix
self.system_metric_info[SystemMetricType.MATRIX.value][key] = matrix

elif mode == 'hard-example-mining':
dataset_files = self._split_dataset(splitting_dataset_times=rounds)
Expand Down Expand Up @@ -246,7 +246,7 @@ def run(self):
LOGGER.info(f"{entry} scores: {scores}")
task_avg_score['accuracy'] += scores['accuracy']
task_avg_score['accuracy'] = task_avg_score['accuracy']/i
self.system_metric_info[SystemMetricType.Task_Avg_Acc.value] = task_avg_score
self.system_metric_info[SystemMetricType.TASK_AVG_ACC.value] = task_avg_score
LOGGER.info(task_avg_score)
test_res, unseen_task_train_samples = self._inference(self.edge_task_index,
self.dataset.test_url,
Expand All @@ -256,7 +256,7 @@ def run(self):
for key in my_dict.keys():
matrix = my_dict[key]
#BWT, FWT = self.compute(key, matrix)
self.system_metric_info[SystemMetricType.Matrix.value][key] = matrix
self.system_metric_info[SystemMetricType.MATRIX.value][key] = matrix

elif mode != 'multi-inference':
dataset_files = self._split_dataset(splitting_dataset_times=rounds)
Expand Down
74 changes: 33 additions & 41 deletions core/testcasecontroller/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,71 +39,56 @@ def samples_transfer_ratio_func(system_metric_info: dict):
"""

info = system_metric_info.get(SystemMetricType.SAMPLES_TRANSFER_RATIO.value)
info = system_metric_info.get(
SystemMetricType.SAMPLES_TRANSFER_RATIO.value)
inference_num = 0
transfer_num = 0
for inference_data, transfer_data in info:
inference_num += len(inference_data)
transfer_num += len(transfer_data)
return round(float(transfer_num) / (inference_num + 1), 4)


def compute(key, matrix):
"""
Compute BWT and FWT scores for a given matrix.
"""
# pylint: disable=C0103
# pylint: disable=C0301
# pylint: disable=C0303
# pylint: disable=R0912
print(
f"compute function: key={key}, matrix={matrix}, type(matrix)={type(matrix)}")

print(f"compute function: key={key}, matrix={matrix}, type(matrix)={type(matrix)}")

length = len(matrix)
accuracy = 0.0
BWT_score = 0.0
FWT_score = 0.0
bwt_score = 0.0
fwt_score = 0.0
flag = True

if key == 'all':
for i in range(length-1, 0, -1):
sum_before_i = sum(item['accuracy'] for item in matrix[i][:i])
sum_after_i = sum(item['accuracy'] for item in matrix[i][-(length - i - 1):])
if i == 0:
seen_class_accuracy = 0.0
else:
seen_class_accuracy = sum_before_i / i
if length - 1 - i == 0:
unseen_class_accuracy = 0.0
else:
unseen_class_accuracy = sum_after_i / (length - 1 - i)
print(f"round {i} : unseen class accuracy is {unseen_class_accuracy}, seen class accuracy is {seen_class_accuracy}")

for row in matrix:
if not isinstance(row, list) or len(row) != length-1:
flag = False
break

if not flag:
BWT_score = np.nan
FWT_score = np.nan
return BWT_score, FWT_score
bwt_score = np.nan
fwt_score = np.nan
return bwt_score, fwt_score

for i in range(length-1):
for j in range(length-1):
if 'accuracy' in matrix[i+1][j] and 'accuracy' in matrix[i][j]:
accuracy += matrix[i+1][j]['accuracy']
BWT_score += matrix[i+1][j]['accuracy'] - matrix[i][j]['accuracy']

bwt_score += matrix[i+1][j]['accuracy'] - \
matrix[i][j]['accuracy']

for i in range(0, length-1):
if 'accuracy' in matrix[i][i] and 'accuracy' in matrix[0][i]:
FWT_score += matrix[i][i]['accuracy'] - matrix[0][i]['accuracy']
fwt_score += matrix[i][i]['accuracy'] - matrix[0][i]['accuracy']

accuracy = accuracy / ((length-1) * (length-1))
BWT_score = BWT_score / ((length-1) * (length-1))
FWT_score = FWT_score / (length-1)
bwt_score = bwt_score / ((length-1) * (length-1))
fwt_score = fwt_score / (length-1)

print(f"{key} BWT_score: {BWT_score}")
print(f"{key} FWT_score: {FWT_score}")
print(f"{key} BWT_score: {bwt_score}")
print(f"{key} FWT_score: {fwt_score}")

my_matrix = []
for i in range(length-1):
Expand All @@ -112,48 +97,53 @@ def compute(key, matrix):
if 'accuracy' in matrix[i+1][j]:
my_matrix[i].append(matrix[i+1][j]['accuracy'])

return my_matrix, BWT_score, FWT_score
return my_matrix, bwt_score, fwt_score


def bwt_func(system_metric_info: dict):
"""
compute BWT
"""
# pylint: disable=C0103
# pylint: disable=W0632
info = system_metric_info.get(SystemMetricType.Matrix.value)
info = system_metric_info.get(SystemMetricType.MATRIX.value)
_, BWT_score, _ = compute("all", info["all"])
return BWT_score


def fwt_func(system_metric_info: dict):
"""
compute FWT
"""
# pylint: disable=C0103
# pylint: disable=W0632
info = system_metric_info.get(SystemMetricType.Matrix.value)
info = system_metric_info.get(SystemMetricType.MATRIX.value)
_, _, FWT_score = compute("all", info["all"])
return FWT_score


def matrix_func(system_metric_info: dict):
"""
compute FWT
"""
# pylint: disable=C0103
# pylint: disable=W0632
info = system_metric_info.get(SystemMetricType.Matrix.value)
info = system_metric_info.get(SystemMetricType.MATRIX.value)
my_dict = {}
for key in info.keys():
my_matrix, _, _ = compute(key, info[key])
my_dict[key] = my_matrix
return my_dict


def task_avg_acc_func(system_metric_info: dict):
"""
compute Task_Avg_Acc
compute task average accuracy
"""
info = system_metric_info.get(SystemMetricType.Task_Avg_Acc.value)
info = system_metric_info.get(SystemMetricType.TASK_AVG_ACC.value)
return info["accuracy"]


def get_metric_func(metric_dict: dict):
"""
get metric func by metric info
Expand All @@ -175,9 +165,11 @@ def get_metric_func(metric_dict: dict):
if url:
try:
load_module(url)
metric_func = ClassFactory.get_cls(type_name=ClassType.GENERAL, t_cls_name=name)
metric_func = ClassFactory.get_cls(
type_name=ClassType.GENERAL, t_cls_name=name)
return name, metric_func
except Exception as err:
raise RuntimeError(f"get metric func(url={url}) failed, error: {err}.") from err
raise RuntimeError(
f"get metric func(url={url}) failed, error: {err}.") from err

return name, getattr(sys.modules[__name__], str.lower(name) + "_func")
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,3 @@ benchmarkingjob:
# 1> "selected_and_all": save selected and all dataitems;
# 2> "selected_only": save selected dataitems;
save_mode: "selected_and_all"






Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

__all__ = ('accuracy')


@ClassFactory.register(ClassType.GENERAL)
def accuracy(y_true, y_pred, **kwargs):
args = val_args()
Expand Down Expand Up @@ -35,4 +36,4 @@ def accuracy(y_true, y_pred, **kwargs):
FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()

print("CPA:{}, mIoU:{}, fwIoU: {}".format(CPA, mIoU, FWIoU))
return CPA
return CPA
Loading

0 comments on commit 7ea4f4a

Please sign in to comment.