@@ -54,7 +54,7 @@ static synDataType getSynType(ElemKind kind) {
5454 case ElemKind::FloatTy:
5555 return syn_type_single;
5656 case ElemKind::Float16Ty:
57- return syn_type_half ;
57+ GLOW_UNREACHABLE ( " Unhandled ElemKind: Float16Ty " ) ;
5858 case ElemKind::Int8QTy:
5959 return syn_type_fixed;
6060 case ElemKind::Int16QTy:
@@ -92,18 +92,6 @@ static std::string getKernelName(llvm::StringRef kernelBase, ElemKind kind) {
9292 return std::string (kernelBase) + getKernelSuffix (kind);
9393}
9494
95- // / If \p PH is an output placeholder, \returns the SaveNode.
96- static SaveNode *getOutputSave (Function *F, Placeholder *PH) {
97- for (auto &use : PH->getUsers ()) {
98- if (auto *save = llvm::dyn_cast<SaveNode>(use.getUser ())) {
99- if (save->getParent () == F && save->getPlaceholder () == PH) {
100- return save;
101- }
102- }
103- }
104- return nullptr ;
105- }
106-
10795namespace {
10896// / Parameters for pooling operation.
10997struct synPoolParams {
@@ -217,23 +205,6 @@ class TensorHandle final {
217205 // Model params need to be floats, even if the tensor is integral or
218206 // quantized.
219207 if (ioType == IOType::Static) {
220- // Quantized types: dequantize into float buffer.
221- if (V->isQuantizedType ()) {
222- // Check that a weight buffer was passed in; these are model params.
223- assert (!allocated_);
224- Type type = *V;
225- if (V->getElementType () == ElemKind::UInt8FusedQTy) {
226- // Fused quantized values just need to be passed through in raw form.
227- type = Type (ElemKind::Int8QTy, V->dims (), 1.0 , 0 );
228- }
229- Tensor DT = quantization::dequantizeTensor (Tensor (buffer_, &type),
230- ElemKind::FloatTy);
231- auto bytes = DT.getSizeInBytes ();
232- buffer_ = malloc (bytes);
233- memcpy (buffer_, DT.getUnsafePtr (), bytes);
234- allocated_ = true ;
235- }
236-
237208 // Int32ITy: Cast to floats.
238209 if (V->getElementType () == ElemKind::Int32ITy) {
239210 float *floats_ = (float *)malloc (V->size () * sizeof (float ));
@@ -258,11 +229,19 @@ class TensorHandle final {
258229 // Create tensor descriptor, with quantization params if needed.
259230 synTensorDescriptor desc (elemType, rdims.size (), rdims.data (), buffer_,
260231 synMemoryHost, false , name_.data ());
261- if (V->isQuantizedType () &&
262- V->getElementType () != ElemKind::UInt8FusedQTy) {
263- desc.m_quantizationParams [0 ].m_zp = V->getOffset ();
264- desc.m_quantizationParams [0 ].m_scale = V->getScale ();
232+ if (V->isQuantizedType ()) {
233+ if (V->getElementType () == ElemKind::UInt8FusedQTy) {
234+ desc.m_quantizationParams [0 ].m_zp = 0 ;
235+ desc.m_quantizationParams [0 ].m_scale = 1 ;
236+ } else {
237+ desc.m_quantizationParams [0 ].m_zp = V->getOffset ();
238+ desc.m_quantizationParams [0 ].m_scale = V->getScale ();
239+ }
240+
265241 desc.m_quantizationParams [0 ].m_qDataType = elemType;
242+ if (ioType == IOType::Static) {
243+ desc.m_isQuantized = true ;
244+ }
266245 }
267246
268247 chk (synCreateTensor (&desc, &tensor_, ioType == IOType::Output, false ,
@@ -307,6 +286,9 @@ class TensorHandle final {
307286 // / Get the underlying data buffer.
308287 void *getData () const { return buffer_; }
309288
289+ // / Get the name of the managed tensor
290+ const std::string &getName () const { return name_; }
291+
310292 // / Get the dimensions of the stored tensor.
311293 llvm::ArrayRef<unsigned > dims () const { return dims_; }
312294
@@ -665,15 +647,18 @@ allocateGraphTensors(Function *F) {
665647 continue ;
666648 }
667649 if (auto *save = getOutputSave (F, V)) {
668- // We want to avoid emitting copies for save nodes by simply marking the
669- // save input as an "output" tensor. The exceptions are when the input
670- // is itself a placeholder/constant, or a reshape. (The reshape case is
671- // likely a Synapse bug.)
650+ // Naively, we'd generate a memcpy for any SaveNode, but that's a waste
651+ // so we want to avoid it. We can optimize it away by mapping the
652+ // SaveNode's input node (N, below) to the output tensor, and then simply
653+ // not generating a memcpy if the SaveNode itself has no associated
654+ // tensor.
672655 auto *N = save->getInput ().getNode ();
673- Node *proxy =
674- (llvm::isa<Storage>(N) || llvm::isa<HabanaReshapeNode>(N)) ? save : N;
675- tensors.emplace (proxy, TensorHandle (V->getType (), V->getName (), nullptr ,
676- IOType::Output));
656+ if (llvm::isa<Storage>(N) || llvm::isa<HabanaReshapeNode>(N) ||
657+ N->getNumUsers () > 1 ) {
658+ N = save;
659+ }
660+ tensors.emplace (
661+ N, TensorHandle (V->getType (), V->getName (), nullptr , IOType::Output));
677662 } else {
678663 tensors.emplace (V, TensorHandle (V->getType (), V->getName (), nullptr ,
679664 IOType::Default));
@@ -747,6 +732,7 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
747732 std::vector<std::unique_ptr<ns_ConstantKernel::Params>> constantParams;
748733 std::vector<std::unique_ptr<ns_TileKernel::Params>> tileParams;
749734 std::vector<std::unique_ptr<unsigned >> concatParams;
735+ std::vector<std::unique_ptr<ns_TakeKernel::Params>> takeParams;
750736
751737 // Keep references to tensor pointer arrays passed into multi-input nodes
752738 // until the compilation is done.
@@ -755,6 +741,10 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
755741 std::vector<TensorHandle> tempTensors;
756742
757743 for (const auto &I : F->getNodes ()) {
744+ if (!isOpSupported (I)) {
745+ llvm::errs () << " Unsupported operator: " << I.getDebugDesc () << " \n " ;
746+ GLOW_UNREACHABLE (" Unsupported operator" );
747+ }
758748 switch (I.getKind ()) {
759749 case Kinded::Kind::HabanaFullyConnectedNodeKind: {
760750 auto *NI = llvm::cast<HabanaFullyConnectedNode>(&I);
@@ -1116,7 +1106,17 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
11161106 makeConcatParams (CI->getDim (), tensors[CI].dims ().size ());
11171107 std::vector<synTensor> inputs;
11181108 for (auto const &N : CI->getInputs ()) {
1119- inputs.push_back (tensors[N].get ());
1109+ std::string memcpyNodeName =
1110+ llvm::formatv (" {0}_memcpy_{1}" , N.getNode ()->getName (),
1111+ inputs.size ())
1112+ .str ();
1113+ TensorHandle memcpy (N.getType (), memcpyNodeName);
1114+ chk (synCreateGenericNode (
1115+ &tensors[N].get (), &memcpy.get (), 1 , 1 , nullptr ,
1116+ getKernelName (" memcpy" , N.getType ()->getElementType ()).c_str (),
1117+ memcpy.getName ().c_str (), nullptr , nullptr ));
1118+ inputs.push_back (memcpy.get ());
1119+ tempTensors.emplace_back (std::move (memcpy));
11201120 }
11211121
11221122 chk (synCreateGenericNode (inputs.data (), &tensors[CI].get (), inputs.size (),
@@ -1165,6 +1165,25 @@ HabanaBackend::compile(Function *F, const BackendOptions &opts) const {
11651165 multiInputs.emplace_back (std::move (inputs));
11661166 break ;
11671167 }
1168+ case Kinded::Kind::GatherNodeKind: {
1169+ auto *gather = llvm::cast<GatherNode>(&I);
1170+ std::vector<synTensor> inputs = {tensors[gather->getData ()].get (),
1171+ tensors[gather->getIndices ()].get ()};
1172+
1173+ auto params = llvm::make_unique<ns_TakeKernel::Params>();
1174+ params->axis =
1175+ gather->getData ().dims ().size () - gather->getBatchDims () - 1 ;
1176+ params->mode = 0 ;
1177+
1178+ chk (synCreateGenericNode (
1179+ inputs.data (), &tensors[gather].get (), inputs.size (), 1 , params.get (),
1180+ getKernelName (" take" , gather->getResult ().getElementType ()).c_str (),
1181+ gather->getName ().data (), nullptr , nullptr ));
1182+
1183+ multiInputs.emplace_back (std::move (inputs));
1184+ takeParams.emplace_back (std::move (params));
1185+ break ;
1186+ }
11681187 default : {
11691188 llvm::errs () << " Unhandled node: " << I.getDebugDesc () << " \n " ;
11701189 GLOW_UNREACHABLE (" Unhandled node" );
0 commit comments