diff --git a/src/axios.ts b/src/axios.ts index 2ed4789..4d55c38 100644 --- a/src/axios.ts +++ b/src/axios.ts @@ -84,7 +84,9 @@ function createInstance(config: AxiosRequestConfig) { const axios = createInstance(defaults) as AxiosStatic; axios.create = function create(config) { - return createInstance(mergeConfig(axios.defaults, config)); + const instance = createInstance(mergeConfig(axios.defaults, config)); + instance.flush = axios.middleware.wrap(instance.flush); + return instance; }; axios.version = version; diff --git a/src/constants/methods.ts b/src/constants/methods.ts index a7dccda..5cbb5ec 100644 --- a/src/constants/methods.ts +++ b/src/constants/methods.ts @@ -16,4 +16,7 @@ export const WITH_DATA_METHODS = ['post', 'put', 'patch'] as const; /** * 可以携带 data 的请求方法 */ -export const WITH_DATA_RE = new RegExp(`^${WITH_DATA_METHODS.join('|')}`, 'i'); +export const WITH_DATA_RE = new RegExp( + `^(${WITH_DATA_METHODS.join('|')})`, + 'i', +); diff --git a/src/core/Axios.ts b/src/core/Axios.ts index abfd342..dfbdf32 100644 --- a/src/core/Axios.ts +++ b/src/core/Axios.ts @@ -1,9 +1,6 @@ import { buildURL } from '../helpers/buildURL'; -import { isAbsoluteURL } from '../helpers/isAbsoluteURL'; import { combineURL } from '../helpers/combineURL'; -import { isString } from '../helpers/isTypes'; import { CancelToken } from '../request/cancel'; -import { dispatchRequest } from '../request/dispatchRequest'; import { AxiosTransformer } from '../request/transformData'; import { AxiosAdapter, @@ -14,7 +11,7 @@ import { } from '../adpater/createAdapter'; import InterceptorManager, { Interceptor } from './InterceptorManager'; import { mergeConfig } from './mergeConfig'; -import AxiosDomain from './AxiosDomain'; +import AxiosDomain, { AxiosDomainRequestHandler } from './AxiosDomain'; /** * 请求方法 @@ -319,7 +316,7 @@ export default class Axios extends AxiosDomain { }; constructor(defaults: AxiosRequestConfig = {}) { - super(defaults, (config) => this.#processRequest(config)); + super(defaults, (...args) => this.#processRequest(...args)); } getUri(config: AxiosRequestConfig): string { @@ -334,17 +331,21 @@ export default class Axios extends AxiosDomain { * 派生领域 */ fork = (config: AxiosRequestConfig = {}) => { - if (isString(config.baseURL) && !isAbsoluteURL(config.baseURL)) { - config.baseURL = combineURL(this.defaults.baseURL, config.baseURL); - } - return new AxiosDomain(mergeConfig(this.defaults, config), (config) => - this.#processRequest(config), + config.baseURL = combineURL(this.defaults.baseURL, config.baseURL); + const domain = new AxiosDomain( + mergeConfig(this.defaults, config), + (...args) => this.#processRequest(...args), ); + domain.flush = this.middleware.wrap(domain.flush); + return domain; }; - #processRequest(config: AxiosRequestConfig) { + #processRequest( + config: AxiosRequestConfig, + requestHandlerFn: AxiosDomainRequestHandler, + ) { const requestHandler = { - resolved: dispatchRequest, + resolved: requestHandlerFn, }; const errorHandler = { rejected: config.errorHandler, diff --git a/src/core/AxiosDomain.ts b/src/core/AxiosDomain.ts index 17955c6..e631226 100644 --- a/src/core/AxiosDomain.ts +++ b/src/core/AxiosDomain.ts @@ -5,6 +5,8 @@ import { } from '../constants/methods'; import { isString, isUndefined } from '../helpers/isTypes'; import { deepMerge } from '../helpers/deepMerge'; +import { combineURL } from '../helpers/combineURL'; +import { dispatchRequest } from '../request/dispatchRequest'; import { mergeConfig } from './mergeConfig'; import { AxiosRequestConfig, @@ -12,6 +14,10 @@ import { AxiosResponse, AxiosResponseData, } from './Axios'; +import MiddlewareManager, { + MiddlewareContext, + MiddlewareFlush, +} from './MiddlewareManager'; /** * 请求函数 @@ -56,12 +62,18 @@ export type AxiosDomainRequestMethodWithData = < config?: AxiosRequestConfig, ) => Promise>; +export interface AxiosDomainRequestHandler { + (config: AxiosRequestConfig): Promise; +} + export default class AxiosDomain { /** * 默认请求配置 */ defaults: AxiosRequestConfig; + middleware = new MiddlewareManager(); + /** * 发送请求 */ @@ -112,9 +124,14 @@ export default class AxiosDomain { */ connect!: AxiosDomainRequestMethod; + flush: MiddlewareFlush; + constructor( defaults: AxiosRequestConfig, - processRequest: (config: AxiosRequestConfig) => Promise, + processRequest: ( + config: AxiosRequestConfig, + requestHandler: AxiosDomainRequestHandler, + ) => Promise, ) { this.defaults = defaults; @@ -132,9 +149,25 @@ export default class AxiosDomain { config.method = 'get'; } - return processRequest(mergeConfig(this.defaults, config)); + return processRequest( + mergeConfig(this.defaults, config), + this.#requestHandler, + ); }; + this.flush = this.middleware.wrap(async (ctx) => { + ctx.res = await dispatchRequest(ctx.req); + }); } + + #requestHandler: AxiosDomainRequestHandler = async (config) => { + config.url = combineURL(config.baseURL, config.url); + const ctx: MiddlewareContext = { + req: config, + res: null, + }; + await this.flush(ctx); + return ctx.res as AxiosResponse; + }; } for (const method of PLAIN_METHODS) { diff --git a/src/core/MiddlewareManager.ts b/src/core/MiddlewareManager.ts new file mode 100644 index 0000000..c669bf9 --- /dev/null +++ b/src/core/MiddlewareManager.ts @@ -0,0 +1,70 @@ +import { assert } from '../helpers/error'; +import { combineURL } from '../helpers/combineURL'; +import { isFunction } from '../helpers/isTypes'; +import { AxiosRequestConfig, AxiosResponse } from './Axios'; + +export interface MiddlewareContext { + req: AxiosRequestConfig; + res: null | AxiosResponse; +} + +export interface MiddlewareNext { + (): Promise; +} + +export interface MiddlewareCallback { + (ctx: MiddlewareContext, next: MiddlewareNext): Promise; +} + +export interface MiddlewareFlush { + (ctx: MiddlewareContext): Promise; +} + +export default class MiddlewareManager { + #map = new Map(); + + use(callback: MiddlewareCallback): MiddlewareManager; + use(path: string, callback: MiddlewareCallback): MiddlewareManager; + use(path: string | MiddlewareCallback, callback?: MiddlewareCallback) { + if (isFunction(path)) { + callback = path; + path = '/'; + } + assert(!!path, 'path 不是一个非空的 string'); + + const middlewares = this.#map.get(path) ?? []; + middlewares.push(callback!); + this.#map.set(path, middlewares); + + return this; + } + + wrap(flush: MiddlewareFlush): MiddlewareFlush { + return (ctx) => this.#performer(ctx, flush); + } + + #performer(ctx: MiddlewareContext, flush: MiddlewareFlush) { + const middlewares = [...this.#getAllMiddlewares(ctx), flush]; + + function next(): Promise { + return middlewares.shift()!(ctx, next); + } + + return next(); + } + + #getAllMiddlewares(ctx: MiddlewareContext) { + const allMiddlewares: MiddlewareCallback[] = []; + + for (const [path, middlewares] of this.#map.entries()) { + const url = combineURL(ctx.req.baseURL, path); + + const checkRE = new RegExp(`^${url}([/?].*)?`); + if (checkRE.test(ctx.req.url!)) { + allMiddlewares.push(...middlewares); + } + } + + return allMiddlewares; + } +} diff --git a/src/helpers/combineURL.ts b/src/helpers/combineURL.ts index 94540a0..9a276a5 100644 --- a/src/helpers/combineURL.ts +++ b/src/helpers/combineURL.ts @@ -1,5 +1,10 @@ +import { isAbsoluteURL } from './isAbsoluteURL'; + const combineRE = /(^|[^:])\/{2,}/g; const removeRE = /\/$/; export function combineURL(baseURL = '', url = ''): string { + if (isAbsoluteURL(url)) { + return url; + } return `${baseURL}/${url}`.replace(combineRE, '$1/').replace(removeRE, ''); } diff --git a/src/index.ts b/src/index.ts index 6b560ba..e58ab62 100644 --- a/src/index.ts +++ b/src/index.ts @@ -14,6 +14,11 @@ export type { AxiosUploadProgressEvent, AxiosUploadProgressCallback, } from './core/Axios'; +export type { + MiddlewareContext, + MiddlewareCallback, + MiddlewareNext, +} from './core/MiddlewareManager'; export type { AxiosAdapter, AxiosAdapterRequestConfig, diff --git a/src/request/dispatchRequest.ts b/src/request/dispatchRequest.ts index 41d9491..7276363 100644 --- a/src/request/dispatchRequest.ts +++ b/src/request/dispatchRequest.ts @@ -27,7 +27,6 @@ export function dispatchRequest(config: AxiosRequestConfig) { assert(isString(config.url), 'url 不是一个 string'); assert(isString(config.method), 'method 不是一个 string'); - config.url = transformURL(config); config.method = config.method!.toUpperCase() as AxiosRequestMethod; config.headers = flattenHeaders(config); @@ -39,6 +38,8 @@ export function dispatchRequest(config: AxiosRequestConfig) { delete config.data; } + config.url = transformURL(config); + function onSuccess(response: AxiosResponse) { throwIfCancellationRequested(config); dataTransformer(response, config.transformResponse); diff --git a/src/request/request.ts b/src/request/request.ts index e52e13b..a6dc768 100644 --- a/src/request/request.ts +++ b/src/request/request.ts @@ -6,7 +6,6 @@ import { } from '../core/Axios'; import { AxiosAdapterRequestConfig, - AxiosAdapterRequestMethod, AxiosAdapterResponse, AxiosAdapterResponseError, AxiosAdapterPlatformTask, diff --git a/src/request/transformURL.ts b/src/request/transformURL.ts index b6f4ab8..de9365e 100644 --- a/src/request/transformURL.ts +++ b/src/request/transformURL.ts @@ -2,15 +2,14 @@ import { isPlainObject } from '../helpers/isTypes'; import { buildURL } from '../helpers/buildURL'; import { combineURL } from '../helpers/combineURL'; import { dynamicURL } from '../helpers/dynamicURL'; -import { isAbsoluteURL } from '../helpers/isAbsoluteURL'; import { AxiosRequestConfig } from '../core/Axios'; export function transformURL(config: AxiosRequestConfig) { - let url = config.url ?? ''; - - if (!isAbsoluteURL(url)) url = combineURL(config.baseURL ?? '', url); - const data = isPlainObject(config.data) ? config.data : {}; + + let url = config.url ?? '/'; + + url = combineURL(config.baseURL ?? '', url); url = dynamicURL(url, config.params, data); url = buildURL(url, config.params, config.paramsSerializer); diff --git a/test/axios.test.ts b/test/axios.test.ts index 7dbe3bc..01d497c 100644 --- a/test/axios.test.ts +++ b/test/axios.test.ts @@ -37,6 +37,7 @@ describe('src/axios.ts', () => { }, }), baseURL: 'http://api.com', + method: 'post', data: { id: 1, }, diff --git a/test/core/MiddlewareManager.test.ts b/test/core/MiddlewareManager.test.ts new file mode 100644 index 0000000..3d80831 --- /dev/null +++ b/test/core/MiddlewareManager.test.ts @@ -0,0 +1,78 @@ +import { describe, test, expect, vi } from 'vitest'; +import MiddlewareManager from '@/core/MiddlewareManager'; + +describe('src/core/MiddlewareManager.ts', () => { + test('应该有这些实例属性', () => { + const m = new MiddlewareManager(); + + expect(m.use).toBeTypeOf('function'); + expect(m.wrap).toBeTypeOf('function'); + }); + + test('应该可以添加中间件回调', async () => { + const m = new MiddlewareManager(); + const ctx = { + req: { url: 'https://api.com' }, + res: null, + }; + const res = { + 'src/core/MiddlewareManager.ts': true, + }; + const midde = vi.fn(async (ctx, next) => { + expect(ctx).toBe(ctx); + ctx.req.url = 'test'; + await next(); + expect(ctx.res).toBe(res); + }); + const flush = vi.fn(async (ctx) => { + expect(ctx.req.url).toBe('test'); + ctx.res = res; + }); + + m.use(midde); + await m.wrap(flush)(ctx); + + expect(ctx.res).toBe(res); + expect(midde).toBeCalled(); + }); + + test('应该可以给路径添加中间件回调', async () => { + const m = new MiddlewareManager(); + const ctx1 = { + req: { + baseURL: 'https://api.com', + url: 'https://api.com', + }, + res: null, + }; + const ctx2 = { + req: { + baseURL: 'https://api.com', + url: 'https://api.com/test', + }, + res: null, + }; + const res = { + 'src/core/MiddlewareManager.ts': true, + }; + const midde = vi.fn(async (ctx, next) => { + expect(ctx).toBe(ctx); + await next(); + expect(ctx.res).toBe(res); + }); + const flush = vi.fn(async (ctx) => { + ctx.res = res; + }); + + m.use('/test', midde); + await m.wrap(flush)(ctx1); + + expect(ctx1.res).toBe(res); + expect(midde).not.toBeCalled(); + + m.use('/test', midde); + await m.wrap(flush)(ctx2); + + expect(midde).toBeCalled(); + }); +}); diff --git a/test/helpers/combineURL.test.ts b/test/helpers/combineURL.test.ts index c13226f..9a2c6c4 100644 --- a/test/helpers/combineURL.test.ts +++ b/test/helpers/combineURL.test.ts @@ -14,6 +14,12 @@ describe('src/helpers/combineURL.ts', () => { expect(combineURL('unknow://api.com', '')).toBe('unknow://api.com'); }); + test('应该直接返回第二个参数', () => { + expect(combineURL('', 'http://api.com')).toBe('http://api.com'); + expect(combineURL('', 'file://api.com')).toBe('file://api.com'); + expect(combineURL('', 'unknow://api.com')).toBe('unknow://api.com'); + }); + test('应该得到拼接后的结果', () => { expect(combineURL('http://api.com', 'test')).toBe('http://api.com/test'); expect(combineURL('file://api.com', '/test')).toBe('file://api.com/test'); diff --git a/test/request/dispatchRequest.test.ts b/test/request/dispatchRequest.test.ts index 255adff..9893340 100644 --- a/test/request/dispatchRequest.test.ts +++ b/test/request/dispatchRequest.test.ts @@ -100,6 +100,7 @@ describe('src/request/dispatchRequest.ts', () => { }; const c3 = { ...defaults, + method: 'post' as const, url: 'test/:id', data: { id: 1,