Skip to content

Commit

Permalink
move env-var reading prior to the model parse
Browse files Browse the repository at this point in the history
  • Loading branch information
darthtrevino committed Apr 2, 2024
1 parent 56e582b commit e0ba225
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ jobs:
name: Unit Tests
env:
GRAPHRAG_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHRAG_BLOB_CONNECTION_STRING: ${{ secrets.GRAPHRAG_BLOB_CONNECTION_STRING}}

- run: yarn test:integration
name: Integration Tests
Expand All @@ -92,3 +91,4 @@ jobs:
GRAPHRAG_API_KEY: ${{ secrets.OPENAI_API_KEY }}
GRAPHRAG_LLM_MODEL: gpt-3.5-turbo
GRAPHRAG_EMBEDDING_MODEL: text-embedding-3-small
GRAPHRAG_BLOB_CONNECTION_STRING: ${{ secrets.GRAPHRAG_BLOB_CONNECTION_STRING}}
9 changes: 3 additions & 6 deletions python/graphrag/graphrag/index/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from graphrag.index.cache import NoopPipelineCache
from graphrag.index.config import PipelineConfig
from graphrag.index.default_config import (
DefaultConfigParametersModel,
default_config,
default_config_parameters,
default_config_parameters_from_dict,
default_config_parameters_from_env_vars,
)
from graphrag.index.progress import (
Expand Down Expand Up @@ -170,17 +169,15 @@ def _read_config_parameters(root: str, reporter: ProgressReporter):
import yaml

data = yaml.safe_load(file)
model = DefaultConfigParametersModel.model_validate(data)
return default_config_parameters(model, root)
return default_config_parameters_from_dict(data, root)

if settings_json.exists():
reporter.success(f"Reading settings from {settings_json}")
with settings_json.open("r") as file:
import json

data = json.loads(file.read())
model = DefaultConfigParametersModel.model_validate(data)
return default_config_parameters(data, root)
return default_config_parameters_from_dict(data, root)

reporter.success("Reading settings from environment variables")
return default_config_parameters_from_env_vars(root)
Expand Down
2 changes: 2 additions & 0 deletions python/graphrag/graphrag/index/default_config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
TextEmbeddingConfigModel,
UmapConfigModel,
default_config_parameters,
default_config_parameters_from_dict,
default_config_parameters_from_env_vars,
)

Expand All @@ -48,6 +49,7 @@
"UmapConfigModel",
"default_config",
"default_config_parameters",
"default_config_parameters_from_dict",
"default_config_parameters_from_env_vars",
"load_pipeline_config",
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .default_config_parameters_model import DefaultConfigParametersModel
from .factories import (
default_config_parameters,
default_config_parameters_from_dict,
default_config_parameters_from_env_vars,
)
from .models import (
Expand Down Expand Up @@ -48,6 +49,7 @@
"TextEmbeddingConfigModel",
"UmapConfigModel",
"default_config_parameters",
"default_config_parameters_from_dict",
"default_config_parameters_from_env_vars",
"read_dotenv",
]
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,41 @@
)


def perform_replacements(values: dict, env: Env):
"""Perform env-var replacements in the dictionary."""

def traverse_dict(dictionary):
for key, value in dictionary.items():
if isinstance(value, dict):
traverse_dict(value)
elif (
isinstance(value, str)
and value.startswith("${")
and value.endswith("}")
):
env_var = value[2:-1]
dictionary[key] = env(env_var)

traverse_dict(values)


def default_config_parameters_from_dict(values: dict, root_dir: str | None):
"""Load Configuration Parameters from a dictionary."""
root_dir = root_dir or str(Path.cwd())
env = _make_env(root_dir)
perform_replacements(values, env)

model = DefaultConfigParametersModel.model_validate(values)
return default_config_parameters(model, root_dir)


def default_config_parameters(
values: DefaultConfigParametersModel, root_dir: str | None
):
"""Load Configuration Parameters from a dictionary."""
root_dir = root_dir or str(Path.cwd())
env = _make_env(root_dir)

return DefaultConfigParametersDict(values, env, root_dir)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(self, values: CacheConfigModel, env: Env):
@property
def type(self) -> PipelineCacheType:
"""The cache type to use."""
return self.replace(self._values.type, DEFAULT_CACHE_TYPE)
result = self.replace(str(self._values.type))
return PipelineCacheType(result) if result else DEFAULT_CACHE_TYPE

@property
def connection_string(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(self, values: InputConfigModel, env: Env):
@property
def type(self) -> PipelineInputType:
"""The input type to use."""
return self.replace(self._values.type, DEFAULT_INPUT_TYPE)
result = self.replace(str(self._values.type))
return PipelineInputType(result) if result else DEFAULT_INPUT_TYPE

@property
def storage_type(self) -> PipelineInputStorageType:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(self, values: ReportingConfigModel, env: Env):
@property
def type(self) -> PipelineReportingType:
"""The reporting type to use."""
return self.replace(self._values.type, DEFAULT_REPORTING_TYPE)
result = self.replace(str(self._values.type))
return PipelineReportingType(result) if result else DEFAULT_REPORTING_TYPE

@property
def connection_string(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def __init__(self, values: StorageConfigModel, env: Env):
@property
def type(self) -> PipelineStorageType:
"""The storage type to use."""
return self.replace(self._values.type, DEFAULT_STORAGE_TYPE)
result = self.replace(str(self._values.type))
return PipelineStorageType(result) if result else DEFAULT_STORAGE_TYPE

@property
def connection_string(self) -> str:
Expand Down
6 changes: 4 additions & 2 deletions python/graphrag/tests/smoke/test_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
BLOB_CONNECTION_STRING = os.environ.get(
"GRAPHRAG_BLOB_CONNECTION_STRING", WELL_KNOWN_AZURITE_CONNECTION_STRING
)
CACHE_TYPE = os.environ.get("CACHE_TYPE", "file") or "file"
CACHE_CONTAINER_NAME = os.environ.get("CACHE_CONTAINER_NAME", "cache") or "cache"


def _load_fixtures():
Expand Down Expand Up @@ -244,8 +246,8 @@ def __run_query(self, root: Path, query_config: dict[str, str]):
os.environ,
{
**os.environ,
"CACHE_TYPE": os.environ.get("CACHE_TYPE", "file"),
"CACHE_CONTAINER_NAME": os.environ.get("CACHE_CONTAINER_NAME", "cache"),
"CACHE_TYPE": CACHE_TYPE,
"CACHE_CONTAINER_NAME": CACHE_CONTAINER_NAME,
"BLOB_STORAGE_CONNECTION_STRING": BLOB_CONNECTION_STRING,
},
clear=True,
Expand Down

0 comments on commit e0ba225

Please sign in to comment.