@@ -548,10 +548,10 @@ makeSynPoolParams(llvm::ArrayRef<unsigned_t> kernel,
548548 params->sW = stride[0 ];
549549 params->sH = stride[1 ];
550550 // Padding
551- params->pWbegin = pad[0 ];
552- params->pWend = pad[0 ];
553- params->pHbegin = pad[1 ];
554- params->pHend = pad[1 ];
551+ params->pHbegin = pad[0 ];
552+ params->pWbegin = pad[1 ];
553+ params->pHend = pad[2 ];
554+ params->pWend = pad[3 ];
555555 // Dilation
556556 params->dilW = 1 ;
557557 params->dilH = 1 ;
@@ -591,6 +591,16 @@ makeSynSliceAxisParams(unsigned axis, unsigned axes, unsigned outputAxisSize,
591591 return params;
592592}
593593
594+ static std::unique_ptr<ns_LrnKernel::Params>
595+ makeLrnParams (float alpha, float beta, float knorm, int halfWindowSize) {
596+ auto params = llvm::make_unique<ns_LrnKernel::Params>();
597+ params->alpha = alpha;
598+ params->beta = beta;
599+ params->knorm = knorm;
600+ params->nsize = 2 * halfWindowSize + 1 ;
601+ return params;
602+ }
603+
594604static std::unique_ptr<ns_ConstantKernel::Params>
595605makeConstantParams (float value) {
596606 auto params = llvm::make_unique<ns_ConstantKernel::Params>();
@@ -733,6 +743,8 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
733743 std::vector<std::unique_ptr<ns_TileKernel::Params>> tileParams;
734744 std::vector<std::unique_ptr<unsigned >> concatParams;
735745 std::vector<std::unique_ptr<ns_TakeKernel::Params>> takeParams;
746+ std::vector<std::unique_ptr<ns_LrnKernel::Params>> lrnParams;
747+ std::vector<std::unique_ptr<synGEMMParams>> gemmParams;
736748
737749 // Keep references to tensor pointer arrays passed into multi-input nodes
738750 // until the compilation is done.
@@ -965,12 +977,16 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
965977 if (MI->getLHS ().getType ()->isQuantizedType ()) {
966978 // Let GEMM run on MME via FullyConnected node.
967979 // MME only runs on quantized types, e.g., int8 or int16.
968- auto params = llvm::make_unique<synFCParams>();
969- params->activation .reluEnable = false ;
970- chk (synFullyConnected (tensors[MI->getLHS ()].get (),
971- tensors[MI->getRHS ()].get (), nullptr ,
972- tensors[MI].get (), *params, " " ));
973- fcParams.emplace_back (std::move (params));
980+ // The default params are OK - don't transpose A and B
981+ auto params = llvm::make_unique<synGEMMParams>();
982+ std::vector<synTensor> inputs;
983+ inputs.push_back (tensors[MI->getLHS ()].get ());
984+ inputs.push_back (tensors[MI->getRHS ()].get ());
985+ chk (synCreateGenericNode (inputs.data (), &tensors[MI].get (),
986+ inputs.size (), 1 , nullptr , " gemm" ,
987+ MI->getName ().data (), nullptr , nullptr ));
988+ gemmParams.emplace_back (std::move (params));
989+
974990 } else {
975991 std::vector<synTensor> inputs;
976992 inputs.push_back (tensors[MI->getLHS ()].get ());
@@ -1015,6 +1031,18 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
10151031 convParams.emplace_back (std::move (params));
10161032 break ;
10171033 }
1034+ case Kinded::Kind::LocalResponseNormalizationNodeKind: {
1035+ auto *NI = llvm::cast<LocalResponseNormalizationNode>(&I);
1036+ std::unique_ptr<ns_LrnKernel::Params> params = makeLrnParams (
1037+ NI->getAlpha (), NI->getBeta (), NI->getK (), NI->getHalfWindowSize ());
1038+
1039+ chk (synCreateGenericNode (&tensors[NI->getInput ()].get (),
1040+ &tensors[NI].get (), 1 , 1 , (void *)params.get (),
1041+ " lrn_f32" , NI->getName ().str ().c_str (), nullptr ,
1042+ nullptr ));
1043+ lrnParams.emplace_back (std::move (params));
1044+ break ;
1045+ }
10181046 case Kinded::Kind::TransposeNodeKind: {
10191047 auto *TI = llvm::cast<TransposeNode>(&I);
10201048 std::unique_ptr<synTransposeParams> params =
@@ -1126,6 +1154,14 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
11261154 concatParams.emplace_back (std::move (params));
11271155 break ;
11281156 }
1157+ case Kinded::Kind::RescaleQuantizedNodeKind: {
1158+ auto *RI = llvm::cast<RescaleQuantizedNode>(&I);
1159+ chk (synCreateGenericNode (
1160+ &tensors[RI->getInput ()].get (), &tensors[RI].get (), 1 , 1 , nullptr ,
1161+ getKernelName (" requant" , RI->getResult ().getElementType ()).c_str (),
1162+ RI->getName ().data (), nullptr , nullptr ));
1163+ break ;
1164+ }
11291165 case Kinded::Kind::SaveNodeKind: {
11301166 auto *CI = llvm::cast<SaveNode>(&I);
11311167 if (tensors.count (CI)) {
@@ -1237,7 +1273,11 @@ bool HabanaBackend::isOpSupported(const NodeInfo &NI) const {
12371273 case Kinded::Kind::SplatNodeKind:
12381274 case Kinded::Kind::SubNodeKind:
12391275 case Kinded::Kind::TileNodeKind:
1276+ case Kinded::Kind::ConcatNodeKind:
12401277 return true ;
1278+ case Kinded::Kind::RescaleQuantizedNodeKind:
1279+ return NI.allInputsAndOutputsHaveSameElemKind (
1280+ {ElemKind::Int8QTy, ElemKind::Int16QTy});
12411281 default :
12421282 return false ;
12431283 }
@@ -1273,6 +1313,7 @@ bool HabanaBackend::isOpSupported(const NodeInfo &NI) const {
12731313 case Kinded::Kind::TransposeNodeKind:
12741314 case Kinded::Kind::SparseLengthsWeightedSumNodeKind:
12751315 case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind:
1316+ case Kinded::Kind::LocalResponseNormalizationNodeKind:
12761317 return true ;
12771318 default :
12781319 return false ;
0 commit comments