diff --git a/src/unexport/rule.py b/src/unexport/rule.py index fdfb61f..37d2ece 100644 --- a/src/unexport/rule.py +++ b/src/unexport/rule.py @@ -129,3 +129,18 @@ def _rule_node_is_all_item(node) -> bool: and isinstance(node.value.func.value, ast.Name) and node.value.func.value.id == "__all__" ) + + +@Rule.register((ast.Name,)) # type: ignore +def _rule_node_is_typevar(node) -> bool: + modules: list[str] = [] + if isinstance(node.parent.value, ast.Call): + if "TypeVar" == getattr(node.parent.value.func, "id", None): + modules.extend(body.module for body in node.parent.parent.body if isinstance(body, ast.ImportFrom)) + elif "typing" == getattr(node.parent.value.func.value, "id", None) and "TypeVar" == getattr(node.parent.value.func, "attr", None): # type: ignore # noqa: E501 + for body in node.parent.parent.body: + if isinstance(body, ast.Import): + modules.extend([import_alias.name for import_alias in body.names]) + if "typing" in modules: + return False + return True diff --git a/tests/test_refactor.py b/tests/test_refactor.py index 69d220b..31c3032 100644 --- a/tests/test_refactor.py +++ b/tests/test_refactor.py @@ -178,6 +178,66 @@ def x():... XXX = 1 # unexport: not-public """, ), + ( + """\ + from typing import TypeVar + + T = TypeVar("T") + """, + """\ + from typing import TypeVar + + T = TypeVar("T") + """, + ), + ( + """\ + from typing import TypeVar + + T = TypeVar("T") + + def func(): + pass + """, + """\ + from typing import TypeVar + + __all__ = ["func"] + + T = TypeVar("T") + + def func(): + pass + """, + ), + ( + """\ + import typing + + T = typing.TypeVar("T") + """, + """\ + import typing + + T = typing.TypeVar("T") + """, + ), + ( + """\ + class TypeVar: + def __init__(self, name):... + + T = TypeVar("T") + """, + """\ + __all__ = ["T", "TypeVar"] + + class TypeVar: + def __init__(self, name):... + + T = TypeVar("T") + """, + ), ]