diff --git a/.gitignore b/.gitignore index 9b3b3cbe57..d3fe831510 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,3 @@ bazel-out/ yalc.lock .rpt2_cache/ package/ -integration_tests/benchmarks/ui/bundle.js diff --git a/package.json b/package.json index 2579b342cf..28021642b8 100644 --- a/package.json +++ b/package.json @@ -22,6 +22,7 @@ "@bazel/typescript": "^0.27.8", "@types/jasmine": "~2.5.53", "@types/node": "~9.6.0", + "@types/node-fetch": "^2.1.2", "browserify": "~16.2.3", "clang-format": "~1.2.4", "jasmine": "~3.1.0", @@ -65,6 +66,7 @@ "@types/seedrandom": "2.4.27", "@types/webgl-ext": "0.0.30", "@types/webgl2": "0.0.4", - "seedrandom": "2.4.3" + "seedrandom": "2.4.3", + "node-fetch": "~2.1.2" } } diff --git a/rollup.config.js b/rollup.config.js index de56e3e0e4..d9722131f2 100644 --- a/rollup.config.js +++ b/rollup.config.js @@ -49,7 +49,7 @@ function config({plugins = [], output = {}, external = []}) { node(), // Polyfill require() from dependencies. commonjs({ - ignore: ['crypto'], + ignore: ['crypto', 'node-fetch'], include: 'node_modules/**', namedExports: { './node_modules/seedrandom/index.js': ['alea'], diff --git a/src/io/browser_http.ts b/src/io/browser_http.ts index 9217900be0..8841fe25f6 100644 --- a/src/io/browser_http.ts +++ b/src/io/browser_http.ts @@ -21,8 +21,7 @@ * Uses [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). */ -import {ENV} from '../environment'; -import {assert} from '../util'; +import {assert, fetch as systemFetch} from '../util'; import {concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils'; import {IORouter, IORouterRegistry} from './router_registry'; import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types'; @@ -34,7 +33,7 @@ export class BrowserHTTPRequest implements IOHandler { protected readonly path: string; protected readonly requestInit: RequestInit; - private readonly fetchFunc: (path: string, init?: RequestInit) => Response; + private readonly fetchFunc: Function; readonly DEFAULT_METHOD = 'POST'; @@ -50,31 +49,17 @@ export class BrowserHTTPRequest implements IOHandler { this.weightPathPrefix = loadOptions.weightPathPrefix; this.onProgress = loadOptions.onProgress; - if (loadOptions.fetchFunc == null) { - const systemFetch = ENV.global.fetch; - if (typeof systemFetch === 'undefined') { - throw new Error( - 'browserHTTPRequest is not supported outside the web browser ' + - 'without a fetch polyfill.'); - } - // Make sure fetch is always bound to global object (the - // original object) when available. - loadOptions.fetchFunc = systemFetch.bind(ENV.global); - } else { + if (loadOptions.fetchFunc != null) { assert( typeof loadOptions.fetchFunc === 'function', () => 'Must pass a function that matches the signature of ' + '`fetch` (see ' + 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)'); + this.fetchFunc = loadOptions.fetchFunc; + } else { + this.fetchFunc = systemFetch; } - this.fetchFunc = (path: string, requestInits: RequestInit) => { - // tslint:disable-next-line:no-any - return loadOptions.fetchFunc(path, requestInits).catch((error: any) => { - throw new Error(`Request for ${path} failed due to error: ${error}`); - }); - }; - assert( path != null && path.length > 0, () => diff --git a/src/io/browser_http_test.ts b/src/io/browser_http_test.ts index 1b3f7a8f6c..02022666c8 100644 --- a/src/io/browser_http_test.ts +++ b/src/io/browser_http_test.ts @@ -162,9 +162,7 @@ describeWithFlags('browserHTTPRequest-load fetch', NODE_ENVS, () => { try { tf.io.browserHTTPRequest('./model.json'); } catch (err) { - expect(err.message) - .toMatch( - /not supported outside the web browser without a fetch polyfill/); + expect(err.message).toMatch(/Unable to find fetch polyfill./); } }); }); diff --git a/src/io/weights_loader.ts b/src/io/weights_loader.ts index e55828b156..1d0f2cc6d1 100644 --- a/src/io/weights_loader.ts +++ b/src/io/weights_loader.ts @@ -40,7 +40,7 @@ export async function loadWeightsAsArrayBuffer( } const fetchFunc = - loadOptions.fetchFunc == null ? fetch : loadOptions.fetchFunc; + loadOptions.fetchFunc == null ? util.fetch : loadOptions.fetchFunc; // Create the requests for all of the weights in parallel. const requests = diff --git a/src/util.ts b/src/util.ts index 26498aae2c..38a400765b 100644 --- a/src/util.ts +++ b/src/util.ts @@ -15,6 +15,7 @@ * ============================================================================= */ +import {ENV} from './environment'; import {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray} from './types'; /** @@ -664,3 +665,43 @@ export function assertNonNegativeIntegerDimensions(shape: number[]) { `shape [${shape}].`); }); } + +let systemFetch: Function; +const getSystemFetch = () => { + let fetchFunc: Function; + + if (ENV.global.fetch != null) { + fetchFunc = ENV.global.fetch; + } else { + if (ENV.get('IS_NODE')) { + // tslint:disable-next-line:no-require-imports + fetchFunc = require('node-fetch'); + } else { + throw new Error(`Unable to find fetch polyfill.`); + } + } + return fetchFunc; +}; + +/** + * Returns a platform-specific implementation of + * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). + * + * If `fetch` is defined on the global object (`window`, `process`, etc.), + * `tf.util.fetch` returns that function. + * + * If not, `tf.util.fetch` returns a platform-specific solution. + * + * ```js + * tf.util.fetch('path/to/resource') + * .then(response => {}) // handle response + * ``` + */ +/** @doc {heading: 'Util'} */ +export function fetch(path: string, requestInits?: RequestInit) { + console.log('calling fetch util'); + if (systemFetch == null) { + systemFetch = getSystemFetch(); + } + return systemFetch(path, requestInits); +} diff --git a/src/util_test.ts b/src/util_test.ts index 4313c53845..78d1b78c5a 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -15,6 +15,7 @@ * ============================================================================= */ +import {ENV} from './environment'; import {scalar, tensor2d} from './ops/ops'; import {inferShape} from './tensor_util_env'; import * as util from './util'; @@ -495,3 +496,14 @@ describe('util.toNestedArray', () => { expect(util.toNestedArray([1, 0, 2], a)).toEqual([]); }); }); + +describe('util.fetch', () => { + it('should allow overriding global fetch', () => { + console.log('in fetch test'); + console.log(ENV.global); + spyOn(ENV.global, 'fetch').and.callFake(() => {}); + + util.fetch(''); + expect(ENV.global.fetch).toHaveBeenCalled(); + }); +}); diff --git a/yarn.lock b/yarn.lock index bfbd6f27da..82049bacf9 100644 --- a/yarn.lock +++ b/yarn.lock @@ -46,6 +46,13 @@ resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-2.5.54.tgz#a6b5f2ae2afb6e0307774e8c7c608e037d491c63" integrity sha512-B9YofFbUljs19g5gBKUYeLIulsh31U5AK70F41BImQRHEZQGm4GcN922UvnYwkduMqbC/NH+9fruWa/zrqvHIg== +"@types/node-fetch@^2.1.2": + version "2.1.7" + resolved "https://registry.yarnpkg.com/@types/node-fetch/-/node-fetch-2.1.7.tgz#0231559340f6e3f3a0608692077d744c87b5b367" + integrity sha512-TZozHCDVrs0Aj1B9ZR5F4Q9MknDNcVd+hO5lxXOCzz07ELBey6s1gMUSZHUYHlPfRFKJFXiTnNuD7ePiI6S4/g== + dependencies: + "@types/node" "*" + "@types/node@*": version "10.12.18" resolved "https://registry.yarnpkg.com/@types/node/-/node-10.12.18.tgz#1d3ca764718915584fcd9f6344621b7672665c67" @@ -2988,6 +2995,11 @@ nice-try@^1.0.4: resolved "https://registry.yarnpkg.com/nice-try/-/nice-try-1.0.5.tgz#a3378a7696ce7d223e88fc9b764bd7ef1089e366" integrity sha512-1nh45deeb5olNY7eX82BkPO7SSxR5SSYJiPTrTdFUVYwAl8CKMA5N9PjTYkHiRjisVcxcQ1HXdLhx2qxxJzLNQ== +node-fetch@~2.1.2: + version "2.1.2" + resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.1.2.tgz#ab884e8e7e57e38a944753cec706f788d1768bb5" + integrity sha1-q4hOjn5X44qUR1POxwb3iNF2i7U= + node-pre-gyp@^0.10.0: version "0.10.3" resolved "https://registry.yarnpkg.com/node-pre-gyp/-/node-pre-gyp-0.10.3.tgz#3070040716afdc778747b61b6887bf78880b80fc"