Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non-default emitters #1403

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241114001527332372.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fix issue when using emitters that are not the default parque output"
}
3 changes: 2 additions & 1 deletion graphrag/index/emit/csv_table_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ class CSVTableEmitter(TableEmitter):
"""CSVTableEmitter class."""

_storage: PipelineStorage
extension = "csv"

def __init__(self, storage: PipelineStorage):
"""Create a new CSV Table Emitter."""
self._storage = storage

async def emit(self, name: str, data: pd.DataFrame) -> None:
"""Emit a dataframe to storage."""
filename = f"{name}.csv"
filename = f"{name}.{self.extension}"
log.info("emitting CSV table %s", filename)
await self._storage.set(
filename,
Expand Down
3 changes: 2 additions & 1 deletion graphrag/index/emit/json_table_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ class JsonTableEmitter(TableEmitter):
"""JsonTableEmitter class."""

_storage: PipelineStorage
extension = "json"

def __init__(self, storage: PipelineStorage):
"""Create a new Json Table Emitter."""
self._storage = storage

async def emit(self, name: str, data: pd.DataFrame) -> None:
"""Emit a dataframe to storage."""
filename = f"{name}.json"
filename = f"{name}.{self.extension}"

log.info("emitting JSON table %s", filename)
await self._storage.set(
Expand Down
3 changes: 2 additions & 1 deletion graphrag/index/emit/parquet_table_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ParquetTableEmitter(TableEmitter):

_storage: PipelineStorage
_on_error: ErrorHandlerFn
extension = "parquet"

def __init__(
self,
Expand All @@ -34,7 +35,7 @@ def __init__(

async def emit(self, name: str, data: pd.DataFrame) -> None:
"""Emit a dataframe to storage."""
filename = f"{name}.parquet"
filename = f"{name}.{self.extension}"
log.info("emitting parquet table %s", filename)
try:
await self._storage.set(filename, data.to_parquet())
Expand Down
2 changes: 2 additions & 0 deletions graphrag/index/emit/table_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@
class TableEmitter(Protocol):
"""TableEmitter protocol for emitting tables to a destination."""

extension: str

async def emit(self, name: str, data: pd.DataFrame) -> None:
"""Emit a dataframe to storage."""
6 changes: 4 additions & 2 deletions graphrag/index/run/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ async def _inject_workflow_data_dependencies(
workflow_dependencies: dict[str, list[str]],
dataset: pd.DataFrame,
storage: PipelineStorage,
extension: str,
) -> None:
"""Inject the data dependencies into the workflow."""
workflow.add_table(DEFAULT_INPUT_NAME, dataset)
Expand All @@ -41,7 +42,7 @@ async def _inject_workflow_data_dependencies(
for id in deps:
workflow_id = f"workflow:{id}"
try:
table = await _load_table_from_storage(f"{id}.parquet", storage)
table = await _load_table_from_storage(f"{id}.{extension}", storage)
except ValueError:
# our workflows now allow transient tables, and we avoid putting those in primary storage
# however, we need to keep the table in the dependency list for proper execution order
Expand Down Expand Up @@ -97,8 +98,9 @@ async def _process_workflow(
return None

context.stats.workflows[workflow_name] = {"overall": 0.0}

await _inject_workflow_data_dependencies(
workflow, workflow_dependencies, dataset, context.storage
workflow, workflow_dependencies, dataset, context.storage, emitters[0].extension
)

workflow_start_time = time.time()
Expand Down
16 changes: 15 additions & 1 deletion graphrag/utils/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,21 @@ async def _load_table_from_storage(name: str, storage: PipelineStorage) -> pd.Da
raise ValueError(msg)
try:
log.info("read table from storage: %s", name)
return pd.read_parquet(BytesIO(await storage.get(name, as_bytes=True)))
match name.split(".")[-1]:
case "parquet":
return pd.read_parquet(BytesIO(await storage.get(name, as_bytes=True)))
case "json":
return pd.read_json(
BytesIO(await storage.get(name, as_bytes=True)),
lines=True,
orient="records",
)
case "csv":
return pd.read_csv(BytesIO(await storage.get(name, as_bytes=True)))
case _:
msg = f"Unknown file extension for {name}"
log.exception(msg)
raise
except Exception:
log.exception("error loading table from storage: %s", name)
raise
Loading