Skip to content

Commit

Permalink
Refactor dir structure
Browse files Browse the repository at this point in the history
  • Loading branch information
herbiebradley committed Jul 10, 2023
1 parent 5070deb commit 3c9c8be
Show file tree
Hide file tree
Showing 14 changed files with 195 additions and 119 deletions.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ def search(self, init_steps: int, total_steps: int, atol: float = 0.0) -> str:
self.fitness_history["mean"].append(self.mean_fitness())
self.fitness_history["qd_score"].append(self.qd_score())
self.fitness_history["niches_filled"].append(self.niches_filled())
self.fitness_history["qd_score"].append(self.qd_score())
self.fitness_history["niches_filled"].append(self.niches_filled())

if (
self.save_snapshot_interval is not None
Expand Down
9 changes: 3 additions & 6 deletions src/openelm/benchmarks/benchmark_crossover.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import functools
import json
import os

os.environ["TRANSFORMERS_CACHE"] = "/fsx/hyperion/hf_cache"
import time
from dataclasses import asdict, dataclass, field
from itertools import permutations
Expand All @@ -16,10 +14,9 @@
from omegaconf import OmegaConf
from tqdm import trange

from openelm.algorithms.map_elites import MAPElites
from openelm.codegen import model_setup, sample, truncate
from openelm.configs import BaseConfig, MAPElitesConfig
from openelm.environments import SQUARE_SEED
from openelm.environments.environments import Sodarace, Sodaracer
from openelm.environments.sodaracer import (
CIRCLE,
CPPN_FIXED,
Expand All @@ -35,8 +32,8 @@
SQUARE_PREREQ,
WHEEL,
)
from openelm.environments.sodaracer.sodarace import Sodarace, Sodaracer
from openelm.environments.sodaracer.walker import Walker
from openelm.map_elites import MAPElites
from openelm.utils.code_eval import pool_exec_processes

INSTRUCTIONS = {
Expand Down Expand Up @@ -128,7 +125,7 @@ def benchmark_seeds(self, seeds):
).to(self.device)

sodarace_env = Sodarace(
seed=SQUARE_SEED,
seed=SEEDS_DICT["square"],
config=self.cfg,
diff_model=self.model,
eval_ms=self.cfg.eval_ms,
Expand Down
4 changes: 2 additions & 2 deletions src/openelm/elm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from hydra.core.hydra_config import HydraConfig

from openelm.configs import DiffModelConfig, ELMConfig, PromptModelConfig
from openelm.environments import QD_DICT, BaseEnvironment, load_env
from openelm.environments import BaseEnvironment, load_algorithm, load_env
from openelm.mutation_model import DiffModel, MutationModel, PromptModel


Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
)
else:
self.environment = env
self.qd_algorithm = QD_DICT[qd_name](
self.qd_algorithm = load_algorithm(qd_name)(
env=self.environment,
config=self.config.qd,
)
Expand Down
32 changes: 17 additions & 15 deletions src/openelm/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,56 @@
from typing import Any

from openelm.environments.environments import BaseEnvironment, Genotype
from openelm.map_elites import CVTMAPElites, MAPElites
from openelm.algorithms.map_elites import CVTMAPElites, MAPElites
from openelm.environments.base import BaseEnvironment, Genotype


def load_env(env_name: str) -> Any:
if env_name == "sodarace":
from openelm.environments.sodarace_env import Sodarace
from openelm.environments.sodaracer.sodarace import Sodarace

return Sodarace
elif env_name == "image_evolution":
from openelm.environments.environments import ImageOptim
from openelm.environments.base import ImageOptim

return ImageOptim
elif env_name == "match_string":
from openelm.environments.environments import MatchString
from openelm.environments.base import MatchString

return MatchString
elif env_name == "function_optim":
from openelm.environments.environments import FunctionOptim
from openelm.environments.base import FunctionOptim

return FunctionOptim
elif env_name == "p3_probsol":
from openelm.environments.p3_env import P3ProbSol
from openelm.environments.p3.p3 import P3ProbSol

return P3ProbSol
elif env_name == "p3_problem":
from openelm.environments.p3_env import P3Problem
from openelm.environments.p3.p3 import P3Problem

return P3Problem
elif env_name == "prompt_evolution":
from openelm.environments.prompt_env import PromptEvolution
from openelm.environments.prompt.prompt import PromptEvolution

return PromptEvolution
elif env_name == "qdaif":
from openelm.environments.environments import PoetryEvolution
from openelm.environments.poetry import PoetryEvolution

return PoetryEvolution
else:
raise ValueError(f"Unknown environment {env_name}")


QD_DICT: dict[str, Any] = {
"mapelites": MAPElites,
"cvtmapelites": CVTMAPElites,
}
def load_algorithm(algorithm_name: str) -> Any:
if algorithm_name == "mapelites":
return MAPElites
elif algorithm_name == "cvtmapelites":
return CVTMAPElites


__all__ = [
"Genotype",
"BaseEnvironment",
"QD_DICT",
"load_algorithm",
"load_env",
]
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import requests

from openelm.configs import EnvConfig, ImageEnvConfig, StringEnvConfig
from openelm.environments.env_utils import NULL_SEED, get_image_target
from openelm.environments.utils import NULL_SEED, get_image_target
from openelm.mutation_model import MutationModel
from openelm.utils.code_eval import pool_exec_processes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers import pipeline

from openelm.configs import P3ProblemEnvConfig, P3ProbSolEnvConfig
from openelm.environments.environments import BaseEnvironment, Genotype, Phenotype
from openelm.environments.base import BaseEnvironment, Genotype, Phenotype
from openelm.environments.p3 import (
P3_IMPORTS,
P3_PROBLEM_LONG_SEED,
Expand Down Expand Up @@ -184,7 +184,7 @@ def construct_prompt(
prompt_str += "\n"
else:
prompt_str += (
f"\n\n# Old version of g6()" f"\n# TODO: fix bugs in the code below\n"
"\n\n# Old version of g6()\n# TODO: fix bugs in the code below\n"
)
if isinstance(code_batch, list):
# TODO: get nearby genotypes
Expand Down Expand Up @@ -235,7 +235,7 @@ def generate_programs(self, code_batch: list[str]) -> list[P3Solution]:
processes=self.config.processes,
debug=self.config.debug,
)
except Exception as e:
except Exception:
return self.generate_programs(code_batch)

results = [
Expand Down Expand Up @@ -482,10 +482,10 @@ def construct_prompt(
i_g6 = program_str.find("def g6_1(")
lines = program_str[:i_g6].strip().split("\n")
new_lines = []
for l in lines:
if l.strip().startswith("#") or len(l.strip()) == 0:
for line in lines:
if line.strip().startswith("#") or len(line.strip()) == 0:
continue
new_lines.append(l)
new_lines.append(line)
program_str = "\n".join(new_lines) + "\n\n" + program_str[i_g6:]
program_str = program_str.strip()

Expand Down Expand Up @@ -525,7 +525,7 @@ def generate_programs(self, code_batch: list[str]) -> list[P3ProbSolResult]:
processes=self.config.processes,
debug=self.config.debug,
)
except Exception as e:
except Exception:
return self.generate_programs(code_batch)

results = [
Expand Down Expand Up @@ -564,10 +564,10 @@ def fitness(self, probsol: P3ProbSolResult) -> float:
debug=self.config.debug,
)

if result[0] != True:
if result[0] is True:
return -np.inf

### Do pass@k eval ###
# Do pass@k eval

# Get f6_2() and make it the new f6()
problem_str = probsol.problem_func.replace("def f6_2(", "def f6(")
Expand All @@ -576,10 +576,10 @@ def fitness(self, probsol: P3ProbSolResult) -> float:
# Remove comments with # (and remove empty lines)
lines = problem_str.strip().split("\n")
new_lines = []
for l in lines:
if l.strip().startswith("#") or len(l.strip()) == 0:
for line in lines:
if line.strip().startswith("#") or len(line.strip()) == 0:
continue
new_lines.append(l)
new_lines.append(line)
problem_str = "\n".join(new_lines)
# Get solution_preamble for g6()
i_end_preamble = probsol.solution_func.find("):")
Expand All @@ -599,7 +599,7 @@ def fitness(self, probsol: P3ProbSolResult) -> float:

c = 0
for s in solutions:
if p3_problem.evaluate_solution(s) == True:
if p3_problem.evaluate_solution(s) is True:
c += 1

pak = pass_at_k(len(solutions), c, self.config.eval_k)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from langchain.schema import HumanMessage

from openelm.configs import QDEnvConfig
from openelm.environments.environments import BaseEnvironment, Genotype, Phenotype
from openelm.environments.prompt_env import PromptGenotype
from openelm.environments.base import BaseEnvironment
from openelm.environments.prompt.prompt import PromptGenotype
from openelm.mutation_model import MutationModel, get_model


Expand Down Expand Up @@ -62,7 +62,7 @@ def evaluate(self, model) -> float:
self.genre = json.loads(diversity_result)["genre"]
self.tone = json.loads(diversity_result)["tone"]
return float(self.quality)
except Exception as e:
except Exception:
return -np.inf

def to_phenotype(self) -> Optional[np.ndarray]:
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from transformers import pipeline

from openelm.configs import PromptEnvConfig
from openelm.environments.env_utils import (
from openelm.environments.base import BaseEnvironment, Genotype, Phenotype
from openelm.environments.prompt.utils import (
AnimalPromptTask,
AntonymPromptTask,
COTPromptTask,
ToyPromptTask,
)
from openelm.environments.environments import BaseEnvironment, Genotype, Phenotype
from openelm.mutation_model import MutationModel


Expand Down Expand Up @@ -53,12 +53,14 @@ class PromptGenotype(Genotype):
"""
Genotype wrapper for a LangChain template.
This consists of a base format for all individuals, as well as individual-specific fields which will be evolved.
This consists of a base format for all individuals, as well as
individual-specific fields which will be evolved.
Remaining fields will be filled in at evaluation time.
Args:
prompt (PromptTemplate): The base template for all individuals.
fixed_inputs (dict[str, str], optional): Individual-specific fields to fill in. Defaults to None.
fixed_inputs (dict[str, str], optional): Individual-specific fields to
fill in. Defaults to None.
"""

def __init__(
Expand Down Expand Up @@ -271,7 +273,8 @@ def rewrite_string(self, input_str, rewrite_instruction, variable_name):
Args:
input_str: The string to rewrite.
rewrite_instruction: String prompt template for the LLM
variable_name: The name of the variable in the template to replace with input_str
variable_name: The name of the variable in the template to replace
with input_str
"""
rewrite_prompt = PromptTemplate(
input_variables=[variable_name],
Expand Down Expand Up @@ -343,7 +346,8 @@ def fitness(self, x: PromptGenotype) -> float:

def evaluate_template(self, eval_template, instruction_str, input_str, output_str):
"""
Evaluates a template on the log likelihood of the output_str, given the instruction_str and input_str.
Evaluates a template on the log likelihood of the output_str, given the
instruction_str and input_str.
Args:
eval_template: The template to evaluate.
Expand All @@ -352,7 +356,8 @@ def evaluate_template(self, eval_template, instruction_str, input_str, output_st
output_str: The output string.
Returns:
The log likelihood of the tokens in the output string, given the instruction and input strings.
The log likelihood of the tokens in the output string, given the
instruction and input strings.
"""
model = self.fitness_model.model.model
tokenizer = self.fitness_model.model.tokenizer
Expand All @@ -361,7 +366,8 @@ def evaluate_template(self, eval_template, instruction_str, input_str, output_st
filled_prompt = partial_template.format(
input_str=input_str, output_str=output_str
)
# hack; replace the output string to figure out which token numbers correspond to the output (see APE)
# hack; replace the output string to figure out which token numbers
# correspond to the output (see APE)
reference_prompt = partial_template.format(input_str=input_str, output_str="~")

tokens_filled = tokenizer.encode(filled_prompt, return_tensors="pt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,6 @@
from openelm.constants import SRC_PATH


def get_image_target(name: str) -> np.ndarray:
if name == "circle":
target = np.zeros((32, 32, 3))
for y in range(32):
for x in range(32):
if (y - 16) ** 2 + (x - 16) ** 2 <= 100: # a radius-10 circle
target[y, x] = np.array([255, 255, 0])
else:
raise NotImplementedError(f"Image target {name} not implemented")
return target


IMAGE_SEED: str = """
def draw():
\tpic = np.zeros((32, 32, 3))
\tfor x in range(2, 30):
\t\tfor y in range(2, 30):
\t\t\tpic[x, y] = np.array([0, 0, 255])
\treturn pic
"""

NULL_SEED: str = ""


@dataclass
class ToyPromptTask:
base_template = "{few_shot_examples}\n{instruction_str} the word {target} {n_repetitions} times:"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import requests

from openelm.configs import SodaraceEnvConfig
from openelm.environments.environments import BaseEnvironment, Genotype, Phenotype
from openelm.environments.base import BaseEnvironment, Genotype, Phenotype
from openelm.environments.sodaracer import (
CIRCLE,
GALLOPER_PREREQ,
Expand Down
25 changes: 25 additions & 0 deletions src/openelm/environments/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np


def get_image_target(name: str) -> np.ndarray:
if name == "circle":
target = np.zeros((32, 32, 3))
for y in range(32):
for x in range(32):
if (y - 16) ** 2 + (x - 16) ** 2 <= 100: # a radius-10 circle
target[y, x] = np.array([255, 255, 0])
else:
raise NotImplementedError(f"Image target {name} not implemented")
return target


IMAGE_SEED: str = """
def draw():
\tpic = np.zeros((32, 32, 3))
\tfor x in range(2, 30):
\t\tfor y in range(2, 30):
\t\t\tpic[x, y] = np.array([0, 0, 255])
\treturn pic
"""

NULL_SEED: str = ""
Loading

0 comments on commit 3c9c8be

Please sign in to comment.