From 06b05c5dbeab24dec5e1a19ba27b5b2643cf517c Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sun, 5 Jul 2020 23:31:43 -0700 Subject: [PATCH] [AutoDiff] Improve `@derivative` and `@transpose` diagnostics. Improve `@derivative` and `@transpose` type-checking diagnostics for resolving the referenced original declaration. Previously, an error was produced on one invalid candidate at the attribute's location. This did not indicate the invalid candidate's location or the total number of invalid candidates. Now: - Diagnostic notes are produced on all invalid candidates at their location. Invalid candidates' descriptive declaration kind are shown for clarity. - Derivative registration for protocol requirements (not yet supported, TF-982) now has a clear, dedicated diagnostic. - The "original declaration type mismatch" diagnostic is improved for expected original function types with generic signatures. The message now accurately reads "candidate does not have type equal to *or less constrained than* ...", instead of "candidate does not have expected type ...". Resolves SR-13150. Paves the way for future diagnostic improvements: SR-13151, SR-13152. --- include/swift/AST/DiagnosticsSema.def | 34 +- lib/Sema/TypeCheckAttr.cpp | 414 ++++++++++-------- .../Sema/derivative_attr_type_checking.swift | 78 +++- .../Sema/transpose_attr_type_checking.swift | 11 +- 4 files changed, 326 insertions(+), 211 deletions(-) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index ced5a901f8a1e..f40b05ee93584 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -3095,10 +3095,16 @@ ERROR(derivative_attr_class_member_dynamic_self_result_unsupported,none, ERROR(derivative_attr_nonfinal_class_init_unsupported,none, "cannot register derivative for 'init' in a non-final class; consider " "making %0 final", (Type)) +ERROR(derivative_attr_unsupported_accessor_kind,none, + "cannot register derivative for %0", (/*accessorKind*/ DescriptiveDeclKind)) // TODO(SR-13096): Remove this temporary diagnostic. ERROR(derivative_attr_class_setter_unsupported,none, "cannot yet register derivative for class property or subscript setters", ()) +// TODO(TF-982): Remove this temporary diagnostic. +NOTE(derivative_attr_protocol_requirement_unsupported,none, + "cannot yet register derivative default implementation for protocol " + "requirements", ()) ERROR(derivative_attr_original_already_has_derivative,none, "a derivative already exists for %0", (DeclName)) NOTE(derivative_attr_duplicate_note,none, @@ -3134,15 +3140,25 @@ NOTE(transpose_attr_wrt_self_self_type_mismatch_note,none, "%1", (Type, Type)) // Automatic differentiation attributes -ERROR(autodiff_attr_original_decl_invalid_kind,none, - "%0 is not a 'func', 'init', 'subscript', or 'var' computed property " - "declaration", (DeclNameRef)) -ERROR(autodiff_attr_accessor_not_found,none, - "%0 does not have a '%1' accessor", (DeclNameRef, StringRef)) -ERROR(autodiff_attr_original_decl_none_valid_found,none, - "could not find function %0 with expected type %1", (DeclNameRef, Type)) -ERROR(autodiff_attr_original_decl_not_same_type_context,none, - "%0 is not defined in the current type context", (DeclNameRef)) +ERROR(autodiff_attr_original_decl_ambiguous,none, + "referenced declaration %0 is ambiguous", (DeclNameRef)) +NOTE(autodiff_attr_original_decl_ambiguous_candidate,none, + "candidate %0 found here", (DescriptiveDeclKind)) +ERROR(autodiff_attr_original_decl_none_valid,none, + "referenced declaration %0 could not be resolved", (DeclNameRef)) +NOTE(autodiff_attr_original_decl_invalid_kind,none, + "candidate %0 is not a 'func', 'init', 'subscript', or 'var' computed " + "property declaration", (DescriptiveDeclKind)) +NOTE(autodiff_attr_original_decl_missing_accessor,none, + "candidate %0 does not have a %1", + (DescriptiveDeclKind, /*accessorDeclKind*/ DescriptiveDeclKind)) +NOTE(autodiff_attr_original_decl_type_mismatch,none, + "candidate %0 does not have " + "%select{expected type|type equal to or less constrained than}2 %1", + (DescriptiveDeclKind, Type, /*hasGenericSignature*/ bool)) +NOTE(autodiff_attr_original_decl_not_same_type_context,none, + "candidate %0 is not defined in the current type context", + (DescriptiveDeclKind)) ERROR(autodiff_attr_original_void_result,none, "cannot differentiate void function %0", (DeclName)) ERROR(autodiff_attr_original_multiple_semantic_results,none, diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 8883713e5aa3f..e29d1dc17952f 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3604,26 +3604,71 @@ static IndexSubset *computeDifferentiabilityParameters( return IndexSubset::get(ctx, parameterBits); } -// Returns the function declaration corresponding to the given function name and -// lookup context. If the base type of the function is specified, member lookup -// is performed. Otherwise, unqualified lookup is performed. +/// Returns the `DescriptiveDeclKind` corresponding to the given `AccessorKind`. +/// Used for diagnostics. +static DescriptiveDeclKind getAccessorDescriptiveDeclKind(AccessorKind kind) { + switch (kind) { + case AccessorKind::Get: + return DescriptiveDeclKind::Getter; + case AccessorKind::Set: + return DescriptiveDeclKind::Setter; + case AccessorKind::Read: + return DescriptiveDeclKind::ReadAccessor; + case AccessorKind::Modify: + return DescriptiveDeclKind::ModifyAccessor; + case AccessorKind::WillSet: + return DescriptiveDeclKind::WillSet; + case AccessorKind::DidSet: + return DescriptiveDeclKind::DidSet; + case AccessorKind::Address: + return DescriptiveDeclKind::Addressor; + case AccessorKind::MutableAddress: + return DescriptiveDeclKind::MutableAddressor; + } +} + +/// An abstract function declaration lookup error. +enum class AbstractFunctionDeclLookupErrorKind { + /// No lookup candidates could be found. + NoCandidatesFound, + /// There are multiple valid lookup candidates. + CandidatesAmbiguous, + /// Lookup candidate does not have the expected type. + CandidateTypeMismatch, + /// Lookup candidate is in the wrong type context. + CandidateWrongTypeContext, + /// Lookup candidate does not have the requested accessor. + CandidateMissingAccessor, + /// Lookup candidate is a protocol requirement. + CandidateProtocolRequirement, + /// Lookup candidate could be resolved to an `AbstractFunctionDecl`. + CandidateNotFunctionDeclaration +}; + +// Returns the function declaration corresponding to the given base type +// (optional), function name, and lookup context. +// +// If the base type of the function is specified, member lookup is performed. +// Otherwise, unqualified lookup is performed. +// // If the function declaration cannot be resolved, emits a diagnostic and // returns nullptr. +// +// Used for resolving the referenced declaration in `@derivative` and +// `@transpose` attributes. static AbstractFunctionDecl *findAbstractFunctionDecl( - DeclNameRef funcName, SourceLoc funcNameLoc, - Optional accessorKind, Type baseType, - DeclContext *lookupContext, - const std::function &isValidCandidate, - const std::function &noneValidDiagnostic, - const std::function &ambiguousDiagnostic, - const std::function ¬FunctionDiagnostic, - const std::function &missingAccessorDiagnostic, - NameLookupOptions lookupOptions, - const Optional> - &hasValidTypeCtx, - const Optional> &invalidTypeCtxDiagnostic) { + DeclAttribute *attr, Type baseType, DeclNameRefWithLoc funcNameWithLoc, + DeclContext *lookupContext, NameLookupOptions lookupOptions, + const llvm::function_ref( + AbstractFunctionDecl *)> &isValidCandidate, + AnyFunctionType *expectedOriginalFnType) { + assert(lookupContext); auto &ctx = lookupContext->getASTContext(); - AbstractFunctionDecl *resolvedCandidate = nullptr; + auto &diags = ctx.Diags; + + auto funcName = funcNameWithLoc.Name; + auto funcNameLoc = funcNameWithLoc.Loc; + auto maybeAccessorKind = funcNameWithLoc.AccessorKind; // Perform lookup. LookupResult results; @@ -3634,22 +3679,25 @@ static AbstractFunctionDecl *findAbstractFunctionDecl( if (baseType) { results = TypeChecker::lookupMember(lookupContext, baseType, funcName); } else { - results = TypeChecker::lookupUnqualified(lookupContext, funcName, - funcNameLoc, lookupOptions); + results = TypeChecker::lookupUnqualified( + lookupContext, funcName, funcNameLoc.getBaseNameLoc(), lookupOptions); + } + + // Error if no candidates were found. + if (results.empty()) { + diags.diagnose(funcNameLoc, diag::cannot_find_in_scope, funcName, + funcName.isOperator()); + return nullptr; } - // Initialize error flags. - bool notFunction = false; - bool wrongTypeContext = false; - bool ambiguousFuncDecl = false; - bool foundInvalid = false; - bool missingAccessor = false; + // Track invalid and valid candidates. + using LookupErrorKind = AbstractFunctionDeclLookupErrorKind; + SmallVector, 2> invalidCandidates; + SmallVector validCandidates; // Filter lookup results. for (auto choice : results) { - auto decl = choice.getValueDecl(); - if (!decl) - continue; + auto *decl = choice.getValueDecl(); // Cast the candidate to an `AbstractFunctionDecl`. auto *candidate = dyn_cast(decl); // If the candidate is an `AbstractStorageDecl`, use one of its accessors as @@ -3657,67 +3705,117 @@ static AbstractFunctionDecl *findAbstractFunctionDecl( if (auto *asd = dyn_cast(decl)) { // If accessor kind is specified, use corresponding accessor from the // candidate. Otherwise, use the getter by default. - if (accessorKind != None) { - candidate = asd->getOpaqueAccessor(accessorKind.getValue()); - // Error if candidate is missing the requested accessor. - if (!candidate) - missingAccessor = true; - } else - candidate = asd->getOpaqueAccessor(AccessorKind::Get); - } else if (accessorKind != None) { - missingAccessor = true; + auto accessorKind = maybeAccessorKind.getValueOr(AccessorKind::Get); + candidate = asd->getOpaqueAccessor(accessorKind); + // Error if candidate is missing the requested accessor. + if (!candidate) { + invalidCandidates.push_back( + {decl, LookupErrorKind::CandidateMissingAccessor}); + continue; + } } - if (!candidate) { - notFunction = true; + // Error if the candidate is not an `AbstractStorageDecl` but an accessor is + // requested. + else if (maybeAccessorKind.hasValue()) { + invalidCandidates.push_back( + {decl, LookupErrorKind::CandidateMissingAccessor}); continue; } - if (hasValidTypeCtx && !(*hasValidTypeCtx)(candidate)) { - wrongTypeContext = true; + // Error if candidate is not a `AbstractFunctionDecl`. + if (!candidate) { + invalidCandidates.push_back( + {decl, LookupErrorKind::CandidateNotFunctionDeclaration}); continue; } - if (!isValidCandidate(candidate)) { - foundInvalid = true; + // Error if candidate is not valid. + auto invalidCandidateKind = isValidCandidate(candidate); + if (invalidCandidateKind.hasValue()) { + invalidCandidates.push_back({candidate, *invalidCandidateKind}); continue; } - if (resolvedCandidate) { - ambiguousFuncDecl = true; - resolvedCandidate = nullptr; - break; - } - resolvedCandidate = candidate; - } - - // If function declaration was resolved, return it. - if (resolvedCandidate && !missingAccessor) - return resolvedCandidate; - - // Otherwise, emit the appropriate diagnostic and return nullptr. - if (results.empty()) { - ctx.Diags.diagnose(funcNameLoc, diag::cannot_find_in_scope, funcName, + // Otherwise, record valid candidate. + validCandidates.push_back(candidate); + } + // If there are no valid candidates, emit diagnostics for invalid candidates. + if (validCandidates.empty()) { + assert(!invalidCandidates.empty()); + diags.diagnose(funcNameLoc, diag::autodiff_attr_original_decl_none_valid, + funcName); + for (auto invalidCandidatePair : invalidCandidates) { + auto *invalidCandidate = invalidCandidatePair.first; + auto invalidCandidateKind = invalidCandidatePair.second; + auto declKind = invalidCandidate->getDescriptiveKind(); + switch (invalidCandidateKind) { + case AbstractFunctionDeclLookupErrorKind::NoCandidatesFound: + diags.diagnose(invalidCandidate, diag::cannot_find_in_scope, funcName, funcName.isOperator()); + break; + case AbstractFunctionDeclLookupErrorKind::CandidatesAmbiguous: + diags.diagnose(invalidCandidate, diag::attr_ambiguous_reference_to_decl, + funcName, attr->getAttrName()); + break; + case AbstractFunctionDeclLookupErrorKind::CandidateTypeMismatch: { + // If the expected original function type has a generic signature, emit + // "candidate does not have type equal to or less constrained than ..." + // diagnostic. + // + // This is significant because derivative/transpose functions may have + // more constrained generic signatures than their referenced original + // declarations. + if (auto genSig = expectedOriginalFnType->getOptGenericSignature()) { + diags.diagnose(invalidCandidate, + diag::autodiff_attr_original_decl_type_mismatch, + declKind, expectedOriginalFnType, + /*hasGenericSignature*/ true); + break; + } + // Otherwise, emit a "candidate does not have expected type ..." error. + diags.diagnose(invalidCandidate, + diag::autodiff_attr_original_decl_type_mismatch, + declKind, expectedOriginalFnType, + /*hasGenericSignature*/ false); + break; + } + case AbstractFunctionDeclLookupErrorKind::CandidateWrongTypeContext: + diags.diagnose(invalidCandidate, + diag::autodiff_attr_original_decl_not_same_type_context, + declKind); + break; + case AbstractFunctionDeclLookupErrorKind::CandidateMissingAccessor: { + auto accessorKind = maybeAccessorKind.getValueOr(AccessorKind::Get); + auto accessorDeclKind = getAccessorDescriptiveDeclKind(accessorKind); + diags.diagnose(invalidCandidate, + diag::autodiff_attr_original_decl_missing_accessor, + declKind, accessorDeclKind); + break; + } + case AbstractFunctionDeclLookupErrorKind::CandidateProtocolRequirement: + diags.diagnose(invalidCandidate, + diag::derivative_attr_protocol_requirement_unsupported); + break; + case AbstractFunctionDeclLookupErrorKind::CandidateNotFunctionDeclaration: + diags.diagnose(invalidCandidate, + diag::autodiff_attr_original_decl_invalid_kind, + declKind); + break; + } + } return nullptr; } - if (ambiguousFuncDecl) { - ambiguousDiagnostic(); - return nullptr; - } - if (missingAccessor) { - missingAccessorDiagnostic(); - return nullptr; - } - if (wrongTypeContext) { - assert(invalidTypeCtxDiagnostic && - "Type context diagnostic should've been specified"); - (*invalidTypeCtxDiagnostic)(); - return nullptr; - } - if (foundInvalid) { - noneValidDiagnostic(); + // Error if there are multiple valid candidates. + if (validCandidates.size() > 1) { + diags.diagnose(funcNameLoc, diag::autodiff_attr_original_decl_ambiguous, + funcName); + for (auto *validCandidate : validCandidates) { + auto declKind = validCandidate->getDescriptiveKind(); + diags.diagnose(validCandidate, + diag::autodiff_attr_original_decl_ambiguous_candidate, + declKind); + } return nullptr; } - assert(notFunction && "Expected 'not a function' error"); - notFunctionDiagnostic(); - return nullptr; + // Success if there is one unambiguous valid candidate. + return validCandidates.front(); } // Checks that the `candidate` function type equals the `required` function @@ -4434,60 +4532,43 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, return target->requirementsNotSatisfiedBy(source).empty(); }; - auto isValidOriginal = [&](AbstractFunctionDecl *originalCandidate) { - // TODO(TF-982): Allow derivatives on protocol requirements. - if (isa(originalCandidate->getDeclContext())) - return false; - return checkFunctionSignature( - cast(originalFnType->getCanonicalType()), - originalCandidate->getInterfaceType()->getCanonicalType(), - checkGenericSignatureSatisfied); - }; - - auto noneValidDiagnostic = [&]() { - diags.diagnose(originalName.Loc, - diag::autodiff_attr_original_decl_none_valid_found, - originalName.Name, originalFnType); - }; - auto ambiguousDiagnostic = [&]() { - diags.diagnose(originalName.Loc, diag::attr_ambiguous_reference_to_decl, - originalName.Name, attr->getAttrName()); - }; - auto notFunctionDiagnostic = [&]() { - diags.diagnose(originalName.Loc, - diag::autodiff_attr_original_decl_invalid_kind, - originalName.Name); - }; - auto missingAccessorDiagnostic = [&]() { - auto accessorKind = originalName.AccessorKind.getValueOr(AccessorKind::Get); - auto accessorLabel = getAccessorLabel(accessorKind); - diags.diagnose(originalName.Loc, diag::autodiff_attr_accessor_not_found, - originalName.Name, accessorLabel); - }; - - std::function invalidTypeContextDiagnostic = [&]() { - diags.diagnose(originalName.Loc, - diag::autodiff_attr_original_decl_not_same_type_context, - originalName.Name); - }; - // Returns true if the derivative function and original function candidate are // defined in compatible type contexts. If the derivative function and the // original function candidate have different parents, return false. - std::function hasValidTypeContext = - [&](AbstractFunctionDecl *func) { - // Check if both functions are top-level. - if (!derivative->getInnermostTypeContext() && - !func->getInnermostTypeContext()) - return true; - // Check if both functions are defined in the same type context. - if (auto typeCtx1 = derivative->getInnermostTypeContext()) - if (auto typeCtx2 = func->getInnermostTypeContext()) { - return typeCtx1->getSelfNominalTypeDecl() == - typeCtx2->getSelfNominalTypeDecl(); - } - return derivative->getParent() == func->getParent(); - }; + auto hasValidTypeContext = [&](AbstractFunctionDecl *originalCandidate) { + // Check if both functions are top-level. + if (!derivative->getInnermostTypeContext() && + !originalCandidate->getInnermostTypeContext()) + return true; + // Check if both functions are defined in the same type context. + if (auto typeCtx1 = derivative->getInnermostTypeContext()) + if (auto typeCtx2 = originalCandidate->getInnermostTypeContext()) { + return typeCtx1->getSelfNominalTypeDecl() == + typeCtx2->getSelfNominalTypeDecl(); + } + return derivative->getParent() == originalCandidate->getParent(); + }; + + auto isValidOriginalCandidate = [&](AbstractFunctionDecl *originalCandidate) + -> Optional { + // Error if the original candidate is a protocol requirement. Derivative + // registration does not yet support protocol requirements. + // TODO(TF-982): Allow default derivative implementations for protocol + // requirements. + if (isa(originalCandidate->getDeclContext())) + return AbstractFunctionDeclLookupErrorKind::CandidateProtocolRequirement; + // Error if the original candidate is not defined in a type context + // compatible with the derivative function. + if (!hasValidTypeContext(originalCandidate)) + return AbstractFunctionDeclLookupErrorKind::CandidateWrongTypeContext; + // Error if the original candidate does not have the expected type. + if (!checkFunctionSignature( + cast(originalFnType->getCanonicalType()), + originalCandidate->getInterfaceType()->getCanonicalType(), + checkGenericSignatureSatisfied)) + return AbstractFunctionDeclLookupErrorKind::CandidateTypeMismatch; + return None; + }; Type baseType; if (auto *baseTypeRepr = attr->getBaseTypeRepr()) { @@ -4507,16 +4588,30 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, derivativeTypeCtx = derivative->getParent(); assert(derivativeTypeCtx); + // Diagnose unsupported original accessor kinds. + // Currently, only getters and setters are supported. + if (originalName.AccessorKind.hasValue()) { + if (*originalName.AccessorKind != AccessorKind::Get && + *originalName.AccessorKind != AccessorKind::Set) { + attr->setInvalid(); + diags.diagnose( + originalName.Loc, diag::derivative_attr_unsupported_accessor_kind, + getAccessorDescriptiveDeclKind(*originalName.AccessorKind)); + return true; + } + } + // Look up original function. auto *originalAFD = findAbstractFunctionDecl( - originalName.Name, originalName.Loc.getBaseNameLoc(), - originalName.AccessorKind, baseType, derivativeTypeCtx, isValidOriginal, - noneValidDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic, - missingAccessorDiagnostic, lookupOptions, hasValidTypeContext, - invalidTypeContextDiagnostic); - if (!originalAFD) + attr, baseType, originalName, derivativeTypeCtx, lookupOptions, + isValidOriginalCandidate, originalFnType); + if (!originalAFD) { + attr->setInvalid(); return true; + } + // Diagnose original stored properties. Stored properties cannot have custom + // registered derivatives. if (auto *accessorDecl = dyn_cast(originalAFD)) { // Diagnose original stored properties. Stored properties cannot have custom // registered derivatives. @@ -5036,46 +5131,17 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) { return target->requirementsNotSatisfiedBy(source).empty(); }; - auto isValidOriginal = [&](AbstractFunctionDecl *originalCandidate) { - return checkFunctionSignature( - cast(expectedOriginalFnType->getCanonicalType()), - originalCandidate->getInterfaceType()->getCanonicalType(), - checkGenericSignatureSatisfied); - }; - - auto noneValidDiagnostic = [&]() { - diagnose(originalName.Loc, - diag::autodiff_attr_original_decl_none_valid_found, - originalName.Name, expectedOriginalFnType); - }; - auto ambiguousDiagnostic = [&]() { - diagnose(originalName.Loc, diag::attr_ambiguous_reference_to_decl, - originalName.Name, attr->getAttrName()); - }; - auto notFunctionDiagnostic = [&]() { - diagnose(originalName.Loc, - diag::autodiff_attr_original_decl_invalid_kind, - originalName.Name); - }; - auto missingAccessorDiagnostic = [&]() { - auto accessorKind = originalName.AccessorKind.getValueOr(AccessorKind::Get); - auto accessorLabel = getAccessorLabel(accessorKind); - diagnose(originalName.Loc, diag::autodiff_attr_accessor_not_found, - originalName.Name, accessorLabel); - }; - - std::function invalidTypeContextDiagnostic = [&]() { - diagnose(originalName.Loc, - diag::autodiff_attr_original_decl_not_same_type_context, - originalName.Name); + auto isValidOriginalCandidate = [&](AbstractFunctionDecl *originalCandidate) + -> Optional { + // Error if the original candidate does not have the expected type. + if (!checkFunctionSignature( + cast(expectedOriginalFnType->getCanonicalType()), + originalCandidate->getInterfaceType()->getCanonicalType(), + checkGenericSignatureSatisfied)) + return AbstractFunctionDeclLookupErrorKind::CandidateTypeMismatch; + return None; }; - // Returns true if the transpose function and original function candidate are - // defined in compatible type contexts. If the transpose function and the - // original function candidate have different parents, return false. - std::function hasValidTypeContext = - [&](AbstractFunctionDecl *decl) { return true; }; - auto resolution = TypeResolution::forContextual(transpose->getDeclContext(), None); Type baseType; @@ -5094,10 +5160,8 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) { if (attr->getBaseTypeRepr()) funcLoc = attr->getBaseTypeRepr()->getLoc(); auto *originalAFD = findAbstractFunctionDecl( - originalName.Name, funcLoc, originalName.AccessorKind, baseType, - transposeTypeCtx, isValidOriginal, noneValidDiagnostic, - ambiguousDiagnostic, notFunctionDiagnostic, missingAccessorDiagnostic, - lookupOptions, hasValidTypeContext, invalidTypeContextDiagnostic); + attr, baseType, originalName, transposeTypeCtx, lookupOptions, + isValidOriginalCandidate, expectedOriginalFnType); if (!originalAFD) { attr->setInvalid(); return; diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index 0b17860a1b51a..598dc3600879a 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -1,5 +1,7 @@ // RUN: %target-swift-frontend-typecheck -verify -disable-availability-checking %s +// Swift.AdditiveArithmetic:3:17: note: cannot yet register derivative default implementation for protocol requirements + import _Differentiation // Dummy `Differentiable`-conforming type. @@ -86,7 +88,8 @@ func vjpOriginalFunctionNotFound2(_ x: Float) -> (value: Int, pullback: (Float) // Test incorrect `@derivative` declaration type. -// expected-note @+1 {{'incorrectDerivativeType' defined here}} +// expected-note @+2 {{'incorrectDerivativeType' defined here}} +// expected-note @+1 {{candidate global function does not have expected type '(Int) -> Int'}} func incorrectDerivativeType(_ x: Float) -> Float { return x } @@ -106,7 +109,7 @@ func vjpResultIncorrectFirstLabel(x: Float) -> (Float, (Float) -> Float) { func vjpResultIncorrectSecondLabel(x: Float) -> (value: Float, (Float) -> Float) { return (x, { $0 }) } -// expected-error @+1 {{could not find function 'incorrectDerivativeType' with expected type '(Int) -> Int'}} +// expected-error @+1 {{referenced declaration 'incorrectDerivativeType' could not be resolved}} @derivative(of: incorrectDerivativeType) func vjpResultNotDifferentiable(x: Int) -> ( value: Int, pullback: (Int) -> Int @@ -357,6 +360,7 @@ extension Wrapper where T: Differentiable, T == T.TangentVector { class Super { @differentiable + // expected-note @+1 {{candidate instance method is not defined in the current type context}} func foo(_ x: Float) -> Float { return x } @@ -370,7 +374,7 @@ class Super { class Sub: Super { // TODO(TF-649): Enable `@derivative` to override derivatives for original // declaration defined in superclass. - // expected-error @+1 {{'foo' is not defined in the current type context}} + // expected-error @+1 {{referenced declaration 'foo' could not be resolved}} @derivative(of: foo) override func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { @@ -412,7 +416,11 @@ extension Class: Differentiable where T: Differentiable {} // Test computed properties. extension Struct { - var computedProperty: T { x } + var computedProperty: T { + get { x } + set { x = newValue } + _modify { yield &x } + } } extension Struct where T: Differentiable & AdditiveArithmetic { @derivative(of: computedProperty) @@ -425,13 +433,20 @@ extension Struct where T: Differentiable & AdditiveArithmetic { fatalError() } - // expected-error @+1 {{'computedProperty' does not have a 'set' accessor}} @derivative(of: computedProperty.set) mutating func vjpPropertySetter(_ newValue: T) -> ( value: (), pullback: (inout TangentVector) -> T.TangentVector ) { fatalError() } + + // expected-error @+1 {{cannot register derivative for _modify accessor}} + @derivative(of: computedProperty._modify) + mutating func vjpPropertyModify(_ newValue: T) -> ( + value: (), pullback: (inout TangentVector) -> T.TangentVector + ) { + fatalError() + } } // Test initializers. @@ -463,10 +478,13 @@ extension Struct { get { 1 } set {} } + subscript(float float: Float) -> Float { get { 1 } set {} } + + // expected-note @+1 {{candidate subscript does not have a setter}} subscript(x: T) -> T { x } } extension Struct where T: Differentiable & AdditiveArithmetic { @@ -516,7 +534,6 @@ extension Struct where T: Differentiable & AdditiveArithmetic { return (x, { _ in .zero }) } - // expected-error @+1 {{'subscript' does not have a 'set' accessor}} @derivative(of: subscript.set) mutating func vjpSubscriptSetter(_ newValue: Float) -> ( value: (), pullback: (inout TangentVector) -> Float @@ -546,7 +563,7 @@ extension Struct where T: Differentiable & AdditiveArithmetic { } // Error: original subscript has no setter. - // expected-error @+1 {{'subscript(_:)' does not have a 'set' accessor}} + // expected-error @+1 {{referenced declaration 'subscript(_:)' could not be resolved}} @derivative(of: subscript(_:).set, wrt: self) mutating func vjpSubscriptGeneric_NoSetter(x: T) -> ( value: T, pullback: (T.TangentVector) -> TangentVector @@ -604,8 +621,10 @@ func jvpDuplicate2(_ x: Float) -> (value: Float, differential: (Float) -> Float) // Test invalid original declaration kind. +// expected-note @+1 {{candidate var does not have a getter}} var globalVariable: Float -// expected-error @+1 {{'globalVariable' is not a 'func', 'init', 'subscript', or 'var' computed property declaration}} + +// expected-error @+1 {{referenced declaration 'globalVariable' could not be resolved}} @derivative(of: globalVariable) func invalidOriginalDeclaration(x: Float) -> ( value: Float, differential: (Float) -> (Float) @@ -617,10 +636,12 @@ func invalidOriginalDeclaration(x: Float) -> ( protocol P1 {} protocol P2 {} +// expected-note @+1 {{candidate global function found here}} func ambiguous(_ x: T) -> T { x } +// expected-note @+1 {{candidate global function found here}} func ambiguous(_ x: T) -> T { x } -// expected-error @+1 {{ambiguous reference to 'ambiguous' in '@derivative' attribute}} +// expected-error @+1 {{referenced declaration 'ambiguous' is ambiguous}} @derivative(of: ambiguous) func jvpAmbiguous(x: T) -> (value: T, differential: (T.TangentVector) -> (T.TangentVector)) @@ -632,11 +653,14 @@ func jvpAmbiguous(x: T) // Original declarations are invalid because they have extra generic // requirements unsatisfied by the `@derivative` function. +// expected-note @+1 {{candidate global function does not have type equal to or less constrained than ' (x: T) -> T'}} func invalid(x: T) -> T { x } +// expected-note @+1 {{candidate global function does not have type equal to or less constrained than ' (x: T) -> T'}} func invalid(x: T) -> T { x } +// expected-note @+1 {{candidate global function does not have type equal to or less constrained than ' (x: T) -> T'}} func invalid(x: T) -> T { x } -// expected-error @+1 {{could not find function 'invalid' with expected type ' (x: T) -> T'}} +// expected-error @+1 {{referenced declaration 'invalid' could not be resolved}} @derivative(of: invalid) func jvpInvalid(x: T) -> ( value: T, differential: (T.TangentVector) -> T.TangentVector @@ -647,9 +671,10 @@ func jvpInvalid(x: T) -> ( // Test invalid derivative type context: instance vs static method mismatch. struct InvalidTypeContext { + // expected-note @+1 {{candidate static method does not have type equal to or less constrained than ' (InvalidTypeContext) -> (T) -> T'}} static func staticMethod(_ x: T) -> T { x } - // expected-error @+1 {{could not find function 'staticMethod' with expected type ' (InvalidTypeContext) -> (T) -> T'}} + // expected-error @+1 {{referenced declaration 'staticMethod' could not be resolved}} @derivative(of: staticMethod) func jvpStatic(_ x: T) -> ( value: T, differential: (T.TangentVector) -> (T.TangentVector) @@ -688,13 +713,11 @@ extension HasStoredProperty { // TODO(TF-982): Lift this restriction and add proper support. protocol ProtocolRequirementDerivative { + // expected-note @+1 {{cannot yet register derivative default implementation for protocol requirements}} func requirement(_ x: Float) -> Float } extension ProtocolRequirementDerivative { - // NOTE: the error is misleading because `findAbstractFunctionDecl` in - // TypeCheckAttr.cpp is not setup to show customized error messages for - // invalid original function candidates. - // expected-error @+1 {{could not find function 'requirement' with expected type ' (Self) -> (Float) -> Float'}} + // expected-error @+1 {{referenced declaration 'requirement' could not be resolved}} @derivative(of: requirement) func vjpRequirement(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() @@ -862,15 +885,19 @@ extension InoutParameters { // Test original/derivative function `inout` parameter mismatches. extension InoutParameters { + // expected-note @+1 {{candidate instance method does not have expected type '(InoutParameters) -> (inout Float) -> Void'}} func inoutParameterMismatch(_ x: Float) {} - // expected-error @+1 {{could not find function 'inoutParameterMismatch' with expected type '(InoutParameters) -> (inout Float) -> Void'}} + + // expected-error @+1 {{referenced declaration 'inoutParameterMismatch' could not be resolved}} @derivative(of: inoutParameterMismatch) func vjpInoutParameterMismatch(_ x: inout Float) -> (value: Void, pullback: (inout Float) -> Void) { fatalError() } + // expected-note @+1 {{candidate instance method does not have expected type '(inout InoutParameters) -> (Float) -> Void'}} func mutatingMismatch(_ x: Float) {} - // expected-error @+1 {{could not find function 'mutatingMismatch' with expected type '(inout InoutParameters) -> (Float) -> Void'}} + + // expected-error @+1 {{referenced declaration 'mutatingMismatch' could not be resolved}} @derivative(of: mutatingMismatch) mutating func vjpMutatingMismatch(_ x: Float) -> (value: Void, pullback: (inout Float) -> Void) { fatalError() @@ -891,7 +918,7 @@ extension FloatingPoint where Self: Differentiable { } extension Differentiable where Self: AdditiveArithmetic { - // expected-error @+1 {{'+' is not defined in the current type context}} + // expected-error @+1 {{referenced declaration '+' could not be resolved}} @derivative(of: +) static func vjpPlus(x: Self, y: Self) -> ( value: Self, @@ -903,7 +930,7 @@ extension Differentiable where Self: AdditiveArithmetic { extension AdditiveArithmetic where Self: Differentiable, Self == Self.TangentVector { - // expected-error @+1 {{could not find function '+' with expected type ' (Self) -> (Self, Self) -> Self'}} + // expected-error @+1 {{referenced declaration '+' could not be resolved}} @derivative(of: +) func vjpPlusInstanceMethod(x: Self, y: Self) -> ( value: Self, pullback: (Self) -> (Self, Self) @@ -927,13 +954,14 @@ extension HasADefaultImplementation { // Test default derivatives of requirements. protocol HasADefaultDerivative { + // expected-note @+1 {{cannot yet register derivative default implementation for protocol requirements}} func req(_ x: Float) -> Float } extension HasADefaultDerivative { - // TODO(TF-982): Make this ok. - // expected-error @+1 {{could not find function 'req'}} + // TODO(TF-982): Support default derivatives for protocol requirements. + // expected-error @+1 {{referenced declaration 'req' could not be resolved}} @derivative(of: req) - func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { + func vjpReq(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { (x, { 10 * $0 }) } } @@ -1080,11 +1108,12 @@ fileprivate func _internal_original_fileprivate_derivative(_ x: Float) -> (value // Test invalid reference to an accessor of a non-storage declaration. +// expected-note @+1 {{candidate global function does not have a getter}} func function(_ x: Float) -> Float { x } -// expected-error @+1 {{'function' does not have a 'get' accessor}} +// expected-error @+1 {{referenced declaration 'function' could not be resolved}} @derivative(of: function(_:).get) func vjpFunction(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() @@ -1127,9 +1156,10 @@ extension Float { // Test original function with opaque result type. +// expected-note @+1 {{candidate global function does not have expected type '(Float) -> Float'}} func opaqueResult(_ x: Float) -> some Differentiable { x } -// expected-error @+1 {{could not find function 'opaqueResult' with expected type '(Float) -> Float'}} +// expected-error @+1 {{referenced declaration 'opaqueResult' could not be resolved}} @derivative(of: opaqueResult) func vjpOpaqueResult(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() diff --git a/test/AutoDiff/Sema/transpose_attr_type_checking.swift b/test/AutoDiff/Sema/transpose_attr_type_checking.swift index ebacb07a9d054..a49ff941a49d3 100644 --- a/test/AutoDiff/Sema/transpose_attr_type_checking.swift +++ b/test/AutoDiff/Sema/transpose_attr_type_checking.swift @@ -174,12 +174,13 @@ func missingSelfRequirementT(x: T) -> T { return x } +// expected-note @+1 {{candidate global function does not have type equal to or less constrained than ' (T) -> T'}} func differentGenericConstraint(x: T) -> T where T == T.TangentVector { return x } -// expected-error @+1 {{could not find function 'differentGenericConstraint' with expected type ' (T) -> T'}} +// expected-error @+1 {{referenced declaration 'differentGenericConstraint' could not be resolved}} @transpose(of: differentGenericConstraint, wrt: 0) func differentGenericConstraintT(x: T) -> T where T == T.TangentVector { @@ -472,6 +473,7 @@ extension Float { // Test non-`func` original declarations. +// expected-note @+1 {{candidate initializer does not have type equal to or less constrained than ' (Struct) -> (Float) -> Struct'}} struct Struct {} extension Struct: Equatable where T: Equatable {} extension Struct: Differentiable & AdditiveArithmetic @@ -496,7 +498,9 @@ extension Struct where T: Differentiable & AdditiveArithmetic { // Test initializers. extension Struct { + // expected-note @+1 {{candidate initializer does not have type equal to or less constrained than ' (Struct) -> (Float) -> Struct'}} init(_ x: Float) {} + // expected-note @+1 {{candidate initializer does not have type equal to or less constrained than ' (Struct) -> (Float) -> Struct'}} init(_ x: T, y: Float) {} } @@ -513,7 +517,7 @@ extension Struct where T: Differentiable, T == T.TangentVector { // Test instance transpose for static original intializer. // TODO(TF-1015): Add improved instance/static member mismatch error. - // expected-error @+1 {{could not find function 'init' with expected type ' (Struct) -> (Float) -> Struct'}} + // expected-error @+1 {{referenced declaration 'init' could not be resolved}} @transpose(of: init, wrt: 0) func vjpInitStaticMismatch(_ x: Self) -> Float { fatalError() @@ -550,6 +554,7 @@ extension Struct where T: Differentiable & AdditiveArithmetic { // Check that `@transpose` attribute rejects stored property original declarations. struct StoredProperty: Differentiable & AdditiveArithmetic { + // expected-note @+1 {{candidate getter does not have expected type '(StoredProperty) -> () -> StoredProperty'}} var stored: Float typealias TangentVector = StoredProperty static var zero: StoredProperty { StoredProperty(stored: 0) } @@ -563,7 +568,7 @@ struct StoredProperty: Differentiable & AdditiveArithmetic { // Note: `@transpose` support for instance members is currently too limited // to properly register a transpose for a non-`Self`-typed member. - // expected-error @+1 {{could not find function 'stored' with expected type '(StoredProperty) -> () -> StoredProperty'}} + // expected-error @+1 {{referenced declaration 'stored' could not be resolved}} @transpose(of: stored, wrt: self) static func vjpStored(v: Self) -> Self { fatalError()