Skip to content

Commit

Permalink
Writing a table into a union node works.
Browse files Browse the repository at this point in the history
  • Loading branch information
danielballan committed Feb 25, 2024
1 parent 5a767c6 commit 0b4a4b4
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 41 deletions.
39 changes: 36 additions & 3 deletions tiled/_tests/test_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def test_union_one_table(tree):
data_source = DataSource(
structure_family=StructureFamily.table,
structure=structure,
name="table",
)
client.create_union([data_source], key="x")

Expand All @@ -479,16 +480,43 @@ def test_union_two_tables(tree):
DataSource(
structure_family=StructureFamily.table,
structure=structure1,
name="table1",
),
DataSource(
structure_family=StructureFamily.table,
structure=structure2,
name="table2",
),
],
key="x",
)


def test_union_two_tables_colliding_names(tree):
with Context.from_app(build_app(tree)) as context:
client = from_context(context)
df1 = pandas.DataFrame({"A": [], "B": []})
df2 = pandas.DataFrame({"C": [], "D": [], "E": []})
structure1 = TableStructure.from_pandas(df1)
structure2 = TableStructure.from_pandas(df2)
with fail_with_status_code(422):
client.create_union(
[
DataSource(
structure_family=StructureFamily.table,
structure=structure1,
name="table1",
),
DataSource(
structure_family=StructureFamily.table,
structure=structure2,
name="table1", # collision
),
],
key="x",
)


def test_union_two_tables_colliding_keys(tree):
with Context.from_app(build_app(tree)) as context:
client = from_context(context)
Expand All @@ -502,10 +530,12 @@ def test_union_two_tables_colliding_keys(tree):
DataSource(
structure_family=StructureFamily.table,
structure=structure1,
name="table1",
),
DataSource(
structure_family=StructureFamily.table,
structure=structure2,
name="table2",
),
],
key="x",
Expand All @@ -528,20 +558,22 @@ def test_union_two_tables_two_arrays(tree):
DataSource(
structure_family=StructureFamily.table,
structure=structure1,
name="table1",
),
DataSource(
structure_family=StructureFamily.table,
structure=structure2,
name="table2",
),
DataSource(
structure_family=StructureFamily.array,
structure=structure3,
key="F",
name="F",
),
DataSource(
structure_family=StructureFamily.array,
structure=structure4,
key="G",
name="G",
),
],
key="x",
Expand All @@ -561,11 +593,12 @@ def test_union_table_column_array_key_collision(tree):
DataSource(
structure_family=StructureFamily.table,
structure=structure1,
name="table",
),
DataSource(
structure_family=StructureFamily.array,
structure=structure2,
key="B",
name="B",
),
],
key="x",
Expand Down
58 changes: 41 additions & 17 deletions tiled/catalog/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(
queries=None,
sorting=None,
access_policy=None,
data_source_name=None,
):
self.context = context
self.engine = self.context.engine
Expand All @@ -282,6 +283,7 @@ def __init__(
self.ancestors = node.ancestors
self.key = node.key
self.access_policy = access_policy
self.data_source_name = data_source_name
self.startup_tasks = [self.startup]
self.shutdown_tasks = [self.shutdown]

Expand Down Expand Up @@ -357,7 +359,7 @@ def structure(self):
data_source_id=data_source.id,
structure=data_source.structure,
structure_family=data_source.structure_family,
key=data_source.key,
name=data_source.name,
)
for data_source in self.data_sources
]
Expand All @@ -379,7 +381,7 @@ async def async_len(self):
async def lookup_adapter(
self,
segments,
data_source_id=None,
data_source_name=None,
): # TODO: Accept filter for predicate-pushdown.
if not segments:
return self
Expand All @@ -389,13 +391,17 @@ async def lookup_adapter(
# this node, either via user search queries or via access
# control policy queries. Look up first the _direct_ child of this
# node, if it exists within the filtered results.
first_level = await self.lookup_adapter(segments[:1])
first_level = await self.lookup_adapter(
segments[:1], data_source_name=data_source_name
)
if first_level is None:
return None
# Now proceed to traverse further down the tree, if needed.
# Search queries and access controls apply only at the top level.
assert not first_level.conditions
return await first_level.lookup_adapter(segments[1:])
return await first_level.lookup_adapter(
segments[1:], data_source_name=data_source_name
)
statement = (
select(orm.Node)
.filter(orm.Node.ancestors == self.segments + ancestors)
Expand All @@ -418,35 +424,41 @@ async def lookup_adapter(
# HDF5 file begins.

for i in range(len(segments)):
catalog_adapter = await self.lookup_adapter(segments[:i])
catalog_adapter = await self.lookup_adapter(
segments[:i], data_source_name=data_source_name
)
if catalog_adapter.data_sources:
adapter = await catalog_adapter.get_adapter(data_source_id)
adapter = await catalog_adapter.get_adapter(data_source_name)
for segment in segments[i:]:
adapter = await anyio.to_thread.run_sync(adapter.get, segment)
if adapter is None:
break
return adapter
return None
return STRUCTURES[node.structure_family](
self.context, node, access_policy=self.access_policy
self.context,
node,
data_source_name=data_source_name,
access_policy=self.access_policy,
)

async def get_adapter(self, data_source_id=None):
async def get_adapter(self):
num_data_sources = len(self.data_sources)
if data_source_id is not None:
if self.data_source_name is not None:
for data_source in self.data_sources:
if data_source_id == data_source.id:
if self.data_source_name == data_source.name:
break
else:
raise ValueError(
f"No such data_source_id {data_source_id} on this node"
f"No DataSource named {self.data_source_name} on this node"
)
elif num_data_sources > 1:
raise ValueError(
"A data_source_id is required because this node "
"A data_source_name is required because this node "
f"has {num_data_sources} data sources"
)
(data_source,) = self.data_sources
else:
(data_source,) = self.data_sources
try:
adapter_factory = self.context.adapters_by_mimetype[data_source.mimetype]
except KeyError:
Expand Down Expand Up @@ -666,6 +678,7 @@ async def create_node(
await db.execute(statement)
data_source_orm = orm.DataSource(
structure_family=data_source.structure_family,
name=data_source.name,
mimetype=data_source.mimetype,
management=data_source.management,
parameters=data_source.parameters,
Expand Down Expand Up @@ -996,10 +1009,21 @@ async def write_partition(self, *args, **kwargs):
class CatalogUnionAdapter(CatalogNodeAdapter):
# This does not support direct reading or writing.

def get(self, key):
pass
# breakpoint()
# Find the key in self.data_sources
async def read(self, *args, **kwargs):
return await ensure_awaitable((await self.get_adapter()).read, *args, **kwargs)

async def write(self, *args, **kwargs):
return await ensure_awaitable((await self.get_adapter()).write, *args, **kwargs)

async def read_partition(self, *args, **kwargs):
return await ensure_awaitable(
(await self.get_adapter()).read_partition, *args, **kwargs
)

async def write_partition(self, *args, **kwargs):
return await ensure_awaitable(
(await self.get_adapter()).write_partition, *args, **kwargs
)


def delete_asset(data_uri, is_directory):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Add 'key' column to data_sources table.
"""Add 'name' column to data_sources table.
Revision ID: 7c8130c40b8f
Revises: e756b9381c14
Expand All @@ -16,7 +16,7 @@


def upgrade():
op.add_column("data_sources", sa.Column("key", sa.Unicode(1023), nullable=True))
op.add_column("data_sources", sa.Column("name", sa.Unicode(1023), nullable=True))


def downgrade():
Expand Down
2 changes: 1 addition & 1 deletion tiled/catalog/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ class DataSource(Timestamped, Base):
structure_family = Column(Enum(StructureFamily), nullable=False)
# This is used by `union` structures to address arrays.
# It may have additional uses in the future.
key = Column(Unicode(1023), nullable=True)
name = Column(Unicode(1023), nullable=True)

# many-to-one relationship to Structure
structure: Mapped["Structure"] = relationship(
Expand Down
7 changes: 7 additions & 0 deletions tiled/serialization/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def serialize_csv(df, metadata, preserve_index=False):
return file.getvalue().encode()


@deserialization_registry.register(StructureFamily.table, "text/csv")
def deserialize_csv(buffer):
import pandas

return pandas.read_csv(io.BytesIO(buffer))


serialization_registry.register(StructureFamily.table, "text/csv", serialize_csv)
serialization_registry.register(
StructureFamily.table, "text/x-comma-separated-values", serialize_csv
Expand Down
5 changes: 4 additions & 1 deletion tiled/server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def SecureEntry(scopes):
async def inner(
path: str,
request: Request,
data_source: Optional[str] = None,
principal: str = Depends(get_current_principal),
root_tree: pydantic.BaseSettings = Depends(get_root_tree),
session_state: dict = Depends(get_session_state),
Expand Down Expand Up @@ -82,7 +83,9 @@ async def inner(
# It can jump directly to the node of interest.

if hasattr(entry, "lookup_adapter"):
entry = await entry.lookup_adapter(path_parts[i:])
entry = await entry.lookup_adapter(
path_parts[i:], data_source_name=data_source
)
if entry is None:
raise NoEntry(path_parts)
break
Expand Down
12 changes: 6 additions & 6 deletions tiled/server/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,11 +687,11 @@ async def table_full(
"""
Fetch the data for the given table.
"""
if entry.structure_family != StructureFamily.table:
raise HTTPException(
status_code=404,
detail=f"Cannot read {entry.structure_family} structure with /table/full route.",
)
# if entry.structure_family != StructureFamily.table:
# raise HTTPException(
# status_code=404,
# detail=f"Cannot read {entry.structure_family} structure with /table/full route.",
# )
try:
with record_timing(request.state.metrics, "read"):
data = await ensure_awaitable(entry.read, column)
Expand Down Expand Up @@ -1138,7 +1138,7 @@ async def _create_node(
data_source_id=data_source.id,
structure=data_source.structure,
structure_family=data_source.structure_family,
key=data_source.key,
name=data_source.name,
)
for data_source in body.data_sources
]
Expand Down
4 changes: 2 additions & 2 deletions tiled/server/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class DataSource(pydantic.BaseModel):
parameters: dict = {}
assets: List[Asset] = []
management: Management = Management.writable
key: Optional[str] = None
name: Optional[str] = None

@classmethod
def from_orm(cls, orm):
Expand All @@ -160,7 +160,7 @@ def from_orm(cls, orm):
parameters=orm.parameters,
assets=[Asset.from_assoc_orm(assoc) for assoc in orm.asset_associations],
management=orm.management,
key=orm.key,
name=orm.name,
)


Expand Down
25 changes: 16 additions & 9 deletions tiled/structures/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class DataSource:
parameters: dict = dataclasses.field(default_factory=dict)
assets: List[Asset] = dataclasses.field(default_factory=list)
management: Management = Management.writable
key: Optional[str] = None
name: Optional[str] = None


def validate_data_sources(node_structure_family, data_sources):
Expand All @@ -51,23 +51,30 @@ def validate_container_data_sources(node_structure_family, data_sources):
def validate_union_data_sources(node_structure_family, data_sources):
"Check that column names and keys of others (e.g. arrays) do not collide."
keys = set()
names = set()
for data_source in data_sources:
if data_source.name is None:
raise ValueError(
"Data sources backing a union structure_family must "
"all have non-NULL names."
)
if data_source.name in names:
raise ValueError(
"Data sources must have unique names. "
f"This name is used one more than one: {data_source.name}"
)
names.add(data_source.name)
if data_source.structure_family == StructureFamily.table:
columns = data_source.structure.columns
if keys.intersection(columns):
raise ValueError(
f"Two data sources provide colliding keys: {keys.intersection(columns)}"
f"Data sources provide colliding keys: {keys.intersection(columns)}"
)
keys.update(columns)
else:
key = data_source.key
if key is None:
raise ValueError(
f"Data source of type {data_source.structure_family} "
"must have a non-NULL key."
)
key = data_source.name
if key in keys:
raise ValueError(f"Collision: {key}")
raise ValueError(f"Data sources provide colliding keys: {key}")
keys.add(key)
return data_sources

Expand Down

0 comments on commit 0b4a4b4

Please sign in to comment.