Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.

Commit 5a32c42

Browse files
Meghan Lelefacebook-github-bot
authored andcommitted
Make pool tests test layout conversion (#3578)
Summary: **Summary** This commit modifies the max and average pool tests in `GradCheckTest` to test layout conversion as well. At present, `NHWC `-> `NCHW` layout conversion for these test cases ends up becoming a `Reshape`, which can be a no-op on the device. **Testing** These tests are enabled for the `Interpreter`. All unit tests pass. Pull Request resolved: #3578 Differential Revision: D17694767 Pulled By: SplitInfinity fbshipit-source-id: 36cc1dd733c3fefed254520c52824c9acfd17443
1 parent 43edf7f commit 5a32c42

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/unittests/GradCheckTest.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -580,19 +580,20 @@ TEST_P(GradCheck, gradientCheckAvgPool) {
580580
auto &mod = EE->getModule();
581581
bindings.clear();
582582
Function *F = mod.createFunction("main");
583-
A = mod.createPlaceholder(ElemKind::FloatTy, {1, numDim, numDim, 1}, "A",
583+
A = mod.createPlaceholder(ElemKind::FloatTy, {1, numDim, numDim, 2}, "A",
584584
false);
585585
Exp = mod.createPlaceholder(ElemKind::FloatTy, {1, numOutputElem}, "Exp",
586586
false);
587587

588588
Node *O = F->createAvgPool("pool", A, 3, 3, 1);
589589
O = F->createTanh("tanh", O);
590+
O = F->createSlice("slice", O, {0, 0, 0, 0}, {1, 4, 4, 1});
590591
O = F->createFullyConnected(bindings, "fc", O, numOutputElem);
591592
O = F->createRegression("reg", O, Exp);
592593
result = F->createSave("ret", O);
593594
}
594595

595-
Tensor inputs(ElemKind::FloatTy, {1, numDim, numDim, 1});
596+
Tensor inputs(ElemKind::FloatTy, {1, numDim, numDim, 2});
596597
Tensor outputs(ElemKind::FloatTy, {1, numOutputElem});
597598

598599
auto inputsH = inputs.getHandle<>();
@@ -616,19 +617,20 @@ TEST_P(GradCheck, gradientCheckMaxPool) {
616617
auto &mod = EE->getModule();
617618
bindings.clear();
618619
Function *F = mod.createFunction("main");
619-
A = mod.createPlaceholder(ElemKind::FloatTy, {1, numDim, numDim, 1}, "A",
620+
A = mod.createPlaceholder(ElemKind::FloatTy, {1, numDim, numDim, 2}, "A",
620621
false);
621622
Exp = mod.createPlaceholder(ElemKind::FloatTy, {1, numOutputElem}, "Exp",
622623
false);
623624

624625
MaxPoolNode *P = F->createMaxPool("pool", A, 3, 3, 1);
625626
Node *O = F->createTanh("tanh", P->getResult());
627+
O = F->createSlice("slice", O, {0, 0, 0, 0}, {1, 4, 4, 1});
626628
O = F->createFullyConnected(bindings, "fc", O, numOutputElem);
627629
O = F->createRegression("reg", O, Exp);
628630
result = F->createSave("ret", O);
629631
}
630632

631-
Tensor inputs(ElemKind::FloatTy, {1, numDim, numDim, 1});
633+
Tensor inputs(ElemKind::FloatTy, {1, numDim, numDim, 2});
632634
Tensor outputs(ElemKind::FloatTy, {1, numOutputElem});
633635

634636
auto inputsH = inputs.getHandle<>();

0 commit comments

Comments
 (0)