Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 100 additions & 19 deletions mypy/nativeparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import os
import time
from functools import cache
from importlib import metadata
from typing import Final, cast

import ast_serialize
Expand Down Expand Up @@ -161,6 +163,81 @@

TypeIgnores = list[tuple[int, list[str]]]

AST_SERIALIZE_REQUIREMENT: Final = ">=0.3.0,<1.0.0"
AST_SERIALIZE_MIN_VERSION: Final = (0, 3, 0)
AST_SERIALIZE_MAX_VERSION: Final = (1, 0, 0)


class NativeParserError(Exception):
"""Raised when the native parser cannot produce compatible serialized data."""


@cache
def ast_serialize_version() -> str | None:
"""Return the installed ast-serialize package version, if available."""
try:
return metadata.version("ast-serialize")
except metadata.PackageNotFoundError:
return None


def _parse_release(version: str) -> tuple[int, ...] | None:
release = version.split("+", 1)[0].split("-", 1)[0]
parts = []
for part in release.split("."):
digits = ""
for char in part:
if not char.isdigit():
break
digits += char
if not digits:
break
parts.append(int(digits))
return tuple(parts) if parts else None


def _is_release_less(left: tuple[int, ...], right: tuple[int, ...]) -> bool:
size = max(len(left), len(right))
left += (0,) * (size - len(left))
right += (0,) * (size - len(right))
return left < right


@cache
def _check_ast_serialize_version() -> None:
version = ast_serialize_version()
if version is None:
return
release = _parse_release(version)
if release is None:
return
if _is_release_less(release, AST_SERIALIZE_MIN_VERSION) or not _is_release_less(
release, AST_SERIALIZE_MAX_VERSION
):
raise NativeParserError(
f"Incompatible ast-serialize version {version} is installed; "
f"mypy requires ast-serialize{AST_SERIALIZE_REQUIREMENT}. "
"Upgrade ast-serialize or reinstall mypy's dependencies."
)


def _format_native_parser_exception(err: BaseException) -> str:
detail = str(err)
return f"{type(err).__name__}: {detail}" if detail else type(err).__name__


def invalid_ast_serialize_data_message(err: BaseException) -> str:
version = ast_serialize_version()
installed = f" (installed ast-serialize: {version})" if version is not None else ""
return (
"The native parser produced serialized AST data that mypy cannot read. "
"This usually means an incompatible ast-serialize version is installed"
f"{installed}; mypy requires ast-serialize{AST_SERIALIZE_REQUIREMENT}. "
"Upgrade ast-serialize or reinstall mypy's dependencies. "
f"Original error: {_format_native_parser_exception(err)}"
)


# There is no way to create reasonable fallbacks at this stage,
# they must be patched later.
_dummy_fallback: Final = Instance(MISSING_FALLBACK, [], -1)
Expand Down Expand Up @@ -257,25 +334,29 @@ def parse_to_binary_ast(
time.sleep(0.0001) # type: ignore[unreachable]
if time.time() - t0 > 10.0:
raise ImportError("Cannot import ast_serialize")
ast_bytes, errors, ignores, import_bytes, ast_data = ast_serialize.parse(
filename,
skip_function_bodies=skip_function_bodies,
python_version=options.python_version,
platform=options.platform,
always_true=options.always_true,
always_false=options.always_false,
cache_version=3,
)
return (
ast_bytes,
errors,
ignores,
import_bytes,
ast_data["is_partial_package"],
ast_data["uses_template_strings"],
ast_data["source_hash"],
ast_data["mypy_comments"],
)
_check_ast_serialize_version()
try:
ast_bytes, errors, ignores, import_bytes, ast_data = ast_serialize.parse(
filename,
skip_function_bodies=skip_function_bodies,
python_version=options.python_version,
platform=options.platform,
always_true=options.always_true,
always_false=options.always_false,
cache_version=3,
)
return (
ast_bytes,
errors,
ignores,
import_bytes,
ast_data["is_partial_package"],
ast_data["uses_template_strings"],
ast_data["source_hash"],
ast_data["mypy_comments"],
)
except (AssertionError, KeyError, TypeError, ValueError) as err:
raise NativeParserError(invalid_ast_serialize_data_message(err)) from err


def read_statement(state: State, data: ReadBuffer) -> Statement:
Expand Down
55 changes: 43 additions & 12 deletions mypy/parse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import re
from typing import NoReturn

from librt.internal import ReadBuffer

Expand All @@ -11,6 +12,14 @@
from mypy.options import Options


def _raise_native_parser_error(
fnam: str, module: str | None, errors: Errors, options: Options, message: str
) -> NoReturn:
errors.set_file(fnam, module, options=options)
errors.report(-1, None, message, blocker=True)
errors.raise_error()


def parse(
source: str | bytes,
fnam: str,
Expand All @@ -37,9 +46,12 @@ def parse(
ignore_errors = options.ignore_errors or fnam in errors.ignored_files
# If errors are ignored, we can drop many function bodies to speed up type checking.
strip_function_bodies = ignore_errors and not options.preserve_asts
tree, _, _ = mypy.nativeparse.native_parse(
fnam, options, skip_function_bodies=strip_function_bodies
)
try:
tree, _, _ = mypy.nativeparse.native_parse(
fnam, options, skip_function_bodies=strip_function_bodies
)
except mypy.nativeparse.NativeParserError as err:
_raise_native_parser_error(fnam, module, errors, options, str(err))
# Set is_stub based on file extension
tree.is_stub = fnam.endswith(".pyi")
# Note: tree.imports is populated directly by load_from_raw() with deserialized
Expand Down Expand Up @@ -69,16 +81,34 @@ def load_from_raw(
If imports_only is true, only deserialize imports and return a mostly
empty AST.
"""
from mypy.nativeparse import State, deserialize_imports, read_statements
from mypy.nativeparse import (
State,
deserialize_imports,
invalid_ast_serialize_data_message,
read_statements,
)

state = State(options)
if imports_only:
defs = []
else:
data = ReadBuffer(raw_data.defs)
n = read_int(data)
defs = read_statements(state, data, n)
imports = deserialize_imports(raw_data.imports)
try:
if imports_only:
defs = []
else:
data = ReadBuffer(raw_data.defs)
n = read_int(data)
defs = read_statements(state, data, n)
imports = deserialize_imports(raw_data.imports)
except (
AssertionError,
EOFError,
IndexError,
KeyError,
TypeError,
UnicodeDecodeError,
ValueError,
) as err:
_raise_native_parser_error(
fnam, module, errors, options, invalid_ast_serialize_data_message(err)
)

tree = MypyFile(defs, imports)
tree.path = fnam
Expand All @@ -93,7 +123,8 @@ def load_from_raw(
all_errors = raw_data.raw_errors + state.errors
errors.set_file(fnam, module, options=options)
for error in all_errors:
# Note we never raise in this function, so it should not be called in coordinator.
# Regular parse errors are reported here; invalid serialized native parser
# data is converted to a blocking error above.
report_parse_error(error, errors)
if imports_only:
# Preserve raw data when only de-serializing imports, it will be sent to
Expand Down
20 changes: 18 additions & 2 deletions mypy/test/test_nativeparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
read_int,
)
from mypy.config_parser import parse_mypy_comments
from mypy.errors import CompileError
from mypy.nodes import MypyFile, ParseError
from mypy.errors import CompileError, Errors
from mypy.nodes import FileRawData, MypyFile, ParseError
from mypy.options import Options
from mypy.parse import load_from_raw
from mypy.test.data import DataDrivenTestCase, DataSuite
from mypy.test.helpers import assert_string_arrays_equal
from mypy.util import get_mypy_comments
Expand Down Expand Up @@ -271,6 +272,21 @@ def locs(start_line: int, start_column: int, end_line: int, end_column: int) ->
+ [END_TAG, END_TAG]
)

def test_incompatible_binary_data_reports_clear_error(self) -> None:
raw_data = FileRawData(bytes([LITERAL_NONE]), b"", [], {}, False, False)
options = Options()
errors = Errors(options)

with self.assertRaises(CompileError) as cm:
load_from_raw("bad.py", "bad", raw_data, errors, options)

self.assertEqual(cm.exception.module_with_blocker, "bad")
self.assertEqual(len(cm.exception.messages), 1)
message = cm.exception.messages[0]
self.assertIn("bad.py: error: The native parser produced serialized AST data", message)
self.assertIn("incompatible ast-serialize version", message)
self.assertIn("Original error: AssertionError", message)


@contextlib.contextmanager
def temp_source(text: str) -> Iterator[str]:
Expand Down
Loading