@@ -8,7 +8,7 @@ import type {
88 DrizzleInstance ,
99 DrizzleQueryFunction ,
1010 DrizzleQueryFunctionInput ,
11- DrizzleTableType ,
11+ DrizzleTableValueType ,
1212} from "./types/drizzleInstanceType" ;
1313import { RumbleError } from "./types/rumbleError" ;
1414import type {
@@ -211,7 +211,7 @@ export const createAbilityBuilder = <
211211 by : (
212212 explicitFilter : Filter <
213213 UserContext ,
214- DrizzleTableType < DB , TableName >
214+ DrizzleTableValueType < DB , TableName >
215215 > ,
216216 ) => {
217217 for ( const action of actions ) {
@@ -253,7 +253,7 @@ export const createAbilityBuilder = <
253253 } ) {
254254 return ( buildersPerTable [ table ] as any ) . _ . runtimeFilters . get (
255255 action ,
256- ) ! as Filter < UserContext , DrizzleTableType < DB , TableNames > > [ ] ;
256+ ) ! as Filter < UserContext , DrizzleTableValueType < DB , TableNames > > [ ] ;
257257 } ,
258258 build ( ) {
259259 const createFilterForTable = < TableName extends TableNames > (
@@ -494,107 +494,108 @@ export const createAbilityBuilder = <
494494 } ;
495495
496496 return {
497- filter : ( {
498- action,
499- userContext,
500- } : {
501- action : Action ;
502- userContext : UserContext ;
503- } ) => {
504- const filters = queryFilters . get ( action ) ;
497+ withContext : ( userContext : UserContext ) => {
498+ return {
499+ filter : ( action : Action ) => {
500+ const filters = queryFilters . get ( action ) ;
505501
506- // in case we have a wildcard ability, skip the rest and return no filters at all
507- if ( filters === "unrestricted" ) {
508- return transformToResponse ( ) ;
509- }
502+ // in case we have a wildcard ability, skip the rest and return no filters at all
503+ if ( filters === "unrestricted" ) {
504+ return transformToResponse ( ) ;
505+ }
510506
511- // if nothing has been allowed, block everything
512- if ( ! filters ) {
513- nothingRegisteredWarningLogger ( tableName . toString ( ) , action ) ;
514- return transformToResponse ( blockEverythingFilter as any ) ;
515- }
507+ // if nothing has been allowed, block everything
508+ if ( ! filters ) {
509+ nothingRegisteredWarningLogger (
510+ tableName . toString ( ) ,
511+ action ,
512+ ) ;
513+ return transformToResponse ( blockEverythingFilter as any ) ;
514+ }
516515
517- // run all dynamic filters
518- const dynamicResults = new Array <
519- DrizzleQueryFunctionInput < DB , TableName >
520- > ( dynamicQueryFilters [ action ] . length ) ;
521- let filtersReturned = 0 ;
522- for ( let i = 0 ; i < dynamicQueryFilters [ action ] . length ; i ++ ) {
523- const func = dynamicQueryFilters [ action ] [ i ] ;
524- const result = func ( userContext ) ;
525- // if one of the dynamic filters returns "allow", we want to allow everything
526- if ( result === "allow" ) {
527- return transformToResponse ( ) ;
528- }
529- // if nothing is returned, nothing is allowed by this filter
530- if ( result === undefined ) continue ;
531-
532- dynamicResults . push ( result ) ;
533- filtersReturned ++ ;
534- }
535- dynamicResults . length = filtersReturned ;
516+ // run all dynamic filters
517+ const dynamicResults = new Array <
518+ DrizzleQueryFunctionInput < DB , TableName >
519+ > ( dynamicQueryFilters [ action ] . length ) ;
520+ let filtersReturned = 0 ;
521+ for ( let i = 0 ; i < dynamicQueryFilters [ action ] . length ; i ++ ) {
522+ const func = dynamicQueryFilters [ action ] [ i ] ;
523+ const result = func ( userContext ) ;
524+ // if one of the dynamic filters returns "allow", we want to allow everything
525+ if ( result === "allow" ) {
526+ return transformToResponse ( ) ;
527+ }
528+ // if nothing is returned, nothing is allowed by this filter
529+ if ( result === undefined ) continue ;
536530
537- const allQueryFilters = [
538- ... simpleQueryFilters [ action ] ,
539- ... dynamicResults ,
540- ] ;
531+ dynamicResults . push ( result ) ;
532+ filtersReturned ++ ;
533+ }
534+ dynamicResults . length = filtersReturned ;
541535
542- // if we don't have any permitted filters then block everything
543- if ( allQueryFilters . length === 0 ) {
544- return transformToResponse ( blockEverythingFilter as any ) ;
545- }
536+ const allQueryFilters = [
537+ ... simpleQueryFilters [ action ] ,
538+ ... dynamicResults ,
539+ ] ;
546540
547- let highestLimit : number | undefined ;
548- for ( let i = 0 ; i < allQueryFilters . length ; i ++ ) {
549- const conditionObject = allQueryFilters [ i ] ;
550- if ( conditionObject ?. limit ) {
551- if (
552- highestLimit === undefined ||
553- ( conditionObject . limit as number ) > highestLimit
554- ) {
555- highestLimit = conditionObject . limit as number ;
541+ // if we don't have any permitted filters then block everything
542+ if ( allQueryFilters . length === 0 ) {
543+ return transformToResponse ( blockEverythingFilter as any ) ;
556544 }
557- }
558- }
559545
560- let allowedColumns : Set < string > | undefined ;
561- for ( let i = 0 ; i < allQueryFilters . length ; i ++ ) {
562- const conditionObject = allQueryFilters [ i ] ;
563- if ( conditionObject ?. columns ) {
564- if ( allowedColumns === undefined ) {
565- allowedColumns = new Set (
566- Object . keys ( conditionObject . columns ) ,
567- ) ;
568- } else {
569- const fields = Object . keys ( conditionObject . columns ) ;
570- for ( let i = 0 ; i < fields . length ; i ++ ) {
571- allowedColumns . add ( fields [ i ] ) ;
546+ let highestLimit : number | undefined ;
547+ for ( let i = 0 ; i < allQueryFilters . length ; i ++ ) {
548+ const conditionObject = allQueryFilters [ i ] ;
549+ if ( conditionObject ?. limit ) {
550+ if (
551+ highestLimit === undefined ||
552+ ( conditionObject . limit as number ) > highestLimit
553+ ) {
554+ highestLimit = conditionObject . limit as number ;
555+ }
556+ }
557+ }
558+
559+ let allowedColumns : Set < string > | undefined ;
560+ for ( let i = 0 ; i < allQueryFilters . length ; i ++ ) {
561+ const conditionObject = allQueryFilters [ i ] ;
562+ if ( conditionObject ?. columns ) {
563+ if ( allowedColumns === undefined ) {
564+ allowedColumns = new Set (
565+ Object . keys ( conditionObject . columns ) ,
566+ ) ;
567+ } else {
568+ const fields = Object . keys ( conditionObject . columns ) ;
569+ for ( let i = 0 ; i < fields . length ; i ++ ) {
570+ allowedColumns . add ( fields [ i ] ) ;
571+ }
572+ }
572573 }
573574 }
574- }
575- }
576575
577- const accumulatedWhereConditions = allQueryFilters
578- . filter ( ( o ) => o ?. where )
579- . map ( ( o ) => o ! . where ) ;
580-
581- return transformToResponse ( {
582- where :
583- accumulatedWhereConditions . length > 0
584- ? { OR : accumulatedWhereConditions }
585- : undefined ,
586- columns : allowedColumns
587- ? Object . fromEntries (
588- Array . from ( allowedColumns ) . map ( ( key ) => [ key , true ] ) ,
589- )
590- : undefined ,
591- limit : highestLimit ,
592- } as any ) ;
576+ const accumulatedWhereConditions = allQueryFilters
577+ . filter ( ( o ) => o ?. where )
578+ . map ( ( o ) => o ! . where ) ;
579+
580+ return transformToResponse ( {
581+ where :
582+ accumulatedWhereConditions . length > 0
583+ ? { OR : accumulatedWhereConditions }
584+ : undefined ,
585+ columns : allowedColumns
586+ ? Object . fromEntries (
587+ Array . from ( allowedColumns ) . map ( ( key ) => [ key , true ] ) ,
588+ )
589+ : undefined ,
590+ limit : highestLimit ,
591+ } as any ) ;
592+ } ,
593+ } ;
593594 } ,
594595 } ;
595596 } ;
596597
597- const ret = Object . fromEntries (
598+ const abilitiesPerTable = Object . fromEntries (
598599 ( Object . keys ( db . query ) as TableNames [ ] ) . map ( ( tableName ) => [
599600 tableName ,
600601 createFilterForTable ( tableName ) ,
@@ -605,7 +606,20 @@ export const createAbilityBuilder = <
605606
606607 hasBeenBuilt = true ;
607608
608- return ret ;
609+ return ( ctx : UserContext ) => {
610+ return Object . fromEntries (
611+ ( Object . keys ( abilitiesPerTable ) as TableNames [ ] ) . map (
612+ ( tableName ) => [
613+ tableName ,
614+ abilitiesPerTable [ tableName ] . withContext ( ctx ) ,
615+ ] ,
616+ ) ,
617+ ) as {
618+ [ key in TableNames ] : ReturnType <
619+ ReturnType < typeof createFilterForTable < key > > [ "withContext" ]
620+ > ;
621+ } ;
622+ } ;
609623 } ,
610624 } ,
611625 } ;
0 commit comments