Skip to content

Commit 0cd8342

Browse files
authored
[Cherry-Pick][CINN] Release fix dynamic arange related model core dump bugs (#74437)
* [CINN] Robustify min/max type matching (#74316) * [CINN] Fixed dynamic arange symbolic values extraction. (#74412) * [CINN] Fix cinn_op.generate_op attribute storing useless dim_expr * [CINN] Removed unnecessary VLOGs * [CINN] Simplify dynamic arange logic and fix bugs.
1 parent 1983aea commit 0cd8342

File tree

6 files changed

+72
-125
lines changed

6 files changed

+72
-125
lines changed

paddle/cinn/backends/codegen_gpu_dev.cc

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -217,14 +217,33 @@ void CodeGenGpuDev::VisitStmt(const ir::stmt::Alloc &stmt) {
217217
PrintTempBufferCreation(stmt->destination().as_buffer_ref());
218218
}
219219

220+
inline void ProcessMinMaxOperand(ir::Expr *a,
221+
ir::Expr *b,
222+
int unify_bit,
223+
bool both_dyn) {
224+
if (unify_bit > 0) {
225+
std::string type_func = "int" + std::to_string(unify_bit) + "_t";
226+
if (both_dyn) {
227+
// if both contains dynamic symbol, like: min(S0, S1), it it likely that
228+
// S0 is int and S1 is int64_t. So we need to enforce the type cast by
229+
// ir::Call
230+
*a = ir::Call::Make(
231+
common::Int(unify_bit), type_func, {*a}, {}, ir::CallType::Intrinsic);
232+
*b = ir::Call::Make(
233+
common::Int(unify_bit), type_func, {*b}, {}, ir::CallType::Intrinsic);
234+
} else {
235+
*a = ir::Cast::Make(common::Int(unify_bit), *a);
236+
*b = ir::Cast::Make(common::Int(unify_bit), *b);
237+
}
238+
}
239+
}
240+
220241
void CodeGenGpuDev::Visit(const ir::Min *op) {
221242
str_ += "min(";
222243
ir::Expr a = op->a(), b = op->b();
223-
int unify_bit = common::UnifiedOperandTypeBits(&dynamic_shape_map_, op);
224-
if (unify_bit > 0) {
225-
a = ir::Cast::Make(common::Int(unify_bit), a);
226-
b = ir::Cast::Make(common::Int(unify_bit), b);
227-
}
244+
auto [unify_bit, both_dyn] =
245+
common::UnifiedOperandTypeBits(&dynamic_shape_map_, op);
246+
ProcessMinMaxOperand(&a, &b, unify_bit, both_dyn);
228247
IrPrinter::Visit(a);
229248
str_ += ", ";
230249
IrPrinter::Visit(b);
@@ -234,11 +253,9 @@ void CodeGenGpuDev::Visit(const ir::Min *op) {
234253
void CodeGenGpuDev::Visit(const ir::Max *op) {
235254
str_ += "max(";
236255
ir::Expr a = op->a(), b = op->b();
237-
int unify_bit = common::UnifiedOperandTypeBits(&dynamic_shape_map_, op);
238-
if (unify_bit > 0) {
239-
a = ir::Cast::Make(common::Int(unify_bit), a);
240-
b = ir::Cast::Make(common::Int(unify_bit), b);
241-
}
256+
auto [unify_bit, both_dyn] =
257+
common::UnifiedOperandTypeBits(&dynamic_shape_map_, op);
258+
ProcessMinMaxOperand(&a, &b, unify_bit, both_dyn);
242259
IrPrinter::Visit(a);
243260
str_ += ", ";
244261
IrPrinter::Visit(b);

paddle/cinn/common/ir_util.cc

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -614,21 +614,21 @@ struct DynamicSymbolExprBitTracker : public ir::IRVisitor {
614614
int dyn_symbol_bit = 0;
615615
};
616616

617-
#define VISIT_OP(NodeType) \
618-
int UnifiedOperandTypeBits( \
619-
const std::unordered_map<std::string, common::Type> *search_map, \
620-
const ir::NodeType *node) { \
621-
if (search_map->empty()) return 0; \
622-
if (!node->a().type().is_int() || !node->b().type().is_int()) return 0; \
623-
int node_a_bits = node->a().type().bits(); \
624-
int node_b_bits = node->b().type().bits(); \
625-
if (node_a_bits < 32 || node_b_bits < 32) return 0; \
626-
DynamicSymbolExprBitTracker tracker; \
627-
tracker(search_map, &node->a()); \
628-
int target_bit = tracker(search_map, &node->b()); \
629-
if (target_bit > 0) { \
630-
} \
631-
return target_bit; \
617+
#define VISIT_OP(NodeType) \
618+
std::pair<int, bool> UnifiedOperandTypeBits( \
619+
const std::unordered_map<std::string, common::Type> *search_map, \
620+
const ir::NodeType *node) { \
621+
if (search_map->empty()) return {0, false}; \
622+
if (!node->a().type().is_int() || !node->b().type().is_int()) \
623+
return {0, false}; \
624+
int node_a_bits = node->a().type().bits(); \
625+
int node_b_bits = node->b().type().bits(); \
626+
if (node_a_bits < 32 || node_b_bits < 32) return {0, false}; \
627+
DynamicSymbolExprBitTracker tracker; \
628+
int b1 = tracker(search_map, &node->a()); \
629+
tracker.dyn_symbol_bit = 0; \
630+
int b2 = tracker(search_map, &node->b()); \
631+
return std::make_pair(std::max(b1, b2), b1 > 0 && b2 > 0); \
632632
}
633633

634634
VISIT_OP(Min)

paddle/cinn/common/ir_util.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,14 @@ void OpDataTypePromote(ir::LoweredFunc *func);
185185

186186
// only process ir::Min and ir::Max where the operands 1. contains dynamic shape
187187
// symbols. 2. the operands are both int types and both are 32/64 bits. Returns
188-
// the number of bits for unifying operands (by casting)
189-
int UnifiedOperandTypeBits(
188+
// the number of bits for unifying operands (by casting). The bool flag
189+
// indicates whether both sides has different dynamic shape symbols, since if
190+
// true (like min(S0, S1))), we should not make a ir::Cast but a ir::Call
191+
// (coercion)
192+
std::pair<int, bool> UnifiedOperandTypeBits(
190193
const std::unordered_map<std::string, common::Type> *search_map,
191194
const ir::Min *op);
192-
int UnifiedOperandTypeBits(
195+
std::pair<int, bool> UnifiedOperandTypeBits(
193196
const std::unordered_map<std::string, common::Type> *search_map,
194197
const ir::Max *op);
195198
} // namespace common

paddle/cinn/hlir/framework/pir/op_lowering_impl.cc

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -667,79 +667,6 @@ std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
667667
return funcs;
668668
}
669669

670-
/**
671-
* This function converts pir::Value::defining_op for ir::Tensor::operation
672-
* Normally, ir::Tensor::operation will only be used to record the name
673-
* of the compiler-generated var name, which is useless. However, operation
674-
* has Attributes field, so can be used to record the op info.
675-
*/
676-
ir::PlaceholderOp* TensorOperationRecording(const ::pir::Value& value) {
677-
// TODO(heqianyue): I think this is kinda ugly, since we should manually
678-
// specify the rules to convert all the op (and their attribute), yet current
679-
// implementation works and can be quickly written.
680-
const ::pir::Operation* define_op = value.defining_op();
681-
ir::PlaceholderOp* res = nullptr;
682-
if (!define_op) return res;
683-
res = cinn::common::make_shared<ir::PlaceholderOp>();
684-
res->name = define_op->name();
685-
// we filter some of the ops, and only record the **needed** attributes
686-
if (define_op->name() == "pd_op.full") {
687-
auto dtype = define_op->attribute("dtype")
688-
.dyn_cast<paddle::dialect::DataTypeAttribute>()
689-
.data();
690-
phi::Scalar data = define_op->attribute("value")
691-
.dyn_cast<paddle::dialect::ScalarAttribute>()
692-
.data();
693-
ir::Expr value;
694-
#define DEFINE_CASE(TypeFlag, Type) \
695-
case phi::DataType::TypeFlag: \
696-
value = ir::Expr(data.to<Type>()); \
697-
break;
698-
switch (dtype) {
699-
DEFINE_CASE(FLOAT32, float)
700-
DEFINE_CASE(FLOAT64, double)
701-
DEFINE_CASE(INT32, int)
702-
DEFINE_CASE(BFLOAT16, float)
703-
value->set_type(cinn::common::BFloat16());
704-
break;
705-
DEFINE_CASE(FLOAT16, float)
706-
value->set_type(cinn::common::Float16());
707-
break;
708-
default:
709-
value = ir::Expr(data.to<int64_t>());
710-
}
711-
#undef DEFINE_CASE
712-
res->attrs.emplace("value", value);
713-
} else if (define_op->name() == "cinn_op.generate_shape") {
714-
// pir::Attribute --> symbol::DimExpr --> ir::Expr
715-
716-
auto ir_dim_expr = [&]() {
717-
auto dim_expr_attr = define_op->attribute("output_dim_exprs");
718-
auto dim_exprs = dialect::ConvertAttributeToDimExprs(dim_expr_attr);
719-
720-
PADDLE_ENFORCE_EQ(
721-
dim_exprs.has_value(),
722-
true,
723-
::common::errors::PreconditionNotMet(
724-
"Required success to execute convert attribute to dim exprs."));
725-
726-
auto expr_vec = dim_exprs.value();
727-
PADDLE_ENFORCE_EQ(
728-
expr_vec.empty(),
729-
false,
730-
::common::errors::PreconditionNotMet(
731-
"Generate shape op can not yield empty symbolic shape."));
732-
// only the first dim_expr matters for ArangeOp
733-
return common::DimExprConverter().ConvertToIrExpr(expr_vec[0]);
734-
}();
735-
res->attrs.emplace("value", ir_dim_expr);
736-
} else {
737-
VLOG(6) << "Tensor defining op recording: not currently supported op.";
738-
return nullptr;
739-
}
740-
return res;
741-
}
742-
743670
ir::Tensor OpLowererImpl::GetTensor(const OpLoweringGroupPtr& group,
744671
const ::pir::Value& value) {
745672
auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>();
@@ -778,9 +705,6 @@ ir::Tensor OpLowererImpl::GetTensor(const OpLoweringGroupPtr& group,
778705
tensor->set_value(*tensor_value);
779706
}
780707
}
781-
if (auto op_ptr = TensorOperationRecording(value)) {
782-
tensor->operation = ir::FunctionRef(op_ptr);
783-
}
784708
return tensor;
785709
}
786710

paddle/cinn/hlir/framework/pir/op_lowering_util.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ std::unordered_set<::pir::Operation*> GetMasters(
180180
}
181181

182182
bool IsConstOp(const ::pir::Operation* op) {
183-
static std::unordered_set<std::string> const_op_type = {
184-
"const_scalar", "fill_constant", "arange"};
183+
static std::unordered_set<std::string> const_op_type = {"const_scalar",
184+
"fill_constant"};
185185
return const_op_type.count(CompatibleInfo::OpName(*op));
186186
}
187187

paddle/cinn/hlir/op/elementwise.cc

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,16 +1257,12 @@ std::shared_ptr<framework::OpStrategy> StrategyForArangeSymbolic(
12571257
const Target &target) {
12581258
bool all_static = true;
12591259
for (int i = 0; i < 3; i++) {
1260-
auto op_node = inputs[i]->operation->as<ir::PlaceholderOp>();
1261-
PADDLE_ENFORCE_NE(
1262-
op_node,
1263-
nullptr,
1264-
::common::errors::PreconditionNotMet(
1265-
"The defining op of the input tensor is not set! Please check."));
1266-
if (op_node->name == "cinn_op.generate_shape") {
1267-
all_static = false;
1268-
break;
1269-
}
1260+
if (!inputs[i]->value().has_value()) continue;
1261+
auto input_val = inputs[i]->value().value();
1262+
if (input_val.empty()) continue;
1263+
if (input_val[0].is_constant()) continue;
1264+
all_static = false;
1265+
break;
12701266
}
12711267
auto attr_store = attrs.attr_store;
12721268
auto dtype =
@@ -1341,15 +1337,22 @@ std::shared_ptr<framework::OpStrategy> StrategyForArangeSymbolic(
13411337
"bfloat16 or float16."));
13421338
}
13431339
#undef EXPR_FROM_ATTR
1344-
} else { // has dynamic shape, some of the operands come from
1345-
// cinn_op.generate_shape
1346-
// in op_lowering_impl.cc, tensor op recorder unified the attribute name
1347-
start = Expr(
1348-
inputs[0]->operation->as<ir::PlaceholderOp>()->attrs.at("value").ptr());
1349-
step = Expr(
1350-
inputs[2]->operation->as<ir::PlaceholderOp>()->attrs.at("value").ptr());
1351-
Expr end = Expr(
1352-
inputs[1]->operation->as<ir::PlaceholderOp>()->attrs.at("value").ptr());
1340+
} else {
1341+
for (int i = 0; i < 3; i++) {
1342+
PADDLE_ENFORCE_EQ(
1343+
inputs[i]->value().has_value(),
1344+
true,
1345+
::common::errors::InvalidArgument(
1346+
"The input tensor of dynamic arange should have valid values."));
1347+
PADDLE_ENFORCE_NE(
1348+
inputs[i]->value().value().empty(),
1349+
true,
1350+
::common::errors::InvalidArgument(
1351+
"The tensor value of dynamic arange should not be empty."));
1352+
}
1353+
start = inputs[0]->value().value()[0];
1354+
step = inputs[2]->value().value()[0];
1355+
Expr end = inputs[1]->value().value()[0];
13531356
auto IrAbs = [=](Expr ir) -> Expr {
13541357
return ir::Call::Make(step.type(), "abs", {ir}, {}, ir::CallType::Extern);
13551358
};

0 commit comments

Comments
 (0)