Skip to content

Commit 1575ec7

Browse files
committed
ccnv: skip conversions implied by terms having the same type
When converting [fun x : A => e] and [fun x : A' => e'] if I know they have the same type [forall x : A, B] I know that [A == A'] (Pi-injectivity). When converting constructors we can skip parameters. For perf reasons we switch the order of comparing arguments so that it's left to right, otherwise the optimisation is skipping an opportunity to fail early leading to eg +100% compile time for mathcomp-odd-order (1.4ks to 2.8ks). The main entry point [clos_gen_conv] does not assume there is a common type. Indeed some callers do not preserve that property, leading to errors. I encountered the following before I stopped assuming: - [nsatz] in 2145.v and in the test suite Nsatz.v - some [apply] call in HoTT (in field_of_fractions.v) When converting types (which have a common type), since optimisations are on lambdas and constructors they're not blocked. This accounts for all kernel calls except module subtyping (rare). In the future we may want to expose a way to pass [idtypes:true] so that subtyping and the well-behaved parts of unification could take advantage.
1 parent 8345302 commit 1575ec7

File tree

3 files changed

+99
-55
lines changed

3 files changed

+99
-55
lines changed

clib/cArray.ml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ sig
3737
val fold_left_i : (int -> 'a -> 'b -> 'a) -> 'a -> 'b array -> 'a
3838
val fold_right2 :
3939
('a -> 'b -> 'c -> 'c) -> 'a array -> 'b array -> 'c -> 'c
40+
val fold_right2_i :
41+
(int -> 'a -> 'b -> 'c -> 'c) -> 'a array -> 'b array -> 'c -> 'c
4042
val fold_left2 :
4143
('a -> 'b -> 'c -> 'a) -> 'a -> 'b array -> 'c array -> 'a
4244
val fold_left3 :
@@ -243,6 +245,16 @@ let fold_right2 f v1 v2 a =
243245
if Array.length v2 <> lv1 then invalid_arg "Array.fold_right2";
244246
fold a lv1
245247

248+
let fold_right2_i f v1 v2 a =
249+
let lv1 = Array.length v1 in
250+
let rec fold a n =
251+
if n=0 then a
252+
else
253+
let k = n-1 in
254+
fold (f k (uget v1 k) (uget v2 k) a) k in
255+
if Array.length v2 <> lv1 then invalid_arg "Array.fold_right2";
256+
fold a lv1
257+
246258
let fold_left2 f a v1 v2 =
247259
let lv1 = Array.length v1 in
248260
let rec fold a n =

clib/cArray.mli

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ sig
6262
val fold_left_i : (int -> 'a -> 'b -> 'a) -> 'a -> 'b array -> 'a
6363
val fold_right2 :
6464
('a -> 'b -> 'c -> 'c) -> 'a array -> 'b array -> 'c -> 'c
65+
val fold_right2_i :
66+
(int -> 'a -> 'b -> 'c -> 'c) -> 'a array -> 'b array -> 'c -> 'c
6567
val fold_left2 :
6668
('a -> 'b -> 'c -> 'a) -> 'a -> 'b array -> 'c array -> 'a
6769
val fold_left3 :

kernel/reduction.ml

Lines changed: 85 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -288,30 +288,43 @@ let conv_table_key infos k1 k2 cuniv =
288288
| RelKey n, RelKey n' when Int.equal n n' -> cuniv
289289
| _ -> raise NotConvertible
290290

291-
let compare_stacks f fmind lft1 stk1 lft2 stk2 cuniv =
292-
let rec cmp_rec pstk1 pstk2 cuniv =
291+
let zlapp_skip skip = function
292+
| Zlapp a ->
293+
let la = Array.length a in
294+
if la >= skip then 0 else skip - la
295+
| _ -> (* assert (skip == 0); *) 0
296+
297+
let compare_stacks ~skip f fmind lft1 stk1 lft2 stk2 cuniv =
298+
let rec cmp_rec skip pstk1 pstk2 cuniv =
293299
match (pstk1,pstk2) with
294300
| (z1::s1, z2::s2) ->
295-
let cu1 = cmp_rec s1 s2 cuniv in
296-
(match (z1,z2) with
297-
| (Zlapp a1,Zlapp a2) ->
298-
Array.fold_right2 f a1 a2 cu1
299-
| (Zlproj (c1,l1),Zlproj (c2,l2)) ->
300-
if not (Constant.equal c1 c2) then
301-
raise NotConvertible
302-
else cu1
303-
| (Zlfix(fx1,a1),Zlfix(fx2,a2)) ->
304-
let cu2 = f fx1 fx2 cu1 in
305-
cmp_rec a1 a2 cu2
306-
| (Zlcase(ci1,l1,p1,br1),Zlcase(ci2,l2,p2,br2)) ->
307-
if not (fmind ci1.ci_ind ci2.ci_ind) then
308-
raise NotConvertible;
309-
let cu2 = f (l1,p1) (l2,p2) cu1 in
310-
Array.fold_right2 (fun c1 c2 -> f (l1,c1) (l2,c2)) br1 br2 cu2
311-
| _ -> assert false)
312-
| _ -> cuniv in
301+
let cuniv = compare_stack_el skip z1 z2 cuniv in
302+
cmp_rec (zlapp_skip skip z1) s1 s2 cuniv
303+
| _ -> cuniv
304+
and compare_stack_el skip z1 z2 cuniv =
305+
match (z1,z2) with
306+
| (Zlapp a1,Zlapp a2) ->
307+
if skip == 0
308+
then Array.fold_left2 (fun cuniv a1 a2 -> f a1 a2 cuniv) cuniv a1 a2
309+
else Array.fold_left2_i (fun i cuniv a1 a2 ->
310+
if i < skip then cuniv else f a1 a2 cuniv)
311+
cuniv a1 a2
312+
| (Zlproj (c1,l1),Zlproj (c2,l2)) ->
313+
if not (Constant.equal c1 c2) then
314+
raise NotConvertible
315+
else cuniv
316+
| (Zlfix(fx1,a1),Zlfix(fx2,a2)) ->
317+
let cuniv = f fx1 fx2 cuniv in
318+
cmp_rec 0 a1 a2 cuniv
319+
| (Zlcase(ci1,l1,p1,br1),Zlcase(ci2,l2,p2,br2)) ->
320+
if not (fmind ci1.ci_ind ci2.ci_ind) then
321+
raise NotConvertible;
322+
let cuniv = f (l1,p1) (l2,p2) cuniv in
323+
Array.fold_right2 (fun c1 c2 -> f (l1,c1) (l2,c2)) br1 br2 cuniv
324+
| _ -> assert false
325+
in
313326
if compare_stack_shape stk1 stk2 then
314-
cmp_rec (pure_stack lft1 stk1) (pure_stack lft2 stk2) cuniv
327+
cmp_rec skip (pure_stack lft1 stk1) (pure_stack lft2 stk2) cuniv
315328
else raise NotConvertible
316329

317330
type conv_tab = {
@@ -325,12 +338,27 @@ type conv_tab = {
325338
(** The same heap separation invariant must hold for the fconstr arguments
326339
passed to each respective side of the conversion function below. *)
327340

341+
(** About the [idtyps] argument:
342+
343+
We should only be converting well-typed terms with a common type,
344+
but some callers (outside the kernel) do not respect this
345+
invariant.
346+
347+
We assume that terms are well-typed (leading to "Conversion test
348+
raised an anomaly"). This means that when converting applications
349+
if the heads convert the arguments have a common type. Then we can
350+
use that information for optimisation, skipping some conversions.
351+
352+
eg if I have [P : (forall x : A, B) -> foo] and my conversion
353+
problem is [P (fun x : A0 => e) == P (fun x : A1 => e')] when I
354+
convert the lambdas I may assume [A0 == A == A1]. *)
355+
328356
(* Conversion between [lft1]term1 and [lft2]term2 *)
329-
let rec ccnv cv_pb l2r infos lft1 lft2 term1 term2 cuniv =
330-
eqappr cv_pb l2r infos (lft1, (term1,[])) (lft2, (term2,[])) cuniv
357+
let rec ccnv ~idtyps cv_pb l2r infos lft1 lft2 term1 term2 cuniv =
358+
eqappr ~idtyps cv_pb l2r infos (lft1, (term1,[])) (lft2, (term2,[])) cuniv
331359

332360
(* Conversion between [lft1](hd1 v1) and [lft2](hd2 v2) *)
333-
and eqappr cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
361+
and eqappr ~idtyps cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
334362
Control.check_for_interrupt ();
335363
(* First head reduce both terms *)
336364
let ninfos = infos_with_reds infos.cnv_inf betaiotazeta in
@@ -394,52 +422,52 @@ and eqappr cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
394422
| Some def1 -> ((lft1, (def1, v1)), appr2)
395423
| None -> raise NotConvertible)
396424
in
397-
eqappr cv_pb l2r infos app1 app2 cuniv)
425+
eqappr ~idtyps cv_pb l2r infos app1 app2 cuniv)
398426

399427
| (FProj (p1,c1), FProj (p2, c2)) ->
400428
(* Projections: prefer unfolding to first-order unification,
401429
which will happen naturally if the terms c1, c2 are not in constructor
402430
form *)
403431
(match unfold_projection infos.cnv_inf p1 with
404432
| Some s1 ->
405-
eqappr cv_pb l2r infos (lft1, (c1, (s1 :: v1))) appr2 cuniv
433+
eqappr ~idtyps cv_pb l2r infos (lft1, (c1, (s1 :: v1))) appr2 cuniv
406434
| None ->
407435
match unfold_projection infos.cnv_inf p2 with
408436
| Some s2 ->
409-
eqappr cv_pb l2r infos appr1 (lft2, (c2, (s2 :: v2))) cuniv
437+
eqappr ~idtyps cv_pb l2r infos appr1 (lft2, (c2, (s2 :: v2))) cuniv
410438
| None ->
411439
if Constant.equal (Projection.constant p1) (Projection.constant p2)
412440
&& compare_stack_shape v1 v2 then
413441
let el1 = el_stack lft1 v1 in
414442
let el2 = el_stack lft2 v2 in
415-
let u1 = ccnv CONV l2r infos el1 el2 c1 c2 cuniv in
443+
let u1 = ccnv ~idtyps:false CONV l2r infos el1 el2 c1 c2 cuniv in
416444
convert_stacks l2r infos lft1 lft2 v1 v2 u1
417445
else (* Two projections in WHNF: unfold *)
418446
raise NotConvertible)
419447

420448
| (FProj (p1,c1), t2) ->
421449
(match unfold_projection infos.cnv_inf p1 with
422450
| Some s1 ->
423-
eqappr cv_pb l2r infos (lft1, (c1, (s1 :: v1))) appr2 cuniv
451+
eqappr ~idtyps cv_pb l2r infos (lft1, (c1, (s1 :: v1))) appr2 cuniv
424452
| None ->
425453
(match t2 with
426454
| FFlex fl2 ->
427455
(match unfold_reference infos.cnv_inf infos.rgt_tab fl2 with
428456
| Some def2 ->
429-
eqappr cv_pb l2r infos appr1 (lft2, (def2, v2)) cuniv
457+
eqappr ~idtyps cv_pb l2r infos appr1 (lft2, (def2, v2)) cuniv
430458
| None -> raise NotConvertible)
431459
| _ -> raise NotConvertible))
432460

433461
| (t1, FProj (p2,c2)) ->
434462
(match unfold_projection infos.cnv_inf p2 with
435463
| Some s2 ->
436-
eqappr cv_pb l2r infos appr1 (lft2, (c2, (s2 :: v2))) cuniv
464+
eqappr ~idtyps cv_pb l2r infos appr1 (lft2, (c2, (s2 :: v2))) cuniv
437465
| None ->
438466
(match t1 with
439467
| FFlex fl1 ->
440468
(match unfold_reference infos.cnv_inf infos.lft_tab fl1 with
441469
| Some def1 ->
442-
eqappr cv_pb l2r infos (lft1, (def1, v1)) appr2 cuniv
470+
eqappr ~idtyps cv_pb l2r infos (lft1, (def1, v1)) appr2 cuniv
443471
| None -> raise NotConvertible)
444472
| _ -> raise NotConvertible))
445473

@@ -453,17 +481,19 @@ and eqappr cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
453481
let (_,ty2,bd2) = destFLambda mk_clos hd2 in
454482
let el1 = el_stack lft1 v1 in
455483
let el2 = el_stack lft2 v2 in
456-
let cuniv = ccnv CONV l2r infos el1 el2 ty1 ty2 cuniv in
457-
ccnv CONV l2r infos (el_lift el1) (el_lift el2) bd1 bd2 cuniv
484+
let cuniv = if idtyps then cuniv (* Pi injectivity *)
485+
else ccnv ~idtyps:true CONV l2r infos el1 el2 ty1 ty2 cuniv
486+
in
487+
ccnv ~idtyps CONV l2r infos (el_lift el1) (el_lift el2) bd1 bd2 cuniv
458488

459489
| (FProd (_,c1,c2), FProd (_,c'1,c'2)) ->
460490
if not (is_empty_stack v1 && is_empty_stack v2) then
461491
anomaly (Pp.str "conversion was given ill-typed terms (FProd).");
462492
(* Luo's system *)
463493
let el1 = el_stack lft1 v1 in
464494
let el2 = el_stack lft2 v2 in
465-
let cuniv = ccnv CONV l2r infos el1 el2 c1 c'1 cuniv in
466-
ccnv cv_pb l2r infos (el_lift el1) (el_lift el2) c2 c'2 cuniv
495+
let cuniv = ccnv ~idtyps:true CONV l2r infos el1 el2 c1 c'1 cuniv in
496+
ccnv ~idtyps:true cv_pb l2r infos (el_lift el1) (el_lift el2) c2 c'2 cuniv
467497

468498
(* Eta-expansion on the fly *)
469499
| (FLambda _, _) ->
@@ -473,7 +503,7 @@ and eqappr cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
473503
anomaly (Pp.str "conversion was given unreduced term (FLambda).")
474504
in
475505
let (_,_ty1,bd1) = destFLambda mk_clos hd1 in
476-
eqappr CONV l2r infos
506+
eqappr ~idtyps CONV l2r infos
477507
(el_lift lft1, (bd1, [])) (el_lift lft2, (hd2, eta_expand_stack v2)) cuniv
478508
| (_, FLambda _) ->
479509
let () = match v2 with
@@ -482,7 +512,7 @@ and eqappr cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
482512
anomaly (Pp.str "conversion was given unreduced term (FLambda).")
483513
in
484514
let (_,_ty2,bd2) = destFLambda mk_clos hd2 in
485-
eqappr CONV l2r infos
515+
eqappr ~idtyps CONV l2r infos
486516
(el_lift lft1, (hd1, eta_expand_stack v1)) (el_lift lft2, (bd2, [])) cuniv
487517

488518
(* only one constant, defined var or defined rel *)
@@ -495,7 +525,7 @@ and eqappr cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
495525
unfoldings, we perform reduction with all flags on. *)
496526
let all = RedFlags.red_add_transparent all (RedFlags.red_transparent (info_flags infos.cnv_inf)) in
497527
let r1 = whd_stack (infos_with_reds infos.cnv_inf all) infos.lft_tab def1 v1 in
498-
eqappr cv_pb l2r infos (lft1, r1) appr2 cuniv
528+
eqappr ~idtyps cv_pb l2r infos (lft1, r1) appr2 cuniv
499529
| None ->
500530
match c2 with
501531
| FConstruct ((ind2,j2),u2) ->
@@ -512,7 +542,7 @@ and eqappr cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
512542
(** Symmetrical case of above. *)
513543
let all = RedFlags.red_add_transparent all (RedFlags.red_transparent (info_flags infos.cnv_inf)) in
514544
let r2 = whd_stack (infos_with_reds infos.cnv_inf all) infos.rgt_tab def2 v2 in
515-
eqappr cv_pb l2r infos appr1 (lft2, r2) cuniv
545+
eqappr ~idtyps cv_pb l2r infos appr1 (lft2, r2) cuniv
516546
| None ->
517547
match c1 with
518548
| FConstruct ((ind1,j1),u1) ->
@@ -540,17 +570,17 @@ and eqappr cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
540570

541571
| (FConstruct ((ind1,j1),u1), FConstruct ((ind2,j2),u2)) ->
542572
if Int.equal j1 j2 && eq_ind ind1 ind2 then
543-
if Univ.Instance.length u1 = 0 || Univ.Instance.length u2 = 0 then
544-
let cuniv = convert_instances ~flex:false u1 u2 cuniv in
545-
convert_stacks l2r infos lft1 lft2 v1 v2 cuniv
573+
let mind = Environ.lookup_mind (fst ind1) (info_env infos.cnv_inf) in
574+
let nargs = CClosure.stack_args_size v1 in
575+
if not (Int.equal nargs (CClosure.stack_args_size v2))
576+
then raise NotConvertible
546577
else
547-
let mind = Environ.lookup_mind (fst ind1) (info_env infos.cnv_inf) in
548-
let nargs = CClosure.stack_args_size v1 in
549-
if not (Int.equal nargs (CClosure.stack_args_size v2))
550-
then raise NotConvertible
551-
else
552-
let cuniv = convert_constructors (mind, snd ind1, j1) nargs u1 u2 cuniv in
553-
convert_stacks l2r infos lft1 lft2 v1 v2 cuniv
578+
let cuniv = convert_constructors (mind, snd ind1, j1) nargs u1 u2 cuniv in
579+
let skip = if idtyps (* Ind injectivity *)
580+
then mind.Declarations.mind_nparams
581+
else 0
582+
in
583+
convert_stacks ~skip l2r infos lft1 lft2 v1 v2 cuniv
554584
else raise NotConvertible
555585

556586
(* Eta expansion of records *)
@@ -610,9 +640,9 @@ and eqappr cv_pb l2r infos (lft1,st1) (lft2,st2) cuniv =
610640
| (FRel _ | FAtom _ | FInd _ | FFix _ | FCoFix _
611641
| FProd _ | FEvar _), _ -> raise NotConvertible
612642

613-
and convert_stacks l2r infos lft1 lft2 stk1 stk2 cuniv =
614-
compare_stacks
615-
(fun (l1,t1) (l2,t2) cuniv -> ccnv CONV l2r infos l1 l2 t1 t2 cuniv)
643+
and convert_stacks ?(skip=0) l2r infos lft1 lft2 stk1 stk2 cuniv =
644+
compare_stacks ~skip
645+
(fun (l1,t1) (l2,t2) cuniv -> ccnv ~idtyps:true CONV l2r infos l1 l2 t1 t2 cuniv)
616646
(eq_ind)
617647
lft1 stk1 lft2 stk2 cuniv
618648

@@ -624,7 +654,7 @@ and convert_vect l2r infos lft1 lft2 v1 v2 cuniv =
624654
let rec fold n cuniv =
625655
if n >= lv1 then cuniv
626656
else
627-
let cuniv = ccnv CONV l2r infos lft1 lft2 v1.(n) v2.(n) cuniv in
657+
let cuniv = ccnv ~idtyps:true CONV l2r infos lft1 lft2 v1.(n) v2.(n) cuniv in
628658
fold (n+1) cuniv in
629659
fold 0 cuniv
630660
else raise NotConvertible
@@ -637,7 +667,7 @@ let clos_gen_conv trans cv_pb l2r evars env univs t1 t2 =
637667
lft_tab = create_tab ();
638668
rgt_tab = create_tab ();
639669
} in
640-
ccnv cv_pb l2r infos el_id el_id (inject t1) (inject t2) univs
670+
ccnv ~idtyps:false cv_pb l2r infos el_id el_id (inject t1) (inject t2) univs
641671

642672

643673
let check_eq univs u u' =

0 commit comments

Comments
 (0)