diff --git a/.gitignore b/.gitignore index 3e38c56..c059054 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ +/scratch poetry.lock -tests/* __pycache__ .mypy_cache .ruff_cache +.venv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..226748d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-toml + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.13.0 # Use the sha / tag you want to point at + hooks: + - id: mypy + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.7.4 + hooks: + # Run the linter. + - id: ruff + name: ruff + description: "Run 'ruff' for extremely fast Python linting" + entry: ruff check --force-exclude + language: python + args: [--fix, --exit-non-zero-on-fix] + require_serial: true + # Run the formatter. + - id: ruff-format diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..3bbe8f4 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "cSpell.words": ["ATTRIBS", "caissuer", "ocsp", "tlsserial"], + "python.testing.pytestArgs": ["tests"], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} diff --git a/README.md b/README.md index b55e860..12d065f 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Options: --help Show this message and exit. ``` -## from a URL +### from a URL ```console ❯ poetry run tlsserial --url dell.com @@ -37,7 +37,7 @@ ca_issuers : http://aia.entrust.net/l1k-chain256.cer serial_number : 5AF6B00AD82F3B8FACCEF4123D36138C ``` -## from a file +### from a file ```console ❯ poetry run tlsserial --file ~/axiom.crt @@ -55,3 +55,10 @@ ocsp : http://ocsp.e3m02.amazontrust.com ca_issuers : http://crt.e3m02.amazontrust.com/e3m02.cer serial_number : 0C9E25D31C5E5ECABC2AB6F10D89C3AF ``` + + +## Development + +poetry install +poetry run pre-commit install +poetry run pytest -v diff --git a/experiments/getchain.py b/experiments/getchain.py new file mode 100644 index 0000000..31dae19 --- /dev/null +++ b/experiments/getchain.py @@ -0,0 +1,48 @@ +import logging +import socket +import sys + +import certifi +from OpenSSL import SSL, crypto + +hostname = "www.google.com" +port = 443 + +methods = [ + (SSL.SSLv23_METHOD, "SSL.SSLv23_METHOD"), + (SSL.TLSv1_METHOD, "SSL.TLSv1_METHOD"), + (SSL.TLSv1_1_METHOD, "SSL.TLSv1_1_METHOD"), + (SSL.TLSv1_2_METHOD, "SSL.TLSv1_2_METHOD"), +] + +for method, method_name in methods: + try: + print(f"\n-- Method {method_name}") + context = SSL.Context(method=method) + context.load_verify_locations(cafile=certifi.where()) + + conn = SSL.Connection( + context, socket=socket.socket(socket.AF_INET, socket.SOCK_STREAM) + ) + conn.settimeout(5) + conn.connect((hostname, port)) + conn.setblocking(1) + conn.do_handshake() + conn.set_tlsext_host_name(hostname.encode()) + chain = conn.get_peer_cert_chain() + + def decode(x: crypto.X509Name) -> str: + return "/".join( + ["=".join(z.decode("utf-8") for z in y) for y in x.get_components()] + ) + + if chain: + for idx, cert in enumerate(chain): + print(f"{idx} subject: {decode(cert.get_subject())}") + print(f" issuer: {decode(cert.get_issuer())})") + print(f" serial: {cert.get_serial_number()}") + print(f' fingerprint: {cert.digest("sha1")}') + + conn.close() + except SSL.Error: + logging.error(f"<><> Method {method_name} failed due to {sys.exc_info()[0]}") diff --git a/pyproject.toml b/pyproject.toml index 8e1190a..22d66e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,18 +6,59 @@ authors = ["Tom Matthews "] readme = "README.md" [tool.poetry.scripts] -tlsserial = 'tlsserial:main' +tlsserial = 'tlsserial.cli:main' [tool.poetry.dependencies] python = "^3.11" click = "^8.1.6" -cryptography = "^41.0" +cryptography = "^43.0.3" pendulum = "^3.0" - +certifi = "*" # Used in experiments +pyopenssl = "*" # Used in experiments [tool.poetry.group.dev.dependencies] -types-cryptography = "^3.3.23.2" -mypy = "^1.4.1" +mypy = "*" # with dev dependencies you almost always want the latest +ruff = "*" +pre-commit = "*" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + + +[tool.ruff] +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "D", # pydocstyle + "PL", # pylint + "FIX", # flake8-fixme + "G", # flake8-logging-format + "PTH", # flake8-use-pathlib + "UP", # pyupgrade + "S", # flake8-bandit + "A", # flake8-builtins + "S", # flake8-simplify + "ARG", # flake8-unused-arguments + "INT", # flake8-gettext + "Q", # flake8-quotes + "C4", # flake8-comprehensions + "ISC", # flake8-implicit-str-concat +] +fixable = ["ALL"] +unfixable = [] +[tool.ruff.lint.per-file-ignores] +#"__init__.py" = ["E402"] # ignore import errors in all __init__.py files +"**/{tests,docs,tools}/*" = [ + "E402", # ignore import errors in selected subdirectories. + "D100", # C0114 missing-module-docstring + "D103", # C0116 missing-function-docstring + "S101", # Use of `assert` detected +] + +[tool.mypy] +#disallow_untyped_defs = true +ignore_missing_imports = true +exclude = ['venv', '.venv', 'tests'] diff --git a/lib/__init__.py b/src/tlsserial/__init__.py similarity index 100% rename from lib/__init__.py rename to src/tlsserial/__init__.py diff --git a/src/tlsserial/__main__.py b/src/tlsserial/__main__.py new file mode 100644 index 0000000..4d990e1 --- /dev/null +++ b/src/tlsserial/__main__.py @@ -0,0 +1,18 @@ +""" +This file allows you to run the program directly from the python module: + + ~ $ python -m tlsserial --help + +""" + +import sys + +from .cli import main + +rc = 1 +try: + main() # pylint: disable=no-value-for-parameter # noqa + rc = 0 +except Exception as e: # pylint: disable=broad-exception-caught + print("Error:", e, file=sys.stderr) +sys.exit(rc) diff --git a/src/tlsserial/cli.py b/src/tlsserial/cli.py new file mode 100644 index 0000000..7bee2d9 --- /dev/null +++ b/src/tlsserial/cli.py @@ -0,0 +1,44 @@ +import logging +import os +from ssl import OPENSSL_VERSION + +import click + +from . import helper, tlsserial + + +# https://click.palletsprojects.com/en/8.1.x/quickstart/ +@click.command() +@click.option( + "--url", + cls=helper.MutuallyExclusiveOption, + mutually_exclusive=["file"], + help="host || host:port || https://host:port/other", +) +@click.option( + "--file", + cls=helper.MutuallyExclusiveOption, + mutually_exclusive=["url"], + help="filename containing a PEM certificate", +) +@click.option("--debug", is_flag=True, type=bool, default=False, help="Debug logging") +@click.option( + "--verbose", is_flag=True, type=bool, default=False, help="Verbose output" +) +def main(url, file, debug, verbose) -> None: + """tlsserial groks X509 certificates for your pleasure""" + default_level = "DEBUG" if debug else "INFO" + logging.basicConfig( + level=getattr(logging, os.getenv("LOGLEVEL", default_level).upper()), + format="[%(levelname)s] %(asctime)s - %(message)s", + datefmt="%Y-%m-%dT%H:%M:%S", + ) + logging.debug("Logging is set to DEBUG level") + if url: + tlsserial.handle_url(url, verbose) + elif file: + tlsserial.handle_file(file, verbose) + else: + click.echo(f"Library version : {OPENSSL_VERSION}") + ctx = click.get_current_context() + click.echo(ctx.get_help()) diff --git a/lib/color.py b/src/tlsserial/color.py similarity index 64% rename from lib/color.py rename to src/tlsserial/color.py index 0a2158d..94878f5 100644 --- a/lib/color.py +++ b/src/tlsserial/color.py @@ -1,5 +1,9 @@ -""" Colour functions for wrapping strings """ -from sys import __stdin__, __stdout__ +"""Colour functions for wrapping strings""" + +# TODO: take a look at the following if you want to take it further (probably overkill) +# https://github.com/Textualize/rich +# https://dslackw.gitlab.io/colored/user_guide/user_guide/ +import sys # Styles BOLD = "\x1b[1m" @@ -11,11 +15,13 @@ COSMOS = "\x1b[38;2;223;42;93m" # Red # Dont wrap with ANSI escape colour codes if we're not a TTY supporting that -IS_TTY_STDIN = __stdin__.isatty() -IS_TTY_STDOUT = __stdout__.isatty() + +IS_TTY_STDIN = sys.stdin.isatty() +IS_TTY_STDOUT = sys.stdout.isatty() def bold(text: str) -> str: + """bold string""" if IS_TTY_STDOUT: return BOLD + text + END else: @@ -23,6 +29,7 @@ def bold(text: str) -> str: def blue(text: str) -> str: + """blue string""" if IS_TTY_STDOUT: return SKY + text + END else: @@ -30,6 +37,7 @@ def blue(text: str) -> str: def orange(text: str) -> str: + """orange string""" if IS_TTY_STDOUT: return SMILE + text + END else: @@ -37,6 +45,7 @@ def orange(text: str) -> str: def red(text: str) -> str: + """red string""" if IS_TTY_STDOUT: return COSMOS + text + END else: diff --git a/src/tlsserial/happy_certificate.py b/src/tlsserial/happy_certificate.py new file mode 100644 index 0000000..b12a3a2 --- /dev/null +++ b/src/tlsserial/happy_certificate.py @@ -0,0 +1,209 @@ +from dataclasses import dataclass +from typing import Any, Dict, List + +import pendulum +from cryptography import x509 +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec, padding, rsa +from cryptography.x509 import ExtensionNotFound +from cryptography.x509.oid import ExtensionOID, NameOID + +NAME_ATTRIBS = [ + ("CN", NameOID.COMMON_NAME), + ("O", NameOID.ORGANIZATION_NAME), + ("OU", NameOID.ORGANIZATIONAL_UNIT_NAME), + ("L", NameOID.LOCALITY_NAME), + ("ST", NameOID.STATE_OR_PROVINCE_NAME), + ("C", NameOID.COUNTRY_NAME), + ("DC", NameOID.DOMAIN_COMPONENT), + ("E", NameOID.EMAIL_ADDRESS), +] + +@dataclass +class HappyCertificate: + """HappyCertificate is a happy certificate.""" + + cert: x509.Certificate + + + @property + def version(self) -> int: + """Return the x509 version.""" + return self.cert.version.value + 1 + + @property + def issuer(self) -> list[str]: + """Issuer.""" + return [ + f"[{a}] {n.value}" + for a, b in NAME_ATTRIBS + for n in self.cert.issuer.get_attributes_for_oid(b) + ] + + @property + def subject(self) -> List[str]: + """Subject.""" + return [ + f"[{a}] {n.value}" + for a, b in NAME_ATTRIBS + for n in self.cert.subject.get_attributes_for_oid(b) + ] + + @property + def sans(self) -> list[Any]: + """The Subject Alternative Names.""" + try: + return self.cert.extensions.get_extension_for_oid( + ExtensionOID.SUBJECT_ALTERNATIVE_NAME + ).value.get_values_for_type(x509.DNSName) + except ExtensionNotFound: + return [] + + @property + def basic_constraints(self) -> Dict[str, str]: + """Return the CA BasicConstraint properties.""" + basic_constraints: Dict[str, str] = {} + try: + basic_constraints_object = self.cert.extensions.get_extension_for_oid( + ExtensionOID.BASIC_CONSTRAINTS + ).value + basic_constraints["ca"] = str(basic_constraints_object.ca) + basic_constraints["path_length"] = str( + basic_constraints_object.path_length + ) + except (ValueError, ExtensionNotFound): + pass + return basic_constraints + + @property + def key_usage(self) -> list[str]: + """Key usage.""" + key_usage = [] + try: + key_usage_object = self.cert.extensions.get_extension_for_oid( + ExtensionOID.KEY_USAGE + ).value + for attr, value in key_usage_object.__dict__.items(): + if value is True: + key_usage.append(attr.lstrip("_")) + except (UnboundLocalError, ValueError, ExtensionNotFound): + pass + return key_usage + + @property + def ext_key_usage(self) -> list[str]: + """Returns list of Extended key usages.""" + ext_key_usage = [] + try: + ext_key_usage_object = self.cert.extensions.get_extension_for_oid( + ExtensionOID.EXTENDED_KEY_USAGE + ).value + for usage in ext_key_usage_object.__dict__["_usages"]: + ext_key_usage.append(usage._name) + except (UnboundLocalError, ValueError, ExtensionNotFound): + pass + return ext_key_usage + + @property + def before_and_after(self) -> tuple: + """Returns tuple of not_after and not_before datetimes.""" + not_after = pendulum.parse(self.cert.not_valid_after.isoformat()) + not_before = pendulum.parse(self.cert.not_valid_before.isoformat()) + return not_before, not_after + + @property + def crls(self) -> str: + """Returns CRLs.""" + crls = "" + try: + crl_distribution_points = [ + crl_dp.full_name + for crl_dp in self.cert.extensions.get_extension_for_oid( + ExtensionOID.CRL_DISTRIBUTION_POINTS + ).value + ] + crls = " ".join( + crl.value for crl_list in crl_distribution_points for crl in crl_list + ) + except (ValueError, ExtensionNotFound): + pass + return crls + + @property + def ocsp_and_caissuer(self) -> tuple: + """Returns tuple of OCSP and CA Issuers locations.""" + ocsp = "" + ca_issuers = "" + try: + authorityInfoAccess = self.cert.extensions.get_extension_for_oid( + ExtensionOID.AUTHORITY_INFORMATION_ACCESS + ) + for access in authorityInfoAccess.value: + name = access.access_method._name + location = access.access_location._value + if "OCSP" in name: + ocsp = location + elif "caIssuers" in name: + ca_issuers = location + except (ValueError, ExtensionNotFound): + pass + return ocsp, ca_issuers + + @property + def key_type(self): + """Return the public key type (eg. RSA/DSA/etc).""" + public_key = self.cert.public_key() + key_type = f"{type(public_key).__name__.lstrip('_')}" + if isinstance(public_key, ec.EllipticCurvePublicKey): + key_type += f" {public_key.public_numbers().curve.name}" + return key_type + + @property + def key_bits(self) -> int: + """Returns the bit length of the public key.""" + key_bits = 0 + public_key = self.cert.public_key() + if isinstance(public_key, rsa.RSAPublicKey): + key_bits = public_key.key_size + elif isinstance(public_key, ec.EllipticCurvePublicKey): + key_bits = public_key.key_size + return key_bits + + @property + def key_factors(self) -> dict: + """Returns dict w/ modulus size and exponent from public key bits.""" + """or other key factors where appropriate""" + key_factors: Dict[str, int] = {} + public_key = self.cert.public_key() + if isinstance(public_key, rsa.RSAPublicKey): + key_factors["exponent"] = public_key.public_numbers().e + key_factors["n"] = public_key.public_numbers().n + elif isinstance(public_key, ec.EllipticCurvePublicKey): + key_factors["x"] = public_key.public_numbers().x + key_factors["y"] = public_key.public_numbers().y + return key_factors + + @property + def sig_algorithm(self) -> str | None: + """Return the signature algorithm.""" + if isinstance(self.cert.signature_hash_algorithm, hashes.HashAlgorithm): + sig_algo = self.cert.signature_algorithm_oid._name + else: + sig_algo = None + return sig_algo + + @property + def sig_algorithm_params(self) -> str: + """Return the signature algorithm parameters.""" + pss = self.cert.signature_algorithm_parameters + try: + if isinstance(pss, padding.PSS): + return "PSS" + elif isinstance(pss, padding.PKCS1v15): + return "PKCS1v15" + elif isinstance(pss, padding.ECDSA): + return "ECDSA" + else: + return "n/a" + except AttributeError: + return "n/a" diff --git a/lib/helper.py b/src/tlsserial/helper.py similarity index 67% rename from lib/helper.py rename to src/tlsserial/helper.py index 596d00e..a398a6f 100644 --- a/lib/helper.py +++ b/src/tlsserial/helper.py @@ -1,14 +1,15 @@ -""" Helper functions to parse pypy cryptography x509 objects """ +"""Helper functions to parse pypy cryptography x509 objects""" + import logging import os import socket import ssl from time import perf_counter -from typing import Dict, List # Lets do static type checking with mypy +from typing import Any, Dict, List # Lets do static type checking with mypy + +import pendulum from click import Option, UsageError from cryptography import x509 -from cryptography.x509 import DNSName, ExtensionNotFound -from cryptography.x509.oid import ExtensionOID, NameOID from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ( # dh, @@ -22,7 +23,8 @@ # x448, # x25519, ) -import pendulum +from cryptography.x509 import DNSName, ExtensionNotFound +from cryptography.x509.oid import ExtensionOID, NameOID NAME_ATTRIBS = ( ("CN", NameOID.COMMON_NAME), @@ -39,6 +41,8 @@ def timethis(func): + # TODO: If you are only going to log when debugging then dont do any + # calculations unless you are debugging """Sample decorator to report a function runtime in milliseconds""" def wrapper(*args, **kwargs): @@ -47,21 +51,28 @@ def wrapper(*args, **kwargs): retval = func(*args, **kwargs) time_after = perf_counter() time_diff = time_after - time_before - if debug: + # TODO: could this just be a log at debug level? + if debug: # noqa: F821 # __qualname__ returns the name of the func passed in logging.info(f"({func.__qualname__}) took {time_diff:.3f} seconds") return retval return wrapper + class MutuallyExclusiveOption(Option): + """Click helper to create mutually exclusive options""" + + # TODO: This should be with the CLI code def __init__(self, *args, **kwargs): - self.mutually_exclusive = set(kwargs.pop('mutually_exclusive', [])) - help = kwargs.get('help', '') + self.mutually_exclusive = set(kwargs.pop("mutually_exclusive", [])) + help = kwargs.get("help", "") if self.mutually_exclusive: - ex_str = ', '.join(self.mutually_exclusive) - kwargs['help'] = help + ( - ' NOTE: This argument is mutually exclusive with arguments: [' + ex_str + '].' + ex_str = ", ".join(self.mutually_exclusive) + kwargs["help"] = help + ( + " NOTE: This argument is mutually exclusive with arguments: [" + + ex_str + + "]." ) super(MutuallyExclusiveOption, self).__init__(*args, **kwargs) @@ -69,21 +80,21 @@ def handle_parse_result(self, ctx, opts, args): if self.mutually_exclusive.intersection(opts) and self.name in opts: raise UsageError( "Illegal usage: `{}` is mutually exclusive with arguments `{}`.".format( - self.name, - ', '.join(self.mutually_exclusive) + self.name, ", ".join(self.mutually_exclusive) ) ) - return super(MutuallyExclusiveOption, self).handle_parse_result( - ctx, - opts, - args - ) + return super(MutuallyExclusiveOption, self).handle_parse_result(ctx, opts, args) -def get_certs_from_host(host, port=443, timeout=8) -> tuple[None | List[x509.Certificate], str]: + +def get_certs_from_host( + host, port=443, timeout=8 +) -> tuple[None | List[x509.Certificate], None | List[x509.Certificate], str]: """Use ssl library to get certificate details from a host""" """Then use 'cryptography' to parse the certificate and return the ugly X509 object""" """Returns (certificate, certificate chain, return status message)""" + # TODO: break this up in to smaller testable methods or functions or even private functions + # TODO: https://peps.python.org/pep-0257/#multi-line-docstrings context = ssl.create_default_context() # We want to retrieve even expired certificates context.check_hostname = False @@ -91,19 +102,20 @@ def get_certs_from_host(host, port=443, timeout=8) -> tuple[None | List[x509.Cer try: with socket.create_connection((host, port), timeout) as connection: with context.wrap_socket(connection, server_hostname=host) as sock: + # TODO: See experiment for possible alternative # FIXME: We really shouldnt use private methods, but # cryptography doesn't expose the certificate chain yet # https://github.com/python/cpython/issues/62433 - sslobj_verified_chain = sock._sslobj.get_verified_chain() + sslobj_verified_chain = sock._sslobj.get_verified_chain() # type: ignore # pylint: disable=protected-access # [<_ssl.Certificate 'CN=expired.rootca1.demo.amazontrust.com'>, # <_ssl.Certificate 'CN=Amazon RSA 2048 M01,O=Amazon,C=US'>, # <_ssl.Certificate 'CN=Amazon Root CA 1,O=Amazon,C=US'>] ssl_chain: List = [] for _, cert in enumerate(sslobj_verified_chain): - for tup in cert.get_info()['subject']: + for tup in cert.get_info()["subject"]: # Each Certificate object has a get_info method, which # returns the subject in an awful tuple of tuples - if tup[0][0] == 'commonName': + if tup[0][0] == "commonName": common_name_val = f"[CN] {tup[0][1]}" ssl_chain.append(common_name_val) sock.settimeout(timeout) @@ -112,15 +124,21 @@ def get_certs_from_host(host, port=443, timeout=8) -> tuple[None | List[x509.Cer finally: sock.close() if cert_der is None: - return (None, "Failed to get peer certificate!") + return (None, None, "Failed to get peer certificate!") else: cert_pem = ssl.DER_cert_to_PEM_cert(cert_der) return ( # load_certificate takes a bytes object, so encode cert_pem x509.load_pem_x509_certificates(str.encode(cert_pem)), ssl_chain, - "SSL certificate" + "SSL certificate", ) + # TODO: do all you exceptions as close as possible to where the can occur. It + # helps with debugging. it may also make sense to handle the all in the caller. + # Have a look at the errors too, you may not need to write the message + # + # > except ssl.SSLError as err: + # > return (None, None, f"SSL Error: {err}") except socket.timeout: return (None, None, "Socket timeout!") except ssl.SSLEOFError: @@ -133,18 +151,21 @@ def get_certs_from_host(host, port=443, timeout=8) -> tuple[None | List[x509.Cer return (None, None, "Connection error!") -def get_certs_from_file(filename: str, mode="r") -> tuple[None | List[x509.Certificate], str]: - """Use ssl library to get certificate details from disk""" - """Then use 'cryptography' to parse the certificate and return the ugly X509 object""" +def get_certs_from_file( + filename: str, mode="r" +) -> tuple[None | List[x509.Certificate], str]: + """Use ssl library to get certificate details from disk + + Then use 'cryptography' to parse the certificate and return the ugly X509 object""" try: base = os.path.dirname(__file__) - with open(os.path.join(base, filename), mode) as file: + with open(os.path.join(base, filename), mode, encoding="utf-8") as file: return ( # load_certificate takes a bytes object, so encode cert_pem x509.load_pem_x509_certificates(str.encode(file.read())), - "SSL certificate" + "SSL certificate", ) - except ValueError as err: + except ValueError as err: # TODO: do all you exceptions as close as possible to where the can occur. return (None, f"{err}") except FileNotFoundError as err: return (None, f"{err}") @@ -157,8 +178,9 @@ def get_version(cert: x509.Certificate) -> int: return cert.version.value + 1 -def get_issuer(cert: x509.Certificate): +def get_issuer(cert: x509.Certificate) -> list[str]: """Issuer""" + # TODO: can be a nested comprehension I think issuer = [] for a, b in NAME_ATTRIBS: for n in cert.issuer.get_attributes_for_oid(b): @@ -166,8 +188,9 @@ def get_issuer(cert: x509.Certificate): return issuer -def get_subject(cert: x509.Certificate): +def get_subject(cert: x509.Certificate) -> List[str]: """Subject""" + # TODO: can be a nested comprehension I think subject = [] for a, b in NAME_ATTRIBS: for n in cert.subject.get_attributes_for_oid(b): @@ -175,8 +198,9 @@ def get_subject(cert: x509.Certificate): return subject -def get_sans(cert: x509.Certificate): +def get_sans(cert: x509.Certificate) -> list[Any]: """The Subject Alternative Names""" + # TODO: just return the value instead of assigning in the try. Effectively the same. sans = [] try: sans = cert.extensions.get_extension_for_oid( @@ -187,11 +211,12 @@ def get_sans(cert: x509.Certificate): return sans -def get_basic_constraints(cert: x509.Certificate): +def get_basic_constraints(cert: x509.Certificate) -> Dict[str, str]: """Return the CA BasicConstraint properties""" + # TODO: you can do this in a single try/except basic_constraints: Dict[str, str] = {} - basic_constraints['ca'] = "" - basic_constraints['path_length'] = "" + basic_constraints["ca"] = "" + basic_constraints["path_length"] = "" try: basic_constraints_object = cert.extensions.get_extension_for_oid( ExtensionOID.BASIC_CONSTRAINTS @@ -200,14 +225,14 @@ def get_basic_constraints(cert: x509.Certificate): except (ValueError, ExtensionNotFound): pass try: - basic_constraints['ca'] = str(basic_constraints_object.ca) - basic_constraints['path_length'] = str(basic_constraints_object.path_length) + basic_constraints["ca"] = str(basic_constraints_object.ca) + basic_constraints["path_length"] = str(basic_constraints_object.path_length) except (ValueError, ExtensionNotFound, UnboundLocalError): pass return basic_constraints -def get_key_usage(cert: x509.Certificate): +def get_key_usage(cert: x509.Certificate) -> list[str]: """Key usage""" key_usage = [] try: @@ -221,19 +246,28 @@ def get_key_usage(cert: x509.Certificate): # data_encipherment=False, key_agreement=False, key_cert_sign=False, # crl_sign=False, encipher_only=False, decipher_only=False )> try: + # TODO: classic comprehension + # key_usage = [ + # attr.lstrip("_") for attr, value in key_usage_object.__dict__.items() if value is True + # ] + # Use the __dict__ method to return only instance attributes for attr, value in key_usage_object.__dict__.items(): # Only return the enabled (True) Key Usage attributes if value is True: # No idea why the names are '_private'? key_usage.append(attr.lstrip("_")) - except (UnboundLocalError, ValueError): + except ( + UnboundLocalError, + ValueError, + ): # TODO: you are only getting these errors because you are passing the error above pass return key_usage -def get_ext_key_usage(cert: x509.Certificate) -> list: +def get_ext_key_usage(cert: x509.Certificate) -> list[str]: """Returns list of Extended key usages""" + # TODO same pattern as get_key_usage ext_key_usage = [] try: ext_key_usage_object = cert.extensions.get_extension_for_oid( @@ -253,12 +287,12 @@ def get_ext_key_usage(cert: x509.Certificate) -> list: def get_before_and_after(cert: x509.Certificate) -> tuple: - """Returns tuple of notAfter and notBefore datetimes""" + """Returns tuple of not_after and not_before datetimes""" # Ignore pyright parse export error, it's pendulums fault # https://github.com/sdispater/pendulum/pull/693 - notAfter = pendulum.parse(cert.not_valid_after.isoformat()) - notBefore = pendulum.parse(cert.not_valid_before.isoformat()) - return notBefore, notAfter + not_after = pendulum.parse(cert.not_valid_after.isoformat()) + not_before = pendulum.parse(cert.not_valid_before.isoformat()) + return not_before, not_after def get_crls(cert: x509.Certificate): @@ -274,20 +308,31 @@ def get_crls(cert: x509.Certificate): # )> # ])> crls = "" + # TODO: I think this is unnecessary, just assign your comprehension instead crl_distribution_points = [] try: + # TODO: + # # This would be simpler as comprehension + # crl_distribution_points = [ + # crl_dp.full_name for crl_dp in cert.extensions.get_extension_for_oid(ExtensionOID.CRL_DISTRIBUTION_POINTS ).value + # ] + # + # # Then the bunch of 'for's can be a nested comprehension that you can just return + # return " ".join([crl.value for crl in crl_list for crl_list in crl_distribution_points]) crl_distribution_points_object = cert.extensions.get_extension_for_oid( ExtensionOID.CRL_DISTRIBUTION_POINTS ).value [ - crl_distribution_points.append(crl_dp.full_name) + crl_distribution_points.append(crl_dp.full_name) # type: ignore for _, crl_dp in enumerate(crl_distribution_points_object) ] for crl_list in crl_distribution_points: crls_list = [] for crl in crl_list: crls_list.append(crl.value) - crls = " ".join(crls_list) + crls = " ".join( + crls_list + ) # TODO I think this is overwriting crls unless you really only want the last one except (ValueError, ExtensionNotFound): pass return crls @@ -305,9 +350,9 @@ def get_ocsp_and_caissuer(cert: x509.Certificate) -> tuple: name = access.access_method._name location = access.access_location._value if "OCSP" in name: - ocsp = location + ocsp = location # FIXME: Overwriting instead of appending? elif "caIssuers" in name: - ca_issuers = location + ca_issuers = location # FIXME: Overwriting instead of appending? except (ValueError, ExtensionNotFound): pass return ocsp, ca_issuers @@ -336,7 +381,7 @@ def get_key_bits(cert: x509.Certificate) -> int: def get_key_factors(cert: x509.Certificate) -> dict: """Returns dict w/ modulus size and exponent from public key bits""" - """or other key factors where appropriate""" + """or other key factors where appropriate""" # TODO: docstrings are multi line key_factors: Dict[str, int] = {} public_key = cert.public_key() # These are only of vague interest for CTF competitions, etc @@ -348,7 +393,15 @@ def get_key_factors(cert: x509.Certificate) -> dict: key_factors["y"] = public_key.public_numbers().y return key_factors -def get_sig_algorithm(cert: x509.Certificate): + +def get_sig_algorithm(cert: x509.Certificate) -> str | None: + """Return the signature algorithm""" + # TODO: if/else/return can quite often be written as a ternary return: + # + # return a if a > b else b + # + # # In this case it would be a bit unwieldy + # return cert.signature_algorithm_oid._name if isinstance(cert.signature_hash_algorithm, hashes.HashAlgorithm) else None if isinstance(cert.signature_hash_algorithm, hashes.HashAlgorithm): sig_algo = cert.signature_algorithm_oid._name else: @@ -357,7 +410,9 @@ def get_sig_algorithm(cert: x509.Certificate): def get_sig_algorithm_params(cert: x509.Certificate) -> str: + """Return the signature algorithm parameters""" pss = cert.signature_algorithm_parameters + # TODO: Take a look at https://arjancodes.com/blog/how-to-use-structural-pattern-matching-in-python/ try: if isinstance(pss, padding.PSS): return "PSS" @@ -369,4 +424,3 @@ def get_sig_algorithm_params(cert: x509.Certificate) -> str: return "n/a" except AttributeError: return "n/a" - diff --git a/lib/nice_certificate.py b/src/tlsserial/nice_certificate.py similarity index 98% rename from lib/nice_certificate.py rename to src/tlsserial/nice_certificate.py index 5da5526..c5d8331 100644 --- a/lib/nice_certificate.py +++ b/src/tlsserial/nice_certificate.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field from datetime import datetime -from cryptography import x509 @dataclass(order=True) diff --git a/tlsserial.py b/src/tlsserial/tlsserial.py similarity index 65% rename from tlsserial.py rename to src/tlsserial/tlsserial.py index 6a9bc7b..484798d 100755 --- a/tlsserial.py +++ b/src/tlsserial/tlsserial.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -""" grab some things from a TLS cert """ +"""grab some things from a TLS cert""" + # TODO: # - Report TLS1.3 negotiation for url lookups as NIST SP 800-52 requires support by Jan 2024 # - Swap back to the pyOpenSSL lib to allow getting entire chain from a host? @@ -17,65 +18,50 @@ # - server temp keys # - server public keys # - TLS cipher? -import logging import re import sys -from ssl import OPENSSL_VERSION + import click from cryptography import x509 -from lib import helper -from lib.nice_certificate import NiceCertificate -from lib.color import bold, red, orange, blue - - -# https://click.palletsprojects.com/en/8.1.x/quickstart/ -@click.command() -@click.option( - "--url", - cls=helper.MutuallyExclusiveOption, - mutually_exclusive=["file"], - help="host || host:port || https://host:port/other", -) -@click.option( - "--file", - cls=helper.MutuallyExclusiveOption, - mutually_exclusive=["url"], - help="filename containing a PEM certificate", -) -@click.option("--debug", is_flag=True, type=bool, default=False) -def main(url, file, debug) -> None: - """tlsserial groks X509 certificates for your pleasure""" - level = logging.DEBUG - fmt = "[%(levelname)s] %(asctime)s - %(message)s" - logging.basicConfig(level=level, format=fmt) - - if url: +from . import helper +from .color import blue, bold, orange, red +from .nice_certificate import NiceCertificate + + +def handle_url(url: str, verbose: bool = False) -> None: + """host || host:port || https://host:port/other.""" + try: host, port = get_args(url) - # Assigns all certificates found to tuple cert([c1, c2, ...], "SSL cert") - cert_chain = helper.get_certs_from_host(host, port) - if cert_chain[0] is not None: - for cert in reversed(cert_chain[0]): - display(host, parse_x509(cert), debug) - else: - print(cert_chain[1]) - elif file: - host = "" - # Assigns all certificates found to tuple cert([c1, c2, ...], "SSL cert") - cert_chain = helper.get_certs_from_file(file) - if cert_chain[0] is not None: - for cert in reversed(cert_chain[0]): - display(host, parse_x509(cert), debug) - click.echo("") - else: - print(cert_chain[1]) + except ValueError as err: + print(err) + sys.exit(1) + # Assigns all certificates found to tuple cert([c1, c2, ...], "SSL cert") + cert_chain = helper.get_certs_from_host(host, port) + if cert_chain[0] is not None: + for cert in reversed(cert_chain[0]): + display(host, parse_x509(cert), verbose) else: - click.echo(f"Library version : {OPENSSL_VERSION}") - ctx = click.get_current_context() - click.echo(ctx.get_help()) + print(cert_chain[1]) + + +def handle_file(file: str, verbose: bool = False) -> None: + """ + filename containing a PEM certificate + """ + host = "" + # Assigns all certificates found to tuple cert([c1, c2, ...], "SSL cert") + cert_chain = helper.get_certs_from_file(file) + if cert_chain[0] is not None: + for cert in reversed(cert_chain[0]): + display(host, parse_x509(cert), verbose) + click.echo("") + else: + print(cert_chain[1]) def get_args(argv: str) -> tuple: + # TODO: https://docs.python.org/3/library/urllib.parse.html#url-parsing """ Try to extract a hostname and port from input string Returns a tuple of (host, port) @@ -89,16 +75,16 @@ def get_args(argv: str) -> tuple: return (args_matched[1], args_matched[2]) # host and default port return (args_matched[1], 443) - print(f"Error parsing the input : {argv}") - sys.exit(1) + raise ValueError(f"Error parsing the input : {argv}") def parse_x509(cert: x509.Certificate) -> NiceCertificate: - """Parse an ugly X509 object""" - """Return a NiceCertificate object """ + """Parse an ugly X509 object. + Return a NiceCertificate object. + """ # We use helper functions where parsing is gnarly - notBefore, notAfter = helper.get_before_and_after(cert) + not_before, not_after = helper.get_before_and_after(cert) ocsp, ca_issuers = helper.get_ocsp_and_caissuer(cert) return NiceCertificate( @@ -111,8 +97,8 @@ def parse_x509(cert: x509.Certificate) -> NiceCertificate: basic_constraints=helper.get_basic_constraints(cert), key_usage=helper.get_key_usage(cert), ext_key_usage=helper.get_ext_key_usage(cert), - not_before=notBefore, - not_after=notAfter, + not_before=not_before, + not_after=not_after, crls=helper.get_crls(cert), ocsp=ocsp, serial_as_int=cert.serial_number, @@ -125,7 +111,10 @@ def parse_x509(cert: x509.Certificate) -> NiceCertificate: def display(host: str, cert: NiceCertificate, debug: bool) -> None: - """Print nicely-formatted attributes of a NiceCertificate object""" + """Print nicely-formatted attributes of a NiceCertificate object.""" + # TODO: This function is long and hard to follow. a lot of the `elif`s could be separate + # functions and be individually testable. + # Maybe use a match/case instead of the big if/elif/else (>=py310) print_items = [ "version", "issuer", @@ -147,42 +136,45 @@ def display(host: str, cert: NiceCertificate, debug: bool) -> None: width = 24 matched_host = False + # TODO: Rather than having a lot of print statements here, split the responsibility of building the + # string from the output. Build up a string in one function and return it, then in another just print it. + # You can test a return value easily, harder to test stdout. for item in print_items: if "issuer" == item: print(f"{orange(f'{item:<{width}}')} : {' '.join(cert.issuer)}") elif "chain" == item: if len(cert.__getattribute__(item)) > 0: - print(f"{orange(f'{item:<{width}}')} " f": {orange(' » ').join(cert.__getattribute__(item))}") + print( + f"{orange(f'{item:<{width}}')} " + f": {orange(' » ').join(cert.__getattribute__(item))}" + ) elif "subject" == item: cert.subject = [ f"{c[:5]}{bold(blue(c[5:]))}" if c.endswith(f" {host}") else c for c in cert.subject ] - print(f"{orange(f'{item:<{width}}')} " f": {' '.join(cert.subject)}") + print(f"{orange(f'{item:<{width}}')} : {' '.join(cert.subject)}") elif "subject_alt_name" == item: for san in sorted(cert.sans): if host == str(san) and not matched_host: # Our host arg matches an exact SAN matched_host = True - print(f"{orange(f'{item:<{width}}')} " f": {bold(blue(san))}") + print(f"{orange(f'{item:<{width}}')} : {bold(blue(san))}") elif ( str(san).endswith(re.sub("^[a-z1-9_-]+", "*", host)) and not matched_host ): # Our host arg matches a wildcard SAN matched_host = True - print(f"{orange(f'{item:<{width}}')} " f": {orange(san)}") + print(f"{orange(f'{item:<{width}}')} : {orange(san)}") else: - print(f"{orange(f'{item:<{width}}')} " f": {san}") + print(f"{orange(f'{item:<{width}}')} : {san}") elif "basic_constraints" == item: # Lets highlight any certs which are CAs - if cert.basic_constraints['ca'] == 'True': - cert.basic_constraints['ca'] = orange('True') - for item in ['ca', 'path_length']: - print( - f"{orange(f'{item:<{width}}')} " - f": {cert.basic_constraints[item]}" - ) + if cert.basic_constraints["ca"] == "True": + cert.basic_constraints["ca"] = orange("True") + for item in ["ca", "path_length"]: + print(f"{orange(f'{item:<{width}}')} : {cert.basic_constraints[item]}") elif "serial_number" == item: print( f"{orange(f'{item:<{width}}')} " @@ -205,10 +197,7 @@ def display(host: str, cert: NiceCertificate, debug: bool) -> None: f": {cert.key_type} ({cert.key_bits} bit)" ) if debug: - print( - f"{orange(f'{item:<{width}}')} " - f": Factors: {cert.key_factors}" - ) + print(f"{orange(f'{item:<{width}}')} : Factors: {cert.key_factors}") elif "signature_algorithm" == item: print( f"{orange(f'{item:<{width}}')} " @@ -220,8 +209,4 @@ def display(host: str, cert: NiceCertificate, debug: bool) -> None: f": {', '.join(sorted(cert.__getattribute__(item)))}" ) else: - print(f"{orange(f'{item:<{width}}')} " f": {cert.__getattribute__(item)}") - - -if __name__ == "__main__": - main() + print(f"{orange(f'{item:<{width}}')} : {cert.__getattribute__(item)}") diff --git a/tests/test_happy_certificate.py b/tests/test_happy_certificate.py new file mode 100644 index 0000000..946f03c --- /dev/null +++ b/tests/test_happy_certificate.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass, field + +import pendulum +import pytest +from cryptography import x509 +from cryptography.hazmat.backends import default_backend + +from tlsserial.happy_certificate import HappyCertificate + + +@pytest.fixture +def test_cert() -> x509.Certificate: + """Load a test certificate.""" + with open("test_cert.pem", "rb") as f: + cert = x509.load_pem_x509_certificate(f.read(), default_backend()) + return cert + +def test_happy_certificate(test_cert): + """Test the HappyCertificate class.""" + happy_cert = HappyCertificate(test_cert) + + assert happy_cert.version == 3 + assert isinstance(happy_cert.issuer, list) + assert "[CN] localhost" in happy_cert.issuer + assert isinstance(happy_cert.subject, list) + assert "[CN] localhost" in happy_cert.subject + assert isinstance(happy_cert.sans, list) + # Add assertions for sans based on test_cert + + assert isinstance(happy_cert.basic_constraints, dict) + # Add assertions for basic_constraints based on test_cert + + assert isinstance(happy_cert.key_usage, list) + # Add assertions for key_usage based on test_cert + + assert isinstance(happy_cert.ext_key_usage, list) + # Add assertions for ext_key_usage based on test_cert + + not_before, not_after = happy_cert.before_and_after + assert isinstance(not_before, pendulum.DateTime) + assert isinstance(not_after, pendulum.DateTime) + assert not_before < not_after + + assert isinstance(happy_cert.crls, str) + # Add assertions for crls based on test_cert + + ocsp, ca_issuers = happy_cert.ocsp_and_caissuer + assert isinstance(ocsp, str) + assert isinstance(ca_issuers, str) + # Add assertions for ocsp and ca_issuers based on test_cert + + assert isinstance(happy_cert.key_type, str) + # Add assertions for key_type based on test_cert + + assert isinstance(happy_cert.key_bits, int) + # Add assertions for key_bits based on test_cert + + assert isinstance(happy_cert.key_factors, dict) + # Add assertions for key_factors based on test_cert + + assert isinstance(happy_cert.sig_algorithm, str) + # Add assertions for sig_algorithm based on test_cert + + assert isinstance(happy_cert.sig_algorithm_params, str) + # Add assertions for sig_algorithm_params based on test_cert + + assert print(happy_cert) == 'foo' \ No newline at end of file diff --git a/tests/test_helper.py b/tests/test_helper.py new file mode 100644 index 0000000..ea35571 --- /dev/null +++ b/tests/test_helper.py @@ -0,0 +1,171 @@ +from unittest import mock + +from cryptography.hazmat.primitives.asymmetric import padding + +from tlsserial import helper + + +def test_timethis(): + @helper.timethis + def func(): + pass + + func() + + +@mock.patch("tlsserial.helper.socket.create_connection") +def test_get_certs_from_host_success(mock_create_connection): + mock_socket = mock.MagicMock() + mock_socket._sslobj = mock.MagicMock() + mock_socket._sslobj.get_verified_chain.return_value = [] + mock_socket.getpeercert.return_value = b"cert" + mock_create_connection.return_value = mock_socket + certs, msg = helper.get_certs_from_host("example.com") + assert certs is not None + assert msg == "SSL certificate" + + +@mock.patch("tlsserial.helper.socket.create_connection") +def test_get_certs_from_host_failure(mock_create_connection): + mock_create_connection.side_effect = Exception("Error") + certs, msg = helper.get_certs_from_host("example.com") + assert certs is None + assert msg == "Error" + + +@mock.patch("tlsserial.helper.open") +def test_get_certs_from_file_success(mock_open): + mock_open.return_value = mock.MagicMock() + mock_open.return_value.__enter__.return_value.read.return_value = "cert" + certs, msg = helper.get_certs_from_file("test.pem") + assert certs is not None + assert msg == "SSL certificate" + + +@mock.patch("tlsserial.helper.open") +def test_get_certs_from_file_failure(mock_open): + mock_open.side_effect = FileNotFoundError("Error") + certs, msg = helper.get_certs_from_file("test.pem") + assert certs is None + assert msg == "Error" + + +def test_get_version(): + cert = mock.MagicMock() + cert.version.value = 1 + assert helper.get_version(cert) == 2 + + +def test_get_issuer(): + cert = mock.MagicMock() + cert.issuer.get_attributes_for_oid.return_value = [mock.MagicMock()] + assert helper.get_issuer(cert) == ["[CN] value"] + + +def test_get_subject(): + cert = mock.MagicMock() + cert.subject.get_attributes_for_oid.return_value = [mock.MagicMock()] + assert helper.get_subject(cert) == ["[CN] value"] + + +def test_get_sans(): + cert = mock.MagicMock() + cert.extensions.get_extension_for_oid.return_value.value.get_values_for_type.return_value = [ + "example.com" + ] + assert helper.get_sans(cert) == ["example.com"] + + +def test_get_basic_constraints(): + cert = mock.MagicMock() + cert.extensions.get_extension_for_oid.return_value.value.ca = True + cert.extensions.get_extension_for_oid.return_value.value.path_length = 1 + assert helper.get_basic_constraints(cert) == {"ca": "True", "path_length": "1"} + + +def test_get_key_usage(): + cert = mock.MagicMock() + cert.extensions.getextension_for_oid.return_value.value.__dict__ = { + "_digital_signature": True + } + assert helper.get_key_usage(cert) == ["digital_signature"] + + +def test_get_ext_key_usage(): + cert = mock.MagicMock() + cert.extensions.get_extension_for_oid.return_value.value.__dict__ = { + "_usages": [mock.MagicMock()] + } + cert.extensions.get_extension_for_oid.return_value.value._usages[ + 0 + ]._name = "serverAuth" + assert helper.get_ext_key_usage(cert) == ["serverAuth"] + + +def test_get_before_and_after(): + cert = mock.MagicMock() + cert.not_valid_after = mock.MagicMock() + cert.not_valid_before = mock.MagicMock() + cert.not_valid_after.isoformat.return_value = "2024-01-01T00:00:00" + cert.not_valid_before.isoformat.return_value = "2023-01-01T00:00:00" + assert helper.get_before_and_after(cert) == (mock.ANY, mock.ANY) + + +def test_get_crls(): + cert = mock.MagicMock() + cert.extensions.get_extension_for_oid.return_value.value = [mock.MagicMock()] + cert.extensions.get_extension_for_oid.return_value.value[0].full_name = [ + mock.MagicMock() + ] + cert.extensions.get_extension_for_oid.return_value.value[0].full_name[ + 0 + ].value = "http://example.com/crl" + assert helper.get_crls(cert) == "http://example.com/crl" + + +def test_get_ocsp_and_caissuer(): + cert = mock.MagicMock() + cert.extensions.get_extension_for_oid.return_value.value = [mock.MagicMock()] + cert.extensions.get_extension_for_oid.return_value.value[ + 0 + ].access_method._name = "OCSP" + cert.extensions.get_extension_for_oid.return_value.value[ + 0 + ].access_location._value = "http://example.com/ocsp" + assert helper.get_ocsp_and_caissuer(cert) == ("http://example.com/ocsp", "") + + +def test_get_key_type(): + cert = mock.MagicMock() + cert.public_key.return_value = mock.MagicMock() + cert.public_key.return_value.__class__.__name__ = "RSAPublicKey" + assert helper.get_key_type(cert) == "RSAPublicKey" + + +def test_get_key_bits(): + cert = mock.MagicMock() + cert.public_key.return_value = mock.MagicMock() + cert.public_key.return_value.key_size = 2048 + assert helper.get_key_bits(cert) == 2048 + + +def test_get_key_factors(): + cert = mock.MagicMock() + cert.public_key.return_value = mock.MagicMock() + cert.public_key.return_value.public_numbers.return_value = mock.MagicMock() + cert.public_key.return_value.public_numbers.return_value.e = 65537 + cert.public_key.return_value.public_numbers.return_value.n = 123456789 + assert helper.get_key_factors(cert) == {"exponent": 65537, "n": 123456789} + + +def test_get_sig_algorithm(): + cert = mock.MagicMock() + cert.signature_algorithm_oid._name = "sha256WithRSAEncryption" + assert helper.get_sig_algorithm(cert) == "sha256WithRSAEncryption" + + +def test_get_sig_algorithm_params(): + cert = mock.MagicMock() + cert.signature_algorithm_parameters = mock.MagicMock() + cert.signature_algorithm_parameters.__class__ = padding.PSS + assert helper.get_sig_algorithm_params() == "PSS" diff --git a/tests/test_tlsserial.py b/tests/test_tlsserial.py new file mode 100644 index 0000000..684b712 --- /dev/null +++ b/tests/test_tlsserial.py @@ -0,0 +1,80 @@ +from unittest import mock + +import pytest +from click.testing import CliRunner + +from tlsserial import cli, tlsserial + + +# We need to mock the helper functions to avoid making network calls +@mock.patch("tlsserial.helper.get_certs_from_host") +@mock.patch("tlsserial.helper.get_certs_from_file") +def test_cli_url(mock_get_certs_from_file, mock_get_certs_from_host): + runner = CliRunner() + result = runner.invoke(cli.main, ["--url", "example.com"]) + assert result.exit_code == 0 + mock_get_certs_from_host.assert_called_once() + mock_get_certs_from_file.assert_not_called() + + +@mock.patch("tlsserial.helper.get_certs_from_host") +@mock.patch("tlsserial.helper.get_certs_from_file") +def test_cli_file(mock_get_certs_from_file, mock_get_certs_from_host): + runner = CliRunner() + result = runner.invoke(cli.main, ["--file", "test.pem"]) + assert result.exit_code == 0 + mock_get_certs_from_file.assert_called_once() + mock_get_certs_from_host.assert_not_called() + + +@mock.patch("tlsserial.helper.get_certs_from_host") +@mock.patch("tlsserial.helper.get_certs_from_file") +def test_cli_no_args(mock_get_certs_from_file, mock_get_certs_from_host): + runner = CliRunner() + result = runner.invoke(cli.main) + assert result.exit_code == 0 + mock_get_certs_from_file.assert_not_called() + mock_get_certs_from_host.assert_not_called() + + +@mock.patch("tlsserial.helper.get_certs_from_host") +def test_handle_url(mock_get_certs_from_host): + mock_get_certs_from_host.return_value = ([mock.MagicMock()], "SSL cert") + tlsserial.handle_url("example.com") + mock_get_certs_from_host.assert_called_once_with("example.com", 443) + + +@mock.patch("tlsserial.helper.get_certs_from_file") +def test_handle_file(mock_get_certs_from_file): + mock_get_certs_from_file.return_value = ([mock.MagicMock()], "SSL cert") + tlsserial.handle_file("test.pem") + mock_get_certs_from_file.assert_called_once_with("test.pem") + + +def test_get_args(): + assert tlsserial.get_args("example.com") == ("example.com", 443) + assert tlsserial.get_args("example.com:8080") == ("example.com", "8080") + with pytest.raises(SystemExit): + tlsserial.get_args("invalid input") + + +# NOTE: Parametrized test cases allow you to use lots of examples + + +@pytest.mark.parametrize( + "input_str, expected", + [ + ("example.com", ("example.com", 443)), + ("example.com:8080", ("example.com", "8080")), + ("http://example.com:8080", ("example.com", "8080")), + ("https://example.com:8080", ("example.com", "8080")), + ("ldaps://example.com:8080", ("example.com", "8080")), + ], +) +def test_get_args_success(input_str, expected): + assert tlsserial.get_args(input_str) == expected + + +def test_get_args_failure(): + with pytest.raises(SystemExit): + tlsserial.get_args("invalid input")