-
-
Save ikupenov/10bc89d92d92eaba8cc5569013e04069 to your computer and use it in GitHub Desktop.
| import { and, type DBQueryConfig, eq, type SQLWrapper } from "drizzle-orm"; | |
| import { drizzle } from "drizzle-orm/postgres-js"; | |
| import postgres, { type Sql } from "postgres"; | |
| import { type AnyArgs } from "@/common"; | |
| import { | |
| type DbClient, | |
| type DbTable, | |
| type DeleteArgs, | |
| type DeleteFn, | |
| type FindArgs, | |
| type FindFn, | |
| type FromArgs, | |
| type FromFn, | |
| type InsertArgs, | |
| type JoinArgs, | |
| type JoinFn, | |
| type Owner, | |
| type RlsDbClient, | |
| type SetArgs, | |
| type SetFn, | |
| type UpdateArgs, | |
| type ValuesArgs, | |
| type ValuesFn, | |
| type WhereArgs, | |
| type WhereFn, | |
| } from "./db-client.types"; | |
| import * as schema from "./schema"; | |
| export const connectDb = (connectionString: string) => { | |
| return postgres(connectionString); | |
| }; | |
| export const createDbClient = (client: Sql): DbClient => { | |
| return drizzle(client, { schema }); | |
| }; | |
| export const createRlsDbClient = (client: Sql, owner: Owner): RlsDbClient => { | |
| const db = createDbClient(client); | |
| const ownerIdColumn = "ownerId" as const; | |
| // eslint-disable-next-line import/namespace | |
| const getTable = (table: DbTable) => schema[table]; | |
| const getAccessPolicy = ( | |
| table: { | |
| // eslint-disable-next-line @typescript-eslint/no-explicit-any | |
| [ownerIdColumn]: any; | |
| }, | |
| owner: Owner, | |
| ) => eq(table[ownerIdColumn], owner.id); | |
| interface InvokeContext { | |
| path?: string[]; | |
| fnPath?: { name: string; args: unknown[] }[]; | |
| } | |
| interface InterceptFn { | |
| invoke: (...args: unknown[]) => unknown; | |
| name: string; | |
| args: unknown[]; | |
| } | |
| interface OverrideFn { | |
| pattern: string | string[]; | |
| action: () => unknown; | |
| } | |
| const intercept = (fn: InterceptFn, context: InvokeContext = {}) => { | |
| const { path = [], fnPath = [] } = context; | |
| const pathAsString = path.join("."); | |
| const matchPath = (pattern: string) => { | |
| return new RegExp( | |
| `^${pattern.replace(/\./g, "\\.").replace(/\*/g, ".*")}$`, | |
| ).test(pathAsString); | |
| }; | |
| const overrides: OverrideFn[] = [ | |
| { | |
| pattern: ["db.execute", "db.*.execute"], | |
| action: () => { | |
| throw new Error("'execute' in rls DB is not allowed"); | |
| }, | |
| }, | |
| { | |
| pattern: [ | |
| "db.query.findMany", | |
| "db.query.*.findMany", | |
| "db.query.findFirst", | |
| "db.query.*.findFirst", | |
| ], | |
| action: () => { | |
| const findFn = fn.invoke as FindFn; | |
| const findArgs = fn.args as FindArgs; | |
| const tableIndex = path.findIndex((x) => x === "query") + 1; | |
| const tableName = path[tableIndex]! as keyof typeof db.query; | |
| const table = getTable(tableName as DbTable); | |
| if (ownerIdColumn in table) { | |
| let [config] = findArgs; | |
| if (config?.where) { | |
| config = { | |
| ...config, | |
| where: and( | |
| getAccessPolicy(table, owner), | |
| config.where as SQLWrapper, | |
| ), | |
| }; | |
| } | |
| if (!config?.where) { | |
| config = { | |
| ...config, | |
| where: getAccessPolicy(table, owner), | |
| }; | |
| } | |
| if (config.with) { | |
| config = { | |
| ...config, | |
| with: ( | |
| Object.keys(config.with) as (keyof typeof config.with)[] | |
| ).reduce<DBQueryConfig["with"]>((acc, key) => { | |
| const value = config!.with![key] as | |
| | true | |
| | null | |
| | DBQueryConfig<"many">; | |
| if (value === true) { | |
| return { | |
| ...acc, | |
| [key]: { | |
| where: (table) => | |
| ownerIdColumn in table | |
| ? // eslint-disable-next-line @typescript-eslint/no-unsafe-argument, @typescript-eslint/no-explicit-any | |
| getAccessPolicy(table as any, owner) | |
| : undefined, | |
| }, | |
| }; | |
| } | |
| if (typeof value === "object" && value !== null) { | |
| return { | |
| ...acc, | |
| [key]: { | |
| ...value, | |
| where: (table, other) => | |
| ownerIdColumn in table | |
| ? and( | |
| // eslint-disable-next-line @typescript-eslint/no-unsafe-argument, @typescript-eslint/no-explicit-any | |
| getAccessPolicy(table as any, owner), | |
| typeof value.where === "function" | |
| ? value.where(table, other) | |
| : value.where, | |
| ) | |
| : typeof value.where === "function" | |
| ? value.where(table, other) | |
| : value.where, | |
| }, | |
| }; | |
| } | |
| return { ...acc, [key]: value }; | |
| // eslint-disable-next-line @typescript-eslint/no-unsafe-argument, @typescript-eslint/prefer-reduce-type-parameter, @typescript-eslint/no-explicit-any | |
| }, config.with as any), | |
| }; | |
| } | |
| return findFn(...([config] as FindArgs)); | |
| } | |
| return findFn(...findArgs); | |
| }, | |
| }, | |
| { | |
| pattern: "db.*.from", | |
| action: () => { | |
| const fromFn = fn.invoke as FromFn; | |
| const fromArgs = fn.args as FromArgs; | |
| const [table] = fromArgs; | |
| if (ownerIdColumn in table) { | |
| return fromFn(...fromArgs).where(getAccessPolicy(table, owner)); | |
| } | |
| return fromFn(...fromArgs); | |
| }, | |
| }, | |
| { | |
| pattern: ["db.*.from.where", "db.*.from.*.where"], | |
| action: () => { | |
| const whereFn = fn.invoke as WhereFn; | |
| const whereArgs = fn.args as WhereArgs; | |
| const [table] = fnPath.findLast((x) => x.name === "from") | |
| ?.args as FromArgs; | |
| if (ownerIdColumn in table) { | |
| const [whereFilter] = whereArgs; | |
| return whereFn( | |
| and(getAccessPolicy(table, owner), whereFilter as SQLWrapper), | |
| ); | |
| } | |
| return whereFn(...whereArgs); | |
| }, | |
| }, | |
| { | |
| pattern: [ | |
| "db.*.leftJoin", | |
| "db.*.rightJoin", | |
| "db.*.innerJoin", | |
| "db.*.fullJoin", | |
| ], | |
| action: () => { | |
| const joinFn = fn.invoke as JoinFn; | |
| const joinArgs = fn.args as JoinArgs; | |
| const [table, joinOptions] = joinArgs; | |
| if (ownerIdColumn in table) { | |
| return joinFn( | |
| table, | |
| and(getAccessPolicy(table, owner), joinOptions as SQLWrapper), | |
| ); | |
| } | |
| return joinFn(...joinArgs); | |
| }, | |
| }, | |
| { | |
| pattern: "db.insert.values", | |
| action: () => { | |
| const valuesFn = fn.invoke as ValuesFn; | |
| const valuesArgs = fn.args as ValuesArgs; | |
| const [table] = fnPath.findLast((x) => x.name === "insert") | |
| ?.args as InsertArgs; | |
| if (ownerIdColumn in table) { | |
| let [valuesToInsert] = valuesArgs; | |
| if (!Array.isArray(valuesToInsert)) { | |
| valuesToInsert = [valuesToInsert]; | |
| } | |
| const valuesToInsertWithOwner = valuesToInsert.map((value) => ({ | |
| ...value, | |
| ownerId: owner.id, | |
| })); | |
| return valuesFn(valuesToInsertWithOwner); | |
| } | |
| return valuesFn(...valuesArgs); | |
| }, | |
| }, | |
| { | |
| pattern: "db.update.set", | |
| action: () => { | |
| const setFn = fn.invoke as SetFn; | |
| const setArgs = fn.args as SetArgs; | |
| const [table] = fnPath.findLast((x) => x.name === "update") | |
| ?.args as UpdateArgs; | |
| if (ownerIdColumn in table) { | |
| return setFn(...setArgs).where(getAccessPolicy(table, owner)); | |
| } | |
| return setFn(...setArgs); | |
| }, | |
| }, | |
| { | |
| pattern: ["db.update.where", "db.update.*.where"], | |
| action: () => { | |
| const whereFn = fn.invoke as WhereFn; | |
| const whereArgs = fn.args as WhereArgs; | |
| const [table] = [...fnPath].reverse().find((x) => x.name === "update") | |
| ?.args as UpdateArgs; | |
| if (ownerIdColumn in table) { | |
| const [whereFilter] = whereArgs; | |
| return whereFn( | |
| and(getAccessPolicy(table, owner), whereFilter as SQLWrapper), | |
| ); | |
| } | |
| return whereFn(...whereArgs); | |
| }, | |
| }, | |
| { | |
| pattern: "db.delete", | |
| action: () => { | |
| const deleteFn = fn.invoke as DeleteFn; | |
| const deleteArgs = fn.args as DeleteArgs; | |
| const [table] = deleteArgs; | |
| if (ownerIdColumn in table) { | |
| return deleteFn(...deleteArgs).where(getAccessPolicy(table, owner)); | |
| } | |
| return deleteFn(...deleteArgs); | |
| }, | |
| }, | |
| { | |
| pattern: ["db.delete.where", "db.delete.*.where"], | |
| action: () => { | |
| const whereFn = fn.invoke as WhereFn; | |
| const whereArgs = fn.args as WhereArgs; | |
| const [table] = fnPath.findLast((x) => x.name === "delete") | |
| ?.args as DeleteArgs; | |
| if (ownerIdColumn in table) { | |
| const [whereOptions] = whereArgs; | |
| return whereFn( | |
| and(getAccessPolicy(table, owner), whereOptions as SQLWrapper), | |
| ); | |
| } | |
| return whereFn(...whereArgs); | |
| }, | |
| }, | |
| ]; | |
| const fnOverride = overrides.find(({ pattern, action }) => { | |
| if (Array.isArray(pattern) && pattern.some(matchPath)) { | |
| return action; | |
| } | |
| if (typeof pattern === "string" && matchPath(pattern)) { | |
| return action; | |
| } | |
| return null; | |
| })?.action; | |
| return fnOverride ? fnOverride() : fn.invoke(...fn.args); | |
| }; | |
| const createProxy = <T extends object>( | |
| target: T, | |
| context: InvokeContext = {}, | |
| ): T => { | |
| const { path = [], fnPath = [] } = context; | |
| return new Proxy<T>(target, { | |
| get: (innerTarget, innerTargetProp, innerTargetReceiver) => { | |
| const currentPath = path.concat(innerTargetProp.toString()); | |
| const innerTargetPropValue = Reflect.get( | |
| innerTarget, | |
| innerTargetProp, | |
| innerTargetReceiver, | |
| ); | |
| if (typeof innerTargetPropValue === "function") { | |
| return (...args: AnyArgs) => { | |
| const currentFnPath = [ | |
| ...fnPath, | |
| { name: innerTargetProp.toString(), args }, | |
| ]; | |
| const result = intercept( | |
| { | |
| invoke: innerTargetPropValue.bind( | |
| innerTarget, | |
| ) as InterceptFn["invoke"], | |
| name: innerTargetProp.toString(), | |
| args, | |
| }, | |
| { path: currentPath, fnPath: currentFnPath }, | |
| ); | |
| if ( | |
| typeof result === "object" && | |
| result !== null && | |
| !Array.isArray(result) | |
| ) { | |
| return createProxy(result, { | |
| path: currentPath, | |
| fnPath: currentFnPath, | |
| }); | |
| } | |
| return result; | |
| }; | |
| } else if ( | |
| typeof innerTargetPropValue === "object" && | |
| innerTargetPropValue !== null && | |
| !Array.isArray(innerTargetPropValue) | |
| ) { | |
| // wrap nested objects in a proxy as well | |
| return createProxy(innerTargetPropValue, { | |
| path: currentPath, | |
| fnPath, | |
| }); | |
| } | |
| return innerTargetPropValue; | |
| }, | |
| }); | |
| }; | |
| return createProxy(db, { path: ["db"] }); | |
| }; |
| import { type drizzle } from "drizzle-orm/postgres-js"; | |
| import type * as schema from "./schema"; | |
| declare const db: ReturnType<typeof drizzle<typeof schema>>; | |
| export interface Owner { | |
| id: string | null; | |
| } | |
| export type DbClient = typeof db; | |
| export type DbSchema = typeof schema; | |
| export type DbTable = keyof DbSchema; | |
| export type RlsDbClient = Omit<DbClient, "execute">; | |
| export type FindFn<K extends keyof typeof db.query = keyof typeof db.query> = ( | |
| ...args: | |
| | Parameters<(typeof db.query)[K]["findFirst"]> | |
| | Parameters<(typeof db.query)[K]["findMany"]> | |
| ) => | |
| | ReturnType<(typeof db.query)[K]["findFirst"]> | |
| | ReturnType<(typeof db.query)[K]["findMany"]>; | |
| export type FindArgs<K extends keyof typeof db.query = keyof typeof db.query> = | |
| Parameters<FindFn<K>>; | |
| export type SelectFn = typeof db.select; | |
| export type SelectArgs = Parameters<SelectFn>; | |
| export type FromFn = ReturnType<SelectFn>["from"]; | |
| export type FromArgs = Parameters<FromFn>; | |
| export type WhereFn = ReturnType<FromFn>["where"]; | |
| export type WhereArgs = Parameters<WhereFn>; | |
| export type JoinFn = ReturnType<FromFn>["leftJoin"]; | |
| export type JoinArgs = Parameters<JoinFn>; | |
| export type InsertFn = typeof db.insert; | |
| export type InsertArgs = Parameters<InsertFn>; | |
| export type ValuesFn = ReturnType<InsertFn>["values"]; | |
| export type ValuesArgs = Parameters<ValuesFn>; | |
| export type UpdateFn = typeof db.update; | |
| export type UpdateArgs = Parameters<UpdateFn>; | |
| export type SetFn = ReturnType<UpdateFn>["set"]; | |
| export type SetArgs = Parameters<SetFn>; | |
| export type DeleteFn = typeof db.delete; | |
| export type DeleteArgs = Parameters<DeleteFn>; |
I was able to make transactions work!, with a promising solution.
I just updated all of the patterns to also search for .tx calls, and then I added one more pattern which matches for db.transaction and overrides it, reusing the proxy function setter and overriding each tx properties. I can't believe this is actually working 🤯
/* eslint-disable @typescript-eslint/no-unsafe-call */
/* eslint-disable @typescript-eslint/no-unsafe-member-access */
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-non-null-assertion */
import type { DBQueryConfig, SQLWrapper } from "drizzle-orm";
import { and, eq } from "drizzle-orm";
import type { db as _db } from "../client";
import type {
DbTable,
DeleteArgs,
DeleteFn,
FindArgs,
FindFn,
FromArgs,
FromFn,
InsertArgs,
JoinArgs,
JoinFn,
SetArgs,
SetFn,
Team,
TeamDbClient,
UpdateArgs,
ValuesArgs,
ValuesFn,
WhereArgs,
WhereFn,
} from "./teamDb.types";
import { db } from "../client";
import * as schema from "../schema";
type AnyArgs = any[];
interface InvokeContext {
path?: string[];
fnPath?: { name: string; args: unknown[] }[];
}
interface InterceptFn {
invoke: (...args: unknown[]) => unknown;
name: string;
args: unknown[];
}
interface OverrideFn {
pattern: string | string[];
action: () => unknown;
}
export const getTeamDb = (team: Team): TeamDbClient => {
const teamIdColumn = "teamId";
const getTable = (table: DbTable) => schema[table];
const getAccessPolicy = (
table: {
[teamIdColumn]: any;
},
owner: Team,
) => eq(table[teamIdColumn], owner.id);
const intercept = (fn: InterceptFn, context: InvokeContext = {}) => {
const { path = [], fnPath = [] } = context;
const pathAsString = path.join(".");
const matchPath = (pattern: string) => {
return new RegExp(
`^${pattern.replace(/\./g, "\\.").replace(/\*/g, ".*")}$`,
).test(pathAsString);
};
const overrides: OverrideFn[] = [
{
pattern: "db.transaction",
action: () => {
const transactionFn = fn.invoke as typeof db.transaction;
const [callback] = fn.args as Parameters<typeof db.transaction>;
return transactionFn(async (tx) => {
const wrappedTx = createProxy(tx, { path: ["tx"] });
return callback(wrappedTx);
});
},
},
{
pattern: ["db.execute", "db.*.execute", "tx.execute", "tx.*.execute"],
action: () => {
throw new Error("'execute' in rls DB is not allowed");
},
},
{
pattern: [
"db.query.findMany",
"db.query.*.findMany",
"db.query.findFirst",
"db.query.*.findFirst",
"tx.query.findMany",
"tx.query.*.findMany",
"tx.query.findFirst",
"tx.query.*.findFirst",
],
action: () => {
const findFn = fn.invoke as FindFn;
const findArgs = fn.args as FindArgs;
const tableIndex = path.findIndex((x) => x === "query") + 1;
const tableName = path[tableIndex]! as keyof typeof db.query;
const table = getTable(tableName as DbTable);
if (teamIdColumn in table) {
let [config] = findArgs;
if (config?.where) {
config = {
...config,
where: and(
getAccessPolicy(table, team),
config.where as SQLWrapper,
),
};
}
if (!config?.where) {
config = {
...config,
where: getAccessPolicy(table, team),
};
}
if (config.with) {
config = {
...config,
with: (
Object.keys(config.with) as (keyof typeof config.with)[]
).reduce<DBQueryConfig["with"]>((acc, key) => {
const value = config!.with![key] as
| true
| null
| DBQueryConfig<"many">;
if (value === true) {
return {
...acc,
[key]: {
where: (table) =>
teamIdColumn in table
? // @ts-expect-error: typescript aint easy
getAccessPolicy(table, team)
: undefined,
},
};
}
if (typeof value === "object" && value !== null) {
return {
...acc,
[key]: {
...value,
where: (table, other) =>
teamIdColumn in table
? and(
// @ts-expect-error: typescript aint easy
getAccessPolicy(table, team),
typeof value.where === "function"
? value.where(table, other)
: value.where,
)
: typeof value.where === "function"
? value.where(table, other)
: value.where,
},
};
}
return { ...acc, [key]: value };
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
}, config.with as any),
};
}
return findFn(...([config] as FindArgs));
}
return findFn(...findArgs);
},
},
{
pattern: ["db.*.from", "tx.*.from"],
action: () => {
const fromFn = fn.invoke as FromFn;
const fromArgs = fn.args as FromArgs;
const [table] = fromArgs;
if (teamIdColumn in table) {
return fromFn(...fromArgs).where(getAccessPolicy(table, team));
}
return fromFn(...fromArgs);
},
},
{
pattern: [
"db.*.from.where",
"db.*.from.*.where",
"tx.*.from.where",
"tx.*.from.*.where",
],
action: () => {
const whereFn = fn.invoke as WhereFn;
const whereArgs = fn.args as WhereArgs;
// @ts-expect-error: typescript aint easy
const [table] = fnPath.findLast((x) => x.name === "from")
?.args as FromArgs;
if (teamIdColumn in table) {
const [whereFilter] = whereArgs;
return whereFn(
and(getAccessPolicy(table, team), whereFilter as SQLWrapper),
);
}
return whereFn(...whereArgs);
},
},
{
pattern: [
"db.*.leftJoin",
"db.*.rightJoin",
"db.*.innerJoin",
"db.*.fullJoin",
"tx.*.leftJoin",
"tx.*.rightJoin",
"tx.*.innerJoin",
"tx.*.fullJoin",
],
action: () => {
const joinFn = fn.invoke as JoinFn;
const joinArgs = fn.args as JoinArgs;
const [table, joinOptions] = joinArgs;
if (teamIdColumn in table) {
return joinFn(
table,
and(getAccessPolicy(table, team), joinOptions as SQLWrapper),
);
}
return joinFn(...joinArgs);
},
},
{
pattern: ["db.insert.values", "tx.insert.values"],
action: () => {
const valuesFn = fn.invoke as ValuesFn;
const valuesArgs = fn.args as ValuesArgs;
// @ts-expect-error: typescript aint easy
const [table] = fnPath.findLast((x) => x.name === "insert")
?.args as InsertArgs;
if (teamIdColumn in table) {
let [valuesToInsert] = valuesArgs;
if (!Array.isArray(valuesToInsert)) {
valuesToInsert = [valuesToInsert];
}
const valuesToInsertWithOwner = valuesToInsert.map((value) => ({
...value,
ownerId: team.id,
}));
return valuesFn(valuesToInsertWithOwner);
}
return valuesFn(...valuesArgs);
},
},
{
pattern: ["db.update.set", "tx.update.set"],
action: () => {
const setFn = fn.invoke as SetFn;
const setArgs = fn.args as SetArgs;
// @ts-expect-error: typescript aint easy
const [table] = fnPath.findLast((x) => x.name === "update")
?.args as UpdateArgs;
if (teamIdColumn in table) {
return setFn(...setArgs).where(getAccessPolicy(table, team));
}
return setFn(...setArgs);
},
},
{
pattern: [
"db.update.where",
"db.update.*.where",
"tx.update.where",
"tx.update.*.where",
],
action: () => {
const whereFn = fn.invoke as WhereFn;
const whereArgs = fn.args as WhereArgs;
const [table] = [...fnPath].reverse().find((x) => x.name === "update")
?.args as UpdateArgs;
if (teamIdColumn in table) {
const [whereFilter] = whereArgs;
return whereFn(
and(getAccessPolicy(table, team), whereFilter as SQLWrapper),
);
}
return whereFn(...whereArgs);
},
},
{
pattern: ["db.delete", "tx.delete"],
action: () => {
const deleteFn = fn.invoke as DeleteFn;
const deleteArgs = fn.args as DeleteArgs;
const [table] = deleteArgs;
if (teamIdColumn in table) {
return deleteFn(...deleteArgs).where(getAccessPolicy(table, team));
}
return deleteFn(...deleteArgs);
},
},
{
pattern: [
"db.delete.where",
"db.delete.*.where",
"tx.delete.where",
"tx.delete.*.where",
],
action: () => {
const whereFn = fn.invoke as WhereFn;
const whereArgs = fn.args as WhereArgs;
// @ts-expect-error: typescript aint easy
const [table] = fnPath.findLast((x) => x.name === "delete")
?.args as DeleteArgs;
if (teamIdColumn in table) {
const [whereOptions] = whereArgs;
return whereFn(
and(getAccessPolicy(table, team), whereOptions as SQLWrapper),
);
}
return whereFn(...whereArgs);
},
},
];
const fnOverride = overrides.find(({ pattern, action }) => {
if (Array.isArray(pattern) && pattern.some(matchPath)) {
return action;
}
if (typeof pattern === "string" && matchPath(pattern)) {
return action;
}
return null;
})?.action;
return fnOverride ? fnOverride() : fn.invoke(...fn.args);
};
const createProxy = <T extends object>(
target: T,
context: InvokeContext = {},
): T => {
const { path = [], fnPath = [] } = context;
return new Proxy<T>(target, {
get: (innerTarget, innerTargetProp, innerTargetReceiver) => {
const currentPath = path.concat(innerTargetProp.toString());
const innerTargetPropValue = Reflect.get(
innerTarget,
innerTargetProp,
innerTargetReceiver,
);
if (typeof innerTargetPropValue === "function") {
return (...args: AnyArgs) => {
const currentFnPath = [
...fnPath,
{ name: innerTargetProp.toString(), args },
];
const result = intercept(
{
invoke: innerTargetPropValue.bind(
innerTarget,
) as InterceptFn["invoke"],
name: innerTargetProp.toString(),
args,
},
{ path: currentPath, fnPath: currentFnPath },
);
if (
typeof result === "object" &&
result !== null &&
!Array.isArray(result)
) {
return createProxy(result, {
path: currentPath,
fnPath: currentFnPath,
});
}
return result;
};
} else if (
typeof innerTargetPropValue === "object" &&
innerTargetPropValue !== null &&
!Array.isArray(innerTargetPropValue)
) {
// wrap nested objects in a proxy as well
return createProxy(innerTargetPropValue, {
path: currentPath,
fnPath,
});
}
return innerTargetPropValue;
},
});
};
return createProxy(db, { path: ["db"] });
};Glad it's working for you guys! This is a separate thing from the DB-level RLS that the Drizzle team is working on. This solution is app-level RLS.
And yeah, in the initial version transactions were not supported. That's fixed now and I have made some improvements since then that allow for more flexible policies defined at a table level. I have created a new gist if you're interested - https://gist.github.com/ikupenov/26f3775821c05f17b6f8b7a037fb2c7a.
Here's an example:
// schema/entities/example-entity.ts
import { and, eq, isNotNull, or, sql } from "drizzle-orm";
import { pgTable, text, uuid } from "drizzle-orm/pg-core";
import { hasRole } from "@sheetah/common";
import { policy } from "@sheetah/db/orm";
import {
getOptionalOrgOwnedBaseEntityProps,
getOwnedBaseEntityProps,
} from "./base";
export const entities = pgTable("entity", {
...getOwnedBaseEntityProps(),
...getOptionalOrgOwnedBaseEntityProps(),
description: text("description"),
transactionId: uuid("transaction_id").unique().notNull(),
categoryId: uuid("category_id"),
taxRateId: uuid("tax_rate_id"),
});
policy(entities, ({ userId, orgId, role }) => {
return (
or(
userId ? eq(entities.ownerId, userId) : sql`false`,
orgId && hasRole(role, ["org:admin"])
? and(
isNotNull(entities.organizationId),
eq(entities.organizationId, orgId),
)
: sql`false`,
) ?? sql`false`
);
});
export type Entity = typeof expenses.$inferSelect;FYI, v1 introduces a breaking change for this pattern, since the query builder API is getting revamped, so this solution will need a touch up.
I copied the code, and it is miraculously working for my mysql setup after some modifications.
However, It is not working with transactions.
If someone finds a solution, would appreciate it