Skip to content

Commit 463bee3

Browse files
feat: allow extending toEqual (fix #2875) (#4880)
Co-authored-by: Vladimir <[email protected]>
1 parent 7f59a1b commit 463bee3

File tree

11 files changed

+380
-43
lines changed

11 files changed

+380
-43
lines changed

docs/api/expect.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,3 +1405,55 @@ Don't forget to include the ambient declaration file in your `tsconfig.json`.
14051405
:::tip
14061406
If you want to know more, checkout [guide on extending matchers](/guide/extending-matchers).
14071407
:::
1408+
1409+
## expect.addEqualityTesters <Badge type="info">1.2.0+</Badge>
1410+
1411+
- **Type:** `(tester: Array<Tester>) => void`
1412+
1413+
You can use this method to define custom testers, which are methods used by matchers, to test if two objects are equal. It is compatible with Jest's `expect.addEqualityTesters`.
1414+
1415+
```ts
1416+
import { expect, test } from 'vitest'
1417+
1418+
class AnagramComparator {
1419+
public word: string
1420+
1421+
constructor(word: string) {
1422+
this.word = word
1423+
}
1424+
1425+
equals(other: AnagramComparator): boolean {
1426+
const cleanStr1 = this.word.replace(/ /g, '').toLowerCase()
1427+
const cleanStr2 = other.word.replace(/ /g, '').toLowerCase()
1428+
1429+
const sortedStr1 = cleanStr1.split('').sort().join('')
1430+
const sortedStr2 = cleanStr2.split('').sort().join('')
1431+
1432+
return sortedStr1 === sortedStr2
1433+
}
1434+
}
1435+
1436+
function isAnagramComparator(a: unknown): a is AnagramComparator {
1437+
return a instanceof AnagramComparator
1438+
}
1439+
1440+
function areAnagramsEqual(a: unknown, b: unknown): boolean | undefined {
1441+
const isAAnagramComparator = isAnagramComparator(a)
1442+
const isBAnagramComparator = isAnagramComparator(b)
1443+
1444+
if (isAAnagramComparator && isBAnagramComparator)
1445+
return a.equals(b)
1446+
1447+
else if (isAAnagramComparator === isBAnagramComparator)
1448+
return undefined
1449+
1450+
else
1451+
return false
1452+
}
1453+
1454+
expect.addEqualityTesters([areAnagramsEqual])
1455+
1456+
test('custom equality tester', () => {
1457+
expect(new AnagramComparator('listen')).toEqual(new AnagramComparator('silent'))
1458+
})
1459+
```

packages/expect/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ export * from './constants'
44
export * from './types'
55
export { getState, setState } from './state'
66
export { JestChaiExpect } from './jest-expect'
7+
export { addCustomEqualityTesters } from './jest-matcher-utils'
78
export { JestExtend } from './jest-extend'
89
export { setupColors } from '@vitest/utils'

packages/expect/src/jest-asymmetric-matchers.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import type { ChaiPlugin, MatcherState } from './types'
22
import { GLOBAL_EXPECT } from './constants'
33
import { getState } from './state'
4-
import { diff, getMatcherUtils, stringify } from './jest-matcher-utils'
4+
import { diff, getCustomEqualityTesters, getMatcherUtils, stringify } from './jest-matcher-utils'
55

66
import { equals, isA, iterableEquality, pluralize, subsetEquality } from './jest-utils'
77

@@ -26,7 +26,7 @@ export abstract class AsymmetricMatcher<
2626
...getState(expect || (globalThis as any)[GLOBAL_EXPECT]),
2727
equals,
2828
isNot: this.inverse,
29-
customTesters: [],
29+
customTesters: getCustomEqualityTesters(),
3030
utils: {
3131
...getMatcherUtils(),
3232
diff,
@@ -116,8 +116,9 @@ export class ObjectContaining extends AsymmetricMatcher<Record<string, unknown>>
116116

117117
let result = true
118118

119+
const matcherContext = this.getMatcherContext()
119120
for (const property in this.sample) {
120-
if (!this.hasProperty(other, property) || !equals(this.sample[property], other[property])) {
121+
if (!this.hasProperty(other, property) || !equals(this.sample[property], other[property], matcherContext.customTesters)) {
121122
result = false
122123
break
123124
}
@@ -149,11 +150,12 @@ export class ArrayContaining<T = unknown> extends AsymmetricMatcher<Array<T>> {
149150
)
150151
}
151152

153+
const matcherContext = this.getMatcherContext()
152154
const result
153155
= this.sample.length === 0
154156
|| (Array.isArray(other)
155157
&& this.sample.every(item =>
156-
other.some(another => equals(item, another)),
158+
other.some(another => equals(item, another, matcherContext.customTesters)),
157159
))
158160

159161
return this.inverse ? !result : result

packages/expect/src/jest-expect.ts

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import type { Test } from '@vitest/runner'
66
import type { Assertion, ChaiPlugin } from './types'
77
import { arrayBufferEquality, generateToBeMessage, iterableEquality, equals as jestEquals, sparseArrayEquality, subsetEquality, typeEquality } from './jest-utils'
88
import type { AsymmetricMatcher } from './jest-asymmetric-matchers'
9-
import { diff, stringify } from './jest-matcher-utils'
9+
import { diff, getCustomEqualityTesters, stringify } from './jest-matcher-utils'
1010
import { JEST_MATCHERS_OBJECT } from './constants'
1111
import { recordAsyncExpect, wrapSoft } from './utils'
1212

@@ -23,6 +23,7 @@ declare class DOMTokenList {
2323
export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
2424
const { AssertionError } = chai
2525
const c = () => getColors()
26+
const customTesters = getCustomEqualityTesters()
2627

2728
function def(name: keyof Assertion | (keyof Assertion)[], fn: ((this: Chai.AssertionStatic & Assertion, ...args: any[]) => any)) {
2829
const addMethod = (n: keyof Assertion) => {
@@ -80,7 +81,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
8081
const equal = jestEquals(
8182
actual,
8283
expected,
83-
[iterableEquality],
84+
[...customTesters, iterableEquality],
8485
)
8586

8687
return this.assert(
@@ -98,6 +99,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
9899
obj,
99100
expected,
100101
[
102+
...customTesters,
101103
iterableEquality,
102104
typeEquality,
103105
sparseArrayEquality,
@@ -125,6 +127,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
125127
actual,
126128
expected,
127129
[
130+
...customTesters,
128131
iterableEquality,
129132
typeEquality,
130133
sparseArrayEquality,
@@ -140,7 +143,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
140143
const toEqualPass = jestEquals(
141144
actual,
142145
expected,
143-
[iterableEquality],
146+
[...customTesters, iterableEquality],
144147
)
145148

146149
if (toEqualPass)
@@ -159,7 +162,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
159162
def('toMatchObject', function (expected) {
160163
const actual = this._obj
161164
return this.assert(
162-
jestEquals(actual, expected, [iterableEquality, subsetEquality]),
165+
jestEquals(actual, expected, [...customTesters, iterableEquality, subsetEquality]),
163166
'expected #{this} to match object #{exp}',
164167
'expected #{this} to not match object #{exp}',
165168
expected,
@@ -208,7 +211,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
208211
def('toContainEqual', function (expected) {
209212
const obj = utils.flag(this, 'object')
210213
const index = Array.from(obj).findIndex((item) => {
211-
return jestEquals(item, expected)
214+
return jestEquals(item, expected, customTesters)
212215
})
213216

214217
this.assert(
@@ -339,7 +342,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
339342
return utils.getPathInfo(actual, propertyName)
340343
}
341344
const { value, exists } = getValue()
342-
const pass = exists && (args.length === 1 || jestEquals(expected, value))
345+
const pass = exists && (args.length === 1 || jestEquals(expected, value, customTesters))
343346

344347
const valueString = args.length === 1 ? '' : ` with value ${utils.objDisplay(expected)}`
345348

@@ -482,7 +485,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
482485
def(['toHaveBeenCalledWith', 'toBeCalledWith'], function (...args) {
483486
const spy = getSpy(this)
484487
const spyName = spy.getMockName()
485-
const pass = spy.mock.calls.some(callArg => jestEquals(callArg, args, [iterableEquality]))
488+
const pass = spy.mock.calls.some(callArg => jestEquals(callArg, args, [...customTesters, iterableEquality]))
486489
const isNot = utils.flag(this, 'negate') as boolean
487490

488491
const msg = utils.getMessage(
@@ -504,7 +507,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
504507
const nthCall = spy.mock.calls[times - 1]
505508

506509
this.assert(
507-
jestEquals(nthCall, args, [iterableEquality]),
510+
jestEquals(nthCall, args, [...customTesters, iterableEquality]),
508511
`expected ${ordinalOf(times)} "${spyName}" call to have been called with #{exp}`,
509512
`expected ${ordinalOf(times)} "${spyName}" call to not have been called with #{exp}`,
510513
args,
@@ -517,7 +520,7 @@ export const JestChaiExpect: ChaiPlugin = (chai, utils) => {
517520
const lastCall = spy.mock.calls[spy.mock.calls.length - 1]
518521

519522
this.assert(
520-
jestEquals(lastCall, args, [iterableEquality]),
523+
jestEquals(lastCall, args, [...customTesters, iterableEquality]),
521524
`expected last "${spyName}" call to have been called with #{exp}`,
522525
`expected last "${spyName}" call to not have been called with #{exp}`,
523526
args,

packages/expect/src/jest-extend.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import { ASYMMETRIC_MATCHERS_OBJECT, JEST_MATCHERS_OBJECT } from './constants'
1010
import { AsymmetricMatcher } from './jest-asymmetric-matchers'
1111
import { getState } from './state'
1212

13-
import { diff, getMatcherUtils, stringify } from './jest-matcher-utils'
13+
import { diff, getCustomEqualityTesters, getMatcherUtils, stringify } from './jest-matcher-utils'
1414

1515
import {
1616
equals,
@@ -33,8 +33,7 @@ function getMatcherState(assertion: Chai.AssertionStatic & Chai.Assertion, expec
3333

3434
const matcherState: MatcherState = {
3535
...getState(expect),
36-
// TODO: implement via expect.addEqualityTesters
37-
customTesters: [],
36+
customTesters: getCustomEqualityTesters(),
3837
isNot,
3938
utils: jestUtils,
4039
promise,

packages/expect/src/jest-matcher-utils.ts

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import { getColors, stringify } from '@vitest/utils'
2-
import type { MatcherHintOptions } from './types'
1+
import { getColors, getType, stringify } from '@vitest/utils'
2+
import type { MatcherHintOptions, Tester } from './types'
3+
import { JEST_MATCHERS_OBJECT } from './constants'
34

45
export { diff } from '@vitest/utils/diff'
56
export { stringify }
@@ -101,3 +102,21 @@ export function getMatcherUtils() {
101102
printExpected,
102103
}
103104
}
105+
106+
export function addCustomEqualityTesters(newTesters: Array<Tester>): void {
107+
if (!Array.isArray(newTesters)) {
108+
throw new TypeError(
109+
`expect.customEqualityTesters: Must be set to an array of Testers. Was given "${getType(
110+
newTesters,
111+
)}"`,
112+
)
113+
}
114+
115+
(globalThis as any)[JEST_MATCHERS_OBJECT].customEqualityTesters.push(
116+
...newTesters,
117+
)
118+
}
119+
120+
export function getCustomEqualityTesters(): Array<Tester> {
121+
return (globalThis as any)[JEST_MATCHERS_OBJECT].customEqualityTesters
122+
}

0 commit comments

Comments
 (0)