diff --git a/src/maison/config.py b/src/maison/config.py index 70853be..f198330 100644 --- a/src/maison/config.py +++ b/src/maison/config.py @@ -16,7 +16,7 @@ def _bootstrap_service(package_name: str) -> service.ConfigService: _config_parser = config_parser.ConfigParser() - pyproject_parser = parsers.PyprojectParser(package_name=package_name) + pyproject_parser = parsers.PyprojectParser(tool_name=package_name) toml_parser = parsers.TomlParser() ini_parser = parsers.IniParser() diff --git a/src/maison/config_parser.py b/src/maison/config_parser.py index cec66d4..7985631 100644 --- a/src/maison/config_parser.py +++ b/src/maison/config_parser.py @@ -13,14 +13,14 @@ class Parser(typing.Protocol): """Defines the interface for a `Parser` class that's used to parse a config.""" - def parse_config(self, file_path: pathlib.Path) -> typedefs.ConfigValues: - """Parse a config file. + def parse_config(self, file: typing.BinaryIO) -> typedefs.ConfigValues: + """Parse a config. Args: - file_path: the path to the config file + file: the binary stream of the config file Returns: - the config values + the parsed config """ ... @@ -42,18 +42,22 @@ def register_parser( key = (suffix, stem) self._parsers[key] = parser - def parse_config(self, file_path: pathlib.Path) -> typedefs.ConfigValues: + def parse_config( + self, + file_path: pathlib.Path, + file: typing.BinaryIO, + ) -> typedefs.ConfigValues: """See `Parser.parse_config`.""" key: ParserDictKey # First try (suffix, stem) key = (file_path.suffix, file_path.stem) if key in self._parsers: - return self._parsers[key].parse_config(file_path) + return self._parsers[key].parse_config(file) # Then fallback to (suffix, None) key = (file_path.suffix, None) if key in self._parsers: - return self._parsers[key].parse_config(file_path) + return self._parsers[key].parse_config(file) raise errors.UnsupportedConfigError(f"No parser registered for {file_path}") diff --git a/src/maison/disk_filesystem.py b/src/maison/disk_filesystem.py index a2c40c5..51a6347 100644 --- a/src/maison/disk_filesystem.py +++ b/src/maison/disk_filesystem.py @@ -55,3 +55,7 @@ def get_file_path( return path / file_name return None + + def open_file(self, path: pathlib.Path) -> typing.BinaryIO: + """See `Filesystem.open_file`.""" + return path.open(mode="rb") diff --git a/src/maison/parsers/ini.py b/src/maison/parsers/ini.py index c621fe7..17dd154 100644 --- a/src/maison/parsers/ini.py +++ b/src/maison/parsers/ini.py @@ -1,7 +1,8 @@ """A parser for .ini files.""" import configparser -import pathlib +import io +import typing from maison import typedefs @@ -12,8 +13,12 @@ class IniParser: Implements the `Parser` protocol """ - def parse_config(self, file_path: pathlib.Path) -> typedefs.ConfigValues: + def parse_config(self, file: typing.BinaryIO) -> typedefs.ConfigValues: """See the Parser.parse_config method.""" config = configparser.ConfigParser() - _ = config.read(file_path) + text_io = io.TextIOWrapper(file, encoding="utf-8") + try: + config.read_file(text_io) + except UnicodeDecodeError: + return {} return {section: dict(config.items(section)) for section in config.sections()} diff --git a/src/maison/parsers/pyproject.py b/src/maison/parsers/pyproject.py index b588e95..7b4130f 100644 --- a/src/maison/parsers/pyproject.py +++ b/src/maison/parsers/pyproject.py @@ -1,37 +1,19 @@ """A parser for pyproject.toml files.""" -import pathlib -import sys +from maison.parsers import toml -if sys.version_info >= (3, 11): - import tomllib -else: - import tomli as tomllib - -from maison import typedefs - - -class PyprojectParser: +class PyprojectParser(toml.TomlParser): """Responsible for parsing pyproject.toml files. Implements the `Parser` protocol """ - def __init__(self, package_name: str) -> None: + def __init__(self, tool_name: str) -> None: """Initialise the pyproject reader. Args: - package_name: the name of the package to look for in file, e.g. + tool_name: the name of the package to look for in file, e.g. `acme` part of `[tool.acme]`. """ - self._package_name = package_name - - def parse_config(self, file_path: pathlib.Path) -> typedefs.ConfigValues: - """See the Parser.parse_config method.""" - try: - with file_path.open(mode="rb") as fd: - pyproject_dict = dict(tomllib.load(fd)) - except FileNotFoundError: - return {} - return dict(pyproject_dict.get("tool", {}).get(self._package_name, {})) + super().__init__(section_key=("tool", tool_name)) diff --git a/src/maison/parsers/toml.py b/src/maison/parsers/toml.py index 3cf0260..0501aba 100644 --- a/src/maison/parsers/toml.py +++ b/src/maison/parsers/toml.py @@ -1,6 +1,5 @@ """A parser for .toml files.""" -import pathlib import sys @@ -9,6 +8,8 @@ else: import tomli as tomllib +import typing + from maison import typedefs @@ -18,10 +19,34 @@ class TomlParser: Implements the `Parser` protocol """ - def parse_config(self, file_path: pathlib.Path) -> typedefs.ConfigValues: + def __init__(self, section_key: typing.Optional[tuple[str, ...]] = None) -> None: + """Instantiate the class. + + Args: + section_key: an optional toml section key/identifier to search for + within the toml. For example if the toml file contains: + + [tool.my_section] + my_value = true + + then setting `section_key=("tool", "my_section")` will return + `{"my_value": True}` as the config values. + + """ + self.section_key = section_key or () + + def parse_config(self, file: typing.BinaryIO) -> typedefs.ConfigValues: """See the Parser.parse_config method.""" try: - with file_path.open(mode="rb") as fd: - return dict(tomllib.load(fd)) - except (FileNotFoundError, tomllib.TOMLDecodeError): + values = dict(tomllib.load(file)) + except tomllib.TOMLDecodeError: return {} + + current = values + for key in self.section_key: + if key in current and isinstance(current[key], dict): + current = current[key] + else: + return {} + + return current diff --git a/src/maison/protocols.py b/src/maison/protocols.py index b708e39..83aca9a 100644 --- a/src/maison/protocols.py +++ b/src/maison/protocols.py @@ -44,16 +44,33 @@ def get_file_path( Returns: The `Path` to the file if it exists or `None` if it doesn't """ + ... + + def open_file(self, path: pathlib.Path) -> typing.BinaryIO: + """Open a file. + + Args: + path: the path to the file + + Returns: + the opened file as a binary I/O stream + """ + ... class ConfigParser(typing.Protocol): """Defines the interface for a class that parses a config.""" - def parse_config(self, file_path: pathlib.Path) -> typedefs.ConfigValues: + def parse_config( + self, + file_path: pathlib.Path, + file: typing.BinaryIO, + ) -> typedefs.ConfigValues: """Parse a config. Args: file_path: the path to a config file. + file: the binary I/O stream of the file. Returns: the parsed config diff --git a/src/maison/service.py b/src/maison/service.py index 10e17c9..d188774 100644 --- a/src/maison/service.py +++ b/src/maison/service.py @@ -67,7 +67,8 @@ def get_config_values( config_values: typedefs.ConfigValues = {} for path in config_file_paths: - parsed_config = self.config_parser.parse_config(path) + file = self.filesystem.open_file(path=path) + parsed_config = self.config_parser.parse_config(file_path=path, file=file) config_values = utils.deep_merge(config_values, parsed_config) if not merge_configs: diff --git a/tests/integration_tests/parsers/test_ini.py b/tests/integration_tests/parsers/test_ini.py deleted file mode 100644 index 1f13edf..0000000 --- a/tests/integration_tests/parsers/test_ini.py +++ /dev/null @@ -1,90 +0,0 @@ -import pathlib -import tempfile -import textwrap -import typing - -import pytest - -from maison.parsers import ini - - -FileFactory = typing.Callable[[str], pathlib.Path] - - -@pytest.fixture -def tmp_ini_file() -> FileFactory: - """Helper to create a temporary ini file.""" - - def _create(content: str) -> pathlib.Path: - with tempfile.NamedTemporaryFile(mode="w+", suffix=".ini", delete=False) as tmp: - _ = tmp.write(content) - tmp.flush() - return pathlib.Path(tmp.name) - - return _create - - -class TestParseConfig: - def test_parse_single_section(self, tmp_ini_file: FileFactory): - ini_content = textwrap.dedent(""" - [database] - host = localhost - port = 5432 - """) - path = tmp_ini_file(ini_content) - - reader = ini.IniParser() - result = reader.parse_config(path) - - assert result == {"database": {"host": "localhost", "port": "5432"}} - - def test_parse_multiple_sections(self, tmp_ini_file: FileFactory): - ini_content = textwrap.dedent(""" - [database] - host = localhost - port = 5432 - - [api] - key = secret - endpoint = https://example.com - """) - path = tmp_ini_file(ini_content) - - reader = ini.IniParser() - result = reader.parse_config(path) - - assert result == { - "database": {"host": "localhost", "port": "5432"}, - "api": {"key": "secret", "endpoint": "https://example.com"}, - } - - def test_empty_file_returns_empty_dict(self, tmp_ini_file: FileFactory): - path = tmp_ini_file("") - - reader = ini.IniParser() - result = reader.parse_config(path) - - assert result == {} - - def test_missing_file_returns_empty_dict(self, tmp_path: pathlib.Path): - path = tmp_path / "nonexistent.ini" - - reader = ini.IniParser() - result = reader.parse_config(path) - - assert result == {} - - def test_overlapping_keys_in_different_sections(self, tmp_ini_file: FileFactory): - ini_content = textwrap.dedent(""" - [section1] - key = value1 - - [section2] - key = value2 - """) - path = tmp_ini_file(ini_content) - - reader = ini.IniParser() - result = reader.parse_config(path) - - assert result == {"section1": {"key": "value1"}, "section2": {"key": "value2"}} diff --git a/tests/integration_tests/parsers/test_pyproject.py b/tests/integration_tests/parsers/test_pyproject.py deleted file mode 100644 index 211bc09..0000000 --- a/tests/integration_tests/parsers/test_pyproject.py +++ /dev/null @@ -1,99 +0,0 @@ -import pathlib -import tempfile -import textwrap -import typing - -import pytest - -from maison.parsers import pyproject - - -FileFactory = typing.Callable[[str], pathlib.Path] - - -@pytest.fixture -def tmp_pyproject_file() -> FileFactory: - """Helper to create a temporary pyproject file.""" - - def _create(content: str) -> pathlib.Path: - with tempfile.NamedTemporaryFile( - mode="w+", suffix=".toml", delete=False - ) as tmp: - _ = tmp.write(content) - tmp.flush() - return pathlib.Path(tmp.name) - - return _create - - -class TestParseConfig: - def test_parse_tool_section_with_values(self, tmp_pyproject_file: FileFactory): - toml_content = textwrap.dedent(""" - [tool.myapp] - debug = true - retries = 3 - url = "https://example.com" - """) - path = tmp_pyproject_file(toml_content) - - reader = pyproject.PyprojectParser("myapp") - result = reader.parse_config(path) - - assert result == {"debug": True, "retries": 3, "url": "https://example.com"} - - def test_returns_empty_dict_if_package_section_missing( - self, tmp_pyproject_file: FileFactory - ): - toml_content = textwrap.dedent(""" - [tool.otherapp] - enabled = true - """) - path = tmp_pyproject_file(toml_content) - - reader = pyproject.PyprojectParser("myapp") - result = reader.parse_config(path) - - assert result == {} - - def test_returns_empty_dict_if_tool_table_missing( - self, tmp_pyproject_file: FileFactory - ): - toml_content = textwrap.dedent(""" - [build-system] - requires = ["setuptools"] - """) - path = tmp_pyproject_file(toml_content) - - reader = pyproject.PyprojectParser("myapp") - result = reader.parse_config(path) - - assert result == {} - - def test_parse_nested_values_inside_package(self, tmp_pyproject_file: FileFactory): - toml_content = textwrap.dedent(""" - [tool.myapp.database] - host = "localhost" - port = 5432 - """) - path = tmp_pyproject_file(toml_content) - - reader = pyproject.PyprojectParser("myapp") - result = reader.parse_config(path) - - assert result == {"database": {"host": "localhost", "port": 5432}} - - def test_empty_file_returns_empty_dict(self, tmp_pyproject_file: FileFactory): - path = tmp_pyproject_file("") - - reader = pyproject.PyprojectParser("myapp") - result = reader.parse_config(path) - - assert result == {} - - def test_missing_file_raises_file_not_found(self, tmp_path: pathlib.Path): - path = tmp_path / "no_such_pyproject.toml" - - reader = pyproject.PyprojectParser("myapp") - result = reader.parse_config(path) - - assert result == {} diff --git a/tests/integration_tests/parsers/test_toml.py b/tests/integration_tests/parsers/test_toml.py deleted file mode 100644 index 63d088f..0000000 --- a/tests/integration_tests/parsers/test_toml.py +++ /dev/null @@ -1,103 +0,0 @@ -import pathlib -import tempfile -import textwrap -import typing - -import pytest - -from maison.parsers import toml - - -FileFactory = typing.Callable[[str], pathlib.Path] - - -@pytest.fixture -def tmp_toml_file() -> FileFactory: - """Helper to create a temporary toml file.""" - - def _create(content: str) -> pathlib.Path: - with tempfile.NamedTemporaryFile( - mode="w+", suffix=".toml", delete=False - ) as tmp: - _ = tmp.write(content) - tmp.flush() - return pathlib.Path(tmp.name) - - return _create - - -class TestParseConfig: - def test_parse_single_section(self, tmp_toml_file: FileFactory): - toml_content = textwrap.dedent(""" - [database] - host = "localhost" - port = 5432 - """) - path = tmp_toml_file(toml_content) - - reader = toml.TomlParser() - result = reader.parse_config(path) - - assert result == {"database": {"host": "localhost", "port": 5432}} - - def test_parse_multiple_sections(self, tmp_toml_file: FileFactory): - toml_content = textwrap.dedent(""" - [database] - host = "localhost" - port = 5432 - - [api] - key = "secret" - endpoint = "https://example.com" - """) - path = tmp_toml_file(toml_content) - - reader = toml.TomlParser() - result = reader.parse_config(path) - - assert result == { - "database": {"host": "localhost", "port": 5432}, - "api": {"key": "secret", "endpoint": "https://example.com"}, - } - - def test_empty_file_returns_empty_dict(self, tmp_toml_file: FileFactory): - path = tmp_toml_file("") - - reader = toml.TomlParser() - result = reader.parse_config(path) - - assert result == {} - - def test_missing_file_returns_empty_dict(self, tmp_path: pathlib.Path): - path = tmp_path / "nonexistent.toml" - - reader = toml.TomlParser() - result = reader.parse_config(path) - - assert result == {} - - def test_overlapping_keys_in_different_sections(self, tmp_toml_file: FileFactory): - toml_content = textwrap.dedent(""" - [section1] - key = "value1" - - [section2] - key = "value2" - """) - path = tmp_toml_file(toml_content) - - reader = toml.TomlParser() - result = reader.parse_config(path) - - assert result == {"section1": {"key": "value1"}, "section2": {"key": "value2"}} - - def test_invalid_toml_returns_an_empty_dict(self, tmp_toml_file: FileFactory): - toml_content = textwrap.dedent(""" - blah - """) - path = tmp_toml_file(toml_content) - - reader = toml.TomlParser() - result = reader.parse_config(path) - - assert result == {} diff --git a/tests/integration_tests/test_disk_filesystem.py b/tests/integration_tests/test_disk_filesystem.py index f274022..588592f 100644 --- a/tests/integration_tests/test_disk_filesystem.py +++ b/tests/integration_tests/test_disk_filesystem.py @@ -43,3 +43,15 @@ def test_get_file_path_returns_none_if_not_found(self): result = fs.get_file_path("ghost.ini") assert result is None + + +class TestOpenFile: + def test_opens_file(self, tmp_path: pathlib.Path): + fs = disk_filesystem.DiskFilesystem() + + file = tmp_path / "thing.txt" + _ = file.write_text("hello") + + result = fs.open_file(path=file) + + assert result.read() == b"hello" diff --git a/tests/integration_tests/parsers/__init__.py b/tests/unit_tests/parsers/__init__.py similarity index 100% rename from tests/integration_tests/parsers/__init__.py rename to tests/unit_tests/parsers/__init__.py diff --git a/tests/unit_tests/parsers/test_ini.py b/tests/unit_tests/parsers/test_ini.py new file mode 100644 index 0000000..bbdfd7e --- /dev/null +++ b/tests/unit_tests/parsers/test_ini.py @@ -0,0 +1,71 @@ +import io +import textwrap + +from maison.parsers import ini + + +class TestParseConfig: + def test_parse_single_section(self): + ini_content = textwrap.dedent(""" + [database] + host = localhost + port = 5432 + """) + file = io.BytesIO(ini_content.encode()) + + reader = ini.IniParser() + result = reader.parse_config(file) + + assert result == {"database": {"host": "localhost", "port": "5432"}} + + def test_parse_multiple_sections(self): + ini_content = textwrap.dedent(""" + [database] + host = localhost + port = 5432 + + [api] + key = secret + endpoint = https://example.com + """) + file = io.BytesIO(ini_content.encode()) + + reader = ini.IniParser() + result = reader.parse_config(file) + + assert result == { + "database": {"host": "localhost", "port": "5432"}, + "api": {"key": "secret", "endpoint": "https://example.com"}, + } + + def test_empty_file_returns_empty_dict(self): + file = io.BytesIO() + + reader = ini.IniParser() + result = reader.parse_config(file) + + assert result == {} + + def test_invalid_bytes_returns_empty_dict(self): + ini_content = b"\xff\xfe\x00bad ini" + file = io.BytesIO(ini_content) + + reader = ini.IniParser() + result = reader.parse_config(file) + + assert result == {} + + def test_overlapping_keys_in_different_sections(self): + ini_content = textwrap.dedent(""" + [section1] + key = value1 + + [section2] + key = value2 + """) + file = io.BytesIO(ini_content.encode()) + + reader = ini.IniParser() + result = reader.parse_config(file) + + assert result == {"section1": {"key": "value1"}, "section2": {"key": "value2"}} diff --git a/tests/unit_tests/parsers/test_toml.py b/tests/unit_tests/parsers/test_toml.py new file mode 100644 index 0000000..9ed4751 --- /dev/null +++ b/tests/unit_tests/parsers/test_toml.py @@ -0,0 +1,90 @@ +import io +import textwrap + +from maison.parsers import toml + + +class TestParseConfig: + def test_parse_single_section(self): + toml_content = textwrap.dedent(""" + [database] + host = "localhost" + port = 5432 + """) + file = io.BytesIO(toml_content.encode()) + + reader = toml.TomlParser() + result = reader.parse_config(file) + + assert result == {"database": {"host": "localhost", "port": 5432}} + + def test_parse_multiple_sections(self): + toml_content = textwrap.dedent(""" + [database] + host = "localhost" + port = 5432 + + [api] + key = "secret" + endpoint = "https://example.com" + """) + file = io.BytesIO(toml_content.encode()) + + reader = toml.TomlParser() + result = reader.parse_config(file) + + assert result == { + "database": {"host": "localhost", "port": 5432}, + "api": {"key": "secret", "endpoint": "https://example.com"}, + } + + def test_empty_file_returns_empty_dict(self): + file = io.BytesIO(b"") + reader = toml.TomlParser() + result = reader.parse_config(file) + assert result == {} + + def test_invalid_toml_returns_empty_dict(self): + file = io.BytesIO(b"not valid toml!") + reader = toml.TomlParser() + result = reader.parse_config(file) + assert result == {} + + def test_overlapping_keys_in_different_sections(self): + toml_content = textwrap.dedent(""" + [section1] + key = "value1" + + [section2] + key = "value2" + """) + file = io.BytesIO(toml_content.encode()) + + reader = toml.TomlParser() + result = reader.parse_config(file) + + assert result == {"section1": {"key": "value1"}, "section2": {"key": "value2"}} + + def test_section_key_returns_subset_of_dict(self): + toml_content = textwrap.dedent(""" + [tool.section] + key = "value" + """) + file = io.BytesIO(toml_content.encode()) + + reader = toml.TomlParser(section_key=("tool", "section")) + result = reader.parse_config(file) + + assert result == {"key": "value"} + + def test_non_existent_section_key_returns_empty_dict(self): + toml_content = textwrap.dedent(""" + [tool.section] + key = "value" + """) + file = io.BytesIO(toml_content.encode()) + + reader = toml.TomlParser(section_key=("tool", "other_section")) + result = reader.parse_config(file) + + assert result == {} diff --git a/tests/unit_tests/test_config_reader.py b/tests/unit_tests/test_config_reader.py index 338e976..6801881 100644 --- a/tests/unit_tests/test_config_reader.py +++ b/tests/unit_tests/test_config_reader.py @@ -1,4 +1,6 @@ +import io import pathlib +import typing import pytest @@ -8,12 +10,12 @@ class FakePyprojectParser: - def parse_config(self, file_path: pathlib.Path) -> typedefs.ConfigValues: + def parse_config(self, file: typing.BinaryIO) -> typedefs.ConfigValues: return {"config": "pyproject"} class FakeTomlParser: - def parse_config(self, file_path: pathlib.Path) -> typedefs.ConfigValues: + def parse_config(self, file: typing.BinaryIO) -> typedefs.ConfigValues: return {"config": "toml"} @@ -26,7 +28,10 @@ def test_uses_parser_by_file_path_and_stem(self): suffix=".toml", parser=FakePyprojectParser(), stem="pyproject" ) - values = self.parser.parse_config(pathlib.Path("path/to/pyproject.toml")) + values = self.parser.parse_config( + file_path=pathlib.Path("path/to/pyproject.toml"), + file=io.BytesIO(b"file"), + ) assert values == {"config": "pyproject"} @@ -36,10 +41,16 @@ def test_falls_back_to_suffix(self): ) self.parser.register_parser(suffix=".toml", parser=FakeTomlParser()) - values = self.parser.parse_config(pathlib.Path("path/to/.acme.toml")) + values = self.parser.parse_config( + pathlib.Path("path/to/.acme.toml"), + file=io.BytesIO(b"file"), + ) assert values == {"config": "toml"} def test_raises_error_if_no_parser_available(self): with pytest.raises(errors.UnsupportedConfigError): - _ = self.parser.parse_config(pathlib.Path("path/to/.acme.toml")) + _ = self.parser.parse_config( + pathlib.Path("path/to/.acme.toml"), + file=io.BytesIO(b"file"), + ) diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index 674a540..b478e81 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -1,3 +1,4 @@ +import io import pathlib import typing @@ -14,9 +15,16 @@ def get_file_path( return None return pathlib.Path(f"/path/to/{file_name}") + def open_file(self, path: pathlib.Path, mode: str = "rb") -> typing.BinaryIO: + return io.BytesIO(b"file") + class FakeConfigParser: - def parse_config(self, file_path: pathlib.Path) -> typedefs.ConfigValues: + def parse_config( + self, + file_path: pathlib.Path, + file: typing.BinaryIO, + ) -> typedefs.ConfigValues: return { "values": {file_path.stem: file_path.suffix}, }