diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index ced5a901f8a1e..f40b05ee93584 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -3095,10 +3095,16 @@ ERROR(derivative_attr_class_member_dynamic_self_result_unsupported,none, ERROR(derivative_attr_nonfinal_class_init_unsupported,none, "cannot register derivative for 'init' in a non-final class; consider " "making %0 final", (Type)) +ERROR(derivative_attr_unsupported_accessor_kind,none, + "cannot register derivative for %0", (/*accessorKind*/ DescriptiveDeclKind)) // TODO(SR-13096): Remove this temporary diagnostic. ERROR(derivative_attr_class_setter_unsupported,none, "cannot yet register derivative for class property or subscript setters", ()) +// TODO(TF-982): Remove this temporary diagnostic. +NOTE(derivative_attr_protocol_requirement_unsupported,none, + "cannot yet register derivative default implementation for protocol " + "requirements", ()) ERROR(derivative_attr_original_already_has_derivative,none, "a derivative already exists for %0", (DeclName)) NOTE(derivative_attr_duplicate_note,none, @@ -3134,15 +3140,25 @@ NOTE(transpose_attr_wrt_self_self_type_mismatch_note,none, "%1", (Type, Type)) // Automatic differentiation attributes -ERROR(autodiff_attr_original_decl_invalid_kind,none, - "%0 is not a 'func', 'init', 'subscript', or 'var' computed property " - "declaration", (DeclNameRef)) -ERROR(autodiff_attr_accessor_not_found,none, - "%0 does not have a '%1' accessor", (DeclNameRef, StringRef)) -ERROR(autodiff_attr_original_decl_none_valid_found,none, - "could not find function %0 with expected type %1", (DeclNameRef, Type)) -ERROR(autodiff_attr_original_decl_not_same_type_context,none, - "%0 is not defined in the current type context", (DeclNameRef)) +ERROR(autodiff_attr_original_decl_ambiguous,none, + "referenced declaration %0 is ambiguous", (DeclNameRef)) +NOTE(autodiff_attr_original_decl_ambiguous_candidate,none, + "candidate %0 found here", (DescriptiveDeclKind)) +ERROR(autodiff_attr_original_decl_none_valid,none, + "referenced declaration %0 could not be resolved", (DeclNameRef)) +NOTE(autodiff_attr_original_decl_invalid_kind,none, + "candidate %0 is not a 'func', 'init', 'subscript', or 'var' computed " + "property declaration", (DescriptiveDeclKind)) +NOTE(autodiff_attr_original_decl_missing_accessor,none, + "candidate %0 does not have a %1", + (DescriptiveDeclKind, /*accessorDeclKind*/ DescriptiveDeclKind)) +NOTE(autodiff_attr_original_decl_type_mismatch,none, + "candidate %0 does not have " + "%select{expected type|type equal to or less constrained than}2 %1", + (DescriptiveDeclKind, Type, /*hasGenericSignature*/ bool)) +NOTE(autodiff_attr_original_decl_not_same_type_context,none, + "candidate %0 is not defined in the current type context", + (DescriptiveDeclKind)) ERROR(autodiff_attr_original_void_result,none, "cannot differentiate void function %0", (DeclName)) ERROR(autodiff_attr_original_multiple_semantic_results,none, diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 8883713e5aa3f..e29d1dc17952f 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3604,26 +3604,71 @@ static IndexSubset *computeDifferentiabilityParameters( return IndexSubset::get(ctx, parameterBits); } -// Returns the function declaration corresponding to the given function name and -// lookup context. If the base type of the function is specified, member lookup -// is performed. Otherwise, unqualified lookup is performed. +/// Returns the `DescriptiveDeclKind` corresponding to the given `AccessorKind`. +/// Used for diagnostics. +static DescriptiveDeclKind getAccessorDescriptiveDeclKind(AccessorKind kind) { + switch (kind) { + case AccessorKind::Get: + return DescriptiveDeclKind::Getter; + case AccessorKind::Set: + return DescriptiveDeclKind::Setter; + case AccessorKind::Read: + return DescriptiveDeclKind::ReadAccessor; + case AccessorKind::Modify: + return DescriptiveDeclKind::ModifyAccessor; + case AccessorKind::WillSet: + return DescriptiveDeclKind::WillSet; + case AccessorKind::DidSet: + return DescriptiveDeclKind::DidSet; + case AccessorKind::Address: + return DescriptiveDeclKind::Addressor; + case AccessorKind::MutableAddress: + return DescriptiveDeclKind::MutableAddressor; + } +} + +/// An abstract function declaration lookup error. +enum class AbstractFunctionDeclLookupErrorKind { + /// No lookup candidates could be found. + NoCandidatesFound, + /// There are multiple valid lookup candidates. + CandidatesAmbiguous, + /// Lookup candidate does not have the expected type. + CandidateTypeMismatch, + /// Lookup candidate is in the wrong type context. + CandidateWrongTypeContext, + /// Lookup candidate does not have the requested accessor. + CandidateMissingAccessor, + /// Lookup candidate is a protocol requirement. + CandidateProtocolRequirement, + /// Lookup candidate could be resolved to an `AbstractFunctionDecl`. + CandidateNotFunctionDeclaration +}; + +// Returns the function declaration corresponding to the given base type +// (optional), function name, and lookup context. +// +// If the base type of the function is specified, member lookup is performed. +// Otherwise, unqualified lookup is performed. +// // If the function declaration cannot be resolved, emits a diagnostic and // returns nullptr. +// +// Used for resolving the referenced declaration in `@derivative` and +// `@transpose` attributes. static AbstractFunctionDecl *findAbstractFunctionDecl( - DeclNameRef funcName, SourceLoc funcNameLoc, - Optional accessorKind, Type baseType, - DeclContext *lookupContext, - const std::function &isValidCandidate, - const std::function &noneValidDiagnostic, - const std::function &ambiguousDiagnostic, - const std::function ¬FunctionDiagnostic, - const std::function &missingAccessorDiagnostic, - NameLookupOptions lookupOptions, - const Optional> - &hasValidTypeCtx, - const Optional> &invalidTypeCtxDiagnostic) { + DeclAttribute *attr, Type baseType, DeclNameRefWithLoc funcNameWithLoc, + DeclContext *lookupContext, NameLookupOptions lookupOptions, + const llvm::function_ref( + AbstractFunctionDecl *)> &isValidCandidate, + AnyFunctionType *expectedOriginalFnType) { + assert(lookupContext); auto &ctx = lookupContext->getASTContext(); - AbstractFunctionDecl *resolvedCandidate = nullptr; + auto &diags = ctx.Diags; + + auto funcName = funcNameWithLoc.Name; + auto funcNameLoc = funcNameWithLoc.Loc; + auto maybeAccessorKind = funcNameWithLoc.AccessorKind; // Perform lookup. LookupResult results; @@ -3634,22 +3679,25 @@ static AbstractFunctionDecl *findAbstractFunctionDecl( if (baseType) { results = TypeChecker::lookupMember(lookupContext, baseType, funcName); } else { - results = TypeChecker::lookupUnqualified(lookupContext, funcName, - funcNameLoc, lookupOptions); + results = TypeChecker::lookupUnqualified( + lookupContext, funcName, funcNameLoc.getBaseNameLoc(), lookupOptions); + } + + // Error if no candidates were found. + if (results.empty()) { + diags.diagnose(funcNameLoc, diag::cannot_find_in_scope, funcName, + funcName.isOperator()); + return nullptr; } - // Initialize error flags. - bool notFunction = false; - bool wrongTypeContext = false; - bool ambiguousFuncDecl = false; - bool foundInvalid = false; - bool missingAccessor = false; + // Track invalid and valid candidates. + using LookupErrorKind = AbstractFunctionDeclLookupErrorKind; + SmallVector, 2> invalidCandidates; + SmallVector validCandidates; // Filter lookup results. for (auto choice : results) { - auto decl = choice.getValueDecl(); - if (!decl) - continue; + auto *decl = choice.getValueDecl(); // Cast the candidate to an `AbstractFunctionDecl`. auto *candidate = dyn_cast(decl); // If the candidate is an `AbstractStorageDecl`, use one of its accessors as @@ -3657,67 +3705,117 @@ static AbstractFunctionDecl *findAbstractFunctionDecl( if (auto *asd = dyn_cast(decl)) { // If accessor kind is specified, use corresponding accessor from the // candidate. Otherwise, use the getter by default. - if (accessorKind != None) { - candidate = asd->getOpaqueAccessor(accessorKind.getValue()); - // Error if candidate is missing the requested accessor. - if (!candidate) - missingAccessor = true; - } else - candidate = asd->getOpaqueAccessor(AccessorKind::Get); - } else if (accessorKind != None) { - missingAccessor = true; + auto accessorKind = maybeAccessorKind.getValueOr(AccessorKind::Get); + candidate = asd->getOpaqueAccessor(accessorKind); + // Error if candidate is missing the requested accessor. + if (!candidate) { + invalidCandidates.push_back( + {decl, LookupErrorKind::CandidateMissingAccessor}); + continue; + } } - if (!candidate) { - notFunction = true; + // Error if the candidate is not an `AbstractStorageDecl` but an accessor is + // requested. + else if (maybeAccessorKind.hasValue()) { + invalidCandidates.push_back( + {decl, LookupErrorKind::CandidateMissingAccessor}); continue; } - if (hasValidTypeCtx && !(*hasValidTypeCtx)(candidate)) { - wrongTypeContext = true; + // Error if candidate is not a `AbstractFunctionDecl`. + if (!candidate) { + invalidCandidates.push_back( + {decl, LookupErrorKind::CandidateNotFunctionDeclaration}); continue; } - if (!isValidCandidate(candidate)) { - foundInvalid = true; + // Error if candidate is not valid. + auto invalidCandidateKind = isValidCandidate(candidate); + if (invalidCandidateKind.hasValue()) { + invalidCandidates.push_back({candidate, *invalidCandidateKind}); continue; } - if (resolvedCandidate) { - ambiguousFuncDecl = true; - resolvedCandidate = nullptr; - break; - } - resolvedCandidate = candidate; - } - - // If function declaration was resolved, return it. - if (resolvedCandidate && !missingAccessor) - return resolvedCandidate; - - // Otherwise, emit the appropriate diagnostic and return nullptr. - if (results.empty()) { - ctx.Diags.diagnose(funcNameLoc, diag::cannot_find_in_scope, funcName, + // Otherwise, record valid candidate. + validCandidates.push_back(candidate); + } + // If there are no valid candidates, emit diagnostics for invalid candidates. + if (validCandidates.empty()) { + assert(!invalidCandidates.empty()); + diags.diagnose(funcNameLoc, diag::autodiff_attr_original_decl_none_valid, + funcName); + for (auto invalidCandidatePair : invalidCandidates) { + auto *invalidCandidate = invalidCandidatePair.first; + auto invalidCandidateKind = invalidCandidatePair.second; + auto declKind = invalidCandidate->getDescriptiveKind(); + switch (invalidCandidateKind) { + case AbstractFunctionDeclLookupErrorKind::NoCandidatesFound: + diags.diagnose(invalidCandidate, diag::cannot_find_in_scope, funcName, funcName.isOperator()); + break; + case AbstractFunctionDeclLookupErrorKind::CandidatesAmbiguous: + diags.diagnose(invalidCandidate, diag::attr_ambiguous_reference_to_decl, + funcName, attr->getAttrName()); + break; + case AbstractFunctionDeclLookupErrorKind::CandidateTypeMismatch: { + // If the expected original function type has a generic signature, emit + // "candidate does not have type equal to or less constrained than ..." + // diagnostic. + // + // This is significant because derivative/transpose functions may have + // more constrained generic signatures than their referenced original + // declarations. + if (auto genSig = expectedOriginalFnType->getOptGenericSignature()) { + diags.diagnose(invalidCandidate, + diag::autodiff_attr_original_decl_type_mismatch, + declKind, expectedOriginalFnType, + /*hasGenericSignature*/ true); + break; + } + // Otherwise, emit a "candidate does not have expected type ..." error. + diags.diagnose(invalidCandidate, + diag::autodiff_attr_original_decl_type_mismatch, + declKind, expectedOriginalFnType, + /*hasGenericSignature*/ false); + break; + } + case AbstractFunctionDeclLookupErrorKind::CandidateWrongTypeContext: + diags.diagnose(invalidCandidate, + diag::autodiff_attr_original_decl_not_same_type_context, + declKind); + break; + case AbstractFunctionDeclLookupErrorKind::CandidateMissingAccessor: { + auto accessorKind = maybeAccessorKind.getValueOr(AccessorKind::Get); + auto accessorDeclKind = getAccessorDescriptiveDeclKind(accessorKind); + diags.diagnose(invalidCandidate, + diag::autodiff_attr_original_decl_missing_accessor, + declKind, accessorDeclKind); + break; + } + case AbstractFunctionDeclLookupErrorKind::CandidateProtocolRequirement: + diags.diagnose(invalidCandidate, + diag::derivative_attr_protocol_requirement_unsupported); + break; + case AbstractFunctionDeclLookupErrorKind::CandidateNotFunctionDeclaration: + diags.diagnose(invalidCandidate, + diag::autodiff_attr_original_decl_invalid_kind, + declKind); + break; + } + } return nullptr; } - if (ambiguousFuncDecl) { - ambiguousDiagnostic(); - return nullptr; - } - if (missingAccessor) { - missingAccessorDiagnostic(); - return nullptr; - } - if (wrongTypeContext) { - assert(invalidTypeCtxDiagnostic && - "Type context diagnostic should've been specified"); - (*invalidTypeCtxDiagnostic)(); - return nullptr; - } - if (foundInvalid) { - noneValidDiagnostic(); + // Error if there are multiple valid candidates. + if (validCandidates.size() > 1) { + diags.diagnose(funcNameLoc, diag::autodiff_attr_original_decl_ambiguous, + funcName); + for (auto *validCandidate : validCandidates) { + auto declKind = validCandidate->getDescriptiveKind(); + diags.diagnose(validCandidate, + diag::autodiff_attr_original_decl_ambiguous_candidate, + declKind); + } return nullptr; } - assert(notFunction && "Expected 'not a function' error"); - notFunctionDiagnostic(); - return nullptr; + // Success if there is one unambiguous valid candidate. + return validCandidates.front(); } // Checks that the `candidate` function type equals the `required` function @@ -4434,60 +4532,43 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, return target->requirementsNotSatisfiedBy(source).empty(); }; - auto isValidOriginal = [&](AbstractFunctionDecl *originalCandidate) { - // TODO(TF-982): Allow derivatives on protocol requirements. - if (isa(originalCandidate->getDeclContext())) - return false; - return checkFunctionSignature( - cast(originalFnType->getCanonicalType()), - originalCandidate->getInterfaceType()->getCanonicalType(), - checkGenericSignatureSatisfied); - }; - - auto noneValidDiagnostic = [&]() { - diags.diagnose(originalName.Loc, - diag::autodiff_attr_original_decl_none_valid_found, - originalName.Name, originalFnType); - }; - auto ambiguousDiagnostic = [&]() { - diags.diagnose(originalName.Loc, diag::attr_ambiguous_reference_to_decl, - originalName.Name, attr->getAttrName()); - }; - auto notFunctionDiagnostic = [&]() { - diags.diagnose(originalName.Loc, - diag::autodiff_attr_original_decl_invalid_kind, - originalName.Name); - }; - auto missingAccessorDiagnostic = [&]() { - auto accessorKind = originalName.AccessorKind.getValueOr(AccessorKind::Get); - auto accessorLabel = getAccessorLabel(accessorKind); - diags.diagnose(originalName.Loc, diag::autodiff_attr_accessor_not_found, - originalName.Name, accessorLabel); - }; - - std::function invalidTypeContextDiagnostic = [&]() { - diags.diagnose(originalName.Loc, - diag::autodiff_attr_original_decl_not_same_type_context, - originalName.Name); - }; - // Returns true if the derivative function and original function candidate are // defined in compatible type contexts. If the derivative function and the // original function candidate have different parents, return false. - std::function hasValidTypeContext = - [&](AbstractFunctionDecl *func) { - // Check if both functions are top-level. - if (!derivative->getInnermostTypeContext() && - !func->getInnermostTypeContext()) - return true; - // Check if both functions are defined in the same type context. - if (auto typeCtx1 = derivative->getInnermostTypeContext()) - if (auto typeCtx2 = func->getInnermostTypeContext()) { - return typeCtx1->getSelfNominalTypeDecl() == - typeCtx2->getSelfNominalTypeDecl(); - } - return derivative->getParent() == func->getParent(); - }; + auto hasValidTypeContext = [&](AbstractFunctionDecl *originalCandidate) { + // Check if both functions are top-level. + if (!derivative->getInnermostTypeContext() && + !originalCandidate->getInnermostTypeContext()) + return true; + // Check if both functions are defined in the same type context. + if (auto typeCtx1 = derivative->getInnermostTypeContext()) + if (auto typeCtx2 = originalCandidate->getInnermostTypeContext()) { + return typeCtx1->getSelfNominalTypeDecl() == + typeCtx2->getSelfNominalTypeDecl(); + } + return derivative->getParent() == originalCandidate->getParent(); + }; + + auto isValidOriginalCandidate = [&](AbstractFunctionDecl *originalCandidate) + -> Optional { + // Error if the original candidate is a protocol requirement. Derivative + // registration does not yet support protocol requirements. + // TODO(TF-982): Allow default derivative implementations for protocol + // requirements. + if (isa(originalCandidate->getDeclContext())) + return AbstractFunctionDeclLookupErrorKind::CandidateProtocolRequirement; + // Error if the original candidate is not defined in a type context + // compatible with the derivative function. + if (!hasValidTypeContext(originalCandidate)) + return AbstractFunctionDeclLookupErrorKind::CandidateWrongTypeContext; + // Error if the original candidate does not have the expected type. + if (!checkFunctionSignature( + cast(originalFnType->getCanonicalType()), + originalCandidate->getInterfaceType()->getCanonicalType(), + checkGenericSignatureSatisfied)) + return AbstractFunctionDeclLookupErrorKind::CandidateTypeMismatch; + return None; + }; Type baseType; if (auto *baseTypeRepr = attr->getBaseTypeRepr()) { @@ -4507,16 +4588,30 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, derivativeTypeCtx = derivative->getParent(); assert(derivativeTypeCtx); + // Diagnose unsupported original accessor kinds. + // Currently, only getters and setters are supported. + if (originalName.AccessorKind.hasValue()) { + if (*originalName.AccessorKind != AccessorKind::Get && + *originalName.AccessorKind != AccessorKind::Set) { + attr->setInvalid(); + diags.diagnose( + originalName.Loc, diag::derivative_attr_unsupported_accessor_kind, + getAccessorDescriptiveDeclKind(*originalName.AccessorKind)); + return true; + } + } + // Look up original function. auto *originalAFD = findAbstractFunctionDecl( - originalName.Name, originalName.Loc.getBaseNameLoc(), - originalName.AccessorKind, baseType, derivativeTypeCtx, isValidOriginal, - noneValidDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic, - missingAccessorDiagnostic, lookupOptions, hasValidTypeContext, - invalidTypeContextDiagnostic); - if (!originalAFD) + attr, baseType, originalName, derivativeTypeCtx, lookupOptions, + isValidOriginalCandidate, originalFnType); + if (!originalAFD) { + attr->setInvalid(); return true; + } + // Diagnose original stored properties. Stored properties cannot have custom + // registered derivatives. if (auto *accessorDecl = dyn_cast(originalAFD)) { // Diagnose original stored properties. Stored properties cannot have custom // registered derivatives. @@ -5036,46 +5131,17 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) { return target->requirementsNotSatisfiedBy(source).empty(); }; - auto isValidOriginal = [&](AbstractFunctionDecl *originalCandidate) { - return checkFunctionSignature( - cast(expectedOriginalFnType->getCanonicalType()), - originalCandidate->getInterfaceType()->getCanonicalType(), - checkGenericSignatureSatisfied); - }; - - auto noneValidDiagnostic = [&]() { - diagnose(originalName.Loc, - diag::autodiff_attr_original_decl_none_valid_found, - originalName.Name, expectedOriginalFnType); - }; - auto ambiguousDiagnostic = [&]() { - diagnose(originalName.Loc, diag::attr_ambiguous_reference_to_decl, - originalName.Name, attr->getAttrName()); - }; - auto notFunctionDiagnostic = [&]() { - diagnose(originalName.Loc, - diag::autodiff_attr_original_decl_invalid_kind, - originalName.Name); - }; - auto missingAccessorDiagnostic = [&]() { - auto accessorKind = originalName.AccessorKind.getValueOr(AccessorKind::Get); - auto accessorLabel = getAccessorLabel(accessorKind); - diagnose(originalName.Loc, diag::autodiff_attr_accessor_not_found, - originalName.Name, accessorLabel); - }; - - std::function invalidTypeContextDiagnostic = [&]() { - diagnose(originalName.Loc, - diag::autodiff_attr_original_decl_not_same_type_context, - originalName.Name); + auto isValidOriginalCandidate = [&](AbstractFunctionDecl *originalCandidate) + -> Optional { + // Error if the original candidate does not have the expected type. + if (!checkFunctionSignature( + cast(expectedOriginalFnType->getCanonicalType()), + originalCandidate->getInterfaceType()->getCanonicalType(), + checkGenericSignatureSatisfied)) + return AbstractFunctionDeclLookupErrorKind::CandidateTypeMismatch; + return None; }; - // Returns true if the transpose function and original function candidate are - // defined in compatible type contexts. If the transpose function and the - // original function candidate have different parents, return false. - std::function hasValidTypeContext = - [&](AbstractFunctionDecl *decl) { return true; }; - auto resolution = TypeResolution::forContextual(transpose->getDeclContext(), None); Type baseType; @@ -5094,10 +5160,8 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) { if (attr->getBaseTypeRepr()) funcLoc = attr->getBaseTypeRepr()->getLoc(); auto *originalAFD = findAbstractFunctionDecl( - originalName.Name, funcLoc, originalName.AccessorKind, baseType, - transposeTypeCtx, isValidOriginal, noneValidDiagnostic, - ambiguousDiagnostic, notFunctionDiagnostic, missingAccessorDiagnostic, - lookupOptions, hasValidTypeContext, invalidTypeContextDiagnostic); + attr, baseType, originalName, transposeTypeCtx, lookupOptions, + isValidOriginalCandidate, expectedOriginalFnType); if (!originalAFD) { attr->setInvalid(); return; diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index 0b17860a1b51a..598dc3600879a 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -1,5 +1,7 @@ // RUN: %target-swift-frontend-typecheck -verify -disable-availability-checking %s +// Swift.AdditiveArithmetic:3:17: note: cannot yet register derivative default implementation for protocol requirements + import _Differentiation // Dummy `Differentiable`-conforming type. @@ -86,7 +88,8 @@ func vjpOriginalFunctionNotFound2(_ x: Float) -> (value: Int, pullback: (Float) // Test incorrect `@derivative` declaration type. -// expected-note @+1 {{'incorrectDerivativeType' defined here}} +// expected-note @+2 {{'incorrectDerivativeType' defined here}} +// expected-note @+1 {{candidate global function does not have expected type '(Int) -> Int'}} func incorrectDerivativeType(_ x: Float) -> Float { return x } @@ -106,7 +109,7 @@ func vjpResultIncorrectFirstLabel(x: Float) -> (Float, (Float) -> Float) { func vjpResultIncorrectSecondLabel(x: Float) -> (value: Float, (Float) -> Float) { return (x, { $0 }) } -// expected-error @+1 {{could not find function 'incorrectDerivativeType' with expected type '(Int) -> Int'}} +// expected-error @+1 {{referenced declaration 'incorrectDerivativeType' could not be resolved}} @derivative(of: incorrectDerivativeType) func vjpResultNotDifferentiable(x: Int) -> ( value: Int, pullback: (Int) -> Int @@ -357,6 +360,7 @@ extension Wrapper where T: Differentiable, T == T.TangentVector { class Super { @differentiable + // expected-note @+1 {{candidate instance method is not defined in the current type context}} func foo(_ x: Float) -> Float { return x } @@ -370,7 +374,7 @@ class Super { class Sub: Super { // TODO(TF-649): Enable `@derivative` to override derivatives for original // declaration defined in superclass. - // expected-error @+1 {{'foo' is not defined in the current type context}} + // expected-error @+1 {{referenced declaration 'foo' could not be resolved}} @derivative(of: foo) override func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { @@ -412,7 +416,11 @@ extension Class: Differentiable where T: Differentiable {} // Test computed properties. extension Struct { - var computedProperty: T { x } + var computedProperty: T { + get { x } + set { x = newValue } + _modify { yield &x } + } } extension Struct where T: Differentiable & AdditiveArithmetic { @derivative(of: computedProperty) @@ -425,13 +433,20 @@ extension Struct where T: Differentiable & AdditiveArithmetic { fatalError() } - // expected-error @+1 {{'computedProperty' does not have a 'set' accessor}} @derivative(of: computedProperty.set) mutating func vjpPropertySetter(_ newValue: T) -> ( value: (), pullback: (inout TangentVector) -> T.TangentVector ) { fatalError() } + + // expected-error @+1 {{cannot register derivative for _modify accessor}} + @derivative(of: computedProperty._modify) + mutating func vjpPropertyModify(_ newValue: T) -> ( + value: (), pullback: (inout TangentVector) -> T.TangentVector + ) { + fatalError() + } } // Test initializers. @@ -463,10 +478,13 @@ extension Struct { get { 1 } set {} } + subscript(float float: Float) -> Float { get { 1 } set {} } + + // expected-note @+1 {{candidate subscript does not have a setter}} subscript(x: T) -> T { x } } extension Struct where T: Differentiable & AdditiveArithmetic { @@ -516,7 +534,6 @@ extension Struct where T: Differentiable & AdditiveArithmetic { return (x, { _ in .zero }) } - // expected-error @+1 {{'subscript' does not have a 'set' accessor}} @derivative(of: subscript.set) mutating func vjpSubscriptSetter(_ newValue: Float) -> ( value: (), pullback: (inout TangentVector) -> Float @@ -546,7 +563,7 @@ extension Struct where T: Differentiable & AdditiveArithmetic { } // Error: original subscript has no setter. - // expected-error @+1 {{'subscript(_:)' does not have a 'set' accessor}} + // expected-error @+1 {{referenced declaration 'subscript(_:)' could not be resolved}} @derivative(of: subscript(_:).set, wrt: self) mutating func vjpSubscriptGeneric_NoSetter(x: T) -> ( value: T, pullback: (T.TangentVector) -> TangentVector @@ -604,8 +621,10 @@ func jvpDuplicate2(_ x: Float) -> (value: Float, differential: (Float) -> Float) // Test invalid original declaration kind. +// expected-note @+1 {{candidate var does not have a getter}} var globalVariable: Float -// expected-error @+1 {{'globalVariable' is not a 'func', 'init', 'subscript', or 'var' computed property declaration}} + +// expected-error @+1 {{referenced declaration 'globalVariable' could not be resolved}} @derivative(of: globalVariable) func invalidOriginalDeclaration(x: Float) -> ( value: Float, differential: (Float) -> (Float) @@ -617,10 +636,12 @@ func invalidOriginalDeclaration(x: Float) -> ( protocol P1 {} protocol P2 {} +// expected-note @+1 {{candidate global function found here}} func ambiguous(_ x: T) -> T { x } +// expected-note @+1 {{candidate global function found here}} func ambiguous(_ x: T) -> T { x } -// expected-error @+1 {{ambiguous reference to 'ambiguous' in '@derivative' attribute}} +// expected-error @+1 {{referenced declaration 'ambiguous' is ambiguous}} @derivative(of: ambiguous) func jvpAmbiguous(x: T) -> (value: T, differential: (T.TangentVector) -> (T.TangentVector)) @@ -632,11 +653,14 @@ func jvpAmbiguous(x: T) // Original declarations are invalid because they have extra generic // requirements unsatisfied by the `@derivative` function. +// expected-note @+1 {{candidate global function does not have type equal to or less constrained than ' (x: T) -> T'}} func invalid(x: T) -> T { x } +// expected-note @+1 {{candidate global function does not have type equal to or less constrained than ' (x: T) -> T'}} func invalid(x: T) -> T { x } +// expected-note @+1 {{candidate global function does not have type equal to or less constrained than ' (x: T) -> T'}} func invalid(x: T) -> T { x } -// expected-error @+1 {{could not find function 'invalid' with expected type ' (x: T) -> T'}} +// expected-error @+1 {{referenced declaration 'invalid' could not be resolved}} @derivative(of: invalid) func jvpInvalid(x: T) -> ( value: T, differential: (T.TangentVector) -> T.TangentVector @@ -647,9 +671,10 @@ func jvpInvalid(x: T) -> ( // Test invalid derivative type context: instance vs static method mismatch. struct InvalidTypeContext { + // expected-note @+1 {{candidate static method does not have type equal to or less constrained than ' (InvalidTypeContext) -> (T) -> T'}} static func staticMethod(_ x: T) -> T { x } - // expected-error @+1 {{could not find function 'staticMethod' with expected type ' (InvalidTypeContext) -> (T) -> T'}} + // expected-error @+1 {{referenced declaration 'staticMethod' could not be resolved}} @derivative(of: staticMethod) func jvpStatic(_ x: T) -> ( value: T, differential: (T.TangentVector) -> (T.TangentVector) @@ -688,13 +713,11 @@ extension HasStoredProperty { // TODO(TF-982): Lift this restriction and add proper support. protocol ProtocolRequirementDerivative { + // expected-note @+1 {{cannot yet register derivative default implementation for protocol requirements}} func requirement(_ x: Float) -> Float } extension ProtocolRequirementDerivative { - // NOTE: the error is misleading because `findAbstractFunctionDecl` in - // TypeCheckAttr.cpp is not setup to show customized error messages for - // invalid original function candidates. - // expected-error @+1 {{could not find function 'requirement' with expected type ' (Self) -> (Float) -> Float'}} + // expected-error @+1 {{referenced declaration 'requirement' could not be resolved}} @derivative(of: requirement) func vjpRequirement(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() @@ -862,15 +885,19 @@ extension InoutParameters { // Test original/derivative function `inout` parameter mismatches. extension InoutParameters { + // expected-note @+1 {{candidate instance method does not have expected type '(InoutParameters) -> (inout Float) -> Void'}} func inoutParameterMismatch(_ x: Float) {} - // expected-error @+1 {{could not find function 'inoutParameterMismatch' with expected type '(InoutParameters) -> (inout Float) -> Void'}} + + // expected-error @+1 {{referenced declaration 'inoutParameterMismatch' could not be resolved}} @derivative(of: inoutParameterMismatch) func vjpInoutParameterMismatch(_ x: inout Float) -> (value: Void, pullback: (inout Float) -> Void) { fatalError() } + // expected-note @+1 {{candidate instance method does not have expected type '(inout InoutParameters) -> (Float) -> Void'}} func mutatingMismatch(_ x: Float) {} - // expected-error @+1 {{could not find function 'mutatingMismatch' with expected type '(inout InoutParameters) -> (Float) -> Void'}} + + // expected-error @+1 {{referenced declaration 'mutatingMismatch' could not be resolved}} @derivative(of: mutatingMismatch) mutating func vjpMutatingMismatch(_ x: Float) -> (value: Void, pullback: (inout Float) -> Void) { fatalError() @@ -891,7 +918,7 @@ extension FloatingPoint where Self: Differentiable { } extension Differentiable where Self: AdditiveArithmetic { - // expected-error @+1 {{'+' is not defined in the current type context}} + // expected-error @+1 {{referenced declaration '+' could not be resolved}} @derivative(of: +) static func vjpPlus(x: Self, y: Self) -> ( value: Self, @@ -903,7 +930,7 @@ extension Differentiable where Self: AdditiveArithmetic { extension AdditiveArithmetic where Self: Differentiable, Self == Self.TangentVector { - // expected-error @+1 {{could not find function '+' with expected type ' (Self) -> (Self, Self) -> Self'}} + // expected-error @+1 {{referenced declaration '+' could not be resolved}} @derivative(of: +) func vjpPlusInstanceMethod(x: Self, y: Self) -> ( value: Self, pullback: (Self) -> (Self, Self) @@ -927,13 +954,14 @@ extension HasADefaultImplementation { // Test default derivatives of requirements. protocol HasADefaultDerivative { + // expected-note @+1 {{cannot yet register derivative default implementation for protocol requirements}} func req(_ x: Float) -> Float } extension HasADefaultDerivative { - // TODO(TF-982): Make this ok. - // expected-error @+1 {{could not find function 'req'}} + // TODO(TF-982): Support default derivatives for protocol requirements. + // expected-error @+1 {{referenced declaration 'req' could not be resolved}} @derivative(of: req) - func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { + func vjpReq(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { (x, { 10 * $0 }) } } @@ -1080,11 +1108,12 @@ fileprivate func _internal_original_fileprivate_derivative(_ x: Float) -> (value // Test invalid reference to an accessor of a non-storage declaration. +// expected-note @+1 {{candidate global function does not have a getter}} func function(_ x: Float) -> Float { x } -// expected-error @+1 {{'function' does not have a 'get' accessor}} +// expected-error @+1 {{referenced declaration 'function' could not be resolved}} @derivative(of: function(_:).get) func vjpFunction(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() @@ -1127,9 +1156,10 @@ extension Float { // Test original function with opaque result type. +// expected-note @+1 {{candidate global function does not have expected type '(Float) -> Float'}} func opaqueResult(_ x: Float) -> some Differentiable { x } -// expected-error @+1 {{could not find function 'opaqueResult' with expected type '(Float) -> Float'}} +// expected-error @+1 {{referenced declaration 'opaqueResult' could not be resolved}} @derivative(of: opaqueResult) func vjpOpaqueResult(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { fatalError() diff --git a/test/AutoDiff/Sema/transpose_attr_type_checking.swift b/test/AutoDiff/Sema/transpose_attr_type_checking.swift index ebacb07a9d054..a49ff941a49d3 100644 --- a/test/AutoDiff/Sema/transpose_attr_type_checking.swift +++ b/test/AutoDiff/Sema/transpose_attr_type_checking.swift @@ -174,12 +174,13 @@ func missingSelfRequirementT(x: T) -> T { return x } +// expected-note @+1 {{candidate global function does not have type equal to or less constrained than ' (T) -> T'}} func differentGenericConstraint(x: T) -> T where T == T.TangentVector { return x } -// expected-error @+1 {{could not find function 'differentGenericConstraint' with expected type ' (T) -> T'}} +// expected-error @+1 {{referenced declaration 'differentGenericConstraint' could not be resolved}} @transpose(of: differentGenericConstraint, wrt: 0) func differentGenericConstraintT(x: T) -> T where T == T.TangentVector { @@ -472,6 +473,7 @@ extension Float { // Test non-`func` original declarations. +// expected-note @+1 {{candidate initializer does not have type equal to or less constrained than ' (Struct) -> (Float) -> Struct'}} struct Struct {} extension Struct: Equatable where T: Equatable {} extension Struct: Differentiable & AdditiveArithmetic @@ -496,7 +498,9 @@ extension Struct where T: Differentiable & AdditiveArithmetic { // Test initializers. extension Struct { + // expected-note @+1 {{candidate initializer does not have type equal to or less constrained than ' (Struct) -> (Float) -> Struct'}} init(_ x: Float) {} + // expected-note @+1 {{candidate initializer does not have type equal to or less constrained than ' (Struct) -> (Float) -> Struct'}} init(_ x: T, y: Float) {} } @@ -513,7 +517,7 @@ extension Struct where T: Differentiable, T == T.TangentVector { // Test instance transpose for static original intializer. // TODO(TF-1015): Add improved instance/static member mismatch error. - // expected-error @+1 {{could not find function 'init' with expected type ' (Struct) -> (Float) -> Struct'}} + // expected-error @+1 {{referenced declaration 'init' could not be resolved}} @transpose(of: init, wrt: 0) func vjpInitStaticMismatch(_ x: Self) -> Float { fatalError() @@ -550,6 +554,7 @@ extension Struct where T: Differentiable & AdditiveArithmetic { // Check that `@transpose` attribute rejects stored property original declarations. struct StoredProperty: Differentiable & AdditiveArithmetic { + // expected-note @+1 {{candidate getter does not have expected type '(StoredProperty) -> () -> StoredProperty'}} var stored: Float typealias TangentVector = StoredProperty static var zero: StoredProperty { StoredProperty(stored: 0) } @@ -563,7 +568,7 @@ struct StoredProperty: Differentiable & AdditiveArithmetic { // Note: `@transpose` support for instance members is currently too limited // to properly register a transpose for a non-`Self`-typed member. - // expected-error @+1 {{could not find function 'stored' with expected type '(StoredProperty) -> () -> StoredProperty'}} + // expected-error @+1 {{referenced declaration 'stored' could not be resolved}} @transpose(of: stored, wrt: self) static func vjpStored(v: Self) -> Self { fatalError()