Skip to content

Commit 6741444

Browse files
committed
fix: remove global NatCast (Fin n) instance (#8620)
This PR removes the `NatCast (Fin n)` global instance (both the direct instance, and the indirect one via `Lean.Grind.Semiring`), as that instance causes causes `x < n` (for `x : Fin k`, `n : Nat`) to be elaborated as `x < ↑n` rather than `↑x < n`, which is undesirable. Note however that in Mathlib this happens anyway!
1 parent 87914f8 commit 6741444

File tree

7 files changed

+106
-31
lines changed

7 files changed

+106
-31
lines changed

src/Init/Data/BitVec/Lemmas.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ theorem ofFin_ofNat (n : Nat) :
319319
@[simp] theorem ofFin_neg {x : Fin (2 ^ w)} : ofFin (-x) = -(ofFin x) := by
320320
rfl
321321

322+
open Fin.NatCast in
322323
@[simp, norm_cast] theorem ofFin_natCast (n : Nat) : ofFin (n : Fin (2^w)) = (n : BitVec w) := by
323324
rfl
324325

@@ -337,6 +338,7 @@ theorem toFin_zero : toFin (0 : BitVec w) = 0 := rfl
337338
theorem toFin_one : toFin (1 : BitVec w) = 1 := by
338339
rw [toFin_inj]; simp only [ofNat_eq_ofNat, ofFin_ofNat]
339340

341+
open Fin.NatCast in
340342
@[simp, norm_cast] theorem toFin_natCast (n : Nat) : toFin (n : BitVec w) = (n : Fin (2^w)) := by
341343
rfl
342344

src/Init/Data/Fin/Lemmas.lean

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,53 @@ theorem dite_val {n : Nat} {c : Prop} [Decidable c] {x y : Fin n} :
102102
(if c then x else y).val = if c then x.val else y.val := by
103103
by_cases c <;> simp [*]
104104

105-
instance (n : Nat) [NeZero n] : NatCast (Fin n) where
105+
namespace NatCast
106+
107+
/--
108+
This is not a global instance, but may be activated locally via `open Fin.NatCast in ...`.
109+
110+
This is not an instance because the `binop%` elaborator assumes that
111+
there are no non-trivial coercion loops,
112+
but this introduces a coercion from `Nat` to `Fin n` and back.
113+
114+
Non-trivial loops lead to undesirable and counterintuitive elaboration behavior.
115+
For example, for `x : Fin k` and `n : Nat`,
116+
it causes `x < n` to be elaborated as `x < ↑n` rather than `↑x < n`,
117+
silently introducing wraparound arithmetic.
118+
119+
Note: as of 2025-06-03, Mathlib has such a coercion for `Fin n` anyway!
120+
-/
121+
@[expose]
122+
def instNatCast (n : Nat) [NeZero n] : NatCast (Fin n) where
106123
natCast a := Fin.ofNat n a
107124

125+
attribute [scoped instance] instNatCast
126+
127+
end NatCast
128+
108129
@[expose]
109130
def intCast [NeZero n] (a : Int) : Fin n :=
110131
if 0 ≤ a then
111132
Fin.ofNat n a.natAbs
112133
else
113134
- Fin.ofNat n a.natAbs
114135

115-
instance (n : Nat) [NeZero n] : IntCast (Fin n) where
136+
namespace IntCast
137+
138+
/--
139+
This is not a global instance, but may be activated locally via `open Fin.IntCast in ...`.
140+
141+
See the doc-string for `Fin.NatCast.instNatCast` for more details.
142+
-/
143+
@[expose]
144+
def instIntCast (n : Nat) [NeZero n] : IntCast (Fin n) where
116145
intCast := Fin.intCast
117146

147+
attribute [scoped instance] instIntCast
148+
149+
end IntCast
150+
151+
open IntCast in
118152
theorem intCast_def {n : Nat} [NeZero n] (x : Int) :
119153
(x : Fin n) = if 0 ≤ x then Fin.ofNat n x.natAbs else -Fin.ofNat n x.natAbs := rfl
120154

src/Init/Grind/CommRing/Basic.lean

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ class CommRing (α : Type u) extends Ring α, CommSemiring α
7171
attribute [instance 100] Semiring.toAdd Semiring.toMul Semiring.toHPow Ring.toNeg Ring.toSub
7272

7373
-- This is a low-priority instance, to avoid conflicts with existing `OfNat`, `NatCast`, and `IntCast` instances.
74-
attribute [instance 100] Semiring.ofNat Semiring.natCast Ring.intCast
74+
attribute [instance 100] Semiring.ofNat
75+
76+
attribute [local instance] Semiring.natCast Ring.intCast
7577

7678
namespace Semiring
7779

src/Init/Grind/CommRing/Fin.lean

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,6 @@ namespace Lean.Grind
1414

1515
namespace Fin
1616

17-
instance (n : Nat) [NeZero n] : NatCast (Fin n) where
18-
natCast a := Fin.ofNat n a
19-
20-
@[expose]
21-
def intCast [NeZero n] (a : Int) : Fin n :=
22-
if 0 ≤ a then
23-
Fin.ofNat n a.natAbs
24-
else
25-
- Fin.ofNat n a.natAbs
26-
27-
instance (n : Nat) [NeZero n] : IntCast (Fin n) where
28-
intCast := Fin.intCast
29-
30-
theorem intCast_def {n : Nat} [NeZero n] (x : Int) :
31-
(x : Fin n) = if 0 ≤ x then Fin.ofNat n x.natAbs else -Fin.ofNat n x.natAbs := rfl
32-
3317
-- TODO: we should replace this at runtime with either repeated squaring,
3418
-- or a GMP accelerated function.
3519
@[expose]
@@ -78,18 +62,22 @@ theorem sub_eq_add_neg [NeZero n] (a b : Fin n) : a - b = a + -b := by
7862
cases a; cases b; simp [Fin.neg_def, Fin.sub_def, Fin.add_def, Nat.add_comm]
7963

8064
private theorem neg_neg [NeZero n] (a : Fin n) : - - a = a := by
81-
cases a; simp [Fin.neg_def, Fin.sub_def];
65+
cases a; simp [Fin.neg_def, Fin.sub_def]
8266
next a h => cases a; simp; next a =>
8367
rw [Nat.self_sub_mod n (a+1)]
8468
have : NeZero (n - (a + 1)) := ⟨by omega⟩
8569
rw [Nat.self_sub_mod, Nat.sub_sub_eq_min, Nat.min_eq_right (Nat.le_of_lt h)]
8670

71+
open Fin.NatCast Fin.IntCast in
8772
theorem intCast_neg [NeZero n] (i : Int) : Int.cast (R := Fin n) (-i) = - Int.cast (R := Fin n) i := by
88-
simp [Int.cast, IntCast.intCast, Fin.intCast]; split <;> split <;> try omega
73+
simp [Int.cast, IntCast.intCast, Fin.intCast]
74+
split <;> split <;> try omega
8975
next h₁ h₂ => simp [Int.le_antisymm h₁ h₂, Fin.neg_def]
9076
next => simp [Fin.neg_neg]
9177

9278
instance (n : Nat) [NeZero n] : CommRing (Fin n) where
79+
natCast := Fin.NatCast.instNatCast n
80+
intCast := Fin.IntCast.instIntCast n
9381
add_assoc := Fin.add_assoc
9482
add_comm := Fin.add_comm
9583
add_zero := Fin.add_zero

src/Init/Grind/CommRing/Poly.lean

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ import Init.Grind.CommRing.Basic
1515
namespace Lean.Grind
1616
namespace CommRing
1717

18+
-- These are no longer global instances, so we need to turn them on here.
19+
attribute [local instance] Semiring.natCast Ring.intCast
20+
1821
abbrev Var := Nat
1922

2023
inductive Expr where

src/Lean/Meta/Tactic/Grind/Arith/CommRing/RingId.lean

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,37 @@ private def getPowFn (type : Expr) (u : Level) (semiringInst : Expr) : GoalM Exp
5858
internalizeFn <| mkApp4 (mkConst ``HPow.hPow [u, 0, u]) type Nat.mkType type inst
5959

6060
private def getIntCastFn (type : Expr) (u : Level) (ringInst : Expr) : GoalM Expr := do
61-
let instType := mkApp (mkConst ``IntCast [u]) type
62-
let .some inst ← trySynthInstance instType |
63-
throwError "failed to find instance for ring intCast{indentExpr instType}"
6461
let inst' := mkApp2 (mkConst ``Grind.Ring.intCast [u]) type ringInst
65-
unless (← withDefault <| isDefEq inst inst') do
66-
throwError "instance for intCast{indentExpr inst}\nis not definitionally equal to the `Grind.Ring` one{indentExpr inst'}"
62+
let instType := mkApp (mkConst ``IntCast [u]) type
63+
-- Note that `Ring.intCast` is not registered as a global instance
64+
-- (to avoid introducing unwanted coercions)
65+
-- so merely having a `Ring α` instance
66+
-- does not guarantee that an `IntCast α` will be available.
67+
-- When both are present we verify that they are defeq,
68+
-- and otherwise fall back to the field of the `Ring α` instance that we already have.
69+
let inst ← match (← trySynthInstance instType).toOption with
70+
| none => pure inst'
71+
| some inst =>
72+
unless (← withDefault <| isDefEq inst inst') do
73+
throwError "instance for intCast{indentExpr inst}\nis not definitionally equal to the `Grind.Ring` one{indentExpr inst'}"
74+
pure inst
6775
internalizeFn <| mkApp2 (mkConst ``IntCast.intCast [u]) type inst
6876

6977
private def getNatCastFn (type : Expr) (u : Level) (semiringInst : Expr) : GoalM Expr := do
70-
let instType := mkApp (mkConst ``NatCast [u]) type
71-
let .some inst ← trySynthInstance instType |
72-
throwError "failed to find instance for ring natCast{indentExpr instType}"
7378
let inst' := mkApp2 (mkConst ``Grind.Semiring.natCast [u]) type semiringInst
74-
unless (← withDefault <| isDefEq inst inst') do
75-
throwError "instance for natCast{indentExpr inst}\nis not definitionally equal to the `Grind.Semiring` one{indentExpr inst'}"
79+
let instType := mkApp (mkConst ``NatCast [u]) type
80+
-- Note that `Semiring.natCast` is not registered as a global instance
81+
-- (to avoid introducing unwanted coercions)
82+
-- so merely having a `Semiring α` instance
83+
-- does not guarantee that an `NatCast α` will be available.
84+
-- When both are present we verify that they are defeq,
85+
-- and otherwise fall back to the field of the `Semiring α` instance that we already have.
86+
let inst ← match (← trySynthInstance instType).toOption with
87+
| none => pure inst'
88+
| some inst =>
89+
unless (← withDefault <| isDefEq inst inst') do
90+
throwError "instance for natCast{indentExpr inst}\nis not definitionally equal to the `Grind.Semiring` one{indentExpr inst'}"
91+
pure inst
7692
internalizeFn <| mkApp2 (mkConst ``NatCast.natCast [u]) type inst
7793

7894
/--

tests/lean/run/fin_coercions.lean

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
set_option pp.mvars false
2+
3+
-- We first verify that there is no global coercion from `Nat` to `Fin n`.
4+
-- Such a coercion would frequently introduce unexpected modular arithmetic.
5+
6+
/--
7+
error: type mismatch
8+
n
9+
has type
10+
Nat : Type
11+
but is expected to have type
12+
Fin 3 : Type
13+
---
14+
info: fun n => sorry : (n : Nat) → ?_ n
15+
-/
16+
#guard_msgs in #check fun (n : Nat) => (n : Fin 3)
17+
18+
-- This instance is available via `open Fin.NatCast in ...`
19+
20+
section
21+
22+
open Fin.NatCast
23+
24+
variable (m : Nat) (n : Fin 3)
25+
/-- info: n < ↑m : Prop -/
26+
#guard_msgs in #check n < m
27+
28+
end
29+
30+
example (x : Fin (n + 1)) (h : x < n) : Fin (n + 1) := x.succ.castLT (by simp [h])

0 commit comments

Comments
 (0)