diff --git a/include/swift/AST/Builtins.def b/include/swift/AST/Builtins.def index fb271e7e1ba88..286ebdb1bfd6a 100644 --- a/include/swift/AST/Builtins.def +++ b/include/swift/AST/Builtins.def @@ -985,14 +985,14 @@ BUILTIN_MISC_OPERATION_WITH_SILGEN(CreateAsyncTaskInGroup, /// is a pure value and therefore we can consider it as readnone). BUILTIN_MISC_OPERATION_WITH_SILGEN(GlobalStringTablePointer, "globalStringTablePointer", "n", Special) -// autoDiffCreateLinearMapContext: (Builtin.Word) -> Builtin.NativeObject -BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffCreateLinearMapContext, "autoDiffCreateLinearMapContext", "", Special) +// autoDiffCreateLinearMapContextWithType: (T.Type) -> Builtin.NativeObject +BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffCreateLinearMapContextWithType, "autoDiffCreateLinearMapContextWithType", "", Special) // autoDiffProjectTopLevelSubcontext: (Builtin.NativeObject) -> Builtin.RawPointer BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffProjectTopLevelSubcontext, "autoDiffProjectTopLevelSubcontext", "n", Special) -// autoDiffAllocateSubcontext: (Builtin.NativeObject, Builtin.Word) -> Builtin.RawPointer -BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffAllocateSubcontext, "autoDiffAllocateSubcontext", "", Special) +// autoDiffAllocateSubcontextWithType: (Builtin.NativeObject, T.Type) -> Builtin.RawPointer +BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffAllocateSubcontextWithType, "autoDiffAllocateSubcontextWithType", "", Special) /// Build a Builtin.Executor value from an "ordinary" serial executor /// reference. diff --git a/include/swift/Runtime/RuntimeFunctions.def b/include/swift/Runtime/RuntimeFunctions.def index 812f807d971ac..5a788b02bf7a7 100644 --- a/include/swift/Runtime/RuntimeFunctions.def +++ b/include/swift/Runtime/RuntimeFunctions.def @@ -2273,12 +2273,12 @@ FUNCTION(TaskGroupDestroy, ATTRS(NoUnwind), EFFECT(Concurrency)) -// AutoDiffLinearMapContext *swift_autoDiffCreateLinearMapContext(size_t); -FUNCTION(AutoDiffCreateLinearMapContext, - swift_autoDiffCreateLinearMapContext, SwiftCC, +// AutoDiffLinearMapContext *swift_autoDiffCreateLinearMapContextWithType(const Metadata *); +FUNCTION(AutoDiffCreateLinearMapContextWithType, + swift_autoDiffCreateLinearMapContextWithType, SwiftCC, DifferentiationAvailability, RETURNS(RefCountedPtrTy), - ARGS(SizeTy), + ARGS(TypeMetadataPtrTy), ATTRS(NoUnwind, ArgMemOnly), EFFECT(AutoDiff)) @@ -2291,12 +2291,12 @@ FUNCTION(AutoDiffProjectTopLevelSubcontext, ATTRS(NoUnwind, ArgMemOnly), EFFECT(AutoDiff)) -// void *swift_autoDiffAllocateSubcontext(AutoDiffLinearMapContext *, size_t); -FUNCTION(AutoDiffAllocateSubcontext, - swift_autoDiffAllocateSubcontext, SwiftCC, +// void *swift_autoDiffAllocateSubcontextWithType(AutoDiffLinearMapContext *, const Metadata *); +FUNCTION(AutoDiffAllocateSubcontextWithType, + swift_autoDiffAllocateSubcontextWithType, SwiftCC, DifferentiationAvailability, RETURNS(Int8PtrTy), - ARGS(RefCountedPtrTy, SizeTy), + ARGS(RefCountedPtrTy, TypeMetadataPtrTy), ATTRS(NoUnwind, ArgMemOnly), EFFECT(AutoDiff)) diff --git a/lib/AST/Builtins.cpp b/lib/AST/Builtins.cpp index 035cbf9beee5a..fe168edc629bc 100644 --- a/lib/AST/Builtins.cpp +++ b/lib/AST/Builtins.cpp @@ -1609,7 +1609,8 @@ static ValueDecl *getBuildComplexEqualitySerialExecutorRef(ASTContext &ctx, static ValueDecl *getAutoDiffCreateLinearMapContext(ASTContext &ctx, Identifier id) { return getBuiltinFunction( - id, {BuiltinIntegerType::getWordType(ctx)}, ctx.TheNativeObjectType); + ctx, id, _thin, _generics(_unrestricted), + _parameters(_metatype(_typeparam(0))), _nativeObject); } static ValueDecl *getAutoDiffProjectTopLevelSubcontext(ASTContext &ctx, @@ -1621,8 +1622,8 @@ static ValueDecl *getAutoDiffProjectTopLevelSubcontext(ASTContext &ctx, static ValueDecl *getAutoDiffAllocateSubcontext(ASTContext &ctx, Identifier id) { return getBuiltinFunction( - id, {ctx.TheNativeObjectType, BuiltinIntegerType::getWordType(ctx)}, - ctx.TheRawPointerType); + ctx, id, _thin, _generics(_unrestricted), + _parameters(_nativeObject, _metatype(_typeparam(0))), _rawPointer); } static ValueDecl *getPoundAssert(ASTContext &Context, Identifier Id) { @@ -2966,13 +2967,13 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) { case BuiltinValueKind::HopToActor: return getHopToActor(Context, Id); - case BuiltinValueKind::AutoDiffCreateLinearMapContext: + case BuiltinValueKind::AutoDiffCreateLinearMapContextWithType: return getAutoDiffCreateLinearMapContext(Context, Id); case BuiltinValueKind::AutoDiffProjectTopLevelSubcontext: return getAutoDiffProjectTopLevelSubcontext(Context, Id); - case BuiltinValueKind::AutoDiffAllocateSubcontext: + case BuiltinValueKind::AutoDiffAllocateSubcontextWithType: return getAutoDiffAllocateSubcontext(Context, Id); } diff --git a/lib/IRGen/GenBuiltin.cpp b/lib/IRGen/GenBuiltin.cpp index c697557a38684..5dfaebc1d0318 100644 --- a/lib/IRGen/GenBuiltin.cpp +++ b/lib/IRGen/GenBuiltin.cpp @@ -1307,9 +1307,10 @@ void irgen::emitBuiltinCall(IRGenFunction &IGF, const BuiltinInfo &Builtin, return; } - if (Builtin.ID == BuiltinValueKind::AutoDiffCreateLinearMapContext) { - auto topLevelSubcontextSize = args.claimNext(); - out.add(emitAutoDiffCreateLinearMapContext(IGF, topLevelSubcontextSize) + if (Builtin.ID == BuiltinValueKind::AutoDiffCreateLinearMapContextWithType) { + auto topLevelSubcontextMetaType = args.claimNext(); + out.add(emitAutoDiffCreateLinearMapContextWithType( + IGF, topLevelSubcontextMetaType) .getAddress()); return; } @@ -1322,12 +1323,13 @@ void irgen::emitBuiltinCall(IRGenFunction &IGF, const BuiltinInfo &Builtin, return; } - if (Builtin.ID == BuiltinValueKind::AutoDiffAllocateSubcontext) { + if (Builtin.ID == BuiltinValueKind::AutoDiffAllocateSubcontextWithType) { Address allocatorAddr(args.claimNext(), IGF.IGM.RefCountedStructTy, IGF.IGM.getPointerAlignment()); - auto size = args.claimNext(); - out.add( - emitAutoDiffAllocateSubcontext(IGF, allocatorAddr, size).getAddress()); + auto subcontextMetatype = args.claimNext(); + out.add(emitAutoDiffAllocateSubcontextWithType(IGF, allocatorAddr, + subcontextMetatype) + .getAddress()); return; } diff --git a/lib/IRGen/GenCall.cpp b/lib/IRGen/GenCall.cpp index c5f264d0e2c8f..c92ac7e86887c 100644 --- a/lib/IRGen/GenCall.cpp +++ b/lib/IRGen/GenCall.cpp @@ -5479,11 +5479,13 @@ IRGenFunction::getFunctionPointerForResumeIntrinsic(llvm::Value *resume) { return fnPtr; } -Address irgen::emitAutoDiffCreateLinearMapContext( - IRGenFunction &IGF, llvm::Value *topLevelSubcontextSize) { +Address irgen::emitAutoDiffCreateLinearMapContextWithType( + IRGenFunction &IGF, llvm::Value *topLevelSubcontextMetatype) { + topLevelSubcontextMetatype = IGF.Builder.CreateBitCast( + topLevelSubcontextMetatype, IGF.IGM.TypeMetadataPtrTy); auto *call = IGF.Builder.CreateCall( - IGF.IGM.getAutoDiffCreateLinearMapContextFunctionPointer(), - {topLevelSubcontextSize}); + IGF.IGM.getAutoDiffCreateLinearMapContextWithTypeFunctionPointer(), + {topLevelSubcontextMetatype}); call->setDoesNotThrow(); call->setCallingConv(IGF.IGM.SwiftCC); return Address(call, IGF.IGM.RefCountedStructTy, @@ -5500,11 +5502,13 @@ Address irgen::emitAutoDiffProjectTopLevelSubcontext( return Address(call, IGF.IGM.Int8Ty, IGF.IGM.getPointerAlignment()); } -Address irgen::emitAutoDiffAllocateSubcontext( - IRGenFunction &IGF, Address context, llvm::Value *size) { +Address irgen::emitAutoDiffAllocateSubcontextWithType( + IRGenFunction &IGF, Address context, llvm::Value *subcontextMetatype) { + subcontextMetatype = + IGF.Builder.CreateBitCast(subcontextMetatype, IGF.IGM.TypeMetadataPtrTy); auto *call = IGF.Builder.CreateCall( - IGF.IGM.getAutoDiffAllocateSubcontextFunctionPointer(), - {context.getAddress(), size}); + IGF.IGM.getAutoDiffAllocateSubcontextWithTypeFunctionPointer(), + {context.getAddress(), subcontextMetatype}); call->setDoesNotThrow(); call->setCallingConv(IGF.IGM.SwiftCC); return Address(call, IGF.IGM.Int8Ty, IGF.IGM.getPointerAlignment()); diff --git a/lib/IRGen/GenCall.h b/lib/IRGen/GenCall.h index c64e97dc8a264..70879ce628d05 100644 --- a/lib/IRGen/GenCall.h +++ b/lib/IRGen/GenCall.h @@ -261,12 +261,15 @@ namespace irgen { CanSILFunctionType fnType, Explosion &result, Explosion &error); - Address emitAutoDiffCreateLinearMapContext( - IRGenFunction &IGF, llvm::Value *topLevelSubcontextSize); + Address emitAutoDiffCreateLinearMapContextWithType( + IRGenFunction &IGF, llvm::Value *topLevelSubcontextMetatype); + Address emitAutoDiffProjectTopLevelSubcontext( IRGenFunction &IGF, Address context); - Address emitAutoDiffAllocateSubcontext( - IRGenFunction &IGF, Address context, llvm::Value *size); + + Address + emitAutoDiffAllocateSubcontextWithType(IRGenFunction &IGF, Address context, + llvm::Value *subcontextMetatype); FunctionPointer getFunctionPointerForDispatchCall(IRGenModule &IGM, const FunctionPointer &fn); diff --git a/lib/SIL/IR/OperandOwnership.cpp b/lib/SIL/IR/OperandOwnership.cpp index 7fe14a6e059a4..faeb4a6141f5b 100644 --- a/lib/SIL/IR/OperandOwnership.cpp +++ b/lib/SIL/IR/OperandOwnership.cpp @@ -943,7 +943,7 @@ BUILTIN_OPERAND_OWNERSHIP(InstantaneousUse, InitializeDistributedRemoteActor) BUILTIN_OPERAND_OWNERSHIP(InstantaneousUse, InitializeNonDefaultDistributedActor) -BUILTIN_OPERAND_OWNERSHIP(PointerEscape, AutoDiffAllocateSubcontext) +BUILTIN_OPERAND_OWNERSHIP(PointerEscape, AutoDiffAllocateSubcontextWithType) BUILTIN_OPERAND_OWNERSHIP(PointerEscape, AutoDiffProjectTopLevelSubcontext) // FIXME: ConvertTaskToJob is documented as taking NativePointer. It's operand's @@ -955,8 +955,7 @@ BUILTIN_OPERAND_OWNERSHIP(BitwiseEscape, BuildComplexEqualitySerialExecutorRef) BUILTIN_OPERAND_OWNERSHIP(BitwiseEscape, BuildDefaultActorExecutorRef) BUILTIN_OPERAND_OWNERSHIP(BitwiseEscape, BuildMainActorExecutorRef) -BUILTIN_OPERAND_OWNERSHIP(TrivialUse, AutoDiffCreateLinearMapContext) - +BUILTIN_OPERAND_OWNERSHIP(TrivialUse, AutoDiffCreateLinearMapContextWithType) #undef BUILTIN_OPERAND_OWNERSHIP #define SHOULD_NEVER_VISIT_BUILTIN(ID) \ diff --git a/lib/SIL/IR/ValueOwnership.cpp b/lib/SIL/IR/ValueOwnership.cpp index 5cdc0b01c15e8..40c1bbcd05726 100644 --- a/lib/SIL/IR/ValueOwnership.cpp +++ b/lib/SIL/IR/ValueOwnership.cpp @@ -571,9 +571,9 @@ CONSTANT_OWNERSHIP_BUILTIN(None, InitializeDefaultActor) CONSTANT_OWNERSHIP_BUILTIN(None, DestroyDefaultActor) CONSTANT_OWNERSHIP_BUILTIN(None, InitializeDistributedRemoteActor) CONSTANT_OWNERSHIP_BUILTIN(None, InitializeNonDefaultDistributedActor) -CONSTANT_OWNERSHIP_BUILTIN(Owned, AutoDiffCreateLinearMapContext) +CONSTANT_OWNERSHIP_BUILTIN(Owned, AutoDiffCreateLinearMapContextWithType) CONSTANT_OWNERSHIP_BUILTIN(None, AutoDiffProjectTopLevelSubcontext) -CONSTANT_OWNERSHIP_BUILTIN(None, AutoDiffAllocateSubcontext) +CONSTANT_OWNERSHIP_BUILTIN(None, AutoDiffAllocateSubcontextWithType) CONSTANT_OWNERSHIP_BUILTIN(None, GetCurrentExecutor) CONSTANT_OWNERSHIP_BUILTIN(None, ResumeNonThrowingContinuationReturning) CONSTANT_OWNERSHIP_BUILTIN(None, ResumeThrowingContinuationReturning) diff --git a/lib/SIL/Utils/MemAccessUtils.cpp b/lib/SIL/Utils/MemAccessUtils.cpp index 6cf1a89f14f3a..b672f69550c50 100644 --- a/lib/SIL/Utils/MemAccessUtils.cpp +++ b/lib/SIL/Utils/MemAccessUtils.cpp @@ -2567,8 +2567,8 @@ static void visitBuiltinAddress(BuiltinInst *builtin, case BuiltinValueKind::CancelAsyncTask: case BuiltinValueKind::CreateAsyncTask: case BuiltinValueKind::CreateAsyncTaskInGroup: - case BuiltinValueKind::AutoDiffCreateLinearMapContext: - case BuiltinValueKind::AutoDiffAllocateSubcontext: + case BuiltinValueKind::AutoDiffCreateLinearMapContextWithType: + case BuiltinValueKind::AutoDiffAllocateSubcontextWithType: case BuiltinValueKind::InitializeDefaultActor: case BuiltinValueKind::InitializeDistributedRemoteActor: case BuiltinValueKind::InitializeNonDefaultDistributedActor: diff --git a/lib/SILGen/SILGenBuiltin.cpp b/lib/SILGen/SILGenBuiltin.cpp index d3b3833aef385..e07bc18b0868a 100644 --- a/lib/SILGen/SILGenBuiltin.cpp +++ b/lib/SILGen/SILGenBuiltin.cpp @@ -1708,16 +1708,15 @@ static ManagedValue emitBuiltinHopToActor(SILGenFunction &SGF, SILLocation loc, return ManagedValue::forObjectRValueWithoutOwnership(SGF.emitEmptyTuple(loc)); } -static ManagedValue emitBuiltinAutoDiffCreateLinearMapContext( +static ManagedValue emitBuiltinAutoDiffCreateLinearMapContextWithType( SILGenFunction &SGF, SILLocation loc, SubstitutionMap subs, ArrayRef args, SGFContext C) { ASTContext &ctx = SGF.getASTContext(); auto *builtinApply = SGF.B.createBuiltin( loc, - ctx.getIdentifier( - getBuiltinName(BuiltinValueKind::AutoDiffCreateLinearMapContext)), - SILType::getNativeObjectType(ctx), - subs, + ctx.getIdentifier(getBuiltinName( + BuiltinValueKind::AutoDiffCreateLinearMapContextWithType)), + SILType::getNativeObjectType(ctx), subs, /*args*/ {args[0].getValue()}); return SGF.emitManagedRValueWithCleanup(builtinApply); } @@ -1736,16 +1735,15 @@ static ManagedValue emitBuiltinAutoDiffProjectTopLevelSubcontext( return ManagedValue::forObjectRValueWithoutOwnership(builtinApply); } -static ManagedValue emitBuiltinAutoDiffAllocateSubcontext( +static ManagedValue emitBuiltinAutoDiffAllocateSubcontextWithType( SILGenFunction &SGF, SILLocation loc, SubstitutionMap subs, ArrayRef args, SGFContext C) { ASTContext &ctx = SGF.getASTContext(); auto *builtinApply = SGF.B.createBuiltin( loc, ctx.getIdentifier( - getBuiltinName(BuiltinValueKind::AutoDiffAllocateSubcontext)), - SILType::getRawPointerType(ctx), - subs, + getBuiltinName(BuiltinValueKind::AutoDiffAllocateSubcontextWithType)), + SILType::getRawPointerType(ctx), subs, /*args*/ {args[0].borrow(SGF, loc).getValue(), args[1].getValue()}); return ManagedValue::forObjectRValueWithoutOwnership(builtinApply); } diff --git a/lib/SILOptimizer/Differentiation/VJPCloner.cpp b/lib/SILOptimizer/Differentiation/VJPCloner.cpp index 7cb715b278553..50098dda2c173 100644 --- a/lib/SILOptimizer/Differentiation/VJPCloner.cpp +++ b/lib/SILOptimizer/Differentiation/VJPCloner.cpp @@ -17,6 +17,8 @@ #define DEBUG_TYPE "differentiation" +#include "swift/AST/Types.h" + #include "swift/SILOptimizer/Differentiation/VJPCloner.h" #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" #include "swift/SILOptimizer/Differentiation/ADContext.h" @@ -118,15 +120,21 @@ class VJPCloner::Implementation final auto pullbackTupleType = remapASTType(pullbackInfo.getLinearMapTupleType(returnBB)->getCanonicalType()); Builder.setInsertionPoint(vjp->getEntryBlock()); - auto topLevelSubcontextSize = emitMemoryLayoutSize( - Builder, original->getLocation(), pullbackTupleType); + + auto pbTupleMetatypeType = + CanMetatypeType::get(pullbackTupleType, MetatypeRepresentation::Thick); + auto pbTupleMetatypeSILType = + SILType::getPrimitiveObjectType(pbTupleMetatypeType); + auto pbTupleMetatype = + Builder.createMetatype(original->getLocation(), pbTupleMetatypeSILType); + // Create an context. pullbackContextValue = Builder.createBuiltin( original->getLocation(), - getASTContext().getIdentifier( - getBuiltinName(BuiltinValueKind::AutoDiffCreateLinearMapContext)), - SILType::getNativeObjectType(getASTContext()), - SubstitutionMap(), {topLevelSubcontextSize}); + getASTContext().getIdentifier(getBuiltinName( + BuiltinValueKind::AutoDiffCreateLinearMapContextWithType)), + SILType::getNativeObjectType(getASTContext()), SubstitutionMap(), + {pbTupleMetatype}); borrowedPullbackContextValue = Builder.createBeginBorrow( original->getLocation(), pullbackContextValue); LLVM_DEBUG(getADDebugStream() @@ -148,8 +156,8 @@ class VJPCloner::Implementation final return builtinAutoDiffAllocateSubcontextGenericSignature; auto &ctx = getASTContext(); auto *decl = cast(getBuiltinValueDecl( - ctx, ctx.getIdentifier( - getBuiltinName(BuiltinValueKind::AutoDiffAllocateSubcontext)))); + ctx, ctx.getIdentifier(getBuiltinName( + BuiltinValueKind::AutoDiffAllocateSubcontextWithType)))); builtinAutoDiffAllocateSubcontextGenericSignature = decl->getGenericSignature(); assert(builtinAutoDiffAllocateSubcontextGenericSignature); @@ -1067,14 +1075,21 @@ EnumInst *VJPCloner::Implementation::buildPredecessorEnumValue( assert(enumEltType == rawPtrType); auto pbTupleType = remapASTType(pullbackInfo.getLinearMapTupleType(predBB)->getCanonicalType()); - SILValue pbTupleSize = - emitMemoryLayoutSize(Builder, loc, pbTupleType); + + auto pbTupleMetatypeType = + CanMetatypeType::get(pbTupleType, MetatypeRepresentation::Thick); + auto pbTupleMetatypeSILType = + SILType::getPrimitiveObjectType(pbTupleMetatypeType); + auto pbTupleMetatype = + Builder.createMetatype(original->getLocation(), pbTupleMetatypeSILType); + auto rawBufferValue = builder.createBuiltin( loc, - getASTContext().getIdentifier( - getBuiltinName(BuiltinValueKind::AutoDiffAllocateSubcontext)), + getASTContext().getIdentifier(getBuiltinName( + BuiltinValueKind::AutoDiffAllocateSubcontextWithType)), rawPtrType, SubstitutionMap(), - {borrowedPullbackContextValue, pbTupleSize}); + {borrowedPullbackContextValue, pbTupleMetatype}); + auto typedBufferValue = builder.createPointerToAddress( loc, rawBufferValue, pbTupleVal->getType().getAddressType(), diff --git a/lib/SILOptimizer/Transforms/AccessEnforcementReleaseSinking.cpp b/lib/SILOptimizer/Transforms/AccessEnforcementReleaseSinking.cpp index 01069fe3e2774..20359097fcf3d 100644 --- a/lib/SILOptimizer/Transforms/AccessEnforcementReleaseSinking.cpp +++ b/lib/SILOptimizer/Transforms/AccessEnforcementReleaseSinking.cpp @@ -146,7 +146,7 @@ static bool isBarrier(SILInstruction *inst) { case BuiltinValueKind::COWBufferForReading: case BuiltinValueKind::GetCurrentAsyncTask: case BuiltinValueKind::GetCurrentExecutor: - case BuiltinValueKind::AutoDiffCreateLinearMapContext: + case BuiltinValueKind::AutoDiffCreateLinearMapContextWithType: case BuiltinValueKind::EndAsyncLet: case BuiltinValueKind::EndAsyncLetLifetime: case BuiltinValueKind::CreateTaskGroup: @@ -199,7 +199,7 @@ static bool isBarrier(SILInstruction *inst) { case BuiltinValueKind::ResumeThrowingContinuationReturning: case BuiltinValueKind::ResumeThrowingContinuationThrowing: case BuiltinValueKind::AutoDiffProjectTopLevelSubcontext: - case BuiltinValueKind::AutoDiffAllocateSubcontext: + case BuiltinValueKind::AutoDiffAllocateSubcontextWithType: case BuiltinValueKind::AddressOfBorrowOpaque: case BuiltinValueKind::UnprotectedAddressOfBorrowOpaque: return true; diff --git a/stdlib/public/runtime/AutoDiffSupport.cpp b/stdlib/public/runtime/AutoDiffSupport.cpp index 3873e6165a7f4..372bcdac9447f 100644 --- a/stdlib/public/runtime/AutoDiffSupport.cpp +++ b/stdlib/public/runtime/AutoDiffSupport.cpp @@ -13,7 +13,7 @@ #include "AutoDiffSupport.h" #include "swift/ABI/Metadata.h" #include "swift/Runtime/HeapObject.h" - +#include "llvm/ADT/SmallVector.h" #include using namespace swift; @@ -47,6 +47,13 @@ AutoDiffLinearMapContext::AutoDiffLinearMapContext() : HeapObject(&linearMapContextHeapMetadata) { } +AutoDiffLinearMapContext::AutoDiffLinearMapContext( + const Metadata *topLevelLinearMapContextMetadata) + : HeapObject(&linearMapContextHeapMetadata) { + allocatedContextObjects.push_back(AllocatedContextObjectRecord{ + topLevelLinearMapContextMetadata, projectTopLevelSubcontext()}); +} + void *AutoDiffLinearMapContext::projectTopLevelSubcontext() const { auto offset = alignTo( sizeof(AutoDiffLinearMapContext), alignof(AutoDiffLinearMapContext)); @@ -58,6 +65,16 @@ void *AutoDiffLinearMapContext::allocate(size_t size) { return allocator.Allocate(size, alignof(AutoDiffLinearMapContext)); } +void *AutoDiffLinearMapContext::allocateSubcontext( + const Metadata *contextObjectMetadata) { + auto size = contextObjectMetadata->vw_size(); + auto align = contextObjectMetadata->vw_alignment(); + auto *contextObjectPtr = allocator.Allocate(size, align); + allocatedContextObjects.push_back( + AllocatedContextObjectRecord{contextObjectMetadata, contextObjectPtr}); + return contextObjectPtr; +} + AutoDiffLinearMapContext *swift::swift_autoDiffCreateLinearMapContext( size_t topLevelLinearMapStructSize) { auto allocationSize = alignTo( @@ -68,11 +85,31 @@ AutoDiffLinearMapContext *swift::swift_autoDiffCreateLinearMapContext( } void *swift::swift_autoDiffProjectTopLevelSubcontext( - AutoDiffLinearMapContext *allocator) { - return allocator->projectTopLevelSubcontext(); + AutoDiffLinearMapContext *linearMapContext) { + return static_cast(linearMapContext->projectTopLevelSubcontext()); } void *swift::swift_autoDiffAllocateSubcontext( AutoDiffLinearMapContext *allocator, size_t size) { return allocator->allocate(size); } + +AutoDiffLinearMapContext *swift::swift_autoDiffCreateLinearMapContextWithType( + const Metadata *topLevelLinearMapContextMetadata) { + assert(topLevelLinearMapContextMetadata->getValueWitnesses() != nullptr); + auto topLevelLinearMapContextSize = + topLevelLinearMapContextMetadata->vw_size(); + auto allocationSize = alignTo(sizeof(AutoDiffLinearMapContext), + alignof(AutoDiffLinearMapContext)) + + topLevelLinearMapContextSize; + auto *buffer = (AutoDiffLinearMapContext *)malloc(allocationSize); + return ::new (buffer) + AutoDiffLinearMapContext(topLevelLinearMapContextMetadata); +} + +void *swift::swift_autoDiffAllocateSubcontextWithType( + AutoDiffLinearMapContext *linearMapContext, + const Metadata *linearMapSubcontextMetadata) { + assert(linearMapSubcontextMetadata->getValueWitnesses() != nullptr); + return linearMapContext->allocateSubcontext(linearMapSubcontextMetadata); +} diff --git a/stdlib/public/runtime/AutoDiffSupport.h b/stdlib/public/runtime/AutoDiffSupport.h index 4fb63d82e062b..b3135ef794628 100644 --- a/stdlib/public/runtime/AutoDiffSupport.h +++ b/stdlib/public/runtime/AutoDiffSupport.h @@ -14,31 +14,95 @@ #define SWIFT_RUNTIME_AUTODIFF_SUPPORT_H #include "swift/Runtime/HeapObject.h" +#include "swift/ABI/Metadata.h" #include "swift/Runtime/Config.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Allocator.h" namespace swift { - /// A data structure responsible for efficiently allocating closure contexts for /// linear maps such as pullbacks, including recursive branching trace enum /// case payloads. class AutoDiffLinearMapContext : public HeapObject { + /// A simple wrapper around a context object allocated by the + /// `AutoDiffLinearMapContext` type. This type knows all the "physical" + /// properties and behavior of the allocated context object by way of + /// storing the allocated type's `TypeMetadata`. It uses this information + /// to ensure that the allocated context object is destroyed/deinitialized + /// properly, upon its own destruction. + class [[nodiscard]] AllocatedContextObjectRecord final { + const Metadata *contextObjectMetadata; + OpaqueValue *contextObjectPtr; + + public: + AllocatedContextObjectRecord(const Metadata *contextObjectMetadata, + OpaqueValue *contextObjectPtr) + : contextObjectMetadata(contextObjectMetadata), + contextObjectPtr(contextObjectPtr) {} + + AllocatedContextObjectRecord(const Metadata *contextObjectMetadata, + void *contextObjectPtr) + : AllocatedContextObjectRecord( + contextObjectMetadata, + static_cast(contextObjectPtr)) {} + + ~AllocatedContextObjectRecord() { + if (contextObjectMetadata != nullptr && contextObjectPtr != nullptr) { + contextObjectMetadata->vw_destroy(contextObjectPtr); + } + } + + AllocatedContextObjectRecord(const AllocatedContextObjectRecord &) = delete; + + AllocatedContextObjectRecord( + AllocatedContextObjectRecord &&other) noexcept { + this->contextObjectMetadata = other.contextObjectMetadata; + this->contextObjectPtr = other.contextObjectPtr; + other.contextObjectMetadata = nullptr; + other.contextObjectPtr = nullptr; + } + + size_t size() const { return contextObjectMetadata->vw_size(); } + + size_t align() const { return contextObjectMetadata->vw_alignment(); } + }; + private: /// The underlying allocator. // TODO: Use a custom allocator so that the initial slab can be // tail-allocated. llvm::BumpPtrAllocator allocator; + /// Storage for `AllocatedContextObjectRecord`s, corresponding to the + /// subcontext allocations performed by the type. + llvm::SmallVector allocatedContextObjects; + public: - /// Creates a linear map context. + /// DEPRECATED - Use overloaded constructor taking a `const Metadata *` + /// parameter instead. This constructor might be removed as it leads to memory + /// leaks. AutoDiffLinearMapContext(); + + AutoDiffLinearMapContext(const Metadata *topLevelLinearMapContextMetadata); + /// Returns the address of the tail-allocated top-level subcontext. void *projectTopLevelSubcontext() const; + /// Allocates memory for a new subcontext. + /// + /// DEPRECATED - Use `allocateSubcontext` instead. This + /// method might be removed as it leads to memory leaks. void *allocate(size_t size); + + /// Allocates memory for a new subcontext. + void *allocateSubcontext(const Metadata *contextObjectMetadata); }; /// Creates a linear map context with a tail-allocated top-level subcontext. +/// +/// DEPRECATED - Use `swift_autoDiffCreateLinearMapContextWithType` instead. +/// This builtin might be removed as it leads to memory leaks. SWIFT_RUNTIME_EXPORT SWIFT_CC(swift) AutoDiffLinearMapContext *swift_autoDiffCreateLinearMapContext( size_t topLevelSubcontextSize); @@ -48,9 +112,21 @@ SWIFT_RUNTIME_EXPORT SWIFT_CC(swift) void *swift_autoDiffProjectTopLevelSubcontext(AutoDiffLinearMapContext *); /// Allocates memory for a new subcontext. +/// +/// DEPRECATED - Use `swift_autoDiffAllocateSubcontextWithType` instead. This +/// builtin might be removed as it leads to memory leaks. SWIFT_RUNTIME_EXPORT SWIFT_CC(swift) void *swift_autoDiffAllocateSubcontext(AutoDiffLinearMapContext *, size_t size); -} +/// Creates a linear map context with a tail-allocated top-level subcontext. +SWIFT_RUNTIME_EXPORT SWIFT_CC(swift) + AutoDiffLinearMapContext *swift_autoDiffCreateLinearMapContextWithType( + const Metadata *topLevelLinearMapContextMetadata); +/// Allocates memory for a new subcontext. +SWIFT_RUNTIME_EXPORT + SWIFT_CC(swift) void *swift_autoDiffAllocateSubcontextWithType( + AutoDiffLinearMapContext *, + const Metadata *linearMapSubcontextMetadata); +} // namespace swift #endif /* SWIFT_RUNTIME_AUTODIFF_SUPPORT_H */ diff --git a/test/AutoDiff/IRGen/runtime.swift b/test/AutoDiff/IRGen/runtime.swift index 09c848fafcc04..8a753356212a8 100644 --- a/test/AutoDiff/IRGen/runtime.swift +++ b/test/AutoDiff/IRGen/runtime.swift @@ -3,22 +3,17 @@ import Swift import _Differentiation -struct ExamplePullbackStruct { - var pb0: (T.TangentVector) -> T.TangentVector -} - -@_silgen_name("test_context_builtins") -func test_context_builtins() { - let pbStruct = ExamplePullbackStruct(pb0: { $0 }) - let context = Builtin.autoDiffCreateLinearMapContext(Builtin.sizeof(type(of: pbStruct))) +@_silgen_name("test_context_builtins_with_type") +func test_context_builtins_with_type(t: T) { + let context = Builtin.autoDiffCreateLinearMapContextWithType(T.self) let topLevelSubctxAddr = Builtin.autoDiffProjectTopLevelSubcontext(context) - UnsafeMutableRawPointer(topLevelSubctxAddr).storeBytes(of: pbStruct, as: type(of: pbStruct)) - let newBuffer = Builtin.autoDiffAllocateSubcontext(context, Builtin.sizeof(type(of: pbStruct))) - UnsafeMutableRawPointer(newBuffer).storeBytes(of: pbStruct, as: type(of: pbStruct)) + UnsafeMutableRawPointer(topLevelSubctxAddr).storeBytes(of: t, as: T.self) + let newBuffer = Builtin.autoDiffAllocateSubcontextWithType(context, T.self) + UnsafeMutableRawPointer(newBuffer).storeBytes(of: t, as: T.self) } -// CHECK-LABEL: define{{.*}}@test_context_builtins() +// CHECK-LABEL: define{{.*}}@test_context_builtins_with_type(ptr noalias nocapture %0, ptr %T) // CHECK: entry: -// CHECK: [[CTX:%.*]] = call swiftcc ptr @swift_autoDiffCreateLinearMapContext({{i[0-9]+}} {{.*}}) +// CHECK: [[CTX:%.*]] = call swiftcc ptr @swift_autoDiffCreateLinearMapContextWithType(ptr %T) // CHECK: call swiftcc ptr @swift_autoDiffProjectTopLevelSubcontext(ptr [[CTX]]) -// CHECK: [[BUF:%.*]] = call swiftcc ptr @swift_autoDiffAllocateSubcontext(ptr [[CTX]], {{i[0-9]+}} {{.*}}) +// CHECK: [[BUF:%.*]] = call swiftcc ptr @swift_autoDiffAllocateSubcontextWithType(ptr [[CTX]], ptr %T) diff --git a/test/AutoDiff/SILGen/autodiff_builtins.swift b/test/AutoDiff/SILGen/autodiff_builtins.swift index 903ba227fbf71..b369deec7809c 100644 --- a/test/AutoDiff/SILGen/autodiff_builtins.swift +++ b/test/AutoDiff/SILGen/autodiff_builtins.swift @@ -84,24 +84,20 @@ func applyDerivative_f1_vjp(t0: T) -> (T // CHECK: copy_addr [take] [[D_RESULT_BUFFER_0_FOR_LOAD]] to [init] [[ORIG_RESULT_OUT_PARAM]] // CHECK: return [[PULLBACK]] -struct ExamplePullbackStruct { - var pb0: (T.TangentVector) -> T.TangentVector -} -@_silgen_name("test_context_builtins") -func test_context_builtins() { - let pbStruct = ExamplePullbackStruct(pb0: { $0 }) - let context = Builtin.autoDiffCreateLinearMapContext(Builtin.sizeof(type(of: pbStruct))) +@_silgen_name("test_context_builtins_with_type") +func test_context_builtins_with_type(t: T) { + let context = Builtin.autoDiffCreateLinearMapContextWithType(T.self) let topLevelSubctxAddr = Builtin.autoDiffProjectTopLevelSubcontext(context) - UnsafeMutableRawPointer(topLevelSubctxAddr).storeBytes(of: pbStruct, as: type(of: pbStruct)) - let newBuffer = Builtin.autoDiffAllocateSubcontext(context, Builtin.sizeof(type(of: pbStruct))) - UnsafeMutableRawPointer(newBuffer).storeBytes(of: pbStruct, as: type(of: pbStruct)) + UnsafeMutableRawPointer(topLevelSubctxAddr).storeBytes(of: t, as: T.self) + let newBuffer = Builtin.autoDiffAllocateSubcontextWithType(context, T.self) + UnsafeMutableRawPointer(newBuffer).storeBytes(of: t, as: T.self) } -// CHECK-LABEL: sil{{.*}}@test_context_builtins -// CHECK: bb0: -// CHECK: [[CTX:%.*]] = builtin "autoDiffCreateLinearMapContext"({{%.*}} : $Builtin.Word) : $Builtin.NativeObject -// CHECK: [[BORROWED_CTX:%.*]] = begin_borrow [lexical] [[CTX]] : $Builtin.NativeObject -// CHECK: [[BUF:%.*]] = builtin "autoDiffProjectTopLevelSubcontext"([[BORROWED_CTX]] : $Builtin.NativeObject) : $Builtin.RawPointer -// CHECK: [[BUF:%.*]] = builtin "autoDiffAllocateSubcontext"([[BORROWED_CTX]] : $Builtin.NativeObject, {{.*}} : $Builtin.Word) : $Builtin.RawPointer +// CHECK-LABEL: sil{{.*}}@test_context_builtins_with_type : $@convention(thin) (@in_guaranteed T) -> () { +// CHECK: bb0({{%.*}} : $*T): +// CHECK: [[CTX:%.*]] = builtin "autoDiffCreateLinearMapContextWithType"({{%.*}} : $@thick T.Type) : $Builtin.NativeObject // users: {{.*}} +// CHECK: [[BORROWED_CTX:%.*]] = begin_borrow [lexical] [[CTX]] : $Builtin.NativeObject // users: {{.*}} +// CHECK: [[BUF:%.*]] = builtin "autoDiffProjectTopLevelSubcontext"([[BORROWED_CTX]] : $Builtin.NativeObject) : $Builtin.RawPointer // users: {{.*}} +// CHECK: [[BUF:%.*]] = builtin "autoDiffAllocateSubcontextWithType"([[BORROWED_CTX]] : $Builtin.NativeObject, {{.*}} : $@thick T.Type) : $Builtin.RawPointer // users: {{.*}} // CHECK: destroy_value [[CTX]] diff --git a/test/AutoDiff/stdlib/callee_differential_not_leaked_in_func_with_loops.swift b/test/AutoDiff/stdlib/callee_differential_not_leaked_in_func_with_loops.swift new file mode 100644 index 0000000000000..8de10d1d02d2b --- /dev/null +++ b/test/AutoDiff/stdlib/callee_differential_not_leaked_in_func_with_loops.swift @@ -0,0 +1,60 @@ +// RUN: %target-run-simple-swift +// REQUIRES: executable_test + +import _Differentiation +import StdlibUnittest + +// When the original function contains loops, we allocate a context object +// on the heap. This context object may store non-trivial objects, such as closures, +// that need to be freed explicitly, at program exit. This test verifies that the +// autodiff runtime destroys and deallocates any such objects. + +extension LifetimeTracked: AdditiveArithmetic { + public static var zero: LifetimeTracked { fatalError() } + public static func + (lhs: LifetimeTracked, rhs: LifetimeTracked) -> LifetimeTracked {fatalError()} + public static func - (lhs: LifetimeTracked, rhs: LifetimeTracked) -> LifetimeTracked {fatalError()} +} + +extension LifetimeTracked: Differentiable { + public typealias TangentVector = LifetimeTracked + public func move(by: LifetimeTracked) {fatalError()} +} + +extension LifetimeTracked { + // The original differentiable callee. + func callee(_: Float) -> Float { 42 } + + // The callee differential (pullback in this case), that is + // captured in the context object allocated on the heap in the + // presence of loops. + // + // If the autodiff runtime does not free this callee differential + // properly, the `LifetimeTracked` instance that it captures will + // also not be freed and we will have a detectable memory leak. + @derivative(of: callee, wrt: (self, f)) + func calleeDifferential(f: Float) -> (value: Float, pullback: (Float) -> (LifetimeTracked, Float)) { + return ( + value: f, + pullback: { x in (self, x) } + ) + } +} + +@differentiable(reverse) +func f(ltti: LifetimeTracked) -> Float { + for _ in 0..<1 { + } + return ltti.callee(0xDEADBEEF) +} + +var Tests = TestSuite("CalleeDifferentialLeakTest") + +Tests.test("dontLeakCalleeDifferential") { + do { + let ltti = LifetimeTracked(0xDEADBEEF) + let _ = valueWithPullback(at: ltti, of: f) + } + expectEqual(0, LifetimeTracked.instances) +} + +runAllTests() \ No newline at end of file