diff --git a/internal/datastore/common/relationships.go b/internal/datastore/common/relationships.go new file mode 100644 index 0000000000..b35b60d79f --- /dev/null +++ b/internal/datastore/common/relationships.go @@ -0,0 +1,154 @@ +package common + +import ( + "context" + "database/sql" + "fmt" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/authzed/spicedb/pkg/datastore" + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +const errUnableToQueryRels = "unable to query relationships: %w" + +// StaticValueOrAddColumnForSelect adds a column to the list of columns to select if the value +// is not static, otherwise it sets the value to the static value. +func StaticValueOrAddColumnForSelect(colsToSelect []any, queryInfo QueryInfo, colName string, field *string) []any { + // If the value is static, set the field to it and return. + if found, ok := queryInfo.FilteringValues[colName]; ok && found.SingleValue != nil { + *field = *found.SingleValue + return colsToSelect + } + + // Otherwise, add the column to the list of columns to select, as the value is not static. + colsToSelect = append(colsToSelect, field) + return colsToSelect +} + +// Querier is an interface for querying the database. +type Querier[R Rows] interface { + QueryFunc(ctx context.Context, f func(context.Context, R) error, sql string, args ...any) error +} + +// Rows is a common interface for database rows reading. +type Rows interface { + Scan(dest ...any) error + Next() bool + Err() error +} + +type closeRowsWithError interface { + Rows + Close() error +} + +type closeRows interface { + Rows + Close() +} + +// QueryRelationships queries relationships for the given query and transaction. +func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInfo QueryInfo, sqlStatement string, args []any, span trace.Span, tx Querier[R], withIntegrity bool) (datastore.RelationshipIterator, error) { + defer span.End() + + colsToSelect := make([]any, 0, 8) + var resourceObjectType string + var resourceObjectID string + var resourceRelation string + var subjectObjectType string + var subjectObjectID string + var subjectRelation string + var caveatName sql.NullString + var caveatCtx C + + var integrityKeyID string + var integrityHash []byte + var timestamp time.Time + + colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColNamespace, &resourceObjectType) + colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColObjectID, &resourceObjectID) + colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColRelation, &resourceRelation) + colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetNamespace, &subjectObjectType) + colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetObjectID, &subjectObjectID) + colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetRelation, &subjectRelation) + + colsToSelect = append(colsToSelect, &caveatName, &caveatCtx) + if withIntegrity { + colsToSelect = append(colsToSelect, &integrityKeyID, &integrityHash, ×tamp) + } + + return func(yield func(tuple.Relationship, error) bool) { + err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error { + var r Rows = rows + if crwe, ok := r.(closeRowsWithError); ok { + defer LogOnError(ctx, crwe.Close) + } else if cr, ok := r.(closeRows); ok { + defer cr.Close() + } + + span.AddEvent("Query issued to database") + relCount := 0 + for rows.Next() { + if err := rows.Scan(colsToSelect...); err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err)) + } + + var caveat *corev1.ContextualizedCaveat + if caveatName.Valid { + var err error + caveat, err = ContextualizedCaveatFrom(caveatName.String, caveatCtx) + if err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("unable to fetch caveat context: %w", err)) + } + } + + var integrity *corev1.RelationshipIntegrity + if integrityKeyID != "" { + integrity = &corev1.RelationshipIntegrity{ + KeyId: integrityKeyID, + Hash: integrityHash, + HashedAt: timestamppb.New(timestamp), + } + } + + relCount++ + if !yield(tuple.Relationship{ + RelationshipReference: tuple.RelationshipReference{ + Resource: tuple.ObjectAndRelation{ + ObjectType: resourceObjectType, + ObjectID: resourceObjectID, + Relation: resourceRelation, + }, + Subject: tuple.ObjectAndRelation{ + ObjectType: subjectObjectType, + ObjectID: subjectObjectID, + Relation: subjectRelation, + }, + }, + OptionalCaveat: caveat, + OptionalIntegrity: integrity, + }, nil) { + return nil + } + } + + if err := rows.Err(); err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("rows err: %w", err)) + } + + span.AddEvent("Rels loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) + return nil + }, sqlStatement, args...) + if err != nil { + if !yield(tuple.Relationship{}, err) { + return + } + } + }, nil +} diff --git a/internal/datastore/common/tuple.go b/internal/datastore/common/sliceiter.go similarity index 100% rename from internal/datastore/common/tuple.go rename to internal/datastore/common/sliceiter.go diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index 536de9385f..33b65b16df 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -64,17 +64,26 @@ const ( // SchemaInformation holds the schema information from the SQL datastore implementation. type SchemaInformation struct { - colNamespace string - colObjectID string - colRelation string - colUsersetNamespace string - colUsersetObjectID string - colUsersetRelation string - colCaveatName string - paginationFilterType PaginationFilterType + RelationshipTableName string + ColNamespace string + ColObjectID string + ColRelation string + ColUsersetNamespace string + ColUsersetObjectID string + ColUsersetRelation string + ColCaveatName string + ColCaveatContext string + PaginationFilterType PaginationFilterType + PlaceholderFolder sq.PlaceholderFormat + + // ExtaFields are additional fields that are not part of the core schema, but are + // requested by the caller for this query. + ExtraFields []string } +// NewSchemaInformation creates a new SchemaInformation object for a query. func NewSchemaInformation( + relationshipTableName, colNamespace, colObjectID, colRelation, @@ -82,9 +91,13 @@ func NewSchemaInformation( colUsersetObjectID, colUsersetRelation, colCaveatName string, + colCaveatContext string, paginationFilterType PaginationFilterType, + placeholderFormat sq.PlaceholderFormat, + extraFields ...string, ) SchemaInformation { return SchemaInformation{ + relationshipTableName, colNamespace, colObjectID, colRelation, @@ -92,35 +105,95 @@ func NewSchemaInformation( colUsersetObjectID, colUsersetRelation, colCaveatName, + colCaveatContext, paginationFilterType, + placeholderFormat, + extraFields, } } +type ColumnTracker struct { + SingleValue *string +} + // SchemaQueryFilterer wraps a SchemaInformation and SelectBuilder to give an opinionated // way to build query objects. type SchemaQueryFilterer struct { - schema SchemaInformation - queryBuilder sq.SelectBuilder - filteringColumnCounts map[string]int - filterMaximumIDCount uint16 + schema SchemaInformation + queryBuilder sq.SelectBuilder + filteringColumnTracker map[string]ColumnTracker + filterMaximumIDCount uint16 + isCustomQuery bool + extraFields []string + fromSuffix string +} + +// NewSchemaQueryFiltererForRelationshipsSelect creates a new SchemaQueryFilterer object for selecting +// relationships. This method will automatically filter the columns retrieved from the database, only +// selecting the columns that are not already specified with a single static value in the query. +func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filterMaximumIDCount uint16, extraFields ...string) SchemaQueryFilterer { + if filterMaximumIDCount == 0 { + filterMaximumIDCount = 100 + log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") + } + + return SchemaQueryFilterer{ + schema: schema, + queryBuilder: sq.StatementBuilder.PlaceholderFormat(schema.PlaceholderFolder).Select(), + filteringColumnTracker: map[string]ColumnTracker{}, + filterMaximumIDCount: filterMaximumIDCount, + isCustomQuery: false, + extraFields: extraFields, + } } -// NewSchemaQueryFilterer creates a new SchemaQueryFilterer object. -func NewSchemaQueryFilterer(schema SchemaInformation, initialQuery sq.SelectBuilder, filterMaximumIDCount uint16) SchemaQueryFilterer { +// NewSchemaQueryFiltererWithStartingQuery creates a new SchemaQueryFilterer object for selecting +// relationships, with a custom starting query. Unlike NewSchemaQueryFiltererForRelationshipsSelect, +// this method will not auto-filter the columns retrieved from the database. +func NewSchemaQueryFiltererWithStartingQuery(schema SchemaInformation, startingQuery sq.SelectBuilder, filterMaximumIDCount uint16) SchemaQueryFilterer { if filterMaximumIDCount == 0 { filterMaximumIDCount = 100 log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") } return SchemaQueryFilterer{ - schema: schema, - queryBuilder: initialQuery, - filteringColumnCounts: map[string]int{}, - filterMaximumIDCount: filterMaximumIDCount, + schema: schema, + queryBuilder: startingQuery, + filteringColumnTracker: map[string]ColumnTracker{}, + filterMaximumIDCount: filterMaximumIDCount, + isCustomQuery: true, + extraFields: nil, + } +} + +// WithAdditionalFilter returns a new SchemaQueryFilterer with an additional filter applied to the query. +func (sqf SchemaQueryFilterer) WithAdditionalFilter(filter func(original sq.SelectBuilder) sq.SelectBuilder) SchemaQueryFilterer { + return SchemaQueryFilterer{ + schema: sqf.schema, + queryBuilder: filter(sqf.queryBuilder), + filteringColumnTracker: sqf.filteringColumnTracker, + filterMaximumIDCount: sqf.filterMaximumIDCount, + isCustomQuery: sqf.isCustomQuery, + extraFields: sqf.extraFields, + } +} + +func (sqf SchemaQueryFilterer) WithFromSuffix(fromSuffix string) SchemaQueryFilterer { + return SchemaQueryFilterer{ + schema: sqf.schema, + queryBuilder: sqf.queryBuilder, + filteringColumnTracker: sqf.filteringColumnTracker, + filterMaximumIDCount: sqf.filterMaximumIDCount, + isCustomQuery: sqf.isCustomQuery, + extraFields: sqf.extraFields, + fromSuffix: fromSuffix, } } func (sqf SchemaQueryFilterer) UnderlyingQueryBuilder() sq.SelectBuilder { + spiceerrors.DebugAssert(func() bool { + return sqf.isCustomQuery + }, "UnderlyingQueryBuilder should only be called on custom queries") return sqf.queryBuilder } @@ -128,22 +201,22 @@ func (sqf SchemaQueryFilterer) TupleOrder(order options.SortOrder) SchemaQueryFi switch order { case options.ByResource: sqf.queryBuilder = sqf.queryBuilder.OrderBy( - sqf.schema.colNamespace, - sqf.schema.colObjectID, - sqf.schema.colRelation, - sqf.schema.colUsersetNamespace, - sqf.schema.colUsersetObjectID, - sqf.schema.colUsersetRelation, + sqf.schema.ColNamespace, + sqf.schema.ColObjectID, + sqf.schema.ColRelation, + sqf.schema.ColUsersetNamespace, + sqf.schema.ColUsersetObjectID, + sqf.schema.ColUsersetRelation, ) case options.BySubject: sqf.queryBuilder = sqf.queryBuilder.OrderBy( - sqf.schema.colUsersetNamespace, - sqf.schema.colUsersetObjectID, - sqf.schema.colUsersetRelation, - sqf.schema.colNamespace, - sqf.schema.colObjectID, - sqf.schema.colRelation, + sqf.schema.ColUsersetNamespace, + sqf.schema.ColUsersetObjectID, + sqf.schema.ColUsersetRelation, + sqf.schema.ColNamespace, + sqf.schema.ColObjectID, + sqf.schema.ColRelation, ) } @@ -162,47 +235,47 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr columnsAndValues := map[options.SortOrder][]nameAndValue{ options.ByResource: { { - sqf.schema.colNamespace, cursor.Resource.ObjectType, + sqf.schema.ColNamespace, cursor.Resource.ObjectType, }, { - sqf.schema.colObjectID, cursor.Resource.ObjectID, + sqf.schema.ColObjectID, cursor.Resource.ObjectID, }, { - sqf.schema.colRelation, cursor.Resource.Relation, + sqf.schema.ColRelation, cursor.Resource.Relation, }, { - sqf.schema.colUsersetNamespace, cursor.Subject.ObjectType, + sqf.schema.ColUsersetNamespace, cursor.Subject.ObjectType, }, { - sqf.schema.colUsersetObjectID, cursor.Subject.ObjectID, + sqf.schema.ColUsersetObjectID, cursor.Subject.ObjectID, }, { - sqf.schema.colUsersetRelation, cursor.Subject.Relation, + sqf.schema.ColUsersetRelation, cursor.Subject.Relation, }, }, options.BySubject: { { - sqf.schema.colUsersetNamespace, cursor.Subject.ObjectType, + sqf.schema.ColUsersetNamespace, cursor.Subject.ObjectType, }, { - sqf.schema.colUsersetObjectID, cursor.Subject.ObjectID, + sqf.schema.ColUsersetObjectID, cursor.Subject.ObjectID, }, { - sqf.schema.colNamespace, cursor.Resource.ObjectType, + sqf.schema.ColNamespace, cursor.Resource.ObjectType, }, { - sqf.schema.colObjectID, cursor.Resource.ObjectID, + sqf.schema.ColObjectID, cursor.Resource.ObjectID, }, { - sqf.schema.colRelation, cursor.Resource.Relation, + sqf.schema.ColRelation, cursor.Resource.Relation, }, { - sqf.schema.colUsersetRelation, cursor.Subject.Relation, + sqf.schema.ColUsersetRelation, cursor.Subject.Relation, }, }, }[order] - switch sqf.schema.paginationFilterType { + switch sqf.schema.PaginationFilterType { case TupleComparison: // For performance reasons, remove any column names that have static values in the query. columnNames := make([]string, 0, len(columnsAndValues)) @@ -210,7 +283,7 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr comparisonSlotCount := 0 for _, cav := range columnsAndValues { - if sqf.filteringColumnCounts[cav.name] != 1 { + if r, ok := sqf.filteringColumnTracker[cav.name]; !ok || r.SingleValue == nil { columnNames = append(columnNames, cav.name) valueSlots = append(valueSlots, cav.value) comparisonSlotCount++ @@ -230,10 +303,10 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr orClause := sq.Or{} for index, cav := range columnsAndValues { - if sqf.filteringColumnCounts[cav.name] != 1 { + if r, ok := sqf.filteringColumnTracker[cav.name]; !ok || r.SingleValue != nil { andClause := sq.And{} for _, previous := range columnsAndValues[0:index] { - if sqf.filteringColumnCounts[previous.name] != 1 { + if r, ok := sqf.filteringColumnTracker[previous.name]; !ok || r.SingleValue != nil { andClause = append(andClause, sq.Eq{previous.name: previous.value}) } } @@ -254,25 +327,31 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr // FilterToResourceType returns a new SchemaQueryFilterer that is limited to resources of the // specified type. func (sqf SchemaQueryFilterer) FilterToResourceType(resourceType string) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colNamespace: resourceType}) - sqf.recordColumnValue(sqf.schema.colNamespace) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColNamespace: resourceType}) + sqf.recordColumnValue(sqf.schema.ColNamespace, resourceType) return sqf } -func (sqf SchemaQueryFilterer) recordColumnValue(colName string) { - if value, ok := sqf.filteringColumnCounts[colName]; ok { - sqf.filteringColumnCounts[colName] = value + 1 - return +func (sqf SchemaQueryFilterer) recordColumnValue(colName string, colValue string) { + existing, ok := sqf.filteringColumnTracker[colName] + if ok { + if existing.SingleValue != nil && *existing.SingleValue != colValue { + sqf.filteringColumnTracker[colName] = ColumnTracker{SingleValue: nil} + } + } else { + sqf.filteringColumnTracker[colName] = ColumnTracker{SingleValue: &colValue} } +} - sqf.filteringColumnCounts[colName] = 1 +func (sqf SchemaQueryFilterer) recordMutableColumnValue(colName string) { + sqf.filteringColumnTracker[colName] = ColumnTracker{SingleValue: nil} } // FilterToResourceID returns a new SchemaQueryFilterer that is limited to resources with the // specified ID. func (sqf SchemaQueryFilterer) FilterToResourceID(objectID string) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colObjectID: objectID}) - sqf.recordColumnValue(sqf.schema.colObjectID) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColObjectID: objectID}) + sqf.recordColumnValue(sqf.schema.ColObjectID, objectID) return sqf } @@ -297,7 +376,7 @@ func (sqf SchemaQueryFilterer) FilterWithResourceIDPrefix(prefix string) (Schema prefix = strings.ReplaceAll(prefix, `\`, `\\`) prefix = strings.ReplaceAll(prefix, "_", `\_`) - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Like{sqf.schema.colObjectID: prefix + "%"}) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Like{sqf.schema.ColObjectID: prefix + "%"}) // NOTE: we do *not* record the use of the resource ID column here, because it is not used // statically and thus is necessary for sorting operations. @@ -320,7 +399,7 @@ func (sqf SchemaQueryFilterer) FilterToResourceIDs(resourceIds []string) (Schema }, "cannot have more than %d resource IDs in a single filter", sqf.filterMaximumIDCount) var builder strings.Builder - builder.WriteString(sqf.schema.colObjectID) + builder.WriteString(sqf.schema.ColObjectID) builder.WriteString(" IN (") args := make([]any, 0, len(resourceIds)) @@ -330,7 +409,7 @@ func (sqf SchemaQueryFilterer) FilterToResourceIDs(resourceIds []string) (Schema } args = append(args, resourceID) - sqf.recordColumnValue(sqf.schema.colObjectID) + sqf.recordColumnValue(sqf.schema.ColObjectID, resourceID) } builder.WriteString("?") @@ -346,8 +425,8 @@ func (sqf SchemaQueryFilterer) FilterToResourceIDs(resourceIds []string) (Schema // FilterToRelation returns a new SchemaQueryFilterer that is limited to resources with the // specified relation. func (sqf SchemaQueryFilterer) FilterToRelation(relation string) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colRelation: relation}) - sqf.recordColumnValue(sqf.schema.colRelation) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColRelation: relation}) + sqf.recordColumnValue(sqf.schema.ColRelation, relation) return sqf } @@ -422,12 +501,21 @@ func (sqf SchemaQueryFilterer) MustFilterWithSubjectsSelectors(selectors ...data func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastore.SubjectsSelector) (SchemaQueryFilterer, error) { selectorsOrClause := sq.Or{} + // If there is more than a single filter, record all the subjects as mutable, as the subjects returned + // can differ for each branch. + // TODO(jschorr): Optimize this further where applicable. + if len(selectors) > 1 { + sqf.recordMutableColumnValue(sqf.schema.ColUsersetNamespace) + sqf.recordMutableColumnValue(sqf.schema.ColUsersetObjectID) + sqf.recordMutableColumnValue(sqf.schema.ColUsersetRelation) + } + for _, selector := range selectors { selectorClause := sq.And{} if len(selector.OptionalSubjectType) > 0 { - selectorClause = append(selectorClause, sq.Eq{sqf.schema.colUsersetNamespace: selector.OptionalSubjectType}) - sqf.recordColumnValue(sqf.schema.colUsersetNamespace) + selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetNamespace: selector.OptionalSubjectType}) + sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, selector.OptionalSubjectType) } if len(selector.OptionalSubjectIds) > 0 { @@ -436,7 +524,7 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor }, "cannot have more than %d subject IDs in a single filter", sqf.filterMaximumIDCount) var builder strings.Builder - builder.WriteString(sqf.schema.colUsersetObjectID) + builder.WriteString(sqf.schema.ColUsersetObjectID) builder.WriteString(" IN (") args := make([]any, 0, len(selector.OptionalSubjectIds)) @@ -446,7 +534,7 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor } args = append(args, subjectID) - sqf.recordColumnValue(sqf.schema.colUsersetObjectID) + sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, subjectID) } builder.WriteString("?") @@ -460,8 +548,8 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor if !selector.RelationFilter.IsEmpty() { if selector.RelationFilter.OnlyNonEllipsisRelations { - selectorClause = append(selectorClause, sq.NotEq{sqf.schema.colUsersetRelation: datastore.Ellipsis}) - sqf.recordColumnValue(sqf.schema.colUsersetRelation) + selectorClause = append(selectorClause, sq.NotEq{sqf.schema.ColUsersetRelation: datastore.Ellipsis}) + sqf.recordMutableColumnValue(sqf.schema.ColUsersetRelation) } else { relations := make([]string, 0, 2) if selector.RelationFilter.IncludeEllipsisRelation { @@ -474,14 +562,14 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor if len(relations) == 1 { relName := relations[0] - selectorClause = append(selectorClause, sq.Eq{sqf.schema.colUsersetRelation: relName}) - sqf.recordColumnValue(sqf.schema.colUsersetRelation) + selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetRelation: relName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, relName) } else { orClause := sq.Or{} for _, relationName := range relations { dsRelationName := stringz.DefaultEmpty(relationName, datastore.Ellipsis) - orClause = append(orClause, sq.Eq{sqf.schema.colUsersetRelation: dsRelationName}) - sqf.recordColumnValue(sqf.schema.colUsersetRelation) + orClause = append(orClause, sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, dsRelationName) } selectorClause = append(selectorClause, orClause) @@ -499,27 +587,27 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor // FilterToSubjectFilter returns a new SchemaQueryFilterer that is limited to resources with // subjects that match the specified filter. func (sqf SchemaQueryFilterer) FilterToSubjectFilter(filter *v1.SubjectFilter) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colUsersetNamespace: filter.SubjectType}) - sqf.recordColumnValue(sqf.schema.colUsersetNamespace) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetNamespace: filter.SubjectType}) + sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, filter.SubjectType) if filter.OptionalSubjectId != "" { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colUsersetObjectID: filter.OptionalSubjectId}) - sqf.recordColumnValue(sqf.schema.colUsersetObjectID) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetObjectID: filter.OptionalSubjectId}) + sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, filter.OptionalSubjectId) } if filter.OptionalRelation != nil { dsRelationName := stringz.DefaultEmpty(filter.OptionalRelation.Relation, datastore.Ellipsis) - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colUsersetRelation: dsRelationName}) - sqf.recordColumnValue(sqf.schema.colUsersetRelation) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, datastore.Ellipsis) } return sqf } func (sqf SchemaQueryFilterer) FilterWithCaveatName(caveatName string) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colCaveatName: caveatName}) - sqf.recordColumnValue(sqf.schema.colCaveatName) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColCaveatName: caveatName}) + sqf.recordColumnValue(sqf.schema.ColCaveatName, caveatName) return sqf } @@ -531,7 +619,7 @@ func (sqf SchemaQueryFilterer) limit(limit uint64) SchemaQueryFilterer { // QueryExecutor is a tuple query runner shared by SQL implementations of the datastore. type QueryExecutor struct { - Executor ExecuteQueryFunc + Executor ExecuteReadRelsQueryFunc } // ExecuteQuery executes the query. @@ -540,6 +628,10 @@ func (tqs QueryExecutor) ExecuteQuery( query SchemaQueryFilterer, opts ...options.QueryOptionsOption, ) (datastore.RelationshipIterator, error) { + if query.isCustomQuery { + return nil, spiceerrors.MustBugf("ExecuteQuery should not be called on custom queries") + } + queryOpts := options.NewQueryOptionsWithOptions(opts...) query = query.TupleOrder(queryOpts.Sort) @@ -562,17 +654,57 @@ func (tqs QueryExecutor) ExecuteQuery( limit = *queryOpts.Limit } - toExecute := query.limit(limit) + if limit < math.MaxInt64 { + query = query.limit(limit) + } + + toExecute := query + + // Set the column names to select. + columnNamesToSelect := make([]string, 0, 8+len(query.extraFields)) + + columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColNamespace) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColObjectID) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColRelation) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColUsersetNamespace) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColUsersetObjectID) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColUsersetRelation) + + columnNamesToSelect = append(columnNamesToSelect, query.schema.ColCaveatName, query.schema.ColCaveatContext) + columnNamesToSelect = append(columnNamesToSelect, query.schema.ExtraFields...) + + toExecute.queryBuilder = toExecute.queryBuilder.Columns(columnNamesToSelect...) + + from := query.schema.RelationshipTableName + if query.fromSuffix != "" { + from += " " + query.fromSuffix + } + + toExecute.queryBuilder = toExecute.queryBuilder.From(from) + sql, args, err := toExecute.queryBuilder.ToSql() if err != nil { return nil, err } - return tqs.Executor(ctx, sql, args) + return tqs.Executor(ctx, QueryInfo{query.schema, query.filteringColumnTracker}, sql, args) +} + +func checkColumn(columns []string, tracker map[string]ColumnTracker, colName string) []string { + if r, ok := tracker[colName]; !ok || r.SingleValue == nil { + return append(columns, colName) + } + return columns +} + +// QueryInfo holds the schema information and filtering values for a query. +type QueryInfo struct { + Schema SchemaInformation + FilteringValues map[string]ColumnTracker } -// ExecuteQueryFunc is a function that can be used to execute a single rendered SQL query. -type ExecuteQueryFunc func(ctx context.Context, sql string, args []any) (datastore.RelationshipIterator, error) +// ExecuteReadRelsQueryFunc is a function that can be used to execute a single rendered SQL query. +type ExecuteReadRelsQueryFunc func(ctx context.Context, queryInfo QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) // TxCleanupFunc is a function that should be executed when the caller of // TransactionFactory is done with the transaction. diff --git a/internal/datastore/common/sql_test.go b/internal/datastore/common/sql_test.go index e61339f2f6..58eccaf231 100644 --- a/internal/datastore/common/sql_test.go +++ b/internal/datastore/common/sql_test.go @@ -3,8 +3,6 @@ package common import ( "testing" - "github.com/google/uuid" - "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/tuple" @@ -20,11 +18,11 @@ var toCursor = options.ToCursor func TestSchemaQueryFilterer(t *testing.T) { tests := []struct { - name string - run func(filterer SchemaQueryFilterer) SchemaQueryFilterer - expectedSQL string - expectedArgs []any - expectedColumnCounts map[string]int + name string + run func(filterer SchemaQueryFilterer) SchemaQueryFilterer + expectedSQL string + expectedArgs []any + expectedStaticColumns []string }{ { "relation filter", @@ -33,9 +31,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE relation = ?", []any{"somerelation"}, - map[string]int{ - "relation": 1, - }, + []string{"relation"}, }, { "resource ID filter", @@ -44,9 +40,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE object_id = ?", []any{"someresourceid"}, - map[string]int{ - "object_id": 1, - }, + []string{"object_id"}, }, { "resource IDs filter", @@ -55,7 +49,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE object_id LIKE ?", []any{"someprefix%"}, - map[string]int{}, // object_id is not statically used, so not present in the map + []string{}, }, { "resource IDs prefix filter", @@ -64,9 +58,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE object_id IN (?,?)", []any{"someresourceid", "anotherresourceid"}, - map[string]int{ - "object_id": 2, - }, + []string{}, }, { "resource type filter", @@ -75,9 +67,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ns = ?", []any{"sometype"}, - map[string]int{ - "ns": 1, - }, + []string{"ns"}, }, { "resource filter", @@ -86,11 +76,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ns = ? AND object_id = ? AND relation = ?", []any{"sometype", "someobj", "somerel"}, - map[string]int{ - "ns": 1, - "object_id": 1, - "relation": 1, - }, + []string{"ns", "object_id", "relation"}, }, { "relationships filter with no IDs or relations", @@ -101,9 +87,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ns = ?", []any{"sometype"}, - map[string]int{ - "ns": 1, - }, + []string{"ns"}, }, { "relationships filter with single ID", @@ -115,10 +99,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ns = ? AND object_id IN (?)", []any{"sometype", "someid"}, - map[string]int{ - "ns": 1, - "object_id": 1, - }, + []string{"ns", "object_id"}, }, { "relationships filter with no IDs", @@ -130,9 +111,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ns = ?", []any{"sometype"}, - map[string]int{ - "ns": 1, - }, + []string{"ns"}, }, { "relationships filter with multiple IDs", @@ -144,10 +123,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ns = ? AND object_id IN (?,?)", []any{"sometype", "someid", "anotherid"}, - map[string]int{ - "ns": 1, - "object_id": 2, - }, + []string{"ns"}, }, { "subjects filter with no IDs or relations", @@ -158,9 +134,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ?))", []any{"somesubjectype"}, - map[string]int{ - "subject_ns": 1, - }, + []string{"subject_ns"}, }, { "multiple subjects filters with just types", @@ -173,9 +147,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ?) OR (subject_ns = ?))", []any{"somesubjectype", "anothersubjectype"}, - map[string]int{ - "subject_ns": 2, - }, + []string{}, }, { "subjects filter with single ID", @@ -187,10 +159,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?)))", []any{"somesubjectype", "somesubjectid"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 1, - }, + []string{"subject_ns", "subject_object_id"}, }, { "subjects filter with single ID and no type", @@ -201,9 +170,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_object_id IN (?)))", []any{"somesubjectid"}, - map[string]int{ - "subject_object_id": 1, - }, + []string{"subject_object_id"}, }, { "empty subjects filter", @@ -212,7 +179,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((1=1))", nil, - map[string]int{}, + []string{}, }, { "subjects filter with multiple IDs", @@ -224,10 +191,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?)))", []any{"somesubjectype", "somesubjectid", "anothersubjectid"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 2, - }, + []string{"subject_ns"}, }, { "subjects filter with single ellipsis relation", @@ -239,10 +203,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ? AND subject_relation = ?))", []any{"somesubjectype", "..."}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, - }, + []string{"subject_ns", "subject_relation"}, }, { "subjects filter with single defined relation", @@ -254,10 +215,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ? AND subject_relation = ?))", []any{"somesubjectype", "somesubrel"}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, - }, + []string{"subject_ns", "subject_relation"}, }, { "subjects filter with only non-ellipsis", @@ -269,10 +227,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ? AND subject_relation <> ?))", []any{"somesubjectype", "..."}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, - }, + []string{"subject_ns"}, }, { "subjects filter with defined relation and ellipsis", @@ -284,10 +239,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ? AND (subject_relation = ? OR subject_relation = ?)))", []any{"somesubjectype", "...", "somesubrel"}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 2, - }, + []string{"subject_ns"}, }, { "subjects filter", @@ -300,11 +252,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)))", []any{"somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 2, - "subject_relation": 2, - }, + []string{"subject_ns"}, }, { "multiple subjects filter", @@ -328,11 +276,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_relation <> ?))", []any{"somesubjectype", "a", "b", "...", "somesubrel", "anothersubjecttype", "b", "c", "...", "anotherrel", "thirdsubjectype", "..."}, - map[string]int{ - "subject_ns": 3, - "subject_object_id": 4, - "subject_relation": 5, - }, + []string{}, }, { "v1 subject filter with namespace", @@ -343,9 +287,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE subject_ns = ?", []any{"subns"}, - map[string]int{ - "subject_ns": 1, - }, + []string{"subject_ns"}, }, { "v1 subject filter with subject id", @@ -357,10 +299,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE subject_ns = ? AND subject_object_id = ?", []any{"subns", "subid"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 1, - }, + []string{"subject_ns", "subject_object_id"}, }, { "v1 subject filter with relation", @@ -374,10 +313,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE subject_ns = ? AND subject_relation = ?", []any{"subns", "subrel"}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, - }, + []string{"subject_ns", "subject_relation"}, }, { "v1 subject filter with empty relation", @@ -391,10 +327,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE subject_ns = ? AND subject_relation = ?", []any{"subns", "..."}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, - }, + []string{"subject_ns", "subject_relation"}, }, { "v1 subject filter", @@ -409,11 +342,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", []any{"subns", "subid", "somerel"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 1, - "subject_relation": 1, - }, + []string{"subject_ns", "subject_object_id", "subject_relation"}, }, { "limit", @@ -422,7 +351,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * LIMIT 100", nil, - map[string]int{}, + []string{}, }, { "full resources filter", @@ -444,14 +373,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ns = ? AND relation = ? AND object_id IN (?,?) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)))", []any{"someresourcetype", "somerelation", "someid", "anotherid", "somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, - map[string]int{ - "ns": 1, - "object_id": 2, - "relation": 1, - "subject_ns": 1, - "subject_object_id": 2, - "subject_relation": 2, - }, + []string{"ns", "relation", "subject_ns"}, }, { "order by", @@ -464,9 +386,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ns = ? ORDER BY ns, object_id, relation, subject_ns, subject_object_id, subject_relation", []any{"someresourcetype"}, - map[string]int{ - "ns": 1, - }, + []string{"ns"}, }, { "after with just namespace", @@ -479,9 +399,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ns = ? AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", []any{"someresourcetype", "foo", "viewer", "user", "bar", "..."}, - map[string]int{ - "ns": 1, - }, + []string{"ns"}, }, { "after with just relation", @@ -494,9 +412,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE relation = ? AND (ns,object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", []any{"somerelation", "someresourcetype", "foo", "user", "bar", "..."}, - map[string]int{ - "relation": 1, - }, + []string{"relation"}, }, { "after with namespace and single resource id", @@ -510,10 +426,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ns = ? AND object_id IN (?) AND (relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?)", []any{"someresourcetype", "one", "viewer", "user", "bar", "..."}, - map[string]int{ - "ns": 1, - "object_id": 1, - }, + []string{"ns", "object_id"}, }, { "after with single resource id", @@ -526,9 +439,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE object_id IN (?) AND (ns,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", []any{"one", "someresourcetype", "viewer", "user", "bar", "..."}, - map[string]int{ - "object_id": 1, - }, + []string{"object_id"}, }, { "after with namespace and resource ids", @@ -542,10 +453,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ns = ? AND object_id IN (?,?) AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", []any{"someresourcetype", "one", "two", "foo", "viewer", "user", "bar", "..."}, - map[string]int{ - "ns": 1, - "object_id": 2, - }, + []string{"ns"}, }, { "after with namespace and relation", @@ -559,10 +467,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ns = ? AND relation = ? AND (object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?)", []any{"someresourcetype", "somerelation", "foo", "user", "bar", "..."}, - map[string]int{ - "ns": 1, - "relation": 1, - }, + []string{"ns", "relation"}, }, { "after with subject namespace", @@ -573,9 +478,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ?)) AND (ns,object_id,relation,subject_object_id,subject_relation) > (?,?,?,?,?)", []any{"somesubjectype", "someresourcetype", "foo", "viewer", "bar", "..."}, - map[string]int{ - "subject_ns": 1, - }, + []string{"subject_ns"}, }, { "after with subject namespaces", @@ -590,9 +493,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ?)) AND ((subject_ns = ?)) AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?)", []any{"somesubjectype", "anothersubjectype", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, - map[string]int{ - "subject_ns": 2, - }, + []string{}, }, { "after with resource ID prefix", @@ -601,7 +502,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE object_id LIKE ? AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?)", []any{"someprefix%", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, - map[string]int{}, + []string{}, }, { "order by subject", @@ -614,9 +515,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ns = ? ORDER BY subject_ns, subject_object_id, subject_relation, ns, object_id, relation", []any{"someresourcetype"}, - map[string]int{ - "ns": 1, - }, + []string{"ns"}, }, { "order by subject, after with subject namespace", @@ -627,9 +526,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ?)) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?)", []any{"somesubjectype", "bar", "someresourcetype", "foo", "viewer", "..."}, - map[string]int{ - "subject_ns": 1, - }, + []string{"subject_ns"}, }, { "order by subject, after with subject namespace and subject object id", @@ -641,7 +538,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?))) AND (ns,object_id,relation,subject_relation) > (?,?,?,?)", []any{"somesubjectype", "foo", "someresourcetype", "someresource", "viewer", "..."}, - map[string]int{"subject_ns": 1, "subject_object_id": 1}, + []string{"subject_ns", "subject_object_id"}, }, { "order by subject, after with subject namespace and multiple subject object IDs", @@ -653,15 +550,15 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?)", []any{"somesubjectype", "foo", "bar", "next", "someresourcetype", "someresource", "viewer", "..."}, - map[string]int{"subject_ns": 1, "subject_object_id": 2}, + []string{"subject_ns"}, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - base := sq.Select("*") schema := NewSchemaInformation( + "relationtuples", "ns", "object_id", "relation", @@ -669,12 +566,23 @@ func TestSchemaQueryFilterer(t *testing.T) { "subject_object_id", "subject_relation", "caveat", + "caveat_context", TupleComparison, + sq.Question, ) - filterer := NewSchemaQueryFilterer(schema, base, 100) + filterer := NewSchemaQueryFiltererForRelationshipsSelect(schema, 100) ran := test.run(filterer) - require.Equal(t, test.expectedColumnCounts, ran.filteringColumnCounts) + foundStaticColumns := []string{} + for col, tracker := range ran.filteringColumnTracker { + if tracker.SingleValue != nil { + foundStaticColumns = append(foundStaticColumns, col) + } + } + + require.ElementsMatch(t, test.expectedStaticColumns, foundStaticColumns) + + ran.queryBuilder = ran.queryBuilder.Columns("*") sql, args, err := ran.queryBuilder.ToSql() require.NoError(t, err) @@ -683,27 +591,3 @@ func TestSchemaQueryFilterer(t *testing.T) { }) } } - -func BenchmarkSchemaFilterer(b *testing.B) { - si := NewSchemaInformation( - "namespace", - "object_id", - "object_relation", - "resource_type", - "resource_id", - "resource_relation", - "caveat_name", - TupleComparison, - ) - sqf := NewSchemaQueryFilterer(si, sq.Select("*"), 100) - var names []string - for i := 0; i < 500; i++ { - names = append(names, uuid.NewString()) - } - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = sqf.FilterToResourceIDs(names) - } -} diff --git a/internal/datastore/crdb/caveat.go b/internal/datastore/crdb/caveat.go index 3f66b95810..ebaae37301 100644 --- a/internal/datastore/crdb/caveat.go +++ b/internal/datastore/crdb/caveat.go @@ -23,7 +23,7 @@ var ( ) writeCaveat = psql.Insert(tableCaveat).Columns(colCaveatName, colCaveatDefinition).Suffix(upsertCaveatSuffix) readCaveat = psql.Select(colCaveatDefinition, colTimestamp) - listCaveat = psql.Select(colCaveatName, colCaveatDefinition, colTimestamp).From(tableCaveat).OrderBy(colCaveatName) + listCaveat = psql.Select(colCaveatName, colCaveatDefinition, colTimestamp).OrderBy(colCaveatName) deleteCaveat = psql.Delete(tableCaveat) ) @@ -35,7 +35,7 @@ const ( ) func (cr *crdbReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { - query := cr.fromBuilder(readCaveat, tableCaveat).Where(sq.Eq{colCaveatName: name}) + query := cr.fromWithAsOfSystemTime(readCaveat.Where(sq.Eq{colCaveatName: name}), tableCaveat) sql, args, err := query.ToSql() if err != nil { return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err) @@ -79,7 +79,7 @@ type bytesAndTimestamp struct { } func (cr *crdbReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) { - caveatsWithNames := cr.fromBuilder(listCaveat, tableCaveat) + caveatsWithNames := cr.fromWithAsOfSystemTime(listCaveat, tableCaveat) if len(caveatNames) > 0 { caveatsWithNames = caveatsWithNames.Where(sq.Eq{colCaveatName: caveatNames}) } diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index 16f67030f1..6b15c7f92d 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -191,6 +191,32 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas maxRevisionStaleness := time.Duration(float64(config.revisionQuantization.Nanoseconds())* config.maxRevisionStalenessPercent) * time.Nanosecond + var extraFields []string + relTableName := tableTuple + if config.withIntegrity { + relTableName = tableTupleWithIntegrity + extraFields = []string{ + colIntegrityKeyID, + colIntegrityHash, + colTimestamp, + } + } + + schema := common.NewSchemaInformation( + relTableName, + colNamespace, + colObjectID, + colRelation, + colUsersetNamespace, + colUsersetObjectID, + colUsersetRelation, + colCaveatContextName, + colCaveatContext, + common.ExpandedLogicComparison, + sq.Dollar, + extraFields..., + ) + ds := &crdbDatastore{ RemoteClockRevisions: revisions.NewRemoteClockRevisions( config.gcWindow, @@ -211,6 +237,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas filterMaximumIDCount: config.filterMaximumIDCount, supportsIntegrity: config.withIntegrity, gcWindow: config.gcWindow, + schema: schema, } ds.RemoteClockRevisions.SetNowFunc(ds.headRevisionInternal) @@ -294,6 +321,7 @@ type crdbDatastore struct { overlapKeyInit func(ctx context.Context) keySet analyzeBeforeStatistics bool gcWindow time.Duration + schema common.SchemaInformation beginChangefeedQuery string transactionNowQuery string @@ -312,11 +340,12 @@ func (cds *crdbDatastore) SnapshotReader(rev datastore.Revision) datastore.Reade Executor: pgxcommon.NewPGXExecutorWithIntegrityOption(cds.readPool, cds.supportsIntegrity), } - fromBuilder := func(query sq.SelectBuilder, fromStr string) sq.SelectBuilder { - return query.From(fromStr + " AS OF SYSTEM TIME " + rev.String()) + withAsOfSystemTime := func(query sq.SelectBuilder, tableName string) sq.SelectBuilder { + return query.From(tableName + " AS OF SYSTEM TIME " + rev.String()) } - return &crdbReader{cds.readPool, executor, noOverlapKeyer, nil, fromBuilder, cds.filterMaximumIDCount, cds.tableTupleName(), cds.supportsIntegrity} + asOfSystemTimeSuffix := "AS OF SYSTEM TIME " + rev.String() + return &crdbReader{cds.readPool, executor, noOverlapKeyer, nil, withAsOfSystemTime, asOfSystemTimeSuffix, cds.filterMaximumIDCount, cds.schema, cds.supportsIntegrity} } func (cds *crdbDatastore) ReadWriteTx( @@ -360,11 +389,12 @@ func (cds *crdbDatastore) ReadWriteTx( executor, cds.writeOverlapKeyer, cds.overlapKeyInit(ctx), - func(query sq.SelectBuilder, fromStr string) sq.SelectBuilder { - return query.From(fromStr) + func(query sq.SelectBuilder, tableName string) sq.SelectBuilder { + return query.From(tableName) }, + "", // No AS OF SYSTEM TIME for writes cds.filterMaximumIDCount, - cds.tableTupleName(), + cds.schema, cds.supportsIntegrity, }, tx, @@ -519,14 +549,6 @@ func (cds *crdbDatastore) Features(ctx context.Context) (*datastore.Features, er return features, err } -func (cds *crdbDatastore) tableTupleName() string { - if cds.supportsIntegrity { - return tableTupleWithIntegrity - } - - return tableTuple -} - func (cds *crdbDatastore) features(ctx context.Context) (*datastore.Features, error) { features := datastore.Features{ ContinuousCheckpointing: datastore.Feature{ @@ -557,7 +579,7 @@ func (cds *crdbDatastore) features(ctx context.Context) (*datastore.Features, er features.Watch.Reason = fmt.Sprintf("Range feeds must be enabled in CockroachDB and the user must have permission to create them in order to enable the Watch API: %s", err.Error()) } return nil - }, fmt.Sprintf(cds.beginChangefeedQuery, cds.tableTupleName(), head, "1s")) + }, fmt.Sprintf(cds.beginChangefeedQuery, cds.schema.RelationshipTableName, head, "1s")) <-streamCtx.Done() diff --git a/internal/datastore/crdb/reader.go b/internal/datastore/crdb/reader.go index 0df1fc3e70..8fd62751a6 100644 --- a/internal/datastore/crdb/reader.go +++ b/internal/datastore/crdb/reader.go @@ -29,17 +29,6 @@ var ( countRels = psql.Select("count(*)") - schema = common.NewSchemaInformation( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - common.ExpandedLogicComparison, - ) - queryCounters = psql.Select( colCounterName, colCounterSerializedFilter, @@ -49,14 +38,15 @@ var ( ) type crdbReader struct { - query pgxcommon.DBFuncQuerier - executor common.QueryExecutor - keyer overlapKeyer - overlapKeySet keySet - fromBuilder func(query sq.SelectBuilder, fromStr string) sq.SelectBuilder - filterMaximumIDCount uint16 - tupleTableName string - withIntegrity bool + query pgxcommon.DBFuncQuerier + executor common.QueryExecutor + keyer overlapKeyer + overlapKeySet keySet + fromWithAsOfSystemTime func(query sq.SelectBuilder, tableName string) sq.SelectBuilder + asOfSystemTimeSuffix string + filterMaximumIDCount uint16 + schema common.SchemaInformation + withIntegrity bool } func (cr *crdbReader) CountRelationships(ctx context.Context, name string) (int, error) { @@ -74,8 +64,8 @@ func (cr *crdbReader) CountRelationships(ctx context.Context, name string) (int, return 0, err } - query := cr.fromBuilder(countRels, cr.tupleTableName) - builder, err := common.NewSchemaQueryFilterer(schema, query, cr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) + query := cr.fromWithAsOfSystemTime(countRels, cr.schema.RelationshipTableName) + builder, err := common.NewSchemaQueryFiltererWithStartingQuery(cr.schema, query, cr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err } @@ -103,8 +93,7 @@ func (cr *crdbReader) LookupCounters(ctx context.Context) ([]datastore.Relations } func (cr *crdbReader) lookupCounters(ctx context.Context, optionalFilterName string) ([]datastore.RelationshipCounter, error) { - query := cr.fromBuilder(queryCounters, tableRelationshipCounter) - + query := cr.fromWithAsOfSystemTime(queryCounters, tableRelationshipCounter) if optionalFilterName != noFilterOnCounterName { query = query.Where(sq.Eq{colCounterName: optionalFilterName}) } @@ -176,42 +165,13 @@ func (cr *crdbReader) ReadNamespaceByName( } func (cr *crdbReader) ListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { - nsDefs, err := loadAllNamespaces(ctx, cr.query, cr.fromBuilder) + nsDefs, err := loadAllNamespaces(ctx, cr.query, cr.fromWithAsOfSystemTime) if err != nil { return nil, fmt.Errorf(errUnableToListNamespaces, err) } return nsDefs, nil } -func (cr *crdbReader) queryTuples() sq.SelectBuilder { - if cr.withIntegrity { - return psql.Select( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colCaveatContext, - colIntegrityKeyID, - colIntegrityHash, - colTimestamp, - ) - } - - return psql.Select( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colCaveatContext, - ) -} - func (cr *crdbReader) LookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) { if len(nsNames) == 0 { return nil, nil @@ -228,8 +188,7 @@ func (cr *crdbReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - query := cr.fromBuilder(cr.queryTuples(), cr.tupleTableName) - qBuilder, err := common.NewSchemaQueryFilterer(schema, query, cr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(cr.schema, cr.filterMaximumIDCount).WithFromSuffix(cr.asOfSystemTimeSuffix).FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } @@ -242,8 +201,8 @@ func (cr *crdbReader) ReverseQueryRelationships( subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - query := cr.fromBuilder(cr.queryTuples(), cr.tupleTableName) - qBuilder, err := common.NewSchemaQueryFilterer(schema, query, cr.filterMaximumIDCount). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(cr.schema, cr.filterMaximumIDCount). + WithFromSuffix(cr.asOfSystemTimeSuffix). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err @@ -266,8 +225,7 @@ func (cr *crdbReader) ReverseQueryRelationships( } func (cr crdbReader) loadNamespace(ctx context.Context, tx pgxcommon.DBFuncQuerier, nsName string) (*core.NamespaceDefinition, time.Time, error) { - query := cr.fromBuilder(queryReadNamespace, tableNamespace).Where(sq.Eq{colNamespace: nsName}) - + query := cr.fromWithAsOfSystemTime(queryReadNamespace, tableNamespace).Where(sq.Eq{colNamespace: nsName}) sql, args, err := query.ToSql() if err != nil { return nil, time.Time{}, err @@ -300,8 +258,7 @@ func (cr crdbReader) lookupNamespaces(ctx context.Context, tx pgxcommon.DBFuncQu clause = append(clause, sq.Eq{colNamespace: nsName}) } - query := cr.fromBuilder(queryReadNamespace, tableNamespace).Where(clause) - + query := cr.fromWithAsOfSystemTime(queryReadNamespace, tableNamespace).Where(clause) sql, args, err := query.ToSql() if err != nil { return nil, err @@ -342,7 +299,6 @@ func (cr crdbReader) lookupNamespaces(ctx context.Context, tx pgxcommon.DBFuncQu func loadAllNamespaces(ctx context.Context, tx pgxcommon.DBFuncQuerier, fromBuilder func(sq.SelectBuilder, string) sq.SelectBuilder) ([]datastore.RevisionedNamespace, error) { query := fromBuilder(queryReadNamespace, tableNamespace) - sql, args, err := query.ToSql() if err != nil { return nil, err diff --git a/internal/datastore/crdb/readwrite.go b/internal/datastore/crdb/readwrite.go index 1aafec83c4..fb55fa68a3 100644 --- a/internal/datastore/crdb/readwrite.go +++ b/internal/datastore/crdb/readwrite.go @@ -115,11 +115,11 @@ var ( ) func (rwt *crdbReadWriteTXN) insertQuery() sq.InsertBuilder { - return psql.Insert(rwt.tupleTableName) + return psql.Insert(rwt.schema.RelationshipTableName) } func (rwt *crdbReadWriteTXN) queryDeleteTuples() sq.DeleteBuilder { - return psql.Delete(rwt.tupleTableName) + return psql.Delete(rwt.schema.RelationshipTableName) } func (rwt *crdbReadWriteTXN) queryWriteTuple() sq.InsertBuilder { @@ -542,10 +542,10 @@ var copyColsWithIntegrity = []string{ func (rwt *crdbReadWriteTXN) BulkLoad(ctx context.Context, iter datastore.BulkWriteRelationshipSource) (uint64, error) { if rwt.withIntegrity { - return pgxcommon.BulkLoad(ctx, rwt.tx, rwt.tupleTableName, copyColsWithIntegrity, iter) + return pgxcommon.BulkLoad(ctx, rwt.tx, rwt.schema.RelationshipTableName, copyColsWithIntegrity, iter) } - return pgxcommon.BulkLoad(ctx, rwt.tx, rwt.tupleTableName, copyCols, iter) + return pgxcommon.BulkLoad(ctx, rwt.tx, rwt.schema.RelationshipTableName, copyCols, iter) } var _ datastore.ReadWriteTransaction = &crdbReadWriteTXN{} diff --git a/internal/datastore/crdb/stats.go b/internal/datastore/crdb/stats.go index 2b66297d91..b01a1f3722 100644 --- a/internal/datastore/crdb/stats.go +++ b/internal/datastore/crdb/stats.go @@ -44,8 +44,8 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro if err != nil { return fmt.Errorf("unable to read namespaces: %w", err) } - nsDefs, err = loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), func(sb squirrel.SelectBuilder, fromStr string) squirrel.SelectBuilder { - return sb.From(fromStr) + nsDefs, err = loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), func(sb squirrel.SelectBuilder, tableName string) squirrel.SelectBuilder { + return sb.From(tableName) }) if err != nil { return fmt.Errorf("unable to read namespaces: %w", err) @@ -57,7 +57,7 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro if cds.analyzeBeforeStatistics { if err := cds.readPool.BeginTxFunc(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}, func(tx pgx.Tx) error { - if _, err := tx.Exec(ctx, "ANALYZE "+cds.tableTupleName()); err != nil { + if _, err := tx.Exec(ctx, "ANALYZE "+cds.schema.RelationshipTableName); err != nil { return fmt.Errorf("unable to analyze tuple table: %w", err) } @@ -131,7 +131,7 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro log.Warn().Bool("has-rows", hasRows).Msg("unable to find row count in statistics query result") return nil - }, "SHOW STATISTICS FOR TABLE "+cds.tableTupleName()); err != nil { + }, "SHOW STATISTICS FOR TABLE "+cds.schema.RelationshipTableName); err != nil { return datastore.Stats{}, fmt.Errorf("unable to query unique estimated row count: %w", err) } diff --git a/internal/datastore/crdb/watch.go b/internal/datastore/crdb/watch.go index 883d14978a..b5fa9ab0d2 100644 --- a/internal/datastore/crdb/watch.go +++ b/internal/datastore/crdb/watch.go @@ -116,7 +116,7 @@ func (cds *crdbDatastore) watch( tableNames := make([]string, 0, 4) tableNames = append(tableNames, tableTransactionMetadata) if opts.Content&datastore.WatchRelationships == datastore.WatchRelationships { - tableNames = append(tableNames, cds.tableTupleName()) + tableNames = append(tableNames, cds.schema.RelationshipTableName) } if opts.Content&datastore.WatchSchema == datastore.WatchSchema { tableNames = append(tableNames, tableNamespace) @@ -255,7 +255,7 @@ func (cds *crdbDatastore) watch( } switch tableName { - case cds.tableTupleName(): + case cds.schema.RelationshipTableName: var caveatName string var caveatContext map[string]any if details.After != nil && details.After.RelationshipCaveatName != "" { diff --git a/internal/datastore/mysql/caveat.go b/internal/datastore/mysql/caveat.go index 84283a3bb6..6cb7edafab 100644 --- a/internal/datastore/mysql/caveat.go +++ b/internal/datastore/mysql/caveat.go @@ -22,7 +22,7 @@ const ( ) func (mr *mysqlReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { - filteredReadCaveat := mr.filterer(mr.ReadCaveatQuery) + filteredReadCaveat := mr.aliveFilter(mr.ReadCaveatQuery) sqlStatement, args, err := filteredReadCaveat.Where(sq.Eq{colName: name}).ToSql() if err != nil { return nil, datastore.NoRevision, err @@ -68,7 +68,7 @@ func (mr *mysqlReader) lookupCaveats(ctx context.Context, caveatNames []string) caveatsWithNames = caveatsWithNames.Where(sq.Eq{colName: caveatNames}) } - filteredListCaveat := mr.filterer(caveatsWithNames) + filteredListCaveat := mr.aliveFilter(caveatsWithNames) listSQL, listArgs, err := filteredListCaveat.ToSql() if err != nil { return nil, err diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index 9d6d33827f..994c61e913 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -18,7 +18,6 @@ import ( "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" @@ -29,7 +28,6 @@ import ( log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" - "github.com/authzed/spicedb/pkg/tuple" ) const ( @@ -240,6 +238,20 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option -1*config.gcWindow.Seconds(), ) + schema := common.NewSchemaInformation( + driver.RelationTuple(), + colNamespace, + colObjectID, + colRelation, + colUsersetNamespace, + colUsersetObjectID, + colUsersetRelation, + colCaveatName, + colCaveatContext, + common.ExpandedLogicComparison, + sq.Question, + ) + store := &Datastore{ db: db, driver: driver, @@ -260,6 +272,7 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option readTxOptions: &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: true}, maxRetries: config.maxRetries, analyzeBeforeStats: config.analyzeBeforeStats, + schema: schema, CachedOptimizedRevisions: revisions.NewCachedOptimizedRevisions( maxRevisionStaleness, ), @@ -319,6 +332,7 @@ func (mds *Datastore) SnapshotReader(rev datastore.Revision) datastore.Reader { executor, buildLivingObjectFilterForRevision(rev), mds.filterMaximumIDCount, + mds.schema, } } @@ -362,6 +376,7 @@ func (mds *Datastore) ReadWriteTx( executor, currentlyLivingObjects, mds.filterMaximumIDCount, + mds.schema, }, mds.driver.RelationTuple(), tx, @@ -410,7 +425,24 @@ type querier interface { QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) } -func newMySQLExecutor(tx querier) common.ExecuteQueryFunc { +type wrappedTX struct { + tx querier +} + +func (wtx wrappedTX) QueryFunc(ctx context.Context, f func(context.Context, common.Rows) error, sql string, args ...any) error { + rows, err := wtx.tx.QueryContext(ctx, sql, args...) + if err != nil { + return err + } + + if rows.Err() != nil { + return rows.Err() + } + + return f(ctx, rows) +} + +func newMySQLExecutor(tx querier) common.ExecuteReadRelsQueryFunc { // This implementation does not create a transaction because it's redundant for single statements, and it avoids // the network overhead and reduce contention on the connection pool. From MySQL docs: // @@ -426,79 +458,9 @@ func newMySQLExecutor(tx querier) common.ExecuteQueryFunc { // // Prepared statements are also not used given they perform poorly on environments where connections have // short lifetime (e.g. to gracefully handle load-balancer connection drain) - return func(ctx context.Context, sqlQuery string, args []interface{}) (datastore.RelationshipIterator, error) { - return func(yield func(tuple.Relationship, error) bool) { - span := trace.SpanFromContext(ctx) - - rows, err := tx.QueryContext(ctx, sqlQuery, args...) - if err != nil { - yield(tuple.Relationship{}, fmt.Errorf(errUnableToQueryTuples, err)) - return - } - defer common.LogOnError(ctx, rows.Close) - - span.AddEvent("Query issued to database") - - relCount := 0 - - defer func() { - span.AddEvent("Relationships loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) - }() - - for rows.Next() { - var resourceObjectType string - var resourceObjectID string - var relation string - var subjectObjectType string - var subjectObjectID string - var subjectRelation string - var caveatName string - var caveatContext structpbWrapper - err := rows.Scan( - &resourceObjectType, - &resourceObjectID, - &relation, - &subjectObjectType, - &subjectObjectID, - &subjectRelation, - &caveatName, - &caveatContext, - ) - if err != nil { - yield(tuple.Relationship{}, fmt.Errorf(errUnableToQueryTuples, err)) - return - } - - caveat, err := common.ContextualizedCaveatFrom(caveatName, caveatContext) - if err != nil { - yield(tuple.Relationship{}, fmt.Errorf(errUnableToQueryTuples, err)) - return - } - - relCount++ - if !yield(tuple.Relationship{ - RelationshipReference: tuple.RelationshipReference{ - Resource: tuple.ObjectAndRelation{ - ObjectType: resourceObjectType, - ObjectID: resourceObjectID, - Relation: relation, - }, - Subject: tuple.ObjectAndRelation{ - ObjectType: subjectObjectType, - ObjectID: subjectObjectID, - Relation: subjectRelation, - }, - }, - OptionalCaveat: caveat, - }, nil) { - return - } - } - if err := rows.Err(); err != nil { - yield(tuple.Relationship{}, fmt.Errorf(errUnableToQueryTuples, err)) - return - } - }, nil + return func(ctx context.Context, queryInfo common.QueryInfo, sqlQuery string, args []interface{}) (datastore.RelationshipIterator, error) { + span := trace.SpanFromContext(ctx) + return common.QueryRelationships[common.Rows, structpbWrapper](ctx, queryInfo, sqlQuery, args, span, wrappedTX{tx}, false) } } @@ -518,6 +480,7 @@ type Datastore struct { watchBufferWriteTimeout time.Duration maxRetries uint8 filterMaximumIDCount uint16 + schema common.SchemaInformation optimizedRevisionQuery string validTransactionQuery string diff --git a/internal/datastore/mysql/reader.go b/internal/datastore/mysql/reader.go index 8d523f09e5..8ce194408e 100644 --- a/internal/datastore/mysql/reader.go +++ b/internal/datastore/mysql/reader.go @@ -24,8 +24,9 @@ type mysqlReader struct { txSource txFactory executor common.QueryExecutor - filterer queryFilterer + aliveFilter queryFilterer filterMaximumIDCount uint16 + schema common.SchemaInformation } type queryFilterer func(original sq.SelectBuilder) sq.SelectBuilder @@ -39,17 +40,6 @@ const ( errUnableToReadCount = "unable to read count: %w" ) -var schema = common.NewSchemaInformation( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - common.ExpandedLogicComparison, -) - func (mr *mysqlReader) CountRelationships(ctx context.Context, name string) (int, error) { // Ensure the counter is registered. counters, err := mr.lookupCounters(ctx, name) @@ -66,7 +56,7 @@ func (mr *mysqlReader) CountRelationships(ctx context.Context, name string) (int return 0, err } - qBuilder, err := common.NewSchemaQueryFilterer(schema, mr.filterer(mr.CountRelsQuery), mr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) + qBuilder, err := common.NewSchemaQueryFiltererWithStartingQuery(mr.schema, mr.aliveFilter(mr.CountRelsQuery), mr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err } @@ -114,7 +104,7 @@ func (mr *mysqlReader) LookupCounters(ctx context.Context) ([]datastore.Relation } func (mr *mysqlReader) lookupCounters(ctx context.Context, optionalName string) ([]datastore.RelationshipCounter, error) { - query := mr.filterer(mr.ReadCounterQuery) + query := mr.aliveFilter(mr.ReadCounterQuery) if optionalName != noFilterOnCounterName { query = query.Where(sq.Eq{colCounterName: optionalName}) } @@ -175,7 +165,9 @@ func (mr *mysqlReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, mr.filterer(mr.QueryRelsQuery), mr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(mr.schema, mr.filterMaximumIDCount). + WithAdditionalFilter(mr.aliveFilter). + FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } @@ -188,7 +180,8 @@ func (mr *mysqlReader) ReverseQueryRelationships( subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, mr.filterer(mr.QueryRelsQuery), mr.filterMaximumIDCount). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(mr.schema, mr.filterMaximumIDCount). + WithAdditionalFilter(mr.aliveFilter). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err @@ -218,7 +211,7 @@ func (mr *mysqlReader) ReadNamespaceByName(ctx context.Context, nsName string) ( } defer common.LogOnError(ctx, txCleanup) - loaded, version, err := loadNamespace(ctx, nsName, tx, mr.filterer(mr.ReadNamespaceQuery)) + loaded, version, err := loadNamespace(ctx, nsName, tx, mr.aliveFilter(mr.ReadNamespaceQuery)) switch { case errors.As(err, &datastore.ErrNamespaceNotFound{}): return nil, datastore.NoRevision, err @@ -263,7 +256,7 @@ func (mr *mysqlReader) ListAllNamespaces(ctx context.Context) ([]datastore.Revis } defer common.LogOnError(ctx, txCleanup) - query := mr.filterer(mr.ReadNamespaceQuery) + query := mr.aliveFilter(mr.ReadNamespaceQuery) nsDefs, err := loadAllNamespaces(ctx, tx, query) if err != nil { @@ -289,7 +282,7 @@ func (mr *mysqlReader) LookupNamespacesWithNames(ctx context.Context, nsNames [] clause = append(clause, sq.Eq{colNamespace: nsName}) } - query := mr.filterer(mr.ReadNamespaceQuery.Where(clause)) + query := mr.aliveFilter(mr.ReadNamespaceQuery.Where(clause)) nsDefs, err := loadAllNamespaces(ctx, tx, query) if err != nil { diff --git a/internal/datastore/postgres/caveat.go b/internal/datastore/postgres/caveat.go index 567dac4a97..4688bad766 100644 --- a/internal/datastore/postgres/caveat.go +++ b/internal/datastore/postgres/caveat.go @@ -33,7 +33,7 @@ const ( ) func (r *pgReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { - filteredReadCaveat := r.filterer(readCaveat) + filteredReadCaveat := r.aliveFilter(readCaveat) sql, args, err := filteredReadCaveat.Where(sq.Eq{colCaveatName: name}).ToSql() if err != nil { return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err) @@ -78,7 +78,7 @@ func (r *pgReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]d caveatsWithNames = caveatsWithNames.Where(sq.Eq{colCaveatName: caveatNames}) } - filteredListCaveat := r.filterer(caveatsWithNames) + filteredListCaveat := r.aliveFilter(caveatsWithNames) sql, args, err := filteredListCaveat.ToSql() if err != nil { return nil, fmt.Errorf(errListCaveats, err) diff --git a/internal/datastore/postgres/common/pgx.go b/internal/datastore/postgres/common/pgx.go index 546fe98945..d811ce5426 100644 --- a/internal/datastore/postgres/common/pgx.go +++ b/internal/datastore/postgres/common/pgx.go @@ -2,9 +2,7 @@ package common import ( "context" - "database/sql" "errors" - "fmt" "time" "github.com/ccoveille/go-safecast" @@ -16,139 +14,28 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/tracelog" "github.com/rs/zerolog" - "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - "google.golang.org/protobuf/types/known/timestamppb" "github.com/authzed/spicedb/internal/datastore/common" log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" - corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" - "github.com/authzed/spicedb/pkg/tuple" ) -const errUnableToQueryTuples = "unable to query tuples: %w" - // NewPGXExecutor creates an executor that uses the pgx library to make the specified queries. -func NewPGXExecutor(querier DBFuncQuerier) common.ExecuteQueryFunc { - return func(ctx context.Context, sql string, args []any) (datastore.RelationshipIterator, error) { +func NewPGXExecutor(querier DBFuncQuerier) common.ExecuteReadRelsQueryFunc { + return func(ctx context.Context, queryInfo common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { span := trace.SpanFromContext(ctx) - return queryRels(ctx, sql, args, span, querier, false) + return common.QueryRelationships[pgx.Rows, map[string]any](ctx, queryInfo, sql, args, span, querier, false) } } -func NewPGXExecutorWithIntegrityOption(querier DBFuncQuerier, withIntegrity bool) common.ExecuteQueryFunc { - return func(ctx context.Context, sql string, args []any) (datastore.RelationshipIterator, error) { +func NewPGXExecutorWithIntegrityOption(querier DBFuncQuerier, withIntegrity bool) common.ExecuteReadRelsQueryFunc { + return func(ctx context.Context, queryInfo common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { span := trace.SpanFromContext(ctx) - return queryRels(ctx, sql, args, span, querier, withIntegrity) + return common.QueryRelationships[pgx.Rows, map[string]any](ctx, queryInfo, sql, args, span, querier, withIntegrity) } } -// queryRels queries relationships for the given query and transaction. -func queryRels(ctx context.Context, sqlStatement string, args []any, span trace.Span, tx DBFuncQuerier, withIntegrity bool) (datastore.RelationshipIterator, error) { - return func(yield func(tuple.Relationship, error) bool) { - err := tx.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error { - span.AddEvent("Query issued to database") - - var resourceObjectType string - var resourceObjectID string - var resourceRelation string - var subjectObjectType string - var subjectObjectID string - var subjectRelation string - var caveatName sql.NullString - var caveatCtx map[string]any - - relCount := 0 - for rows.Next() { - var integrity *corev1.RelationshipIntegrity - - if withIntegrity { - var integrityKeyID string - var integrityHash []byte - var timestamp time.Time - - if err := rows.Scan( - &resourceObjectType, - &resourceObjectID, - &resourceRelation, - &subjectObjectType, - &subjectObjectID, - &subjectRelation, - &caveatName, - &caveatCtx, - &integrityKeyID, - &integrityHash, - ×tamp, - ); err != nil { - return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("scan err: %w", err)) - } - - integrity = &corev1.RelationshipIntegrity{ - KeyId: integrityKeyID, - Hash: integrityHash, - HashedAt: timestamppb.New(timestamp), - } - } else { - if err := rows.Scan( - &resourceObjectType, - &resourceObjectID, - &resourceRelation, - &subjectObjectType, - &subjectObjectID, - &subjectRelation, - &caveatName, - &caveatCtx, - ); err != nil { - return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("scan err: %w", err)) - } - } - - var caveat *corev1.ContextualizedCaveat - if caveatName.Valid { - var err error - caveat, err = common.ContextualizedCaveatFrom(caveatName.String, caveatCtx) - if err != nil { - return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("unable to fetch caveat context: %w", err)) - } - } - - relCount++ - if !yield(tuple.Relationship{ - RelationshipReference: tuple.RelationshipReference{ - Resource: tuple.ObjectAndRelation{ - ObjectType: resourceObjectType, - ObjectID: resourceObjectID, - Relation: resourceRelation, - }, - Subject: tuple.ObjectAndRelation{ - ObjectType: subjectObjectType, - ObjectID: subjectObjectID, - Relation: subjectRelation, - }, - }, - OptionalCaveat: caveat, - OptionalIntegrity: integrity, - }, nil) { - return nil - } - } - - if err := rows.Err(); err != nil { - return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("rows err: %w", err)) - } - - span.AddEvent("Rels loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) - return nil - }, sqlStatement, args...) - if err != nil { - if !yield(tuple.Relationship{}, err) { - return - } - } - }, nil -} - // ParseConfigWithInstrumentation returns a pgx.ConnConfig that has been instrumented for observability func ParseConfigWithInstrumentation(url string) (*pgx.ConnConfig, error) { connConfig, err := pgx.ParseConfig(url) diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index 81648075fc..eadf47a054 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -309,6 +309,20 @@ func newPostgresDatastore( maxRevisionStaleness := time.Duration(float64(config.revisionQuantization.Nanoseconds())* config.maxRevisionStalenessPercent) * time.Nanosecond + schema := common.NewSchemaInformation( + tableTuple, + colNamespace, + colObjectID, + colRelation, + colUsersetNamespace, + colUsersetObjectID, + colUsersetRelation, + colCaveatContextName, + colCaveatContext, + common.TupleComparison, + sq.Dollar, + ) + datastore := &pgDatastore{ CachedOptimizedRevisions: revisions.NewCachedOptimizedRevisions( maxRevisionStaleness, @@ -333,6 +347,7 @@ func newPostgresDatastore( isPrimary: isPrimary, inStrictReadMode: config.readStrictMode, filterMaximumIDCount: config.filterMaximumIDCount, + schema: schema, } if isPrimary && config.readStrictMode { @@ -384,6 +399,7 @@ type pgDatastore struct { watchEnabled bool isPrimary bool inStrictReadMode bool + schema common.SchemaInformation credentialsProvider datastore.CredentialsProvider @@ -415,6 +431,7 @@ func (pgd *pgDatastore) SnapshotReader(revRaw datastore.Revision) datastore.Read executor, buildLivingObjectFilterForRevision(rev), pgd.filterMaximumIDCount, + pgd.schema, } } @@ -458,6 +475,7 @@ func (pgd *pgDatastore) ReadWriteTx( executor, currentlyLivingObjects, pgd.filterMaximumIDCount, + pgd.schema, }, tx, newXID, diff --git a/internal/datastore/postgres/reader.go b/internal/datastore/postgres/reader.go index b8c0448ebb..112762575b 100644 --- a/internal/datastore/postgres/reader.go +++ b/internal/datastore/postgres/reader.go @@ -18,37 +18,16 @@ import ( type pgReader struct { query pgxcommon.DBFuncQuerier executor common.QueryExecutor - filterer queryFilterer + aliveFilter queryFilterer filterMaximumIDCount uint16 + schema common.SchemaInformation } type queryFilterer func(original sq.SelectBuilder) sq.SelectBuilder var ( - queryTuples = psql.Select( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colCaveatContext, - ).From(tableTuple) - countRels = psql.Select("COUNT(*)").From(tableTuple) - schema = common.NewSchemaInformation( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - common.TupleComparison, - ) - readNamespace = psql. Select(colConfig, colCreatedXid). From(tableNamespace) @@ -82,7 +61,7 @@ func (r *pgReader) CountRelationships(ctx context.Context, name string) (int, er return 0, err } - qBuilder, err := common.NewSchemaQueryFilterer(schema, r.filterer(countRels), r.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) + qBuilder, err := common.NewSchemaQueryFiltererWithStartingQuery(r.schema, r.aliveFilter(countRels), r.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err } @@ -122,7 +101,7 @@ func (r *pgReader) lookupCounters(ctx context.Context, optionalName string) ([]d query = query.Where(sq.Eq{colCounterName: optionalName}) } - sql, args, err := r.filterer(query).ToSql() + sql, args, err := r.aliveFilter(query).ToSql() if err != nil { return nil, fmt.Errorf("unable to lookup counters: %w", err) } @@ -170,7 +149,9 @@ func (r *pgReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, r.filterer(queryTuples), r.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(r.schema, r.filterMaximumIDCount). + WithAdditionalFilter(r.aliveFilter). + FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } @@ -183,7 +164,8 @@ func (r *pgReader) ReverseQueryRelationships( subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, r.filterer(queryTuples), r.filterMaximumIDCount). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(r.schema, r.filterMaximumIDCount). + WithAdditionalFilter(r.aliveFilter). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err @@ -206,7 +188,7 @@ func (r *pgReader) ReverseQueryRelationships( } func (r *pgReader) ReadNamespaceByName(ctx context.Context, nsName string) (*core.NamespaceDefinition, datastore.Revision, error) { - loaded, version, err := r.loadNamespace(ctx, nsName, r.query, r.filterer) + loaded, version, err := r.loadNamespace(ctx, nsName, r.query, r.aliveFilter) switch { case errors.As(err, &datastore.ErrNamespaceNotFound{}): return nil, datastore.NoRevision, err @@ -236,7 +218,7 @@ func (r *pgReader) loadNamespace(ctx context.Context, namespace string, tx pgxco } func (r *pgReader) ListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { - nsDefsWithRevisions, err := loadAllNamespaces(ctx, r.query, r.filterer) + nsDefsWithRevisions, err := loadAllNamespaces(ctx, r.query, r.aliveFilter) if err != nil { return nil, fmt.Errorf(errUnableToListNamespaces, err) } @@ -255,7 +237,7 @@ func (r *pgReader) LookupNamespacesWithNames(ctx context.Context, nsNames []stri } nsDefsWithRevisions, err := loadAllNamespaces(ctx, r.query, func(original sq.SelectBuilder) sq.SelectBuilder { - return r.filterer(original).Where(clause) + return r.aliveFilter(original).Where(clause) }) if err != nil { return nil, fmt.Errorf(errUnableToListNamespaces, err) diff --git a/internal/datastore/postgres/readwrite.go b/internal/datastore/postgres/readwrite.go index 48b93c7471..7db7605e4b 100644 --- a/internal/datastore/postgres/readwrite.go +++ b/internal/datastore/postgres/readwrite.go @@ -574,7 +574,7 @@ func (rwt *pgReadWriteTXN) WriteNamespaces(ctx context.Context, newConfigs ...*c } func (rwt *pgReadWriteTXN) DeleteNamespaces(ctx context.Context, nsNames ...string) error { - filterer := func(original sq.SelectBuilder) sq.SelectBuilder { + aliveFilter := func(original sq.SelectBuilder) sq.SelectBuilder { return original.Where(sq.Eq{colDeletedXid: liveDeletedTxnID}) } @@ -582,7 +582,7 @@ func (rwt *pgReadWriteTXN) DeleteNamespaces(ctx context.Context, nsNames ...stri tplClauses := make([]sq.Sqlizer, 0, len(nsNames)) querier := pgxcommon.QuerierFuncsFor(rwt.tx) for _, nsName := range nsNames { - _, _, err := rwt.loadNamespace(ctx, nsName, querier, filterer) + _, _, err := rwt.loadNamespace(ctx, nsName, querier, aliveFilter) switch { case errors.As(err, &datastore.ErrNamespaceNotFound{}): return err diff --git a/internal/datastore/postgres/stats.go b/internal/datastore/postgres/stats.go index b428cd4657..0e0bea63f6 100644 --- a/internal/datastore/postgres/stats.go +++ b/internal/datastore/postgres/stats.go @@ -51,7 +51,7 @@ func (pgd *pgDatastore) Statistics(ctx context.Context) (datastore.Stats, error) return datastore.Stats{}, fmt.Errorf("unable to prepare row count sql: %w", err) } - filterer := func(original sq.SelectBuilder) sq.SelectBuilder { + aliveFilter := func(original sq.SelectBuilder) sq.SelectBuilder { return original.Where(sq.Eq{colDeletedXid: liveDeletedTxnID}) } @@ -69,7 +69,7 @@ func (pgd *pgDatastore) Statistics(ctx context.Context) (datastore.Stats, error) return fmt.Errorf("unable to query unique ID: %w", err) } - nsDefsWithRevisions, err := loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), filterer) + nsDefsWithRevisions, err := loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), aliveFilter) if err != nil { return fmt.Errorf("unable to load namespaces: %w", err) } diff --git a/internal/datastore/spanner/reader.go b/internal/datastore/spanner/reader.go index 821de67035..6ced17a91f 100644 --- a/internal/datastore/spanner/reader.go +++ b/internal/datastore/spanner/reader.go @@ -7,6 +7,7 @@ import ( "time" "cloud.google.com/go/spanner" + sq "github.com/Masterminds/squirrel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "google.golang.org/grpc/codes" @@ -54,7 +55,7 @@ func (sr spannerReader) CountRelationships(ctx context.Context, name string) (in return 0, err } - builder, err := common.NewSchemaQueryFilterer(schema, countRels, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) + builder, err := common.NewSchemaQueryFiltererWithStartingQuery(schema, countRels, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err } @@ -134,7 +135,7 @@ func (sr spannerReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, queryTuples, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(schema, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } @@ -147,7 +148,7 @@ func (sr spannerReader) ReverseQueryRelationships( subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, queryTuples, sr.filterMaximumIDCount). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(schema, sr.filterMaximumIDCount). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err @@ -171,8 +172,8 @@ func (sr spannerReader) ReverseQueryRelationships( var errStopIterator = fmt.Errorf("stop iteration") -func queryExecutor(txSource txFactory) common.ExecuteQueryFunc { - return func(ctx context.Context, sql string, args []any) (datastore.RelationshipIterator, error) { +func queryExecutor(txSource txFactory) common.ExecuteReadRelsQueryFunc { + return func(ctx context.Context, queryInfo common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { return func(yield func(tuple.Relationship, error) bool) { span := trace.SpanFromContext(ctx) span.AddEvent("Query issued to database") @@ -185,25 +186,28 @@ func queryExecutor(txSource txFactory) common.ExecuteQueryFunc { relCount := 0 defer span.SetAttributes(attribute.Int("count", relCount)) + var resourceObjectType string + var resourceObjectID string + var relation string + var subjectObjectType string + var subjectObjectID string + var subjectRelation string + var caveatName spanner.NullString + var caveatCtx spanner.NullJSON + + colsToSelect := make([]any, 0, 8) + + colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColNamespace, &resourceObjectType) + colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColObjectID, &resourceObjectID) + colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColRelation, &relation) + colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetNamespace, &subjectObjectType) + colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetObjectID, &subjectObjectID) + colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetRelation, &subjectRelation) + + colsToSelect = append(colsToSelect, &caveatName, &caveatCtx) + if err := iter.Do(func(row *spanner.Row) error { - var resourceObjectType string - var resourceObjectID string - var relation string - var subjectObjectType string - var subjectObjectID string - var subjectRelation string - var caveatName spanner.NullString - var caveatCtx spanner.NullJSON - err := row.Columns( - &resourceObjectType, - &resourceObjectID, - &relation, - &subjectObjectType, - &subjectObjectID, - &subjectRelation, - &caveatName, - &caveatCtx, - ) + err := row.Columns(colsToSelect...) if err != nil { return err } @@ -346,17 +350,6 @@ func readAllNamespaces(iter *spanner.RowIterator, span trace.Span) ([]datastore. return allNamespaces, nil } -var queryTuples = sql.Select( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colCaveatContext, -).From(tableRelationship) - var countRels = sql.Select("COUNT(*)").From(tableRelationship) var queryTuplesForDelete = sql.Select( @@ -369,6 +362,7 @@ var queryTuplesForDelete = sql.Select( ).From(tableRelationship) var schema = common.NewSchemaInformation( + tableRelationship, colNamespace, colObjectID, colRelation, @@ -376,7 +370,9 @@ var schema = common.NewSchemaInformation( colUsersetObjectID, colUsersetRelation, colCaveatName, + colCaveatContext, common.ExpandedLogicComparison, + sq.AtP, ) var _ datastore.Reader = spannerReader{} diff --git a/pkg/datastore/test/relationships.go b/pkg/datastore/test/relationships.go index 0c64f12dd8..7005d095fb 100644 --- a/pkg/datastore/test/relationships.go +++ b/pkg/datastore/test/relationships.go @@ -1573,10 +1573,7 @@ func ConcurrentWriteSerializationTest(t *testing.T, tester DatastoreTester) { <-waitToFinish return err }) - if err != nil { - panic(err) - } - return nil + return err }) <-waitToStart