# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import functools
import re
from collections.abc import Callable, Collection, Generator, Mapping, Sequence
from contextlib import suppress
from copy import deepcopy
from dataclasses import dataclass, replace
from pathlib import Path
from types import FunctionType, MappingProxyType, UnionType
from typing import (
Any,
ClassVar,
Union,
get_args,
get_origin,
)
from libcst import (
BatchableCSTVisitor,
Comma,
Comment,
CSTNode,
Decorator,
EmptyLine,
IndentedBlock,
Module,
SimpleStatementSuite,
TrailingWhitespace,
)
from libcst.metadata import (
CodePosition,
CodeRange,
ParentNodeProvider,
ProviderT,
)
from .ftypes import (
FileContent,
Invalid,
LintViolation,
NodeReplacement,
Valid,
VisitHook,
VisitorMethod,
parse_lint_ignore_comment,
)
SourcePattern = str | bytes
RuleReference = str | tuple[str, str]
def rule_name_from_class_name(class_name: str) -> str:
words = re.findall(
r"[A-Z]+[0-9]+(?=[A-Z]|$)|[A-Z]+(?=[A-Z][a-z]|$)|[A-Z]?[a-z]+|[0-9]+",
class_name,
)
return re.sub(r"[^a-z0-9]+", "-", "-".join(words).lower()).strip("-")
def _normalize_source_pattern(pattern: SourcePattern) -> bytes:
return pattern.encode("utf-8") if isinstance(pattern, str) else pattern
def _source_pattern_matches(source: FileContent, pattern: SourcePattern) -> bool:
pattern = _normalize_source_pattern(pattern)
whitespace = rb"[ \t\f\r\n]*"
if pattern in source:
return True
if pattern.endswith(b"("):
call_name = pattern[:-1]
if re.fullmatch(rb"[A-Za-z_][A-Za-z0-9_]*", call_name):
return re.search(re.escape(call_name) + whitespace + rb"\(", source) is not None
if pattern.startswith(b"."):
attr_name = pattern[1:]
if re.fullmatch(rb"[A-Za-z_][A-Za-z0-9_]*", attr_name):
return re.search(rb"\." + whitespace + re.escape(attr_name), source) is not None
if pattern.endswith(b" "):
keyword = pattern[:-1]
if re.fullmatch(rb"[A-Za-z_][A-Za-z0-9_]*", keyword):
return re.search(re.escape(keyword) + rb"[ \t\f]+", source) is not None
return False
class RuleConfigurationError(ValueError):
pass
_RULE_SETTING_MISSING = object()
_SCALAR_SETTING_TYPES = (str, int, float, bool)
_VISITOR_NAMES_BY_RULE_TYPE: dict[type[LintRule], tuple[str, ...]] = {}
def _is_no_op_visitor(member: FunctionType) -> bool:
return vars(member).get("_is_no_op") is True
def _is_scalar_setting_type(value: object) -> bool:
return isinstance(value, type) and value in _SCALAR_SETTING_TYPES
def _is_instance_for_type(value: object, expected: type[object]) -> bool:
if expected is bool:
return isinstance(value, bool)
if expected is int:
return type(value) is int
if expected is float:
return type(value) is float
if expected is str:
return isinstance(value, str)
return isinstance(value, expected)
def _type_name(expected_type: object) -> str:
name = getattr(expected_type, "__name__", None)
if name is not None:
return name
return repr(expected_type)
def _validate_union_value(
value: object,
expected_type: object,
*,
setting_name: str,
rule_name: str,
path: str,
) -> None:
for option_type in get_args(expected_type):
with suppress(RuleConfigurationError):
_validate_value_for_type(
value,
option_type,
setting_name=setting_name,
rule_name=rule_name,
path=path,
)
return
expected = " | ".join(_type_name(option_type) for option_type in get_args(expected_type))
raise RuleConfigurationError(
f"{rule_name}: setting {setting_name!r} at {path} expected {expected}, got {type(value)!r}"
)
def _validate_list_value(
value: object,
expected_type: object,
*,
setting_name: str,
rule_name: str,
path: str,
) -> None:
args = get_args(expected_type)
if len(args) != 1:
raise RuleConfigurationError(
f"{rule_name}: unsupported list type for setting {setting_name!r}: {expected_type!r}"
)
if not isinstance(value, list):
raise RuleConfigurationError(
f"{rule_name}: setting {setting_name!r} at {path} expected {expected_type!r}, got {type(value)!r}"
)
item_type = args[0]
for index, item in enumerate(value):
_validate_value_for_type(
item,
item_type,
setting_name=setting_name,
rule_name=rule_name,
path=f"{path}[{index}]",
)
def _validate_dict_value(
value: object,
expected_type: object,
*,
setting_name: str,
rule_name: str,
path: str,
) -> None:
args = get_args(expected_type)
if len(args) != 2 or args[0] is not str:
raise RuleConfigurationError(
f"{rule_name}: unsupported dict type for setting {setting_name!r}: {expected_type!r}"
)
if not isinstance(value, Mapping):
raise RuleConfigurationError(
f"{rule_name}: setting {setting_name!r} at {path} expected {expected_type!r}, got {type(value)!r}"
)
item_type = args[1]
for key, item in value.items():
if not isinstance(key, str):
raise RuleConfigurationError(
f"{rule_name}: setting {setting_name!r} at {path} expected string keys, got {type(key)!r}"
)
_validate_value_for_type(
item,
item_type,
setting_name=setting_name,
rule_name=rule_name,
path=f"{path}.{key}",
)
def _validate_scalar_value(
value: object,
expected_type: object,
*,
setting_name: str,
rule_name: str,
path: str,
) -> None:
if not _is_scalar_setting_type(expected_type):
raise RuleConfigurationError(
f"{rule_name}: unsupported type for setting {setting_name!r}: {expected_type!r}"
)
assert isinstance(expected_type, type)
if not _is_instance_for_type(value, expected_type):
raise RuleConfigurationError(
f"{rule_name}: setting {setting_name!r} at {path} expected {expected_type!r}, got {type(value)!r}"
)
def _validate_value_for_type(
value: object,
expected_type: object,
*,
setting_name: str,
rule_name: str,
path: str,
) -> None:
origin = get_origin(expected_type)
if origin in (Union, UnionType):
_validate_union_value(
value,
expected_type,
setting_name=setting_name,
rule_name=rule_name,
path=path,
)
return
if origin is list:
_validate_list_value(
value,
expected_type,
setting_name=setting_name,
rule_name=rule_name,
path=path,
)
return
if origin is dict:
_validate_dict_value(
value,
expected_type,
setting_name=setting_name,
rule_name=rule_name,
path=path,
)
return
_validate_scalar_value(
value,
expected_type,
setting_name=setting_name,
rule_name=rule_name,
path=path,
)
[docs]
@dataclass(frozen=True)
class RuleSetting:
value_type: object
default: object = _RULE_SETTING_MISSING
validator: Callable[[object], object] | None = None
description: str = ""
def _validate_type(
self,
*,
value: object,
setting_name: str,
rule_name: str,
) -> None:
_validate_value_for_type(
value,
self.value_type,
setting_name=setting_name,
rule_name=rule_name,
path=setting_name,
)
[docs]
def validate(
self,
value: object,
*,
setting_name: str,
rule_name: str,
) -> object:
self._validate_type(value=value, setting_name=setting_name, rule_name=rule_name)
if self.validator:
try:
validator_result = self.validator(value)
except Exception as error:
raise RuleConfigurationError(
f"{rule_name}: setting {setting_name!r} failed validation: {error}"
) from error
if validator_result is False:
raise RuleConfigurationError(
f"{rule_name}: setting {setting_name!r} failed validation"
)
if validator_result is not None and validator_result is not True:
return validator_result
return value
[docs]
class LintRule(BatchableCSTVisitor):
"""
Lint rule implemented using LibCST.
To build a new lint rule, subclass this and `Implement a CST visitor
<https://libcst.readthedocs.io/en/latest/tutorial.html#Build-Visitor-or-Transformer>`_.
When a lint rule violation should be reported, use the :meth:`report` method.
"""
METADATA_DEPENDENCIES: ClassVar[Collection[ProviderT]] = ()
"""
Required LibCST metadata providers
"""
TAGS: set[str] = set()
"Arbitrary classification tags for use in configuration/selection"
NAME: ClassVar[str] = ""
"Explicit public rule name. Defaults to kebab-case generated from the class name."
PYTHON_VERSION: str = ""
"""
Compatible target Python versions, in `PEP 440 version specifier`__ format.
__ https://peps.python.org/pep-0440/#version-specifiers
"""
VALID: ClassVar[Sequence[str | Valid]]
"Test cases that should produce no errors/reports"
INVALID: ClassVar[Sequence[str | Invalid]]
"Test cases that are expected to produce errors, with optional replacements"
SETTINGS: ClassVar[dict[str, RuleSetting]] = {}
"Optional typed configuration settings for this lint rule."
REFERENCES: ClassVar[Sequence[RuleReference]] = ()
"External references for documentation, as URLs or ``(label, URL)`` pairs."
SOURCE_PATTERNS: ClassVar[tuple[SourcePattern, ...]] = ()
AUTOFIX = False # set by __subclass_init__
"""
Whether the lint rule contains an autofix.
Set to ``True`` automatically when :attr:`INVALID` contains at least one
test case that provides an expected replacement.
"""
name: ClassVar[str] = ""
"""
Canonical kebab-case name of this lint rule.
"""
def __init__(self) -> None:
self._violations: list[LintViolation] = []
self._lint_ignore_enabled = True
self.settings: Mapping[str, Any] = MappingProxyType({})
def __init_subclass__(cls) -> None:
super().__init_subclass__()
cls.name = cls.__dict__.get("NAME") or rule_name_from_class_name(cls.__name__)
invalid: list[str | Invalid] = getattr(cls, "INVALID", [])
cls.AUTOFIX = any(
isinstance(case, Invalid) and bool(case.expected_replacement) for case in invalid
)
def __str__(self) -> str:
return self.qualified_name()
[docs]
@classmethod
def qualified_name(cls) -> str:
return f"{cls.__module__}:{cls.name}"
[docs]
def should_lint_file(self, source: FileContent, _path: Path) -> bool:
return not self.SOURCE_PATTERNS or any(
_source_pattern_matches(source, pattern) for pattern in self.SOURCE_PATTERNS
)
_visit_hook: VisitHook | None = None
def _node_trailing_whitespace(self, node: CSTNode) -> TrailingWhitespace | None:
trailing_whitespace = getattr(node, "trailing_whitespace", None)
if trailing_whitespace is not None:
return trailing_whitespace
body = getattr(node, "body", None)
if isinstance(body, SimpleStatementSuite):
return body.trailing_whitespace
if isinstance(body, IndentedBlock):
return body.header
return None
def _yield_comment_value(
self, trailing_whitespace: TrailingWhitespace | None
) -> Generator[str, None, None]:
if trailing_whitespace and trailing_whitespace.comment:
yield trailing_whitespace.comment.value
def _yield_empty_line_comments(
self, empty_lines: Sequence[EmptyLine] | None
) -> Generator[str, None, None]:
if empty_lines is None:
return
for line in empty_lines:
if line.comment:
yield line.comment.value
def _yield_direct_node_comments(self, node: CSTNode) -> Generator[str, None, None]:
if isinstance(node, Comment):
yield node.value
yield from self._yield_comment_value(self._node_trailing_whitespace(node))
comma = getattr(node, "comma", None)
if isinstance(comma, Comma):
first_line = getattr(comma.whitespace_after, "first_line", None)
yield from self._yield_comment_value(first_line)
right_bracket = getattr(node, "rbracket", None)
if right_bracket is not None:
first_line = getattr(right_bracket.whitespace_before, "first_line", None)
yield from self._yield_comment_value(first_line)
left_bracket = getattr(node, "lbracket", None)
if left_bracket is not None:
yield from self._yield_empty_line_comments(
getattr(left_bracket.whitespace_after, "empty_lines", None)
)
yield from self._yield_empty_line_comments(getattr(node, "lines_after_decorators", None))
yield from self._yield_empty_line_comments(getattr(node, "leading_lines", None))
def _should_stop_comment_search(self, node: CSTNode) -> bool:
return getattr(node, "leading_lines", None) is not None and not isinstance(node, Decorator)
[docs]
def ignore_lint(self, node: CSTNode) -> bool:
"""
Whether to ignore a violation for a given node.
Returns true if any ``# rattle: ignore[...]`` directive matches the current
rule by name, or if the directive has no rule names listed.
"""
if not self._lint_ignore_enabled:
return False
rule_names = (self.name,)
for comment in self.node_comments(node):
directive = parse_lint_ignore_comment(comment)
if directive is None:
continue
if directive.names is None:
return True
for name in (n.strip() for n in directive.names.split(",")):
if name in rule_names:
return True
return False
[docs]
def report(
self,
node: CSTNode,
message: str,
*,
position: CodePosition | CodeRange | None = None,
position_node: CSTNode | None = None,
replacement: NodeReplacement[CSTNode] | None = None,
) -> None:
"""
Report a lint rule violation.
The optional `position` parameter can override the location where the
violation is reported. By default, the entire span of `node` is used. If
`position` is a `CodePosition`, only a single character is marked.
The optional `replacement` parameter can be used to provide an auto-fix for this
lint violation. Replacing `node` with `replacement` should make the lint
violation go away.
"""
if self.ignore_lint(node):
# TODO: consider logging/reporting this somewhere?
return
if isinstance(position, CodePosition):
end = replace(position, line=position.line + 1, column=0)
position = CodeRange(start=position, end=end)
self._violations.append(
LintViolation(
self.name,
range=position,
message=message,
node=node,
replacement=replacement,
position_node=position_node,
)
)
@classmethod
def _visitor_names(cls) -> tuple[str, ...]:
if cached := _VISITOR_NAMES_BY_RULE_TYPE.get(cls):
return cached
names: set[str] = set()
for rule_type in reversed(cls.__mro__):
if not issubclass(rule_type, BatchableCSTVisitor):
continue
for name, member in vars(rule_type).items():
if not isinstance(member, FunctionType):
continue
method_name = member.__name__
if method_name.startswith(("visit_", "leave_")) and not _is_no_op_visitor(member):
names.add(name)
result = tuple(sorted(names))
_VISITOR_NAMES_BY_RULE_TYPE[cls] = result
return result
[docs]
def get_visitors(self) -> Mapping[str, VisitorMethod]:
visitors = super().get_visitors()
visitor_names = self._visitor_names()
if self._visit_hook is None:
return {name: visitors[name] for name in visitor_names}
def _wrap(name: str, func: VisitorMethod) -> VisitorMethod:
@functools.wraps(func)
def wrapper(node: CSTNode) -> None:
assert self._visit_hook is not None
with self._visit_hook(name):
return func(node)
return wrapper
return {name: _wrap(f"{self.name}.{name}", visitors[name]) for name in visitor_names}
__all__ = [
"LintRule",
"RuleConfigurationError",
"RuleReference",
"RuleSetting",
]