Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Add tf.util.fetch. #1663

Merged
merged 20 commits into from
Apr 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,3 @@ bazel-out/
yalc.lock
.rpt2_cache/
package/
integration_tests/benchmarks/ui/bundle.js
9 changes: 7 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"@bazel/typescript": "^0.27.8",
"@types/jasmine": "~2.5.53",
"@types/node": "~9.6.0",
"@types/node-fetch": "~2.1.2",
"browserify": "~16.2.3",
"clang-format": "~1.2.4",
"jasmine": "~3.1.0",
Expand Down Expand Up @@ -69,6 +70,10 @@
"@types/seedrandom": "2.4.27",
"@types/webgl-ext": "0.0.30",
"@types/webgl2": "0.0.4",
"seedrandom": "2.4.3"
"seedrandom": "2.4.3",
"node-fetch": "~2.1.2"
},
"browser": {
"node-fetch": false
}
}
}
2 changes: 1 addition & 1 deletion rollup.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function config({plugins = [], output = {}, external = [], visualize = false}) {
node(),
// Polyfill require() from dependencies.
commonjs({
ignore: ['crypto'],
ignore: ['crypto', 'node-fetch'],
include: 'node_modules/**',
namedExports: {
'./node_modules/seedrandom/index.js': ['alea'],
Expand Down
45 changes: 9 additions & 36 deletions src/io/browser_http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
* Uses [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
*/

import {ENV} from '../environment';
import {assert} from '../util';
import {assert, fetch} from '../util';
import {concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils';
import {IORouter, IORouterRegistry} from './router_registry';
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
Expand All @@ -34,7 +33,7 @@ export class BrowserHTTPRequest implements IOHandler {
protected readonly path: string;
protected readonly requestInit: RequestInit;

private readonly fetchFunc: (path: string, init?: RequestInit) => Response;
private readonly fetch: Function;

readonly DEFAULT_METHOD = 'POST';

Expand All @@ -50,31 +49,17 @@ export class BrowserHTTPRequest implements IOHandler {
this.weightPathPrefix = loadOptions.weightPathPrefix;
this.onProgress = loadOptions.onProgress;

if (loadOptions.fetchFunc == null) {
const systemFetch = ENV.global.fetch;
if (typeof systemFetch === 'undefined') {
throw new Error(
'browserHTTPRequest is not supported outside the web browser ' +
'without a fetch polyfill.');
}
// Make sure fetch is always bound to global object (the
// original object) when available.
loadOptions.fetchFunc = systemFetch.bind(ENV.global);
} else {
if (loadOptions.fetchFunc != null) {
assert(
typeof loadOptions.fetchFunc === 'function',
() => 'Must pass a function that matches the signature of ' +
'`fetch` (see ' +
'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)');
this.fetch = loadOptions.fetchFunc;
} else {
this.fetch = fetch;
}

this.fetchFunc = (path: string, requestInits: RequestInit) => {
// tslint:disable-next-line:no-any
return loadOptions.fetchFunc(path, requestInits).catch((error: any) => {
throw new Error(`Request for ${path} failed due to error: ${error}`);
});
};

assert(
path != null && path.length > 0,
() =>
Expand Down Expand Up @@ -133,7 +118,7 @@ export class BrowserHTTPRequest implements IOHandler {
'model.weights.bin');
}

const response = await this.getFetchFunc()(this.path, init);
const response = await this.fetch(this.path, init);

if (response.ok) {
return {
Expand All @@ -156,8 +141,7 @@ export class BrowserHTTPRequest implements IOHandler {
* @returns The loaded model artifacts (if loading succeeds).
*/
async load(): Promise<ModelArtifacts> {
const modelConfigRequest =
await this.getFetchFunc()(this.path, this.requestInit);
const modelConfigRequest = await this.fetch(this.path, this.requestInit);

if (!modelConfigRequest.ok) {
throw new Error(
Expand Down Expand Up @@ -224,22 +208,11 @@ export class BrowserHTTPRequest implements IOHandler {
});
const buffers = await loadWeightsAsArrayBuffer(fetchURLs, {
requestInit: this.requestInit,
fetchFunc: this.getFetchFunc(),
fetchFunc: this.fetch,
onProgress: this.onProgress
});
return [weightSpecs, concatenateArrayBuffers(buffers)];
}

/**
* Helper method to get the `fetch`-like function set for this instance.
*
* This is mainly for avoiding confusion with regard to what context
* the `fetch`-like function is bound to. In the default (browser) case,
* the function will be bound to `window`, instead of `this`.
*/
private getFetchFunc() {
return this.fetchFunc;
}
}

/**
Expand Down
9 changes: 3 additions & 6 deletions src/io/browser_http_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ const setupFakeWeightFiles =
requestInits: {[key: string]: RequestInit}) => {
windowFetchSpy =
// tslint:disable-next-line:no-any
spyOn(global as any, 'fetch')
spyOn(tf.util, 'fetch')
.and.callFake((path: string, init: RequestInit) => {
if (fileBufferMap[path]) {
requestInits[path] = init;
Expand Down Expand Up @@ -162,9 +162,7 @@ describeWithFlags('browserHTTPRequest-load fetch', NODE_ENVS, () => {
try {
tf.io.browserHTTPRequest('./model.json');
} catch (err) {
expect(err.message)
.toMatch(
/not supported outside the web browser without a fetch polyfill/);
expect(err.message).toMatch(/Unable to find fetch polyfill./);
}
});
});
Expand Down Expand Up @@ -199,7 +197,7 @@ describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => {

beforeEach(() => {
requestInits = [];
spyOn(window, 'fetch').and.callFake((path: string, init: RequestInit) => {
spyOn(tf.util, 'fetch').and.callFake((path: string, init: RequestInit) => {
if (path === 'model-upload-test' || path === 'http://model-upload-test') {
requestInits.push(init);
return Promise.resolve(new Response(null, {status: 200}));
Expand Down Expand Up @@ -766,7 +764,6 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
expect(data).toBeDefined();
done.fail('Loading with fetch rejection succeeded unexpectedly.');
} catch (err) {
expect(err.message).toMatch(/Request for path2\/model.json failed /);
done();
}
});
Expand Down
2 changes: 1 addition & 1 deletion src/io/weights_loader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export async function loadWeightsAsArrayBuffer(
}

const fetchFunc =
loadOptions.fetchFunc == null ? fetch : loadOptions.fetchFunc;
loadOptions.fetchFunc == null ? util.fetch : loadOptions.fetchFunc;

// Create the requests for all of the weights in parallel.
const requests =
Expand Down
30 changes: 15 additions & 15 deletions src/io/weights_loader_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {
[filename: string]: Float32Array|Int32Array|ArrayBuffer|Uint8Array|
Uint16Array
}) => {
spyOn(window, 'fetch').and.callFake((path: string) => {
spyOn(tf.util, 'fetch').and.callFake((path: string) => {
return new Response(
fileBufferMap[path],
{headers: {'Content-type': 'application/octet-stream'}});
Expand All @@ -42,7 +42,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {
const weightsNamesToFetch = ['weight0'];
tf.io.loadWeights(manifest, './', weightsNamesToFetch)
.then(weights => {
expect((window.fetch as jasmine.Spy).calls.count()).toBe(1);
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(1);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(weightsNamesToFetch.length);
Expand Down Expand Up @@ -70,7 +70,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {
// Load the first weight.
tf.io.loadWeights(manifest, './', ['weight0'])
.then(weights => {
expect((window.fetch as jasmine.Spy).calls.count()).toBe(1);
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(1);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(1);
Expand Down Expand Up @@ -98,7 +98,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {
// Load the second weight.
tf.io.loadWeights(manifest, './', ['weight1'])
.then(weights => {
expect((window.fetch as jasmine.Spy).calls.count()).toBe(1);
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(1);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(1);
Expand Down Expand Up @@ -126,7 +126,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {
// Load all weights.
tf.io.loadWeights(manifest, './', ['weight0', 'weight1'])
.then(weights => {
expect((window.fetch as jasmine.Spy).calls.count()).toBe(1);
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(1);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(2);
Expand Down Expand Up @@ -168,7 +168,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {
// Load all weights.
tf.io.loadWeights(manifest, './', ['weight0', 'weight1', 'weight2'])
.then(weights => {
expect((window.fetch as jasmine.Spy).calls.count()).toBe(1);
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(1);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(3);
Expand Down Expand Up @@ -210,7 +210,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {

tf.io.loadWeights(manifest, './', ['weight0'])
.then(weights => {
expect((window.fetch as jasmine.Spy).calls.count()).toBe(3);
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(3);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(1);
Expand Down Expand Up @@ -252,7 +252,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {

tf.io.loadWeights(manifest, './', ['weight0', 'weight1'])
.then(weights => {
expect((window.fetch as jasmine.Spy).calls.count()).toBe(3);
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(3);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(2);
Expand Down Expand Up @@ -297,7 +297,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {
tf.io.loadWeights(manifest, './', ['weight0', 'weight1'])
.then(weights => {
// Only the first group should be fetched.
expect((window.fetch as jasmine.Spy).calls.count()).toBe(1);
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(1);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(2);
Expand Down Expand Up @@ -342,7 +342,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {
tf.io.loadWeights(manifest, './', ['weight0', 'weight2'])
.then(weights => {
// Both groups need to be fetched.
expect((window.fetch as jasmine.Spy).calls.count()).toBe(2);
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(2);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(2);
Expand Down Expand Up @@ -388,7 +388,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {
tf.io.loadWeights(manifest, './')
.then(weights => {
// Both groups need to be fetched.
expect((window.fetch as jasmine.Spy).calls.count()).toBe(2);
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(2);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(4);
Expand Down Expand Up @@ -469,8 +469,8 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {
.loadWeights(
manifest, './', weightsNamesToFetch, {credentials: 'include'})
.then(weights => {
expect((window.fetch as jasmine.Spy).calls.count()).toBe(1);
expect(window.fetch).toHaveBeenCalledWith('./weightfile0', {
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(1);
expect(tf.util.fetch).toHaveBeenCalledWith('./weightfile0', {
credentials: 'include'
});
})
Expand Down Expand Up @@ -508,7 +508,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {
const weightsNamesToFetch = ['weight0', 'weight1'];
tf.io.loadWeights(manifest, './', weightsNamesToFetch)
.then(weights => {
expect((window.fetch as jasmine.Spy).calls.count()).toBe(1);
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(1);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(weightsNamesToFetch.length);
Expand Down Expand Up @@ -571,7 +571,7 @@ describeWithFlags('loadWeights', BROWSER_ENVS, () => {
tf.io.loadWeights(manifest, './', ['weight0', 'weight2'])
.then(weights => {
// Both groups need to be fetched.
expect((window.fetch as jasmine.Spy).calls.count()).toBe(2);
expect((tf.util.fetch as jasmine.Spy).calls.count()).toBe(2);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(2);
Expand Down
43 changes: 43 additions & 0 deletions src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* =============================================================================
*/

import {ENV} from './environment';
import {DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray} from './types';

/**
Expand Down Expand Up @@ -664,3 +665,45 @@ export function assertNonNegativeIntegerDimensions(shape: number[]) {
`shape [${shape}].`);
});
}

const getSystemFetch = () => {
if (ENV.global.fetch != null) {
return ENV.global.fetch;
} else if (ENV.get('IS_NODE')) {
return getNodeFetch.fetchImport();
}
throw new Error(
`Unable to find the fetch() method. Please add your own fetch() ` +
`function to the global namespace.`);
};

// We are wrapping this within an object so it can be stubbed by Jasmine.
export const getNodeFetch = {
fetchImport: () => {
// tslint:disable-next-line:no-require-imports
return require('node-fetch');
}
};

/**
* Returns a platform-specific implementation of
* [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
*
* If `fetch` is defined on the global object (`window`, `process`, etc.),
* `tf.util.fetch` returns that function.
*
* If not, `tf.util.fetch` returns a platform-specific solution.
*
* ```js
* const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs');
* // handle response
* ```
*/
/** @doc {heading: 'Util'} */
export let systemFetch: Function;
export function fetch(path: string, requestInits?: RequestInit) {
if (systemFetch == null) {
systemFetch = getSystemFetch();
}
return systemFetch(path, requestInits);
}
Loading