@@ -6007,24 +6007,13 @@ mod tests {
6007
6007
use crate :: semantic_index:: definition:: Definition ;
6008
6008
use crate :: semantic_index:: symbol:: FileScopeId ;
6009
6009
use crate :: semantic_index:: { global_scope, semantic_index, symbol_table, use_def_map} ;
6010
+ use crate :: types:: check_types;
6010
6011
use ruff_db:: files:: { system_path_to_file, File } ;
6011
6012
use ruff_db:: system:: DbWithTestSystem ;
6012
6013
use ruff_db:: testing:: assert_function_query_was_not_run;
6013
6014
6014
6015
use super :: * ;
6015
6016
6016
- #[ track_caller]
6017
- fn assert_public_type ( db : & TestDb , file_name : & str , symbol_name : & str , expected : & str ) {
6018
- let file = system_path_to_file ( db, file_name) . expect ( "file to exist" ) ;
6019
-
6020
- let ty = global_symbol ( db, file, symbol_name) . expect_type ( ) ;
6021
- assert_eq ! (
6022
- ty. display( db) . to_string( ) ,
6023
- expected,
6024
- "Mismatch for symbol '{symbol_name}' in '{file_name}'"
6025
- ) ;
6026
- }
6027
-
6028
6017
#[ track_caller]
6029
6018
fn get_symbol < ' db > (
6030
6019
db : & ' db TestDb ,
@@ -6049,64 +6038,66 @@ mod tests {
6049
6038
symbol ( db, scope, symbol_name)
6050
6039
}
6051
6040
6041
+ #[ track_caller]
6042
+ fn assert_diagnostic_messages ( diagnostics : & TypeCheckDiagnostics , expected : & [ & str ] ) {
6043
+ let messages: Vec < & str > = diagnostics
6044
+ . iter ( )
6045
+ . map ( |diagnostic| diagnostic. message ( ) )
6046
+ . collect ( ) ;
6047
+ assert_eq ! ( & messages, expected) ;
6048
+ }
6049
+
6050
+ #[ track_caller]
6051
+ fn assert_file_diagnostics ( db : & TestDb , filename : & str , expected : & [ & str ] ) {
6052
+ let file = system_path_to_file ( db, filename) . unwrap ( ) ;
6053
+ let diagnostics = check_types ( db, file) ;
6054
+
6055
+ assert_diagnostic_messages ( diagnostics, expected) ;
6056
+ }
6057
+
6052
6058
#[ test]
6053
6059
fn not_literal_string ( ) -> anyhow:: Result < ( ) > {
6054
6060
let mut db = setup_db ( ) ;
6055
6061
let content = format ! (
6056
6062
r#"
6057
- v = not "{y}"
6058
- w = not 10*"{y}"
6059
- x = not "{y}"*10
6060
- z = not 0*"{y}"
6061
- u = not (-100)*"{y}"
6062
- "# ,
6063
+ from typing_extensions import assert_type
6064
+
6065
+ assert_type(not "{y}", bool)
6066
+ assert_type(not 10*"{y}", bool)
6067
+ assert_type(not "{y}"*10, bool)
6068
+ assert_type(not 0*"{y}", Literal[True])
6069
+ assert_type(not (-100)*"{y}", Literal[True])
6070
+ "# ,
6063
6071
y = "a" . repeat( TypeInferenceBuilder :: MAX_STRING_LITERAL_SIZE + 1 ) ,
6064
6072
) ;
6065
6073
db. write_dedented ( "src/a.py" , & content) ?;
6066
6074
6067
- assert_public_type ( & db, "src/a.py" , "v" , "bool" ) ;
6068
- assert_public_type ( & db, "src/a.py" , "w" , "bool" ) ;
6069
- assert_public_type ( & db, "src/a.py" , "x" , "bool" ) ;
6070
- assert_public_type ( & db, "src/a.py" , "z" , "Literal[True]" ) ;
6071
- assert_public_type ( & db, "src/a.py" , "u" , "Literal[True]" ) ;
6075
+ assert_file_diagnostics ( & db, "src/a.py" , & [ ] ) ;
6072
6076
6073
6077
Ok ( ( ) )
6074
6078
}
6075
6079
6076
6080
#[ test]
6077
6081
fn multiplied_string ( ) -> anyhow:: Result < ( ) > {
6078
6082
let mut db = setup_db ( ) ;
6079
-
6080
- db. write_dedented (
6081
- "src/a.py" ,
6082
- & format ! (
6083
- r#"
6084
- w = 2 * "hello"
6085
- x = "goodbye" * 3
6086
- y = "a" * {y}
6087
- z = {z} * "b"
6088
- a = 0 * "hello"
6089
- b = -3 * "hello"
6083
+ let content = format ! (
6084
+ r#"
6085
+ from typing_extensions import assert_type
6086
+
6087
+ assert_type(2 * "hello", Literal["hellohello"])
6088
+ assert_type("goodbye" * 3, Literal["goodbyegoodbyegoodbye"])
6089
+ assert_type("a" * {y}, Literal["{a_repeated}"])
6090
+ assert_type({z} * "b", LiteralString)
6091
+ assert_type(0 * "hello", Literal[""])
6092
+ assert_type(-3 * "hello", Literal[""])
6090
6093
"# ,
6091
- y = TypeInferenceBuilder :: MAX_STRING_LITERAL_SIZE ,
6092
- z = TypeInferenceBuilder :: MAX_STRING_LITERAL_SIZE + 1
6093
- ) ,
6094
- ) ?;
6095
-
6096
- assert_public_type ( & db, "src/a.py" , "w" , r#"Literal["hellohello"]"# ) ;
6097
- assert_public_type ( & db, "src/a.py" , "x" , r#"Literal["goodbyegoodbyegoodbye"]"# ) ;
6098
- assert_public_type (
6099
- & db,
6100
- "src/a.py" ,
6101
- "y" ,
6102
- & format ! (
6103
- r#"Literal["{}"]"# ,
6104
- "a" . repeat( TypeInferenceBuilder :: MAX_STRING_LITERAL_SIZE )
6105
- ) ,
6094
+ y = TypeInferenceBuilder :: MAX_STRING_LITERAL_SIZE ,
6095
+ z = TypeInferenceBuilder :: MAX_STRING_LITERAL_SIZE + 1 ,
6096
+ a_repeated = "a" . repeat( TypeInferenceBuilder :: MAX_STRING_LITERAL_SIZE ) ,
6106
6097
) ;
6107
- assert_public_type ( & db , "src/a.py" , "z" , "LiteralString" ) ;
6108
- assert_public_type ( & db , "src/a.py" , "a" , r#"Literal[""]"# ) ;
6109
- assert_public_type ( & db, "src/a.py" , "b" , r#"Literal[""]"# ) ;
6098
+ db . write_dedented ( "src/a.py" , & content ) ? ;
6099
+
6100
+ assert_file_diagnostics ( & db, "src/a.py" , & [ ] ) ;
6110
6101
6111
6102
Ok ( ( ) )
6112
6103
}
@@ -6116,21 +6107,20 @@ mod tests {
6116
6107
let mut db = setup_db ( ) ;
6117
6108
let content = format ! (
6118
6109
r#"
6119
- v = "{y}"
6120
- w = 10*"{y}"
6121
- x = "{y}"*10
6122
- z = 0*"{y}"
6123
- u = (-100)*"{y}"
6124
- "# ,
6110
+ from typing_extensions import assert_type
6111
+
6112
+ assert_type("{y}", LiteralString)
6113
+ assert_type(10*"{y}", LiteralString)
6114
+ assert_type("{y}"*10, LiteralString)
6115
+ assert_type(0*"{y}", Literal[""])
6116
+ assert_type((-100)*"{y}", Literal[""])
6117
+ "# ,
6125
6118
y = "a" . repeat( TypeInferenceBuilder :: MAX_STRING_LITERAL_SIZE + 1 ) ,
6126
6119
) ;
6127
6120
db. write_dedented ( "src/a.py" , & content) ?;
6128
6121
6129
- assert_public_type ( & db, "src/a.py" , "v" , "LiteralString" ) ;
6130
- assert_public_type ( & db, "src/a.py" , "w" , "LiteralString" ) ;
6131
- assert_public_type ( & db, "src/a.py" , "x" , "LiteralString" ) ;
6132
- assert_public_type ( & db, "src/a.py" , "z" , r#"Literal[""]"# ) ;
6133
- assert_public_type ( & db, "src/a.py" , "u" , r#"Literal[""]"# ) ;
6122
+ assert_file_diagnostics ( & db, "src/a.py" , & [ ] ) ;
6123
+
6134
6124
Ok ( ( ) )
6135
6125
}
6136
6126
@@ -6139,16 +6129,17 @@ mod tests {
6139
6129
let mut db = setup_db ( ) ;
6140
6130
let content = format ! (
6141
6131
r#"
6142
- w = "{y}"
6143
- x = "a" + "{z}"
6144
- "# ,
6132
+ from typing_extensions import assert_type
6133
+
6134
+ assert_type("{y}", LiteralString)
6135
+ assert_type("a" + "{z}", LiteralString)
6136
+ "# ,
6145
6137
y = "a" . repeat( TypeInferenceBuilder :: MAX_STRING_LITERAL_SIZE + 1 ) ,
6146
6138
z = "a" . repeat( TypeInferenceBuilder :: MAX_STRING_LITERAL_SIZE ) ,
6147
6139
) ;
6148
6140
db. write_dedented ( "src/a.py" , & content) ?;
6149
6141
6150
- assert_public_type ( & db, "src/a.py" , "w" , "LiteralString" ) ;
6151
- assert_public_type ( & db, "src/a.py" , "x" , "LiteralString" ) ;
6142
+ assert_file_diagnostics ( & db, "src/a.py" , & [ ] ) ;
6152
6143
6153
6144
Ok ( ( ) )
6154
6145
}
@@ -6158,19 +6149,18 @@ mod tests {
6158
6149
let mut db = setup_db ( ) ;
6159
6150
let content = format ! (
6160
6151
r#"
6161
- v = "{y}"
6162
- w = "{y}" + "a"
6163
- x = "a" + "{y}"
6164
- z = "{y}" + "{y}"
6165
- "# ,
6152
+ from typing_extensions import assert_type
6153
+
6154
+ assert_type("{y}", LiteralString)
6155
+ assert_type("{y}" + "a", LiteralString)
6156
+ assert_type("a" + "{y}", LiteralString)
6157
+ assert_type("{y}" + "{y}", LiteralString)
6158
+ "# ,
6166
6159
y = "a" . repeat( TypeInferenceBuilder :: MAX_STRING_LITERAL_SIZE + 1 ) ,
6167
6160
) ;
6168
6161
db. write_dedented ( "src/a.py" , & content) ?;
6169
6162
6170
- assert_public_type ( & db, "src/a.py" , "v" , "LiteralString" ) ;
6171
- assert_public_type ( & db, "src/a.py" , "w" , "LiteralString" ) ;
6172
- assert_public_type ( & db, "src/a.py" , "x" , "LiteralString" ) ;
6173
- assert_public_type ( & db, "src/a.py" , "z" , "LiteralString" ) ;
6163
+ assert_file_diagnostics ( & db, "src/a.py" , & [ ] ) ;
6174
6164
6175
6165
Ok ( ( ) )
6176
6166
}
@@ -6257,22 +6247,22 @@ mod tests {
6257
6247
6258
6248
db. write_files ( [
6259
6249
( "/src/a.py" , "from foo import x" ) ,
6260
- ( "/src/foo.py" , "x = 10\n def foo(): ..." ) ,
6250
+ ( "/src/foo.py" , "x: int = 10\n def foo(): ..." ) ,
6261
6251
] ) ?;
6262
6252
6263
6253
let a = system_path_to_file ( & db, "/src/a.py" ) . unwrap ( ) ;
6264
6254
let x_ty = global_symbol ( & db, a, "x" ) . expect_type ( ) ;
6265
6255
6266
- assert_eq ! ( x_ty. display( & db) . to_string( ) , "Literal[10] " ) ;
6256
+ assert_eq ! ( x_ty. display( & db) . to_string( ) , "int " ) ;
6267
6257
6268
6258
// Change `x` to a different value
6269
- db. write_file ( "/src/foo.py" , "x = 20 \n def foo(): ..." ) ?;
6259
+ db. write_file ( "/src/foo.py" , "x: bool = True \n def foo(): ..." ) ?;
6270
6260
6271
6261
let a = system_path_to_file ( & db, "/src/a.py" ) . unwrap ( ) ;
6272
6262
6273
6263
let x_ty_2 = global_symbol ( & db, a, "x" ) . expect_type ( ) ;
6274
6264
6275
- assert_eq ! ( x_ty_2. display( & db) . to_string( ) , "Literal[20] " ) ;
6265
+ assert_eq ! ( x_ty_2. display( & db) . to_string( ) , "bool " ) ;
6276
6266
6277
6267
Ok ( ( ) )
6278
6268
}
@@ -6283,23 +6273,23 @@ mod tests {
6283
6273
6284
6274
db. write_files ( [
6285
6275
( "/src/a.py" , "from foo import x" ) ,
6286
- ( "/src/foo.py" , "x = 10\n def foo(): y = 1" ) ,
6276
+ ( "/src/foo.py" , "x: int = 10\n def foo(): y = 1" ) ,
6287
6277
] ) ?;
6288
6278
6289
6279
let a = system_path_to_file ( & db, "/src/a.py" ) . unwrap ( ) ;
6290
6280
let x_ty = global_symbol ( & db, a, "x" ) . expect_type ( ) ;
6291
6281
6292
- assert_eq ! ( x_ty. display( & db) . to_string( ) , "Literal[10] " ) ;
6282
+ assert_eq ! ( x_ty. display( & db) . to_string( ) , "int " ) ;
6293
6283
6294
- db. write_file ( "/src/foo.py" , "x = 10\n def foo(): pass" ) ?;
6284
+ db. write_file ( "/src/foo.py" , "x: int = 10\n def foo(): pass" ) ?;
6295
6285
6296
6286
let a = system_path_to_file ( & db, "/src/a.py" ) . unwrap ( ) ;
6297
6287
6298
6288
db. clear_salsa_events ( ) ;
6299
6289
6300
6290
let x_ty_2 = global_symbol ( & db, a, "x" ) . expect_type ( ) ;
6301
6291
6302
- assert_eq ! ( x_ty_2. display( & db) . to_string( ) , "Literal[10] " ) ;
6292
+ assert_eq ! ( x_ty_2. display( & db) . to_string( ) , "int " ) ;
6303
6293
6304
6294
let events = db. take_salsa_events ( ) ;
6305
6295
@@ -6319,23 +6309,23 @@ mod tests {
6319
6309
6320
6310
db. write_files ( [
6321
6311
( "/src/a.py" , "from foo import x" ) ,
6322
- ( "/src/foo.py" , "x = 10\n y = 20 " ) ,
6312
+ ( "/src/foo.py" , "x: int = 10\n y: bool = True " ) ,
6323
6313
] ) ?;
6324
6314
6325
6315
let a = system_path_to_file ( & db, "/src/a.py" ) . unwrap ( ) ;
6326
6316
let x_ty = global_symbol ( & db, a, "x" ) . expect_type ( ) ;
6327
6317
6328
- assert_eq ! ( x_ty. display( & db) . to_string( ) , "Literal[10] " ) ;
6318
+ assert_eq ! ( x_ty. display( & db) . to_string( ) , "int " ) ;
6329
6319
6330
- db. write_file ( "/src/foo.py" , "x = 10\n y = 30 " ) ?;
6320
+ db. write_file ( "/src/foo.py" , "x: int = 10\n y: bool = False " ) ?;
6331
6321
6332
6322
let a = system_path_to_file ( & db, "/src/a.py" ) . unwrap ( ) ;
6333
6323
6334
6324
db. clear_salsa_events ( ) ;
6335
6325
6336
6326
let x_ty_2 = global_symbol ( & db, a, "x" ) . expect_type ( ) ;
6337
6327
6338
- assert_eq ! ( x_ty_2. display( & db) . to_string( ) , "Literal[10] " ) ;
6328
+ assert_eq ! ( x_ty_2. display( & db) . to_string( ) , "int " ) ;
6339
6329
6340
6330
let events = db. take_salsa_events ( ) ;
6341
6331
0 commit comments