Skip to content

[AutoDiff] Fix quite subtle but nasty bug in linear map tuple types computation #68413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions lib/SILOptimizer/Differentiation/LinearMapInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,12 @@ void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB,
heapAllocatedContext = true;
decl->setInterfaceType(astCtx.TheRawPointerType);
} else { // Otherwise the payload is the linear map tuple.
auto linearMapStructTy = getLinearMapTupleType(predBB)->getCanonicalType();
auto *linearMapStructTy = getLinearMapTupleType(predBB);
assert(linearMapStructTy && "must have linear map struct type for predecessor BB");
auto canLinearMapStructTy = linearMapStructTy->getCanonicalType();
decl->setInterfaceType(
linearMapStructTy->hasArchetype()
? linearMapStructTy->mapTypeOutOfContext() : linearMapStructTy);
canLinearMapStructTy->hasArchetype()
? canLinearMapStructTy->mapTypeOutOfContext() : canLinearMapStructTy);
}
// Create enum element and enum case declarations.
auto *paramList = ParameterList::create(astCtx, {decl});
Expand Down Expand Up @@ -331,10 +333,28 @@ void LinearMapInfo::generateDifferentiationDataStructures(
}

// Add linear map fields to the linear map tuples.
for (auto &origBB : *original) {
//
// Now we need to be very careful as we're having a very subtle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain why the reverse post-order traversal(RPOT) is needed and how it is helping?

I understand that we need the branch trace enum for a BB to be fully constructed before using it to derive the type for the linear map tuple of the same BB. But I don't understand how going in RPOT order is going to ensure that we have the right types for the branch trace enums?

I think so because while fully constructing a branch trace enum declaration in populateBranchingTraceDecl we look up the linear map tuple types of the predecessor BBs, but these types haven't been set yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When traversing BBs in RPOT we ensure that each BB is processed after its predecessors. As a result, we know that all linear map tuples for predecessor BBs are already finalized and therefore the enum type that we created will also be "complete" – we will not need to add any entries later.

The problem was not linear map tuples, but that we passed "incomplete" EnumDecl to getBranchingTraceEnumLoweredType. As a result, several flags on the corresponding SIL type were set improperly, causing the branch trace enum type to be always trivial (despite non-trivial payloads added afterwards, the lowered type will be cached and not re-calculated).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When traversing BBs in RPOT we ensure that each BB is processed after its predecessors.

Ah gotcha. I was mistaking RPOT for regular post order traversal but (Right->Left->Node) instead of (Left->Right->Node). But it seems like it's more or less a "pre-order" traversal.

Copy link
Contributor Author

@asl asl Sep 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's neither a pre-order nor a post-order, it's not a DFS. We cannot use DFS because we're having a DAG, not a tree. Think about A -> B -> C -> D; B -> D CFG. We need to visit D after both B and C. RPOT is quite a standard technique for various data-flow problems (and could be used to compute topological sorting of the graph).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thanks for the explanation!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem was not linear map tuples, but that we passed "incomplete" EnumDecl to getBranchingTraceEnumLoweredType. As a result, several flags on the corresponding SIL type were set improperly, causing the branch trace enum type to be always trivial (despite non-trivial payloads added afterwards, the lowered type will be cached and not re-calculated).

You're not supposed to mutate EnumDecls like that at all and changing the order of iteration is only papering over the issue. Type lowering caches the results because it assumes the inputs are immutable. You need to build your AST before you get to SIL.

// chicken-and-egg problem. We need lowered branch trace enum type for the
// linear map typle type. However branch trace enum type lowering depends on
// the lowering of its elements (at very least, the type classification of
// being trivial / non-trivial). As the lowering is cached we need to ensure
// we compute lowered type for the branch trace enum when the corresponding
// EnumDecl is fully complete: we cannot add more entries without causing some
// very subtle issues later on. However, the elements of the enum are linear
// map tuples of predecessors, that correspondingly may contain branch trace
// enums of corresponding predecessor BBs.
//
// Traverse all BBs in reverse post-order traversal order to ensure we process
// each BB before its predecessors.
llvm::ReversePostOrderTraversal<SILFunction *> RPOT(original);
for (auto Iter = RPOT.begin(), E = RPOT.end(); Iter != E; ++Iter) {
auto *origBB = *Iter;
SmallVector<TupleTypeElt, 4> linearTupleTypes;
if (!origBB.isEntry()) {
CanType traceEnumType = getBranchingTraceEnumLoweredType(&origBB).getASTType();
if (!origBB->isEntry()) {
populateBranchingTraceDecl(origBB, loopInfo);

CanType traceEnumType = getBranchingTraceEnumLoweredType(origBB).getASTType();
linearTupleTypes.emplace_back(traceEnumType,
astCtx.getIdentifier(traceEnumFieldName));
}
Expand All @@ -343,7 +363,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(
// Do not add linear map fields for semantic member accessors, which have
// special-case pullback generation. Linear map tuples should be empty.
} else {
for (auto &inst : origBB) {
for (auto &inst : *origBB) {
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
// Add linear map field to struct for active `apply` instructions.
// Skip array literal intrinsic applications since array literal
Expand All @@ -363,12 +383,9 @@ void LinearMapInfo::generateDifferentiationDataStructures(
}
}

linearMapTuples.insert({&origBB, TupleType::get(linearTupleTypes, astCtx)});
linearMapTuples.insert({origBB, TupleType::get(linearTupleTypes, astCtx)});
}

for (auto &origBB : *original)
populateBranchingTraceDecl(&origBB, loopInfo);

// Print generated linear map structs and branching trace enums.
// These declarations do not show up with `-emit-sil` because they are
// implicit. Instead, use `-Xllvm -debug-only=differentiation` to test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ func cond(_ x: Float) -> Float {
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt, [[BB2_PB_STRUCT]]
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0)

// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0)
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0)
// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @condTJpSpSr
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]])
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
// CHECK-SIL: return [[VJP_RESULT]]


// CHECK-SIL-LABEL: sil private [ossa] @condTJpSpSr : $@convention(thin) (Float, @owned _AD__cond_bb3__Pred__src_0_wrt_0) -> Float {
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0):
// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0):
// CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb3

// CHECK-SIL: bb1([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : @owned $(predecessor: _AD__cond_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float))):
Expand Down Expand Up @@ -132,6 +132,39 @@ func loop_generic<T : Differentiable & FloatingPoint>(_ x: T) -> T {
return result
}

@differentiable(reverse)
@_silgen_name("loop_context")
func loop_context(x: Float) -> Float {
let y = x + 1
for _ in 0 ..< 1 {}
return y
}

// CHECK-DATA-STRUCTURES-LABEL: Generated linear map tuples and branching trace enums for @loop_context:
// CHECK-DATA-STRUCTURES: (_: (Float) -> Float)
// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb2__Pred__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb3__Pred__src_0_wrt_0)
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb0__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb1__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb2(Builtin.RawPointer)
// CHECK-DATA-STRUCTURES: case bb0((_: (Float) -> Float))
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb2__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb1(Builtin.RawPointer)
// CHECK-DATA-STRUCTURES: }
// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb3__Pred__src_0_wrt_0 {
// CHECK-DATA-STRUCTURES: case bb1(Builtin.RawPointer)
// CHECK-DATA-STRUCTURES: }

// CHECK-SIL-LABEL: sil private [ossa] @loop_contextTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> Float {
// CHECK-SIL: bb1([[LOOP_CONTEXT:%.*]] : $Builtin.RawPointer):
// CHECK-SIL: [[PB_TUPLE_ADDR:%.*]] = pointer_to_address [[LOOP_CONTEXT]] : $Builtin.RawPointer to [strict] $*(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)
// CHECK-SIL: [[PB_TUPLE_CPY:%.*]] = load [copy] [[PB_TUPLE_ADDR]] : $*(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)
// CHECK-SIL: br bb3({{.*}} : $Float, {{.*}} : $Float, [[PB_TUPLE_CPY]] : $(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0))
// CHECK-SIL: bb3({{.*}} : $Float, {{.*}} : $Float, {{.*}} : @owned $(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)):

// Test `switch_enum`.

enum Enum {
Expand Down Expand Up @@ -164,7 +197,7 @@ func enum_notactive(_ e: Enum, _ x: Float) -> Float {
// CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__enum_notactive_bb3__Pred__src_0_wrt_1, #_AD__enum_notactive_bb3__Pred__src_0_wrt_1.bb2!enumelt, [[BB2_PB_STRUCT]] : $(predecessor: _AD__enum_notactive_bb2__Pred__src_0_wrt_1, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float)
// CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)

// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)
// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__enum_notactive_bb3__Pred__src_0_wrt_1)
// CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @enum_notactiveTJpUSpSr
// CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]])
// CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)
Expand Down