diff --git a/src/utils/byte_utils.ts b/src/utils/byte_utils.ts index d611c906..8cc7e56d 100644 --- a/src/utils/byte_utils.ts +++ b/src/utils/byte_utils.ts @@ -23,8 +23,6 @@ export type ByteUtils = { fromHex: (hex: string) => Uint8Array; /** Create a lowercase hex string from bytes */ toHex: (buffer: Uint8Array) => string; - /** Create a Uint8Array containing utf8 code units from a string */ - fromUTF8: (text: string) => Uint8Array; /** Create a string from utf8 code units, fatal=true will throw an error if UTF-8 bytes are invalid, fatal=false will insert replacement characters */ toUTF8: (buffer: Uint8Array, start: number, end: number, fatal: boolean) => string; /** Get the utf8 code unit count from a string if it were to be transformed to utf8 */ diff --git a/src/utils/latin.ts b/src/utils/latin.ts index 860497db..5dd5c91f 100644 --- a/src/utils/latin.ts +++ b/src/utils/latin.ts @@ -13,7 +13,11 @@ * @param end - The index to stop searching the uint8array * @returns string if all bytes are within the basic latin range, otherwise null */ -export function tryLatin(uint8array: Uint8Array, start: number, end: number): string | null { +export function tryReadBasicLatin( + uint8array: Uint8Array, + start: number, + end: number +): string | null { if (uint8array.length === 0) { return ''; } @@ -59,3 +63,42 @@ export function tryLatin(uint8array: Uint8Array, start: number, end: number): st return String.fromCharCode(...latinBytes); } + +/** + * This function is an optimization for writing small basic latin strings. + * @internal + * @remarks + * ### Important characteristics: + * - If the string length is 0 return 0, do not perform any work + * - If a string is longer than 25 code units return null + * - If any code unit exceeds 128 this function returns null + * + * @param destination - The uint8array to serialize the string to + * @param source - The string to turn into UTF-8 bytes if it fits in the basic latin range + * @param offset - The position in the destination to begin writing bytes to + * @returns the number of bytes written to destination if all code units are below 128, otherwise null + */ +export function tryWriteBasicLatin( + destination: Uint8Array, + source: string, + offset: number +): number | null { + if (source.length === 0) return 0; + + if (source.length > 25) return null; + + if (destination.length - offset < source.length) return null; + + for ( + let charOffset = 0, destinationOffset = offset; + charOffset < source.length; + charOffset++, destinationOffset++ + ) { + const char = source.charCodeAt(charOffset); + if (char > 127) return null; + + destination[destinationOffset] = char; + } + + return source.length; +} diff --git a/src/utils/node_byte_utils.ts b/src/utils/node_byte_utils.ts index 85811e14..519fad2f 100644 --- a/src/utils/node_byte_utils.ts +++ b/src/utils/node_byte_utils.ts @@ -1,6 +1,6 @@ import { BSONError } from '../error'; import { validateUtf8 } from '../validate_utf8'; -import { tryLatin } from './latin'; +import { tryReadBasicLatin, tryWriteBasicLatin } from './latin'; type NodeJsEncoding = 'base64' | 'hex' | 'utf8' | 'binary'; type NodeJsBuffer = ArrayBufferView & @@ -123,12 +123,8 @@ export const nodeJsByteUtils = { return nodeJsByteUtils.toLocalBufferType(buffer).toString('hex'); }, - fromUTF8(text: string): NodeJsBuffer { - return Buffer.from(text, 'utf8'); - }, - toUTF8(buffer: Uint8Array, start: number, end: number, fatal: boolean): string { - const basicLatin = end - start <= 20 ? tryLatin(buffer, start, end) : null; + const basicLatin = end - start <= 20 ? tryReadBasicLatin(buffer, start, end) : null; if (basicLatin != null) { return basicLatin; } @@ -153,6 +149,11 @@ export const nodeJsByteUtils = { }, encodeUTF8Into(buffer: Uint8Array, source: string, byteOffset: number): number { + const latinBytesWritten = tryWriteBasicLatin(buffer, source, byteOffset); + if (latinBytesWritten != null) { + return latinBytesWritten; + } + return nodeJsByteUtils.toLocalBufferType(buffer).write(source, byteOffset, undefined, 'utf8'); }, diff --git a/src/utils/web_byte_utils.ts b/src/utils/web_byte_utils.ts index 9e38efd4..374e3835 100644 --- a/src/utils/web_byte_utils.ts +++ b/src/utils/web_byte_utils.ts @@ -1,5 +1,5 @@ import { BSONError } from '../error'; -import { tryLatin } from './latin'; +import { tryReadBasicLatin } from './latin'; type TextDecoder = { readonly encoding: string; @@ -169,12 +169,8 @@ export const webByteUtils = { return Array.from(uint8array, byte => byte.toString(16).padStart(2, '0')).join(''); }, - fromUTF8(text: string): Uint8Array { - return new TextEncoder().encode(text); - }, - toUTF8(uint8array: Uint8Array, start: number, end: number, fatal: boolean): string { - const basicLatin = end - start <= 20 ? tryLatin(uint8array, start, end) : null; + const basicLatin = end - start <= 20 ? tryReadBasicLatin(uint8array, start, end) : null; if (basicLatin != null) { return basicLatin; } @@ -190,11 +186,11 @@ export const webByteUtils = { }, utf8ByteLength(input: string): number { - return webByteUtils.fromUTF8(input).byteLength; + return new TextEncoder().encode(input).byteLength; }, encodeUTF8Into(buffer: Uint8Array, source: string, byteOffset: number): number { - const bytes = webByteUtils.fromUTF8(source); + const bytes = new TextEncoder().encode(source); buffer.set(bytes, byteOffset); return bytes.byteLength; }, diff --git a/test/node/byte_utils.test.ts b/test/node/byte_utils.test.ts index 9526329a..fa6d7f89 100644 --- a/test/node/byte_utils.test.ts +++ b/test/node/byte_utils.test.ts @@ -365,33 +365,35 @@ const toISO88591Tests: ByteUtilTest<'toISO88591'>[] = [ } } ]; -const fromUTF8Tests: ByteUtilTest<'fromUTF8'>[] = [ +const fromUTF8Tests: ByteUtilTest<'encodeUTF8Into'>[] = [ { - name: 'should create buffer from utf8 input', - inputs: [Buffer.from('abc\u{1f913}', 'utf8').toString('utf8')], + name: 'should insert utf8 bytes into buffer', + inputs: [Buffer.alloc(7), 'abc\u{1f913}', 0], expectation({ output, error }) { expect(error).to.be.null; - expect(output).to.deep.equal(Buffer.from('abc\u{1f913}', 'utf8')); + expect(output).to.equal(7); + expect(this.inputs[0]).to.deep.equal(Buffer.from('abc\u{1f913}', 'utf8')); } }, { - name: 'should return empty buffer for empty string input', - inputs: [''], + name: 'should return 0 and not modify input buffer', + inputs: [Uint8Array.from([2, 2]), '', 0], expectation({ output, error }) { expect(error).to.be.null; - expect(output).to.have.property('byteLength', 0); + expect(output).to.equal(0); + expect(this.inputs[0]).to.deep.equal(Uint8Array.from([2, 2])); } }, { - name: 'should return bytes with replacement character if string is not encodable', - inputs: ['\u{1f913}'.slice(0, 1)], + name: 'should insert replacement character bytes if string is not encodable', + inputs: [Uint8Array.from({ length: 10 }, () => 2), '\u{1f913}'.slice(0, 1), 2], expectation({ output, error }) { expect(error).to.be.null; - expect(output).to.have.property('byteLength', 3); - expect(output).to.have.property('0', 0xef); - expect(output).to.have.property('1', 0xbf); - expect(output).to.have.property('2', 0xbd); - const backToString = Buffer.from(output!).toString('utf8'); + expect(output).to.equal(3); + expect(this.inputs[0]).to.have.property('2', 0xef); + expect(this.inputs[0]).to.have.property('3', 0xbf); + expect(this.inputs[0]).to.have.property('4', 0xbd); + const backToString = Buffer.from(this.inputs[0].subarray(2, 5)).toString('utf8'); const replacementCharacter = '\u{fffd}'; expect(backToString).to.equal(replacementCharacter); } @@ -507,7 +509,7 @@ const table = new Map[]>([ ['toHex', toHexTests], ['fromISO88591', fromISO88591Tests], ['toISO88591', toISO88591Tests], - ['fromUTF8', fromUTF8Tests], + ['encodeUTF8Into', fromUTF8Tests], ['toUTF8', toUTF8Tests], ['utf8ByteLength', utf8ByteLengthTests], ['randomBytes', randomBytesTests] diff --git a/test/node/utils/latin.test.ts b/test/node/utils/latin.test.ts index 96e0e6fa..6caea393 100644 --- a/test/node/utils/latin.test.ts +++ b/test/node/utils/latin.test.ts @@ -1,17 +1,17 @@ import { expect } from 'chai'; -import { tryLatin } from '../../../src/utils/latin'; +import { tryReadBasicLatin, tryWriteBasicLatin } from '../../../src/utils/latin'; import * as sinon from 'sinon'; -describe('tryLatin()', () => { +describe('tryReadBasicLatin()', () => { context('when given a buffer of length 0', () => { it('returns an empty string', () => { - expect(tryLatin(new Uint8Array(), 0, 10)).to.equal(''); + expect(tryReadBasicLatin(new Uint8Array(), 0, 10)).to.equal(''); }); }); context('when the distance between end and start is 0', () => { it('returns an empty string', () => { - expect(tryLatin(new Uint8Array([1, 2, 3]), 0, 0)).to.equal(''); + expect(tryReadBasicLatin(new Uint8Array([1, 2, 3]), 0, 0)).to.equal(''); }); }); @@ -30,17 +30,17 @@ describe('tryLatin()', () => { context('when there is 1 byte', () => { context('that exceed 127', () => { it('returns null', () => { - expect(tryLatin(new Uint8Array([128]), 0, 1)).be.null; + expect(tryReadBasicLatin(new Uint8Array([128]), 0, 1)).be.null; }); }); it('calls fromCharCode once', () => { - tryLatin(new Uint8Array([95]), 0, 1); + tryReadBasicLatin(new Uint8Array([95]), 0, 1); expect(fromCharCodeSpy).to.have.been.calledOnce; }); it('never calls array.push', () => { - tryLatin(new Uint8Array([95]), 0, 1); + tryReadBasicLatin(new Uint8Array([95]), 0, 1); expect(pushSpy).to.have.not.been.called; }); }); @@ -48,19 +48,19 @@ describe('tryLatin()', () => { context('when there is 2 bytes', () => { context('that exceed 127', () => { it('returns null', () => { - expect(tryLatin(new Uint8Array([0, 128]), 0, 2)).be.null; - expect(tryLatin(new Uint8Array([128, 0]), 0, 2)).be.null; - expect(tryLatin(new Uint8Array([128, 128]), 0, 2)).be.null; + expect(tryReadBasicLatin(new Uint8Array([0, 128]), 0, 2)).be.null; + expect(tryReadBasicLatin(new Uint8Array([128, 0]), 0, 2)).be.null; + expect(tryReadBasicLatin(new Uint8Array([128, 128]), 0, 2)).be.null; }); }); it('calls fromCharCode twice', () => { - tryLatin(new Uint8Array([95, 105]), 0, 2); + tryReadBasicLatin(new Uint8Array([95, 105]), 0, 2); expect(fromCharCodeSpy).to.have.been.calledTwice; }); it('never calls array.push', () => { - tryLatin(new Uint8Array([95, 105]), 0, 2); + tryReadBasicLatin(new Uint8Array([95, 105]), 0, 2); expect(pushSpy).to.have.not.been.called; }); }); @@ -68,23 +68,23 @@ describe('tryLatin()', () => { context('when there is 3 bytes', () => { context('that exceed 127', () => { it('returns null', () => { - expect(tryLatin(new Uint8Array([0, 0, 128]), 0, 3)).be.null; - expect(tryLatin(new Uint8Array([0, 128, 0]), 0, 3)).be.null; - expect(tryLatin(new Uint8Array([128, 0, 0]), 0, 3)).be.null; - expect(tryLatin(new Uint8Array([128, 128, 128]), 0, 3)).be.null; - expect(tryLatin(new Uint8Array([128, 128, 0]), 0, 3)).be.null; - expect(tryLatin(new Uint8Array([128, 0, 128]), 0, 3)).be.null; - expect(tryLatin(new Uint8Array([0, 128, 128]), 0, 3)).be.null; + expect(tryReadBasicLatin(new Uint8Array([0, 0, 128]), 0, 3)).be.null; + expect(tryReadBasicLatin(new Uint8Array([0, 128, 0]), 0, 3)).be.null; + expect(tryReadBasicLatin(new Uint8Array([128, 0, 0]), 0, 3)).be.null; + expect(tryReadBasicLatin(new Uint8Array([128, 128, 128]), 0, 3)).be.null; + expect(tryReadBasicLatin(new Uint8Array([128, 128, 0]), 0, 3)).be.null; + expect(tryReadBasicLatin(new Uint8Array([128, 0, 128]), 0, 3)).be.null; + expect(tryReadBasicLatin(new Uint8Array([0, 128, 128]), 0, 3)).be.null; }); }); it('calls fromCharCode thrice', () => { - tryLatin(new Uint8Array([95, 105, 100]), 0, 3); + tryReadBasicLatin(new Uint8Array([95, 105, 100]), 0, 3); expect(fromCharCodeSpy).to.have.been.calledThrice; }); it('never calls array.push', () => { - tryLatin(new Uint8Array([95, 105, 100]), 0, 3); + tryReadBasicLatin(new Uint8Array([95, 105, 100]), 0, 3); expect(pushSpy).to.have.not.been.called; }); }); @@ -93,17 +93,18 @@ describe('tryLatin()', () => { context(`when there is ${stringLength} bytes`, () => { context('that exceed 127', () => { it('returns null', () => { - expect(tryLatin(new Uint8Array(stringLength).fill(128), 0, stringLength)).be.null; + expect(tryReadBasicLatin(new Uint8Array(stringLength).fill(128), 0, stringLength)).be + .null; }); }); it('calls fromCharCode once', () => { - tryLatin(new Uint8Array(stringLength).fill(95), 0, stringLength); + tryReadBasicLatin(new Uint8Array(stringLength).fill(95), 0, stringLength); expect(fromCharCodeSpy).to.have.been.calledOnce; }); it(`calls array.push ${stringLength}`, () => { - tryLatin(new Uint8Array(stringLength).fill(95), 0, stringLength); + tryReadBasicLatin(new Uint8Array(stringLength).fill(95), 0, stringLength); expect(pushSpy).to.have.callCount(stringLength); }); }); @@ -111,8 +112,67 @@ describe('tryLatin()', () => { context('when there is >21 bytes', () => { it('returns null', () => { - expect(tryLatin(new Uint8Array(21).fill(95), 0, 21)).be.null; - expect(tryLatin(new Uint8Array(201).fill(95), 0, 201)).be.null; + expect(tryReadBasicLatin(new Uint8Array(21).fill(95), 0, 21)).be.null; + expect(tryReadBasicLatin(new Uint8Array(201).fill(95), 0, 201)).be.null; + }); + }); +}); + +describe('tryWriteBasicLatin()', () => { + context('when given a string of length 0', () => { + it('returns 0 and does not modify the destination', () => { + const input = Uint8Array.from({ length: 10 }, () => 1); + expect(tryWriteBasicLatin(input, '', 2)).to.equal(0); + expect(input).to.deep.equal(Uint8Array.from({ length: 10 }, () => 1)); + }); + }); + + context('when given a string with a length larger than the buffer', () => { + it('returns null', () => { + const input = Uint8Array.from({ length: 10 }, () => 1); + expect(tryWriteBasicLatin(input, 'a'.repeat(11), 0)).to.be.null; + expect(tryWriteBasicLatin(input, 'a'.repeat(13), 2)).to.be.null; + }); + }); + + let charCodeAtSpy; + + beforeEach(() => { + charCodeAtSpy = sinon.spy(String.prototype, 'charCodeAt'); + }); + + afterEach(() => { + sinon.restore(); + }); + + for (let stringLength = 1; stringLength <= 25; stringLength++) { + context(`when there is ${stringLength} bytes`, () => { + context('that exceed 127', () => { + it('returns null', () => { + expect( + tryWriteBasicLatin( + new Uint8Array(stringLength * 3), + 'a'.repeat(stringLength - 1) + '\x80', + 0 + ) + ).be.null; + }); + }); + + it(`calls charCodeAt ${stringLength}`, () => { + tryWriteBasicLatin( + new Uint8Array(stringLength * 3), + String.fromCharCode(127).repeat(stringLength), + stringLength + ); + expect(charCodeAtSpy).to.have.callCount(stringLength); + }); + }); + } + + context('when there is >25 characters', () => { + it('returns null', () => { + expect(tryWriteBasicLatin(new Uint8Array(75), 'a'.repeat(26), 0)).be.null; }); }); });