@@ -372,38 +372,37 @@ ConvolutionNode *Function::createConv(llvm::StringRef name, NodeValue input,
372372 auto OT = getParent ()->uniqueType (ElemKind::FloatTy, outDims);
373373
374374 return addNode (new ConvolutionNode (name, OT, input, filter, bias, kernel,
375- stride, pad, depth, /* group = */ 1 ));
375+ stride, pad, /* group = */ 1 ));
376376}
377377
378378// / Check that the dimensions that are passed in when the convolution is
379379// / constructed are correct.
380380static void assertConvDims (NodeValue input, NodeValue filter, NodeValue bias,
381- size_t depth , size_t kernel , size_t stride ,
382- size_t pad, size_t group) {
381+ size_t kernel , size_t stride , size_t pad ,
382+ size_t group) {
383383 ShapeNHWC idim = ShapeNHWC (input.dims ());
384384 assert (idim.w >= kernel && idim.h >= kernel &&
385385 " buffer too small for selected stride" );
386386 assert (idim.c % group == 0 && " channels number must be divisible by groups" );
387387 (void )idim;
388388
389389 auto filterDims = filter->dims ();
390- assert (filterDims[0 ] == depth * group && filterDims[1 ] == kernel &&
390+ assert (filterDims[0 ] % group == 0 && filterDims[1 ] == kernel &&
391391 filterDims[2 ] == kernel && filterDims[3 ] == idim.c / group &&
392392 " Invalid filter dims" );
393393 (void )filterDims;
394394
395- assert (bias->getType ()->size () == depth * group && " Invalid bias size" );
395+ assert (bias->getType ()->size () == filterDims[ 0 ] && " Invalid bias size" );
396396}
397397
398398ConvolutionNode *Function::createConv (llvm::StringRef name, NodeValue input,
399399 NodeValue filter, NodeValue bias,
400- TypeRef outTy, size_t depth,
401- size_t kernel, size_t stride, size_t pad,
402- size_t group) {
403- assertConvDims (input, filter, bias, depth, kernel, stride, pad, group);
400+ TypeRef outTy, size_t kernel,
401+ size_t stride, size_t pad, size_t group) {
402+ assertConvDims (input, filter, bias, kernel, stride, pad, group);
404403 auto OT = getParent ()->uniqueType (*outTy);
405404 return addNode (new ConvolutionNode (name, OT, input, filter, bias, kernel,
406- stride, pad, depth, group));
405+ stride, pad, group));
407406}
408407
409408PoolMaxNode *Function::createPoolMax (llvm::StringRef name, NodeValue input,
0 commit comments