Skip to content
This repository was archived by the owner on Dec 7, 2021. It is now read-only.

Commit 9d64f4a

Browse files
hermanhoJacopoMangiavacchi
authored andcommitted
fix: test asset distribution to include all tags on test/train split (#823)
* fix: test asset distribution to include all tags on test/train split The test asset may not included all tags when export with test/train split option in current venison (2.1.0). * Extract the same split logic into helper function * Formatting * Inverting if statement
1 parent c0201ca commit 9d64f4a

File tree

6 files changed

+229
-41
lines changed

6 files changed

+229
-41
lines changed

src/providers/export/cntk.test.ts

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,37 @@ describe("CNTK Export Provider", () => {
115115

116116
const assetsToExport = await getAssetsSpy.mock.results[0].value;
117117
const testSplit = (100 - (defaultOptions.testTrainSplit || 80)) / 100;
118-
const testCount = Math.ceil(assetsToExport.length * testSplit);
119-
const testArray = assetsToExport.slice(0, testCount);
120-
const trainArray = assetsToExport.slice(testCount, assetsToExport.length);
118+
119+
const trainArray = [];
120+
const testArray = [];
121+
const tagsAssestList: {
122+
[index: string]: {
123+
assetSet: Set<string>,
124+
testArray: string[],
125+
trainArray: string[],
126+
},
127+
} = {};
128+
testProject.tags.forEach((tag) =>
129+
tagsAssestList[tag.name] = {
130+
assetSet: new Set(), testArray: [],
131+
trainArray: [],
132+
});
133+
assetsToExport.forEach((assetMetadata) => {
134+
assetMetadata.regions.forEach((region) => {
135+
region.tags.forEach((tagName) => {
136+
if (tagsAssestList[tagName]) {
137+
tagsAssestList[tagName].assetSet.add(assetMetadata.asset.name);
138+
}
139+
});
140+
});
141+
});
142+
143+
for (const tagKey of Object.keys(tagsAssestList)) {
144+
const assetSet = tagsAssestList[tagKey].assetSet;
145+
const testCount = Math.ceil(assetSet.size * testSplit);
146+
testArray.push(...Array.from(assetSet).slice(0, testCount));
147+
trainArray.push(...Array.from(assetSet).slice(testCount, assetSet.size));
148+
}
121149

122150
const storageProviderMock = LocalFileSystemProxy as any;
123151
const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls;

src/providers/export/cntk.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { ExportProvider, IExportResults } from "./exportProvider";
33
import { IAssetMetadata, IExportProviderOptions, IProject } from "../../models/applicationState";
44
import HtmlFileReader from "../../common/htmlFileReader";
55
import Guard from "../../common/guard";
6+
import { splitTestAsset } from "./testAssetsSplitHelper";
67

78
enum ExportSplit {
89
Test,
@@ -33,13 +34,17 @@ export class CntkExportProvider extends ExportProvider<ICntkExportProviderOption
3334
public async export(): Promise<IExportResults> {
3435
await this.createFolderStructure();
3536
const assetsToExport = await this.getAssetsForExport();
37+
const testAssets: string[] = [];
38+
3639
const testSplit = (100 - (this.options.testTrainSplit || 80)) / 100;
37-
const testCount = Math.ceil(assetsToExport.length * testSplit);
38-
const testArray = assetsToExport.slice(0, testCount);
40+
if (testSplit > 0 && testSplit <= 1) {
41+
const splittedAssets = splitTestAsset(assetsToExport, this.project.tags, testSplit);
42+
testAssets.push(...splittedAssets);
43+
}
3944

4045
const results = await assetsToExport.mapAsync(async (assetMetadata) => {
4146
try {
42-
const exportSplit = testArray.find((am) => am.asset.id === assetMetadata.asset.id)
47+
const exportSplit = testAssets.find((am) => am === assetMetadata.asset.id)
4348
? ExportSplit.Test
4449
: ExportSplit.Train;
4550

src/providers/export/pascalVOC.test.ts

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ describe("PascalVOC Json Export Provider", () => {
6969
beforeEach(() => {
7070
const assetServiceMock = AssetService as jest.Mocked<typeof AssetService>;
7171
assetServiceMock.prototype.getAssetMetadata = jest.fn((asset) => {
72-
const mockTag = MockFactory.createTestTag();
72+
const mockTag1 = MockFactory.createTestTag("1");
73+
const mockTag2 = MockFactory.createTestTag("2");
74+
const mockTag = Number(asset.id.split("-")[1]) > 7 ? mockTag1 : mockTag2;
7375
const mockRegion1 = MockFactory.createTestRegion("region-1", [mockTag.name]);
7476
const mockRegion2 = MockFactory.createTestRegion("region-2", [mockTag.name]);
7577

@@ -352,27 +354,70 @@ describe("PascalVOC Json Export Provider", () => {
352354
};
353355

354356
const testProject = { ...baseTestProject };
355-
const testAssets = MockFactory.createTestAssets(10, 0);
357+
const testAssets = MockFactory.createTestAssets(13, 0);
356358
testAssets.forEach((asset) => asset.state = AssetState.Tagged);
357359
testProject.assets = _.keyBy(testAssets, (asset) => asset.id);
358-
testProject.tags = [MockFactory.createTestTag("1")];
360+
testProject.tags = MockFactory.createTestTags(3);
359361

360362
const exportProvider = new PascalVOCExportProvider(testProject, options);
363+
const getAssetsSpy = jest.spyOn(exportProvider, "getAssetsForExport");
364+
361365
await exportProvider.export();
362366

363367
const storageProviderMock = LocalFileSystemProxy as any;
364368
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls as any[];
365369

366-
const valDataIndex = writeTextFileCalls
370+
const valDataIndex1 = writeTextFileCalls
367371
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_val.txt"));
368-
const trainDataIndex = writeTextFileCalls
372+
const trainDataIndex1 = writeTextFileCalls
369373
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 1_train.txt"));
370-
371-
const expectedTrainCount = (testTrainSplit / 100) * testAssets.length;
372-
const expectedTestCount = ((100 - testTrainSplit) / 100) * testAssets.length;
373-
374-
expect(writeTextFileCalls[valDataIndex][1].split("\n")).toHaveLength(expectedTestCount);
375-
expect(writeTextFileCalls[trainDataIndex][1].split("\n")).toHaveLength(expectedTrainCount);
374+
const valDataIndex2 = writeTextFileCalls
375+
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 2_val.txt"));
376+
const trainDataIndex2 = writeTextFileCalls
377+
.findIndex((args) => args[0].endsWith("/ImageSets/Main/Tag 2_train.txt"));
378+
379+
const assetsToExport = await getAssetsSpy.mock.results[0].value;
380+
const trainArray = [];
381+
const testArray = [];
382+
const tagsAssestList: {
383+
[index: string]: {
384+
assetSet: Set<string>,
385+
testArray: string[],
386+
trainArray: string[],
387+
},
388+
} = {};
389+
testProject.tags.forEach((tag) =>
390+
tagsAssestList[tag.name] = {
391+
assetSet: new Set(), testArray: [],
392+
trainArray: [],
393+
});
394+
assetsToExport.forEach((assetMetadata) => {
395+
assetMetadata.regions.forEach((region) => {
396+
region.tags.forEach((tagName) => {
397+
if (tagsAssestList[tagName]) {
398+
tagsAssestList[tagName].assetSet.add(assetMetadata.asset.name);
399+
}
400+
});
401+
});
402+
});
403+
404+
for (const tagKey of Object.keys(tagsAssestList)) {
405+
const assetSet = tagsAssestList[tagKey].assetSet;
406+
const testCount = Math.ceil(((100 - testTrainSplit) / 100) * assetSet.size);
407+
tagsAssestList[tagKey].testArray = Array.from(assetSet).slice(0, testCount);
408+
tagsAssestList[tagKey].trainArray = Array.from(assetSet).slice(testCount, assetSet.size);
409+
testArray.push(...tagsAssestList[tagKey].testArray);
410+
trainArray.push(...tagsAssestList[tagKey].trainArray);
411+
}
412+
413+
expect(writeTextFileCalls[valDataIndex1][1].split(/\r?\n/).filter((line) =>
414+
line.endsWith(" 1"))).toHaveLength(tagsAssestList["Tag 1"].testArray.length);
415+
expect(writeTextFileCalls[trainDataIndex1][1].split(/\r?\n/).filter((line) =>
416+
line.endsWith(" 1"))).toHaveLength(tagsAssestList["Tag 1"].trainArray.length);
417+
expect(writeTextFileCalls[valDataIndex2][1].split(/\r?\n/).filter((line) =>
418+
line.endsWith(" 1"))).toHaveLength(tagsAssestList["Tag 2"].testArray.length);
419+
expect(writeTextFileCalls[trainDataIndex2][1].split(/\r?\n/).filter((line) =>
420+
line.endsWith(" 1"))).toHaveLength(tagsAssestList["Tag 2"].trainArray.length);
376421
}
377422

378423
it("Correctly generated files based on 50/50 test / train split", async () => {

src/providers/export/pascalVOC.ts

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import HtmlFileReader from "../../common/htmlFileReader";
66
import { itemTemplate, annotationTemplate, objectTemplate } from "./pascalVOC/pascalVOCTemplates";
77
import { interpolate } from "../../common/strings";
88
import os from "os";
9+
import { splitTestAsset } from "./testAssetsSplitHelper";
910

1011
interface IObjectInfo {
1112
name: string;
@@ -253,40 +254,58 @@ export class PascalVOCExportProvider extends ExportProvider<IPascalVOCExportProv
253254
}
254255
});
255256

256-
// Save ImageSets
257-
await tags.forEachAsync(async (tag) => {
258-
const tagInstances = tagUsage.get(tag.name) || 0;
259-
if (!exportUnassignedTags && tagInstances === 0) {
260-
return;
261-
}
257+
if (testSplit > 0 && testSplit <= 1) {
258+
const tags = this.project.tags;
259+
const testAssets: string[] = splitTestAsset(allAssets, tags, testSplit);
262260

263-
const assetList = [];
264-
assetUsage.forEach((tags, assetName) => {
265-
if (tags.has(tag.name)) {
266-
assetList.push(`${assetName} 1`);
267-
} else {
268-
assetList.push(`${assetName} -1`);
261+
await tags.forEachAsync(async (tag) => {
262+
const tagInstances = tagUsage.get(tag.name) || 0;
263+
if (!exportUnassignedTags && tagInstances === 0) {
264+
return;
269265
}
270-
});
271-
272-
if (testSplit > 0 && testSplit <= 1) {
273-
// Split in Test and Train sets
274-
const totalAssets = assetUsage.size;
275-
const testCount = Math.ceil(totalAssets * testSplit);
276-
277-
const testArray = assetList.slice(0, testCount);
278-
const trainArray = assetList.slice(testCount, totalAssets);
266+
const testArray = [];
267+
const trainArray = [];
268+
assetUsage.forEach((tags, assetName) => {
269+
let assetString = "";
270+
if (tags.has(tag.name)) {
271+
assetString = `${assetName} 1`;
272+
} else {
273+
assetString = `${assetName} -1`;
274+
}
275+
if (testAssets.find((am) => am === assetName)) {
276+
testArray.push(assetString);
277+
} else {
278+
trainArray.push(assetString);
279+
}
280+
});
279281

280282
const testImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_val.txt`;
281283
await this.storageProvider.writeText(testImageSetFileName, testArray.join(os.EOL));
282284

283285
const trainImageSetFileName = `${imageSetsMainFolderName}/${tag.name}_train.txt`;
284286
await this.storageProvider.writeText(trainImageSetFileName, trainArray.join(os.EOL));
287+
});
288+
} else {
289+
290+
// Save ImageSets
291+
await tags.forEachAsync(async (tag) => {
292+
const tagInstances = tagUsage.get(tag.name) || 0;
293+
if (!exportUnassignedTags && tagInstances === 0) {
294+
return;
295+
}
296+
297+
const assetList = [];
298+
assetUsage.forEach((tags, assetName) => {
299+
if (tags.has(tag.name)) {
300+
assetList.push(`${assetName} 1`);
301+
} else {
302+
assetList.push(`${assetName} -1`);
303+
}
304+
});
285305

286-
} else {
287306
const imageSetFileName = `${imageSetsMainFolderName}/${tag.name}.txt`;
288307
await this.storageProvider.writeText(imageSetFileName, assetList.join(os.EOL));
289-
}
290-
});
308+
});
309+
}
291310
}
292311
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import _ from "lodash";
2+
import {
3+
IAssetMetadata, AssetState, IRegion,
4+
RegionType, IPoint, IExportProviderOptions,
5+
} from "../../models/applicationState";
6+
import MockFactory from "../../common/mockFactory";
7+
import { splitTestAsset } from "./testAssetsSplitHelper";
8+
import { appInfo } from "../../common/appInfo";
9+
10+
describe("splitTestAsset Helper tests", () => {
11+
12+
describe("Test Train Splits", () => {
13+
async function testTestTrainSplit(testTrainSplit: number): Promise<void> {
14+
const assetArray = MockFactory.createTestAssets(13, 0);
15+
const tags = MockFactory.createTestTags(2);
16+
assetArray.forEach((asset) => asset.state = AssetState.Tagged);
17+
18+
const testSplit = (100 - testTrainSplit) / 100;
19+
const testCount = Math.ceil(testSplit * assetArray.length);
20+
21+
const assetMetadatas = assetArray.map((asset, i) =>
22+
MockFactory.createTestAssetMetadata(asset,
23+
i < (assetArray.length - testCount) ?
24+
[MockFactory.createTestRegion("Region" + i, [tags[0].name])] :
25+
[MockFactory.createTestRegion("Region" + i, [tags[1].name])]));
26+
const testAssetsNames = splitTestAsset(assetMetadatas, tags, testSplit);
27+
28+
const trainAssetsArray = assetMetadatas.filter((assetMetadata) =>
29+
testAssetsNames.indexOf(assetMetadata.asset.name) < 0);
30+
const testAssetsArray = assetMetadatas.filter((assetMetadata) =>
31+
testAssetsNames.indexOf(assetMetadata.asset.name) >= 0);
32+
33+
const expectedTestCount = Math.ceil(testSplit * testCount) +
34+
Math.ceil(testSplit * (assetArray.length - testCount));
35+
expect(testAssetsNames).toHaveLength(expectedTestCount);
36+
expect(trainAssetsArray.length + testAssetsArray.length).toEqual(assetMetadatas.length);
37+
expect(testAssetsArray).toHaveLength(expectedTestCount);
38+
39+
expect(testAssetsArray.filter((assetMetadata) => assetMetadata.regions[0].tags[0] === tags[0].name).length)
40+
.toBeGreaterThan(0);
41+
expect(testAssetsArray.filter((assetMetadata) => assetMetadata.regions[0].tags[0] === tags[1].name).length)
42+
.toBeGreaterThan(0);
43+
}
44+
45+
it("Correctly generated files based on 50/50 test / train split", async () => {
46+
await testTestTrainSplit(50);
47+
});
48+
49+
it("Correctly generated files based on 60/40 test / train split", async () => {
50+
await testTestTrainSplit(60);
51+
});
52+
53+
it("Correctly generated files based on 80/20 test / train split", async () => {
54+
await testTestTrainSplit(80);
55+
});
56+
57+
it("Correctly generated files based on 90/10 test / train split", async () => {
58+
await testTestTrainSplit(90);
59+
});
60+
});
61+
});
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import { IAssetMetadata, ITag } from "../../models/applicationState";
2+
3+
/**
4+
* A helper function to split train and test assets
5+
* @param template String containing variables
6+
* @param params Params containing substitution values
7+
*/
8+
export function splitTestAsset(allAssets: IAssetMetadata[], tags: ITag[], testSplitRatio: number): string[] {
9+
if (testSplitRatio <= 0 || testSplitRatio > 1) { return []; }
10+
11+
const testAssets: string[] = [];
12+
const tagsAssetDict: { [index: string]: { assetList: Set<string> } } = {};
13+
tags.forEach((tag) => tagsAssetDict[tag.name] = { assetList: new Set() });
14+
allAssets.forEach((assetMetadata) => {
15+
assetMetadata.regions.forEach((region) => {
16+
region.tags.forEach((tagName) => {
17+
if (tagsAssetDict[tagName]) {
18+
tagsAssetDict[tagName].assetList.add(assetMetadata.asset.name);
19+
}
20+
});
21+
});
22+
});
23+
24+
for (const tagKey of Object.keys(tagsAssetDict)) {
25+
const assetList = tagsAssetDict[tagKey].assetList;
26+
const testCount = Math.ceil(assetList.size * testSplitRatio);
27+
testAssets.push(...Array.from(assetList).slice(0, testCount));
28+
}
29+
return testAssets;
30+
}

0 commit comments

Comments
 (0)