From 8ae5c53f6efa2bdac69f56f6be3c63c27abf8a24 Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Fri, 8 Sep 2023 19:21:04 -0700 Subject: [PATCH] Fix quite subtle but nasty bug in linear map tuple types computation: we need lowered type for branch trace enum in order to compute linear map tuple type. However, the lowering of branch trace enum type depends on the types of its elements (the payloads are linear map tuples of predecessor BB). As lowered types are cached, we cannot populate branch trace enum entries in the end as we did before: we already used wrong lowered types for linear map tuples. Traverse basic blocks in reverse post-order traver order building linear map tuples and branch tracing enumns in one go, ensuring that we've done with predecessor BBs before processing the BB itself. --- .../Differentiation/LinearMapInfo.cpp | 39 +++++++++++++------ .../differentiation_control_flow_sil.swift | 39 +++++++++++++++++-- 2 files changed, 64 insertions(+), 14 deletions(-) diff --git a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp index 9797c3e982381..3585697e96875 100644 --- a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp +++ b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp @@ -142,10 +142,12 @@ void LinearMapInfo::populateBranchingTraceDecl(SILBasicBlock *originalBB, heapAllocatedContext = true; decl->setInterfaceType(astCtx.TheRawPointerType); } else { // Otherwise the payload is the linear map tuple. - auto linearMapStructTy = getLinearMapTupleType(predBB)->getCanonicalType(); + auto *linearMapStructTy = getLinearMapTupleType(predBB); + assert(linearMapStructTy && "must have linear map struct type for predecessor BB"); + auto canLinearMapStructTy = linearMapStructTy->getCanonicalType(); decl->setInterfaceType( - linearMapStructTy->hasArchetype() - ? linearMapStructTy->mapTypeOutOfContext() : linearMapStructTy); + canLinearMapStructTy->hasArchetype() + ? canLinearMapStructTy->mapTypeOutOfContext() : canLinearMapStructTy); } // Create enum element and enum case declarations. auto *paramList = ParameterList::create(astCtx, {decl}); @@ -331,10 +333,28 @@ void LinearMapInfo::generateDifferentiationDataStructures( } // Add linear map fields to the linear map tuples. - for (auto &origBB : *original) { + // + // Now we need to be very careful as we're having a very subtle + // chicken-and-egg problem. We need lowered branch trace enum type for the + // linear map typle type. However branch trace enum type lowering depends on + // the lowering of its elements (at very least, the type classification of + // being trivial / non-trivial). As the lowering is cached we need to ensure + // we compute lowered type for the branch trace enum when the corresponding + // EnumDecl is fully complete: we cannot add more entries without causing some + // very subtle issues later on. However, the elements of the enum are linear + // map tuples of predecessors, that correspondingly may contain branch trace + // enums of corresponding predecessor BBs. + // + // Traverse all BBs in reverse post-order traversal order to ensure we process + // each BB before its predecessors. + llvm::ReversePostOrderTraversal RPOT(original); + for (auto Iter = RPOT.begin(), E = RPOT.end(); Iter != E; ++Iter) { + auto *origBB = *Iter; SmallVector linearTupleTypes; - if (!origBB.isEntry()) { - CanType traceEnumType = getBranchingTraceEnumLoweredType(&origBB).getASTType(); + if (!origBB->isEntry()) { + populateBranchingTraceDecl(origBB, loopInfo); + + CanType traceEnumType = getBranchingTraceEnumLoweredType(origBB).getASTType(); linearTupleTypes.emplace_back(traceEnumType, astCtx.getIdentifier(traceEnumFieldName)); } @@ -343,7 +363,7 @@ void LinearMapInfo::generateDifferentiationDataStructures( // Do not add linear map fields for semantic member accessors, which have // special-case pullback generation. Linear map tuples should be empty. } else { - for (auto &inst : origBB) { + for (auto &inst : *origBB) { if (auto *ai = dyn_cast(&inst)) { // Add linear map field to struct for active `apply` instructions. // Skip array literal intrinsic applications since array literal @@ -363,12 +383,9 @@ void LinearMapInfo::generateDifferentiationDataStructures( } } - linearMapTuples.insert({&origBB, TupleType::get(linearTupleTypes, astCtx)}); + linearMapTuples.insert({origBB, TupleType::get(linearTupleTypes, astCtx)}); } - for (auto &origBB : *original) - populateBranchingTraceDecl(&origBB, loopInfo); - // Print generated linear map structs and branching trace enums. // These declarations do not show up with `-emit-sil` because they are // implicit. Instead, use `-Xllvm -debug-only=differentiation` to test diff --git a/test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift b/test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift index 31c137a502d49..3884c11420e62 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_control_flow_sil.swift @@ -56,7 +56,7 @@ func cond(_ x: Float) -> Float { // CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__cond_bb3__Pred__src_0_wrt_0, #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt, [[BB2_PB_STRUCT]] // CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__cond_bb3__Pred__src_0_wrt_0) -// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0) +// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0) // CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @condTJpSpSr // CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]]) // CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float) @@ -64,7 +64,7 @@ func cond(_ x: Float) -> Float { // CHECK-SIL-LABEL: sil private [ossa] @condTJpSpSr : $@convention(thin) (Float, @owned _AD__cond_bb3__Pred__src_0_wrt_0) -> Float { -// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : $_AD__cond_bb3__Pred__src_0_wrt_0): +// CHECK-SIL: bb0([[SEED:%.*]] : $Float, [[BB3_PRED:%.*]] : @owned $_AD__cond_bb3__Pred__src_0_wrt_0): // CHECK-SIL: switch_enum [[BB3_PRED]] : $_AD__cond_bb3__Pred__src_0_wrt_0, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__cond_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb3 // CHECK-SIL: bb1([[BB3_PRED2_TRAMP_PB_STRUCT:%.*]] : @owned $(predecessor: _AD__cond_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> (Float, Float))): @@ -132,6 +132,39 @@ func loop_generic(_ x: T) -> T { return result } +@differentiable(reverse) +@_silgen_name("loop_context") +func loop_context(x: Float) -> Float { + let y = x + 1 + for _ in 0 ..< 1 {} + return y +} + +// CHECK-DATA-STRUCTURES-LABEL: Generated linear map tuples and branching trace enums for @loop_context: +// CHECK-DATA-STRUCTURES: (_: (Float) -> Float) +// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0) +// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb2__Pred__src_0_wrt_0) +// CHECK-DATA-STRUCTURES: (predecessor: _AD__loop_context_bb3__Pred__src_0_wrt_0) +// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb0__Pred__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: } +// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb1__Pred__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: case bb2(Builtin.RawPointer) +// CHECK-DATA-STRUCTURES: case bb0((_: (Float) -> Float)) +// CHECK-DATA-STRUCTURES: } +// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb2__Pred__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: case bb1(Builtin.RawPointer) +// CHECK-DATA-STRUCTURES: } +// CHECK-DATA-STRUCTURES: enum _AD__loop_context_bb3__Pred__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: case bb1(Builtin.RawPointer) +// CHECK-DATA-STRUCTURES: } + +// CHECK-SIL-LABEL: sil private [ossa] @loop_contextTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> Float { +// CHECK-SIL: bb1([[LOOP_CONTEXT:%.*]] : $Builtin.RawPointer): +// CHECK-SIL: [[PB_TUPLE_ADDR:%.*]] = pointer_to_address [[LOOP_CONTEXT]] : $Builtin.RawPointer to [strict] $*(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0) +// CHECK-SIL: [[PB_TUPLE_CPY:%.*]] = load [copy] [[PB_TUPLE_ADDR]] : $*(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0) +// CHECK-SIL: br bb3({{.*}} : $Float, {{.*}} : $Float, [[PB_TUPLE_CPY]] : $(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)) +// CHECK-SIL: bb3({{.*}} : $Float, {{.*}} : $Float, {{.*}} : @owned $(predecessor: _AD__loop_context_bb1__Pred__src_0_wrt_0)): + // Test `switch_enum`. enum Enum { @@ -164,7 +197,7 @@ func enum_notactive(_ e: Enum, _ x: Float) -> Float { // CHECK-SIL: [[BB3_PRED_PRED2:%.*]] = enum $_AD__enum_notactive_bb3__Pred__src_0_wrt_1, #_AD__enum_notactive_bb3__Pred__src_0_wrt_1.bb2!enumelt, [[BB2_PB_STRUCT]] : $(predecessor: _AD__enum_notactive_bb2__Pred__src_0_wrt_1, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float) // CHECK-SIL: br bb3({{.*}} : $Float, [[BB3_PRED_PRED2]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1) -// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : $_AD__enum_notactive_bb3__Pred__src_0_wrt_1) +// CHECK-SIL: bb3([[ORIG_RES:%.*]] : $Float, [[BB3_PRED_ARG:%.*]] : @owned $_AD__enum_notactive_bb3__Pred__src_0_wrt_1) // CHECK-SIL: [[PULLBACK_REF:%.*]] = function_ref @enum_notactiveTJpUSpSr // CHECK-SIL: [[PB:%.*]] = partial_apply [callee_guaranteed] [[PULLBACK_REF]]([[BB3_PRED_ARG]]) // CHECK-SIL: [[VJP_RESULT:%.*]] = tuple ([[ORIG_RES]] : $Float, [[PB]] : $@callee_guaranteed (Float) -> Float)