diff --git a/src/middleware/index.ts b/src/middleware/index.ts index c47baa4..f55eaca 100644 --- a/src/middleware/index.ts +++ b/src/middleware/index.ts @@ -1,11 +1,9 @@ -import EventEmitter from 'events'; import { parse, validate } from 'graphql'; import { GraphQLSchema } from 'graphql/type/schema'; import { Request, Response, NextFunction, RequestHandler } from 'express'; import buildTypeWeightsFromSchema, { defaultTypeWeightsConfig } from '../analysis/buildTypeWeights'; import setupRateLimiter from './rateLimiterSetup'; import { ExpressMiddlewareConfig, ExpressMiddlewareSet } from '../@types/expressMiddleware'; -import { RateLimiterResponse } from '../@types/rateLimit'; import { connect } from '../utils/redis'; import QueryParser from '../analysis/QueryParser'; @@ -66,77 +64,12 @@ export default function expressGraphQLRateLimiter( middlewareSetup.redis.keyExpiry ); - /** - * We are using a queue and event emitter to handle situations where a user has two concurrent requests being processed. - * The trailing request will be added to the queue to and await the prior request processing by the rate-limiter - * This will maintain the consistency and accuracy of the cache when under load from one user - */ - // stores request IDs for each user in an array to be processed - const requestQueues: { [index: string]: string[] } = {}; - // Manages processing of requests queue - const requestEvents = new EventEmitter(); - - // processes requests (by resolving promises) that have been throttled by throttledProcess - async function processRequestResolver( - userId: string, - timestamp: number, - tokens: number, - resolve: (value: RateLimiterResponse | PromiseLike) => void, - reject: (reason: any) => void - ) { - try { - const response = await rateLimiter.processRequest(userId, timestamp, tokens); - requestQueues[userId] = requestQueues[userId].slice(1); - resolve(response); - // trigger the next event and delete the request queue for this user if there are no more requests to process - requestEvents.emit(requestQueues[userId][0]); - if (requestQueues[userId].length === 0) delete requestQueues[userId]; - } catch (err) { - reject(err); - } - } - - /** - * Throttle rateLimiter.processRequest based on user IP to prevent inaccurate redis reads - * Throttling is based on a event driven promise fulfillment approach. - * Each time a request is received a promise is added to the user's request queue. The promise "subscribes" - * to the previous request in the user's queue then calls processRequest and resolves once the previous request - * is complete. - * @param userId - * @param timestamp - * @param tokens - * @returns - */ - async function throttledProcess( - userId: string, - timestamp: number, - tokens: number - ): Promise { - // Alternatively use crypto.randomUUID() to generate a random uuid - const requestId = `${timestamp}${tokens}`; - - if (!requestQueues[userId]) { - requestQueues[userId] = []; - } - requestQueues[userId].push(requestId); - - return new Promise((resolve, reject) => { - if (requestQueues[userId].length > 1) { - requestEvents.once(requestId, async () => { - await processRequestResolver(userId, timestamp, tokens, resolve, reject); - }); - } else { - processRequestResolver(userId, timestamp, tokens, resolve, reject); - } - }); - } - /** Rate-limiting middleware */ return async ( req: Request, res: Response, next: NextFunction - ): Promise>> => { + ): Promise>> => { const requestTimestamp = new Date().valueOf(); // access the query and variables passed to the server in the body or query string let query; @@ -149,8 +82,9 @@ export default function expressGraphQLRateLimiter( variables = req.body.variables; } if (!query) { + // eslint-disable-next-line no-console console.error( - 'Error in expressGraphQLRateLimiter: There is no query on the request. Rate-Limiting skipped' + '[graphql-gate] Error in expressGraphQLRateLimiter: There is no query on the request. Rate-Limiting skipped' ); return next(); } @@ -169,7 +103,7 @@ export default function expressGraphQLRateLimiter( const queryComplexity = queryParser.processQuery(queryAST); try { - const rateLimiterResponse = await throttledProcess( + const rateLimiterResponse = await rateLimiter.processRequest( ip, requestTimestamp, queryComplexity @@ -207,9 +141,9 @@ export default function expressGraphQLRateLimiter( } return next(); } catch (err) { - // log the error to the console and pass the request onto the next middleware. + // eslint-disable-next-line no-console console.error( - `Error in expressGraphQLRateLimiter processing query. Rate limiting is skipped: ${err}` + `[graphql-gate] Error in expressGraphQLRateLimiter processing query. Rate limiting is skipped: ${err}` ); return next(err); } diff --git a/src/middleware/rateLimiterSetup.ts b/src/middleware/rateLimiterSetup.ts index 26166f8..caea6e9 100644 --- a/src/middleware/rateLimiterSetup.ts +++ b/src/middleware/rateLimiterSetup.ts @@ -1,5 +1,8 @@ +import EventEmitter from 'events'; + import Redis from 'ioredis'; -import { RateLimiterConfig } from '../@types/rateLimit'; + +import { RateLimiter, RateLimiterConfig, RateLimiterResponse } from '../@types/rateLimit'; import TokenBucket from '../rateLimiters/tokenBucket'; import SlidingWindowCounter from '../rateLimiters/slidingWindowCounter'; import SlidingWindowLog from '../rateLimiters/slidingWindowLog'; @@ -9,22 +12,106 @@ import FixedWindow from '../rateLimiters/fixedWindow'; * Instatieate the rateLimiting algorithm class based on the developer selection and options * * @export - * @param {RateLimiterConfig} rateLimiter limiter selection and option + * @param {RateLimiterConfig} rateLimiterConfig limiter selection and option * @param {Redis} client * @param {number} keyExpiry - * @return {*} + * @return {RateLimiter} */ export default function setupRateLimiter( - rateLimiter: RateLimiterConfig, + rateLimiterConfig: RateLimiterConfig, client: Redis, keyExpiry: number -) { +): RateLimiter { + let rateLimiter: RateLimiter; + + /** + * We are using a queue and event emitter to handle situations where a user has two concurrent requests being processed. + * The trailing request will be added to the queue to and await the prior request processing by the rate-limiter + * This will maintain the consistency and accuracy of the cache when under load from one user + */ + // stores request IDs for each user in an array to be processed + const requestQueues: { [index: string]: string[] } = {}; + // Manages processing of requests queue + const requestEvents = new EventEmitter(); + + // processes requests (by resolving promises) that have been throttled by throttledProcess + async function processRequestResolver( + userId: string, + timestamp: number, + tokens: number, + processRequest: ( + userId: string, + timestamp: number, + tokens: number + ) => Promise, + resolve: (value: RateLimiterResponse | PromiseLike) => void, + reject: (reason: unknown) => void + ) { + try { + const response = await processRequest(userId, timestamp, tokens); + requestQueues[userId] = requestQueues[userId].slice(1); + resolve(response); + // trigger the next event and delete the request queue for this user if there are no more requests to process + requestEvents.emit(requestQueues[userId][0]); + if (requestQueues[userId].length === 0) delete requestQueues[userId]; + } catch (err) { + reject(err); + } + } + + /** + * Throttle rateLimiter.processRequest based on user IP to prevent inaccurate redis reads + * Throttling is based on a event driven promise fulfillment approach. + * Each time a request is received a promise is added to the user's request queue. The promise "subscribes" + * to the previous request in the user's queue then calls processRequest and resolves once the previous request + * is complete. + * @param userId + * @param timestamp + * @param tokens + * @returns + */ + async function throttledProcess( + processRequest: ( + userId: string, + timestamp: number, + tokens: number + ) => Promise, + userId: string, + timestamp: number, + tokens = 1 + ): Promise { + // Alternatively use crypto.randomUUID() to generate a random uuid + const requestId = `${timestamp}${tokens}`; + + if (!requestQueues[userId]) { + requestQueues[userId] = []; + } + requestQueues[userId].push(requestId); + + return new Promise((resolve, reject) => { + if (requestQueues[userId].length > 1) { + requestEvents.once(requestId, async () => { + processRequestResolver( + userId, + timestamp, + tokens, + processRequest, + resolve, + reject + ); + }); + } else { + processRequestResolver(userId, timestamp, tokens, processRequest, resolve, reject); + } + }); + } + try { - switch (rateLimiter.type) { + switch (rateLimiterConfig.type) { case 'TOKEN_BUCKET': - return new TokenBucket( - rateLimiter.capacity, - rateLimiter.refillRate, + rateLimiter = new TokenBucket( + rateLimiterConfig.capacity, + rateLimiterConfig.refillRate, client, keyExpiry ); @@ -32,23 +119,25 @@ export default function setupRateLimiter( case 'LEAKY_BUCKET': throw new Error('Leaky Bucket algonithm has not be implemented.'); case 'FIXED_WINDOW': - return new FixedWindow( - rateLimiter.capacity, - rateLimiter.windowSize, + rateLimiter = new FixedWindow( + rateLimiterConfig.capacity, + rateLimiterConfig.windowSize, client, keyExpiry ); + break; case 'SLIDING_WINDOW_LOG': - return new SlidingWindowLog( - rateLimiter.windowSize, - rateLimiter.capacity, + rateLimiter = new SlidingWindowLog( + rateLimiterConfig.windowSize, + rateLimiterConfig.capacity, client, keyExpiry ); + break; case 'SLIDING_WINDOW_COUNTER': - return new SlidingWindowCounter( - rateLimiter.windowSize, - rateLimiter.capacity, + rateLimiter = new SlidingWindowCounter( + rateLimiterConfig.windowSize, + rateLimiterConfig.capacity, client, keyExpiry ); @@ -57,6 +146,19 @@ export default function setupRateLimiter( // typescript should never let us invoke this function with anything other than the options above throw new Error('Selected rate limiting algorithm is not suppported'); } + + // Overwrite the processRequest method with a throttled implementation to ensure async redis interactions are handled + // sequentially for each user. + const boundProcessRequest = rateLimiter.processRequest.bind(rateLimiter); + + rateLimiter.processRequest = async ( + userId: string, + timestamp: number, + tokens = 1 + ): Promise => + throttledProcess(boundProcessRequest, userId, timestamp, tokens); + + return rateLimiter; } catch (err) { throw new Error(`Error in expressGraphQLRateLimiter setting up rate-limiter: ${err}`); } diff --git a/src/rateLimiters/slidingWindowLog.ts b/src/rateLimiters/slidingWindowLog.ts index 3407920..98020cd 100644 --- a/src/rateLimiters/slidingWindowLog.ts +++ b/src/rateLimiters/slidingWindowLog.ts @@ -1,5 +1,5 @@ import Redis from 'ioredis'; -import { RateLimiter, RateLimiterResponse, RedisBucket, RedisLog } from '../@types/rateLimit'; +import { RateLimiter, RateLimiterResponse, RedisLog } from '../@types/rateLimit'; /** * The SlidingWindowLog instance of a RateLimiter limits requests based on a unique user ID. diff --git a/test/middleware/express.test.ts b/test/middleware/express.test.ts index 6518d2a..10462d1 100644 --- a/test/middleware/express.test.ts +++ b/test/middleware/express.test.ts @@ -375,9 +375,7 @@ describe('Express Middleware tests', () => { describe('Adds expected properties to res.locals', () => { test('Adds UNIX timestamp', async () => { - jest.useRealTimers(); await middleware(mockRequest as Request, mockResponse as Response, nextFunction); - jest.useFakeTimers(); // confirm that this is timestamp +/- 5 minutes of now. const now: number = Date.now().valueOf();