From aacfea4b0269777a6cccd10fc5f721c7a8e27ec9 Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Tue, 3 May 2022 20:45:06 +0200 Subject: [PATCH 1/3] Implement cross-file lookup of derivatives Look-up for functions with @derivative attributes defined in non-primary source files Fixes #55170 --- include/swift/AST/Attr.h | 10 +++++++ lib/AST/Attr.cpp | 7 +++++ lib/AST/Decl.cpp | 28 +++++++++++++++++++ lib/Parse/ParseDecl.cpp | 2 ++ lib/Sema/TypeCheckAttr.cpp | 12 ++++++-- lib/Serialization/Deserialization.cpp | 5 ++++ lib/Serialization/Serialization.cpp | 2 +- ...fferentiation_diagnostics_cross_file.swift | 5 +--- 8 files changed, 63 insertions(+), 8 deletions(-) diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index 1a49d388c2267..6bfcd68a38d91 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1938,6 +1938,10 @@ class DerivativeAttr final friend TrailingObjects; friend class DerivativeAttrOriginalDeclRequest; + /// The declaration on which the `@derivative` attribute is declared. + /// May not be a valid declaration for `@derivative` attributes. + /// Resolved during parsing and deserialization. + Decl *OriginalDeclaration = nullptr; /// The base type for the referenced original declaration. This field is /// non-null only for parsed attributes that reference a qualified original /// declaration. This field is not serialized; type-checking uses it to @@ -1991,6 +1995,12 @@ class DerivativeAttr final DeclNameRefWithLoc original, IndexSubset *parameterIndices); + Decl *getOriginalDeclaration() const { return OriginalDeclaration; } + + /// Sets the original declaration on which this attribute is declared. + /// Should only be used by parsing and deserialization. + void setOriginalDeclaration(Decl *originalDeclaration); + TypeRepr *getBaseTypeRepr() const { return BaseTypeRepr; } DeclNameRefWithLoc getOriginalFunctionName() const { return OriginalFunctionName; diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index 427e94c4336fb..8b98bb33fdc4e 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -2107,6 +2107,13 @@ void DerivativeAttr::setOriginalFunctionResolver( ResolverContextData = resolverContextData; } +void DerivativeAttr::setOriginalDeclaration(Decl *originalDeclaration) { + assert(originalDeclaration && "Original declaration must be non-null"); + assert(!OriginalDeclaration && + "Original declaration cannot have already been set"); + OriginalDeclaration = originalDeclaration; +} + TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange, TypeRepr *baseTypeRepr, DeclNameRefWithLoc originalName, diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 81cd76ae1b5dc..bafef097a5934 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -20,6 +20,7 @@ #include "swift/AST/ASTContext.h" #include "swift/AST/ASTWalker.h" #include "swift/AST/ASTMangler.h" +#include "swift/AST/Attr.h" #include "swift/AST/CaptureInfo.h" #include "swift/AST/DiagnosticEngine.h" #include "swift/AST/DiagnosticsSema.h" @@ -8333,6 +8334,33 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations() { ctx.loadDerivativeFunctionConfigurations(this, previousGeneration, *DerivativeFunctionConfigs); } + + class DerivativeFinder : public ASTWalker { + const AbstractFunctionDecl *AFD; + public: + DerivativeFinder(const AbstractFunctionDecl *afd) : AFD(afd) {} + + bool walkToDeclPre(Decl *D) override { + if (auto *afd = dyn_cast(D)) { + for (auto *derAttr : afd->getAttrs().getAttributes()) { + // Resolve derivative function configurations from `@derivative` + // attributes by type-checking them. + if (AFD->getName().matchesRef(derAttr->getOriginalFunctionName().Name.getFullName())) { + (void)derAttr->getOriginalFunction(afd->getASTContext()); + return false; + } + } + } + + return true; + } + }; + + // Load derivative configurations from @derivative attributes defined in + // non-primary sources + DerivativeFinder finder(this); + getParent()->walkContext(finder); + return DerivativeFunctionConfigs->getArrayRef(); } diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 1b78f4e7ca892..547a06d9f9f5d 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -4455,6 +4455,8 @@ setOriginalDeclarationForDifferentiableAttributes(DeclAttributes attrs, Decl *D) { for (auto *attr : attrs.getAttributes()) const_cast(attr)->setOriginalDeclaration(D); + for (auto *attr : attrs.getAttributes()) + const_cast(attr)->setOriginalDeclaration(D); } /// Parse a single syntactic declaration and return a list of decl diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 109d911f1b7ef..19a7930f48334 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -4949,10 +4949,11 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { /// - Stores the attribute in `ASTContext::DerivativeAttrs`. /// /// \returns true on error, false on success. -static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, - DerivativeAttr *attr) { +static bool typeCheckDerivativeAttr(DerivativeAttr *attr) { // Note: Implementation must be idempotent because it may be called multiple // times for the same attribute. + Decl *D = attr->getOriginalDeclaration(); + auto &Ctx = D->getASTContext(); auto &diags = Ctx.Diags; // `@derivative` attribute requires experimental differentiable programming // to be enabled. @@ -5365,13 +5366,18 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, } void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) { - if (typeCheckDerivativeAttr(Ctx, D, attr)) + if (typeCheckDerivativeAttr(attr)) attr->setInvalid(); } AbstractFunctionDecl * DerivativeAttrOriginalDeclRequest::evaluate(Evaluator &evaluator, DerivativeAttr *attr) const { + // Try to resolve the original function + if (attr->isValid() && attr->OriginalFunction.isNull()) + if (typeCheckDerivativeAttr(attr)) + attr->setInvalid(); + // If the typechecker has resolved the original function, return it. if (auto *FD = attr->OriginalFunction.dyn_cast()) return FD; diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index 9a6068e02478f..9d6c3dbf0de2c 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -15,6 +15,7 @@ #include "ModuleFile.h" #include "ModuleFormat.h" #include "swift/AST/ASTContext.h" +#include "swift/AST/Attr.h" #include "swift/AST/AutoDiff.h" #include "swift/AST/DiagnosticsSema.h" #include "swift/AST/Expr.h" @@ -2590,6 +2591,10 @@ static void setOriginalDeclarationAndParameterIndicesInDifferentiableAttributes( diffAttr->setOriginalDeclaration(decl); diffAttr->setParameterIndices(diffAttrParamIndicesMap[diffAttr]); } + for (auto *attr : tempAttrs.getAttributes()) { + auto *derAttr = const_cast(attr); + derAttr->setOriginalDeclaration(decl); + } } Decl *ModuleFile::getDecl(DeclID DID) { diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index c4ab6747b39e3..25aba2a6c18bc 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -2780,7 +2780,7 @@ class Serializer::DeclSerializer : public DeclVisitor { auto abbrCode = S.DeclTypeAbbrCodes[DerivativeDeclAttrLayout::Code]; auto *attr = cast(DA); auto &ctx = S.getASTContext(); - assert(attr->getOriginalFunction(ctx) && + assert(attr->getOriginalFunction(ctx) && attr->getOriginalDeclaration() && "`@derivative` attribute should have original declaration set " "during construction or parsing"); auto origDeclNameRef = attr->getOriginalFunctionName(); diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift index 54e13f9c0d918..539e6f1c34d2e 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift @@ -13,14 +13,11 @@ func crossFileDifferentiableAttr( } // TF-1272: Test original function with registered derivatives in other files. -// FIXME(TF-1272): Find a way to type-check `@derivative` attributes in other -// files. @differentiable(reverse) func crossFileDerivativeAttr( _ input: T ) -> T { - // expected-error @+2 {{expression is not differentiable}} - // expected-note @+1 {{cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files}} + // No error expected return input.identityDerivativeAttr() } From a764df9655dda9f8c04fa803d0dc70efb8338184 Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Wed, 4 May 2022 16:21:08 +0200 Subject: [PATCH 2/3] Show hackish way to break the lookup cycle --- include/swift/AST/Decl.h | 3 ++- lib/AST/Decl.cpp | 11 +++++++---- lib/Sema/TypeCheckProtocol.cpp | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index e86fff132de50..1420a9ad12e7f 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -6266,7 +6266,8 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl { public: /// Get all derivative function configurations. - ArrayRef getDerivativeFunctionConfigurations(); + ArrayRef + getDerivativeFunctionConfigurations(bool NonPrimaryLookup = true); /// Add the given derivative function configuration. void addDerivativeFunctionConfiguration(const AutoDiffConfig &config); diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index bafef097a5934..b171dcc8008dd 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -8311,7 +8311,7 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() { } ArrayRef -AbstractFunctionDecl::getDerivativeFunctionConfigurations() { +AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool NonPrimaryLookup) { prepareDerivativeFunctionConfigurations(); // Resolve derivative function configurations from `@differentiable` @@ -8357,9 +8357,12 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations() { }; // Load derivative configurations from @derivative attributes defined in - // non-primary sources - DerivativeFinder finder(this); - getParent()->walkContext(finder); + // non-primary sources. Note that it might trigger lookup cycles if called + // from inside Sema stages + if (NonPrimaryLookup) { + DerivativeFinder finder(this); + getParent()->walkContext(finder); + } return DerivativeFunctionConfigs->getArrayRef(); } diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index bddbb877d9217..783e3566d19d2 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -379,7 +379,7 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, bool foundExactConfig = false; Optional supersetConfig = None; for (auto witnessConfig : - witnessAFD->getDerivativeFunctionConfigurations()) { + witnessAFD->getDerivativeFunctionConfigurations(/*NonLocal*/ false)) { // All the witness's derivative generic requirements must be satisfied // by the requirement's derivative generic requirements OR by the // conditional conformance requirements. From 6149146c18dce2f86b9d54e376a4072bb88d9bac Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Sun, 8 May 2022 19:11:11 +0200 Subject: [PATCH 3/3] Rename argument and address some other stylistic comments --- include/swift/AST/Decl.h | 6 ++++-- lib/AST/Decl.cpp | 9 +++++---- lib/Sema/TypeCheckAttr.cpp | 2 +- lib/Sema/TypeCheckProtocol.cpp | 3 ++- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index 1420a9ad12e7f..573a640eb2164 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -6265,9 +6265,11 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl { DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr; public: - /// Get all derivative function configurations. + /// Get all derivative function configurations. If `lookInNonPrimarySources` + /// is true then lookup is done in non-primary sources as well. Note that + /// such lookup might end in cycles if done during sema stages. ArrayRef - getDerivativeFunctionConfigurations(bool NonPrimaryLookup = true); + getDerivativeFunctionConfigurations(bool lookInNonPrimarySources = true); /// Add the given derivative function configuration. void addDerivativeFunctionConfiguration(const AutoDiffConfig &config); diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index b171dcc8008dd..2300941ae3b3a 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -8311,7 +8311,7 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() { } ArrayRef -AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool NonPrimaryLookup) { +AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool lookInNonPrimarySources) { prepareDerivativeFunctionConfigurations(); // Resolve derivative function configurations from `@differentiable` @@ -8345,7 +8345,8 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool NonPrimaryLookup) for (auto *derAttr : afd->getAttrs().getAttributes()) { // Resolve derivative function configurations from `@derivative` // attributes by type-checking them. - if (AFD->getName().matchesRef(derAttr->getOriginalFunctionName().Name.getFullName())) { + if (AFD->getName().matchesRef( + derAttr->getOriginalFunctionName().Name.getFullName())) { (void)derAttr->getOriginalFunction(afd->getASTContext()); return false; } @@ -8358,8 +8359,8 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool NonPrimaryLookup) // Load derivative configurations from @derivative attributes defined in // non-primary sources. Note that it might trigger lookup cycles if called - // from inside Sema stages - if (NonPrimaryLookup) { + // from inside Sema stages. + if (lookInNonPrimarySources) { DerivativeFinder finder(this); getParent()->walkContext(finder); } diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 19a7930f48334..1c252905c65ba 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -5373,7 +5373,7 @@ void AttributeChecker::visitDerivativeAttr(DerivativeAttr *attr) { AbstractFunctionDecl * DerivativeAttrOriginalDeclRequest::evaluate(Evaluator &evaluator, DerivativeAttr *attr) const { - // Try to resolve the original function + // Try to resolve the original function. if (attr->isValid() && attr->OriginalFunction.isNull()) if (typeCheckDerivativeAttr(attr)) attr->setInvalid(); diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 783e3566d19d2..6c3071a8db904 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -379,7 +379,8 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, bool foundExactConfig = false; Optional supersetConfig = None; for (auto witnessConfig : - witnessAFD->getDerivativeFunctionConfigurations(/*NonLocal*/ false)) { + witnessAFD->getDerivativeFunctionConfigurations( + /*lookInNonPrimarySources*/ false)) { // All the witness's derivative generic requirements must be satisfied // by the requirement's derivative generic requirements OR by the // conditional conformance requirements.