diff --git a/mypy/checkstrformat.py b/mypy/checkstrformat.py index 302f077b5bd9..e1c4f71fa61b 100644 --- a/mypy/checkstrformat.py +++ b/mypy/checkstrformat.py @@ -51,13 +51,13 @@ def compile_format_re() -> Pattern[str]: The regexp is intentionally a bit wider to report better errors. """ key_re = r'(\((?P[^)]*)\))?' # (optional) parenthesised sequence of characters. - flags_re = r'(?P[#0\-+ ]*)' # (optional) sequence of flags. + flags_re = r'(?P[#0\-+ ]*)' # (optional) sequence of flags. width_re = r'(?P[1-9][0-9]*|\*)?' # (optional) minimum field width (* or numbers). precision_re = r'(?:\.(?P\*|[0-9]+)?)?' # (optional) . followed by * of numbers. length_mod_re = r'[hlL]?' # (optional) length modifier (unused). type_re = r'(?P.)?' # conversion type. format_re = '%' + key_re + flags_re + width_re + precision_re + length_mod_re + type_re - return re.compile('({})'.format(format_re)) + return re.compile(format_re) def compile_new_format_re(custom_spec: bool) -> Pattern[str]: @@ -83,8 +83,8 @@ def compile_new_format_re(custom_spec: bool) -> Pattern[str]: # This contains sign, flags (sign, # and/or 0), width, grouping (_ or ,) and precision. num_spec = r'(?P[+\- ]?#?0?)(?P\d+)?[_,]?(?P\.\d+)?' # The last element is type. - type = r'(?P.)?' # only some are supported, but we want to give a better error - format_spec = r'(?P:' + fill_align + num_spec + type + r')?' + conv_type = r'(?P.)?' # only some are supported, but we want to give a better error + format_spec = r'(?P:' + fill_align + num_spec + conv_type + r')?' else: # Custom types can define their own form_spec using __format__(). format_spec = r'(?P:.*)?' @@ -114,52 +114,30 @@ def compile_new_format_re(custom_spec: bool) -> Pattern[str]: class ConversionSpecifier: - def __init__(self, type: str, - key: Optional[str], - flags: Optional[str], - width: Optional[str], - precision: Optional[str], - format_spec: Optional[str] = None, - conversion: Optional[str] = None, - field: Optional[str] = None, - whole_seq: Optional[str] = None) -> None: - self.type = type - self.key = key - self.flags = flags - self.width = width - self.precision = precision + def __init__(self, match: Match[str], + start_pos: int = -1, + non_standard_format_spec: bool = False) -> None: + + self.whole_seq = match.group() + self.start_pos = start_pos + + m_dict = match.groupdict() + self.key = m_dict.get('key') + + # Replace unmatched optional groups with empty matches (for convenience). + self.conv_type = m_dict.get('type', '') + self.flags = m_dict.get('flags', '') + self.width = m_dict.get('width', '') + self.precision = m_dict.get('precision', '') + # Used only for str.format() calls (it may be custom for types with __format__()). - self.format_spec = format_spec - self.non_standard_format_spec = False + self.format_spec = m_dict.get('format_spec') + self.non_standard_format_spec = non_standard_format_spec # Used only for str.format() calls. - self.conversion = conversion + self.conversion = m_dict.get('conversion') # Full formatted expression (i.e. key plus following attributes and/or indexes). # Used only for str.format() calls. - self.field = field - self.whole_seq = whole_seq - - @classmethod - def from_match(cls, match: Match[str], - non_standard_spec: bool = False) -> 'ConversionSpecifier': - """Construct specifier from match object resulted from parsing str.format() call.""" - if non_standard_spec: - spec = cls(type='', - key=match.group('key'), - flags='', width='', precision='', - format_spec=match.group('format_spec'), - conversion=match.group('conversion'), - field=match.group('field')) - spec.non_standard_format_spec = True - return spec - # Replace unmatched optional groups with empty matches (for convenience). - return cls(type=match.group('type') or '', - key=match.group('key'), - flags=match.group('flags') or '', - width=match.group('width') or '', - precision=match.group('precision') or '', - format_spec=match.group('format_spec'), - conversion=match.group('conversion'), - field=match.group('field')) + self.field = m_dict.get('field') def has_key(self) -> bool: return self.key is not None @@ -168,6 +146,112 @@ def has_star(self) -> bool: return self.width == '*' or self.precision == '*' +def parse_conversion_specifiers(format_str: str) -> List[ConversionSpecifier]: + """Parse c-printf-style format string into list of conversion specifiers.""" + specifiers: List[ConversionSpecifier] = [] + for m in re.finditer(FORMAT_RE, format_str): + specifiers.append(ConversionSpecifier(m, start_pos=m.start())) + return specifiers + + +def parse_format_value(format_value: str, ctx: Context, msg: MessageBuilder, + nested: bool = False) -> Optional[List[ConversionSpecifier]]: + """Parse format string into list of conversion specifiers. + + The specifiers may be nested (two levels maximum), in this case they are ordered as + '{0:{1}}, {2:{3}{4}}'. Return None in case of an error. + """ + top_targets = find_non_escaped_targets(format_value, ctx, msg) + if top_targets is None: + return None + + result: List[ConversionSpecifier] = [] + for target, start_pos in top_targets: + match = FORMAT_RE_NEW.fullmatch(target) + if match: + conv_spec = ConversionSpecifier(match, start_pos=start_pos) + else: + custom_match = FORMAT_RE_NEW_CUSTOM.fullmatch(target) + if custom_match: + conv_spec = ConversionSpecifier( + custom_match, start_pos=start_pos, + non_standard_format_spec=True) + else: + msg.fail('Invalid conversion specifier in format string', + ctx, code=codes.STRING_FORMATTING) + return None + + if conv_spec.key and ('{' in conv_spec.key or '}' in conv_spec.key): + msg.fail('Conversion value must not contain { or }', + ctx, code=codes.STRING_FORMATTING) + return None + result.append(conv_spec) + + # Parse nested conversions that are allowed in format specifier. + if (conv_spec.format_spec and conv_spec.non_standard_format_spec and + ('{' in conv_spec.format_spec or '}' in conv_spec.format_spec)): + if nested: + msg.fail('Formatting nesting must be at most two levels deep', + ctx, code=codes.STRING_FORMATTING) + return None + sub_conv_specs = parse_format_value(conv_spec.format_spec, ctx, msg, + nested=True) + if sub_conv_specs is None: + return None + result.extend(sub_conv_specs) + return result + + +def find_non_escaped_targets(format_value: str, ctx: Context, + msg: MessageBuilder) -> Optional[List[Tuple[str, int]]]: + """Return list of raw (un-parsed) format specifiers in format string. + + Format specifiers don't include enclosing braces. We don't use regexp for + this because they don't work well with nested/repeated patterns + (both greedy and non-greedy), and these are heavily used internally for + representation of f-strings. + + Return None in case of an error. + """ + result = [] + next_spec = '' + pos = 0 + nesting = 0 + while pos < len(format_value): + c = format_value[pos] + if not nesting: + # Skip any paired '{{' and '}}', enter nesting on '{', report error on '}'. + if c == '{': + if pos < len(format_value) - 1 and format_value[pos + 1] == '{': + pos += 1 + else: + nesting = 1 + if c == '}': + if pos < len(format_value) - 1 and format_value[pos + 1] == '}': + pos += 1 + else: + msg.fail('Invalid conversion specifier in format string:' + ' unexpected }', ctx, code=codes.STRING_FORMATTING) + return None + else: + # Adjust nesting level, then either continue adding chars or move on. + if c == '{': + nesting += 1 + if c == '}': + nesting -= 1 + if nesting: + next_spec += c + else: + result.append((next_spec, pos - len(next_spec))) + next_spec = '' + pos += 1 + if nesting: + msg.fail('Invalid conversion specifier in format string:' + ' unmatched {', ctx, code=codes.STRING_FORMATTING) + return None + return result + + class StringFormatterChecker: """String interpolation/formatter type checker. @@ -214,107 +298,13 @@ def check_str_format_call(self, call: CallExpr, format_value: str) -> None: - 's' must not accept bytes - non-empty flags are only allowed for numeric types """ - conv_specs = self.parse_format_value(format_value, call) + conv_specs = parse_format_value(format_value, call, self.msg) if conv_specs is None: return if not self.auto_generate_keys(conv_specs, call): return self.check_specs_in_format_call(call, conv_specs, format_value) - def parse_format_value(self, format_value: str, ctx: Context, - nested: bool = False) -> Optional[List[ConversionSpecifier]]: - """Parse format string into list of conversion specifiers. - - The specifiers may be nested (two levels maximum), in this case they are ordered as - '{0:{1}}, {2:{3}{4}}'. Return None in case of an error. - """ - top_targets = self.find_non_escaped_targets(format_value, ctx) - if top_targets is None: - return None - - result: List[ConversionSpecifier] = [] - for target in top_targets: - match = FORMAT_RE_NEW.fullmatch(target) - if match: - conv_spec = ConversionSpecifier.from_match(match) - else: - custom_match = FORMAT_RE_NEW_CUSTOM.fullmatch(target) - if custom_match: - conv_spec = ConversionSpecifier.from_match(custom_match, - non_standard_spec=True) - else: - self.msg.fail('Invalid conversion specifier in format string', - ctx, code=codes.STRING_FORMATTING) - return None - - if conv_spec.key and ('{' in conv_spec.key or '}' in conv_spec.key): - self.msg.fail('Conversion value must not contain { or }', - ctx, code=codes.STRING_FORMATTING) - return None - result.append(conv_spec) - - # Parse nested conversions that are allowed in format specifier. - if (conv_spec.format_spec and conv_spec.non_standard_format_spec and - ('{' in conv_spec.format_spec or '}' in conv_spec.format_spec)): - if nested: - self.msg.fail('Formatting nesting must be at most two levels deep', - ctx, code=codes.STRING_FORMATTING) - return None - sub_conv_specs = self.parse_format_value(conv_spec.format_spec, ctx=ctx, - nested=True) - if sub_conv_specs is None: - return None - result.extend(sub_conv_specs) - return result - - def find_non_escaped_targets(self, format_value: str, ctx: Context) -> Optional[List[str]]: - """Return list of raw (un-parsed) format specifiers in format string. - - Format specifiers don't include enclosing braces. We don't use regexp for - this because they don't work well with nested/repeated patterns - (both greedy and non-greedy), and these are heavily used internally for - representation of f-strings. - - Return None in case of an error. - """ - result = [] - next_spec = '' - pos = 0 - nesting = 0 - while pos < len(format_value): - c = format_value[pos] - if not nesting: - # Skip any paired '{{' and '}}', enter nesting on '{', report error on '}'. - if c == '{': - if pos < len(format_value) - 1 and format_value[pos + 1] == '{': - pos += 1 - else: - nesting = 1 - if c == '}': - if pos < len(format_value) - 1 and format_value[pos + 1] == '}': - pos += 1 - else: - self.msg.fail('Invalid conversion specifier in format string:' - ' unexpected }', ctx, code=codes.STRING_FORMATTING) - return None - else: - # Adjust nesting level, then either continue adding chars or move on. - if c == '{': - nesting += 1 - if c == '}': - nesting -= 1 - if nesting: - next_spec += c - else: - result.append(next_spec) - next_spec = '' - pos += 1 - if nesting: - self.msg.fail('Invalid conversion specifier in format string:' - ' unmatched {', ctx, code=codes.STRING_FORMATTING) - return None - return result - def check_specs_in_format_call(self, call: CallExpr, specs: List[ConversionSpecifier], format_value: str) -> None: """Perform pairwise checks for conversion specifiers vs their replacements. @@ -341,7 +331,7 @@ def check_specs_in_format_call(self, call: CallExpr, call, code=codes.STRING_FORMATTING) continue # Adjust expected and actual types. - if not spec.type: + if not spec.conv_type: expected_type: Optional[Type] = AnyType(TypeOfAny.special_form) else: assert isinstance(call.callee, MemberExpr) @@ -349,7 +339,7 @@ def check_specs_in_format_call(self, call: CallExpr, format_str = call.callee.expr else: format_str = StrExpr(format_value) - expected_type = self.conversion_type(spec.type, call, format_str, + expected_type = self.conversion_type(spec.conv_type, call, format_str, format_call=True) if spec.conversion is not None: # If the explicit conversion is given, then explicit conversion is called _first_. @@ -376,7 +366,7 @@ def perform_special_format_checks(self, spec: ConversionSpecifier, call: CallExp repl: Expression, actual_type: Type, expected_type: Type) -> None: # TODO: try refactoring to combine this logic with % formatting. - if spec.type == 'c': + if spec.conv_type == 'c': if isinstance(repl, (StrExpr, BytesExpr)) and len(repl.value) != 1: self.msg.requires_int_or_char(call, format_call=True) c_typ = get_proper_type(self.chk.type_map[repl]) @@ -385,7 +375,7 @@ def perform_special_format_checks(self, spec: ConversionSpecifier, call: CallExp if isinstance(c_typ, LiteralType) and isinstance(c_typ.value, str): if len(c_typ.value) != 1: self.msg.requires_int_or_char(call, format_call=True) - if (not spec.type or spec.type == 's') and not spec.conversion: + if (not spec.conv_type or spec.conv_type == 's') and not spec.conversion: if self.chk.options.python_version >= (3, 0): if (has_type_component(actual_type, 'builtins.bytes') and not custom_special_method(actual_type, '__str__')): @@ -396,8 +386,8 @@ def perform_special_format_checks(self, spec: ConversionSpecifier, call: CallExp if spec.flags: numeric_types = UnionType([self.named_type('builtins.int'), self.named_type('builtins.float')]) - if (spec.type and spec.type not in NUMERIC_TYPES_NEW or - not spec.type and not is_subtype(actual_type, numeric_types) and + if (spec.conv_type and spec.conv_type not in NUMERIC_TYPES_NEW or + not spec.conv_type and not is_subtype(actual_type, numeric_types) and not custom_special_method(actual_type, '__format__')): self.msg.fail('Numeric flags are only allowed for numeric types', call, code=codes.STRING_FORMATTING) @@ -593,7 +583,7 @@ class User(TypedDict): spec=spec, ctx=ctx) # TODO: In Python 3, the bytes formatting has a more restricted set of options - # compared to string formatting. + # compared to string formatting. def check_str_interpolation(self, expr: FormatStringExpr, replacements: Expression) -> Type: @@ -601,7 +591,7 @@ def check_str_interpolation(self, expression: str % replacements. """ self.exprchk.accept(expr) - specifiers = self.parse_conversion_specifiers(expr.value) + specifiers = parse_conversion_specifiers(expr.value) has_mapping_keys = self.analyze_conversion_specifiers(specifiers, expr) if isinstance(expr, BytesExpr) and (3, 0) <= self.chk.options.python_version < (3, 5): self.msg.fail('Bytes formatting is only supported in Python 3.5 and later', @@ -627,22 +617,12 @@ def check_str_interpolation(self, else: assert False - def parse_conversion_specifiers(self, format: str) -> List[ConversionSpecifier]: - specifiers: List[ConversionSpecifier] = [] - for whole_seq, parens_key, key, flags, width, precision, type \ - in FORMAT_RE.findall(format): - if parens_key == '': - key = None - specifiers.append(ConversionSpecifier(type, key, flags, width, precision, - whole_seq=whole_seq)) - return specifiers - def analyze_conversion_specifiers(self, specifiers: List[ConversionSpecifier], context: Context) -> Optional[bool]: has_star = any(specifier.has_star() for specifier in specifiers) has_key = any(specifier.has_key() for specifier in specifiers) all_have_keys = all( - specifier.has_key() or specifier.type == '%' for specifier in specifiers + specifier.has_key() or specifier.conv_type == '%' for specifier in specifiers ) if has_key and has_star: @@ -717,7 +697,7 @@ def check_mapping_str_interpolation(self, specifiers: List[ConversionSpecifier], mapping[key_str] = self.accept(v) for specifier in specifiers: - if specifier.type == '%': + if specifier.conv_type == '%': # %% is allowed in mappings, no checking is required continue assert specifier.key is not None @@ -725,7 +705,8 @@ def check_mapping_str_interpolation(self, specifiers: List[ConversionSpecifier], self.msg.key_not_in_mapping(specifier.key, replacements) return rep_type = mapping[specifier.key] - expected_type = self.conversion_type(specifier.type, replacements, expr) + assert specifier.conv_type is not None + expected_type = self.conversion_type(specifier.conv_type, replacements, expr) if expected_type is None: return self.chk.check_subtype(rep_type, expected_type, replacements, @@ -733,7 +714,7 @@ def check_mapping_str_interpolation(self, specifiers: List[ConversionSpecifier], 'expression has type', 'placeholder with key \'%s\' has type' % specifier.key, code=codes.STRING_FORMATTING) - if specifier.type == 's': + if specifier.conv_type == 's': self.check_s_special_cases(expr, rep_type, expr) else: rep_type = self.accept(replacements) @@ -789,13 +770,14 @@ def replacement_checkers(self, specifier: ConversionSpecifier, context: Context, checkers.append(self.checkers_for_star(context)) if specifier.precision == '*': checkers.append(self.checkers_for_star(context)) - if specifier.type == 'c': - c = self.checkers_for_c_type(specifier.type, context, expr) + + if specifier.conv_type == 'c': + c = self.checkers_for_c_type(specifier.conv_type, context, expr) if c is None: return None checkers.append(c) - elif specifier.type != '%': - c = self.checkers_for_regular_type(specifier.type, context, expr) + elif specifier.conv_type is not None and specifier.conv_type != '%': + c = self.checkers_for_regular_type(specifier.conv_type, context, expr) if c is None: return None checkers.append(c) @@ -824,20 +806,20 @@ def check_placeholder_type(self, typ: Type, expected_type: Type, context: Contex 'expression has type', 'placeholder has type', code=codes.STRING_FORMATTING) - def checkers_for_regular_type(self, type: str, + def checkers_for_regular_type(self, conv_type: str, context: Context, expr: FormatStringExpr) -> Optional[Checkers]: """Returns a tuple of check functions that check whether, respectively, a node or a type is compatible with 'type'. Return None in case of an error. """ - expected_type = self.conversion_type(type, context, expr) + expected_type = self.conversion_type(conv_type, context, expr) if expected_type is None: return None def check_type(typ: Type) -> bool: assert expected_type is not None ret = self.check_placeholder_type(typ, expected_type, context) - if ret and type == 's': + if ret and conv_type == 's': ret = self.check_s_special_cases(expr, typ, context) return ret diff --git a/mypyc/irbuild/format_str_tokenizer.py b/mypyc/irbuild/format_str_tokenizer.py index a09c1dffb597..d31201c02b81 100644 --- a/mypyc/irbuild/format_str_tokenizer.py +++ b/mypyc/irbuild/format_str_tokenizer.py @@ -1,10 +1,9 @@ """Tokenizers for three string formatting methods""" -import re from typing import List, Tuple from mypy.checkstrformat import ( - FORMAT_RE, ConversionSpecifier + ConversionSpecifier, parse_conversion_specifiers ) from mypyc.ir.ops import Value, Integer from mypyc.ir.rtypes import c_pyssize_t_rprimitive @@ -19,18 +18,13 @@ def tokenizer_printf_style(format_str: str) -> Tuple[List[str], List[ConversionS A list of string literals and a list of conversion operations """ literals: List[str] = [] - specifiers: List[ConversionSpecifier] = [] - last_end = 0 - - for m in re.finditer(FORMAT_RE, format_str): - whole_seq, parens_key, key, flags, width, precision, conversion_type = m.groups() - specifiers.append(ConversionSpecifier(conversion_type, key, flags, width, precision, - whole_seq=whole_seq)) + specifiers: List[ConversionSpecifier] = parse_conversion_specifiers(format_str) - cur_start = m.start(1) + last_end = 0 + for spec in specifiers: + cur_start = spec.start_pos literals.append(format_str[last_end:cur_start]) - last_end = cur_start + len(whole_seq) - + last_end = cur_start + len(spec.whole_seq) literals.append(format_str[last_end:]) return literals, specifiers diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 8dd1d10c2a24..18cd898ef493 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -13,10 +13,14 @@ """ from typing import Callable, Optional, Dict, Tuple, List +from typing_extensions import Final +from mypy.checkstrformat import parse_format_value +from mypy.errors import Errors +from mypy.messages import MessageBuilder from mypy.nodes import ( CallExpr, RefExpr, MemberExpr, NameExpr, TupleExpr, GeneratorExpr, - ListExpr, DictExpr, StrExpr, ARG_POS + ListExpr, DictExpr, StrExpr, ARG_POS, Context ) from mypy.types import AnyType, TypeOfAny @@ -388,17 +392,37 @@ def translate_dict_setdefault( return None +# The empty Context as an argument for parse_format_value(). +# It wouldn't be used since the code has passed the type-checking. +EMPTY_CONTEXT: Final = Context() + + @specialize_function('format', str_rprimitive) def translate_str_format( builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: if (isinstance(callee, MemberExpr) and isinstance(callee.expr, StrExpr) and expr.arg_kinds.count(ARG_POS) == len(expr.arg_kinds)): - format_str = callee.expr.value - if not can_optimize_format(format_str): + + # Creates an empty MessageBuilder here. + # It wouldn't be used since the code has passed the type-checking. + specifiers = parse_format_value(format_str, EMPTY_CONTEXT, + MessageBuilder(Errors(), {})) + if specifiers is None: return None - literals = split_braces(format_str) + literals = [] + last_pos = 0 + for spec in specifiers: + # Only empty curly brace is allowed + if spec.whole_seq: + return None + literals.append(format_str[last_pos:spec.start_pos-1]) + last_pos = spec.start_pos + len(spec.whole_seq) + 1 + literals.append(format_str[last_pos:]) + + # Deal with escaped {{ + literals = [x.replace('{{', '{').replace('}}', '}') for x in literals] # Convert variables to strings variables = [] @@ -416,61 +440,6 @@ def translate_str_format( return None -def can_optimize_format(format_str: str) -> bool: - # TODO - # Only empty braces can be optimized - prev = '' - for c in format_str: - if (c == '{' and prev == '{' - or c == '}' and prev == '}'): - prev = '' - continue - if (prev != '' and (c == '}' and prev != '{' - or prev == '{' and c != '}')): - return False - prev = c - return True - - -def split_braces(format_str: str) -> List[str]: - # This function can only be called after format_str passes - # 'can_optimize_format()'. - tmp_str = '' - ret_list = [] - prev = '' - for c in format_str: - # There are three cases: {, }, others - # when c is '}': prev is '{' -> match empty braces - # '}' -> merge into one } in literal - # others -> pass - # c is '{': prev is '{' -> merge into one { in literal - # '}' -> pass - # others -> pass - # c is others: add c into literal - clear_prev = True - if c == '}': - if prev == '{': - ret_list.append(tmp_str) - tmp_str = '' - elif prev == '}': - tmp_str += '}' - else: - clear_prev = False - elif c == '{': - if prev == '{': - tmp_str += '{' - else: - clear_prev = False - else: - tmp_str += c - clear_prev = False - prev = c - if clear_prev: - prev = '' - ret_list.append(tmp_str) - return ret_list - - @specialize_function('join', str_rprimitive) def translate_fstring( builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index a896308610ea..eb0c754a73f1 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -164,12 +164,12 @@ def f(s: str, num: int) -> None: s1 = "Hi! I'm {}, and I'm {} years old.".format(s, num) s2 = ''.format() s3 = 'abc'.format() - s3 = '}}{}{{{}}}{{{}'.format(num, num, num) + s4 = '}}{}{{{}}}{{{}'.format(num, num, num) [out] def f(s, num): s :: str num :: int - r0, r1, r2, r3, r4, s1, r5, s2, r6, s3, r7, r8, r9, r10, r11, r12, r13 :: str + r0, r1, r2, r3, r4, s1, r5, s2, r6, s3, r7, r8, r9, r10, r11, r12, r13, s4 :: str L0: r0 = CPyTagged_Str(num) r1 = "Hi! I'm " @@ -188,7 +188,7 @@ L0: r11 = '{' r12 = '}{' r13 = CPyStr_Build(6, r10, r7, r11, r8, r12, r9) - s3 = r13 + s4 = r13 return 1 [case testFStrings]