Skip to content

Commit

Permalink
[drizzle] Fix cursors using aliased columsn
Browse files Browse the repository at this point in the history
  • Loading branch information
hayes committed Sep 9, 2024
1 parent 27a4189 commit cc5f993
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 35 deletions.
5 changes: 5 additions & 0 deletions .changeset/chatty-crabs-mate.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@pothos/plugin-drizzle": patch
---

Fix cursors using aliased coluns
2 changes: 1 addition & 1 deletion packages/plugin-drizzle/src/drizzle-field-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ export class DrizzleObjectFieldBuilder<
(parent as Record<string, never>)[name],
args,
select.limit,
getCursorFormatter(cursorColumns),
getCursorFormatter(cursorColumns, schemaConfig),
);
},
},
Expand Down
5 changes: 3 additions & 2 deletions packages/plugin-drizzle/src/schema-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,12 @@ schemaBuilderProto.drizzleNode = function drizzleNode(
Column
>,
) {
const tableConfig = getSchemaConfig(this).schema![table];
const schemaConfig = getSchemaConfig(this);
const tableConfig = schemaConfig.schema![table];
const idColumn = typeof column === 'function' ? column(tableConfig.columns) : column;
const idColumns = Array.isArray(idColumn) ? idColumn : [idColumn];
const interfaceRef = this.nodeInterfaceRef?.();
const resolve = getIDSerializer(idColumns);
const resolve = getIDSerializer(idColumns, schemaConfig);
const idParser = getIDParser(idColumns);
const typeName = variant ?? name ?? table;
const nodeRef = new DrizzleNodeRef(typeName, table, {
Expand Down
28 changes: 28 additions & 0 deletions packages/plugin-drizzle/src/utils/config.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import { type SchemaTypes, createContextCache } from '@pothos/core';
import {
type Column,
type RelationalSchemaConfig,
type TableRelationalConfig,
type TablesRelationalConfig,
createTableRelationsHelpers,
extractTablesRelationalConfig,
getTableName,
} from 'drizzle-orm';
import type { DrizzleClient } from '../types';

export interface PothosDrizzleSchemaConfig extends RelationalSchemaConfig<TablesRelationalConfig> {
dbToSchema: Record<string, TableRelationalConfig>;
columnToTsName: (column: Column) => string;
}
const configCache = createContextCache(
(builder: PothosSchemaTypes.SchemaBuilder<SchemaTypes>): PothosDrizzleSchemaConfig => {
Expand Down Expand Up @@ -37,8 +40,33 @@ const configCache = createContextCache(
{},
);

const columnMappings = Object.values(dbToSchema).reduce<Record<string, Record<string, string>>>(
(acc, table) => {
acc[table.dbName] = Object.entries(table.columns).reduce<Record<string, string>>(
(acc, [name, column]) => {
acc[column.name] = name;
return acc;
},
{},
);
return acc;
},
{},
);

return {
dbToSchema,
columnToTsName: (column) => {
const tableName = getTableName(column.table);
const table = columnMappings[tableName];
const columnName = table?.[column.name];

if (!columnName) {
throw new Error(`Could not find column mapping for ${tableName}.${column.name}`);
}

return columnName;
},
...config,
};
},
Expand Down
26 changes: 15 additions & 11 deletions packages/plugin-drizzle/src/utils/cursors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ export function formatCursorChunk(value: unknown) {
}
}

export function formatDrizzleCursor(record: Record<string, unknown>, fields: Column[]) {
return getCursorFormatter(fields)(record);
export function formatDrizzleCursor(
record: Record<string, unknown>,
fields: Column[],
config: PothosDrizzleSchemaConfig,
) {
return getCursorFormatter(fields, config)(record);
}

export function formatIDChunk(value: unknown) {
Expand All @@ -58,40 +62,40 @@ export function formatIDChunk(value: unknown) {
}
}

export function getIDSerializer(fields: Column[]) {
export function getIDSerializer(fields: Column[], config: PothosDrizzleSchemaConfig) {
if (fields.length === 0) {
throw new PothosValidationError('Column serializer must have at least one field');
}

return (value: Record<string, unknown>) => {
if (fields.length > 1) {
return `${JSON.stringify(fields.map((col) => value[col.name]))}`;
return `${JSON.stringify(fields.map((col) => value[config.columnToTsName(col)]))}`;
}

return `${formatIDChunk(value[fields[0].name])}`;
return `${formatIDChunk(value[config.columnToTsName(fields[0])])}`;
};
}

export function getColumnSerializer(fields: Column[]) {
export function getColumnSerializer(fields: Column[], config: PothosDrizzleSchemaConfig) {
if (fields.length === 0) {
throw new PothosValidationError('Column serializer must have at least one field');
}

return (value: Record<string, unknown>) => {
if (fields.length > 1) {
return `J:${JSON.stringify(fields.map((col) => value[col.name]))}`;
return `J:${JSON.stringify(fields.map((col) => value[config.columnToTsName(col)]))}`;
}

return `${formatCursorChunk(value[fields[0].name])}`;
return `${formatCursorChunk(value[config.columnToTsName(fields[0])])}`;
};
}

export function getCursorFormatter(fields: Column[]) {
export function getCursorFormatter(fields: Column[], config: PothosDrizzleSchemaConfig) {
if (fields.length === 0) {
throw new PothosValidationError('Cursor must have at least one field');
}

const serializer = getColumnSerializer(fields);
const serializer = getColumnSerializer(fields, config);

return (value: Record<string, unknown>) => {
return encodeBase64(`DC:${serializer(value)}`);
Expand Down Expand Up @@ -489,7 +493,7 @@ export async function resolveDrizzleCursorConnection<T extends {}>(
: q.orderBy
: table.primaryKey,
});
formatter = getCursorFormatter(cursorColumns);
formatter = getCursorFormatter(cursorColumns, config);

const where = typeof q.where === 'function' ? q.where(table.columns, getOperators()) : q.where;

Expand Down
36 changes: 15 additions & 21 deletions packages/plugin-drizzle/tests/example/db/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,21 @@ export const profile = sqliteTable('profile', {
bio: text('bio'),
});

export const posts = sqliteTable(
'posts',
{
id: integer('id'),
slug: text('slug').unique(),
title: text('title').notNull(),
content: text('content').notNull(),
published: integer('published').notNull().default(0),
authorId: integer('author_id')
.notNull()
.references(() => users.id, {
onDelete: 'cascade',
}),
categoryId: integer('category_id').references(() => categories.id),
createdAt: text('createdAt').notNull().default(sql`(current_timestamp)`),
updatedAt: text('createdAt').notNull().default(sql`(current_timestamp)`),
},
(table) => ({
pk: primaryKey({ columns: [table.id] }),
}),
);
export const posts = sqliteTable('posts', {
id: integer('id').primaryKey({ autoIncrement: true }),
slug: text('slug').unique(),
title: text('title').notNull(),
content: text('content').notNull(),
published: integer('published').notNull().default(0),
authorId: integer('author_id')
.notNull()
.references(() => users.id, {
onDelete: 'cascade',
}),
categoryId: integer('category_id').references(() => categories.id),
createdAt: text('createdAt').notNull().default(sql`(current_timestamp)`),
updatedAt: text('createdAt').notNull().default(sql`(current_timestamp)`),
});

export const postLikes = sqliteTable(
'post_likes',
Expand Down

0 comments on commit cc5f993

Please sign in to comment.