Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.

Commit 63be3d6

Browse files
louisfengfacebook-github-bot
authored andcommitted
Added numerical checking for GraphOptzTest mergeNonInverseTransposes. (#3580)
Summary: Added numerical result checking for original and optimized graph. Fixes T54749316 Pull Request resolved: #3580 Test Plan: => ctest -R GraphOptz Test project /Users/lofe/git/glow/build_Debug Start 8: GraphOptzTest 1/1 Test #8: GraphOptzTest .................... Passed 0.09 sec 100% tests passed, 0 tests failed out of 1 Total Test time (real) = 0.11 sec Differential Revision: D17699478 Pulled By: SplitInfinity fbshipit-source-id: df59bebe990f88947ff782e3d6189d694d2e4ccc
1 parent 5a32c42 commit 63be3d6

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

tests/unittests/GraphOptzTest.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,8 @@ TEST_F(GraphOptz, mergeNonInverseTransposes) {
846846
const size_t origDims[] = {1, 5, 10, 15};
847847
const size_t finalDims[] = {5, 1, 15, 10};
848848

849-
Node *A = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
849+
Placeholder *A =
850+
mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
850851
TransposeNode *T1 = F_->createTranspose("transpose", A, {0, 3, 2, 1});
851852
TransposeNode *T2 = F_->createTranspose("transpose", T1, {0, 2, 3, 1});
852853
TransposeNode *T3 = F_->createTranspose("transpose", T2, {1, 0, 3, 2});
@@ -863,15 +864,25 @@ TEST_F(GraphOptz, mergeNonInverseTransposes) {
863864

864865
EXPECT_EQ(F_->getNodes().size(), 5);
865866

866-
::glow::optimize(F_, CompilationMode::Infer);
867-
867+
optimizedF_ = optimizeFunction(F_);
868+
// Find save node in the optimized graph.
869+
for (auto &N : optimizedF_->getNodes()) {
870+
if (N.getKind() == Kinded::Kind::SaveNodeKind) {
871+
save = llvm::dyn_cast<SaveNode>(&N);
872+
}
873+
}
874+
// Get the last transpose node in the optimized graph.
868875
auto *TR = llvm::dyn_cast<TransposeNode>(save->getInput());
869876
ASSERT_NE(TR, nullptr);
870877

871-
EXPECT_EQ(F_->getNodes().size(), 2);
878+
EXPECT_EQ(optimizedF_->getNodes().size(), 2);
872879
EXPECT_EQ(TR->getResult().dims(), llvm::makeArrayRef(finalDims));
873880
EXPECT_EQ(A->getNthResult(0).dims(), llvm::makeArrayRef(origDims));
874881
EXPECT_EQ(TR->getInput().getNode(), A);
882+
883+
bindings_.allocate(mod_.getPlaceholders());
884+
bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
885+
checkNumericalEquivalence();
875886
}
876887

877888
TEST_F(GraphOptz, sinkTransposeBelowArithmeticNodes) {

0 commit comments

Comments
 (0)