Skip to content

Try to evaluate in try unify and postpone resolution of constants that contain inference variables #95179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,15 +687,25 @@ pub struct CombinedSnapshot<'a, 'tcx> {
impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
/// calls `tcx.try_unify_abstract_consts` after
/// canonicalizing the consts.
#[instrument(skip(self), level = "debug")]
pub fn try_unify_abstract_consts(
&self,
a: ty::Unevaluated<'tcx, ()>,
b: ty::Unevaluated<'tcx, ()>,
param_env: ty::ParamEnv<'tcx>,
) -> bool {
// Reject any attempt to unify two unevaluated constants that contain inference
// variables.
// FIXME `TyCtxt::const_eval_resolve` already rejects the resolution of those
// constants early, but the canonicalization below messes with that mechanism.
if a.substs.has_infer_types_or_consts() || b.substs.has_infer_types_or_consts() {
debug!("a or b contain infer vars in its substs -> cannot unify");
return false;
}

let canonical = self.canonicalize_query((a, b), &mut OriginalQueryValues::default());
debug!("canonical consts: {:?}", &canonical.value);

self.tcx.try_unify_abstract_consts(canonical.value)
self.tcx.try_unify_abstract_consts(param_env.and(canonical.value))
}

pub fn is_in_snapshot(&self) -> bool {
Expand Down Expand Up @@ -1598,22 +1608,27 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
///
/// This handles inferences variables within both `param_env` and `substs` by
/// performing the operation on their respective canonical forms.
#[instrument(skip(self), level = "debug")]
pub fn const_eval_resolve(
&self,
param_env: ty::ParamEnv<'tcx>,
unevaluated: ty::Unevaluated<'tcx>,
span: Option<Span>,
) -> EvalToConstValueResult<'tcx> {
let substs = self.resolve_vars_if_possible(unevaluated.substs);
debug!(?substs);

// Postpone the evaluation of constants whose substs depend on inference
// variables
if substs.has_infer_types_or_consts() {
debug!("has infer types or consts");
return Err(ErrorHandled::TooGeneric);
}

let param_env_erased = self.tcx.erase_regions(param_env);
let substs_erased = self.tcx.erase_regions(substs);
debug!(?param_env_erased);
debug!(?substs_erased);

let unevaluated = ty::Unevaluated {
def: unevaluated.def,
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_middle/src/mir/interpret/queries.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{ErrorHandled, EvalToConstValueResult, GlobalId};

use crate::mir;
use crate::ty::fold::TypeFoldable;
use crate::ty::subst::InternalSubsts;
use crate::ty::{self, TyCtxt};
use rustc_hir::def_id::DefId;
Expand Down Expand Up @@ -38,6 +39,13 @@ impl<'tcx> TyCtxt<'tcx> {
ct: ty::Unevaluated<'tcx>,
span: Option<Span>,
) -> EvalToConstValueResult<'tcx> {
// Cannot resolve `Unevaluated` constants that contain inference
// variables. We reject those here since `resolve_opt_const_arg`
// would fail otherwise
if ct.substs.has_infer_types_or_consts() {
return Err(ErrorHandled::TooGeneric);
}

match ty::Instance::resolve_opt_const_arg(self, param_env, ct.def, ct.substs) {
Ok(Some(instance)) => {
let cid = GlobalId { instance, promoted: ct.promoted };
Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,12 @@ rustc_queries! {
}
}

query try_unify_abstract_consts(key: (
ty::Unevaluated<'tcx, ()>, ty::Unevaluated<'tcx, ()>
)) -> bool {
query try_unify_abstract_consts(key:
ty::ParamEnvAnd<'tcx, (ty::Unevaluated<'tcx, ()>, ty::Unevaluated<'tcx, ()>
)>) -> bool {
desc {
|tcx| "trying to unify the generic constants {} and {}",
tcx.def_path_str(key.0.def.did), tcx.def_path_str(key.1.def.did)
tcx.def_path_str(key.value.0.def.did), tcx.def_path_str(key.value.1.def.did)
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/relate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ pub fn super_relate_consts<'tcx, R: TypeRelation<'tcx>>(
(ty::ConstKind::Unevaluated(au), ty::ConstKind::Unevaluated(bu))
if tcx.features().generic_const_exprs =>
{
tcx.try_unify_abstract_consts((au.shrink(), bu.shrink()))
tcx.try_unify_abstract_consts(relation.param_env().and((au.shrink(), bu.shrink())))
}

// While this is slightly incorrect, it shouldn't matter for `min_const_generics`
Expand Down
88 changes: 59 additions & 29 deletions compiler/rustc_trait_selection/src/traits/const_evaluatable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ use std::iter;
use std::ops::ControlFlow;

/// Check if a given constant can be evaluated.
#[instrument(skip(infcx), level = "debug")]
pub fn is_const_evaluatable<'cx, 'tcx>(
infcx: &InferCtxt<'cx, 'tcx>,
uv: ty::Unevaluated<'tcx, ()>,
param_env: ty::ParamEnv<'tcx>,
span: Span,
) -> Result<(), NotConstEvaluatable> {
debug!("is_const_evaluatable({:?})", uv);
let tcx = infcx.tcx;

if tcx.features().generic_const_exprs {
Expand Down Expand Up @@ -185,6 +185,7 @@ pub fn is_const_evaluatable<'cx, 'tcx>(
}
}

#[instrument(skip(tcx), level = "debug")]
fn satisfied_from_param_env<'tcx>(
tcx: TyCtxt<'tcx>,
ct: AbstractConst<'tcx>,
Expand All @@ -197,11 +198,12 @@ fn satisfied_from_param_env<'tcx>(
// Try to unify with each subtree in the AbstractConst to allow for
// `N + 1` being const evaluatable even if theres only a `ConstEvaluatable`
// predicate for `(N + 1) * 2`
let result =
walk_abstract_const(tcx, b_ct, |b_ct| match try_unify(tcx, ct, b_ct) {
let result = walk_abstract_const(tcx, b_ct, |b_ct| {
match try_unify(tcx, ct, b_ct, param_env) {
true => ControlFlow::BREAK,
false => ControlFlow::CONTINUE,
});
}
});

if let ControlFlow::Break(()) = result {
debug!("is_const_evaluatable: abstract_const ~~> ok");
Expand Down Expand Up @@ -570,11 +572,12 @@ pub(super) fn thir_abstract_const<'tcx>(
pub(super) fn try_unify_abstract_consts<'tcx>(
tcx: TyCtxt<'tcx>,
(a, b): (ty::Unevaluated<'tcx, ()>, ty::Unevaluated<'tcx, ()>),
param_env: ty::ParamEnv<'tcx>,
) -> bool {
(|| {
if let Some(a) = AbstractConst::new(tcx, a)? {
if let Some(b) = AbstractConst::new(tcx, b)? {
return Ok(try_unify(tcx, a, b));
return Ok(try_unify(tcx, a, b, param_env));
}
}

Expand Down Expand Up @@ -619,32 +622,59 @@ where
recurse(tcx, ct, &mut f)
}

// Substitutes generics repeatedly to allow AbstractConsts to unify where a
// ConstKind::Unevalated could be turned into an AbstractConst that would unify e.g.
// Param(N) should unify with Param(T), substs: [Unevaluated("T2", [Unevaluated("T3", [Param(N)])])]
#[inline]
#[instrument(skip(tcx), level = "debug")]
fn try_replace_substs_in_root<'tcx>(
tcx: TyCtxt<'tcx>,
mut abstr_const: AbstractConst<'tcx>,
) -> Option<AbstractConst<'tcx>> {
while let Node::Leaf(ct) = abstr_const.root(tcx) {
match AbstractConst::from_const(tcx, ct) {
Ok(Some(act)) => abstr_const = act,
Ok(None) => break,
Err(_) => return None,
}
}

Some(abstr_const)
}

/// Tries to unify two abstract constants using structural equality.
#[instrument(skip(tcx), level = "debug")]
pub(super) fn try_unify<'tcx>(
tcx: TyCtxt<'tcx>,
mut a: AbstractConst<'tcx>,
mut b: AbstractConst<'tcx>,
a: AbstractConst<'tcx>,
b: AbstractConst<'tcx>,
param_env: ty::ParamEnv<'tcx>,
) -> bool {
// We substitute generics repeatedly to allow AbstractConsts to unify where a
// ConstKind::Unevalated could be turned into an AbstractConst that would unify e.g.
// Param(N) should unify with Param(T), substs: [Unevaluated("T2", [Unevaluated("T3", [Param(N)])])]
while let Node::Leaf(a_ct) = a.root(tcx) {
match AbstractConst::from_const(tcx, a_ct) {
Ok(Some(a_act)) => a = a_act,
Ok(None) => break,
Err(_) => return true,
let a = match try_replace_substs_in_root(tcx, a) {
Some(a) => a,
None => {
return true;
}
}
while let Node::Leaf(b_ct) = b.root(tcx) {
match AbstractConst::from_const(tcx, b_ct) {
Ok(Some(b_act)) => b = b_act,
Ok(None) => break,
Err(_) => return true,
};

let b = match try_replace_substs_in_root(tcx, b) {
Some(b) => b,
None => {
return true;
}
}
};

match (a.root(tcx), b.root(tcx)) {
let a_root = a.root(tcx);
let b_root = b.root(tcx);
debug!(?a_root, ?b_root);

match (a_root, b_root) {
(Node::Leaf(a_ct), Node::Leaf(b_ct)) => {
let a_ct = a_ct.eval(tcx, param_env);
debug!("a_ct evaluated: {:?}", a_ct);
let b_ct = b_ct.eval(tcx, param_env);
debug!("b_ct evaluated: {:?}", b_ct);

if a_ct.ty() != b_ct.ty() {
return false;
}
Expand Down Expand Up @@ -678,23 +708,23 @@ pub(super) fn try_unify<'tcx>(
}
}
(Node::Binop(a_op, al, ar), Node::Binop(b_op, bl, br)) if a_op == b_op => {
try_unify(tcx, a.subtree(al), b.subtree(bl))
&& try_unify(tcx, a.subtree(ar), b.subtree(br))
try_unify(tcx, a.subtree(al), b.subtree(bl), param_env)
&& try_unify(tcx, a.subtree(ar), b.subtree(br), param_env)
}
(Node::UnaryOp(a_op, av), Node::UnaryOp(b_op, bv)) if a_op == b_op => {
try_unify(tcx, a.subtree(av), b.subtree(bv))
try_unify(tcx, a.subtree(av), b.subtree(bv), param_env)
}
(Node::FunctionCall(a_f, a_args), Node::FunctionCall(b_f, b_args))
if a_args.len() == b_args.len() =>
{
try_unify(tcx, a.subtree(a_f), b.subtree(b_f))
try_unify(tcx, a.subtree(a_f), b.subtree(b_f), param_env)
&& iter::zip(a_args, b_args)
.all(|(&an, &bn)| try_unify(tcx, a.subtree(an), b.subtree(bn)))
.all(|(&an, &bn)| try_unify(tcx, a.subtree(an), b.subtree(bn), param_env))
}
(Node::Cast(a_kind, a_operand, a_ty), Node::Cast(b_kind, b_operand, b_ty))
if (a_ty == b_ty) && (a_kind == b_kind) =>
{
try_unify(tcx, a.subtree(a_operand), b.subtree(b_operand))
try_unify(tcx, a.subtree(a_operand), b.subtree(b_operand), param_env)
}
// use this over `_ => false` to make adding variants to `Node` less error prone
(Node::Cast(..), _)
Expand Down
26 changes: 26 additions & 0 deletions src/test/ui/const-generics/generic_const_exprs/eval-try-unify.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// build-pass

#![feature(generic_const_exprs)]
//~^ WARNING the feature `generic_const_exprs` is incomplete

trait Generic {
const ASSOC: usize;
}

impl Generic for u8 {
const ASSOC: usize = 17;
}
impl Generic for u16 {
const ASSOC: usize = 13;
}


fn uses_assoc_type<T: Generic, const N: usize>() -> [u8; N + T::ASSOC] {
[0; N + T::ASSOC]
}

fn only_generic_n<const N: usize>() -> [u8; N + 13] {
uses_assoc_type::<u16, N>()
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
warning: the feature `generic_const_exprs` is incomplete and may not be safe to use and/or cause compiler crashes
--> $DIR/eval-try-unify.rs:3:12
|
LL | #![feature(generic_const_exprs)]
| ^^^^^^^^^^^^^^^^^^^
|
= note: `#[warn(incomplete_features)]` on by default
= note: see issue #76560 <https://github.com/rust-lang/rust/issues/76560> for more information

warning: 1 warning emitted

13 changes: 1 addition & 12 deletions src/test/ui/const-generics/issues/issue-83765.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

trait TensorDimension {
const DIM : usize;
//~^ ERROR cycle detected when resolving instance `<LazyUpdim<T, {T::DIM}, DIM>
const ISSCALAR : bool = Self::DIM == 0;
fn is_scalar(&self) -> bool {Self::ISSCALAR}
}
Expand Down Expand Up @@ -42,22 +43,16 @@ impl<'a,T : Broadcastable,const DIM : usize> TensorDimension for LazyUpdim<'a,T,

impl<'a,T : Broadcastable,const DIM : usize> TensorSize for LazyUpdim<'a,T,{T::DIM},DIM> {
fn size(&self) -> [usize;DIM] {self.size}
//~^ ERROR method not compatible with trait
}

impl<'a,T : Broadcastable,const DIM : usize> Broadcastable for LazyUpdim<'a,T,{T::DIM},DIM>
{
type Element = T::Element;
fn bget(&self,index:[usize;DIM]) -> Option<Self::Element> {
//~^ ERROR method not compatible with trait
assert!(DIM >= T::DIM);
if !self.inbounds(index) {return None}
//~^ ERROR unconstrained generic constant
//~| ERROR mismatched types
let size = self.size();
//~^ ERROR unconstrained generic constant
let newindex : [usize;T::DIM] = Default::default();
//~^ ERROR the trait bound `[usize; _]: Default` is not satisfied
self.reference.bget(newindex)
}
}
Expand All @@ -76,20 +71,14 @@ impl<'a,R, T : Broadcastable, F : Fn(T::Element) -> R ,
const DIM: usize> TensorSize for BMap<'a,R,T,F,DIM> {

fn size(&self) -> [usize;DIM] {self.reference.size()}
//~^ ERROR unconstrained generic constant
//~| ERROR mismatched types
//~| ERROR method not compatible with trait
}

impl<'a,R, T : Broadcastable, F : Fn(T::Element) -> R ,
const DIM: usize> Broadcastable for BMap<'a,R,T,F,DIM> {

type Element = R;
fn bget(&self,index:[usize;DIM]) -> Option<Self::Element> {
//~^ ERROR method not compatible with trait
self.reference.bget(index).map(&self.closure)
//~^ ERROR unconstrained generic constant
//~| ERROR mismatched types
}
}

Expand Down
Loading