@@ -220,10 +220,55 @@ pub(super) fn check_pat(tc_state: &mut TcFunState, pat: &mut ast::L<ast::Pat>, l
220
220
ast:: Pat :: Char ( _) => Ty :: char ( ) ,
221
221
222
222
ast:: Pat :: Or ( pat1, pat2) => {
223
+ // Collect binders for `pat1` and `pat2` in separate envs.
224
+ tc_state. env . enter ( ) ;
223
225
let pat1_ty = check_pat ( tc_state, pat1, level) ;
226
+ let pat1_binders = tc_state. env . exit ( ) ;
227
+
228
+ tc_state. env . enter ( ) ;
224
229
let pat2_ty = check_pat ( tc_state, pat2, level) ;
225
- // TODO: Check that the patterns bind the same variables of same types.
226
- // TODO: Any other checks here?
230
+ let pat2_binders = tc_state. env . exit ( ) ;
231
+
232
+ // Check that patterns bind the same variables.
233
+ let pat1_binder_keys: Set < & Id > = pat1_binders. keys ( ) . collect ( ) ;
234
+ let pat2_binder_keys: Set < & Id > = pat2_binders. keys ( ) . collect ( ) ;
235
+
236
+ if pat1_binder_keys != pat2_binder_keys {
237
+ let mut left_vars: Vec < Id > =
238
+ pat1_binder_keys. iter ( ) . map ( |id| ( * id) . clone ( ) ) . collect ( ) ;
239
+ left_vars. sort ( ) ;
240
+ let mut right_vars: Vec < Id > =
241
+ pat2_binder_keys. iter ( ) . map ( |id| ( * id) . clone ( ) ) . collect ( ) ;
242
+ right_vars. sort ( ) ;
243
+ panic ! (
244
+ "{}: Or pattern alternatives bind different set of variables:
245
+ Left = {}
246
+ Right = {}" ,
247
+ loc_display( & pat. loc) ,
248
+ left_vars. join( ", " ) ,
249
+ right_vars. join( ", " ) ,
250
+ )
251
+ }
252
+
253
+ // Unify pattern binders to make sure they bind the values with same types.
254
+ for binder in pat1_binder_keys {
255
+ let ty1 = pat1_binders. get ( binder) . unwrap ( ) ;
256
+ let ty2 = pat2_binders. get ( binder) . unwrap ( ) ;
257
+ unify (
258
+ ty1,
259
+ ty2,
260
+ tc_state. tys . tys . cons ( ) ,
261
+ tc_state. var_gen ,
262
+ level,
263
+ & pat. loc ,
264
+ ) ;
265
+ }
266
+
267
+ // Add bound variables back to the env.
268
+ for ( k, v) in pat1_binders {
269
+ tc_state. env . insert ( k, v) ;
270
+ }
271
+
227
272
unify (
228
273
& pat1_ty,
229
274
& pat2_ty,
@@ -232,6 +277,7 @@ pub(super) fn check_pat(tc_state: &mut TcFunState, pat: &mut ast::L<ast::Pat>, l
232
277
level,
233
278
& pat. loc ,
234
279
) ;
280
+
235
281
pat1_ty
236
282
}
237
283
}
0 commit comments