diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index 0ffd3a600eb17..79d9be1b1da2a 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -4390,6 +4390,89 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode, const clang::FunctionType *getClangFunctionType() const; + /// Returns the type of the derivative function for the given parameter + /// indices, result index, derivative function kind, derivative function + /// generic signature (optional), and other auxiliary parameters. + /// + /// Preconditions: + /// - Parameters corresponding to parameter indices must conform to + /// `Differentiable`. + /// - The result corresponding to the result index must conform to + /// `Differentiable`. + /// + /// Typing rules, given: + /// - Original function type: $(T0, T1, ...) -> (R0, R1, ...) + /// + /// Terminology: + /// - The derivative of a `Differentiable`-conforming type has the + /// `TangentVector` associated type. `TangentVector` is abbreviated as `Tan` + /// below. + /// - "wrt" parameters refers to parameters indicated by the parameter + /// indices. + /// - "wrt" result refers to the result indicated by the result index. + /// + /// JVP derivative type: + /// - Takes original parameters. + /// - Returns original results, followed by a differential function, which + /// takes "wrt" parameter derivatives and returns a "wrt" result derivative. + /// + /// $(T0, ...) -> (R0, ..., (T0.Tan, T1.Tan, ...) -> R0.Tan) + /// ^~~~~~~ ^~~~~~~~~~~~~~~~~~~ ^~~~~~ + /// original results | derivatives wrt params | derivative wrt result + /// + /// VJP derivative type: + /// - Takes original parameters. + /// - Returns original results, followed by a pullback function, which + /// takes a "wrt" result derivative and returns "wrt" parameter derivatives. + /// + /// $(T0, ...) -> (R0, ..., (R0.Tan) -> (T0.Tan, T1.Tan, ...)) + /// ^~~~~~~ ^~~~~~ ^~~~~~~~~~~~~~~~~~~ + /// original results | derivative wrt result | derivatives wrt params + /// + /// A "constrained derivative generic signature" is computed from + /// `derivativeFunctionGenericSignature`, if specified. Otherwise, it is + /// computed from the original generic signature. A "constrained derivative + /// generic signature" requires all "wrt" parameters to conform to + /// `Differentiable`; this is important for correctness. + /// + /// This "constrained derivative generic signature" is used for + /// parameter/result type lowering. It is used as the actual generic signature + /// of the derivative function type iff the original function type has a + /// generic signature and not all generic parameters are bound to concrete + /// types. Otherwise, no derivative generic signature is used. + /// + /// Other properties of the original function type are copied exactly: + /// `ExtInfo`, coroutine kind, callee convention, yields, optional error + /// result, witness method conformance, etc. + /// + /// Special cases: + /// - Reabstraction thunks have special derivative type calculation. The + /// original function-typed last parameter is transformed into a + /// `@differentiable` function-typed parameter in the derivative type. This + /// is necessary for the differentiation transform to support reabstraction + /// thunk differentiation because the function argument is opaque and cannot + /// be differentiated. Instead, the argument is made `@differentiable` and + /// reabstraction thunk JVP/VJP callers are responsible for passing a + /// `@differentiable` function. + /// - TODO(TF-1036): Investigate more efficient reabstraction thunk + /// derivative approaches. The last argument can simply be a + /// corresponding derivative function, instead of a `@differentiable` + /// function - this is more direct. It may be possible to implement + /// reabstraction thunk derivatives using "reabstraction thunks for + /// the original function's derivative", avoiding extra code generation. + /// + /// Caveats: + /// - We may support multiple result indices instead of a single result index + /// eventually. At the SIL level, this enables differentiating wrt multiple + /// function results. At the Swift level, this enables differentiating wrt + /// multiple tuple elements for tuple-returning functions. + CanSILFunctionType getAutoDiffDerivativeFunctionType( + IndexSubset *parameterIndices, unsigned resultIndex, + AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC, + LookupConformanceFn lookupConformance, + CanGenericSignature derivativeFunctionGenericSignature = nullptr, + bool isReabstractionThunk = false); + ExtInfo getExtInfo() const { return ExtInfo(Bits.SILFunctionType.ExtInfoBits, getClangFunctionType()); } diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp index df9384141bf26..b491b338cbad4 100644 --- a/lib/SIL/SILFunctionType.cpp +++ b/lib/SIL/SILFunctionType.cpp @@ -22,6 +22,7 @@ #include "swift/AST/DiagnosticsSIL.h" #include "swift/AST/ForeignInfo.h" #include "swift/AST/GenericEnvironment.h" +#include "swift/AST/GenericSignatureBuilder.h" #include "swift/AST/Module.h" #include "swift/AST/ModuleLoader.h" #include "swift/AST/ProtocolConformance.h" @@ -190,6 +191,196 @@ SILFunctionType::getWitnessMethodClass(SILModule &M) const { return nullptr; } +// Returns the canonical generic signature for an autodiff derivative function +// given an existing derivative function generic signature. All +// differentiability parameters are required to conform to `Differentiable`. +static CanGenericSignature getAutoDiffDerivativeFunctionGenericSignature( + CanGenericSignature derivativeFnGenSig, + ArrayRef originalParameters, + IndexSubset *parameterIndices, ModuleDecl *module) { + if (!derivativeFnGenSig) + return nullptr; + auto &ctx = module->getASTContext(); + GenericSignatureBuilder builder(ctx); + // Add derivative function generic signature. + builder.addGenericSignature(derivativeFnGenSig); + // All differentiability parameters are required to conform to + // `Differentiable`. + auto source = + GenericSignatureBuilder::FloatingRequirementSource::forAbstract(); + auto *differentiableProtocol = + ctx.getProtocol(KnownProtocolKind::Differentiable); + for (unsigned paramIdx : parameterIndices->getIndices()) { + auto paramType = originalParameters[paramIdx].getInterfaceType(); + Requirement req(RequirementKind::Conformance, paramType, + differentiableProtocol->getDeclaredType()); + builder.addRequirement(req, source, module); + } + return std::move(builder) + .computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams*/ true) + ->getCanonicalSignature(); +} + +CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType( + IndexSubset *parameterIndices, unsigned resultIndex, + AutoDiffDerivativeFunctionKind kind, TypeConverter &TC, + LookupConformanceFn lookupConformance, + CanGenericSignature derivativeFnGenSig, bool isReabstractionThunk) { + auto &ctx = getASTContext(); + + // Returns true if `index` is a differentiability parameter index. + auto isDiffParamIndex = [&](unsigned index) -> bool { + return index < parameterIndices->getCapacity() && + parameterIndices->contains(index); + }; + + // Calculate differentiability parameter infos. + SmallVector diffParams; + for (auto valueAndIndex : enumerate(getParameters())) + if (isDiffParamIndex(valueAndIndex.index())) + diffParams.push_back(valueAndIndex.value()); + + // Get the canonical derivative function generic signature. + if (!derivativeFnGenSig) + derivativeFnGenSig = getSubstGenericSignature(); + derivativeFnGenSig = getAutoDiffDerivativeFunctionGenericSignature( + derivativeFnGenSig, getParameters(), parameterIndices, &TC.M); + + // Given a type, returns its formal SIL parameter info. + auto getTangentParameterInfoForOriginalResult = + [&](CanType tanType, ResultConvention origResConv) -> SILParameterInfo { + AbstractionPattern pattern(derivativeFnGenSig, tanType); + auto &tl = + TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal()); + ParameterConvention conv; + switch (origResConv) { + case ResultConvention::Owned: + case ResultConvention::Autoreleased: + conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned + : ParameterConvention::Direct_Guaranteed; + break; + case ResultConvention::Unowned: + case ResultConvention::UnownedInnerPointer: + conv = ParameterConvention::Direct_Unowned; + break; + case ResultConvention::Indirect: + conv = ParameterConvention::Indirect_In_Guaranteed; + break; + } + return {tanType, conv}; + }; + + // Given a type, returns its formal SIL result info. + auto getTangentResultInfoForOriginalParameter = + [&](CanType tanType, ParameterConvention origParamConv) -> SILResultInfo { + AbstractionPattern pattern(derivativeFnGenSig, tanType); + auto &tl = + TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal()); + ResultConvention conv; + switch (origParamConv) { + case ParameterConvention::Direct_Owned: + case ParameterConvention::Direct_Guaranteed: + case ParameterConvention::Direct_Unowned: + conv = + tl.isTrivial() ? ResultConvention::Unowned : ResultConvention::Owned; + break; + case ParameterConvention::Indirect_In: + case ParameterConvention::Indirect_Inout: + case ParameterConvention::Indirect_In_Constant: + case ParameterConvention::Indirect_In_Guaranteed: + case ParameterConvention::Indirect_InoutAliasable: + conv = ResultConvention::Indirect; + break; + } + return {tanType, conv}; + }; + + CanSILFunctionType closureType; + switch (kind) { + case AutoDiffDerivativeFunctionKind::JVP: { + SmallVector differentialParams; + for (auto ¶m : diffParams) { + auto paramTan = + param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance); + assert(paramTan && "Parameter type does not have a tangent space?"); + differentialParams.push_back( + {paramTan->getCanonicalType(), param.getConvention()}); + } + SmallVector differentialResults; + auto &result = getResults()[resultIndex]; + auto resultTan = + result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance); + assert(resultTan && "Result type does not have a tangent space?"); + differentialResults.push_back( + {resultTan->getCanonicalType(), result.getConvention()}); + closureType = SILFunctionType::get( + /*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None, + ParameterConvention::Direct_Guaranteed, differentialParams, {}, + differentialResults, None, getSubstitutions(), + isGenericSignatureImplied(), ctx); + break; + } + case AutoDiffDerivativeFunctionKind::VJP: { + SmallVector pullbackParams; + auto &origRes = getResults()[resultIndex]; + auto resultTan = + origRes.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance); + assert(resultTan && "Result type does not have a tangent space?"); + pullbackParams.push_back(getTangentParameterInfoForOriginalResult( + resultTan->getCanonicalType(), origRes.getConvention())); + SmallVector pullbackResults; + for (auto ¶m : diffParams) { + auto paramTan = + param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance); + assert(paramTan && "Parameter type does not have a tangent space?"); + pullbackResults.push_back(getTangentResultInfoForOriginalParameter( + paramTan->getCanonicalType(), param.getConvention())); + } + closureType = SILFunctionType::get( + /*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None, + ParameterConvention::Direct_Guaranteed, pullbackParams, {}, + pullbackResults, {}, getSubstitutions(), isGenericSignatureImplied(), + ctx); + break; + } + } + + SmallVector newParameters; + newParameters.reserve(getNumParameters()); + for (auto ¶m : getParameters()) { + newParameters.push_back(param.getWithInterfaceType( + param.getInterfaceType()->getCanonicalType(derivativeFnGenSig))); + } + // TODO(TF-1124): Upstream reabstraction thunk derivative typing rules. + // Blocked by TF-1125: `SILFunctionType::getWithDifferentiability`. + SmallVector newResults; + newResults.reserve(getNumResults() + 1); + for (auto &result : getResults()) { + newResults.push_back(result.getWithInterfaceType( + result.getInterfaceType()->getCanonicalType(derivativeFnGenSig))); + } + newResults.push_back({closureType->getCanonicalType(derivativeFnGenSig), + ResultConvention::Owned}); + // Derivative function type has a generic signature only if the original + // function type does, and if `derivativeFnGenSig` does not have all concrete + // generic parameters. + CanGenericSignature canGenSig; + if (getSubstGenericSignature() && derivativeFnGenSig && + !derivativeFnGenSig->areAllParamsConcrete()) + canGenSig = derivativeFnGenSig; + // If original function is `@convention(c)`, the derivative function should + // have `@convention(thin)`. IRGen does not support `@convention(c)` functions + // with multiple results. + auto extInfo = getExtInfo(); + if (getRepresentation() == SILFunctionTypeRepresentation::CFunctionPointer) + extInfo = extInfo.withRepresentation(SILFunctionTypeRepresentation::Thin); + return SILFunctionType::get(canGenSig, extInfo, getCoroutineKind(), + getCalleeConvention(), newParameters, getYields(), + newResults, getOptionalErrorResult(), + getSubstitutions(), isGenericSignatureImplied(), + ctx, getWitnessMethodConformanceOrInvalid()); +} + static CanType getKnownType(Optional &cacheSlot, ASTContext &C, StringRef moduleName, StringRef typeName) { if (!cacheSlot) {