From 38d95f7f138bace1be011c3ea91db8edb6ea8568 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 25 Mar 2019 11:43:23 -0400 Subject: [PATCH 01/10] add packages --- package.json | 4 +++- yarn.lock | 12 ++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/package.json b/package.json index 4c52018a28..1a10cd04b3 100644 --- a/package.json +++ b/package.json @@ -20,6 +20,7 @@ "devDependencies": { "@types/jasmine": "~2.5.53", "@types/node": "~9.6.0", + "@types/node-fetch": "^2.1.2", "browserify": "~16.2.3", "clang-format": "~1.2.4", "jasmine": "~3.1.0", @@ -62,6 +63,7 @@ "@types/seedrandom": "2.4.27", "@types/webgl-ext": "0.0.30", "@types/webgl2": "0.0.4", - "seedrandom": "2.4.3" + "seedrandom": "2.4.3", + "node-fetch": "~2.1.2" } } diff --git a/yarn.lock b/yarn.lock index 4c801459a6..7f6ff221a7 100644 --- a/yarn.lock +++ b/yarn.lock @@ -12,6 +12,13 @@ resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-2.5.54.tgz#a6b5f2ae2afb6e0307774e8c7c608e037d491c63" integrity sha512-B9YofFbUljs19g5gBKUYeLIulsh31U5AK70F41BImQRHEZQGm4GcN922UvnYwkduMqbC/NH+9fruWa/zrqvHIg== +"@types/node-fetch@^2.1.2": + version "2.1.7" + resolved "https://registry.yarnpkg.com/@types/node-fetch/-/node-fetch-2.1.7.tgz#0231559340f6e3f3a0608692077d744c87b5b367" + integrity sha512-TZozHCDVrs0Aj1B9ZR5F4Q9MknDNcVd+hO5lxXOCzz07ELBey6s1gMUSZHUYHlPfRFKJFXiTnNuD7ePiI6S4/g== + dependencies: + "@types/node" "*" + "@types/node@*": version "10.12.18" resolved "https://registry.yarnpkg.com/@types/node/-/node-10.12.18.tgz#1d3ca764718915584fcd9f6344621b7672665c67" @@ -2924,6 +2931,11 @@ nice-try@^1.0.4: resolved "https://registry.yarnpkg.com/nice-try/-/nice-try-1.0.5.tgz#a3378a7696ce7d223e88fc9b764bd7ef1089e366" integrity sha512-1nh45deeb5olNY7eX82BkPO7SSxR5SSYJiPTrTdFUVYwAl8CKMA5N9PjTYkHiRjisVcxcQ1HXdLhx2qxxJzLNQ== +node-fetch@~2.1.2: + version "2.1.2" + resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.1.2.tgz#ab884e8e7e57e38a944753cec706f788d1768bb5" + integrity sha1-q4hOjn5X44qUR1POxwb3iNF2i7U= + node-pre-gyp@^0.10.0: version "0.10.3" resolved "https://registry.yarnpkg.com/node-pre-gyp/-/node-pre-gyp-0.10.3.tgz#3070040716afdc778747b61b6887bf78880b80fc" From 8fdb2508e4bbb335be7cb650507226ad88725523 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 25 Mar 2019 13:02:51 -0400 Subject: [PATCH 02/10] add fetch --- src/util.ts | 112 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 64 insertions(+), 48 deletions(-) diff --git a/src/util.ts b/src/util.ts index 47da965965..9edb496d95 100644 --- a/src/util.ts +++ b/src/util.ts @@ -15,7 +15,8 @@ * ============================================================================= */ -import {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray} from './types'; +import { DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray } from './types'; +import { ENV } from './environment'; /** * Shuffles the array in-place using Fisher-Yates algorithm. @@ -30,8 +31,8 @@ import {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, Tens */ /** @doc {heading: 'Util'} */ // tslint:disable-next-line:no-any -export function shuffle(array: any[]|Uint32Array|Int32Array| - Float32Array): void { +export function shuffle(array: any[] | Uint32Array | Int32Array | + Float32Array): void { let counter = array.length; let temp = 0; let index = 0; @@ -107,16 +108,16 @@ export function assert(expr: boolean, msg: () => string) { } export function assertShapesMatch( - shapeA: number[], shapeB: number[], errorMessagePrefix = ''): void { + shapeA: number[], shapeB: number[], errorMessagePrefix = ''): void { assert( - arraysEqual(shapeA, shapeB), - () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`); + arraysEqual(shapeA, shapeB), + () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`); } export function assertNonNull(a: TensorLike): void { assert( - a != null, - () => `The input to the tensor constructor must be a non-null value.`); + a != null, + () => `The input to the tensor constructor must be a non-null value.`); } // NOTE: We explicitly type out what T extends instead of any so that @@ -136,8 +137,8 @@ export function assertNonNull(a: TensorLike): void { */ /** @doc {heading: 'Util'} */ export function -flatten|TypedArray>( - arr: T|RecursiveArray, result: T[] = []): T[] { + flatten | TypedArray>( + arr: T | RecursiveArray, result: T[] = []): T[] { if (result == null) { result = []; } @@ -238,8 +239,8 @@ export function rightPad(a: string, size: number): string { } export function repeatedTry( - checkFn: () => boolean, delayFn = (counter: number) => 0, - maxCounter?: number): Promise { + checkFn: () => boolean, delayFn = (counter: number) => 0, + maxCounter?: number): Promise { return new Promise((resolve, reject) => { let tryCount = 0; @@ -274,7 +275,7 @@ export function repeatedTry( * @return The inferred shape where -1 is replaced with the inferred size. */ export function inferFromImplicitShape( - shape: number[], size: number): number[] { + shape: number[], size: number): number[] { let shapeProd = 1; let implicitIdx = -1; @@ -284,8 +285,8 @@ export function inferFromImplicitShape( } else if (shape[i] === -1) { if (implicitIdx !== -1) { throw Error( - `Shapes can only have 1 implicit size. ` + - `Found -1 at dim ${implicitIdx} and dim ${i}`); + `Shapes can only have 1 implicit size. ` + + `Found -1 at dim ${implicitIdx} and dim ${i}`); } implicitIdx = i; } else if (shape[i] < 0) { @@ -302,13 +303,13 @@ export function inferFromImplicitShape( if (shapeProd === 0) { throw Error( - `Cannot infer the missing size in [${shape}] when ` + - `there are 0 elements`); + `Cannot infer the missing size in [${shape}] when ` + + `there are 0 elements`); } if (size % shapeProd !== 0) { throw Error( - `The implicit shape can't be a fractional number. ` + - `Got ${size} / ${shapeProd}`); + `The implicit shape can't be a fractional number. ` + + `Got ${size} / ${shapeProd}`); } const newShape = shape.slice(); @@ -317,7 +318,7 @@ export function inferFromImplicitShape( } export function parseAxisParam( - axis: number|number[], shape: number[]): number[] { + axis: number | number[], shape: number[]): number[] { const rank = shape.length; // Normalize input @@ -325,24 +326,23 @@ export function parseAxisParam( // Check for valid range assert( - axis.every(ax => ax >= -rank && ax < rank), - () => - `All values in axis param must be in range [-${rank}, ${rank}) but ` + - `got axis ${axis}`); + axis.every(ax => ax >= -rank && ax < rank), + () => + `All values in axis param must be in range [-${rank}, ${rank}) but ` + + `got axis ${axis}`); // Check for only integers assert( - axis.every(ax => isInt(ax)), - () => `All values in axis param must be integers but ` + - `got axis ${axis}`); + axis.every(ax => isInt(ax)), + () => `All values in axis param must be integers but ` + + `got axis ${axis}`); // Handle negative axis. return axis.map(a => a < 0 ? rank + a : a); } /** Reduces the shape by removing all dimensions of shape 1. */ -export function squeezeShape(shape: number[], axis?: number[]): - {newShape: number[], keptDims: number[]} { +export function squeezeShape(shape: number[], axis?: number[]): { newShape: number[], keptDims: number[] } { const newShape: number[] = []; const keptDims: number[] = []; const axes = axis == null ? null : parseAxisParam(axis, shape).sort(); @@ -351,7 +351,7 @@ export function squeezeShape(shape: number[], axis?: number[]): if (axes != null) { if (axes[j] === i && shape[i] !== 1) { throw new Error( - `Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`); + `Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`); } if ((axes[j] == null || axes[j] > i) && shape[i] === 1) { newShape.push(shape[i]); @@ -366,11 +366,11 @@ export function squeezeShape(shape: number[], axis?: number[]): keptDims.push(i); } } - return {newShape, keptDims}; + return { newShape, keptDims }; } export function getTypedArrayFromDType( - dtype: D, size: number): DataTypeMap[D] { + dtype: D, size: number): DataTypeMap[D] { let values = null; if (dtype == null || dtype === 'float32') { values = new Float32Array(size); @@ -385,7 +385,7 @@ export function getTypedArrayFromDType( } export function getArrayFromDType( - dtype: D, size: number): DataTypeMap[D] { + dtype: D, size: number): DataTypeMap[D] { let values = null; if (dtype == null || dtype === 'float32') { values = new Float32Array(size); @@ -402,7 +402,7 @@ export function getArrayFromDType( } export function checkComputationForErrors( - vals: DataTypeMap[D], dtype: D, name: string): void { + vals: DataTypeMap[D], dtype: D, name: string): void { if (dtype !== 'float32') { // Only floating point computations will generate NaN values return; @@ -416,7 +416,7 @@ export function checkComputationForErrors( } export function checkConversionForErrors( - vals: DataTypeMap[D]|number[], dtype: D): void { + vals: DataTypeMap[D] | number[], dtype: D): void { for (let i = 0; i < vals.length; i++) { const num = vals[i] as number; if (isNaN(num) || !isFinite(num)) { @@ -445,9 +445,9 @@ export function hasEncodingLoss(oldType: DataType, newType: DataType): boolean { return true; } -export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array { +export function isTypedArray(a: {}): a is Float32Array | Int32Array | Uint8Array { return a instanceof Float32Array || a instanceof Int32Array || - a instanceof Uint8Array; + a instanceof Uint8Array; } export function bytesPerElement(dtype: DataType): number { @@ -538,7 +538,7 @@ export function computeStrides(shape: number[]): number[] { } export function toTypedArray( - a: TensorLike, dtype: DataType, debugMode: boolean): TypedArray { + a: TensorLike, dtype: DataType, debugMode: boolean): TypedArray { if (dtype === 'string') { throw new Error('Cannot convert a string[] to a TypedArray'); } @@ -606,12 +606,12 @@ export function toNestedArray(shape: number[], a: TypedArray) { function noConversionNeeded(a: TensorLike, dtype: DataType): boolean { return (a instanceof Float32Array && dtype === 'float32') || - (a instanceof Int32Array && dtype === 'int32') || - (a instanceof Uint8Array && dtype === 'bool'); + (a instanceof Int32Array && dtype === 'int32') || + (a instanceof Uint8Array && dtype === 'bool'); } export function makeOnesTypedArray( - size: number, dtype: D): DataTypeMap[D] { + size: number, dtype: D): DataTypeMap[D] { const array = makeZerosTypedArray(size, dtype); for (let i = 0; i < array.length; i++) { array[i] = 1; @@ -620,7 +620,7 @@ export function makeOnesTypedArray( } export function makeZerosTypedArray( - size: number, dtype: D): DataTypeMap[D] { + size: number, dtype: D): DataTypeMap[D] { if (dtype == null || dtype === 'float32' || dtype === 'complex64') { return new Float32Array(size); } else if (dtype === 'int32') { @@ -650,17 +650,33 @@ export function now(): number { return time[0] * 1000 + time[1] / 1000000; } else { throw new Error( - 'Cannot measure time in this environment. You should run tf.js ' + - 'in the browser or in Node.js'); + 'Cannot measure time in this environment. You should run tf.js ' + + 'in the browser or in Node.js'); } } export function assertNonNegativeIntegerDimensions(shape: number[]) { shape.forEach(dimSize => { assert( - Number.isInteger(dimSize) && dimSize >= 0, - () => - `Tensor must have a shape comprised of positive integers but got ` + - `shape [${shape}].`); + Number.isInteger(dimSize) && dimSize >= 0, + () => + `Tensor must have a shape comprised of positive integers but got ` + + `shape [${shape}].`); }); } + +export function fetch() { + let fetchFunc: Function; + + if (ENV.global.fetch != null) { + fetchFunc = ENV.global.fetch; + } else { + if (!ENV.get('IS_BROWSER')) { + fetchFunc = require('node-fetch'); + } else { + throw new Error(`Unable to find a fetch function on ${ENV.global}`) + } + } + + return fetchFunc; +} From 73bbaed88eece7d7ce7687d6dbbd596301e96d6e Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 26 Mar 2019 09:24:35 -0400 Subject: [PATCH 03/10] edit replace --- src/io/weights_loader.ts | 8 +-- src/util.ts | 116 ++++++++++++++++++++++----------------- 2 files changed, 68 insertions(+), 56 deletions(-) diff --git a/src/io/weights_loader.ts b/src/io/weights_loader.ts index e55828b156..161d8c318d 100644 --- a/src/io/weights_loader.ts +++ b/src/io/weights_loader.ts @@ -27,7 +27,6 @@ import {DTYPE_VALUE_SIZE_MAP, LoadOptions, WeightsManifestConfig, WeightsManifes * * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls. * @param requestOptions RequestInit (options) for the HTTP requests. - * @param fetchFunc Optional overriding value for the `window.fetch` function. * @param onProgress Optional, progress callback function, fired periodically * before the load is completed. * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same @@ -39,12 +38,9 @@ export async function loadWeightsAsArrayBuffer( loadOptions = {}; } - const fetchFunc = - loadOptions.fetchFunc == null ? fetch : loadOptions.fetchFunc; - // Create the requests for all of the weights in parallel. - const requests = - fetchURLs.map(fetchURL => fetchFunc(fetchURL, loadOptions.requestInit)); + const requests = fetchURLs.map( + fetchURL => util.fetch()(fetchURL, loadOptions.requestInit)); const fetchStartFraction = 0; const fetchEndFraction = 0.5; diff --git a/src/util.ts b/src/util.ts index 9edb496d95..b672d10c02 100644 --- a/src/util.ts +++ b/src/util.ts @@ -15,8 +15,9 @@ * ============================================================================= */ -import { DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray } from './types'; -import { ENV } from './environment'; +import {ENV} from './environment'; +import {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray} from './types'; + /** * Shuffles the array in-place using Fisher-Yates algorithm. @@ -31,8 +32,8 @@ import { ENV } from './environment'; */ /** @doc {heading: 'Util'} */ // tslint:disable-next-line:no-any -export function shuffle(array: any[] | Uint32Array | Int32Array | - Float32Array): void { +export function shuffle(array: any[]|Uint32Array|Int32Array| + Float32Array): void { let counter = array.length; let temp = 0; let index = 0; @@ -108,16 +109,16 @@ export function assert(expr: boolean, msg: () => string) { } export function assertShapesMatch( - shapeA: number[], shapeB: number[], errorMessagePrefix = ''): void { + shapeA: number[], shapeB: number[], errorMessagePrefix = ''): void { assert( - arraysEqual(shapeA, shapeB), - () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`); + arraysEqual(shapeA, shapeB), + () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`); } export function assertNonNull(a: TensorLike): void { assert( - a != null, - () => `The input to the tensor constructor must be a non-null value.`); + a != null, + () => `The input to the tensor constructor must be a non-null value.`); } // NOTE: We explicitly type out what T extends instead of any so that @@ -137,8 +138,8 @@ export function assertNonNull(a: TensorLike): void { */ /** @doc {heading: 'Util'} */ export function - flatten | TypedArray>( - arr: T | RecursiveArray, result: T[] = []): T[] { +flatten|TypedArray>( + arr: T|RecursiveArray, result: T[] = []): T[] { if (result == null) { result = []; } @@ -239,8 +240,8 @@ export function rightPad(a: string, size: number): string { } export function repeatedTry( - checkFn: () => boolean, delayFn = (counter: number) => 0, - maxCounter?: number): Promise { + checkFn: () => boolean, delayFn = (counter: number) => 0, + maxCounter?: number): Promise { return new Promise((resolve, reject) => { let tryCount = 0; @@ -275,7 +276,7 @@ export function repeatedTry( * @return The inferred shape where -1 is replaced with the inferred size. */ export function inferFromImplicitShape( - shape: number[], size: number): number[] { + shape: number[], size: number): number[] { let shapeProd = 1; let implicitIdx = -1; @@ -285,8 +286,8 @@ export function inferFromImplicitShape( } else if (shape[i] === -1) { if (implicitIdx !== -1) { throw Error( - `Shapes can only have 1 implicit size. ` + - `Found -1 at dim ${implicitIdx} and dim ${i}`); + `Shapes can only have 1 implicit size. ` + + `Found -1 at dim ${implicitIdx} and dim ${i}`); } implicitIdx = i; } else if (shape[i] < 0) { @@ -303,13 +304,13 @@ export function inferFromImplicitShape( if (shapeProd === 0) { throw Error( - `Cannot infer the missing size in [${shape}] when ` + - `there are 0 elements`); + `Cannot infer the missing size in [${shape}] when ` + + `there are 0 elements`); } if (size % shapeProd !== 0) { throw Error( - `The implicit shape can't be a fractional number. ` + - `Got ${size} / ${shapeProd}`); + `The implicit shape can't be a fractional number. ` + + `Got ${size} / ${shapeProd}`); } const newShape = shape.slice(); @@ -318,7 +319,7 @@ export function inferFromImplicitShape( } export function parseAxisParam( - axis: number | number[], shape: number[]): number[] { + axis: number|number[], shape: number[]): number[] { const rank = shape.length; // Normalize input @@ -326,23 +327,24 @@ export function parseAxisParam( // Check for valid range assert( - axis.every(ax => ax >= -rank && ax < rank), - () => - `All values in axis param must be in range [-${rank}, ${rank}) but ` + - `got axis ${axis}`); + axis.every(ax => ax >= -rank && ax < rank), + () => + `All values in axis param must be in range [-${rank}, ${rank}) but ` + + `got axis ${axis}`); // Check for only integers assert( - axis.every(ax => isInt(ax)), - () => `All values in axis param must be integers but ` + - `got axis ${axis}`); + axis.every(ax => isInt(ax)), + () => `All values in axis param must be integers but ` + + `got axis ${axis}`); // Handle negative axis. return axis.map(a => a < 0 ? rank + a : a); } /** Reduces the shape by removing all dimensions of shape 1. */ -export function squeezeShape(shape: number[], axis?: number[]): { newShape: number[], keptDims: number[] } { +export function squeezeShape(shape: number[], axis?: number[]): + {newShape: number[], keptDims: number[]} { const newShape: number[] = []; const keptDims: number[] = []; const axes = axis == null ? null : parseAxisParam(axis, shape).sort(); @@ -351,7 +353,7 @@ export function squeezeShape(shape: number[], axis?: number[]): { newShape: numb if (axes != null) { if (axes[j] === i && shape[i] !== 1) { throw new Error( - `Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`); + `Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`); } if ((axes[j] == null || axes[j] > i) && shape[i] === 1) { newShape.push(shape[i]); @@ -366,11 +368,11 @@ export function squeezeShape(shape: number[], axis?: number[]): { newShape: numb keptDims.push(i); } } - return { newShape, keptDims }; + return {newShape, keptDims}; } export function getTypedArrayFromDType( - dtype: D, size: number): DataTypeMap[D] { + dtype: D, size: number): DataTypeMap[D] { let values = null; if (dtype == null || dtype === 'float32') { values = new Float32Array(size); @@ -385,7 +387,7 @@ export function getTypedArrayFromDType( } export function getArrayFromDType( - dtype: D, size: number): DataTypeMap[D] { + dtype: D, size: number): DataTypeMap[D] { let values = null; if (dtype == null || dtype === 'float32') { values = new Float32Array(size); @@ -402,7 +404,7 @@ export function getArrayFromDType( } export function checkComputationForErrors( - vals: DataTypeMap[D], dtype: D, name: string): void { + vals: DataTypeMap[D], dtype: D, name: string): void { if (dtype !== 'float32') { // Only floating point computations will generate NaN values return; @@ -416,7 +418,7 @@ export function checkComputationForErrors( } export function checkConversionForErrors( - vals: DataTypeMap[D] | number[], dtype: D): void { + vals: DataTypeMap[D]|number[], dtype: D): void { for (let i = 0; i < vals.length; i++) { const num = vals[i] as number; if (isNaN(num) || !isFinite(num)) { @@ -445,9 +447,9 @@ export function hasEncodingLoss(oldType: DataType, newType: DataType): boolean { return true; } -export function isTypedArray(a: {}): a is Float32Array | Int32Array | Uint8Array { +export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array { return a instanceof Float32Array || a instanceof Int32Array || - a instanceof Uint8Array; + a instanceof Uint8Array; } export function bytesPerElement(dtype: DataType): number { @@ -538,7 +540,7 @@ export function computeStrides(shape: number[]): number[] { } export function toTypedArray( - a: TensorLike, dtype: DataType, debugMode: boolean): TypedArray { + a: TensorLike, dtype: DataType, debugMode: boolean): TypedArray { if (dtype === 'string') { throw new Error('Cannot convert a string[] to a TypedArray'); } @@ -606,12 +608,12 @@ export function toNestedArray(shape: number[], a: TypedArray) { function noConversionNeeded(a: TensorLike, dtype: DataType): boolean { return (a instanceof Float32Array && dtype === 'float32') || - (a instanceof Int32Array && dtype === 'int32') || - (a instanceof Uint8Array && dtype === 'bool'); + (a instanceof Int32Array && dtype === 'int32') || + (a instanceof Uint8Array && dtype === 'bool'); } export function makeOnesTypedArray( - size: number, dtype: D): DataTypeMap[D] { + size: number, dtype: D): DataTypeMap[D] { const array = makeZerosTypedArray(size, dtype); for (let i = 0; i < array.length; i++) { array[i] = 1; @@ -620,7 +622,7 @@ export function makeOnesTypedArray( } export function makeZerosTypedArray( - size: number, dtype: D): DataTypeMap[D] { + size: number, dtype: D): DataTypeMap[D] { if (dtype == null || dtype === 'float32' || dtype === 'complex64') { return new Float32Array(size); } else if (dtype === 'int32') { @@ -650,31 +652,45 @@ export function now(): number { return time[0] * 1000 + time[1] / 1000000; } else { throw new Error( - 'Cannot measure time in this environment. You should run tf.js ' + - 'in the browser or in Node.js'); + 'Cannot measure time in this environment. You should run tf.js ' + + 'in the browser or in Node.js'); } } export function assertNonNegativeIntegerDimensions(shape: number[]) { shape.forEach(dimSize => { assert( - Number.isInteger(dimSize) && dimSize >= 0, - () => - `Tensor must have a shape comprised of positive integers but got ` + - `shape [${shape}].`); + Number.isInteger(dimSize) && dimSize >= 0, + () => + `Tensor must have a shape comprised of positive integers but got ` + + `shape [${shape}].`); }); } +/** + * Returns a platform-specific implementation of `window.fetch`. + * + * If `fetch` is defined on the global object (`window`, `process`, etc.), + * `tf.util.fetch` returns that function. + * + * If not, `tf.util.fetch` returns a platform-specific solution. + * + * ```js + * tf.util.fetch('path/to/resource') + * .then(response => {}) // handle response + * ``` + */ export function fetch() { let fetchFunc: Function; if (ENV.global.fetch != null) { fetchFunc = ENV.global.fetch; } else { - if (!ENV.get('IS_BROWSER')) { + if (ENV.get('IS_NODE')) { + // tslint:disable-next-line:no-require-imports fetchFunc = require('node-fetch'); } else { - throw new Error(`Unable to find a fetch function on ${ENV.global}`) + throw new Error(`Unable to find fetch solution on ${ENV.global}`) } } From 492e9d3972d0c743aab85d0de7b3d6ff0354602f Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 26 Mar 2019 10:58:51 -0400 Subject: [PATCH 04/10] remove fetch logic from browserhttp --- src/io/browser_http.ts | 10 ++-------- src/io/weights_loader.ts | 8 ++++++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/io/browser_http.ts b/src/io/browser_http.ts index d27ee1c7d6..25ed359e49 100644 --- a/src/io/browser_http.ts +++ b/src/io/browser_http.ts @@ -22,7 +22,7 @@ */ import {ENV} from '../environment'; -import {assert} from '../util'; +import {assert, fetch as systemFetch} from '../util'; import {concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils'; import {IORouter, IORouterRegistry} from './router_registry'; import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types'; @@ -51,15 +51,9 @@ export class BrowserHTTPRequest implements IOHandler { this.onProgress = loadOptions.onProgress; if (loadOptions.fetchFunc == null) { - const systemFetch = ENV.global.fetch; - if (typeof systemFetch === 'undefined') { - throw new Error( - 'browserHTTPRequest is not supported outside the web browser ' + - 'without a fetch polyfill.'); - } // Make sure fetch is always bound to global object (the // original object) when available. - loadOptions.fetchFunc = systemFetch.bind(ENV.global); + loadOptions.fetchFunc = systemFetch().bind(ENV.global); } else { assert( typeof loadOptions.fetchFunc === 'function', diff --git a/src/io/weights_loader.ts b/src/io/weights_loader.ts index 161d8c318d..1553290b78 100644 --- a/src/io/weights_loader.ts +++ b/src/io/weights_loader.ts @@ -27,6 +27,7 @@ import {DTYPE_VALUE_SIZE_MAP, LoadOptions, WeightsManifestConfig, WeightsManifes * * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls. * @param requestOptions RequestInit (options) for the HTTP requests. + * @param fetchFunc Optional overriding value for the `window.fetch` function. * @param onProgress Optional, progress callback function, fired periodically * before the load is completed. * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same @@ -38,9 +39,12 @@ export async function loadWeightsAsArrayBuffer( loadOptions = {}; } + const fetchFunc = + loadOptions.fetchFunc == null ? util.fetch() : loadOptions.fetchFunc; + // Create the requests for all of the weights in parallel. - const requests = fetchURLs.map( - fetchURL => util.fetch()(fetchURL, loadOptions.requestInit)); + const requests = + fetchURLs.map(fetchURL => fetchFunc(fetchURL, loadOptions.requestInit)); const fetchStartFraction = 0; const fetchEndFraction = 0.5; From f895a3b794734344b4f1459c161d47580ad14b70 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 26 Mar 2019 11:08:39 -0400 Subject: [PATCH 05/10] add to rollup ignore --- .gitignore | 1 - package.json | 1 - rollup.config.js | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index bbd3452ac8..08e8599807 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,3 @@ dist/ yalc.lock .rpt2_cache/ package/ -integration_tests/benchmarks/ui/bundle.js diff --git a/package.json b/package.json index 1a10cd04b3..605de8a810 100644 --- a/package.json +++ b/package.json @@ -54,7 +54,6 @@ "coverage": "KARMA_COVERAGE=1 karma start --browsers='Chrome' --singleRun", "test": "karma start", "run-browserstack": "karma start --browserstack", - "test-benchmark": "cd integration_tests/benchmarks && yarn benchmark-travis && cd ../../", "test-node": "node dist/test_node.js", "test-integration": "./scripts/test-integration.sh", "test-travis": "./scripts/test-travis.sh" diff --git a/rollup.config.js b/rollup.config.js index de56e3e0e4..d9722131f2 100644 --- a/rollup.config.js +++ b/rollup.config.js @@ -49,7 +49,7 @@ function config({plugins = [], output = {}, external = []}) { node(), // Polyfill require() from dependencies. commonjs({ - ignore: ['crypto'], + ignore: ['crypto', 'node-fetch'], include: 'node_modules/**', namedExports: { './node_modules/seedrandom/index.js': ['alea'], From 4cc532a08fe917b532f09192d1899f0ce87b4542 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 26 Mar 2019 14:38:45 -0400 Subject: [PATCH 06/10] modify test --- src/io/browser_http_test.ts | 4 +--- src/util.ts | 2 +- src/util_test.ts | 10 ++++++++++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/io/browser_http_test.ts b/src/io/browser_http_test.ts index 36fd9b0b6c..68f9232a91 100644 --- a/src/io/browser_http_test.ts +++ b/src/io/browser_http_test.ts @@ -164,9 +164,7 @@ describeWithFlags('browserHTTPRequest-load fetch', NODE_ENVS, () => { try { tf.io.browserHTTPRequest('./model.json'); } catch (err) { - expect(err.message) - .toMatch( - /not supported outside the web browser without a fetch polyfill/); + expect(err.message).toMatch(/Unable to find fetch polyfill./); } }); }); diff --git a/src/util.ts b/src/util.ts index b672d10c02..250d586837 100644 --- a/src/util.ts +++ b/src/util.ts @@ -690,7 +690,7 @@ export function fetch() { // tslint:disable-next-line:no-require-imports fetchFunc = require('node-fetch'); } else { - throw new Error(`Unable to find fetch solution on ${ENV.global}`) + throw new Error(`Unable to find fetch polyfill.`) } } diff --git a/src/util_test.ts b/src/util_test.ts index 4313c53845..ccfbb00112 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -15,6 +15,7 @@ * ============================================================================= */ +import {ENV} from './environment'; import {scalar, tensor2d} from './ops/ops'; import {inferShape} from './tensor_util_env'; import * as util from './util'; @@ -495,3 +496,12 @@ describe('util.toNestedArray', () => { expect(util.toNestedArray([1, 0, 2], a)).toEqual([]); }); }); + +describe('util.fetch', () => { + it('should allow overriding global fetch', () => { + spyOn(ENV.global, 'fetch').and.callFake(() => {}); + + util.fetch()(); + expect(ENV.global.fetch).toHaveBeenCalled(); + }); +}); From c36440dcf86d384da56d5ef6e738bc69a80ae8c1 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 26 Mar 2019 16:41:39 -0400 Subject: [PATCH 07/10] lazily evaluate --- src/io/browser_http.ts | 19 +++++-------------- src/io/weights_loader.ts | 2 +- src/util.ts | 36 ++++++++++++++++++++++-------------- src/util_test.ts | 2 +- 4 files changed, 29 insertions(+), 30 deletions(-) diff --git a/src/io/browser_http.ts b/src/io/browser_http.ts index 25ed359e49..b169fa8293 100644 --- a/src/io/browser_http.ts +++ b/src/io/browser_http.ts @@ -21,7 +21,6 @@ * Uses [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). */ -import {ENV} from '../environment'; import {assert, fetch as systemFetch} from '../util'; import {concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils'; import {IORouter, IORouterRegistry} from './router_registry'; @@ -34,7 +33,7 @@ export class BrowserHTTPRequest implements IOHandler { protected readonly path: string; protected readonly requestInit: RequestInit; - private readonly fetchFunc: (path: string, init?: RequestInit) => Response; + private readonly fetchFunc: Function; readonly DEFAULT_METHOD = 'POST'; @@ -50,25 +49,17 @@ export class BrowserHTTPRequest implements IOHandler { this.weightPathPrefix = loadOptions.weightPathPrefix; this.onProgress = loadOptions.onProgress; - if (loadOptions.fetchFunc == null) { - // Make sure fetch is always bound to global object (the - // original object) when available. - loadOptions.fetchFunc = systemFetch().bind(ENV.global); - } else { + if (loadOptions.fetchFunc != null) { assert( typeof loadOptions.fetchFunc === 'function', () => 'Must pass a function that matches the signature of ' + '`fetch` (see ' + 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)'); + this.fetchFunc = loadOptions.fetchFunc; + } else { + this.fetchFunc = systemFetch; } - this.fetchFunc = (path: string, requestInits: RequestInit) => { - // tslint:disable-next-line:no-any - return loadOptions.fetchFunc(path, requestInits).catch((error: any) => { - throw new Error(`Request for ${path} failed due to error: ${error}`); - }); - }; - assert( path != null && path.length > 0, () => diff --git a/src/io/weights_loader.ts b/src/io/weights_loader.ts index 1553290b78..1d0f2cc6d1 100644 --- a/src/io/weights_loader.ts +++ b/src/io/weights_loader.ts @@ -40,7 +40,7 @@ export async function loadWeightsAsArrayBuffer( } const fetchFunc = - loadOptions.fetchFunc == null ? util.fetch() : loadOptions.fetchFunc; + loadOptions.fetchFunc == null ? util.fetch : loadOptions.fetchFunc; // Create the requests for all of the weights in parallel. const requests = diff --git a/src/util.ts b/src/util.ts index 250d586837..dfa1880dda 100644 --- a/src/util.ts +++ b/src/util.ts @@ -667,6 +667,23 @@ export function assertNonNegativeIntegerDimensions(shape: number[]) { }); } +let systemFetch: Function; +const getSystemFetch = () => { + let fetchFunc: Function; + + if (ENV.global.fetch != null) { + fetchFunc = ENV.global.fetch; + } else { + if (ENV.get('IS_NODE')) { + // tslint:disable-next-line:no-require-imports + fetchFunc = require('node-fetch'); + } else { + throw new Error(`Unable to find fetch polyfill.`); + } + } + return fetchFunc; +}; + /** * Returns a platform-specific implementation of `window.fetch`. * @@ -680,19 +697,10 @@ export function assertNonNegativeIntegerDimensions(shape: number[]) { * .then(response => {}) // handle response * ``` */ -export function fetch() { - let fetchFunc: Function; - - if (ENV.global.fetch != null) { - fetchFunc = ENV.global.fetch; - } else { - if (ENV.get('IS_NODE')) { - // tslint:disable-next-line:no-require-imports - fetchFunc = require('node-fetch'); - } else { - throw new Error(`Unable to find fetch polyfill.`) - } +/** @doc {heading: 'Util'} */ +export function fetch(path: string, requestInits?: RequestInit) { + if (systemFetch == null) { + systemFetch = getSystemFetch(); } - - return fetchFunc; + return systemFetch(path, requestInits); } diff --git a/src/util_test.ts b/src/util_test.ts index ccfbb00112..c212a77d2d 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -501,7 +501,7 @@ describe('util.fetch', () => { it('should allow overriding global fetch', () => { spyOn(ENV.global, 'fetch').and.callFake(() => {}); - util.fetch()(); + util.fetch(''); expect(ENV.global.fetch).toHaveBeenCalled(); }); }); From 5ed7a4de0938109815a42daecee2792196824202 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 4 Apr 2019 15:45:55 -0400 Subject: [PATCH 08/10] whitespace --- src/util.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/src/util.ts b/src/util.ts index bedd3a2127..417192ec9d 100644 --- a/src/util.ts +++ b/src/util.ts @@ -18,7 +18,6 @@ import {ENV} from './environment'; import {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray} from './types'; - /** * Shuffles the array in-place using Fisher-Yates algorithm. * From b1d619de49c1cbfdd737e0b1b692fa160c50452b Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 4 Apr 2019 15:50:55 -0400 Subject: [PATCH 09/10] pr comments --- src/util.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/util.ts b/src/util.ts index 417192ec9d..51f794e3c2 100644 --- a/src/util.ts +++ b/src/util.ts @@ -684,7 +684,8 @@ const getSystemFetch = () => { }; /** - * Returns a platform-specific implementation of `window.fetch`. + * Returns a platform-specific implementation of + * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). * * If `fetch` is defined on the global object (`window`, `process`, etc.), * `tf.util.fetch` returns that function. From efec158210a06e85be0dd4d8bdda776cd4022a1f Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 4 Apr 2019 16:47:45 -0400 Subject: [PATCH 10/10] debug --- src/util.ts | 1 + src/util_test.ts | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/util.ts b/src/util.ts index 51f794e3c2..38a400765b 100644 --- a/src/util.ts +++ b/src/util.ts @@ -699,6 +699,7 @@ const getSystemFetch = () => { */ /** @doc {heading: 'Util'} */ export function fetch(path: string, requestInits?: RequestInit) { + console.log('calling fetch util'); if (systemFetch == null) { systemFetch = getSystemFetch(); } diff --git a/src/util_test.ts b/src/util_test.ts index c212a77d2d..78d1b78c5a 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -499,6 +499,8 @@ describe('util.toNestedArray', () => { describe('util.fetch', () => { it('should allow overriding global fetch', () => { + console.log('in fetch test'); + console.log(ENV.global); spyOn(ENV.global, 'fetch').and.callFake(() => {}); util.fetch('');