Skip to content

Commit

Permalink
Transient entity graph (#1349)
Browse files Browse the repository at this point in the history
* Make base_entity_graph transient

* Add transient snapshots

* Semver

* Fix unit test

* Fix smoke tests
  • Loading branch information
natoverse authored Nov 5, 2024
1 parent 17658c5 commit 634e3ed
Show file tree
Hide file tree
Showing 34 changed files with 209 additions and 96 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241105004012425642.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Transient entity graph and snapshotting."
}
12 changes: 7 additions & 5 deletions docs/config/env_vars.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,13 @@ This section controls the reporting mechanism used by the pipeline, for common e

## Data Snapshotting

| Parameter | Description | Type | Required or Optional | Default |
| ----------------------------------- | ------------------------------------------- | ------ | -------------------- | ------- |
| `GRAPHRAG_SNAPSHOT_GRAPHML` | Whether to enable GraphML snapshots. | `bool` | optional | False |
| `GRAPHRAG_SNAPSHOT_RAW_ENTITIES` | Whether to enable raw entity snapshots. | `bool` | optional | False |
| `GRAPHRAG_SNAPSHOT_TOP_LEVEL_NODES` | Whether to enable top-level node snapshots. | `bool` | optional | False |
| Parameter | Description | Type | Required or Optional | Default |
| -------------------------------------- | ----------------------------------------------- | ------ | -------------------- | ------- |
| `GRAPHRAG_SNAPSHOT_EMBEDDINGS` | Whether to enable embeddings snapshots. | `bool` | optional | False |
| `GRAPHRAG_SNAPSHOT_GRAPHML` | Whether to enable GraphML snapshots. | `bool` | optional | False |
| `GRAPHRAG_SNAPSHOT_RAW_ENTITIES` | Whether to enable raw entity snapshots. | `bool` | optional | False |
| `GRAPHRAG_SNAPSHOT_TOP_LEVEL_NODES` | Whether to enable top-level node snapshots. | `bool` | optional | False |
| `GRAPHRAG_SNAPSHOT_TRANSIENT` | Whether to enable transient table snapshots. | `bool` | optional | False |

# Miscellaneous Settings

Expand Down
8 changes: 5 additions & 3 deletions docs/config/json_yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,11 @@ This is the base LLM configuration section. Other steps may override this config

### Fields

- `graphml` **bool** - Emit graphml snapshots.
- `raw_entities` **bool** - Emit raw entity snapshots.
- `top_level_nodes` **bool** - Emit top-level-node snapshots.
- `embeddings` **bool** - Emit embeddings snapshots to parquet.
- `graphml` **bool** - Emit graph snapshots to GraphML.
- `raw_entities` **bool** - Emit raw entity snapshots to JSON.
- `top_level_nodes` **bool** - Emit top-level-node snapshots to JSON.
- `transient` **bool** - Emit transient workflow tables snapshots to parquet.

## encoding_model

Expand Down
1 change: 1 addition & 0 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def hydrate_parallelization_params(
top_level_nodes=reader.bool("top_level_nodes")
or defs.SNAPSHOTS_TOP_LEVEL_NODES,
embeddings=reader.bool("embeddings") or defs.SNAPSHOTS_EMBEDDINGS,
transient=reader.bool("transient") or defs.SNAPSHOTS_TRANSIENT,
)
with reader.envvar_prefix(Section.umap), reader.use(values.get("umap")):
umap_model = UmapConfig(
Expand Down
1 change: 1 addition & 0 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
SNAPSHOTS_RAW_ENTITIES = False
SNAPSHOTS_TOP_LEVEL_NODES = False
SNAPSHOTS_EMBEDDINGS = False
SNAPSHOTS_TRANSIENT = False
STORAGE_BASE_DIR = "output"
STORAGE_TYPE = StorageType.file
SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500
Expand Down
2 changes: 2 additions & 0 deletions graphrag/config/input_models/snapshots_config_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
class SnapshotsConfigInput(TypedDict):
"""Configuration section for snapshots."""

embeddings: NotRequired[bool | str | None]
graphml: NotRequired[bool | str | None]
raw_entities: NotRequired[bool | str | None]
top_level_nodes: NotRequired[bool | str | None]
transient: NotRequired[bool | str | None]
10 changes: 7 additions & 3 deletions graphrag/config/models/snapshots_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
class SnapshotsConfig(BaseModel):
"""Configuration section for snapshots."""

embeddings: bool = Field(
description="A flag indicating whether to take snapshots of embeddings.",
default=defs.SNAPSHOTS_EMBEDDINGS,
)
graphml: bool = Field(
description="A flag indicating whether to take snapshots of GraphML.",
default=defs.SNAPSHOTS_GRAPHML,
Expand All @@ -23,7 +27,7 @@ class SnapshotsConfig(BaseModel):
description="A flag indicating whether to take snapshots of top-level nodes.",
default=defs.SNAPSHOTS_TOP_LEVEL_NODES,
)
embeddings: bool = Field(
description="A flag indicating whether to take snapshots of embeddings.",
default=defs.SNAPSHOTS_EMBEDDINGS,
transient: bool = Field(
description="A flag indicating whether to take snapshots of transient tables.",
default=defs.SNAPSHOTS_TRANSIENT,
)
5 changes: 4 additions & 1 deletion graphrag/index/create_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _text_unit_workflows(
PipelineWorkflowReference(
name=create_base_text_units,
config={
"snapshot_transient": settings.snapshots.transient,
"chunk_by": settings.chunks.group_by_columns,
"text_chunk": {
"strategy": settings.chunks.resolved_strategy(
Expand Down Expand Up @@ -215,7 +216,9 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
PipelineWorkflowReference(
name=create_base_entity_graph,
config={
"graphml_snapshot": settings.snapshots.graphml,
"snapshot_graphml": settings.snapshots.graphml,
"snapshot_transient": settings.snapshots.transient,
"snapshot_raw_entities": settings.snapshots.raw_entities,
"entity_extract": {
**settings.entity_extraction.parallelization.model_dump(),
"async_mode": settings.entity_extraction.async_mode,
Expand Down
21 changes: 16 additions & 5 deletions graphrag/index/flows/create_base_entity_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ async def create_base_entity_graph(
summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4,
embedding_strategy: dict[str, Any] | None = None,
graphml_snapshot_enabled: bool = False,
raw_entity_snapshot_enabled: bool = False,
snapshot_graphml_enabled: bool = False,
snapshot_raw_entities_enabled: bool = False,
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later
Expand Down Expand Up @@ -92,15 +93,15 @@ async def create_base_entity_graph(
strategy=embedding_strategy,
)

if raw_entity_snapshot_enabled:
if snapshot_raw_entities_enabled:
await snapshot(
entities,
name="raw_extracted_entities",
storage=storage,
formats=["json"],
)

if graphml_snapshot_enabled:
if snapshot_graphml_enabled:
await snapshot_graphml(
merged_graph,
name="merged_graph",
Expand Down Expand Up @@ -131,4 +132,14 @@ async def create_base_entity_graph(
if embedding_strategy:
final_columns.append("embeddings")

return cast(pd.DataFrame, clustered[final_columns])
output = cast(pd.DataFrame, clustered[final_columns])

if snapshot_transient_enabled:
await snapshot(
output,
name="create_base_entity_graph",
storage=storage,
formats=["parquet"],
)

return output
18 changes: 16 additions & 2 deletions graphrag/index/flows/create_base_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@
)

from graphrag.index.operations.chunk_text import chunk_text
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.storage import PipelineStorage
from graphrag.index.utils import gen_md5_hash


def create_base_text_units(
async def create_base_text_units(
documents: pd.DataFrame,
callbacks: VerbCallbacks,
storage: PipelineStorage,
chunk_column_name: str,
n_tokens_column_name: str,
chunk_by_columns: list[str],
chunk_strategy: dict[str, Any] | None = None,
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
sort = documents.sort_values(by=["id"], ascending=[True])
Expand Down Expand Up @@ -73,10 +77,20 @@ def create_base_text_units(
)
chunked["id"] = chunked["chunk_id"]

return cast(
output = cast(
pd.DataFrame, chunked[chunked[chunk_column_name].notna()].reset_index(drop=True)
)

if snapshot_transient_enabled:
await snapshot(
output,
name="create_base_text_units",
storage=storage,
formats=["parquet"],
)

return output


# TODO: would be nice to inline this completely in the main method with pandas
def _aggregate_df(
Expand Down
4 changes: 2 additions & 2 deletions graphrag/index/flows/create_final_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def create_final_nodes(
storage: PipelineStorage,
layout_strategy: dict[str, Any],
level_for_node_positions: int,
snapshot_top_level_nodes: bool = False,
snapshot_top_level_nodes_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to transform final nodes."""
laid_out_entity_graph = cast(
Expand Down Expand Up @@ -50,7 +50,7 @@ async def create_final_nodes(
nodes = nodes[nodes["level"] == level_for_node_positions].reset_index(drop=True)
nodes = cast(pd.DataFrame, nodes[["id", "x", "y"]])

if snapshot_top_level_nodes:
if snapshot_top_level_nodes_enabled:
await snapshot(
nodes,
name="top_level_nodes",
Expand Down
8 changes: 4 additions & 4 deletions graphrag/index/flows/generate_text_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def generate_text_embeddings(
storage: PipelineStorage,
text_embed_config: dict,
embedded_fields: set[str],
embeddings_snapshot_enabled: bool = False,
snapshot_embeddings_enabled: bool = False,
) -> None:
"""All the steps to generate all embeddings."""
embedding_param_map = {
Expand Down Expand Up @@ -109,7 +109,7 @@ async def generate_text_embeddings(
cache=cache,
storage=storage,
text_embed_config=text_embed_config,
embeddings_snapshot_enabled=embeddings_snapshot_enabled,
snapshot_embeddings_enabled=snapshot_embeddings_enabled,
**embedding_param_map[field],
)

Expand All @@ -122,7 +122,7 @@ async def _run_and_snapshot_embeddings(
cache: PipelineCache,
storage: PipelineStorage,
text_embed_config: dict,
embeddings_snapshot_enabled: bool,
snapshot_embeddings_enabled: bool,
) -> None:
"""All the steps to generate single embedding."""
if text_embed_config:
Expand All @@ -137,7 +137,7 @@ async def _run_and_snapshot_embeddings(

data = data.loc[:, ["id", "embedding"]]

if embeddings_snapshot_enabled is True:
if snapshot_embeddings_enabled is True:
await snapshot(
data,
name=f"embeddings.{name}",
Expand Down
10 changes: 6 additions & 4 deletions graphrag/index/workflows/v1/create_base_entity_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ def build_steps(
embedding_strategy = embed_graph_config.get("strategy")
embed_graph_enabled = config.get("embed_graph_enabled", False) or False

graphml_snapshot_enabled = config.get("graphml_snapshot", False) or False
raw_entity_snapshot_enabled = config.get("raw_entity_snapshot", False) or False
snapshot_graphml = config.get("snapshot_graphml", False) or False
snapshot_raw_entities = config.get("snapshot_raw_entities", False) or False
snapshot_transient = config.get("snapshot_transient", False) or False

return [
{
Expand All @@ -109,8 +110,9 @@ def build_steps(
"embedding_strategy": embedding_strategy
if embed_graph_enabled
else None,
"raw_entity_snapshot_enabled": raw_entity_snapshot_enabled,
"graphml_snapshot_enabled": graphml_snapshot_enabled,
"snapshot_raw_entities_enabled": snapshot_raw_entities,
"snapshot_graphml_enabled": snapshot_graphml,
"snapshot_transient_enabled": snapshot_transient,
},
"input": ({"source": "workflow:create_base_text_units"}),
},
Expand Down
3 changes: 3 additions & 0 deletions graphrag/index/workflows/v1/create_base_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def build_steps(
n_tokens_column_name = config.get("n_tokens_column", "n_tokens")
text_chunk_config = config.get("text_chunk", {})
chunk_strategy = text_chunk_config.get("strategy")

snapshot_transient = config.get("snapshot_transient", False) or False
return [
{
"verb": "create_base_text_units",
Expand All @@ -32,6 +34,7 @@ def build_steps(
"n_tokens_column_name": n_tokens_column_name,
"chunk_by_columns": chunk_by_columns,
"chunk_strategy": chunk_strategy,
"snapshot_transient_enabled": snapshot_transient,
},
"input": {"source": DEFAULT_INPUT_NAME},
},
Expand Down
2 changes: 1 addition & 1 deletion graphrag/index/workflows/v1/create_final_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def build_steps(
"args": {
"layout_strategy": layout_strategy,
"level_for_node_positions": level_for_node_positions,
"snapshot_top_level_nodes": snapshot_top_level_nodes,
"snapshot_top_level_nodes_enabled": snapshot_top_level_nodes,
},
"input": {"source": "workflow:create_base_entity_graph"},
},
Expand Down
4 changes: 2 additions & 2 deletions graphrag/index/workflows/v1/generate_text_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def build_steps(
"""
text_embed = config.get("text_embed", {})
embedded_fields = config.get("embedded_fields", {})
embeddings_snapshot_enabled = config.get("snapshot_embeddings", False)
snapshot_embeddings = config.get("snapshot_embeddings", False)
return [
{
"verb": "generate_text_embeddings",
"args": {
"text_embed": text_embed,
"embedded_fields": embedded_fields,
"embeddings_snapshot_enabled": embeddings_snapshot_enabled,
"snapshot_embeddings_enabled": snapshot_embeddings,
},
"input": input,
},
Expand Down
15 changes: 10 additions & 5 deletions graphrag/index/workflows/v1/subflows/create_base_entity_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Any, cast

import pandas as pd
from datashaper import (
AsyncType,
Table,
Expand Down Expand Up @@ -41,8 +42,9 @@ async def create_base_entity_graph(
summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4,
embedding_strategy: dict[str, Any] | None = None,
graphml_snapshot_enabled: bool = False,
raw_entity_snapshot_enabled: bool = False,
snapshot_graphml_enabled: bool = False,
snapshot_raw_entities_enabled: bool = False,
snapshot_transient_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to create the base entity graph."""
Expand All @@ -65,8 +67,11 @@ async def create_base_entity_graph(
summarization_strategy=summarization_strategy,
summarization_num_threads=summarization_num_threads,
embedding_strategy=embedding_strategy,
graphml_snapshot_enabled=graphml_snapshot_enabled,
raw_entity_snapshot_enabled=raw_entity_snapshot_enabled,
snapshot_graphml_enabled=snapshot_graphml_enabled,
snapshot_raw_entities_enabled=snapshot_raw_entities_enabled,
snapshot_transient_enabled=snapshot_transient_enabled,
)

return create_verb_result(cast(Table, output))
await runtime_storage.set("base_entity_graph", output)

return create_verb_result(cast(Table, pd.DataFrame()))
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,27 @@
async def create_base_text_units(
input: VerbInput,
callbacks: VerbCallbacks,
storage: PipelineStorage,
runtime_storage: PipelineStorage,
chunk_column_name: str,
n_tokens_column_name: str,
chunk_by_columns: list[str],
chunk_strategy: dict[str, Any] | None = None,
snapshot_transient_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform base text_units."""
source = cast(pd.DataFrame, input.get_input())

output = create_base_text_units_flow(
output = await create_base_text_units_flow(
source,
callbacks,
storage,
chunk_column_name,
n_tokens_column_name,
chunk_by_columns,
chunk_strategy=chunk_strategy,
snapshot_transient_enabled=snapshot_transient_enabled,
)

await runtime_storage.set("base_text_units", output)
Expand Down
Loading

0 comments on commit 634e3ed

Please sign in to comment.