import { assert } from '../helpers/error'; import { combineURL } from '../helpers/combineURL'; import { isFunction, isString } from '../helpers/isTypes'; export interface MiddlewareNext { (): Promise; } export interface MiddlewareCallback { (ctx: Context, next: MiddlewareNext): Promise; } export interface MiddlewareUse { /** * 添加中间件 * * @param path 中间件路径 * @param callback 中间件回调 */ ( path: string, callback: MiddlewareCallback, ): MiddlewareManager; /** * 添加中间件 * * @param callback 中间件回调 */ (callback: MiddlewareCallback): MiddlewareManager; } export default class MiddlewareManager { #map = new Map[]>(); /** * 添加中间件 */ use: MiddlewareUse = ( path: string | MiddlewareCallback, callback?: MiddlewareCallback, ) => { if (isFunction(path)) { callback = path; path = '/'; } assert(isString(path), 'path 不是一个 string'); assert(!!path, 'path 不是一个长度大于零的 string'); assert(isFunction(callback), 'callback 不是一个 function'); const middlewares = this.#map.get(path) ?? []; middlewares.push(callback!); this.#map.set(path, middlewares); return this; }; flush(ctx: Context, finish: MiddlewareNext) { const allMiddlewares: MiddlewareCallback[] = []; for (const [path, middlewares] of this.#map.entries()) { const url = combineURL(ctx.req.baseURL, path); const checkRE = new RegExp(`^${url}([/?].*)?`); if (path === '/') { allMiddlewares.push(...middlewares); } else if (checkRE.test(ctx.req.url!)) { allMiddlewares.push(...middlewares); } } const tasks = [...allMiddlewares, finish]; return (function next(): Promise { return tasks.shift()!(ctx, next); })(); } }