Skip to content

Commit

Permalink
[Feat] support pass@k
Browse files Browse the repository at this point in the history
  • Loading branch information
yingfhu committed Nov 16, 2023
1 parent 9740bba commit 62d3497
Showing 1 changed file with 121 additions and 72 deletions.
193 changes: 121 additions & 72 deletions opencompass/datasets/mbpp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import contextlib
import io
import itertools
import multiprocessing
import re
import signal
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -73,6 +75,50 @@ class TimeOutException(Exception):
pass


@contextlib.contextmanager
def swallow_io():
stream = WriteOnlyStringIO()
with contextlib.redirect_stdout(stream):
with contextlib.redirect_stderr(stream):
with redirect_stdin(stream):
yield


@contextlib.contextmanager
def time_limit(seconds: float):

def signal_handler(signum, frame):
raise TimeOutException('Time out!')

signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)


class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from."""

def read(self, *args, **kwargs):
raise IOError

def readline(self, *args, **kwargs):
raise IOError

def readlines(self, *args, **kwargs):
raise IOError

def readable(self, *args, **kwargs):
"""Returns True if the IO object can be read."""
return False


class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = 'stdin'


@ICL_EVALUATORS.register_module()
class MBPPEvaluator(BaseEvaluator):

Expand All @@ -87,8 +133,8 @@ def score(self, predictions, references):
# Add exec globals to prevent the exec to raise
# unnecessary NameError for correct answer
exec_globals = {}
with self.swallow_io():
with self.time_limit(2):
with swallow_io():
with time_limit(2):
exec(programs, exec_globals)
result['pass'] += 1
except TimeOutException:
Expand Down Expand Up @@ -121,46 +167,6 @@ def _process_test(self, test_case, pred):
formatted += test_case
return formatted

@contextlib.contextmanager
def swallow_io(self):
stream = self.WriteOnlyStringIO()
with contextlib.redirect_stdout(stream):
with contextlib.redirect_stderr(stream):
with self.redirect_stdin(stream):
yield

@contextlib.contextmanager
def time_limit(self, seconds: float):

def signal_handler(signum, frame):
raise TimeOutException('Time out!')

signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)

class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from."""

def read(self, *args, **kwargs):
raise IOError

def readline(self, *args, **kwargs):
raise IOError

def readlines(self, *args, **kwargs):
raise IOError

def readable(self, *args, **kwargs):
"""Returns True if the IO object can be read."""
return False

class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = 'stdin'


@ICL_EVALUATORS.register_module()
class MBPPEvaluator2(MBPPEvaluator):
Expand Down Expand Up @@ -200,6 +206,54 @@ def _process_answer(self, text):
return text


def execution(programs, task_id, timeout):
"""Execution function for running generation code.
Args:
programs(str): Python code to be executed.
task_id(int): Task id of the current example.
timeout(int): Time limit for execution, avoid unnecessary
blocking.
In pass@k scenario, a lot of programs should be executed.
Some internal error cannot be handled properly, such as
`RecursionError` might cause system break. It is better to
separate the execution in thread or multiprocess to better
control the process.
"""

def _execution(programs, timeout):
try:
# Add exec globals to prevent the exec to raise
# unnecessary NameError for correct answer
exec_globals = {}
with swallow_io():
with time_limit(timeout):
exec(programs, exec_globals)
key.append('pass')
except TimeOutException:
key.append('timeout')
except AssertionError:
key.append('wrong_answer')
except BaseException as e:
print(e)
key.append('failed')

manager = multiprocessing.Manager()
key = manager.list()
# `signal` cannot be used in child thread, therefore, we
# need to create a process in the thread.
p = multiprocessing.Process(target=_execution,
args=(programs, timeout - 1))
p.start()
p.join(timeout=timeout)
if p.is_alive():
p.kill()
# key might not have value if killed
return task_id, 'timeout'
return task_id, key[0]


class MBPPPassKEvaluator(MBPPEvaluator):
"""Better use for pass k evaluation.
Expand Down Expand Up @@ -246,37 +300,32 @@ def score(self, predictions, references):
task_total = defaultdict(int)

result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0}
for refer, preds in zip(references, predictions):
# suits for two case
# 1. use repeated dataset
# 2. use `num_return_sequences` to generate multiple responses
if not isinstance(preds, list):
preds = [preds]
test_case = refer['test_list_2']
task_id = refer['task_id']
# create empty task_pass in case all example failed
if task_id not in task_pass:
task_pass[task_id] = 0
for pred in preds:
pred = self._process_answer(pred)
programs = self._process_test(test_case, pred)
try:
# Add exec globals to prevent the exec to raise
# unnecessary NameError for correct answer
exec_globals = {}
with self.swallow_io():
with self.time_limit(2):
exec(programs, exec_globals)
result['pass'] += 1
with ThreadPoolExecutor() as executor:
futures = []
for refer, preds in zip(references, predictions):
# suits for two case
# 1. use repeated dataset
# 2. use `num_return_sequences` to generate multiple responses
if not isinstance(preds, list):
preds = [preds]
test_case = refer['test_list_2']
task_id = refer['task_id']
# create empty task_pass in case all example failed
if task_id not in task_pass:
task_pass[task_id] = 0
for pred in preds:
pred = self._process_answer(pred)
programs = self._process_test(test_case, pred)
future = executor.submit(execution, programs, task_id, 3)
futures.append(future)

from tqdm import tqdm
for future in tqdm(as_completed(futures), total=len(futures)):
task_id, key = future.result()
result[key] += 1
task_total[task_id] += 1
if key == 'pass':
task_pass[task_id] += 1
except TimeOutException:
result['timeout'] += 1
except AssertionError:
result['wrong_answer'] += 1
except BaseException:
result['failed'] += 1
finally:
task_total[task_id] += 1

def get_number(tasks):
return np.array([
Expand Down

0 comments on commit 62d3497

Please sign in to comment.