From f78aed71a38acbeb542c5f8872b8c855e4cf7b0d Mon Sep 17 00:00:00 2001 From: Kieren Eaton <499977+circulon@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:48:40 +0800 Subject: [PATCH 01/10] updated project setup use ruff ito replace black, isort and flake8 updated github action moved all project setup into pyproject.toml updated makefile for better testing and initial setup --- .github/workflows/pythonapp.yml | 104 ++++++++++++++++++---------- .gitignore | 3 +- .pre-commit-config.yaml | 48 ++++++------- makefile | 63 ++++++++++------- pyproject.toml | 104 ++++++++++++++++++++++------ pytest.ini | 3 - requirements.txt | 7 +- setup.py | 116 -------------------------------- 8 files changed, 213 insertions(+), 235 deletions(-) delete mode 100644 pytest.ini delete mode 100644 setup.py diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 316f69ef..411d37d1 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -1,73 +1,107 @@ -name: Test Application +name: CI -on: [push, pull_request] +on: + push: + branches: [main, master] + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true jobs: - build: + lint: + name: Lint (ruff) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + + - name: Install dev dependencies + run: make init-ci + + - name: Lint & format check + run: make check + + test: + name: Python ${{ matrix.python-version }} runs-on: ubuntu-latest + needs: lint services: postgres: - image: postgres:10.8 + image: postgres:16 env: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres POSTGRES_DB: postgres ports: - # will assign a random free host port - 5432/tcp - # needed because the postgres container does not provide a healthcheck - options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 mysql: - image: mysql:5.7 + image: mysql:8.0 env: MYSQL_ALLOW_EMPTY_PASSWORD: yes MYSQL_DATABASE: orm ports: - - 3306 - options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 + - 3306/tcp + options: >- + --health-cmd "mysqladmin ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 strategy: + fail-fast: false matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] - name: Python ${{ matrix.python-version }} + steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - make init-ci - - name: Test with pytest + cache: pip + + - name: Install test dependencies + run: make init-ci + + - name: Run migrations env: POSTGRES_DATABASE_HOST: localhost + POSTGRES_DATABASE_PORT: ${{ job.services.postgres.ports[5432] }} POSTGRES_DATABASE_DATABASE: postgres POSTGRES_DATABASE_USER: postgres POSTGRES_DATABASE_PASSWORD: postgres - POSTGRES_DATABASE_PORT: ${{ job.services.postgres.ports[5432] }} MYSQL_DATABASE_HOST: localhost + MYSQL_DATABASE_PORT: ${{ job.services.mysql.ports[3306] }} MYSQL_DATABASE_DATABASE: orm MYSQL_DATABASE_USER: root - MYSQL_DATABASE_PORT: ${{ job.services.mysql.ports[3306] }} DB_CONFIG_PATH: tests/integrations/config/database.py run: | python orm migrate --connection postgres python orm migrate --connection mysql - make test - lint: - runs-on: ubuntu-latest - name: Lint - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.6 - uses: actions/setup-python@v3 - with: - python-version: 3.12 - - name: Install Flake8 - run: | - pip install flake8-pyproject - - name: Lint - run: make lint + + - name: Run tests + env: + POSTGRES_DATABASE_HOST: localhost + POSTGRES_DATABASE_PORT: ${{ job.services.postgres.ports[5432] }} + POSTGRES_DATABASE_DATABASE: postgres + POSTGRES_DATABASE_USER: postgres + POSTGRES_DATABASE_PASSWORD: postgres + MYSQL_DATABASE_HOST: localhost + MYSQL_DATABASE_PORT: ${{ job.services.mysql.ports[3306] }} + MYSQL_DATABASE_DATABASE: orm + MYSQL_DATABASE_USER: root + DB_CONFIG_PATH: tests/integrations/config/database.py + run: make test diff --git a/.gitignore b/.gitignore index 713b32f2..68817677 100644 --- a/.gitignore +++ b/.gitignore @@ -18,5 +18,4 @@ coverage.xml *.log build /orm.sqlite3 -/.bootstrapped-pip -/.ignore-pre-commit +/.bootstrapped-* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e3bac65e..932dd384 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,32 +1,24 @@ repos: - - repo: https://github.com/pycqa/isort - rev: 6.0.1 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 hooks: - - id: isort - args: [--profile=black] - exclude: | - (?x)( - ^build| - ^conda - ) + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-added-large-files + args: [--maxkb=500] + - id: debug-statements - - repo: https://github.com/psf/black - rev: 25.1.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.6 hooks: - - id: black - exclude: | - (?x)( - ^build| - ^conda - ) - - - repo: https://github.com/pycqa/flake8 - rev: 7.1.2 - hooks: - - id: flake8 - additional_dependencies: [flake8-pyproject] - exclude: | - (?x)( - ^build| - ^conda - ) + - id: ruff-check + args: [--fix, --exit-non-zero-on-fix] + exclude: &exclude_patterns | + (?x)( + ^build/| + ^conda/ + ) + - id: ruff-format + exclude: *exclude_patterns diff --git a/makefile b/makefile index 1da923aa..b53efc5b 100644 --- a/makefile +++ b/makefile @@ -1,54 +1,65 @@ SHELL := /bin/bash -init: .env .bootstrapped-pip .git/hooks/pre-commit -init-ci: - touch .ignore-pre-commit - make init +.PHONY: init +init: .env .bootstrapped-dev -.bootstrapped-pip: requirements.txt +.PHONY: init-ci +init-ci: .env .bootstrapped-tests + +.bootstrapped-tests: pip install -r requirements.txt - touch .bootstrapped-pip + touch .bootstrapped-tests -.git/hooks/pre-commit: - @if ! test -e ".ignore-pre-commit"; then \ - pip install pre-commit; \ - pre-commit install --install-hooks; \ - fi +.bootstrapped-dev: .bootstrapped-tests + pip install pre-commit faker + pre-commit install + touch .bootstrapped-dev .env: cp .env-example .env # Create MySQL Database # Create Postgres Database -test: init + +.PHONY: test +test: .bootstrapped-tests python -m pytest tests + +.PHONY: ci ci: make test -check: format sort lint -lint: - flake8 src/masoniteorm tests -format: init - black src/masoniteorm tests/ -sort: init - isort src/masoniteorm tests/ + +.PHONY: check +check: format lint + +.PHONY: lint +lint: .bootstrapped-tests + ruff check --fix --exit-non-zero-on-fix src/masoniteorm tests + +format: .bootstrapped-tests + ruff format --check src/masoniteorm tests/ + coverage: python -m pytest --cov-report term --cov-report xml --cov=src/masoniteorm tests/ python -m coveralls + show: python -m pytest --cov-report term --cov-report html --cov=src/masoniteorm tests/ + cov: python -m pytest --cov-report term --cov-report xml --cov=src/masoniteorm tests/ + publish: - pip install twine + pip install build twine make test - python setup.py sdist + python -m build twine upload dist/* - rm -fr build dist .egg masonite.egg-info - rm -rf dist/* + rm -rf build dist *.egg-info + pub: - python setup.py sdist + python -m build twine upload dist/* - rm -fr build dist .egg masonite.egg-info - rm -rf dist/* + rm -rf build dist *.egg-info + pypirc: cp .pypirc ~/.pypirc diff --git a/pyproject.toml b/pyproject.toml index e141db59..26214889 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,21 +1,85 @@ -[tool.black] -target-version = ['py38'] -include = '\.pyi?$' -line-length = 79 - -[tool.isort] -profile = "black" -multi_line_output = 3 -include_trailing_comma = true -force_grid_wrap = 0 -use_parentheses = true -ensure_newline_before_comments = true - -[tool.flake8] -ignore = ['E501', 'E203', 'E128', 'E402', 'E731', 'F821', 'E712', 'W503', 'F811'] -#max-line-length = 79 -#max-complexity = 18 -per-file-ignores = [ - '__init__.py:F401', - 'setup.py:E266', +[build-system] +requires = ["setuptools>=68", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "masonite-orm" +version = "3.0.0" +description = "The Official Masonite ORM" +readme = { file = "README.md", content-type = "text/markdown" } +license = { text = "MIT" } +authors = [{ name = "Joe Mancuso", email = "joe@masoniteproject.com" }] +requires-python = ">=3.8" +keywords = ["Masonite", "MasoniteFramework", "Python", "ORM"] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Topic :: Software Development :: Build Tools", + "Environment :: Web Environment", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Framework :: Masonite", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dependencies = [ + "inflection>=0.3,<0.6", + "pendulum>=3.0,<4.0", + "cleo>=0.8.0,<2.0", ] + +[project.optional-dependencies] +seeder = ["faker"] + +[project.urls] +Homepage = "https://github.com/masoniteframework/orm" + +[project.scripts] +masonite-orm = "masoniteorm.commands.Entry:application.run" + +[tool.setuptools.package-dir] +"" = "src" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.pytest.ini_options] +env = [ + "D:DB_CONFIG_PATH=tests/integrations/config/database", +] + +[tool.ruff] +target-version = "py38" +line-length = 79 +exclude = ["build", "conda"] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "UP", # pyupgrade +] +ignore = [ +# "E501", # line too long +# "E203", # whitespace before ':' + "E402", # module-level import not at top of file + "E731", # do not assign a lambda expression +# "E712", # comparison to True/False +# "F821", # undefined name +# "F811", # redefinition of unused name +] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] # imported but unused + +[tool.ruff.lint.isort] +force-sort-within-sections = true +combine-as-imports = true +forced-separate = ["tests"] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index a3af8f8c..00000000 --- a/pytest.ini +++ /dev/null @@ -1,3 +0,0 @@ -[pytest] -env = - D:DB_CONFIG_PATH=tests/integrations/config/database diff --git a/requirements.txt b/requirements.txt index ca245376..b608f926 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,7 @@ -flake8-pyproject -black -isort -faker +ruff pytest pytest-env -pytest-cov +coverage pymysql inflection>=0.3 psycopg2-binary diff --git a/setup.py b/setup.py deleted file mode 100644 index c58f78ba..00000000 --- a/setup.py +++ /dev/null @@ -1,116 +0,0 @@ -from setuptools import setup - -with open("README.md", "r") as fh: - long_description = fh.read() - -setup( - name="masonite-orm", - # Versions should comply with PEP440. For a discussion on single-sourcing - # the version across setup.py and the project code, see - # https://packaging.python.org/en/latest/single_source_version.html - version="3.0.0", - package_dir={"": "src"}, - description="The Official Masonite ORM", - long_description=long_description, - long_description_content_type="text/markdown", - # The project's main homepage. - url="https://github.com/masoniteframework/orm", - # Author details - author="Joe Mancuso", - author_email="joe@masoniteproject.com", - # Choose your license - license="MIT", - # If your package should include things you specify in your MANIFEST.in file - # Use this option if your package needs to include files that are not python files - # like html templates or css files - include_package_data=True, - # List run-time dependencies here. These will be installed by pip when - # your project is installed. For an analysis of "install_requires" vs pip's - # requirements files see: - # https://packaging.python.org/en/latest/requirements.html - install_requires=[ - "inflection>=0.3,<0.6", - "pendulum>=3.0,<4.0", - "cleo>=0.8.0,<2.0", - ], - # See https://pypi.python.org/pypi?%3Aaction=list_classifiers - classifiers=[ - # How mature is this project? Common values are - # 3 - Alpha - # 4 - Beta - # 5 - Production/Stable - "Development Status :: 5 - Production/Stable", - # Indicate who your project is intended for - "Intended Audience :: Developers", - "Topic :: Software Development :: Build Tools", - "Environment :: Web Environment", - # Pick your license as you wish (should match "license" above) - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - # Specify the Python versions you support here. In particular, ensure - # that you indicate whether you support Python 2, Python 3 or both. - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "Framework :: Masonite", - "Topic :: Software Development :: Libraries :: Python Modules", - "Framework :: Masonite", - ], - # What does your project relate to? - keywords="Masonite, MasoniteFramework, Python, ORM", - # You can just specify the packages manually here if your project is - # simple. Or you can use find_packages(). - packages=[ - "masoniteorm", - "masoniteorm.collection", - "masoniteorm.commands", - "masoniteorm.connections", - "masoniteorm.expressions", - "masoniteorm.factories", - "masoniteorm.helpers", - "masoniteorm.migrations", - "masoniteorm.models", - "masoniteorm.observers", - "masoniteorm.pagination", - "masoniteorm.providers", - "masoniteorm.query", - "masoniteorm.query.grammars", - "masoniteorm.query.processors", - "masoniteorm.relationships", - "masoniteorm.schema", - "masoniteorm.schema.platforms", - "masoniteorm.scopes", - "masoniteorm.seeds", - ], - # List additional groups of dependencies here (e.g. development - # dependencies). You can install these using the following syntax, - # for example: - # $ pip install -e .[dev,test] - # $ pip install your-package[dev,test] - extras_require={ - "test": ["coverage", "pytest", "faker"], - "seeder": ["faker"], - }, - # If there are data files included in your packages that need to be - # installed, specify them here. If using Python 2.6 or less, then these - # have to be included in MANIFEST.in as well. - ## package_data={ - ## 'sample': [], - ## }, - # Although 'package_data' is the preferred approach, in some case you may - # need to place data files outside of your packages. See: - # http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa - # In this case, 'data_file' will be installed into '/my_data' - ## data_files=[('my_data', ['data/data_file.txt'])], - # To provide executable scripts, use entry points in preference to the - # "scripts" keyword. Entry points provide cross-platform support and allow - # pip to create the appropriate form of executable for the target platform. - entry_points={ - "console_scripts": [ - "masonite-orm = masoniteorm.commands.Entry:application.run", - ], - }, -) From a93e73c4a0bee354d79e82b8019856f3b8477f95 Mon Sep 17 00:00:00 2001 From: Kieren Eaton <499977+circulon@users.noreply.github.com> Date: Tue, 17 Mar 2026 15:02:14 +0800 Subject: [PATCH 02/10] updated project spec --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 26214889..4a984878 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ env = [ [tool.ruff] target-version = "py38" -line-length = 79 +line-length = 88 exclude = ["build", "conda"] [tool.ruff.lint] @@ -78,6 +78,7 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] # imported but unused +"tests/**/*.py" = ["E501"] [tool.ruff.lint.isort] force-sort-within-sections = true From 5d98903b47bbf873b43de7844a23cc12786777aa Mon Sep 17 00:00:00 2001 From: Kieren Eaton <499977+circulon@users.noreply.github.com> Date: Tue, 17 Mar 2026 15:05:40 +0800 Subject: [PATCH 03/10] linted and formatted tests formating cleanup sorted imports removed duplicate tests other fixes --- tests/User.py | 2 +- tests/collection/test_collection.py | 4 +- tests/commands/test_shell.py | 1 + tests/eagers/test_eager.py | 12 ++---- tests/factories/test_factories.py | 4 +- tests/models/test_models.py | 14 ++----- .../mssql/builder/test_mssql_query_builder.py | 17 +++----- .../test_mssql_query_builder_relationships.py | 5 +-- .../grammar/test_mssql_update_grammar.py | 2 +- .../mssql/schema/test_mssql_schema_builder.py | 3 +- .../schema/test_mssql_schema_builder_alter.py | 3 +- .../builder/test_mysql_builder_transaction.py | 1 + tests/mysql/builder/test_query_builder.py | 41 ++++--------------- .../builder/test_query_builder_scopes.py | 5 +-- .../grammar/test_mysql_select_grammar.py | 28 +++++-------- .../grammar/test_mysql_update_grammar.py | 2 +- .../model/test_accessors_and_mutators.py | 4 +- tests/mysql/model/test_model.py | 8 +--- .../relationships/test_belongs_to_many.py | 8 +--- .../relationships/test_has_many_through.py | 3 +- .../relationships/test_has_one_through.py | 3 +- .../mysql/relationships/test_relationships.py | 8 +--- .../mysql/schema/test_mysql_schema_builder.py | 33 ++------------- .../schema/test_mysql_schema_builder_alter.py | 23 ++++------- tests/mysql/scopes/test_can_use_scopes.py | 5 +-- tests/mysql/scopes/test_soft_delete.py | 1 + .../builder/test_postgres_query_builder.py | 13 ++---- .../builder/test_postgres_transaction.py | 1 + tests/postgres/grammar/test_select_grammar.py | 32 +++++++-------- tests/postgres/grammar/test_update_grammar.py | 2 +- .../schema/test_postgres_schema_builder.py | 6 ++- .../test_postgres_schema_builder_alter.py | 5 ++- .../builder/test_sqlite_builder_insert.py | 1 + .../builder/test_sqlite_builder_pagination.py | 1 + .../builder/test_sqlite_query_builder.py | 31 +++++--------- ...test_sqlite_query_builder_eager_loading.py | 1 + ...test_sqlite_query_builder_relationships.py | 5 +-- .../sqlite/builder/test_sqlite_transaction.py | 1 + .../grammar/test_sqlite_select_grammar.py | 36 +++++++--------- .../grammar/test_sqlite_update_grammar.py | 2 +- tests/sqlite/models/test_attach_detach.py | 1 + tests/sqlite/models/test_observers.py | 1 + tests/sqlite/models/test_sqlite_model.py | 37 +++++------------ ...st_sqlite_has_many_through_relationship.py | 29 ++++--------- ...est_sqlite_has_one_through_relationship.py | 21 +++------- .../relationships/test_sqlite_polymorphic.py | 1 + .../test_sqlite_relationships.py | 9 +++- .../schema/test_sqlite_schema_builder.py | 13 +++--- .../test_sqlite_schema_builder_alter.py | 21 ++++------ tests/sqlite/schema/test_table.py | 3 +- 50 files changed, 184 insertions(+), 329 deletions(-) diff --git a/tests/User.py b/tests/User.py index 452f0b4d..a2bdb229 100644 --- a/tests/User.py +++ b/tests/User.py @@ -1,4 +1,4 @@ -""" User Model """ +"""User Model""" from src.masoniteorm import Model diff --git a/tests/collection/test_collection.py b/tests/collection/test_collection.py index f1822be1..1687b37c 100644 --- a/tests/collection/test_collection.py +++ b/tests/collection/test_collection.py @@ -3,6 +3,7 @@ from src.masoniteorm.collection import Collection from src.masoniteorm.factories import Factory as factory from src.masoniteorm.models import Model + from tests.User import User @@ -583,8 +584,7 @@ def test_json(self): self.assertEqual( json_data, - '[{"name": "Corentin", "age": 10}, ' - '{"name": "Joe", "age": 20}, {"name": "Marlysson", "age": 15}]', + '[{"name": "Corentin", "age": 10}, {"name": "Joe", "age": 20}, {"name": "Marlysson", "age": 15}]', ) def test_contains(self): diff --git a/tests/commands/test_shell.py b/tests/commands/test_shell.py index 8d432121..7de7b08f 100644 --- a/tests/commands/test_shell.py +++ b/tests/commands/test_shell.py @@ -1,4 +1,5 @@ import unittest + from cleo import CommandTester from src.masoniteorm.commands import ShellCommand diff --git a/tests/eagers/test_eager.py b/tests/eagers/test_eager.py index 98adb5d7..b162c180 100644 --- a/tests/eagers/test_eager.py +++ b/tests/eagers/test_eager.py @@ -14,9 +14,7 @@ def test_can_register_string_eager_load(self): [{"profile": ["user"]}], ) self.assertEqual( - EagerRelations() - .register("profile.user", "profile.logo") - .get_eagers(), + EagerRelations().register("profile.user", "profile.logo").get_eagers(), [{"profile": ["user", "logo"]}], ) self.assertEqual( @@ -39,9 +37,7 @@ def test_can_register_tuple_eager_load(self): [["profile", "user"]], ) self.assertEqual( - EagerRelations() - .register(("profile.name", "profile.user")) - .get_eagers(), + EagerRelations().register(("profile.name", "profile.user")).get_eagers(), [{"profile": ["name", "user"]}], ) @@ -54,9 +50,7 @@ def test_can_register_list_eager_load(self): [["profile", "user"]], ) self.assertEqual( - EagerRelations() - .register(["profile.name", "profile.user"]) - .get_eagers(), + EagerRelations().register(["profile.name", "profile.user"]).get_eagers(), [{"profile": ["name", "user"]}], ) self.assertEqual( diff --git a/tests/factories/test_factories.py b/tests/factories/test_factories.py index 63ff77f5..868264b0 100644 --- a/tests/factories/test_factories.py +++ b/tests/factories/test_factories.py @@ -36,9 +36,7 @@ def test_can_make_single(self): self.assertIsInstance(user, User) def test_can_make_several(self): - users = factory(User).make( - [{"id": 1, "name": "Joe"}, {"id": 2, "name": "Bob"}] - ) + users = factory(User).make([{"id": 1, "name": "Joe"}, {"id": 2, "name": "Bob"}]) self.assertEqual(users.count(), 2) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index c2b759ed..9263d01f 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -1,6 +1,6 @@ import datetime -import unittest from decimal import Decimal +import unittest import pendulum @@ -50,9 +50,7 @@ class ModelWithBaseModel(BaseModel): class TestModels(unittest.TestCase): def test_model_can_access_str_dates_as_pendulum(self): - model = ModelTest.hydrate( - {"user": "joe", "due_date": "2020-11-28 11:42:07"} - ) + model = ModelTest.hydrate({"user": "joe", "due_date": "2020-11-28 11:42:07"}) self.assertTrue(model.user) self.assertTrue(model.due_date) @@ -71,9 +69,7 @@ def test_model_can_access_str_dates_as_pendulum_from_correct_datetimes( self.assertEqual(model.get_new_date("2020-11-28 11:42:07").hour, 11) def test_model_can_access_str_dates_on_relationships(self): - model = ModelTest.hydrate( - {"user": "joe", "due_date": "2020-11-28 11:42:07"} - ) + model = ModelTest.hydrate({"user": "joe", "due_date": "2020-11-28 11:42:07"}) model.add_relation( { "profile": ModelTest.hydrate( @@ -199,9 +195,7 @@ def test_valid_json_cast(self): self.assertEqual(type(model.payload), dict) - model = ModelTest.hydrate( - {"payload": "{'this': 'should', 'throw': 'error'}"} - ) + model = ModelTest.hydrate({"payload": "{'this': 'should', 'throw': 'error'}"}) self.assertEqual(model.payload, None) diff --git a/tests/mssql/builder/test_mssql_query_builder.py b/tests/mssql/builder/test_mssql_query_builder.py index ea8cb484..16cf68c4 100644 --- a/tests/mssql/builder/test_mssql_query_builder.py +++ b/tests/mssql/builder/test_mssql_query_builder.py @@ -2,6 +2,7 @@ from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder + from tests.utils import MockConnectionFactory @@ -115,9 +116,7 @@ def test_add_select_no_table(self): builder.add_select( "other_test", lambda q: q.max("updated_at").table("different_table"), - ).add_select( - "some_alias", lambda q: q.max("updated_at").table("another_table") - ) + ).add_select("some_alias", lambda q: q.max("updated_at").table("another_table")) self.assertEqual( builder.to_sql(), @@ -168,9 +167,7 @@ def test_where(self): def test_where_exists(self): builder = self.get_builder() builder.where_exists("name") - self.assertEqual( - builder.to_sql(), "SELECT * FROM [users] WHERE EXISTS 'name'" - ) + self.assertEqual(builder.to_sql(), "SELECT * FROM [users] WHERE EXISTS 'name'") def test_limit(self): builder = self.get_builder() @@ -245,9 +242,7 @@ def test_count(self): def test_order_by_asc(self): builder = self.get_builder() builder.order_by("email", "asc") - self.assertEqual( - builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] ASC" - ) + self.assertEqual(builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] ASC") def test_order_by_desc(self): builder = self.get_builder() @@ -445,9 +440,7 @@ def test_latest_multiple(self): def test_oldest(self): builder = self.get_builder() builder.oldest("email") - self.assertEqual( - builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] ASC" - ) + self.assertEqual(builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] ASC") def test_oldest_multiple(self): builder = self.get_builder() diff --git a/tests/mssql/builder/test_mssql_query_builder_relationships.py b/tests/mssql/builder/test_mssql_query_builder_relationships.py index ff070d40..7adaa259 100644 --- a/tests/mssql/builder/test_mssql_query_builder_relationships.py +++ b/tests/mssql/builder/test_mssql_query_builder_relationships.py @@ -6,6 +6,7 @@ from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MSSQLGrammar from src.masoniteorm.relationships import belongs_to + from tests.utils import MockConnectionFactory load_dotenv(".env") @@ -92,9 +93,7 @@ def test_has_reference_to_self_using_class(self): def test_where_has_query(self): builder = self.get_builder() - sql = builder.where_has( - "articles", lambda q: q.where("active", 1) - ).to_sql() + sql = builder.where_has("articles", lambda q: q.where("active", 1)).to_sql() self.assertEqual( sql, """SELECT * FROM [users] WHERE EXISTS (""" diff --git a/tests/mssql/grammar/test_mssql_update_grammar.py b/tests/mssql/grammar/test_mssql_update_grammar.py index 49c6e4ab..aa4c6f17 100644 --- a/tests/mssql/grammar/test_mssql_update_grammar.py +++ b/tests/mssql/grammar/test_mssql_update_grammar.py @@ -1,8 +1,8 @@ import unittest +from src.masoniteorm.expressions import Raw from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MSSQLGrammar -from src.masoniteorm.expressions import Raw class TestMSSQLUpdateGrammar(unittest.TestCase): diff --git a/tests/mssql/schema/test_mssql_schema_builder.py b/tests/mssql/schema/test_mssql_schema_builder.py index 4205bc7e..71f07bb7 100644 --- a/tests/mssql/schema/test_mssql_schema_builder.py +++ b/tests/mssql/schema/test_mssql_schema_builder.py @@ -1,10 +1,11 @@ import unittest -from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import MSSQLConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import MSSQLPlatform +from tests.integrations.config.database import DATABASES + class TestMSSQLSchemaBuilder(unittest.TestCase): maxDiff = None diff --git a/tests/mssql/schema/test_mssql_schema_builder_alter.py b/tests/mssql/schema/test_mssql_schema_builder_alter.py index f1b323e2..e23e6acf 100644 --- a/tests/mssql/schema/test_mssql_schema_builder_alter.py +++ b/tests/mssql/schema/test_mssql_schema_builder_alter.py @@ -1,11 +1,12 @@ import unittest -from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import MSSQLConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import MSSQLPlatform from src.masoniteorm.schema.Table import Table +from tests.integrations.config.database import DATABASES + class TestMySQLSchemaBuilderAlter(unittest.TestCase): maxDiff = None diff --git a/tests/mysql/builder/test_mysql_builder_transaction.py b/tests/mysql/builder/test_mysql_builder_transaction.py index 9ee7950a..54cebed8 100644 --- a/tests/mysql/builder/test_mysql_builder_transaction.py +++ b/tests/mysql/builder/test_mysql_builder_transaction.py @@ -5,6 +5,7 @@ from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar + from tests.integrations.config.database import DB if os.getenv("RUN_MYSQL_DATABASE") == "True": diff --git a/tests/mysql/builder/test_query_builder.py b/tests/mysql/builder/test_query_builder.py index 85252f6a..0a33d096 100644 --- a/tests/mysql/builder/test_query_builder.py +++ b/tests/mysql/builder/test_query_builder.py @@ -6,6 +6,7 @@ from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar from src.masoniteorm.relationships import has_many + from tests.integrations.config.database import DATABASES from tests.utils import MockConnectionFactory @@ -160,9 +161,7 @@ def test_find_with_builder_and_list(self): builder = self.get_builder() builder._model = None builder.find([10, 20, 30], column="age", query=True) - sql = ( - """SELECT * FROM `users` WHERE `users`.`age` IN ('10','20','30')""" - ) + sql = """SELECT * FROM `users` WHERE `users`.`age` IN ('10','20','30')""" self.assertEqual(builder.to_sql(), sql) def test_find_with_builder_without_column(self): @@ -514,38 +513,18 @@ def test_or_where(self): )() self.assertEqual(builder.to_sql(), sql) - def test_or_where(self): - builder = self.get_builder() - builder.where("age", "20").or_where("age", "<", 20) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() - self.assertEqual(builder.to_sql(), sql) - def test_where_like_as_operator(self): builder = self.get_builder() builder.where("age", "like", "%name%") sql = getattr(self, "where_like")() self.assertEqual(builder.to_sql(), sql) - def test_where_like(self): - builder = self.get_builder() - builder.where_like("age", "%name%") - sql = getattr(self, "where_like")() - self.assertEqual(builder.to_sql(), sql) - def test_where_not_like_as_operator(self): builder = self.get_builder() builder.where("age", "not like", "%name%") sql = getattr(self, "where_not_like")() self.assertEqual(builder.to_sql(), sql) - def test_where_not_like(self): - builder = self.get_builder() - builder.where_not_like("age", "%name%") - sql = getattr(self, "where_not_like")() - self.assertEqual(builder.to_sql(), sql) - def test_can_call_with_multi_tables(self): builder = self.get_builder() sql = ( @@ -811,9 +790,7 @@ def where_column(self): """ builder.where_column('name', 'username') """ - return ( - "SELECT * FROM `users` WHERE `users`.`name` = `users`.`username`" - ) + return "SELECT * FROM `users` WHERE `users`.`name` = `users`.`username`" def where_null(self): """ @@ -849,9 +826,7 @@ def not_between(self): """ builder.not_between('id', 2, 5) """ - return ( - "SELECT * FROM `users` WHERE `users`.`id` NOT BETWEEN '2' AND '5'" - ) + return "SELECT * FROM `users` WHERE `users`.`id` NOT BETWEEN '2' AND '5'" def having(self): """ @@ -905,7 +880,9 @@ def or_where(self): builder = self.get_builder() builder.where('age', '20').or_where('age','<', 20) """ - return "SELECT * FROM `users` WHERE `users`.`age` = '20' OR `users`.`age` < '20'" + return ( + "SELECT * FROM `users` WHERE `users`.`age` = '20' OR `users`.`age` < '20'" + ) def where_like(self): """ @@ -951,9 +928,7 @@ def update_lock(self): builder = self.get_builder() builder.truncate() """ - return ( - "SELECT * FROM `users` WHERE `users`.`votes` >= '100' FOR UPDATE" - ) + return "SELECT * FROM `users` WHERE `users`.`votes` >= '100' FOR UPDATE" def test_latest(self): builder = self.get_builder() diff --git a/tests/mysql/builder/test_query_builder_scopes.py b/tests/mysql/builder/test_query_builder_scopes.py index 100cc6a3..4686cf22 100644 --- a/tests/mysql/builder/test_query_builder_scopes.py +++ b/tests/mysql/builder/test_query_builder_scopes.py @@ -3,6 +3,7 @@ from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar from src.masoniteorm.scopes import SoftDeleteScope + from tests.integrations.config.database import DATABASES from tests.utils import MockConnectionFactory @@ -65,6 +66,4 @@ def test_global_scope_remove_from_class(self): def test_global_scope_adds_method(self): builder = self.get_builder().set_global_scope(SoftDeleteScope()) - self.assertEqual( - builder.with_trashed().to_sql(), "SELECT * FROM `users`" - ) + self.assertEqual(builder.with_trashed().to_sql(), "SELECT * FROM `users`") diff --git a/tests/mysql/grammar/test_mysql_select_grammar.py b/tests/mysql/grammar/test_mysql_select_grammar.py index aa4d6b40..a87dd772 100644 --- a/tests/mysql/grammar/test_mysql_select_grammar.py +++ b/tests/mysql/grammar/test_mysql_select_grammar.py @@ -59,17 +59,13 @@ def can_compile_with_max_and_columns(self): """ self.builder.select('username').max('age').to_sql() """ - return ( - "SELECT `users`.`username`, MAX(`users`.`age`) AS age FROM `users`" - ) + return "SELECT `users`.`username`, MAX(`users`.`age`) AS age FROM `users`" def can_compile_with_max_and_columns_different_order(self): """ self.builder.max('age').select('username').to_sql() """ - return ( - "SELECT `users`.`username`, MAX(`users`.`age`) AS age FROM `users`" - ) + return "SELECT `users`.`username`, MAX(`users`.`age`) AS age FROM `users`" def can_compile_with_order_by(self): """ @@ -180,7 +176,9 @@ def can_compile_or_where(self): """ self.builder.where('name', 2).or_where('name', 3).to_sql() """ - return "SELECT * FROM `users` WHERE `users`.`name` = '2' OR `users`.`name` = '3'" + return ( + "SELECT * FROM `users` WHERE `users`.`name` = '2' OR `users`.`name` = '3'" + ) def can_grouped_where(self): """ @@ -288,9 +286,7 @@ def can_compile_between(self): """ builder.between('age', 18, 21).to_sql() """ - return ( - "SELECT * FROM `users` WHERE `users`.`age` BETWEEN '18' AND '21'" - ) + return "SELECT * FROM `users` WHERE `users`.`age` BETWEEN '18' AND '21'" def can_compile_not_between(self): """ @@ -337,9 +333,7 @@ def can_compile_first_or_fail(self): builder = self.get_builder() builder.where("is_admin", "=", True).first_or_fail() """ - return ( - """SELECT * FROM `users` WHERE `users`.`is_admin` = '1' LIMIT 1""" - ) + return """SELECT * FROM `users` WHERE `users`.`is_admin` = '1' LIMIT 1""" def where_not_like(self): """ @@ -471,9 +465,7 @@ def update_lock(self): builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ - return ( - "SELECT * FROM `users` WHERE `users`.`votes` >= '100' FOR UPDATE" - ) + return "SELECT * FROM `users` WHERE `users`.`votes` >= '100' FOR UPDATE" def can_user_where_raw_and_where(self): """ @@ -488,7 +480,9 @@ def where_not_exists_with_lambda(self): return """SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `users` WHERE `users`.`age` = '1')""" def where_date(self): - return """SELECT * FROM `users` WHERE DATE(`users`.`created_at`) = '2022-06-01'""" + return ( + """SELECT * FROM `users` WHERE DATE(`users`.`created_at`) = '2022-06-01'""" + ) def or_where_null(self): return """SELECT * FROM `users` WHERE `users`.`column1` IS NULL OR `users`.`column2` IS NULL""" diff --git a/tests/mysql/grammar/test_mysql_update_grammar.py b/tests/mysql/grammar/test_mysql_update_grammar.py index 0212ec8f..dde37bb3 100644 --- a/tests/mysql/grammar/test_mysql_update_grammar.py +++ b/tests/mysql/grammar/test_mysql_update_grammar.py @@ -1,9 +1,9 @@ import inspect import unittest +from src.masoniteorm.expressions import Raw from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar -from src.masoniteorm.expressions import Raw class BaseTestCaseUpdateGrammar: diff --git a/tests/mysql/model/test_accessors_and_mutators.py b/tests/mysql/model/test_accessors_and_mutators.py index 93c90e2c..609785b2 100644 --- a/tests/mysql/model/test_accessors_and_mutators.py +++ b/tests/mysql/model/test_accessors_and_mutators.py @@ -30,9 +30,7 @@ def test_can_get_accessor(self): self.assertTrue(user.is_admin is True, f"{user.is_admin} is not True") def test_mutator(self): - user = SetUser.hydrate( - {"email": "joe@masoniteproject.com", "is_admin": 1} - ) + user = SetUser.hydrate({"email": "joe@masoniteproject.com", "is_admin": 1}) user.name = "joe" diff --git a/tests/mysql/model/test_model.py b/tests/mysql/model/test_model.py index 152949df..0961c4d1 100644 --- a/tests/mysql/model/test_model.py +++ b/tests/mysql/model/test_model.py @@ -8,7 +8,6 @@ from src.masoniteorm.collection import Collection from src.masoniteorm.exceptions import ModelNotFound from src.masoniteorm.models import Model -from tests.User import User class ProfileFillable(Model): @@ -306,8 +305,7 @@ def test_access_as_date(self): user = User.hydrate( { "name": "Joe", - "created_at": datetime.datetime.now() - + datetime.timedelta(days=1), + "created_at": datetime.datetime.now() + datetime.timedelta(days=1), } ) @@ -323,9 +321,7 @@ def test_serialize_with_dirty_attribute(self): profile = ProfileFillAsterisk.hydrate({"name": "Joe", "id": 1}) profile.age = 18 - self.assertEqual( - profile.serialize(), {"age": 18, "name": "Joe", "id": 1} - ) + self.assertEqual(profile.serialize(), {"age": 18, "name": "Joe", "id": 1}) def test_attribute_check_with_hasattr(self): self.assertFalse(hasattr(Profile(), "__password__")) diff --git a/tests/mysql/relationships/test_belongs_to_many.py b/tests/mysql/relationships/test_belongs_to_many.py index 118f7cca..69494f9b 100644 --- a/tests/mysql/relationships/test_belongs_to_many.py +++ b/tests/mysql/relationships/test_belongs_to_many.py @@ -48,7 +48,7 @@ class MySQLRelationships(unittest.TestCase): def test_belongs_to_many(self): sql = Permission.where_has( - "role", lambda query: (query.where("slug", "users")) + "role", lambda query: query.where("slug", "users") ).to_sql() self.assertEqual( @@ -85,11 +85,7 @@ def test_belongs_to_many_or_where_has(self): ) def test_belongs_to_many_or_doesnt_have(self): - sql = ( - Role.where("name", "role_name") - .or_doesnt_have("permissions") - .to_sql() - ) + sql = Role.where("name", "role_name").or_doesnt_have("permissions").to_sql() self.assertEqual( sql, diff --git a/tests/mysql/relationships/test_has_many_through.py b/tests/mysql/relationships/test_has_many_through.py index 3c6d5b7e..f341a52d 100644 --- a/tests/mysql/relationships/test_has_many_through.py +++ b/tests/mysql/relationships/test_has_many_through.py @@ -1,10 +1,11 @@ import unittest +from dotenv import load_dotenv + from src.masoniteorm.models import Model from src.masoniteorm.relationships import ( has_many_through, ) -from dotenv import load_dotenv load_dotenv(".env") diff --git a/tests/mysql/relationships/test_has_one_through.py b/tests/mysql/relationships/test_has_one_through.py index 4337dc83..0078b1dc 100644 --- a/tests/mysql/relationships/test_has_one_through.py +++ b/tests/mysql/relationships/test_has_one_through.py @@ -1,10 +1,11 @@ import unittest +from dotenv import load_dotenv + from src.masoniteorm.models import Model from src.masoniteorm.relationships import ( has_one_through, ) -from dotenv import load_dotenv load_dotenv(".env") diff --git a/tests/mysql/relationships/test_relationships.py b/tests/mysql/relationships/test_relationships.py index 0829b439..a1eddcf0 100644 --- a/tests/mysql/relationships/test_relationships.py +++ b/tests/mysql/relationships/test_relationships.py @@ -52,9 +52,7 @@ def test_or_has(self): ) def test_or_has_nested(self): - sql = ( - User.where("name", "Joe").or_has("profile.identification").to_sql() - ) + sql = User.where("name", "Joe").or_has("profile.identification").to_sql() self.assertEqual( sql, @@ -179,9 +177,7 @@ def test_joins(self): ) def test_join_on(self): - sql = User.join_on( - "profile", lambda q: (q.where("active", 1)) - ).to_sql() + sql = User.join_on("profile", lambda q: q.where("active", 1)).to_sql() self.assertEqual( sql, diff --git a/tests/mysql/schema/test_mysql_schema_builder.py b/tests/mysql/schema/test_mysql_schema_builder.py index b3aefc03..c83a3c84 100644 --- a/tests/mysql/schema/test_mysql_schema_builder.py +++ b/tests/mysql/schema/test_mysql_schema_builder.py @@ -2,7 +2,6 @@ import unittest from src.masoniteorm import Model -from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import MySQLConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import MySQLPlatform @@ -76,8 +75,8 @@ def test_can_add_columns_with_constaint(self): with self.schema.create("users") as blueprint: blueprint.string("name") blueprint.integer("age") - blueprint.unique("name"), - blueprint.unique("name", name="table_unique"), + (blueprint.unique("name"),) + (blueprint.unique("name", name="table_unique"),) self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( @@ -112,30 +111,6 @@ def test_can_add_table_comment(self): ], ) - def test_can_add_columns_with_foreign_key_constaint(self): - with self.schema.create("users") as blueprint: - blueprint.string("name").unique() - blueprint.integer("age") - blueprint.integer("profile_id") - blueprint.foreign("profile_id").references("id").on("profiles") - blueprint.foreign_id("post_id").references("id").on("posts") - blueprint.foreign_id_for(Discussion).references("id").on("discussions") - - self.assertEqual(len(blueprint.table.added_columns), 3) - self.assertEqual( - blueprint.to_sql(), - [ - "CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL, " - "`age` INT(11) NOT NULL, " - "`profile_id` INT(11) NOT NULL, " - "`post_id` BIGINT UNSIGNED NOT NULL, " - "CONSTRAINT users_name_unique UNIQUE (name), " - "CONSTRAINT users_profile_id_foreign FOREIGN KEY (`profile_id`) REFERENCES `profiles`(`id`), " - "CONSTRAINT users_profile_id_foreign FOREIGN KEY (`post_id`) REFERENCES `posts`(`id`)), " - "CONSTRAINT users_discussions_id_foreign FOREIGN KEY (`discussion_id`) REFERENCES `posts`(`id`))" - ], - ) - def test_can_add_columns_with_foreign_key_constaint(self): with self.schema.create("users") as blueprint: blueprint.string("name").unique() @@ -316,7 +291,7 @@ def test_can_have_default_blank_string(self): self.assertEqual( blueprint.to_sql(), - ["CREATE TABLE `users` (" "`profile_id` VARCHAR(255) NOT NULL DEFAULT '')"], + ["CREATE TABLE `users` (`profile_id` VARCHAR(255) NOT NULL DEFAULT '')"], ) def test_can_have_float_type(self): @@ -325,7 +300,7 @@ def test_can_have_float_type(self): self.assertEqual( blueprint.to_sql(), - ["CREATE TABLE `users` (" "`amount` FLOAT(19, 4) NOT NULL)"], + ["CREATE TABLE `users` (`amount` FLOAT(19, 4) NOT NULL)"], ) def test_has_table(self): diff --git a/tests/mysql/schema/test_mysql_schema_builder_alter.py b/tests/mysql/schema/test_mysql_schema_builder_alter.py index 0bd4a6b0..bba8007e 100644 --- a/tests/mysql/schema/test_mysql_schema_builder_alter.py +++ b/tests/mysql/schema/test_mysql_schema_builder_alter.py @@ -4,6 +4,7 @@ from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import MySQLPlatform from src.masoniteorm.schema.Table import Table + from tests.integrations.config.database import DATABASES @@ -69,9 +70,7 @@ def test_can_add_column_after(self): self.assertEqual(len(blueprint.table.added_columns), 1) - sql = [ - "ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL AFTER `age`" - ] + sql = ["ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL AFTER `age`"] self.assertEqual(blueprint.to_sql(), sql) @@ -130,9 +129,9 @@ def test_alter_drop1(self): def test_alter_add_column_and_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.unsigned_integer("playlist_id").nullable() - blueprint.foreign("playlist_id").references("id").on( - "playlists" - ).on_delete("cascade") + blueprint.foreign("playlist_id").references("id").on("playlists").on_delete( + "cascade" + ) sql = [ "ALTER TABLE `users` ADD `playlist_id` INT UNSIGNED NULL", @@ -145,9 +144,7 @@ def test_alter_drop_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.drop_foreign("users_playlist_id_foreign") - sql = [ - "ALTER TABLE `users` DROP FOREIGN KEY users_playlist_id_foreign" - ] + sql = ["ALTER TABLE `users` DROP FOREIGN KEY users_playlist_id_foreign"] self.assertEqual(blueprint.to_sql(), sql) @@ -155,9 +152,7 @@ def test_alter_drop_foreign_key_shortcut(self): with self.schema.table("users") as blueprint: blueprint.drop_foreign(["playlist_id"]) - sql = [ - "ALTER TABLE `users` DROP FOREIGN KEY users_playlist_id_foreign" - ] + sql = ["ALTER TABLE `users` DROP FOREIGN KEY users_playlist_id_foreign"] self.assertEqual(blueprint.to_sql(), sql) @@ -314,9 +309,7 @@ def test_can_add_column_enum(self): def test_can_change_column_enum(self): with self.schema.table("users") as blueprint: - blueprint.enum("status", ["active", "inactive"]).default( - "active" - ).change() + blueprint.enum("status", ["active", "inactive"]).default("active").change() self.assertEqual(len(blueprint.table.changed_columns), 1) diff --git a/tests/mysql/scopes/test_can_use_scopes.py b/tests/mysql/scopes/test_can_use_scopes.py index 354fd0d4..c38207db 100644 --- a/tests/mysql/scopes/test_can_use_scopes.py +++ b/tests/mysql/scopes/test_can_use_scopes.py @@ -2,7 +2,6 @@ from src.masoniteorm.models import Model from src.masoniteorm.scopes import SoftDeletesMixin, scope -from tests.User import User class User(Model): @@ -36,6 +35,4 @@ def test_active_scope_with_params(self): def test_can_chain_scopes(self): sql = "SELECT * FROM `users` WHERE `users`.`active` = '2' AND `users`.`gender` = 'W' AND `users`.`name` = 'joe'" - self.assertEqual( - sql, User.active(2).gender("W").where("name", "joe").to_sql() - ) + self.assertEqual(sql, User.active(2).gender("W").where("name", "joe").to_sql()) diff --git a/tests/mysql/scopes/test_soft_delete.py b/tests/mysql/scopes/test_soft_delete.py index 7d6689b4..a1e9318d 100644 --- a/tests/mysql/scopes/test_soft_delete.py +++ b/tests/mysql/scopes/test_soft_delete.py @@ -4,6 +4,7 @@ from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar from src.masoniteorm.scopes import SoftDeleteScope, SoftDeletesMixin + from tests.integrations.config.database import DATABASES from tests.utils import MockConnectionFactory diff --git a/tests/postgres/builder/test_postgres_query_builder.py b/tests/postgres/builder/test_postgres_query_builder.py index 1431b9f9..4860fa89 100644 --- a/tests/postgres/builder/test_postgres_query_builder.py +++ b/tests/postgres/builder/test_postgres_query_builder.py @@ -4,6 +4,7 @@ from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import PostgresGrammar + from tests.utils import MockConnectionFactory @@ -657,9 +658,7 @@ def where_not_in(self): """ builder.where_not_in('id', [1, 2, 3]) """ - return ( - """SELECT * FROM "users" WHERE "users"."id" NOT IN ('1','2','3')""" - ) + return """SELECT * FROM "users" WHERE "users"."id" NOT IN ('1','2','3')""" def where_in(self): """ @@ -671,9 +670,7 @@ def between(self): """ builder.between('id', 2, 5) """ - return ( - """SELECT * FROM "users" WHERE "users"."id" BETWEEN '2' AND '5'""" - ) + return """SELECT * FROM "users" WHERE "users"."id" BETWEEN '2' AND '5'""" def not_between(self): """ @@ -747,9 +744,7 @@ def where_not_like(self): builder = self.get_builder() builder.where("age", "not like", "%name%") """ - return ( - """SELECT * FROM "users" WHERE "users"."age" NOT ILIKE '%name%'""" - ) + return """SELECT * FROM "users" WHERE "users"."age" NOT ILIKE '%name%'""" def truncate(self): """ diff --git a/tests/postgres/builder/test_postgres_transaction.py b/tests/postgres/builder/test_postgres_transaction.py index 2e7eda8b..ae1c64ff 100644 --- a/tests/postgres/builder/test_postgres_transaction.py +++ b/tests/postgres/builder/test_postgres_transaction.py @@ -5,6 +5,7 @@ from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import PostgresGrammar + from tests.integrations.config.database import DATABASES if os.getenv("RUN_POSTGRES_DATABASE") == "True": diff --git a/tests/postgres/grammar/test_select_grammar.py b/tests/postgres/grammar/test_select_grammar.py index f55d06a4..475d6eee 100644 --- a/tests/postgres/grammar/test_select_grammar.py +++ b/tests/postgres/grammar/test_select_grammar.py @@ -17,9 +17,7 @@ def can_compile_with_columns(self): """ self.builder.select('username', 'password').to_sql() """ - return ( - """SELECT "users"."username", "users"."password" FROM "users\"""" - ) + return """SELECT "users"."username", "users"."password" FROM "users\"""" def can_compile_with_where(self): """ @@ -73,7 +71,9 @@ def can_compile_with_multiple_order_by(self): """ self.builder.select('username').order_by('age', 'desc').order_by('name').to_sql() """ - return """SELECT "users"."username" FROM "users" ORDER BY "age" DESC, "name" ASC""" + return ( + """SELECT "users"."username" FROM "users" ORDER BY "age" DESC, "name" ASC""" + ) def can_compile_with_group_by(self): """ @@ -109,7 +109,9 @@ def can_compile_where_not_null(self): """ self.builder.select('username').where_not_null('age').to_sql() """ - return """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NOT NULL""" + return ( + """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NOT NULL""" + ) def can_compile_where_raw(self): """ @@ -160,9 +162,7 @@ def can_compile_where_column(self): self.builder.where_column('name', 'email').to_sql() """ - return ( - """SELECT * FROM "users" WHERE "users"."name" = "users"."email\"""" - ) + return """SELECT * FROM "users" WHERE "users"."name" = "users"."email\"""" def can_compile_or_where(self): """ @@ -301,9 +301,7 @@ def can_compile_not_between(self): def test_can_compile_where_raw(self): to_sql = self.builder.where_raw(""" "age" = '18'""").to_sql() - self.assertEqual( - to_sql, """SELECT * FROM "users" WHERE "age" = '18'""" - ) + self.assertEqual(to_sql, """SELECT * FROM "users" WHERE "age" = '18'""") def test_can_compile_having_raw(self): to_sql = ( @@ -344,9 +342,7 @@ def test_can_compile_select_raw(self): def test_can_compile_select_raw_with_select(self): to_sql = self.builder.select("id").select_raw("COUNT(*)").to_sql() - self.assertEqual( - to_sql, """SELECT "users"."id", COUNT(*) FROM "users\"""" - ) + self.assertEqual(to_sql, """SELECT "users"."id", COUNT(*) FROM "users\"""") def can_compile_first_or_fail(self): """ @@ -360,9 +356,7 @@ def where_not_like(self): builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ - return ( - """SELECT * FROM "users" WHERE "users"."age" NOT ILIKE '%name%'""" - ) + return """SELECT * FROM "users" WHERE "users"."age" NOT ILIKE '%name%'""" def where_like(self): """ @@ -502,7 +496,9 @@ def where_not_exists_with_lambda(self): return """SELECT * FROM "users" WHERE NOT EXISTS (SELECT * FROM "users" WHERE "users"."age" = '1')""" def where_date(self): - return """SELECT * FROM "users" WHERE DATE("users"."created_at") = '2022-06-01'""" + return ( + """SELECT * FROM "users" WHERE DATE("users"."created_at") = '2022-06-01'""" + ) def or_where_null(self): return """SELECT * FROM "users" WHERE "users"."column1" IS NULL OR "users"."column2" IS NULL""" diff --git a/tests/postgres/grammar/test_update_grammar.py b/tests/postgres/grammar/test_update_grammar.py index 76d19f7f..88d458d3 100644 --- a/tests/postgres/grammar/test_update_grammar.py +++ b/tests/postgres/grammar/test_update_grammar.py @@ -2,9 +2,9 @@ import unittest from src.masoniteorm.connections import PostgresConnection +from src.masoniteorm.expressions import Raw from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import PostgresGrammar -from src.masoniteorm.expressions import Raw class BaseTestCaseUpdateGrammar: diff --git a/tests/postgres/schema/test_postgres_schema_builder.py b/tests/postgres/schema/test_postgres_schema_builder.py index 9e04b3dd..c2270b34 100644 --- a/tests/postgres/schema/test_postgres_schema_builder.py +++ b/tests/postgres/schema/test_postgres_schema_builder.py @@ -1,10 +1,11 @@ import unittest -from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import PostgresConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import PostgresPlatform +from tests.integrations.config.database import DATABASES + class TestPostgresSchemaBuilder(unittest.TestCase): maxDiff = None @@ -377,6 +378,7 @@ def test_can_add_enum(self): self.assertEqual( blueprint.to_sql(), [ - 'CREATE TABLE "users" ("status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL ' 'DEFAULT \'active\')' + "CREATE TABLE \"users\" (\"status\" VARCHAR(255) CHECK(status IN ('active', 'inactive')) NOT NULL " + "DEFAULT 'active')" ], ) diff --git a/tests/postgres/schema/test_postgres_schema_builder_alter.py b/tests/postgres/schema/test_postgres_schema_builder_alter.py index d77d632b..9d929dd2 100644 --- a/tests/postgres/schema/test_postgres_schema_builder_alter.py +++ b/tests/postgres/schema/test_postgres_schema_builder_alter.py @@ -1,11 +1,12 @@ import unittest -from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import PostgresConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import PostgresPlatform from src.masoniteorm.schema.Table import Table +from tests.integrations.config.database import DATABASES + class TestPostgresSchemaBuilderAlter(unittest.TestCase): maxDiff = None @@ -309,7 +310,7 @@ def test_can_add_column_enum(self): self.assertEqual(len(blueprint.table.added_columns), 1) sql = [ - 'ALTER TABLE "users" ADD COLUMN "status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL DEFAULT \'active\'', + "ALTER TABLE \"users\" ADD COLUMN \"status\" VARCHAR(255) CHECK(status IN ('active', 'inactive')) NOT NULL DEFAULT 'active'", ] self.assertEqual(blueprint.to_sql(), sql) diff --git a/tests/sqlite/builder/test_sqlite_builder_insert.py b/tests/sqlite/builder/test_sqlite_builder_insert.py index c09135d5..285b38cb 100644 --- a/tests/sqlite/builder/test_sqlite_builder_insert.py +++ b/tests/sqlite/builder/test_sqlite_builder_insert.py @@ -4,6 +4,7 @@ from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar + from tests.integrations.config.database import DB diff --git a/tests/sqlite/builder/test_sqlite_builder_pagination.py b/tests/sqlite/builder/test_sqlite_builder_pagination.py index 99a5e99b..f15f1d99 100644 --- a/tests/sqlite/builder/test_sqlite_builder_pagination.py +++ b/tests/sqlite/builder/test_sqlite_builder_pagination.py @@ -4,6 +4,7 @@ from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar + from tests.integrations.config.database import DB diff --git a/tests/sqlite/builder/test_sqlite_query_builder.py b/tests/sqlite/builder/test_sqlite_query_builder.py index 3c3597eb..89234729 100644 --- a/tests/sqlite/builder/test_sqlite_query_builder.py +++ b/tests/sqlite/builder/test_sqlite_query_builder.py @@ -1,7 +1,7 @@ import inspect import os -import unittest from pathlib import Path +import unittest import pytest @@ -16,6 +16,7 @@ from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar + from tests.utils import MockConnectionFactory @@ -518,9 +519,7 @@ def test_group_by_raw(self): def test_group_by_multiple(self): builder = self.get_builder(table="payments") - builder.select("user_id").min("salary").group_by("user_id").group_by( - "salary" - ) + builder.select("user_id").min("salary").group_by("user_id").group_by("salary") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") @@ -854,7 +853,9 @@ def order_by_multiple(self): """ builder.order_by('email', 'asc') """ - return """SELECT * FROM "users" ORDER BY "email" ASC, "name" ASC, "active" ASC""" + return ( + """SELECT * FROM "users" ORDER BY "email" ASC, "name" ASC, "active" ASC""" + ) def order_by_raw(self): """ @@ -896,9 +897,7 @@ def where_not_in(self): """ builder.where_not_in('id', [1, 2, 3]) """ - return ( - """SELECT * FROM "users" WHERE "users"."id" NOT IN ('1','2','3')""" - ) + return """SELECT * FROM "users" WHERE "users"."id" NOT IN ('1','2','3')""" def where_in(self): """ @@ -910,9 +909,7 @@ def between(self): """ builder.between('id', 2, 5) """ - return ( - """SELECT * FROM "users" WHERE "users"."id" BETWEEN '2' AND '5'""" - ) + return """SELECT * FROM "users" WHERE "users"."id" BETWEEN '2' AND '5'""" def not_between(self): """ @@ -998,24 +995,18 @@ def where_not_like(self): builder = self.get_builder() builder.where("age", "like", "%name%") """ - return ( - """SELECT * FROM "users" WHERE "users"."age" NOT LIKE '%name%'""" - ) + return """SELECT * FROM "users" WHERE "users"."age" NOT LIKE '%name%'""" def test_when(self): builder = self.get_builder() - sql = builder.when( - 19 > 18, lambda q: q.where("age_restricted", 1) - ).to_sql() + sql = builder.when(19 > 18, lambda q: q.where("age_restricted", 1)).to_sql() return self.assertEqual( sql, """SELECT * FROM "users" WHERE "users"."age_restricted" = '1'""", ) builder = self.get_builder() - sql = builder.when( - 17 > 18, lambda q: q.where("age_restricted", 1) - ).to_sql() + sql = builder.when(17 > 18, lambda q: q.where("age_restricted", 1)).to_sql() return self.assertEqual(sql, """SELECT * FROM "users\"""") def truncate(self): diff --git a/tests/sqlite/builder/test_sqlite_query_builder_eager_loading.py b/tests/sqlite/builder/test_sqlite_query_builder_eager_loading.py index be95dead..79a036a7 100644 --- a/tests/sqlite/builder/test_sqlite_query_builder_eager_loading.py +++ b/tests/sqlite/builder/test_sqlite_query_builder_eager_loading.py @@ -5,6 +5,7 @@ from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar from src.masoniteorm.relationships import belongs_to, has_many + from tests.integrations.config.database import DB diff --git a/tests/sqlite/builder/test_sqlite_query_builder_relationships.py b/tests/sqlite/builder/test_sqlite_query_builder_relationships.py index eb1e24b8..b364ecda 100644 --- a/tests/sqlite/builder/test_sqlite_query_builder_relationships.py +++ b/tests/sqlite/builder/test_sqlite_query_builder_relationships.py @@ -7,6 +7,7 @@ from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar from src.masoniteorm.relationships import belongs_to + from tests.integrations.config.database import DB load_dotenv(".env") @@ -87,9 +88,7 @@ def test_where_doesnt_have(self): def test_where_has_query(self): builder = self.get_builder() - sql = builder.where_has( - "articles", lambda q: q.where("active", 1) - ).to_sql() + sql = builder.where_has("articles", lambda q: q.where("active", 1)).to_sql() self.assertEqual( sql, """SELECT * FROM "users" WHERE EXISTS (""" diff --git a/tests/sqlite/builder/test_sqlite_transaction.py b/tests/sqlite/builder/test_sqlite_transaction.py index f367944f..c507c9fd 100644 --- a/tests/sqlite/builder/test_sqlite_transaction.py +++ b/tests/sqlite/builder/test_sqlite_transaction.py @@ -5,6 +5,7 @@ from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar + from tests.integrations.config.database import DB diff --git a/tests/sqlite/grammar/test_sqlite_select_grammar.py b/tests/sqlite/grammar/test_sqlite_select_grammar.py index ad88acd6..45575648 100644 --- a/tests/sqlite/grammar/test_sqlite_select_grammar.py +++ b/tests/sqlite/grammar/test_sqlite_select_grammar.py @@ -24,9 +24,7 @@ def can_compile_with_columns(self): """ self.builder.select('username', 'password').to_sql() """ - return ( - """SELECT "users"."username", "users"."password" FROM "users\"""" - ) + return """SELECT "users"."username", "users"."password" FROM "users\"""" def can_compile_with_where(self): """ @@ -80,7 +78,9 @@ def can_compile_with_multiple_order_by(self): """ self.builder.select('username').order_by('age', 'desc').order_by('name').to_sql() """ - return """SELECT "users"."username" FROM "users" ORDER BY "age" DESC, "name" ASC""" + return ( + """SELECT "users"."username" FROM "users" ORDER BY "age" DESC, "name" ASC""" + ) def can_compile_with_group_by(self): """ @@ -116,7 +116,9 @@ def can_compile_where_not_null(self): """ self.builder.select('username').where_not_null('age').to_sql() """ - return """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NOT NULL""" + return ( + """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NOT NULL""" + ) def can_compile_where_raw(self): """ @@ -161,9 +163,7 @@ def can_compile_where_column(self): self.builder.where_column('name', 'email').to_sql() """ - return ( - """SELECT * FROM "users" WHERE "users"."name" = "users"."email\"""" - ) + return """SELECT * FROM "users" WHERE "users"."name" = "users"."email\"""" def can_compile_or_where(self): """ @@ -293,9 +293,7 @@ def can_compile_not_between(self): def test_can_compile_where_raw(self): to_sql = self.builder.where_raw(""" "age" = '18'""").to_sql() - self.assertEqual( - to_sql, """SELECT * FROM "users" WHERE "age" = '18'""" - ) + self.assertEqual(to_sql, """SELECT * FROM "users" WHERE "age" = '18'""") def test_can_compile_where_raw_and_where_with_multiple_bindings(self): query = self.builder.where_raw( @@ -313,27 +311,21 @@ def test_can_compile_select_raw(self): def test_can_compile_select_raw_with_select(self): to_sql = self.builder.select("id").select_raw("COUNT(*)").to_sql() - self.assertEqual( - to_sql, """SELECT "users"."id", COUNT(*) FROM "users\"""" - ) + self.assertEqual(to_sql, """SELECT "users"."id", COUNT(*) FROM "users\"""") def can_compile_first_or_fail(self): """ builder = self.get_builder() builder.where("is_admin", "=", True).first_or_fail() """ - return ( - """SELECT * FROM "users" WHERE "users"."is_admin" = '1' LIMIT 1""" - ) + return """SELECT * FROM "users" WHERE "users"."is_admin" = '1' LIMIT 1""" def where_not_like(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ - return ( - """SELECT * FROM "users" WHERE "users"."age" NOT LIKE '%name%'""" - ) + return """SELECT * FROM "users" WHERE "users"."age" NOT LIKE '%name%'""" def where_like(self): """ @@ -473,7 +465,9 @@ def where_not_exists_with_lambda(self): return """SELECT * FROM "users" WHERE NOT EXISTS (SELECT * FROM "users" WHERE "users"."age" = '1')""" def where_date(self): - return """SELECT * FROM "users" WHERE DATE("users"."created_at") = '2022-06-01'""" + return ( + """SELECT * FROM "users" WHERE DATE("users"."created_at") = '2022-06-01'""" + ) def or_where_null(self): return """SELECT * FROM "users" WHERE "users"."column1" IS NULL OR "users"."column2" IS NULL""" diff --git a/tests/sqlite/grammar/test_sqlite_update_grammar.py b/tests/sqlite/grammar/test_sqlite_update_grammar.py index 4ee26543..0c8e06e0 100644 --- a/tests/sqlite/grammar/test_sqlite_update_grammar.py +++ b/tests/sqlite/grammar/test_sqlite_update_grammar.py @@ -1,9 +1,9 @@ import inspect import unittest +from src.masoniteorm.expressions import Raw from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar -from src.masoniteorm.expressions import Raw class BaseTestCaseUpdateGrammar: diff --git a/tests/sqlite/models/test_attach_detach.py b/tests/sqlite/models/test_attach_detach.py index df528406..f1bc71cf 100644 --- a/tests/sqlite/models/test_attach_detach.py +++ b/tests/sqlite/models/test_attach_detach.py @@ -4,6 +4,7 @@ from src.masoniteorm.relationships import belongs_to, has_one from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import SQLitePlatform + from tests.integrations.config.database import DATABASES diff --git a/tests/sqlite/models/test_observers.py b/tests/sqlite/models/test_observers.py index d4096b41..d6cf291c 100644 --- a/tests/sqlite/models/test_observers.py +++ b/tests/sqlite/models/test_observers.py @@ -1,6 +1,7 @@ import unittest from src.masoniteorm.models import Model + from tests.integrations.config.database import DB diff --git a/tests/sqlite/models/test_sqlite_model.py b/tests/sqlite/models/test_sqlite_model.py index 530c6641..2f0ef9b8 100644 --- a/tests/sqlite/models/test_sqlite_model.py +++ b/tests/sqlite/models/test_sqlite_model.py @@ -4,6 +4,7 @@ from src.masoniteorm.relationships import belongs_to_many from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms.SQLitePlatform import SQLitePlatform + from tests.integrations.config.database import DATABASES @@ -58,9 +59,7 @@ def test_update_specific_record(self): self.assertEqual( sql, - """UPDATE "users" SET "name" = 'joe' WHERE "id" = '{}'""".format( - user.id - ), + f"""UPDATE "users" SET "name" = 'joe' WHERE "id" = '{user.id}'""", ) def test_update_all_records(self): @@ -71,9 +70,7 @@ def test_update_all_records(self): def test_can_find_list(self): sql = User.find(1, query=True).to_sql() - self.assertEqual( - sql, """SELECT * FROM "users" WHERE "users"."id" = '1'""" - ) + self.assertEqual(sql, """SELECT * FROM "users" WHERE "users"."id" = '1'""") sql = User.find([1, 2, 3], query=True).to_sql() @@ -118,21 +115,15 @@ def test_update_only_changed_attributes(self): # unchanged name attribute is not updated self.assertEqual( sql, - """UPDATE "users" SET "username" = 'new' WHERE "id" = '{}'""".format( - user.id - ), + f"""UPDATE "users" SET "username" = 'new' WHERE "id" = '{user.id}'""", ) def test_can_force_update_on_method(self): user = User.first() - sql = user.update( - {"name": user.name, "username": "new"}, force=True - ).to_sql() + sql = user.update({"name": user.name, "username": "new"}, force=True).to_sql() self.assertEqual( sql, - """UPDATE "users" SET "name" = 'bill', "username" = 'new' WHERE "id" = '{}'""".format( - user.id - ), + f"""UPDATE "users" SET "name" = 'bill', "username" = 'new' WHERE "id" = '{user.id}'""", ) def test_can_force_update_on_model(self): @@ -140,21 +131,15 @@ def test_can_force_update_on_model(self): sql = user.update({"name": user.name, "username": "new"}).to_sql() self.assertEqual( sql, - """UPDATE "users" SET "name" = 'bill', "username" = 'new' WHERE "id" = '{}'""".format( - user.id - ), + f"""UPDATE "users" SET "name" = 'bill', "username" = 'new' WHERE "id" = '{user.id}'""", ) def test_force_update(self): user = User.first() - sql = user.force_update( - {"name": user.name, "username": "new"} - ).to_sql() + sql = user.force_update({"name": user.name, "username": "new"}).to_sql() self.assertEqual( sql, - """UPDATE "users" SET "name" = 'bill', "username" = 'new' WHERE "id" = '{}'""".format( - user.id - ), + f"""UPDATE "users" SET "name" = 'bill', "username" = 'new' WHERE "id" = '{user.id}'""", ) def test_update_is_not_done_when_no_changes(self): @@ -175,9 +160,7 @@ class ModelUser(Model): __connection__ = "dev" __table__ = "users" - count = ( - User.where_not_null("id").not_between("age", 1, 2).get().count() - ) + count = User.where_not_null("id").not_between("age", 1, 2).get().count() self.assertEqual(count, 0) def test_get_columns(self): diff --git a/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py b/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py index baf68eae..4b10780e 100644 --- a/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py +++ b/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py @@ -3,10 +3,11 @@ from src.masoniteorm.collection import Collection from src.masoniteorm.models import Model from src.masoniteorm.relationships import has_many_through -from tests.integrations.config.database import DATABASES from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import SQLitePlatform +from tests.integrations.config.database import DATABASES + class Enrolment(Model): __table__ = "enrolment" @@ -26,11 +27,7 @@ class Course(Model): __fillable__ = ["course_id", "name"] @has_many_through( - None, - "in_course_id", - "active_student_id", - "course_id", - "student_id" + None, "in_course_id", "active_student_id", "course_id", "student_id" ) def students(self): return [Student, Enrolment] @@ -103,16 +100,10 @@ def test_has_many_through_can_eager_load(self): self.assertEqual(student2.name, "Bob") # check .first() and .get() produce the same result - single = ( - Course.where("name", "History 101") - .with_("students") - .first() - ) + single = Course.where("name", "History 101").with_("students").first() self.assertIsInstance(single.students, Collection) - single_get = ( - Course.where("name", "History 101").with_("students").get() - ) + single_get = Course.where("name", "History 101").with_("students").get() print(single.students) print(single_get.first().students) @@ -124,11 +115,7 @@ def test_has_many_through_can_eager_load(self): self.assertEqual(single_name, single_get_name) def test_has_many_through_eager_load_can_be_empty(self): - courses = ( - Course.where("name", "Biology 302") - .with_("students") - .get() - ) + courses = Course.where("name", "Biology 302").with_("students").get() self.assertIsNone(courses.first().students) def test_has_many_through_can_get_related(self): @@ -138,7 +125,5 @@ def test_has_many_through_can_get_related(self): self.assertEqual(course.students.count(), 2) def test_has_many_through_has_query(self): - courses = Course.where_has( - "students", lambda query: query.where("name", "Bob") - ) + courses = Course.where_has("students", lambda query: query.where("name", "Bob")) self.assertEqual(courses.count(), 2) diff --git a/tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py b/tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py index b66a2e93..73988429 100644 --- a/tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py +++ b/tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py @@ -4,6 +4,7 @@ from src.masoniteorm.relationships import has_one_through from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import SQLitePlatform + from tests.integrations.config.database import DATABASES @@ -24,9 +25,7 @@ class IncomingShipment(Model): __connection__ = "dev" __fillable__ = ["shipment_id", "name", "from_port_id"] - @has_one_through( - None, "from_port_id", "port_country_id", "port_id", "country_id" - ) + @has_one_through(None, "from_port_id", "port_country_id", "port_id", "country_id") def from_country(self): return [Country, Port] @@ -39,9 +38,7 @@ def setUp(self): platform=SQLitePlatform, ).on("dev") - with self.schema.create_table_if_not_exists( - "incoming_shipments" - ) as table: + with self.schema.create_table_if_not_exists("incoming_shipments") as table: table.integer("shipment_id").primary() table.string("name") table.integer("from_port_id") @@ -117,9 +114,7 @@ def setUp(self): ) def test_has_one_through_can_eager_load(self): - shipments = ( - IncomingShipment.where("name", "Bread").with_("from_country").get() - ) + shipments = IncomingShipment.where("name", "Bread").with_("from_country").get() self.assertEqual(shipments.count(), 2) shipment1 = shipments.shift() @@ -137,9 +132,7 @@ def test_has_one_through_can_eager_load(self): .first() ) single_get = ( - IncomingShipment.where("name", "Tractor Parts") - .with_("from_country") - .get() + IncomingShipment.where("name", "Tractor Parts").with_("from_country").get() ) self.assertEqual(single.from_country.country_id, 10) self.assertEqual(single_get.count(), 1) @@ -151,9 +144,7 @@ def test_has_one_through_can_eager_load(self): def test_has_one_through_eager_load_can_be_empty(self): shipments = ( IncomingShipment.where("name", "Bread") - .where_has( - "from_country", lambda query: query.where("name", "Ueaguay") - ) + .where_has("from_country", lambda query: query.where("name", "Ueaguay")) .with_( "from_country", ) diff --git a/tests/sqlite/relationships/test_sqlite_polymorphic.py b/tests/sqlite/relationships/test_sqlite_polymorphic.py index 6c87317a..f0158e0f 100644 --- a/tests/sqlite/relationships/test_sqlite_polymorphic.py +++ b/tests/sqlite/relationships/test_sqlite_polymorphic.py @@ -2,6 +2,7 @@ from src.masoniteorm.models import Model from src.masoniteorm.relationships import belongs_to, morph_to + from tests.integrations.config.database import DB diff --git a/tests/sqlite/relationships/test_sqlite_relationships.py b/tests/sqlite/relationships/test_sqlite_relationships.py index a3be5246..866a2a16 100644 --- a/tests/sqlite/relationships/test_sqlite_relationships.py +++ b/tests/sqlite/relationships/test_sqlite_relationships.py @@ -1,6 +1,13 @@ import unittest + from src.masoniteorm.models import Model -from src.masoniteorm.relationships import belongs_to, has_many, has_one, belongs_to_many +from src.masoniteorm.relationships import ( + belongs_to, + belongs_to_many, + has_many, + has_one, +) + from tests.integrations.config.database import DB diff --git a/tests/sqlite/schema/test_sqlite_schema_builder.py b/tests/sqlite/schema/test_sqlite_schema_builder.py index 461abd55..b1e771c5 100644 --- a/tests/sqlite/schema/test_sqlite_schema_builder.py +++ b/tests/sqlite/schema/test_sqlite_schema_builder.py @@ -2,6 +2,7 @@ from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import SQLitePlatform + from tests.integrations.config.database import DATABASES @@ -113,9 +114,9 @@ def test_can_add_columns_with_foreign_key_constraint_name(self): blueprint.string("name").unique() blueprint.integer("age") blueprint.integer("profile_id") - blueprint.foreign("profile_id", name="profile_foreign").references( - "id" - ).on("profiles") + blueprint.foreign("profile_id", name="profile_foreign").references("id").on( + "profiles" + ) self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( @@ -259,9 +260,9 @@ def test_can_advanced_table_creation2(self): blueprint.string("thumbnail").nullable() blueprint.integer("premium") blueprint.integer("author_id").unsigned().nullable() - blueprint.foreign("author_id").references("id").on( - "users" - ).on_delete("set null") + blueprint.foreign("author_id").references("id").on("users").on_delete( + "set null" + ) blueprint.text("description") blueprint.timestamps() diff --git a/tests/sqlite/schema/test_sqlite_schema_builder_alter.py b/tests/sqlite/schema/test_sqlite_schema_builder_alter.py index 677b1445..82580aa0 100644 --- a/tests/sqlite/schema/test_sqlite_schema_builder_alter.py +++ b/tests/sqlite/schema/test_sqlite_schema_builder_alter.py @@ -3,6 +3,7 @@ from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import SQLitePlatform from src.masoniteorm.schema.Table import Table + from tests.integrations.config.database import DATABASES @@ -143,9 +144,7 @@ def test_timestamp_alter_add_nullable_column(self): self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_on_table_schema_table(self): - schema = Schema(connection="dev", connection_details=DATABASES).on( - "dev" - ) + schema = Schema(connection="dev", connection_details=DATABASES).on("dev") with schema.table("table_schema") as blueprint: blueprint.drop_column("name") @@ -166,9 +165,9 @@ def test_alter_add_primary(self): def test_alter_add_column_and_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.unsigned_integer("playlist_id").nullable() - blueprint.foreign("playlist_id").references("id").on( - "playlists" - ).on_delete("cascade").on_update("SET NULL") + blueprint.foreign("playlist_id").references("id").on("playlists").on_delete( + "cascade" + ).on_update("SET NULL") table = Table("users") table.add_column("age", "string") @@ -190,9 +189,9 @@ def test_alter_add_column_and_foreign_key(self): def test_alter_add_foreign_key_only(self): with self.schema.table("users") as blueprint: - blueprint.foreign("playlist_id").references("id").on( - "playlists" - ).on_delete("cascade").on_update("set null") + blueprint.foreign("playlist_id").references("id").on("playlists").on_delete( + "cascade" + ).on_update("set null") table = Table("users") table.add_column("age", "string") @@ -225,9 +224,7 @@ def test_can_add_column_enum(self): def test_can_change_column_enum(self): with self.schema.table("users") as blueprint: - blueprint.enum("status", ["active", "inactive"]).default( - "active" - ).change() + blueprint.enum("status", ["active", "inactive"]).default("active").change() blueprint.table.from_table = Table("users") diff --git a/tests/sqlite/schema/test_table.py b/tests/sqlite/schema/test_table.py index 2e73a4ab..464f56d0 100644 --- a/tests/sqlite/schema/test_table.py +++ b/tests/sqlite/schema/test_table.py @@ -1,10 +1,11 @@ import unittest -from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import SQLiteConnection from src.masoniteorm.schema import Column, Table from src.masoniteorm.schema.platforms.SQLitePlatform import SQLitePlatform +from tests.integrations.config.database import DATABASES + class TestTable(unittest.TestCase): maxDiff = None From f9acc2a78e87a37a01829ff2bc1b683cde745777 Mon Sep 17 00:00:00 2001 From: Kieren Eaton <499977+circulon@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:01:21 +0800 Subject: [PATCH 04/10] updated linting and formatting settings --- makefile | 9 ++++++++- pyproject.toml | 12 +++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/makefile b/makefile index b53efc5b..74755b9e 100644 --- a/makefile +++ b/makefile @@ -30,25 +30,30 @@ ci: make test .PHONY: check -check: format lint +check: lint format .PHONY: lint lint: .bootstrapped-tests ruff check --fix --exit-non-zero-on-fix src/masoniteorm tests +.PHONY: format format: .bootstrapped-tests ruff format --check src/masoniteorm tests/ +.PHONY: coverage coverage: python -m pytest --cov-report term --cov-report xml --cov=src/masoniteorm tests/ python -m coveralls +.PHONY: show show: python -m pytest --cov-report term --cov-report html --cov=src/masoniteorm tests/ +.PHONY: cov cov: python -m pytest --cov-report term --cov-report xml --cov=src/masoniteorm tests/ +.PHONY: publish publish: pip install build twine make test @@ -56,10 +61,12 @@ publish: twine upload dist/* rm -rf build dist *.egg-info +.PHONY: pub pub: python -m build twine upload dist/* rm -rf build dist *.egg-info +.PHONY: pypirc pypirc: cp .pypirc ~/.pypirc diff --git a/pyproject.toml b/pyproject.toml index 4a984878..03f376de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ env = [ [tool.ruff] target-version = "py38" -line-length = 88 +line-length = 110 exclude = ["build", "conda"] [tool.ruff.lint] @@ -67,7 +67,6 @@ select = [ "UP", # pyupgrade ] ignore = [ -# "E501", # line too long # "E203", # whitespace before ':' "E402", # module-level import not at top of file "E731", # do not assign a lambda expression @@ -76,9 +75,16 @@ ignore = [ # "F811", # redefinition of unused name ] -[tool.ruff.lint.per-file-ignores] +[tool.ruff.lint.extend-per-file-ignores] "__init__.py" = ["F401"] # imported but unused "tests/**/*.py" = ["E501"] +"src/masoniteorm/schema/platforms/*.py" = ["E501"] +"src/masoniteorm/relationships/*.py" = ["E501"] +"src/masoniteorm/query/processors/*.py" = ["E501"] +"src/masoniteorm/query/grammars/*.py" = ["E501"] +"src/masoniteorm/connections/*.py" = ["E501"] +"src/masoniteorm/commands/*.py" = ["E501"] +"*.pyi" = ["E501"] [tool.ruff.lint.isort] force-sort-within-sections = true From 1e5c35fe0269541760ff01399eb86abf8da198ed Mon Sep 17 00:00:00 2001 From: Kieren Eaton <499977+circulon@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:02:50 +0800 Subject: [PATCH 05/10] lint anf format src files sorted imports removed duplicate methods tidy --- src/masoniteorm/collection/Collection.py | 32 +- .../commands/MakeMigrationCommand.py | 8 +- src/masoniteorm/commands/MakeModelCommand.py | 22 +- .../commands/MakeObserverCommand.py | 14 +- src/masoniteorm/commands/MakeSeedCommand.py | 12 +- .../commands/MigrateRollbackCommand.py | 4 +- src/masoniteorm/commands/ShellCommand.py | 29 +- src/masoniteorm/config.py | 18 +- src/masoniteorm/connections/BaseConnection.py | 6 +- .../connections/ConnectionFactory.py | 10 +- .../connections/ConnectionResolver.py | 12 +- .../connections/MSSQLConnection.py | 6 +- .../connections/MySQLConnection.py | 8 +- .../connections/PostgresConnection.py | 4 +- src/masoniteorm/expressions/expressions.py | 21 +- src/masoniteorm/helpers/misc.py | 4 +- src/masoniteorm/migrations/Migration.py | 95 +--- src/masoniteorm/models/Model.py | 100 +--- src/masoniteorm/models/Model.pyi | 71 +-- src/masoniteorm/observers/ObservesEvents.py | 2 +- src/masoniteorm/pagination/BasePaginator.py | 3 +- .../pagination/LengthAwarePaginator.py | 4 +- src/masoniteorm/pagination/SimplePaginator.py | 4 +- src/masoniteorm/providers/ORMProvider.py | 26 +- src/masoniteorm/query/EagerRelation.py | 4 +- src/masoniteorm/query/QueryBuilder.py | 455 +++++------------- src/masoniteorm/query/grammars/BaseGrammar.py | 125 ++--- .../query/grammars/SQLiteGrammar.py | 2 +- .../query/processors/MSSQLPostProcessor.py | 4 +- .../query/processors/MySQLPostProcessor.py | 4 +- .../query/processors/SQLitePostProcessor.py | 4 +- src/masoniteorm/relationships/BelongsTo.py | 20 +- .../relationships/BelongsToMany.py | 70 +-- src/masoniteorm/relationships/HasMany.py | 15 +- .../relationships/HasManyThrough.py | 74 +-- src/masoniteorm/relationships/HasOne.py | 12 +- .../relationships/HasOneThrough.py | 70 +-- src/masoniteorm/relationships/MorphMany.py | 14 +- src/masoniteorm/relationships/MorphOne.py | 14 +- src/masoniteorm/schema/Blueprint.py | 183 ++----- src/masoniteorm/schema/Schema.py | 52 +- src/masoniteorm/schema/Table.py | 8 +- src/masoniteorm/schema/TableDiff.py | 4 +- .../schema/platforms/MSSQLPlatform.py | 75 +-- .../schema/platforms/MySQLPlatform.py | 118 +---- src/masoniteorm/schema/platforms/Platform.py | 20 +- .../schema/platforms/PostgresPlatform.py | 115 +---- .../schema/platforms/SQLitePlatform.py | 100 +--- src/masoniteorm/scopes/SoftDeleteScope.py | 16 +- src/masoniteorm/seeds/Seeder.py | 8 +- .../testing/BaseTestCaseSelectGrammar.py | 370 ++++---------- 51 files changed, 639 insertions(+), 1832 deletions(-) diff --git a/src/masoniteorm/collection/Collection.py b/src/masoniteorm/collection/Collection.py index bed2a345..4087effc 100644 --- a/src/masoniteorm/collection/Collection.py +++ b/src/masoniteorm/collection/Collection.py @@ -1,7 +1,8 @@ +# ruff: noqa: E501 +from functools import reduce import json import operator import random -from functools import reduce class Collection: @@ -77,7 +78,7 @@ def avg(self, key=None): If a key is given it will return the average of all the values of the key. - Keyword Arguments: + Arguments: key {string} -- The key to use to find the average of all the values of that key. (default: {None}) Returns: @@ -188,12 +189,10 @@ def flatten(self): def _flatten(items): if isinstance(items, dict): for v in items.values(): - for x in _flatten(v): - yield x + yield from _flatten(v) elif isinstance(items, list): for i in items: - for j in _flatten(i): - yield j + yield from _flatten(i) else: yield items @@ -246,7 +245,7 @@ def merge(self, items): if isinstance(items, Collection): items = items._items elif not isinstance(items, list): - raise ValueError("Unable to merge uncompatible types") + raise ValueError("Unable to merge incompatible types") items = self.__get_items(items) @@ -276,9 +275,7 @@ def pluck(self, value, key=None, keep_nulls=True): if k == value: if key: - attributes[self._data_get(item, key)] = self._data_get( - item, value - ) + attributes[self._data_get(item, key)] = self._data_get(item, value) else: attributes.append(v) @@ -310,9 +307,7 @@ def random(self, count=None): if collection_count == 0: return None elif count and count > collection_count: - raise ValueError( - "count argument must be inferior to collection length." - ) + raise ValueError("count argument must be inferior to collection length.") elif count: self._items = random.sample(self._items, k=count) return self @@ -426,9 +421,7 @@ def where(self, key, *args): if isinstance(item, dict): comparison = item.get(key) else: - comparison = ( - getattr(item, key) if hasattr(item, key) else False - ) + comparison = getattr(item, key) if hasattr(item, key) else False if self._make_comparison(comparison, value, op): attributes.append(item) return self.__class__(attributes) @@ -482,9 +475,7 @@ def where_not_in(self, key, args: list) -> "Collection": def zip(self, items): items = self.__get_items(items) if not isinstance(items, list): - raise ValueError( - "The 'items' parameter must be a list or a Collection" - ) + raise ValueError("The 'items' parameter must be a list or a Collection") _items = [] for x, y in zip(self, items): @@ -550,8 +541,7 @@ def _make_comparison(self, a, b, op): return operators[op](str(a), str(b)) def __iter__(self): - for item in self._items: - yield item + yield from self._items def __eq__(self, other): other = self.__get_items(other) diff --git a/src/masoniteorm/commands/MakeMigrationCommand.py b/src/masoniteorm/commands/MakeMigrationCommand.py index df87d993..e22c7834 100644 --- a/src/masoniteorm/commands/MakeMigrationCommand.py +++ b/src/masoniteorm/commands/MakeMigrationCommand.py @@ -47,11 +47,7 @@ def handle(self): file_name = f"{now.strftime('%Y_%m_%d_%H%M%S')}_{name}.py" - with open( - os.path.join(os.getcwd(), migration_directory, file_name), "w" - ) as fp: + with open(os.path.join(os.getcwd(), migration_directory, file_name), "w") as fp: fp.write(output) - self.info( - f"Migration file created: {os.path.join(migration_directory, file_name)}" - ) + self.info(f"Migration file created: {os.path.join(migration_directory, file_name)}") diff --git a/src/masoniteorm/commands/MakeModelCommand.py b/src/masoniteorm/commands/MakeModelCommand.py index 50b1cc98..9282a8cb 100644 --- a/src/masoniteorm/commands/MakeModelCommand.py +++ b/src/masoniteorm/commands/MakeModelCommand.py @@ -27,11 +27,7 @@ def handle(self): model_directory = self.option("directory") - with open( - os.path.join( - pathlib.Path(__file__).parent.absolute(), "stubs/model.stub" - ) - ) as fp: + with open(os.path.join(pathlib.Path(__file__).parent.absolute(), "stubs/model.stub")) as fp: output = fp.read() output = output.replace("__CLASS__", camelize(name)) @@ -43,18 +39,12 @@ def handle(self): full_directory_path = os.path.join(os.getcwd(), model_directory) if os.path.exists(os.path.join(full_directory_path, file_name)): - self.line( - f'Model "{name}" Already Exists ({full_directory_path}/{file_name})' - ) + self.line(f'Model "{name}" Already Exists ({full_directory_path}/{file_name})') return - os.makedirs( - os.path.dirname(os.path.join(full_directory_path)), exist_ok=True - ) + os.makedirs(os.path.dirname(os.path.join(full_directory_path)), exist_ok=True) - with open( - os.path.join(os.getcwd(), model_directory, file_name), "w+" - ) as fp: + with open(os.path.join(os.getcwd(), model_directory, file_name), "w+") as fp: fp.write(output) self.info(f"Model created: {os.path.join(model_directory, file_name)}") @@ -73,6 +63,4 @@ def handle(self): if self.option("seeder"): directory = self.option("seeders-directory") - self.call( - "seed", f"{self.argument('name')} --directory {directory}" - ) + self.call("seed", f"{self.argument('name')} --directory {directory}") diff --git a/src/masoniteorm/commands/MakeObserverCommand.py b/src/masoniteorm/commands/MakeObserverCommand.py index 667b3cb6..e9a60a5e 100644 --- a/src/masoniteorm/commands/MakeObserverCommand.py +++ b/src/masoniteorm/commands/MakeObserverCommand.py @@ -24,11 +24,7 @@ def handle(self): observer_directory = self.option("directory") - with open( - os.path.join( - pathlib.Path(__file__).parent.absolute(), "stubs/observer.stub" - ) - ) as fp: + with open(os.path.join(pathlib.Path(__file__).parent.absolute(), "stubs/observer.stub")) as fp: output = fp.read() output = output.replace("__CLASS__", camelize(name)) output = output.replace("__MODEL_VARIABLE__", underscore(model)) @@ -39,16 +35,12 @@ def handle(self): full_directory_path = os.path.join(os.getcwd(), observer_directory) if os.path.exists(os.path.join(full_directory_path, file_name)): - self.line( - f'Observer "{name}" Already Exists ({full_directory_path}/{file_name})' - ) + self.line(f'Observer "{name}" Already Exists ({full_directory_path}/{file_name})') return os.makedirs(os.path.join(full_directory_path), exist_ok=True) - with open( - os.path.join(os.getcwd(), observer_directory, file_name), "w+" - ) as fp: + with open(os.path.join(os.getcwd(), observer_directory, file_name), "w+") as fp: fp.write(output) self.info(f"Observer created: {file_name}") diff --git a/src/masoniteorm/commands/MakeSeedCommand.py b/src/masoniteorm/commands/MakeSeedCommand.py index e75caa6a..9447e651 100644 --- a/src/masoniteorm/commands/MakeSeedCommand.py +++ b/src/masoniteorm/commands/MakeSeedCommand.py @@ -35,18 +35,12 @@ def handle(self): output = output.replace("__SEEDER_NAME__", camelize(name)) file_name = f"{underscore(name)}.py" - full_path = pathlib.Path( - os.path.join(os.getcwd(), seed_directory, file_name) - ) + full_path = pathlib.Path(os.path.join(os.getcwd(), seed_directory, file_name)) - path_normalized = pathlib.Path(seed_directory) / pathlib.Path( - file_name - ) + path_normalized = pathlib.Path(seed_directory) / pathlib.Path(file_name) if os.path.exists(full_path): - return self.line( - f"{path_normalized} already exists." - ) + return self.line(f"{path_normalized} already exists.") with open(full_path, "w") as fp: fp.write(output) diff --git a/src/masoniteorm/commands/MigrateRollbackCommand.py b/src/masoniteorm/commands/MigrateRollbackCommand.py index 2da30979..10cdd4a1 100644 --- a/src/masoniteorm/commands/MigrateRollbackCommand.py +++ b/src/masoniteorm/commands/MigrateRollbackCommand.py @@ -21,6 +21,4 @@ def handle(self): migration_directory=self.option("directory"), config_path=self.option("config"), schema=self.option("schema"), - ).rollback( - migration=self.option("migration"), output=self.option("show") - ) + ).rollback(migration=self.option("migration"), output=self.option("show")) diff --git a/src/masoniteorm/commands/ShellCommand.py b/src/masoniteorm/commands/ShellCommand.py index 4762127d..a7c129e8 100644 --- a/src/masoniteorm/commands/ShellCommand.py +++ b/src/masoniteorm/commands/ShellCommand.py @@ -1,8 +1,8 @@ +from collections import OrderedDict import os import re import shlex import subprocess -from collections import OrderedDict from ..config import load_config from .Command import Command @@ -31,9 +31,7 @@ def handle(self): connection = resolver.get_connection_details()["default"] config = resolver.get_connection_information(connection) if not config.get("full_details"): - self.line( - f"Connection configuration for '{connection}' not found !" - ) + self.line(f"Connection configuration for '{connection}' not found !") exit(-1) command, env = self.get_command(config) @@ -67,9 +65,7 @@ def get_command(self, config): try: get_driver_args = getattr(self, f"get_{driver}_args") except AttributeError: - self.line( - f"Connecting with driver '{driver}' is not implemented !" - ) + self.line(f"Connecting with driver '{driver}' is not implemented !") exit(-1) args, options = get_driver_args(config) # process positional arguments @@ -77,8 +73,7 @@ def get_command(self, config): # process optional arguments options = self.remove_empty_options(options) options_string = " ".join( - f"{option} {value}" if value else option - for option, value in options.items() + f"{option} {value}" if value else option for option, value in options.items() ) # finally build command string command = program @@ -106,9 +101,7 @@ def get_mysql_args(self, config): "--port": config.get("port"), "--user": config.get("user"), "--password": config.get("password"), - "--default-character-set": config.get("options", {}).get( - "charset" - ), + "--default-character-set": config.get("options", {}).get("charset"), } ) return args, options @@ -137,9 +130,7 @@ def get_mssql_args(self, config): if config.get("port"): server += f",{config.get('port')}" - trusted_connection = ( - config.get("options").get("trusted_connection") == "Yes" - ) + trusted_connection = config.get("options").get("trusted_connection") == "Yes" options = OrderedDict( { "-d": config.get("database"), @@ -175,9 +166,7 @@ def remove_empty_options(self, options): def get_sensitive_options(self, config): driver = config.get("full_details").get("driver") try: - sensitive_options = getattr( - self, f"get_{driver}_sensitive_options" - )() + sensitive_options = getattr(self, f"get_{driver}_sensitive_options")() except AttributeError: sensitive_options = [] return sensitive_options @@ -199,7 +188,5 @@ def hide_sensitive_options(self, config, command): if option in command: match = re.search(rf"{option} (\w+)", command) if match.groups(): - cleaned_command = cleaned_command.replace( - match.groups()[0], "***" - ) + cleaned_command = cleaned_command.replace(match.groups()[0], "***") return cleaned_command diff --git a/src/masoniteorm/config.py b/src/masoniteorm/config.py index 57cf0de2..ba0ec91b 100644 --- a/src/masoniteorm/config.py +++ b/src/masoniteorm/config.py @@ -11,21 +11,15 @@ def load_config(config_path=None): 1. try to load from DB_CONFIG_PATH environment variable 2. else try to load from default config_path: config/database """ - selected_config_path = ( - os.getenv("DB_CONFIG_PATH", config_path) or "config/database" - ) + selected_config_path = os.getenv("DB_CONFIG_PATH", config_path) or "config/database" os.environ["DB_CONFIG_PATH"] = selected_config_path # format path as python module if needed - selected_config_path = ( - selected_config_path.replace("/", ".").replace("\\", ".").rstrip(".py") - ) + selected_config_path = selected_config_path.replace("/", ".").replace("\\", ".").rstrip(".py") config_module = pydoc.locate(selected_config_path) if config_module is None: - raise ConfigurationNotFound( - f"ORM configuration file has not been found in {selected_config_path}." - ) + raise ConfigurationNotFound(f"ORM configuration file has not been found in {selected_config_path}.") return config_module @@ -96,11 +90,7 @@ def db_url(database_url=None, prefix="", options={}, log_queries=False): # lookup specified driver driver = DRIVERS_MAP[url.scheme] - port = ( - str(url.port) - if url.port and driver in [DRIVERS_MAP["mssql"]] - else url.port - ) + port = str(url.port) if url.port and driver in [DRIVERS_MAP["mssql"]] else url.port # build final configuration config = { diff --git a/src/masoniteorm/connections/BaseConnection.py b/src/masoniteorm/connections/BaseConnection.py index 035c70cc..20318fc8 100644 --- a/src/masoniteorm/connections/BaseConnection.py +++ b/src/masoniteorm/connections/BaseConnection.py @@ -50,7 +50,7 @@ def statement(self, query, bindings=()): ) self._cursor.execute(query, bindings) - end = "{:.2f}".format(timer() - start) + end = f"{timer() - start:.2f}" if self.full_details and self.full_details.get("log_queries", False): self.log(query, bindings, query_time=end) @@ -96,6 +96,4 @@ def enable_disable_foreign_keys(self): if foreign_keys: self._connection.execute(platform.enable_foreign_key_constraints()) elif foreign_keys is not None: - self._connection.execute( - platform.disable_foreign_key_constraints() - ) + self._connection.execute(platform.disable_foreign_key_constraints()) diff --git a/src/masoniteorm/connections/ConnectionFactory.py b/src/masoniteorm/connections/ConnectionFactory.py index 4da627ad..1662c142 100644 --- a/src/masoniteorm/connections/ConnectionFactory.py +++ b/src/masoniteorm/connections/ConnectionFactory.py @@ -43,17 +43,11 @@ def make(self, key): connections = self._resolver.get_connection_details() if key == "default": connection_details = connections.get(connections.get("default")) - connection = self._connections.get( - connection_details.get("driver") - ) + connection = self._connections.get(connection_details.get("driver")) else: connection = self._connections.get(key) if connection: return connection - raise Exception( - "The '{connection}' connection does not exist".format( - connection=key - ) - ) + raise Exception(f"The '{key}' connection does not exist") diff --git a/src/masoniteorm/connections/ConnectionResolver.py b/src/masoniteorm/connections/ConnectionResolver.py index 7839409f..bfe7735d 100644 --- a/src/masoniteorm/connections/ConnectionResolver.py +++ b/src/masoniteorm/connections/ConnectionResolver.py @@ -17,9 +17,7 @@ def __init__(self, config_path=None, connection_details=None): self.config_path = config_path self._connection_details = connection_details or {} - self.connection_factory = ConnectionFactory( - config_path=config_path, resolver=self - ) + self.connection_factory = ConnectionFactory(config_path=config_path, resolver=self) self.register(SQLiteConnection) self.register(PostgresConnection) self.register(MySQLConnection) @@ -56,9 +54,7 @@ def begin_transaction(self, name=None): driver = self.get_connection_details()[name].get("driver") connection = ( - self.connection_factory.make(driver)( - **self.get_connection_information(name) - ) + self.connection_factory.make(driver)(**self.get_connection_information(name)) .make_connection() .begin() ) @@ -127,6 +123,4 @@ def get_query_builder(self, connection="default"): ) def statement(self, query, bindings=(), connection="default"): - return ( - self.get_query_builder().on(connection).statement(query, bindings) - ) + return self.get_query_builder().on(connection).statement(query, bindings) diff --git a/src/masoniteorm/connections/MSSQLConnection.py b/src/masoniteorm/connections/MSSQLConnection.py index 53b87f95..45ea5bfd 100644 --- a/src/masoniteorm/connections/MSSQLConnection.py +++ b/src/masoniteorm/connections/MSSQLConnection.py @@ -150,11 +150,7 @@ def query(self, query, bindings=(), results="*"): return {} columnNames = [column[0] for column in cursor.description] result = cursor.fetchone() - return ( - dict(zip(columnNames, result)) - if result is not None - else {} - ) + return dict(zip(columnNames, result)) if result is not None else {} else: if not cursor.description: return {} diff --git a/src/masoniteorm/connections/MySQLConnection.py b/src/masoniteorm/connections/MySQLConnection.py index 01820840..5698bda2 100644 --- a/src/masoniteorm/connections/MySQLConnection.py +++ b/src/masoniteorm/connections/MySQLConnection.py @@ -35,9 +35,7 @@ def __init__( self.password = password self.prefix = prefix self.full_details = full_details or {} - self.connection_pool_size = full_details.get( - "connection_pooling_max_size", 100 - ) + self.connection_pool_size = full_details.get("connection_pooling_max_size", 100) self.options = options or {} self._cursor = None self.open = 0 @@ -82,9 +80,7 @@ def create_connection(self, autocommit=True): import pendulum import pymysql.converters - pymysql.converters.conversions[pendulum.DateTime] = ( - pymysql.converters.escape_datetime - ) + pymysql.converters.conversions[pendulum.DateTime] = pymysql.converters.escape_datetime # Initialize the connection pool if the option is set initialize_size = self.full_details.get("connection_pooling_min_size") diff --git a/src/masoniteorm/connections/PostgresConnection.py b/src/masoniteorm/connections/PostgresConnection.py index 374db744..f031f4aa 100644 --- a/src/masoniteorm/connections/PostgresConnection.py +++ b/src/masoniteorm/connections/PostgresConnection.py @@ -35,9 +35,7 @@ def __init__( self.prefix = prefix self.full_details = full_details or {} - self.connection_pool_size = full_details.get( - "connection_pooling_max_size", 100 - ) + self.connection_pool_size = full_details.get("connection_pooling_max_size", 100) self.options = options or {} self._cursor = None self.transaction_level = 0 diff --git a/src/masoniteorm/expressions/expressions.py b/src/masoniteorm/expressions/expressions.py index 12b1f47c..48dac0f8 100644 --- a/src/masoniteorm/expressions/expressions.py +++ b/src/masoniteorm/expressions/expressions.py @@ -108,7 +108,7 @@ def __init__(self, column, direction="ASC", raw=False, bindings=()): self.direction = direction self.bindings = bindings - if raw is False: + if not raw: if self.column.endswith(" desc"): self.column = self.column.split(" desc")[0].strip() self.direction = "DESC" @@ -168,9 +168,7 @@ def on_value(self, column, *args): def or_on_value(self, column, *args): equality, value = self._extract_operator_value(*args) - self.on_clauses += ( - (OnValueClause(column, equality, value, "value", operator="or")), - ) + self.on_clauses += ((OnValueClause(column, equality, value, "value", operator="or")),) return self def on_null(self, column): @@ -206,9 +204,7 @@ def or_on_null(self, column): Returns: self """ - self.on_clauses += ( - (OnValueClause(column, "=", None, "NULL", operator="or")), - ) + self.on_clauses += ((OnValueClause(column, "=", None, "NULL", operator="or")),) return self def or_on_not_null(self, column: str): @@ -220,14 +216,10 @@ def or_on_not_null(self, column: str): Returns: self """ - self.on_clauses += ( - (OnValueClause(column, "=", True, "NOT NULL", operator="or")), - ) + self.on_clauses += ((OnValueClause(column, "=", True, "NOT NULL", operator="or")),) return self - @deprecated( - "Using where() in a Join clause has been superceded by on_value()" - ) + @deprecated("Using where() in a Join clause has been superseded by on_value()") def where(self, column, *args): return self.on_value(column, *args) @@ -246,8 +238,7 @@ def _extract_operator_value(self, *args): if operator not in operators: raise ValueError( - "Invalid comparison operator. The operator can be %s" - % ", ".join(operators) + "Invalid comparison operator. The operator can be {}".format(", ".join(operators)) ) return operator, value diff --git a/src/masoniteorm/helpers/misc.py b/src/masoniteorm/helpers/misc.py index 2b26830e..fd0c156b 100644 --- a/src/masoniteorm/helpers/misc.py +++ b/src/masoniteorm/helpers/misc.py @@ -9,9 +9,7 @@ def deprecated(message): def deprecated_decorator(func): def deprecated_func(*args, **kwargs): warnings.warn( - "{} is a deprecated function. {}".format( - func.__name__, message - ), + f"{func.__name__} is a deprecated function. {message}", category=DeprecationWarning, stacklevel=2, ) diff --git a/src/masoniteorm/migrations/Migration.py b/src/masoniteorm/migrations/Migration.py index 51291e5d..b981ae10 100644 --- a/src/masoniteorm/migrations/Migration.py +++ b/src/masoniteorm/migrations/Migration.py @@ -59,9 +59,7 @@ def get_unran_migrations(self): all_migrations = [ f.replace(".py", "") for f in listdir(directory_path) - if isfile(join(directory_path, f)) - and f != "__init__.py" - and not f.startswith(".") + if isfile(join(directory_path, f)) and f != "__init__.py" and not f.startswith(".") ] all_migrations.sort() unran_migrations = [] @@ -73,9 +71,7 @@ def get_unran_migrations(self): def get_rollback_migrations(self): return ( - self.migration_model.where( - "batch", self.migration_model.all().max("batch") - ) + self.migration_model.where("batch", self.migration_model.all().max("batch")) .order_by("migration_id", "desc") .get() .pluck("migration") @@ -83,11 +79,7 @@ def get_rollback_migrations(self): def get_all_migrations(self, reverse=False): if reverse: - return ( - self.migration_model.order_by("migration_id", "desc") - .get() - .pluck("migration") - ) + return self.migration_model.order_by("migration_id", "desc").get().pluck("migration") return self.migration_model.all().pluck("migration") @@ -98,13 +90,9 @@ def delete_migration(self, file_path): return self.migration_model.where("migration", file_path).delete() def locate(self, file_name): - migration_name = camelize( - "_".join(file_name.split("_")[4:]).replace(".py", "") - ) + migration_name = camelize("_".join(file_name.split("_")[4:]).replace(".py", "")) file_name = file_name.replace(".py", "") - migration_directory = self.migration_directory.replace( - "/", "." - ).replace("\\", ".") + migration_directory = self.migration_directory.replace("/", ".").replace("\\", ".") return locate(f"{migration_directory}.{file_name}.{migration_name}") def get_ran_migrations(self): @@ -112,18 +100,14 @@ def get_ran_migrations(self): all_migrations = [ f.replace(".py", "") for f in listdir(directory_path) - if isfile(join(directory_path, f)) - and f != "__init__.py" - and not f.startswith(".") + if isfile(join(directory_path, f)) and f != "__init__.py" and not f.startswith(".") ] all_migrations.sort() ran = [] database_migrations = self.migration_model.all() for migration in all_migrations: - matched_migration = database_migrations.where( - "migration", migration - ).first() + matched_migration = database_migrations.where("migration", migration).first() if matched_migration: ran.append( { @@ -144,26 +128,20 @@ def migrate(self, migration="all", output=False): migration_class = self.locate(migration) except TypeError: - self.command_class.line( - f"Not Found: {migration}" - ) + self.command_class.line(f"Not Found: {migration}") continue self.last_migrations_ran.append(migration) if self.command_class: - self.command_class.line( - f"Migrating: {migration}" - ) + self.command_class.line(f"Migrating: {migration}") - migration_class = migration_class( - connection=self.connection, schema=self.schema_name - ) + migration_class = migration_class(connection=self.connection, schema=self.schema_name) if output: migration_class.schema.dry() start = timer() migration_class.up() - duration = "{:.2f}".format(timer() - start) + duration = f"{timer() - start:.2f}" if output: if self.command_class: @@ -183,9 +161,7 @@ def migrate(self, migration="all", output=False): f"Migrated: {migration} ({duration}s)" ) - self.migration_model.create( - {"batch": batch, "migration": migration.replace(".py", "")} - ) + self.migration_model.create({"batch": batch, "migration": migration.replace(".py", "")}) def rollback(self, migration="all", output=False): default_migrations = self.get_rollback_migrations() @@ -196,37 +172,28 @@ def rollback(self, migration="all", output=False): migration = migration.replace(".py", "") if self.command_class: - self.command_class.line( - f"Rolling back: {migration}" - ) + self.command_class.line(f"Rolling back: {migration}") try: migration_class = self.locate(migration) except TypeError: - self.command_class.line( - f"Not Found: {migration}" - ) + self.command_class.line(f"Not Found: {migration}") continue - migration_class = migration_class( - connection=self.connection, schema=self.schema_name - ) + migration_class = migration_class(connection=self.connection, schema=self.schema_name) if output: migration_class.schema.dry() start = timer() migration_class.down() - duration = "{:.2f}".format(timer() - start) + duration = f"{timer() - start:.2f}" if output: if self.command_class: table = self.command_class.table() table.set_header_row(["SQL"]) - if ( - hasattr(migration_class.schema, "_blueprint") - and migration_class.schema._blueprint - ): + if hasattr(migration_class.schema, "_blueprint") and migration_class.schema._blueprint: sql = migration_class.schema._blueprint.to_sql() if isinstance(sql, list): sql = ",".join(sql) @@ -247,14 +214,10 @@ def rollback(self, migration="all", output=False): ) def delete_migrations(self, migrations=None): - return self.migration_model.where_in( - "migration", migrations or [] - ).delete() + return self.migration_model.where_in("migration", migrations or []).delete() def delete_last_batch(self): - return self.migration_model.where( - "batch", self.get_last_batch_number() - ).delete() + return self.migration_model.where("batch", self.get_last_batch_number()).delete() def reset(self, migration="all"): default_migrations = self.get_all_migrations(reverse=True) @@ -268,18 +231,12 @@ def reset(self, migration="all"): for migration in migrations: if self.command_class: - self.command_class.line( - f"Rolling back: {migration}" - ) + self.command_class.line(f"Rolling back: {migration}") try: - self.locate(migration)( - connection=self.connection, schema=self.schema_name - ).down() + self.locate(migration)(connection=self.connection, schema=self.schema_name).down() except TypeError: - self.command_class.line( - f"Not Found: {migration}" - ) + self.command_class.line(f"Not Found: {migration}") continue # raise MigrationNotFound(f"Could not find {migration}") @@ -287,9 +244,7 @@ def reset(self, migration="all"): self.delete_migration(migration) if self.command_class: - self.command_class.line( - f"Rolled back: {migration}" - ) + self.command_class.line(f"Rolled back: {migration}") self.delete_migrations([migration]) @@ -322,9 +277,7 @@ def fresh(self, ignore_fk=False, migration="all"): if not self.get_unran_migrations(): if self.command_class: - self.command_class.line( - "Nothing to migrate" - ) + self.command_class.line("Nothing to migrate") return self.migrate(migration) diff --git a/src/masoniteorm/models/Model.py b/src/masoniteorm/models/Model.py index a7ee0a1f..43d1d237 100644 --- a/src/masoniteorm/models/Model.py +++ b/src/masoniteorm/models/Model.py @@ -1,14 +1,13 @@ +# ruff: noqa: E501 +from datetime import date as datetimedate, datetime, time as datetimetime +from decimal import Decimal import inspect import json import logging -from datetime import date as datetimedate -from datetime import datetime -from datetime import time as datetimetime -from decimal import Decimal from typing import Any, Dict -import pendulum from inflection import tableize, underscore +import pendulum from ..collection import Collection from ..config import load_config @@ -362,9 +361,7 @@ def get_foreign_key(self): Returns: str """ - return underscore( - self.__class__.__name__ + "_" + self.get_primary_key() - ) + return underscore(self.__class__.__name__ + "_" + self.get_primary_key()) def query(self): return self.get_builder() @@ -586,9 +583,7 @@ def create(cls, dictionary=None, query=False, cast=True, **kwargs): self: A hydrated version of a model """ if query: - return cls.builder.create( - dictionary, query=True, cast=cast, **kwargs - ) + return cls.builder.create(dictionary, query=True, cast=cast, **kwargs) return cls.builder.create(dictionary, cast=cast, **kwargs) @@ -620,9 +615,7 @@ def cast_values(self, attributes: Dict[str, Any]) -> Dict[str, Any]: updated_attribs = {} for key, value in attributes.items(): if key in self.get_dates(): - updated_attribs.update( - {key: self.get_new_datetime_string(value)} - ) + updated_attribs.update({key: self.get_new_datetime_string(value)}) elif key in self.__casts__: updated_attribs.update({key: self.cast_value(key, value)}) else: @@ -631,11 +624,7 @@ def cast_values(self, attributes: Dict[str, Any]) -> Dict[str, Any]: return updated_attribs def fresh(self): - return ( - self.get_builder() - .where(self.get_primary_key(), self.get_primary_key_value()) - .first() - ) + return self.get_builder().where(self.get_primary_key(), self.get_primary_key_value()).first() def serialize(self, exclude=None, include=None): """Takes the data as a model and converts it into a dictionary. @@ -647,9 +636,7 @@ def serialize(self, exclude=None, include=None): # prevent using both exclude and include at the same time if exclude is not None and include is not None: - raise AttributeError( - "Can not define both includes and exclude values." - ) + raise AttributeError("Can not define both includes and exclude values.") if exclude is not None: self.__hidden__ = exclude @@ -665,9 +652,7 @@ def serialize(self, exclude=None, include=None): if self.__visible__: new_serialized_dictionary = { - k: serialized_dictionary[k] - for k in self.__visible__ - if k in serialized_dictionary + k: serialized_dictionary[k] for k in self.__visible__ if k in serialized_dictionary } serialized_dictionary = new_serialized_dictionary else: @@ -676,14 +661,9 @@ def serialize(self, exclude=None, include=None): serialized_dictionary.pop(key) for date_column in self.get_dates(): - if ( - date_column in serialized_dictionary - and serialized_dictionary[date_column] - ): - serialized_dictionary[date_column] = ( - self.get_new_serialized_date( - serialized_dictionary[date_column] - ) + if date_column in serialized_dictionary and serialized_dictionary[date_column]: + serialized_dictionary[date_column] = self.get_new_serialized_date( + serialized_dictionary[date_column] ) serialized_dictionary.update(self.__dirty_attributes__) @@ -703,9 +683,7 @@ def serialize(self, exclude=None, include=None): if key in self.__hidden__: remove_keys.append(key) if hasattr(value, "serialize"): - value = value.serialize( - self.__relationship_hidden__.get(key, []) - ) + value = value.serialize(self.__relationship_hidden__.get(key, [])) if isinstance(value, datetime): value = self.get_new_serialized_date(value) if key in self.__casts__: @@ -817,22 +795,12 @@ def __getattr__(self, attribute): if (new_name_accessor) in self.__class__.__dict__: return self.__class__.__dict__.get(new_name_accessor)(self) - if ( - "__dirty_attributes__" in self.__dict__ - and attribute in self.__dict__["__dirty_attributes__"] - ): + if "__dirty_attributes__" in self.__dict__ and attribute in self.__dict__["__dirty_attributes__"]: return self.get_dirty_value(attribute) - if ( - "__attributes__" in self.__dict__ - and attribute in self.__dict__["__attributes__"] - ): + if "__attributes__" in self.__dict__ and attribute in self.__dict__["__attributes__"]: if attribute in self.get_dates(): - return ( - self.get_new_date(self.get_value(attribute)) - if self.get_value(attribute) - else None - ) + return self.get_new_date(self.get_value(attribute)) if self.get_value(attribute) else None return self.get_value(attribute) if attribute in self.__passthrough__: @@ -848,9 +816,7 @@ def method(*args, **kwargs): if attribute not in self.__dict__: name = self.__class__.__name__ - raise AttributeError( - f"class model '{name}' has no attribute {attribute}" - ) + raise AttributeError(f"class model '{name}' has no attribute {attribute}") return None @@ -884,9 +850,7 @@ def __setattr__(self, attribute, value): try: if not attribute.startswith("_"): - self.__dict__["__dirty_attributes__"].update( - {attribute: value} - ) + self.__dict__["__dirty_attributes__"].update({attribute: value}) else: self.__dict__[attribute] = value except KeyError: @@ -925,9 +889,7 @@ def save(self, query=False): if not query: if self.is_loaded(): - result = builder.update( - self.__dirty_attributes__, ignore_mass_assignment=True - ) + result = builder.update(self.__dirty_attributes__, ignore_mass_assignment=True) else: result = self.create( self.__dirty_attributes__, @@ -1121,9 +1083,7 @@ def detach(self, relation, related_record): related = getattr(self.__class__, relation) if not related_record.is_created(): - related_record = related_record.create( - related_record.all_attributes() - ) + related_record = related_record.create(related_record.all_attributes()) else: related_record.save() @@ -1159,11 +1119,7 @@ def delete_quietly(self): Returns: self """ - delete = ( - self.without_events() - .where(self.get_primary_key(), self.get_primary_key_value()) - .delete() - ) + delete = self.without_events().where(self.get_primary_key(), self.get_primary_key_value()).delete() self.with_events() return delete @@ -1178,15 +1134,11 @@ def filter_fillable(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: Passed dictionary is not mutated. """ if cls.__fillable__ != ["*"]: - dictionary = { - x: dictionary[x] for x in cls.__fillable__ if x in dictionary - } + dictionary = {x: dictionary[x] for x in cls.__fillable__ if x in dictionary} return dictionary @classmethod - def filter_mass_assignment( - cls, dictionary: Dict[str, Any] - ) -> Dict[str, Any]: + def filter_mass_assignment(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: """ Filters the provided dictionary in preparation for a mass-assignment operation @@ -1204,6 +1156,4 @@ def filter_guarded(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: if cls.__guarded__ == ["*"]: # If all fields are guarded, all data should be filtered return {} - return { - f: dictionary[f] for f in dictionary if f not in cls.__guarded__ - } + return {f: dictionary[f] for f in dictionary if f not in cls.__guarded__} diff --git a/src/masoniteorm/models/Model.pyi b/src/masoniteorm/models/Model.pyi index 9a346215..a7f000e3 100644 --- a/src/masoniteorm/models/Model.pyi +++ b/src/masoniteorm/models/Model.pyi @@ -1,9 +1,8 @@ -from typing import Any, Callable, Dict +from typing import Any, Callable from ..query.QueryBuilder import QueryBuilder class Model: - # ============================== # Model Methods # ============================== @@ -33,7 +32,7 @@ class Model: """ pass - def cast_values(self, attributes: Dict[str, Any]) -> Dict[str, Any]: + def cast_values(self, attributes: dict[str, Any]) -> dict[str, Any]: """ Runs provided dictionary through all model casters and returns the result. @@ -60,6 +59,14 @@ class Model: Returns: self: A hydrated version of a model + :param dictionary: + :type dictionary: + :param query: + :type query: + :param cast: + :type cast: + :param ignore_mass_assignment: + :type ignore_mass_assignment: """ pass @@ -94,7 +101,7 @@ class Model: pass @classmethod - def filter_fillable(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: + def filter_fillable(cls, dictionary: dict[str, Any]) -> dict[str, Any]: """ Filters provided dictionary to only include fields specified in the model's __fillable__ property @@ -103,9 +110,7 @@ class Model: pass @classmethod - def filter_mass_assignment( - cls, dictionary: Dict[str, Any] - ) -> Dict[str, Any]: + def filter_mass_assignment(cls, dictionary: dict[str, Any]) -> dict[str, Any]: """ Filters the provided dictionary in preparation for a mass-assignment operation @@ -114,7 +119,7 @@ class Model: pass @classmethod - def filter_guarded(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: + def filter_guarded(cls, dictionary: dict[str, Any]) -> dict[str, Any]: """ Filters provided dictionary to exclude fields specified in the model's __guarded__ property @@ -191,9 +196,6 @@ class Model: def get_foreign_key(self): """Gets the foreign key based on this model name. - Args: - relationship (str): The relationship name. - Returns: str """ @@ -404,7 +406,6 @@ class Model: # all marked as @classmethod for IDE # autocomplete to work correctly # ============================== - @classmethod def add_select(cls, alias: str, callable: Any) -> QueryBuilder: """Specifies columns that should be selected @@ -415,9 +416,7 @@ class Model: pass @classmethod - def aggregate( - cls, aggregate: str, column: str, alias: str - ) -> QueryBuilder: + def aggregate(cls, aggregate: str, column: str, alias: str) -> QueryBuilder: """Helper function to aggregate. Arguments: @@ -427,7 +426,7 @@ class Model: pass @classmethod - def all(cls, selects: list = [], query: bool = False): + def all(cls, selects=None, query: bool = False): """Returns all records from the table. Returns: @@ -448,9 +447,7 @@ class Model: pass @classmethod - def between( - cls, column: str, low: str | int, high: str | int - ) -> QueryBuilder: + def between(cls, column: str, low: str | int, high: str | int) -> QueryBuilder: """Specifies a where between expression. Arguments: @@ -499,9 +496,7 @@ class Model: pass @classmethod - def delete( - cls, column: str = None, value: str = None, query: bool = False - ): + def delete(cls, column: str = None, value: str = None, query: bool = False): """Specify the column and value to delete or deletes everything based on a previously used where expression. @@ -546,9 +541,7 @@ class Model: """ pass - def find_or( - self, record_id: int, callback: Callable, args=None, column=None - ): + def find_or(self, record_id: int, callback: Callable, args=None, column=None): """Finds a row by the primary key ID (Requires a model) or raise a ModelNotFound exception. Arguments: @@ -624,7 +617,7 @@ class Model: pass @classmethod - def get(cls, selects: list = []): + def get(cls, selects: list = None): """Runs the select query built from the query builder. Returns: @@ -743,9 +736,7 @@ class Model: pass @classmethod - def joins( - cls, *relationships: list[str], clause: str = "inner" - ) -> QueryBuilder: + def joins(cls, *relationships: list[str], clause: str = "inner") -> QueryBuilder: pass @classmethod @@ -850,9 +841,7 @@ class Model: pass @classmethod - def new_from_builder( - cls, from_builder: QueryBuilder = None - ) -> QueryBuilder: + def new_from_builder(cls, from_builder: QueryBuilder = None) -> QueryBuilder: """Creates a new QueryBuilder class. Returns: @@ -861,9 +850,7 @@ class Model: pass @classmethod - def not_between( - cls, column: str, low: str | int, high: str | int - ) -> QueryBuilder: + def not_between(cls, column: str, low: str | int, high: str | int) -> QueryBuilder: """Specifies a where not between expression. Arguments: @@ -931,7 +918,7 @@ class Model: pass @classmethod - def or_where_exists(cls, value: "str|int|QueryBuilder") -> QueryBuilder: + def or_where_exists(cls, value: str | int | QueryBuilder) -> QueryBuilder: """Specifies a where exists expression. Arguments: @@ -943,9 +930,7 @@ class Model: pass @classmethod - def or_where_not_exists( - cls, value: "str|int|QueryBuilder" - ) -> QueryBuilder: + def or_where_not_exists(cls, value: str | int | QueryBuilder) -> QueryBuilder: """Specifies a where exists expression. Arguments: @@ -978,9 +963,7 @@ class Model: pass @classmethod - def order_by( - cls, column: str, direction: str = "ASC|DESC" - ) -> QueryBuilder: + def order_by(cls, column: str, direction: str = "ASC|DESC") -> QueryBuilder: """Specifies a column to order by. Arguments: @@ -1356,9 +1339,7 @@ class Model: pass @classmethod - def with_count( - cls, relationship: str, callback: Any = None - ) -> QueryBuilder: + def with_count(cls, relationship: str, callback: Any = None) -> QueryBuilder: pass @classmethod diff --git a/src/masoniteorm/observers/ObservesEvents.py b/src/masoniteorm/observers/ObservesEvents.py index 3e2cb040..88fac0ed 100644 --- a/src/masoniteorm/observers/ObservesEvents.py +++ b/src/masoniteorm/observers/ObservesEvents.py @@ -1,6 +1,6 @@ class ObservesEvents: def observe_events(self, model, event): - if model.__has_events__ == True: + if model.__has_events__: for observer in model.__observers__.get(model.__class__, []): try: getattr(observer, event)(model) diff --git a/src/masoniteorm/pagination/BasePaginator.py b/src/masoniteorm/pagination/BasePaginator.py index 5b4932de..a10e4a98 100644 --- a/src/masoniteorm/pagination/BasePaginator.py +++ b/src/masoniteorm/pagination/BasePaginator.py @@ -3,8 +3,7 @@ class BasePaginator: def __iter__(self): - for result in self.result: - yield result + yield from self.result def to_json(self): return json.dumps(self.serialize()) diff --git a/src/masoniteorm/pagination/LengthAwarePaginator.py b/src/masoniteorm/pagination/LengthAwarePaginator.py index 34bb6e20..ea769c40 100644 --- a/src/masoniteorm/pagination/LengthAwarePaginator.py +++ b/src/masoniteorm/pagination/LengthAwarePaginator.py @@ -10,9 +10,7 @@ def __init__(self, result, per_page, current_page, total, url=None): self.per_page = per_page self.count = len(self.result) self.last_page = int(math.ceil(total / per_page)) - self.next_page = ( - (int(self.current_page) + 1) if self.has_more_pages() else None - ) + self.next_page = (int(self.current_page) + 1) if self.has_more_pages() else None self.previous_page = (int(self.current_page) - 1) or None self.total = total self.url = url diff --git a/src/masoniteorm/pagination/SimplePaginator.py b/src/masoniteorm/pagination/SimplePaginator.py index e5542e66..cf1b9a83 100644 --- a/src/masoniteorm/pagination/SimplePaginator.py +++ b/src/masoniteorm/pagination/SimplePaginator.py @@ -7,9 +7,7 @@ def __init__(self, result, per_page, current_page, url=None): self.current_page = current_page self.per_page = per_page self.count = len(self.result) - self.next_page = ( - (int(self.current_page) + 1) if self.has_more_pages() else None - ) + self.next_page = (int(self.current_page) + 1) if self.has_more_pages() else None self.previous_page = (int(self.current_page) - 1) or None self.url = url diff --git a/src/masoniteorm/providers/ORMProvider.py b/src/masoniteorm/providers/ORMProvider.py index 69c7e35c..a4cdfb9c 100644 --- a/src/masoniteorm/providers/ORMProvider.py +++ b/src/masoniteorm/providers/ORMProvider.py @@ -22,18 +22,20 @@ def __init__(self, application): self.application = application def register(self): - self.application.make("commands").add( - MakeMigrationCommand(), - MakeSeedCommand(), - MakeObserverCommand(), - MigrateCommand(), - MigrateResetCommand(), - MakeModelCommand(), - MigrateStatusCommand(), - MigrateRefreshCommand(), - MigrateRollbackCommand(), - SeedRunCommand(), - ), + ( + self.application.make("commands").add( + MakeMigrationCommand(), + MakeSeedCommand(), + MakeObserverCommand(), + MigrateCommand(), + MigrateResetCommand(), + MakeModelCommand(), + MigrateStatusCommand(), + MigrateRefreshCommand(), + MigrateRollbackCommand(), + SeedRunCommand(), + ), + ) def boot(self): pass diff --git a/src/masoniteorm/query/EagerRelation.py b/src/masoniteorm/query/EagerRelation.py index 3597e0fc..5675da1c 100644 --- a/src/masoniteorm/query/EagerRelation.py +++ b/src/masoniteorm/query/EagerRelation.py @@ -14,9 +14,7 @@ def register(self, *relations, callback=None): self.is_nested = True relation_key = relation.split(".")[0] if relation_key not in self.nested_eagers: - self.nested_eagers = { - relation_key: relation.split(".")[1:] - } + self.nested_eagers = {relation_key: relation.split(".")[1:]} else: self.nested_eagers[relation_key] += relation.split(".")[1:] elif isinstance(relation, (tuple, list)): diff --git a/src/masoniteorm/query/QueryBuilder.py b/src/masoniteorm/query/QueryBuilder.py index fb1a2adc..f056bb51 100644 --- a/src/masoniteorm/query/QueryBuilder.py +++ b/src/masoniteorm/query/QueryBuilder.py @@ -1,6 +1,7 @@ -import inspect +# ruff: noqa: E501 from copy import deepcopy from datetime import datetime +import inspect from typing import Any, Callable, Dict, List, Optional from ..collection.Collection import Collection @@ -154,27 +155,13 @@ def reset(self): def get_connection_information(self): return { - "host": self._connection_details.get(self.connection, {}).get( - "host" - ), - "database": self._connection_details.get(self.connection, {}).get( - "database" - ), - "user": self._connection_details.get(self.connection, {}).get( - "user" - ), - "port": self._connection_details.get(self.connection, {}).get( - "port" - ), - "password": self._connection_details.get(self.connection, {}).get( - "password" - ), - "prefix": self._connection_details.get(self.connection, {}).get( - "prefix" - ), - "options": self._connection_details.get(self.connection, {}).get( - "options", {} - ), + "host": self._connection_details.get(self.connection, {}).get("host"), + "database": self._connection_details.get(self.connection, {}).get("database"), + "user": self._connection_details.get(self.connection, {}).get("user"), + "port": self._connection_details.get(self.connection, {}).get("port"), + "password": self._connection_details.get(self.connection, {}).get("password"), + "prefix": self._connection_details.get(self.connection, {}).get("prefix"), + "options": self._connection_details.get(self.connection, {}).get("options", {}), "full_details": self._connection_details.get(self.connection, {}), } @@ -237,15 +224,8 @@ def get_table_name(self): """ return self._table.name - def get_connection(self): - """Sets a table on the query builder - - Arguments: - table {string} -- The name of the table - - Returns: - self - """ + def get_connection_class(self): + """Gets the connection class""" return self.connection_class def begin(self): @@ -288,16 +268,16 @@ def rollback(self): self._connection.rollback() return self - def get_relation(self, key): - """Sets a table on the query builder - - Arguments: - table {string} -- The name of the table - - Returns: - self - """ - return getattr(self.owner, key) + # def get_relation(self, key): + # """Sets a table on the query builder + # + # Arguments: + # table {string} -- The name of the table + # + # Returns: + # self + # """ + # return getattr(self.owner, key) def set_scope(self, name, callable): """Sets a scope based on a class and maps it to a name. @@ -372,31 +352,23 @@ def __getattr__(self, attribute): self """ if attribute == "__setstate__": - raise AttributeError( - "'QueryBuilder' object has no attribute '{}'".format(attribute) - ) + raise AttributeError(f"'QueryBuilder' object has no attribute '{attribute}'") if attribute in self._scopes: def method(*args, **kwargs): - return self._scopes[attribute]( - self._model, self, *args, **kwargs - ) + return self._scopes[attribute](self._model, self, *args, **kwargs) return method if attribute in self._macros: def method(*args, **kwargs): - return self._macros[attribute]( - self._model, self, *args, **kwargs - ) + return self._macros[attribute](self._model, self, *args, **kwargs) return method - raise AttributeError( - "'QueryBuilder' object has no attribute '{}'".format(attribute) - ) + raise AttributeError(f"'QueryBuilder' object has no attribute '{attribute}'") def on(self, connection): if connection == "default": @@ -405,19 +377,11 @@ def on(self, connection): self.connection = connection if self.connection not in self._connection_details: - raise ConnectionNotRegistered( - f"Could not find the '{self.connection}' connection details" - ) + raise ConnectionNotRegistered(f"Could not find the '{self.connection}' connection details") - self._connection_driver = self._connection_details.get( - self.connection - ).get("driver") - resolver = ConnectionResolver( - connection_details=self._connection_details - ) - self.connection_class = resolver.connection_factory.make( - self._connection_driver - ) + self._connection_driver = self._connection_details.get(self.connection).get("driver") + resolver = ConnectionResolver(connection_details=self._connection_details) + self.connection_class = resolver.connection_factory.make(self._connection_driver) self.grammar = self.connection_class.get_default_query_grammar() @@ -505,9 +469,7 @@ def bulk_create( model = model.hydrate(self._creates) if not self.dry: connection = self.new_connection() - query_result = connection.query( - self.to_qmark(), self._bindings, results=1 - ) + query_result = connection.query(self.to_qmark(), self._bindings, results=1) processed_results = query_result or self._creates else: @@ -563,9 +525,7 @@ def create( if not self.dry: connection = self.new_connection() - query_result = connection.query( - self.to_qmark(), self._bindings, results=1 - ) + query_result = connection.query(self.to_qmark(), self._bindings, results=1) if model: id_key = model.get_primary_key() @@ -641,24 +601,14 @@ def where(self, column, *args): if inspect.isfunction(column): builder = column(self.new()) - self._wheres += ( - (QueryExpression(None, operator, SubGroupExpression(builder))), - ) + self._wheres += ((QueryExpression(None, operator, SubGroupExpression(builder))),) elif isinstance(column, dict): for key, value in column.items(): self._wheres += ((QueryExpression(key, "=", value, "value")),) elif isinstance(value, QueryBuilder): - self._wheres += ( - ( - QueryExpression( - column, operator, SubSelectExpression(value) - ) - ), - ) + self._wheres += ((QueryExpression(column, operator, SubSelectExpression(value))),) else: - self._wheres += ( - (QueryExpression(column, operator, value, "value")), - ) + self._wheres += ((QueryExpression(column, operator, value, "value")),) return self def where_from_builder(self, builder): @@ -674,9 +624,7 @@ def where_from_builder(self, builder): self """ - self._wheres += ( - (QueryExpression(None, "=", SubGroupExpression(builder))), - ) + self._wheres += ((QueryExpression(None, "=", SubGroupExpression(builder))),) return self @@ -720,13 +668,7 @@ def where_raw(self, query: str, bindings=()): Returns: self """ - self._wheres += ( - ( - QueryExpression( - query, "=", None, "value", raw=True, bindings=bindings - ) - ), - ) + self._wheres += ((QueryExpression(query, "=", None, "value", raw=True, bindings=bindings)),) return self def or_where(self, column, *args): @@ -753,21 +695,9 @@ def or_where(self, column, *args): ), ) elif isinstance(value, QueryBuilder): - self._wheres += ( - ( - QueryExpression( - column, operator, SubSelectExpression(value) - ) - ), - ) + self._wheres += ((QueryExpression(column, operator, SubSelectExpression(value))),) else: - self._wheres += ( - ( - QueryExpression( - column, operator, value, "value", keyword="or" - ) - ), - ) + self._wheres += ((QueryExpression(column, operator, value, "value", keyword="or")),) return self def where_exists(self, value: "str|int|QueryBuilder"): @@ -780,21 +710,11 @@ def where_exists(self, value: "str|int|QueryBuilder"): self """ if inspect.isfunction(value): - self._wheres += ( - ( - QueryExpression( - None, "EXISTS", SubSelectExpression(value(self.new())) - ) - ), - ) + self._wheres += ((QueryExpression(None, "EXISTS", SubSelectExpression(value(self.new())))),) elif isinstance(value, QueryBuilder): - self._wheres += ( - (QueryExpression(None, "EXISTS", SubSelectExpression(value))), - ) + self._wheres += ((QueryExpression(None, "EXISTS", SubSelectExpression(value))),) else: - self._wheres += ( - (QueryExpression(None, "EXISTS", value, "value")), - ) + self._wheres += ((QueryExpression(None, "EXISTS", value, "value")),) return self @@ -830,13 +750,7 @@ def or_where_exists(self, value: "str|int|QueryBuilder"): ), ) else: - self._wheres += ( - ( - QueryExpression( - None, "EXISTS", value, "value", keyword="or" - ) - ), - ) + self._wheres += ((QueryExpression(None, "EXISTS", value, "value", keyword="or")),) return self @@ -861,17 +775,9 @@ def where_not_exists(self, value: "str|int|QueryBuilder"): ), ) elif isinstance(value, QueryBuilder): - self._wheres += ( - ( - QueryExpression( - None, "NOT EXISTS", SubSelectExpression(value) - ) - ), - ) + self._wheres += ((QueryExpression(None, "NOT EXISTS", SubSelectExpression(value))),) else: - self._wheres += ( - (QueryExpression(None, "NOT EXISTS", value, "value")), - ) + self._wheres += ((QueryExpression(None, "NOT EXISTS", value, "value")),) return self @@ -908,13 +814,7 @@ def or_where_not_exists(self, value: "str|int|QueryBuilder"): ), ) else: - self._wheres += ( - ( - QueryExpression( - None, "NOT EXISTS", value, "value", keyword="or" - ) - ), - ) + self._wheres += ((QueryExpression(None, "NOT EXISTS", value, "value", keyword="or")),) return self @@ -967,16 +867,12 @@ def or_where_null(self, column): Returns: self """ - self._wheres += ( - (QueryExpression(column, "=", None, "NULL", keyword="or")), - ) + self._wheres += ((QueryExpression(column, "=", None, "NULL", keyword="or")),) return self def chunk(self, chunk_amount): chunk_connection = self.new_connection() - for result in chunk_connection.select_many( - self.to_sql(), (), chunk_amount - ): + for result in chunk_connection.select_many(self.to_sql(), (), chunk_amount): yield self.prepare_result(result) def where_not_null(self, column: str): @@ -1008,13 +904,7 @@ def where_date(self, column: str, date: "str|datetime"): Returns: self """ - self._wheres += ( - ( - QueryExpression( - column, "=", self._get_date_string(date), "DATE" - ) - ), - ) + self._wheres += ((QueryExpression(column, "=", self._get_date_string(date), "DATE")),) return self def or_where_date(self, column: str, date: "str|datetime"): @@ -1071,9 +961,7 @@ def not_between(self, column: str, low: str, high: str): Returns: self """ - self._wheres += ( - BetweenExpression(column, low, high, equality="NOT BETWEEN"), - ) + self._wheres += (BetweenExpression(column, low, high, equality="NOT BETWEEN"),) return self def where_in(self, column, wheres=None): @@ -1095,17 +983,9 @@ def where_in(self, column, wheres=None): self._wheres += ((QueryExpression(0, "=", 1, "value_equals")),) elif isinstance(wheres, QueryBuilder): - self._wheres += ( - (QueryExpression(column, "IN", SubSelectExpression(wheres))), - ) + self._wheres += ((QueryExpression(column, "IN", SubSelectExpression(wheres))),) elif callable(wheres): - self._wheres += ( - ( - QueryExpression( - column, "IN", SubSelectExpression(wheres(self.new())) - ) - ), - ) + self._wheres += ((QueryExpression(column, "IN", SubSelectExpression(wheres(self.new())))),) else: self._wheres += ((QueryExpression(column, "IN", list(wheres))),) return self @@ -1115,17 +995,13 @@ def get_relation(self, relationship, builder=None): builder = self if not builder._model: - raise AttributeError( - "You must specify a model in order to use relationship methods" - ) + raise AttributeError("You must specify a model in order to use relationship methods") return getattr(builder._model, relationship) def has(self, *relationships): if not self._model: - raise AttributeError( - "You must specify a model in order to use 'has' relationship methods" - ) + raise AttributeError("You must specify a model in order to use 'has' relationship methods") for relationship in relationships: if "." in relationship: @@ -1140,28 +1016,20 @@ def has(self, *relationships): def or_has(self, *relationships): if not self._model: - raise AttributeError( - "You must specify a model in order to use 'has' relationship methods" - ) + raise AttributeError("You must specify a model in order to use 'has' relationship methods") for relationship in relationships: if "." in relationship: last_builder = self._model.builder split_count = len(relationship.split(".")) - for index, split_relationship in enumerate( - relationship.split(".") - ): + for index, split_relationship in enumerate(relationship.split(".")): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: - last_builder = related.query_has( - last_builder, method="where_exists" - ) + last_builder = related.query_has(last_builder, method="where_exists") continue - last_builder = related.query_has( - last_builder, method="or_where_exists" - ) + last_builder = related.query_has(last_builder, method="or_where_exists") else: related = getattr(self._model, relationship) related.query_has(self, method="or_where_exists") @@ -1177,19 +1045,13 @@ def doesnt_have(self, *relationships): if "." in relationship: last_builder = self._model.builder split_count = len(relationship.split(".")) - for index, split_relationship in enumerate( - relationship.split(".") - ): + for index, split_relationship in enumerate(relationship.split(".")): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: - last_builder = related.query_has( - last_builder, method="where_exists" - ) + last_builder = related.query_has(last_builder, method="where_exists") continue - last_builder = related.query_has( - last_builder, method="where_not_exists" - ) + last_builder = related.query_has(last_builder, method="where_not_exists") else: related = getattr(self._model, relationship) related.query_has(self, method="where_not_exists") @@ -1205,19 +1067,13 @@ def or_doesnt_have(self, *relationships): if "." in relationship: last_builder = self._model.builder split_count = len(relationship.split(".")) - for index, split_relationship in enumerate( - relationship.split(".") - ): + for index, split_relationship in enumerate(relationship.split(".")): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: - last_builder = related.query_has( - last_builder, method="where_exists" - ) + last_builder = related.query_has(last_builder, method="where_exists") continue - last_builder = related.query_has( - last_builder, method="or_where_not_exists" - ) + last_builder = related.query_has(last_builder, method="or_where_not_exists") else: related = getattr(self._model, relationship) related.query_has(self, method="or_where_not_exists") @@ -1225,9 +1081,7 @@ def or_doesnt_have(self, *relationships): def where_has(self, relationship, callback): if not self._model: - raise AttributeError( - "You must specify a model in order to use 'has' relationship methods" - ) + raise AttributeError("You must specify a model in order to use 'has' relationship methods") if "." in relationship: last_builder = self._model.builder @@ -1237,13 +1091,9 @@ def where_has(self, relationship, callback): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: - last_builder = related.query_where_exists( - last_builder, callback, method="where_exists" - ) + last_builder = related.query_where_exists(last_builder, callback, method="where_exists") continue - last_builder = related.query_has( - last_builder, method="where_exists" - ) + last_builder = related.query_has(last_builder, method="where_exists") else: related = getattr(self._model, relationship) related.query_where_exists(self, callback, method="where_exists") @@ -1251,9 +1101,7 @@ def where_has(self, relationship, callback): def or_where_has(self, relationship, callback): if not self._model: - raise AttributeError( - "You must specify a model in order to use 'has' relationship methods" - ) + raise AttributeError("You must specify a model in order to use 'has' relationship methods") if "." in relationship: last_builder = self._model.builder @@ -1262,19 +1110,13 @@ def or_where_has(self, relationship, callback): for index, split_relationship in enumerate(splits): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: - last_builder = related.query_where_exists( - last_builder, callback, method="where_exists" - ) + last_builder = related.query_where_exists(last_builder, callback, method="where_exists") continue - last_builder = related.query_has( - last_builder, method="or_where_exists" - ) + last_builder = related.query_has(last_builder, method="or_where_exists") else: related = getattr(self._model, relationship) - related.query_where_exists( - self, callback, method="or_where_exists" - ) + related.query_where_exists(self, callback, method="or_where_exists") return self def where_doesnt_have(self, relationship, callback): @@ -1286,24 +1128,18 @@ def where_doesnt_have(self, relationship, callback): if "." in relationship: last_builder = self._model.builder split_count = len(relationship.split(".")) - for index, split_relationship in enumerate( - relationship.split(".") - ): + for index, split_relationship in enumerate(relationship.split(".")): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: - last_builder = getattr( - last_builder._model, split_relationship - ).query_where_exists(self, callback, method="where_exists") + last_builder = getattr(last_builder._model, split_relationship).query_where_exists( + self, callback, method="where_exists" + ) continue - last_builder = related.query_has( - last_builder, method="where_not_exists" - ) + last_builder = related.query_has(last_builder, method="where_not_exists") else: related = getattr(self._model, relationship) - related.query_where_exists( - self, callback, method="where_not_exists" - ) + related.query_where_exists(self, callback, method="where_not_exists") return self def or_where_doesnt_have(self, relationship, callback): @@ -1315,31 +1151,23 @@ def or_where_doesnt_have(self, relationship, callback): if "." in relationship: last_builder = self._model.builder split_count = len(relationship.split(".")) - for index, split_relationship in enumerate( - relationship.split(".") - ): + for index, split_relationship in enumerate(relationship.split(".")): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: - last_builder = getattr( - last_builder._model, split_relationship - ).query_where_exists(self, callback, method="where_exists") + last_builder = getattr(last_builder._model, split_relationship).query_where_exists( + self, callback, method="where_exists" + ) continue - last_builder = related.query_has( - last_builder, method="or_where_not_exists" - ) + last_builder = related.query_has(last_builder, method="or_where_not_exists") else: related = getattr(self._model, relationship) - related.query_where_exists( - self, callback, method="or_where_not_exists" - ) + related.query_where_exists(self, callback, method="or_where_not_exists") return self def with_count(self, relationship, callback=None): self.select(*self._model.get_selects()) - return getattr(self._model, relationship).get_with_count_query( - self, callback=callback - ) + return getattr(self._model, relationship).get_with_count_query(self, callback=callback) def where_not_in(self, column, wheres=None): """Specifies where a column does not contain a list of a values. @@ -1357,17 +1185,9 @@ def where_not_in(self, column, wheres=None): wheres = wheres or [] if isinstance(wheres, QueryBuilder): - self._wheres += ( - ( - QueryExpression( - column, "NOT IN", SubSelectExpression(wheres) - ) - ), - ) + self._wheres += ((QueryExpression(column, "NOT IN", SubSelectExpression(wheres))),) else: - self._wheres += ( - (QueryExpression(column, "NOT IN", list(wheres))), - ) + self._wheres += ((QueryExpression(column, "NOT IN", list(wheres))),) return self def join( @@ -1395,11 +1215,7 @@ def join( if inspect.isfunction(column1): self._joins += (column1(JoinClause(table, clause=clause)),) elif isinstance(table, str): - self._joins += ( - JoinClause(table, clause=clause).on( - column1, equality, column2 - ), - ) + self._joins += (JoinClause(table, clause=clause).on(column1, equality, column2),) else: self._joins += (table,) return self @@ -1538,9 +1354,7 @@ def update( if model and model.is_loaded(): self.where(model.get_primary_key(), model.get_primary_key_value()) - additional.update( - {model.get_primary_key(): model.get_primary_key_value()} - ) + additional.update({model.get_primary_key(): model.get_primary_key_value()}) self.observe_events(model, "updating") @@ -1550,11 +1364,7 @@ def update( updates = { attr: value for attr, value in updates.items() - if ( - value is None - or model.__original_attributes__.get(attr, None) - != value - ) + if (value is None or model.__original_attributes__.get(attr, None) != value) } # Do not execute query if no changes @@ -1626,24 +1436,18 @@ def increment(self, column, value=1, dry=False): if model and model.is_loaded(): self.where(model.get_primary_key(), model.get_primary_key_value()) - additional.update( - {model.get_primary_key(): model.get_primary_key_value()} - ) + additional.update({model.get_primary_key(): model.get_primary_key_value()}) self.observe_events(model, "updating") - self._updates += ( - UpdateQueryExpression(column, value, update_type="increment"), - ) + self._updates += (UpdateQueryExpression(column, value, update_type="increment"),) self.set_action("update") if dry or self.dry: return self results = self.new_connection().query(self.to_qmark(), self._bindings) - processed_results = self.get_processor().get_column_value( - self, column, results, id_key, id_value - ) + processed_results = self.get_processor().get_column_value(self, column, results, id_key, id_value) return processed_results def decrement(self, column, value=1, dry=False): @@ -1670,24 +1474,18 @@ def decrement(self, column, value=1, dry=False): if model and model.is_loaded(): self.where(model.get_primary_key(), model.get_primary_key_value()) - additional.update( - {model.get_primary_key(): model.get_primary_key_value()} - ) + additional.update({model.get_primary_key(): model.get_primary_key_value()}) self.observe_events(model, "updating") - self._updates += ( - UpdateQueryExpression(column, value, update_type="decrement"), - ) + self._updates += (UpdateQueryExpression(column, value, update_type="decrement"),) self.set_action("update") if dry or self.dry: return self result = self.new_connection().query(self.to_qmark(), self._bindings) - processed_results = self.get_processor().get_column_value( - self, column, result, id_key, id_value - ) + processed_results = self.get_processor().get_column_value(self, column, result, id_key, id_value) return processed_results def sum(self, column): @@ -1699,7 +1497,7 @@ def sum(self, column): Returns: self """ - self.aggregate("SUM", "{column}".format(column=column)) + self.aggregate("SUM", f"{column}") return self def count(self, column=None, dry=False): @@ -1711,9 +1509,7 @@ def count(self, column=None, dry=False): Returns: self """ - alias = ( - "m_count_reserved" if (column == "*" or column is None) else column - ) + alias = "m_count_reserved" if (column == "*" or column is None) else column if column == "*": self.aggregate("COUNT", f"{column} as {alias}") elif column is None: @@ -1725,9 +1521,7 @@ def count(self, column=None, dry=False): return self if not column: - result = self.new_connection().query( - self.to_qmark(), self._bindings, results=1 - ) + result = self.new_connection().query(self.to_qmark(), self._bindings, results=1) if isinstance(result, dict): return result.get(alias, 0) @@ -1748,7 +1542,7 @@ def max(self, column): Returns: self """ - self.aggregate("MAX", "{column}".format(column=column)) + self.aggregate("MAX", f"{column}") return self def order_by(self, column, direction="ASC"): @@ -1781,9 +1575,7 @@ def order_by_raw(self, query, bindings=None): """ if bindings is None: bindings = [] - self._order_by += ( - OrderByExpression(query, raw=True, bindings=bindings), - ) + self._order_by += (OrderByExpression(query, raw=True, bindings=bindings),) return self def group_by(self, column): @@ -1811,9 +1603,7 @@ def group_by_raw(self, query, bindings=None): """ if bindings is None: bindings = [] - self._group_by += ( - GroupByExpression(column=query, raw=True, bindings=bindings), - ) + self._group_by += (GroupByExpression(column=query, raw=True, bindings=bindings),) return self @@ -1824,11 +1614,7 @@ def aggregate(self, aggregate, column, alias=None): aggregate {string} -- The name of the aggregation. column {string} -- The name of the column to aggregate. """ - self._aggregates += ( - AggregateExpression( - aggregate=aggregate, column=column, alias=alias - ), - ) + self._aggregates += (AggregateExpression(aggregate=aggregate, column=column, alias=alias),) def first(self, fields=None, query=False): """Gets the first record. @@ -1845,9 +1631,7 @@ def first(self, fields=None, query=False): if query: return self - result = self.new_connection().query( - self.to_qmark(), self._bindings, results=1 - ) + result = self.new_connection().query(self.to_qmark(), self._bindings, results=1) return self.prepare_result(result) @@ -1948,9 +1732,7 @@ def find(self, record_id, column=None, query=False): return self.first() - def find_or( - self, record_id: int, callback: Callable, args=None, column=None - ): + def find_or(self, record_id: int, callback: Callable, args=None, column=None): """Finds a row by the primary key ID (Requires a model) or raise a ModelNotFound exception. Arguments: @@ -2069,9 +1851,7 @@ def prepare_result(self, result, collection=False): else: related = self._model.get_related(eager) - result_set = related.get_related( - self, hydrated_model - ) + result_set = related.get_related(self, hydrated_model) self._register_relationships_to_model( related, @@ -2090,9 +1870,7 @@ def prepare_result(self, result, collection=False): else: return result or None - def _register_relationships_to_model( - self, related, related_result, hydrated_model, relation_key - ): + def _register_relationships_to_model(self, related, related_result, hydrated_model, relation_key): """Takes a related result and a hydrated model and registers them to eachother using the relation key. Args: @@ -2119,21 +1897,21 @@ def _register_relationships_to_model( def _map_related(self, related_result, related): return related.map_related(related_result) - def all(self, selects=[], query=False): + def all(self, selects=None, query=False): """Returns all records from the table. Returns: dictionary -- Returns a dictionary of results. """ + if selects is None: + selects = [] self.select(*selects) if query: return self - result = ( - self.new_connection().query(self.to_qmark(), self._bindings) or [] - ) + result = self.new_connection().query(self.to_qmark(), self._bindings) or [] return self.prepare_result(result, collection=True) @@ -2153,9 +1931,7 @@ def new_connection(self): return self._connection self._connection = ( - self.connection_class( - **self.get_connection_information(), name=self.connection - ) + self.connection_class(**self.get_connection_information(), name=self.connection) .set_schema(self._schema) .make_connection() ) @@ -2313,7 +2089,7 @@ def avg(self, column): Returns: self """ - self.aggregate("AVG", "{column}".format(column=column)) + self.aggregate("AVG", f"{column}") return self def min(self, column): @@ -2325,7 +2101,7 @@ def min(self, column): Returns: self """ - self.aggregate("MIN", "{column}".format(column=column)) + self.aggregate("MIN", f"{column}") return self def _extract_operator_value(self, *args): @@ -2355,8 +2131,7 @@ def _extract_operator_value(self, *args): if operator not in operators: raise ValueError( - "Invalid comparison operator. The operator can be %s" - % ", ".join(operators) + "Invalid comparison operator. The operator can be {}".format(", ".join(operators)) ) return operator, value @@ -2379,9 +2154,7 @@ def when(self, conditional, callback): return self def truncate(self, foreign_keys=False, dry=False): - sql = self.get_grammar().truncate_table( - self.get_table_name(), foreign_keys - ) + sql = self.get_grammar().truncate_table(self.get_table_name(), foreign_keys) if dry or self.dry: return sql diff --git a/src/masoniteorm/query/grammars/BaseGrammar.py b/src/masoniteorm/query/grammars/BaseGrammar.py index 2ab5297f..f4afaef3 100644 --- a/src/masoniteorm/query/grammars/BaseGrammar.py +++ b/src/masoniteorm/query/grammars/BaseGrammar.py @@ -142,9 +142,7 @@ def _compile_insert(self, qmark=False): self._sql = self.insert_format().format( key_equals=self._compile_key_value_equals(qmark=qmark), table=self.process_table(self.table), - columns=self.process_columns( - separator=", ", action="insert", qmark=qmark - ), + columns=self.process_columns(separator=", ", action="insert", qmark=qmark), values=self.process_values(separator=", ", qmark=qmark), ) @@ -167,10 +165,7 @@ def _compile_bulk_create(self, qmark=False): return self def columnize_bulk_columns(self, columns=[]): - return ", ".join( - self.column_string().format(column=x, separator="") - for x in columns - ).rstrip(",") + return ", ".join(self.column_string().format(column=x, separator="") for x in columns).rstrip(",") def columnize_bulk_values(self, columns=[], qmark=False): sql = "" @@ -180,27 +175,17 @@ def columnize_bulk_values(self, columns=[], qmark=False): for y in x: if qmark: self.add_binding(y) - inner += ( - "?, " - if qmark - else self.value_string().format( - value=y, separator=", " - ) - ) + inner += "?, " if qmark else self.value_string().format(value=y, separator=", ") inner = inner.rstrip(", ") - sql += self.process_value_string().format( - value=inner, separator=", " - ) + sql += self.process_value_string().format(value=inner, separator=", ") else: if qmark: self.add_binding(x) sql += ( "?, " if qmark - else self.process_value_string().format( - value="?" if qmark else x, separator=", " - ) + else self.process_value_string().format(value="?" if qmark else x, separator=", ") ) return sql.rstrip(", ") @@ -278,11 +263,7 @@ def process_joins(self, qmark=False): sql += self.join_string().format( foreign_table=self.process_table(join.table), - alias=( - f" AS {self.process_table(join.alias)}" - if join.alias - else "" - ), + alias=(f" AS {self.process_table(join.alias)}" if join.alias else ""), on=on_string, keyword=self.join_keywords[join.clause], ) @@ -323,11 +304,7 @@ def _compile_key_value_equals(self, qmark=False): sql += sql_string.format( column=self._table_column_string(key), value=( - self.value_string().format( - value=value, separator="" - ) - if not qmark - else "?" + self.value_string().format(value=value, separator="") if not qmark else "?" ), separator=", ", ) @@ -337,11 +314,7 @@ def _compile_key_value_equals(self, qmark=False): else: sql += sql_string.format( column=self._table_column_string(column), - value=( - self.value_string().format(value=value, separator=", ") - if not qmark - else "?" - ), + value=(self.value_string().format(value=value, separator=", ") if not qmark else "?"), separator=", ", ) if qmark: @@ -369,11 +342,7 @@ def process_aggregates(self): sql += ( aggregate_string.format( aggregate_function=aggregate_function, - column=( - "*" - if column == "*" - else self._table_column_string(column) - ), + column=("*" if column == "*" else self._table_column_string(column)), alias=self.process_alias(aggregates.alias or column), ) + ", " @@ -410,12 +379,8 @@ def process_order_by(self): if "." in column: column_string = self._table_column_string(column) else: - column_string = self.column_string().format( - column=column, separator="" - ) - order_crit += self.order_by_format().format( - column=column_string, direction=direction.upper() - ) + column_string = self.column_string().format(column=column, separator="") + order_crit += self.order_by_format().format(column=column_string, direction=direction.upper()) sql += self.order_by_string().format(order_columns=order_crit) return sql @@ -509,9 +474,7 @@ def process_offset(self): if not self._offset: return "" - return self.offset_string().format( - offset=self._offset, limit=self._limit or 1 - ) + return self.offset_string().format(offset=self._offset, limit=self._limit or 1) def process_locks(self): return self.locks.get(self.lock, "") @@ -538,11 +501,7 @@ def process_having(self, qmark=False): sql_string = self.having_equality_string() sql += sql_string.format( - column=( - self._table_column_string(column) - if raw is False - else column - ), + column=(self._table_column_string(column) if raw is False else column), equality=equality, value=self._compile_value(value), ) @@ -587,14 +546,10 @@ def process_wheres(self, qmark=False, strip_first_where=False): """If we have a raw query we just want to use the query supplied and don't need to compile anything. """ - sql += self.raw_query_string().format( - keyword=keyword, query=where.column - ) + sql += self.raw_query_string().format(keyword=keyword, query=where.column) if not isinstance(where.bindings, (list, tuple)): - raise ValueError( - f"Bindings must be tuple or list. Received {type(where.bindings)}" - ) + raise ValueError(f"Bindings must be tuple or list. Received {type(where.bindings)}") if where.bindings: self.add_binding(*where.bindings) @@ -669,11 +624,7 @@ def process_wheres(self, qmark=False, strip_first_where=False): grammar = value.builder.get_grammar() query_value = ( self.subquery_string() - .format( - query=grammar.process_wheres( - qmark=qmark, strip_first_where=True - ) - ) + .format(query=grammar.process_wheres(qmark=qmark, strip_first_where=True)) .replace("( ", "(") ) if grammar._bindings: @@ -686,9 +637,7 @@ def process_wheres(self, qmark=False, strip_first_where=False): self.add_binding(*value.builder._bindings) else: query_from_builder = value.builder.to_sql() - query_value = self.subquery_string().format( - query=query_from_builder - ) + query_value = self.subquery_string().format(query=query_from_builder) elif isinstance(value, list): query_value = "(" for val in value: @@ -696,9 +645,7 @@ def process_wheres(self, qmark=False, strip_first_where=False): query_value += "?, " self.add_binding(val) else: - query_value += self.value_string().format( - value=val, separator="," - ) + query_value += self.value_string().format(value=val, separator=",") query_value = query_value.rstrip(",").rstrip(", ") + ")" elif value is True and value_type != "NOT NULL": sql_string = self.get_true_column_string() @@ -719,23 +666,15 @@ def process_wheres(self, qmark=False, strip_first_where=False): if qmark: query_value = "?" else: - query_value = self.value_string().format( - value=value, separator="" - ) + query_value = self.value_string().format(value=value, separator="") self.add_binding(value) elif value_type == "column": - query_value = self._table_column_string( - column=value, separator="" - ) + query_value = self._table_column_string(column=value, separator="") elif value_type == "DATE": - query_value = self.value_string().format( - value=value, separator="" - ) + query_value = self.value_string().format(value=value, separator="") elif value_type == "having": - query_value = self._table_column_string( - column=value, separator="" - ) + query_value = self._table_column_string(column=value, separator="") else: query_value = "" @@ -852,9 +791,7 @@ def process_columns(self, separator="", action="select", qmark=False): sql += f"({builder_sql}) AS {column.alias}, " continue - sql += self._table_column_string( - column, alias=alias, separator=separator - ) + sql += self._table_column_string(column, alias=alias, separator=separator) if self._aggregates: sql += self.process_aggregates() @@ -913,9 +850,7 @@ def process_column(self, column, separator=""): table = None if column and "." in column: table, column = column.split(".") - return self.column_string().format( - column=column, separator=separator, table=table or self.table - ) + return self.column_string().format(column=column, separator=separator, table=table or self.table) def _table_column_string(self, column, alias=None, separator=""): """Compiles a column into the column syntax. @@ -972,9 +907,7 @@ def drop_table(self, table): Returns: self """ - self._sql = self.drop_table_string().format( - table=self.process_column(table) - ) + self._sql = self.drop_table_string().format(table=self.process_column(table)) return self def drop_table_if_exists(self, table): @@ -986,9 +919,7 @@ def drop_table_if_exists(self, table): Returns: self """ - self._sql = self.drop_table_if_exists_string().format( - table=self.process_column(table) - ) + self._sql = self.drop_table_if_exists_string().format(table=self.process_column(table)) return self def rename_table(self, current_table_name, new_table_name): @@ -1016,9 +947,7 @@ def truncate_table(self, table, foreign_keys=False): Returns: self """ - raise NotImplementedError( - f"'{self.__class__.__name__}' does not support truncating" - ) + raise NotImplementedError(f"'{self.__class__.__name__}' does not support truncating") def where_regexp_string(self): return "{keyword} {column} REGEXP {value}" diff --git a/src/masoniteorm/query/grammars/SQLiteGrammar.py b/src/masoniteorm/query/grammars/SQLiteGrammar.py index 191c1591..6ee4604b 100644 --- a/src/masoniteorm/query/grammars/SQLiteGrammar.py +++ b/src/masoniteorm/query/grammars/SQLiteGrammar.py @@ -223,6 +223,6 @@ def process_offset(self): self """ if not self._limit: - self._limit = int(-1) + self._limit = -1 return super().process_offset() diff --git a/src/masoniteorm/query/processors/MSSQLPostProcessor.py b/src/masoniteorm/query/processors/MSSQLPostProcessor.py index ecc7847d..4d464061 100644 --- a/src/masoniteorm/query/processors/MSSQLPostProcessor.py +++ b/src/masoniteorm/query/processors/MSSQLPostProcessor.py @@ -24,9 +24,7 @@ def process_insert_get_id(self, builder, results, id_key): dictionary: Should return the modified dictionary. """ - last_id = builder.new_connection().query( - "SELECT @@Identity as [id]", results=1 - ) + last_id = builder.new_connection().query("SELECT @@Identity as [id]", results=1) id = last_id["id"] diff --git a/src/masoniteorm/query/processors/MySQLPostProcessor.py b/src/masoniteorm/query/processors/MySQLPostProcessor.py index 6de4e289..08d7e2aa 100644 --- a/src/masoniteorm/query/processors/MySQLPostProcessor.py +++ b/src/masoniteorm/query/processors/MySQLPostProcessor.py @@ -23,9 +23,7 @@ def process_insert_get_id(self, builder, results, id_key): """ if id_key not in results: - results.update( - {id_key: builder._connection.get_cursor().lastrowid} - ) + results.update({id_key: builder._connection.get_cursor().lastrowid}) return results def get_column_value(self, builder, column, results, id_key, id_value): diff --git a/src/masoniteorm/query/processors/SQLitePostProcessor.py b/src/masoniteorm/query/processors/SQLitePostProcessor.py index ab1372a2..5a96b019 100644 --- a/src/masoniteorm/query/processors/SQLitePostProcessor.py +++ b/src/masoniteorm/query/processors/SQLitePostProcessor.py @@ -23,9 +23,7 @@ def process_insert_get_id(self, builder, results, id_key="id"): """ if id_key not in results: - results.update( - {id_key: builder.get_connection().get_cursor().lastrowid} - ) + results.update({id_key: builder.get_connection().get_cursor().lastrowid}) return results diff --git a/src/masoniteorm/relationships/BelongsTo.py b/src/masoniteorm/relationships/BelongsTo.py index 1a5063a7..a0caa442 100644 --- a/src/masoniteorm/relationships/BelongsTo.py +++ b/src/masoniteorm/relationships/BelongsTo.py @@ -30,9 +30,7 @@ def apply_query(self, foreign, owner): Returns: dict -- A dictionary of data which will be hydrated. """ - return foreign.where( - self.foreign_key, owner.__attributes__[self.local_key] - ).first() + return foreign.where(self.foreign_key, owner.__attributes__[self.local_key]).first() def query_has(self, current_query_builder, method="where_exists"): related_builder = self.get_builder() @@ -98,9 +96,7 @@ def attach(self, current_model, related_record): foreign_key_value = getattr(related_record, self.foreign_key) if not current_model.is_created(): current_model.fill({self.local_key: foreign_key_value}) - return current_model.create( - current_model.all_attributes(), cast=True - ) + return current_model.create(current_model.all_attributes(), cast=True) current_model.update({self.local_key: foreign_key_value}) return current_model @@ -111,14 +107,6 @@ def detach(self, current_model, related_record): def relate(self, related_record): return ( self.get_builder() - .where( - self.foreign_key, related_record.__attributes__[self.local_key] - ) - ._set_creates_related( - { - self.foreign_key: related_record.__attributes__[ - self.local_key - ] - } - ) + .where(self.foreign_key, related_record.__attributes__[self.local_key]) + ._set_creates_related({self.foreign_key: related_record.__attributes__[self.local_key]}) ) diff --git a/src/masoniteorm/relationships/BelongsToMany.py b/src/masoniteorm/relationships/BelongsToMany.py index d12e8ee3..5e566682 100644 --- a/src/masoniteorm/relationships/BelongsToMany.py +++ b/src/masoniteorm/relationships/BelongsToMany.py @@ -1,5 +1,5 @@ -import pendulum from inflection import singularize +import pendulum from ..collection import Collection from ..models.Pivot import Pivot @@ -133,9 +133,7 @@ def apply_query(self, query, owner): model.delete_attribute("m_reserved2") if self.pivot_id: - pivot_data.update( - {self.pivot_id: getattr(model, "m_reserved3")} - ) + pivot_data.update({self.pivot_id: getattr(model, "m_reserved3")}) model.delete_attribute("m_reserved3") if self.with_fields: @@ -244,14 +242,10 @@ def make_query(self, query, relation, eagers=None, callback=None): Collection(relation._get_value(self.local_owner_key)).unique(), ).get() else: - return result.where( - self.local_owner_key, getattr(relation, self.local_owner_key) - ).get() + return result.where(self.local_owner_key, getattr(relation, self.local_owner_key)).get() def get_related(self, query, relation, eagers=None, callback=None): - final_result = self.make_query( - query, relation, eagers=eagers, callback=callback - ) + final_result = self.make_query(query, relation, eagers=eagers, callback=callback) builder = self.make_builder(eagers) for model in final_result: @@ -271,9 +265,7 @@ def get_related(self, query, relation, eagers=None, callback=None): ) if self.pivot_id: - pivot_data.update( - {self.pivot_id: getattr(model, "m_reserved3")} - ) + pivot_data.update({self.pivot_id: getattr(model, "m_reserved3")}) model.delete_attribute("m_reserved3") if self.with_fields: @@ -356,13 +348,7 @@ def relate(self, related_record): return result def register_related(self, key, model, collection): - model.add_relation( - { - key: collection.where( - f"{self._table}_id", getattr(model, self.local_owner_key) - ) - } - ) + model.add_relation({key: collection.where(f"{self._table}_id", getattr(model, self.local_owner_key))}) def joins(self, builder, clause=None): if not self._table: @@ -487,22 +473,18 @@ def get_with_count_query(self, builder, callback): return_query = builder.add_select( f"{query.get_table_name()}_count", lambda q: ( - ( - q.count("*") - .where_column( - f"{builder.get_table_name()}.{self.local_owner_key}", - f"{self._table}.{self.local_key}", - ) - .table(self._table) - .when( - callback, - lambda q: ( - q.where_in( - self.foreign_key, - callback(query.select(self.other_owner_key)), - ) - ), - ) + q.count("*") + .where_column( + f"{builder.get_table_name()}.{self.local_owner_key}", + f"{self._table}.{self.local_key}", + ) + .table(self._table) + .when( + callback, + lambda q: q.where_in( + self.foreign_key, + callback(query.select(self.other_owner_key)), + ), ) ), ) @@ -515,9 +497,7 @@ def attach(self, current_model, related_record): self.foreign_key: getattr(related_record, self.other_owner_key), } - self._table = self._table or self.get_pivot_table_name( - current_model, related_record - ) + self._table = self._table or self.get_pivot_table_name(current_model, related_record) if self.with_timestamps: data.update( @@ -540,9 +520,7 @@ def detach(self, current_model, related_record): self.foreign_key: getattr(related_record, self.other_owner_key), } - self._table = self._table or self.get_pivot_table_name( - current_model, related_record - ) + self._table = self._table or self.get_pivot_table_name(current_model, related_record) return ( Pivot.on(current_model.get_builder().connection) @@ -558,9 +536,7 @@ def attach_related(self, current_model, related_record): self.foreign_key: getattr(related_record, self.other_owner_key), } - self._table = self._table or self.get_pivot_table_name( - current_model, related_record - ) + self._table = self._table or self.get_pivot_table_name(current_model, related_record) if self.with_timestamps: data.update( @@ -583,9 +559,7 @@ def detach_related(self, current_model, related_record): self.foreign_key: getattr(related_record, self.other_owner_key), } - self._table = self._table or self.get_pivot_table_name( - current_model, related_record - ) + self._table = self._table or self.get_pivot_table_name(current_model, related_record) if self.with_timestamps: data.update( diff --git a/src/masoniteorm/relationships/HasMany.py b/src/masoniteorm/relationships/HasMany.py index d4e9bbc9..4004ac44 100644 --- a/src/masoniteorm/relationships/HasMany.py +++ b/src/masoniteorm/relationships/HasMany.py @@ -15,9 +15,7 @@ def apply_query(self, foreign, owner): Returns: dict -- A dictionary of data which will be hydrated. """ - result = foreign.where( - self.foreign_key, owner.__attributes__[self.local_key] - ).get() + result = foreign.where(self.foreign_key, owner.__attributes__[self.local_key]).get() return result @@ -27,12 +25,7 @@ def set_keys(self, owner, attribute): return self def register_related(self, key, model, collection): - model.add_relation( - { - key: collection.get(getattr(model, self.local_key)) - or Collection() - } - ) + model.add_relation({key: collection.get(getattr(model, self.local_key)) or Collection()}) def map_related(self, related_result): return related_result.group_by(self.foreign_key) @@ -41,9 +34,7 @@ def attach(self, current_model, related_record): local_key_value = getattr(current_model, self.local_key) if not related_record.is_created(): related_record.fill({self.foreign_key: local_key_value}) - return related_record.create( - related_record.all_attributes(), cast=True - ) + return related_record.create(related_record.all_attributes(), cast=True) related_record.update({self.foreign_key: local_key_value}) return related_record diff --git a/src/masoniteorm/relationships/HasManyThrough.py b/src/masoniteorm/relationships/HasManyThrough.py index 50973c5e..d08c0996 100644 --- a/src/masoniteorm/relationships/HasManyThrough.py +++ b/src/masoniteorm/relationships/HasManyThrough.py @@ -51,9 +51,7 @@ def __get__(self, instance, owner): relationship2 = self.fn(self)[1]() self.distant_builder = relationship1.builder self.intermediary_builder = relationship2.builder - self.set_keys( - self.distant_builder, self.intermediary_builder, attribute - ) + self.set_keys(self.distant_builder, self.intermediary_builder, attribute) if not instance.is_loaded(): return self @@ -61,13 +59,9 @@ def __get__(self, instance, owner): if attribute in instance._relationships: return instance._relationships[attribute] - return self.apply_related_query( - self.distant_builder, self.intermediary_builder, instance - ) + return self.apply_related_query(self.distant_builder, self.intermediary_builder, instance) - def apply_related_query( - self, distant_builder, intermediary_builder, owner - ): + def apply_related_query(self, distant_builder, intermediary_builder, owner): """ Apply the query to return a Collection of data for the distant models to be hydrated with. @@ -87,9 +81,7 @@ def apply_related_query( intermediate_table = intermediary_builder.get_table_name() return ( - self.distant_builder.select( - f"{distant_table}.*, {intermediate_table}.{self.local_key}" - ) + self.distant_builder.select(f"{distant_table}.*, {intermediate_table}.{self.local_key}") .join( f"{intermediate_table}", f"{intermediate_table}.{self.foreign_key}", @@ -140,9 +132,7 @@ def register_related(self, key, model, collection): model.add_relation({key: related if related else None}) - def get_related( - self, current_builder, relation, eagers=None, callback=None - ): + def get_related(self, current_builder, relation, eagers=None, callback=None): """ Get a Collection to hydrate the models for the distant table with Used when eager loading the model attribute @@ -164,9 +154,7 @@ def get_related( callback(current_builder) ( - self.distant_builder.select( - f"{distant_table}.*, {intermediate_table}.{self.local_key}" - ).join( + self.distant_builder.select(f"{distant_table}.*, {intermediate_table}.{self.local_key}").join( f"{intermediate_table}", f"{intermediate_table}.{self.foreign_key}", "=", @@ -203,9 +191,7 @@ def query_has(self, current_builder, method="where_exists"): return self.distant_builder - def query_where_exists( - self, current_builder, callback, method="where_exists" - ): + def query_where_exists(self, current_builder, callback, method="where_exists"): distant_table = self.distant_builder.get_table_name() intermediate_table = self.intermediary_builder.get_table_name() @@ -220,7 +206,7 @@ def query_where_exists( f"{intermediate_table}.{self.local_key}", f"{current_builder.get_table_name()}.{self.local_owner_key}", ) - .when(callback, lambda q: (callback(q))) + .when(callback, lambda q: callback(q)) ) def get_with_count_query(self, current_builder, callback): @@ -233,32 +219,24 @@ def get_with_count_query(self, current_builder, callback): return_query = current_builder.add_select( f"{self.attribute}_count", lambda q: ( - ( - q.count("*") - .join( - f"{intermediate_table}", - f"{intermediate_table}.{self.foreign_key}", - "=", - f"{distant_table}.{self.other_owner_key}", - ) - .where_column( - f"{intermediate_table}.{self.local_key}", - f"{current_builder.get_table_name()}.{self.local_owner_key}", - ) - .table(distant_table) - .when( - callback, - lambda q: ( - q.where_in( - self.foreign_key, - callback( - self.distant_builder.select( - self.other_owner_key - ) - ), - ) - ), - ) + q.count("*") + .join( + f"{intermediate_table}", + f"{intermediate_table}.{self.foreign_key}", + "=", + f"{distant_table}.{self.other_owner_key}", + ) + .where_column( + f"{intermediate_table}.{self.local_key}", + f"{current_builder.get_table_name()}.{self.local_owner_key}", + ) + .table(distant_table) + .when( + callback, + lambda q: q.where_in( + self.foreign_key, + callback(self.distant_builder.select(self.other_owner_key)), + ), ) ), ) diff --git a/src/masoniteorm/relationships/HasOne.py b/src/masoniteorm/relationships/HasOne.py index b6805e1b..ce46be4e 100644 --- a/src/masoniteorm/relationships/HasOne.py +++ b/src/masoniteorm/relationships/HasOne.py @@ -30,9 +30,7 @@ def apply_query(self, foreign, owner): dict -- A dictionary of data which will be hydrated. """ - return foreign.where( - self.foreign_key, owner.__attributes__[self.local_key] - ).first() + return foreign.where(self.foreign_key, owner.__attributes__[self.local_key]).first() def get_related(self, query, relation, eagers=(), callback=None): """Gets the relation needed between the relation and the related builder. If the relation is a collection @@ -87,9 +85,7 @@ def query_where_exists(self, builder, callback, method="where_exists"): return query def register_related(self, key, model, collection): - related = collection.where( - self.foreign_key, getattr(model, self.local_key) - ).first() + related = collection.where(self.foreign_key, getattr(model, self.local_key)).first() model.add_relation({key: related or None}) @@ -100,9 +96,7 @@ def attach(self, current_model, related_record): local_key_value = getattr(current_model, self.local_key) if not related_record.is_created(): related_record.fill({self.foreign_key: local_key_value}) - return related_record.create( - related_record.all_attributes(), cast=True - ) + return related_record.create(related_record.all_attributes(), cast=True) related_record.update({self.foreign_key: local_key_value}) return related_record diff --git a/src/masoniteorm/relationships/HasOneThrough.py b/src/masoniteorm/relationships/HasOneThrough.py index 1b4e7937..c3e39f2d 100644 --- a/src/masoniteorm/relationships/HasOneThrough.py +++ b/src/masoniteorm/relationships/HasOneThrough.py @@ -56,23 +56,17 @@ def __get__(self, instance, owner): relationship2 = self.fn(self)[1]() self.distant_builder = relationship1.builder self.intermediary_builder = relationship2.builder - self.set_keys( - self.distant_builder, self.intermediary_builder, attribute - ) + self.set_keys(self.distant_builder, self.intermediary_builder, attribute) if instance.is_loaded(): if attribute in instance._relationships: return instance._relationships[attribute] - return self.apply_relation_query( - self.distant_builder, self.intermediary_builder, instance - ) + return self.apply_relation_query(self.distant_builder, self.intermediary_builder, instance) else: return self - def apply_relation_query( - self, distant_builder, intermediary_builder, owner - ): + def apply_relation_query(self, distant_builder, intermediary_builder, owner): """ Apply the query and return a dict of data for the distant model to be hydrated with. @@ -92,9 +86,7 @@ def apply_relation_query( int_table = intermediary_builder.get_table_name() return ( - distant_builder.select( - f"{dist_table}.*, {int_table}.{self.local_owner_key} as {self.local_key}" - ) + distant_builder.select(f"{dist_table}.*, {int_table}.{self.local_owner_key} as {self.local_key}") .join( f"{int_table}", f"{int_table}.{self.foreign_key}", @@ -146,9 +138,7 @@ def register_related(self, key, model, collection): related = collection.get(getattr(model, self.local_key), None) model.add_relation({key: related[0] if related else None}) - def get_related( - self, current_builder, relation, eagers=None, callback=None - ): + def get_related(self, current_builder, relation, eagers=None, callback=None): """ Get the data to hydrate the model for the distant table with Used when eager loading the model attribute @@ -209,9 +199,7 @@ def query_has(self, current_builder, method="where_exists"): return self.distant_builder - def query_where_exists( - self, current_builder, callback, method="where_exists" - ): + def query_where_exists(self, current_builder, callback, method="where_exists"): dist_table = self.distant_builder.get_table_name() int_table = self.intermediary_builder.get_table_name() @@ -226,7 +214,7 @@ def query_where_exists( f"{int_table}.{self.local_owner_key}", f"{current_builder.get_table_name()}.{self.local_key}", ) - .when(callback, lambda q: (callback(q))) + .when(callback, lambda q: callback(q)) ) def get_with_count_query(self, current_builder, callback): @@ -239,32 +227,24 @@ def get_with_count_query(self, current_builder, callback): return_query = current_builder.add_select( f"{self.attribute}_count", lambda q: ( - ( - q.count("*") - .join( - f"{int_table}", - f"{int_table}.{self.foreign_key}", - "=", - f"{dist_table}.{self.other_owner_key}", - ) - .where_column( - f"{int_table}.{self.local_owner_key}", - f"{current_builder.get_table_name()}.{self.local_key}", - ) - .table(dist_table) - .when( - callback, - lambda q: ( - q.where_in( - self.foreign_key, - callback( - self.distant_builder.select( - self.other_owner_key - ) - ), - ) - ), - ) + q.count("*") + .join( + f"{int_table}", + f"{int_table}.{self.foreign_key}", + "=", + f"{dist_table}.{self.other_owner_key}", + ) + .where_column( + f"{int_table}.{self.local_owner_key}", + f"{current_builder.get_table_name()}.{self.local_key}", + ) + .table(dist_table) + .when( + callback, + lambda q: q.where_in( + self.foreign_key, + callback(self.distant_builder.select(self.other_owner_key)), + ), ) ), ) diff --git a/src/masoniteorm/relationships/MorphMany.py b/src/masoniteorm/relationships/MorphMany.py index 272d3510..f0dd6735 100644 --- a/src/masoniteorm/relationships/MorphMany.py +++ b/src/masoniteorm/relationships/MorphMany.py @@ -104,9 +104,7 @@ def get_related(self, query, relation, eagers=None, callback=None): ) .where_in( self.morph_id, - relation.pluck( - relation.first().get_primary_key(), keep_nulls=False - ).unique(), + relation.pluck(relation.first().get_primary_key(), keep_nulls=False).unique(), ) .get() ) @@ -116,9 +114,9 @@ def get_related(self, query, relation, eagers=None, callback=None): if callback: return callback( - self.polymorphic_builder.where( - self.morph_key, record_type - ).where(self.morph_id, relation.get_primary_key_value()) + self.polymorphic_builder.where(self.morph_key, record_type).where( + self.morph_id, relation.get_primary_key_value() + ) ).get() return ( self.polymorphic_builder.where(self.morph_key, record_type) @@ -145,8 +143,6 @@ def get_record_key_lookup(self, relation): break if not record_type: - raise ValueError( - f"Could not find the record type key for the {relation} class" - ) + raise ValueError(f"Could not find the record type key for the {relation} class") return record_type diff --git a/src/masoniteorm/relationships/MorphOne.py b/src/masoniteorm/relationships/MorphOne.py index 4e34ee07..9750be5a 100644 --- a/src/masoniteorm/relationships/MorphOne.py +++ b/src/masoniteorm/relationships/MorphOne.py @@ -106,9 +106,7 @@ def get_related(self, query, relation, eagers=None, callback=None): ) .where_in( self.morph_id, - relation.pluck( - relation.first().get_primary_key(), keep_nulls=False - ).unique(), + relation.pluck(relation.first().get_primary_key(), keep_nulls=False).unique(), ) .get() ) @@ -117,9 +115,9 @@ def get_related(self, query, relation, eagers=None, callback=None): record_type = self.get_record_key_lookup(relation) if callback: return callback( - self.polymorphic_builder.where( - self.morph_key, record_type - ).where(self.morph_id, relation.get_primary_key_value()) + self.polymorphic_builder.where(self.morph_key, record_type).where( + self.morph_id, relation.get_primary_key_value() + ) ).first() return ( @@ -149,8 +147,6 @@ def get_record_key_lookup(self, relation): break if not record_type: - raise ValueError( - f"Could not find the record type key for the {relation} class" - ) + raise ValueError(f"Could not find the record type key for the {relation} class") return record_type diff --git a/src/masoniteorm/schema/Blueprint.py b/src/masoniteorm/schema/Blueprint.py index 05089501..17cf9ef8 100644 --- a/src/masoniteorm/schema/Blueprint.py +++ b/src/masoniteorm/schema/Blueprint.py @@ -1,3 +1,4 @@ +# ruff: noqa: E501 class Blueprint: """Used for building schemas for creating, modifying or altering schema.""" @@ -37,9 +38,7 @@ def string(self, column, length=255, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "string", length=length, nullable=nullable - ) + self._last_column = self.table.add_column(column, "string", length=length, nullable=nullable) return self @@ -56,9 +55,7 @@ def tiny_integer(self, column, length=1, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "tiny_integer", length=length, nullable=nullable - ) + self._last_column = self.table.add_column(column, "tiny_integer", length=length, nullable=nullable) return self def small_integer(self, column, length=5, nullable=False): @@ -74,9 +71,7 @@ def small_integer(self, column, length=5, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "small_integer", length=length, nullable=nullable - ) + self._last_column = self.table.add_column(column, "small_integer", length=length, nullable=nullable) return self def medium_integer(self, column, length=7, nullable=False): @@ -92,9 +87,7 @@ def medium_integer(self, column, length=7, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "medium_integer", length=length, nullable=nullable - ) + self._last_column = self.table.add_column(column, "medium_integer", length=length, nullable=nullable) return self def integer(self, column, length=11, nullable=False): @@ -110,9 +103,7 @@ def integer(self, column, length=11, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "integer", length=length, nullable=nullable - ) + self._last_column = self.table.add_column(column, "integer", length=length, nullable=nullable) return self def big_integer(self, column, length=32, nullable=False): @@ -128,9 +119,7 @@ def big_integer(self, column, length=32, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "big_integer", length=length, nullable=nullable - ) + self._last_column = self.table.add_column(column, "big_integer", length=length, nullable=nullable) return self def unsigned_big_integer(self, column, length=32, nullable=False): @@ -146,19 +135,13 @@ def unsigned_big_integer(self, column, length=32, nullable=False): Returns: self """ - return self.big_integer( - column, length=length, nullable=nullable - ).unsigned() + return self.big_integer(column, length=length, nullable=nullable).unsigned() def _compile_create(self): - return self.grammar( - creates=self._columns, table=self.table - )._compile_create() + return self.grammar(creates=self._columns, table=self.table)._compile_create() def _compile_alter(self): - return self.grammar( - creates=self._columns, table=self.table - )._compile_create() + return self.grammar(creates=self._columns, table=self.table)._compile_create() def increments(self, column, nullable=False): """Sets a column to be the auto incrementing primary key representation for the table. @@ -172,9 +155,7 @@ def increments(self, column, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "increments", nullable=nullable - ) + self._last_column = self.table.add_column(column, "increments", nullable=nullable) self.primary(column) return self @@ -191,9 +172,7 @@ def tiny_increments(self, column, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "tiny_increments", nullable=nullable - ) + self._last_column = self.table.add_column(column, "tiny_increments", nullable=nullable) self.primary(column) return self @@ -221,9 +200,7 @@ def uuid(self, column, nullable=False, length=36): Returns: self """ - self._last_column = self.table.add_column( - column, "uuid", nullable=nullable, length=length - ) + self._last_column = self.table.add_column(column, "uuid", nullable=nullable, length=length) return self def big_increments(self, column, nullable=False): @@ -238,9 +215,7 @@ def big_increments(self, column, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "big_increments", nullable=nullable - ) + self._last_column = self.table.add_column(column, "big_increments", nullable=nullable) self.primary(column) return self @@ -257,9 +232,7 @@ def binary(self, column, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "binary", nullable=nullable - ) + self._last_column = self.table.add_column(column, "binary", nullable=nullable) return self def boolean(self, column, nullable=False): @@ -274,9 +247,7 @@ def boolean(self, column, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "boolean", nullable=nullable - ) + self._last_column = self.table.add_column(column, "boolean", nullable=nullable) return self def default(self, value, raw=False): @@ -301,9 +272,7 @@ def char(self, column, length=1, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "char", length=length, nullable=nullable - ) + self._last_column = self.table.add_column(column, "char", length=length, nullable=nullable) return self def date(self, column, nullable=False): @@ -318,9 +287,7 @@ def date(self, column, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "date", nullable=nullable - ) + self._last_column = self.table.add_column(column, "date", nullable=nullable) return self def time(self, column, nullable=False): @@ -335,9 +302,7 @@ def time(self, column, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "time", nullable=nullable - ) + self._last_column = self.table.add_column(column, "time", nullable=nullable) return self def datetime(self, column, nullable=False, now=False): @@ -354,9 +319,7 @@ def datetime(self, column, nullable=False, now=False): Returns: self """ - self._last_column = self.table.add_column( - column, "datetime", nullable=nullable - ) + self._last_column = self.table.add_column(column, "datetime", nullable=nullable) if now: self._last_column.use_current() @@ -377,9 +340,7 @@ def timestamp(self, column, nullable=False, now=False): self """ - self._last_column = self.table.add_column( - column, "timestamp", nullable=nullable - ) + self._last_column = self.table.add_column(column, "timestamp", nullable=nullable) if now: self._last_column.use_current() @@ -416,9 +377,7 @@ def decimal(self, column, length=17, precision=6, nullable=False): self._last_column = self.table.add_column( column, "decimal", - length="{length}, {precision}".format( - length=length, precision=precision - ), + length=f"{length}, {precision}", nullable=nullable, ) return self @@ -440,9 +399,7 @@ def float(self, column, length=19, precision=4, nullable=False): self._last_column = self.table.add_column( column, "float", - length="{length}, {precision}".format( - length=length, precision=precision - ), + length=f"{length}, {precision}", nullable=nullable, ) return self @@ -459,9 +416,7 @@ def double(self, column, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "double", nullable=nullable - ) + self._last_column = self.table.add_column(column, "double", nullable=nullable) return self def enum(self, column, options=None, nullable=False): @@ -481,7 +436,7 @@ def enum(self, column, options=None, nullable=False): options = options or [] new_options = "" for option in options: - new_options += "'{}',".format(option) + new_options += f"'{option}'," new_options = new_options.rstrip(",") self._last_column = self.table.add_column( column, "enum", length="255", values=options, nullable=nullable @@ -501,9 +456,7 @@ def text(self, column, length=None, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "text", length=length, nullable=nullable - ) + self._last_column = self.table.add_column(column, "text", length=length, nullable=nullable) return self def tiny_text(self, column, length=None, nullable=False): @@ -519,9 +472,7 @@ def tiny_text(self, column, length=None, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "tiny_text", length=length, nullable=nullable - ) + self._last_column = self.table.add_column(column, "tiny_text", length=length, nullable=nullable) return self def unsigned_decimal(self, column, length=17, precision=6, nullable=False): @@ -540,9 +491,7 @@ def unsigned_decimal(self, column, length=17, precision=6, nullable=False): self._last_column = self.table.add_column( column, "decimal", - length="{length}, {precision}".format( - length=length, precision=precision - ), + length=f"{length}, {precision}", nullable=nullable, ).unsigned() return self @@ -561,9 +510,7 @@ def long_text(self, column, length=None, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "long_text", length=length, nullable=nullable - ) + self._last_column = self.table.add_column(column, "long_text", length=length, nullable=nullable) return self def json(self, column, nullable=False): @@ -578,9 +525,7 @@ def json(self, column, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "json", nullable=nullable - ) + self._last_column = self.table.add_column(column, "json", nullable=nullable) return self def jsonb(self, column, nullable=False): @@ -595,9 +540,7 @@ def jsonb(self, column, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "jsonb", nullable=nullable - ) + self._last_column = self.table.add_column(column, "jsonb", nullable=nullable) return self def inet(self, column, length=255, nullable=False): @@ -612,9 +555,7 @@ def inet(self, column, length=255, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "inet", length=255, nullable=nullable - ) + self._last_column = self.table.add_column(column, "inet", length=255, nullable=nullable) return self def cidr(self, column, length=255, nullable=False): @@ -629,9 +570,7 @@ def cidr(self, column, length=255, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "cidr", length=255, nullable=nullable - ) + self._last_column = self.table.add_column(column, "cidr", length=255, nullable=nullable) return self def macaddr(self, column, length=255, nullable=False): @@ -646,9 +585,7 @@ def macaddr(self, column, length=255, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "macaddr", length=255, nullable=nullable - ) + self._last_column = self.table.add_column(column, "macaddr", length=255, nullable=nullable) return self def point(self, column, nullable=False): @@ -663,9 +600,7 @@ def point(self, column, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "point", nullable=nullable - ) + self._last_column = self.table.add_column(column, "point", nullable=nullable) return self def geometry(self, column, nullable=False): @@ -680,9 +615,7 @@ def geometry(self, column, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "geometry", nullable=nullable - ) + self._last_column = self.table.add_column(column, "geometry", nullable=nullable) return self def year(self, column, length=4, default=None, nullable=False): @@ -736,9 +669,7 @@ def unsigned_integer(self, column, nullable=False): Returns: self """ - self._last_column = self.table.add_column( - column, "integer", nullable=nullable - ).unsigned() + self._last_column = self.table.add_column(column, "integer", nullable=nullable).unsigned() return self def morphs(self, column, nullable=False, indexes=True): @@ -754,14 +685,10 @@ def morphs(self, column, nullable=False, indexes=True): self """ _columns = [] + _columns.append(self.table.add_column(f"{column}_id", "integer", nullable=nullable).unsigned()) _columns.append( self.table.add_column( - "{}_id".format(column), "integer", nullable=nullable - ).unsigned() - ) - _columns.append( - self.table.add_column( - "{}_type".format(column), + f"{column}_type", "string", nullable=nullable, length=self._default_string_length, @@ -784,9 +711,7 @@ def to_sql(self): if self._action == "create": return self.platform().compile_create_sql(self.table) elif self._action == "create_table_if_not_exists": - return self.platform().compile_create_sql( - self.table, if_not_exists=True - ) + return self.platform().compile_create_sql(self.table, if_not_exists=True) else: if not self._dry: # get current table schema @@ -885,9 +810,7 @@ def fulltext(self, column=None, name=None): if not isinstance(column, list): column = [column] - self.table.add_constraint( - name or f"{'_'.join(column)}_fulltext", "fulltext", column - ) + self.table.add_constraint(name or f"{'_'.join(column)}_fulltext", "fulltext", column) return self @@ -923,15 +846,9 @@ def add_foreign(self, columns, name=None): columns {string} -- The name of the from_column . to_column . table """ if len(columns.split(".")) != 3: - raise Exception( - "Wrong add_foreign argument, the struncture is from_column.to_column.table" - ) + raise Exception("Wrong add_foreign argument, the struncture is from_column.to_column.table") from_column, to_column, table = columns.split(".") - return ( - self.foreign(from_column, name=name) - .references(to_column) - .on(table) - ) + return self.foreign(from_column, name=name).references(to_column).on(table) def foreign(self, column, name=None): """Starts the creation of a foreign key constraint @@ -980,11 +897,7 @@ def foreign_id_for(self, model, column=None): """ clm = column if column else model.get_foreign_key() - return ( - self.foreign_id(clm) - if model.get_primary_key_type() == "int" - else self.foreign_uuid(column) - ) + return self.foreign_id(clm) if model.get_primary_key_type() == "int" else self.foreign_uuid(column) def references(self, column): """Sets the other column on the foreign table that the local column will use to reference. @@ -1052,9 +965,7 @@ def rename(self, old_column, new_column, data_type, length=None): Returns: self """ - self.table.rename_column( - old_column, new_column, data_type, length=length - ) + self.table.rename_column(old_column, new_column, data_type, length=length) return self def after(self, old_column): @@ -1116,9 +1027,7 @@ def drop_unique(self, index): """ if isinstance(index, list): for column in index: - self.table.remove_unique_index( - f"{self.table.name}_{column}_unique" - ) + self.table.remove_unique_index(f"{self.table.name}_{column}_unique") return self diff --git a/src/masoniteorm/schema/Schema.py b/src/masoniteorm/schema/Schema.py index eaedabf8..a16098cc 100644 --- a/src/masoniteorm/schema/Schema.py +++ b/src/masoniteorm/schema/Schema.py @@ -98,13 +98,9 @@ def on(self, connection_key): if connection_detail: self._connection_driver = connection_detail.get("driver") else: - raise ConnectionNotRegistered( - f"Could not find the '{connection_key}' connection details" - ) + raise ConnectionNotRegistered(f"Could not find the '{connection_key}' connection details") - self.connection_class = resolver.connection_factory.make( - self._connection_driver - ) + self.connection_class = resolver.connection_factory.make(self._connection_driver) return self @@ -187,27 +183,13 @@ def table(self, table): def get_connection_information(self): return { - "host": self.connection_details.get(self.connection, {}).get( - "host" - ), - "database": self.connection_details.get(self.connection, {}).get( - "database" - ), - "user": self.connection_details.get(self.connection, {}).get( - "user" - ), - "port": self.connection_details.get(self.connection, {}).get( - "port" - ), - "password": self.connection_details.get(self.connection, {}).get( - "password" - ), - "prefix": self.connection_details.get(self.connection, {}).get( - "prefix" - ), - "options": self.connection_details.get(self.connection, {}).get( - "options", {} - ), + "host": self.connection_details.get(self.connection, {}).get("host"), + "database": self.connection_details.get(self.connection, {}).get("database"), + "user": self.connection_details.get(self.connection, {}).get("user"), + "port": self.connection_details.get(self.connection, {}).get("port"), + "password": self.connection_details.get(self.connection, {}).get("password"), + "prefix": self.connection_details.get(self.connection, {}).get("prefix"), + "options": self.connection_details.get(self.connection, {}).get("options", {}), "full_details": self.connection_details.get(self.connection), } @@ -241,9 +223,7 @@ def has_column(self, table, column, query_only=False): return bool(self.new_connection().query(sql, ())) def get_columns(self, table, dict=True): - table = self.platform().get_current_schema( - self.new_connection(), table, schema=self.get_schema() - ) + table = self.platform().get_current_schema(self.new_connection(), table, schema=self.get_schema()) result = {} if dict: for column in table.get_added_columns().items(): @@ -288,9 +268,7 @@ def rename(self, table, new_name): return bool(self.new_connection().query(sql, ())) def truncate(self, table, foreign_keys=False): - sql = self.platform().compile_truncate( - table, foreign_keys=foreign_keys - ) + sql = self.platform().compile_truncate(table, foreign_keys=foreign_keys) if self._dry: self._sql = sql @@ -300,9 +278,7 @@ def truncate(self, table, foreign_keys=False): def get_schema(self): """Gets the schema set on the migration class""" - return self.schema or self.get_connection_information().get( - "full_details" - ).get("schema") + return self.schema or self.get_connection_information().get("full_details").get("schema") def get_all_tables(self): """Gets all tables in the database""" @@ -317,9 +293,7 @@ def get_all_tables(self): result = self.new_connection().query(sql, ()) - return ( - list(map(lambda t: list(t.values())[0], result)) if result else [] - ) + return list(map(lambda t: list(t.values())[0], result)) if result else [] def has_table(self, table, query_only=False): """Checks if the a database has a specific table diff --git a/src/masoniteorm/schema/Table.py b/src/masoniteorm/schema/Table.py index 5b3b50dd..419eeb62 100644 --- a/src/masoniteorm/schema/Table.py +++ b/src/masoniteorm/schema/Table.py @@ -47,13 +47,9 @@ def add_column( return column def add_constraint(self, name, constraint_type, columns=None): - self.added_constraints.update( - {name: Constraint(name, constraint_type, columns=columns or [])} - ) + self.added_constraints.update({name: Constraint(name, constraint_type, columns=columns or [])}) - def add_foreign_key( - self, column, table=None, foreign_column=None, name=None - ): + def add_foreign_key(self, column, table=None, foreign_column=None, name=None): foreign_key = ForeignKeyConstraint( column, table, diff --git a/src/masoniteorm/schema/TableDiff.py b/src/masoniteorm/schema/TableDiff.py index 2901793b..205da558 100644 --- a/src/masoniteorm/schema/TableDiff.py +++ b/src/masoniteorm/schema/TableDiff.py @@ -22,9 +22,7 @@ def __init__(self, name): self.comment = None def remove_constraint(self, name): - self.removed_constraints.update( - {name: self.from_table.get_constraint(name)} - ) + self.removed_constraints.update({name: self.from_table.get_constraint(name)}) def get_removed_constraints(self): return self.removed_constraints diff --git a/src/masoniteorm/schema/platforms/MSSQLPlatform.py b/src/masoniteorm/schema/platforms/MSSQLPlatform.py index 0ee57b2a..02993bbe 100644 --- a/src/masoniteorm/schema/platforms/MSSQLPlatform.py +++ b/src/masoniteorm/schema/platforms/MSSQLPlatform.py @@ -63,34 +63,18 @@ class MSSQLPlatform(Platform): def compile_create_sql(self, table, if_not_exists=False): sql = [] - table_create_format = ( - self.create_if_not_exists_format() - if if_not_exists - else self.create_format() - ) + table_create_format = self.create_if_not_exists_format() if if_not_exists else self.create_format() sql.append( table_create_format.format( table=self.wrap_table(table.name), - columns=", ".join( - self.columnize(table.get_added_columns()) - ).strip(), + columns=", ".join(self.columnize(table.get_added_columns())).strip(), constraints=( - ", " - + ", ".join( - self.constraintize( - table.get_added_constraints(), table - ) - ) + ", " + ", ".join(self.constraintize(table.get_added_constraints(), table)) if table.get_added_constraints() else "" ), foreign_keys=( - ", " - + ", ".join( - self.foreign_key_constraintize( - table.name, table.added_foreign_keys - ) - ) + ", " + ", ".join(self.foreign_key_constraintize(table.name, table.added_foreign_keys)) if table.added_foreign_keys else "" ), @@ -116,8 +100,7 @@ def compile_alter_sql(self, table): sql.append( self.alter_format().format( table=self.wrap_table(table.name), - columns="ADD " - + ", ".join(self.columnize(table.added_columns)).strip(), + columns="ADD " + ", ".join(self.columnize(table.added_columns)).strip(), ) ) @@ -125,26 +108,19 @@ def compile_alter_sql(self, table): sql.append( self.alter_format().format( table=self.wrap_table(table.name), - columns="ALTER COLUMN " - + ", ".join(self.columnize(table.changed_columns)).strip(), + columns="ALTER COLUMN " + ", ".join(self.columnize(table.changed_columns)).strip(), ) ) if table.renamed_columns: for name, column in table.get_renamed_columns().items(): - sql.append( - self.rename_column_string( - table.name, name, column.name - ).strip() - ) + sql.append(self.rename_column_string(table.name, name, column.name).strip()) if table.dropped_columns: dropped_sql = [] for name in table.get_dropped_columns(): - dropped_sql.append( - self.drop_column_string().format(name=name).strip() - ) + dropped_sql.append(self.drop_column_string().format(name=name).strip()) sql.append( self.alter_format().format( @@ -169,12 +145,8 @@ def compile_alter_sql(self, table): constraint_name=foreign_key_constraint.constraint_name, column=self.wrap_column(column), table=self.wrap_table(table.name), - foreign_table=self.wrap_table( - foreign_key_constraint.foreign_table - ), - foreign_column=self.wrap_column( - foreign_key_constraint.foreign_column - ), + foreign_table=self.wrap_table(foreign_key_constraint.foreign_table), + foreign_column=self.wrap_column(foreign_key_constraint.foreign_column), cascade=cascade, ) ) @@ -182,9 +154,7 @@ def compile_alter_sql(self, table): if table.dropped_foreign_keys: constraints = table.dropped_foreign_keys for constraint in constraints: - sql.append( - f"ALTER TABLE {self.wrap_table(table.name)} DROP CONSTRAINT {constraint}" - ) + sql.append(f"ALTER TABLE {self.wrap_table(table.name)} DROP CONSTRAINT {constraint}") if table.added_indexes: for name, index in table.added_indexes.items(): @@ -196,18 +166,12 @@ def compile_alter_sql(self, table): ) ) - if ( - table.removed_indexes - or table.removed_unique_indexes - or table.dropped_primary_keys - ): + if table.removed_indexes or table.removed_unique_indexes or table.dropped_primary_keys: constraints = table.removed_indexes constraints += table.removed_unique_indexes constraints += table.dropped_primary_keys for constraint in constraints: - sql.append( - f"DROP INDEX {self.wrap_table(table.name)}.{self.wrap_table(constraint)}" - ) + sql.append(f"DROP INDEX {self.wrap_table(table.name)}.{self.wrap_table(constraint)}") if table.added_constraints: for name, constraint in table.added_constraints.items(): @@ -236,9 +200,7 @@ def columnize(self, columns): sql = [] for name, column in columns.items(): if column.length: - length = self.create_column_length(column.column_type).format( - length=column.length - ) + length = self.create_column_length(column.column_type).format(length=column.length) else: length = "" @@ -249,10 +211,7 @@ def columnize(self, columns): elif column.default in self.premapped_defaults.keys(): default = self.premapped_defaults.get(column.default) elif column.default: - if ( - isinstance(column.default, (str,)) - and not column.default_is_raw - ): + if isinstance(column.default, (str,)) and not column.default_is_raw: default = f" DEFAULT '{column.default}'" else: default = f" DEFAULT {column.default}" @@ -291,9 +250,7 @@ def constraintize(self, constraints, table): sql = [] for name, constraint in constraints.items(): sql.append( - getattr( - self, f"get_{constraint.constraint_type}_constraint_string" - )().format( + getattr(self, f"get_{constraint.constraint_type}_constraint_string")().format( columns=", ".join(constraint.columns), name_columns="_".join(constraint.columns), constraint_name=constraint.name, diff --git a/src/masoniteorm/schema/platforms/MySQLPlatform.py b/src/masoniteorm/schema/platforms/MySQLPlatform.py index 23206274..3a89a98e 100644 --- a/src/masoniteorm/schema/platforms/MySQLPlatform.py +++ b/src/masoniteorm/schema/platforms/MySQLPlatform.py @@ -63,9 +63,7 @@ def columnize(self, columns): sql = [] for name, column in columns.items(): if column.length: - length = self.create_column_length(column.column_type).format( - length=column.length - ) + length = self.create_column_length(column.column_type).format(length=column.length) else: length = "" @@ -76,10 +74,7 @@ def columnize(self, columns): elif column.default in self.premapped_defaults.keys(): default = self.premapped_defaults.get(column.default) elif column.default: - if ( - isinstance(column.default, (str,)) - and not column.default_is_raw - ): + if isinstance(column.default, (str,)) and not column.default_is_raw: default = f" DEFAULT '{column.default}'" else: default = f" DEFAULT {column.default}" @@ -104,16 +99,8 @@ def columnize(self, columns): constraint=constraint, nullable=self.premapped_nulls.get(column.is_null) or "", default=default, - signed=( - " " + self.signed.get(column._signed) - if column._signed - else "" - ), - comment=( - "COMMENT '" + column.comment + "'" - if column.comment - else "" - ), + signed=(" " + self.signed.get(column._signed) if column._signed else ""), + comment=("COMMENT '" + column.comment + "'" if column.comment else ""), ) .strip() ) @@ -122,34 +109,18 @@ def columnize(self, columns): def compile_create_sql(self, table, if_not_exists=False): sql = [] - table_create_format = ( - self.create_if_not_exists_format() - if if_not_exists - else self.create_format() - ) + table_create_format = self.create_if_not_exists_format() if if_not_exists else self.create_format() sql.append( table_create_format.format( table=self.get_table_string().format(table=table.name), - columns=", ".join( - self.columnize(table.get_added_columns()) - ).strip(), + columns=", ".join(self.columnize(table.get_added_columns())).strip(), constraints=( - ", " - + ", ".join( - self.constraintize( - table.get_added_constraints(), table - ) - ) + ", " + ", ".join(self.constraintize(table.get_added_constraints(), table)) if table.get_added_constraints() else "" ), foreign_keys=( - ", " - + ", ".join( - self.foreign_key_constraintize( - table.name, table.added_foreign_keys - ) - ) + ", " + ", ".join(self.foreign_key_constraintize(table.name, table.added_foreign_keys)) if table.added_foreign_keys else "" ), @@ -177,9 +148,7 @@ def compile_alter_sql(self, table): for name, column in table.get_added_columns().items(): if column.length: - length = self.create_column_length( - column.column_type - ).format(length=column.length) + length = self.create_column_length(column.column_type).format(length=column.length) else: length = "" @@ -203,30 +172,16 @@ def compile_alter_sql(self, table): add_columns.append( self.add_column_string() .format( - name=self.get_column_string().format( - column=column.name - ), + name=self.get_column_string().format(column=column.name), data_type=self.type_map.get(column.column_type, ""), column_constraint=column_constraint, length=length, constraint="PRIMARY KEY" if column.primary else "", nullable="NULL" if column.is_null else "NOT NULL", default=default, - signed=( - " " + self.signed.get(column._signed) - if column._signed - else "" - ), - after=( - (" AFTER " + self.wrap_column(column._after)) - if column._after - else "" - ), - comment=( - " COMMENT '" + column.comment + "'" - if column.comment - else "" - ), + signed=(" " + self.signed.get(column._signed) if column._signed else ""), + after=((" AFTER " + self.wrap_column(column._after)) if column._after else ""), + comment=(" COMMENT '" + column.comment + "'" if column.comment else ""), ) .strip() ) @@ -235,9 +190,7 @@ def compile_alter_sql(self, table): self.alter_format().format( table=self.wrap_table(table.name), columns=", ".join(add_columns).strip(), - comment=( - f" COMMENT '{table.comment}'" if table.comment else "" - ), + comment=(f" COMMENT '{table.comment}'" if table.comment else ""), ) ) @@ -246,9 +199,7 @@ def compile_alter_sql(self, table): for name, column in table.get_renamed_columns().items(): if column.length: - length = self.create_column_length( - column.column_type - ).format(length=column.length) + length = self.create_column_length(column.column_type).format(length=column.length) else: length = "" @@ -272,10 +223,7 @@ def compile_alter_sql(self, table): sql.append( self.alter_format().format( table=self.wrap_table(table.name), - columns=", ".join( - f"MODIFY {x}" - for x in self.columnize(table.changed_columns) - ), + columns=", ".join(f"MODIFY {x}" for x in self.columnize(table.changed_columns)), ) ) @@ -321,9 +269,7 @@ def compile_alter_sql(self, table): if table.dropped_foreign_keys: constraints = table.dropped_foreign_keys for constraint in constraints: - sql.append( - f"ALTER TABLE {self.wrap_table(table.name)} DROP FOREIGN KEY {constraint}" - ) + sql.append(f"ALTER TABLE {self.wrap_table(table.name)} DROP FOREIGN KEY {constraint}") if table.added_indexes: for name, index in table.added_indexes.items(): @@ -350,22 +296,14 @@ def compile_alter_sql(self, table): f"ALTER TABLE {self.wrap_table(table.name)} ADD CONSTRAINT {constraint.name} PRIMARY KEY ({','.join(constraint.columns)})" ) - if ( - table.removed_indexes - or table.removed_unique_indexes - or table.dropped_primary_keys - ): + if table.removed_indexes or table.removed_unique_indexes or table.dropped_primary_keys: constraints = table.removed_indexes constraints += table.removed_unique_indexes constraints += table.dropped_primary_keys for constraint in constraints: - sql.append( - f"ALTER TABLE {self.wrap_table(table.name)} DROP INDEX {constraint}" - ) + sql.append(f"ALTER TABLE {self.wrap_table(table.name)} DROP INDEX {constraint}") if table.comment: - sql.append( - f"ALTER TABLE {self.wrap_table(table.name)} COMMENT '{table.comment}'" - ) + sql.append(f"ALTER TABLE {self.wrap_table(table.name)} COMMENT '{table.comment}'") return sql def add_column_string(self): @@ -381,15 +319,15 @@ def rename_column_string(self): return "CHANGE {old} {to}" def columnize_string(self): - return "{name} {data_type}{length}{column_constraint}{signed} {nullable}{default} {constraint}{comment}" + return ( + "{name} {data_type}{length}{column_constraint}{signed} {nullable}{default} {constraint}{comment}" + ) def constraintize(self, constraints, table): sql = [] for name, constraint in constraints.items(): sql.append( - getattr( - self, f"get_{constraint.constraint_type}_constraint_string" - )().format( + getattr(self, f"get_{constraint.constraint_type}_constraint_string")().format( columns=", ".join(constraint.columns), name_columns="_".join(constraint.columns), table=table.name, @@ -458,18 +396,14 @@ def get_current_schema(self, connection, table_name, schema=None): reversed_type_map = {v: k for k, v in self.type_map.items()} for column in result: - column_type = self.get_column_type( - reversed_type_map, column["Type"].upper() - ) + column_type = self.get_column_type(reversed_type_map, column["Type"].upper()) length = self.get_column_length(column["Type"]) default = column.get("Default") table.add_column( column["Field"], column_type, - column_python_type=Schema._type_hints_map.get( - column_type, str - ), + column_python_type=Schema._type_hints_map.get(column_type, str), default=default, length=length, ) diff --git a/src/masoniteorm/schema/platforms/Platform.py b/src/masoniteorm/schema/platforms/Platform.py index c39d65dd..c2d1b583 100644 --- a/src/masoniteorm/schema/platforms/Platform.py +++ b/src/masoniteorm/schema/platforms/Platform.py @@ -2,7 +2,6 @@ class Platform: foreign_key_actions = { "cascade": "CASCADE", "set null": "SET NULL", - "cascade": "CASCADE", "restrict": "RESTRICT", "no action": "NO ACTION", "default": "SET DEFAULT", @@ -14,9 +13,7 @@ def columnize(self, columns): sql = [] for name, column in columns.items(): if column.length: - length = self.create_column_length(column.column_type).format( - length=column.length - ) + length = self.create_column_length(column.column_type).format(length=column.length) else: length = "" @@ -25,10 +22,7 @@ def columnize(self, columns): elif column.default in self.premapped_defaults.keys(): default = self.premapped_defaults.get(column.default) elif column.default: - if ( - isinstance(column.default, (str,)) - and not column.default_is_raw - ): + if isinstance(column.default, (str,)) and not column.default_is_raw: default = f" DEFAULT '{column.default}'" else: default = f" DEFAULT {column.default}" @@ -72,9 +66,7 @@ def foreign_key_constraintize(self, table, foreign_keys): constraint_name=foreign_key.constraint_name, table=self.wrap_table(table), foreign_table=self.wrap_table(foreign_key.foreign_table), - foreign_column=self.wrap_column( - foreign_key.foreign_column - ), + foreign_column=self.wrap_column(foreign_key.foreign_column), cascade=cascade, ) ) @@ -84,9 +76,9 @@ def constraintize(self, constraints): sql = [] for name, constraint in constraints.items(): sql.append( - getattr( - self, f"get_{constraint.constraint_type}_constraint_string" - )().format(columns=", ".join(constraint.columns)) + getattr(self, f"get_{constraint.constraint_type}_constraint_string")().format( + columns=", ".join(constraint.columns) + ) ) return sql diff --git a/src/masoniteorm/schema/platforms/PostgresPlatform.py b/src/masoniteorm/schema/platforms/PostgresPlatform.py index 3b0184de..769c921c 100644 --- a/src/masoniteorm/schema/platforms/PostgresPlatform.py +++ b/src/masoniteorm/schema/platforms/PostgresPlatform.py @@ -75,34 +75,18 @@ class PostgresPlatform(Platform): def compile_create_sql(self, table, if_not_exists=False): sql = [] - table_create_format = ( - self.create_if_not_exists_format() - if if_not_exists - else self.create_format() - ) + table_create_format = self.create_if_not_exists_format() if if_not_exists else self.create_format() sql.append( table_create_format.format( table=self.wrap_table(table.name), - columns=", ".join( - self.columnize(table.get_added_columns()) - ).strip(), + columns=", ".join(self.columnize(table.get_added_columns())).strip(), constraints=( - ", " - + ", ".join( - self.constraintize( - table.get_added_constraints(), table - ) - ) + ", " + ", ".join(self.constraintize(table.get_added_constraints(), table)) if table.get_added_constraints() else "" ), foreign_keys=( - ", " - + ", ".join( - self.foreign_key_constraintize( - table.name, table.added_foreign_keys - ) - ) + ", " + ", ".join(self.foreign_key_constraintize(table.name, table.added_foreign_keys)) if table.added_foreign_keys else "" ), @@ -121,14 +105,10 @@ def compile_create_sql(self, table, if_not_exists=False): for name, column in table.get_added_columns().items(): if column.comment: - sql.append( - f"""COMMENT ON COLUMN "{table.name}"."{name}" is '{column.comment}'""" - ) + sql.append(f"""COMMENT ON COLUMN "{table.name}"."{name}" is '{column.comment}'""") if table.comment: - sql.append( - f"""COMMENT ON TABLE "{table.name}" is '{table.comment}'""" - ) + sql.append(f"""COMMENT ON TABLE "{table.name}" is '{table.comment}'""") return sql @@ -136,9 +116,7 @@ def columnize(self, columns): sql = [] for name, column in columns.items(): if column.length: - length = self.create_column_length(column.column_type).format( - length=column.length - ) + length = self.create_column_length(column.column_type).format(length=column.length) else: length = "" @@ -149,10 +127,7 @@ def columnize(self, columns): elif column.default in self.premapped_defaults.keys(): default = self.premapped_defaults.get(column.default) elif column.default: - if ( - isinstance(column.default, (str,)) - and not column.default_is_raw - ): + if isinstance(column.default, (str,)) and not column.default_is_raw: default = f" DEFAULT '{column.default}'" else: default = f" DEFAULT {column.default}" @@ -192,9 +167,7 @@ def compile_alter_sql(self, table): for name, column in table.get_added_columns().items(): if column.length: - length = self.create_column_length( - column.column_type - ).format(length=column.length) + length = self.create_column_length(column.column_type).format(length=column.length) else: length = "" @@ -226,11 +199,7 @@ def compile_alter_sql(self, table): column_constraint=column_constraint, nullable="NULL" if column.is_null else "NOT NULL", default=default, - after=( - (" AFTER " + self.wrap_column(column._after)) - if column._after - else "" - ), + after=((" AFTER " + self.wrap_column(column._after)) if column._after else ""), ) .strip() ) @@ -247,9 +216,7 @@ def compile_alter_sql(self, table): for name, column in table.get_renamed_columns().items(): if column.length: - length = self.create_column_length( - column.column_type - ).format(length=column.length) + length = self.create_column_length(column.column_type).format(length=column.length) else: length = "" @@ -273,11 +240,7 @@ def compile_alter_sql(self, table): dropped_sql = [] for name in table.get_dropped_columns(): - dropped_sql.append( - self.drop_column_string() - .format(name=self.wrap_column(name)) - .strip() - ) + dropped_sql.append(self.drop_column_string().format(name=self.wrap_column(name)).strip()) sql.append( self.alter_format().format( @@ -290,7 +253,6 @@ def compile_alter_sql(self, table): changed_sql = [] for name, column in table.changed_columns.items(): - column_constraint = "" if column.column_type == "enum": values = ", ".join(f"'{x}'" for x in column.values) @@ -303,8 +265,7 @@ def compile_alter_sql(self, table): nullable="NULL" if column.is_null else "NOT NULL", length=( "(" + str(column.length) + ")" - if column.column_type - not in self.types_without_lengths + if column.column_type not in self.types_without_lengths else "" ), column_constraint=column_constraint, @@ -314,18 +275,12 @@ def compile_alter_sql(self, table): ) if column.is_null: - changed_sql.append( - f"ALTER COLUMN {self.wrap_column(name)} DROP NOT NULL" - ) + changed_sql.append(f"ALTER COLUMN {self.wrap_column(name)} DROP NOT NULL") else: - changed_sql.append( - f"ALTER COLUMN {self.wrap_column(name)} SET NOT NULL" - ) + changed_sql.append(f"ALTER COLUMN {self.wrap_column(name)} SET NOT NULL") if column.default is not None: - changed_sql.append( - f"ALTER COLUMN {self.wrap_column(name)} SET DEFAULT {column.default}" - ) + changed_sql.append(f"ALTER COLUMN {self.wrap_column(name)} SET DEFAULT {column.default}") sql.append( self.alter_format().format( @@ -349,12 +304,8 @@ def compile_alter_sql(self, table): column=self.wrap_column(column), constraint_name=foreign_key_constraint.constraint_name, table=self.wrap_table(table.name), - foreign_table=self.wrap_table( - foreign_key_constraint.foreign_table - ), - foreign_column=self.wrap_column( - foreign_key_constraint.foreign_column - ), + foreign_table=self.wrap_table(foreign_key_constraint.foreign_table), + foreign_column=self.wrap_column(foreign_key_constraint.foreign_column), cascade=cascade, ) ) @@ -364,18 +315,12 @@ def compile_alter_sql(self, table): for constraint in constraints: sql.append(f"DROP INDEX {constraint}") - if ( - table.dropped_foreign_keys - or table.removed_unique_indexes - or table.dropped_primary_keys - ): + if table.dropped_foreign_keys or table.removed_unique_indexes or table.dropped_primary_keys: constraints = table.dropped_foreign_keys constraints += table.removed_unique_indexes constraints += table.dropped_primary_keys for constraint in constraints: - sql.append( - f"ALTER TABLE {self.wrap_table(table.name)} DROP CONSTRAINT {constraint}" - ) + sql.append(f"ALTER TABLE {self.wrap_table(table.name)} DROP CONSTRAINT {constraint}") if table.added_indexes: for name, index in table.added_indexes.items(): @@ -405,9 +350,7 @@ def compile_alter_sql(self, table): ) if table.comment: - sql.append( - f"""COMMENT ON TABLE {self.wrap_table(table.name)} is '{table.comment}'""" - ) + sql.append(f"""COMMENT ON TABLE {self.wrap_table(table.name)} is '{table.comment}'""") return sql @@ -436,9 +379,7 @@ def constraintize(self, constraints, table): sql = [] for name, constraint in constraints.items(): sql.append( - getattr( - self, f"get_{constraint.constraint_type}_constraint_string" - )().format( + getattr(self, f"get_{constraint.constraint_type}_constraint_string")().format( columns=", ".join(constraint.columns), name_columns="_".join(constraint.columns), constraint_name=constraint.name, @@ -500,9 +441,7 @@ def compile_get_all_tables(self, database=None, schema=None): return f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_catalog = '{database}'" def get_current_schema(self, connection, table_name, schema=None): - sql = self.table_information_string().format( - table=table_name, schema=schema or "public" - ) + sql = self.table_information_string().format(table=table_name, schema=schema or "public") reversed_type_map = {v: k for k, v in self.type_map.items()} reversed_type_map.update(self.table_info_map) @@ -523,9 +462,7 @@ def get_current_schema(self, connection, table_name, schema=None): length = None # find default - default = column.get("dflt_value", "") or column.get( - "column_default", "" - ) + default = column.get("dflt_value", "") or column.get("column_default", "") if default and default.startswith("nextval"): table.set_primary_key(column["column_name"]) default = None @@ -534,9 +471,7 @@ def get_current_schema(self, connection, table_name, schema=None): column["column_name"], column_type, default=default, - column_python_type=Schema._type_hints_map.get( - column_type, str - ), + column_python_type=Schema._type_hints_map.get(column_type, str), length=length, ) diff --git a/src/masoniteorm/schema/platforms/SQLitePlatform.py b/src/masoniteorm/schema/platforms/SQLitePlatform.py index f52ad14b..f4cc8404 100644 --- a/src/masoniteorm/schema/platforms/SQLitePlatform.py +++ b/src/masoniteorm/schema/platforms/SQLitePlatform.py @@ -65,32 +65,18 @@ class SQLitePlatform(Platform): def compile_create_sql(self, table, if_not_exists=False): sql = [] - table_create_format = ( - self.create_if_not_exists_format() - if if_not_exists - else self.create_format() - ) + table_create_format = self.create_if_not_exists_format() if if_not_exists else self.create_format() sql.append( table_create_format.format( table=self.get_table_string().format(table=table.name).strip(), - columns=", ".join( - self.columnize(table.get_added_columns()) - ).strip(), + columns=", ".join(self.columnize(table.get_added_columns())).strip(), constraints=( - ", " - + ", ".join( - self.constraintize(table.get_added_constraints()) - ) + ", " + ", ".join(self.constraintize(table.get_added_constraints())) if table.get_added_constraints() else "" ), foreign_keys=( - ", " - + ", ".join( - self.foreign_key_constraintize( - table.name, table.added_foreign_keys - ) - ) + ", " + ", ".join(self.foreign_key_constraintize(table.name, table.added_foreign_keys)) if table.added_foreign_keys else "" ), @@ -109,9 +95,7 @@ def columnize(self, columns): sql = [] for name, column in columns.items(): if column.length: - length = self.create_column_length(column.column_type).format( - length=column.length - ) + length = self.create_column_length(column.column_type).format(length=column.length) else: length = "" @@ -122,10 +106,7 @@ def columnize(self, columns): elif column.default in self.premapped_defaults.keys(): default = self.premapped_defaults.get(column.default) elif column.default: - if ( - isinstance(column.default, (str,)) - and not column.default_is_raw - ): + if isinstance(column.default, (str,)) and not column.default_is_raw: default = f" DEFAULT '{column.default}'" else: default = f" DEFAULT {column.default}" @@ -150,8 +131,7 @@ def columnize(self, columns): length=length, signed=( " " + self.signed.get(column._signed) - if column.column_type not in self.types_without_signs - and column._signed + if column.column_type not in self.types_without_signs and column._signed else "" ), constraint=constraint, @@ -169,7 +149,7 @@ def compile_alter_sql(self, diff): indexes = diff.removed_indexes indexes += diff.removed_unique_indexes for name in indexes: - sql.append("DROP INDEX {name}".format(name=name)) + sql.append(f"DROP INDEX {name}") if diff.added_columns: for name, column in diff.added_columns.items(): @@ -205,21 +185,14 @@ def compile_alter_sql(self, diff): default=default, signed=( " " + self.signed.get(column._signed) - if column.column_type - not in self.types_without_signs - and column._signed + if column.column_type not in self.types_without_signs and column._signed else "" ), constraint=constraint, ) .strip() ) - if ( - diff.renamed_columns - or diff.dropped_columns - or diff.changed_columns - or diff.added_foreign_keys - ): + if diff.renamed_columns or diff.dropped_columns or diff.changed_columns or diff.added_foreign_keys: original_columns = diff.from_table.added_columns # pop off the dropped columns. No need for them here for column in diff.dropped_columns: @@ -228,15 +201,11 @@ def compile_alter_sql(self, diff): sql.append( "CREATE TEMPORARY TABLE __temp__{table} AS SELECT {original_column_names} FROM {table}".format( table=diff.name, - original_column_names=", ".join( - diff.from_table.added_columns.keys() - ), + original_column_names=", ".join(diff.from_table.added_columns.keys()), ) ) - sql.append( - "DROP TABLE {table}".format(table=self.wrap_table(diff.name)) - ) + sql.append(f"DROP TABLE {self.wrap_table(diff.name)}") columns = diff.from_table.added_columns @@ -246,25 +215,15 @@ def compile_alter_sql(self, diff): sql.append( self.create_format().format( - table=self.get_table_string() - .format(table=diff.name) - .strip(), + table=self.get_table_string().format(table=diff.name).strip(), columns=", ".join(self.columnize(columns)).strip(), constraints=( - ", " - + ", ".join( - self.constraintize(diff.get_added_constraints()) - ) + ", " + ", ".join(self.constraintize(diff.get_added_constraints())) if diff.get_added_constraints() else "" ), foreign_keys=( - ", " - + ", ".join( - self.foreign_key_constraintize( - diff.name, diff.added_foreign_keys - ) - ) + ", " + ", ".join(self.foreign_key_constraintize(diff.name, diff.added_foreign_keys)) if diff.added_foreign_keys else "" ), @@ -279,20 +238,13 @@ def compile_alter_sql(self, diff): quoted_table=self.wrap_table(diff.name), table=diff.name, new_columns=", ".join(self.columnize_names(columns)), - original_column_names=", ".join( - diff.from_table.added_columns.keys() - ), + original_column_names=", ".join(diff.from_table.added_columns.keys()), ) ) - sql.append("DROP TABLE __temp__{table}".format(table=diff.name)) + sql.append(f"DROP TABLE __temp__{diff.name}") if diff.new_name: - sql.append( - "ALTER TABLE {old_name} RENAME TO {new_name}".format( - old_name=self.wrap_table(diff.name), - new_name=self.wrap_table(diff.new_name), - ) - ) + sql.append(f"ALTER TABLE {self.wrap_table(diff.name)} RENAME TO {self.wrap_table(diff.new_name)}") if diff.added_indexes: for name, index in diff.added_indexes.items(): @@ -348,9 +300,7 @@ def constraintize(self, constraints): sql = [] for name, constraint in constraints.items(): sql.append( - getattr( - self, f"get_{constraint.constraint_type}_constraint_string" - )().format( + getattr(self, f"get_{constraint.constraint_type}_constraint_string")().format( columns=", ".join(constraint.columns), constraint_name=constraint.name, ) @@ -371,9 +321,7 @@ def foreign_key_constraintize(self, table, foreign_keys): constraint_name=foreign_key.constraint_name, table=self.wrap_table(table), foreign_table=self.wrap_table(foreign_key.foreign_table), - foreign_column=self.wrap_column( - foreign_key.foreign_column - ), + foreign_column=self.wrap_column(foreign_key.foreign_column), cascade=cascade, ) ) @@ -394,9 +342,7 @@ def get_current_schema(self, connection, table_name, schema=None): result = connection.query(sql, ()) for column in result: - column_type = self.get_column_type( - reversed_type_map, column["type"].upper() - ) + column_type = self.get_column_type(reversed_type_map, column["type"].upper()) length = self.get_column_length(column["type"]) # find default @@ -407,9 +353,7 @@ def get_current_schema(self, connection, table_name, schema=None): table.add_column( column["name"], column_type, - column_python_type=Schema._type_hints_map.get( - column_type, str - ), + column_python_type=Schema._type_hints_map.get(column_type, str), default=default, length=length, nullable=int(column.get("notnull")) == 0, diff --git a/src/masoniteorm/scopes/SoftDeleteScope.py b/src/masoniteorm/scopes/SoftDeleteScope.py index 2f749d29..2b6ea510 100644 --- a/src/masoniteorm/scopes/SoftDeleteScope.py +++ b/src/masoniteorm/scopes/SoftDeleteScope.py @@ -8,9 +8,7 @@ def __init__(self, deleted_at_column="deleted_at"): self.deleted_at_column = deleted_at_column def on_boot(self, builder): - builder.set_global_scope( - "_where_null", self._where_null, action="select" - ) + builder.set_global_scope("_where_null", self._where_null, action="select") builder.set_global_scope( "_query_set_null_on_delete", self._query_set_null_on_delete, @@ -23,14 +21,10 @@ def on_boot(self, builder): def on_remove(self, builder): builder.remove_global_scope("_where_null", action="select") - builder.remove_global_scope( - "_query_set_null_on_delete", action="delete" - ) + builder.remove_global_scope("_query_set_null_on_delete", action="delete") def _where_null(self, builder): - return builder.where_null( - f"{builder.get_table_name()}.{self.deleted_at_column}" - ) + return builder.where_null(f"{builder.get_table_name()}.{self.deleted_at_column}") def _with_trashed(self, model, builder): builder.remove_global_scope("_where_null", action="select") @@ -46,9 +40,7 @@ def _force_delete(self, model, builder, query=False): return builder.remove_global_scope(self).delete() def _restore(self, model, builder): - return builder.remove_global_scope(self).update( - {self.deleted_at_column: None} - ) + return builder.remove_global_scope(self).update({self.deleted_at_column: None}) def _query_set_null_on_delete(self, builder): return builder.set_action("update").set_updates( diff --git a/src/masoniteorm/seeds/Seeder.py b/src/masoniteorm/seeds/Seeder.py index 41868a11..108545a5 100644 --- a/src/masoniteorm/seeds/Seeder.py +++ b/src/masoniteorm/seeds/Seeder.py @@ -2,9 +2,7 @@ class Seeder: - def __init__( - self, dry=False, seed_path="databases/seeds", connection=None - ): + def __init__(self, dry=False, seed_path="databases/seeds", connection=None): self.ran_seeds = [] self.dry = dry self.seed_path = seed_path @@ -18,9 +16,7 @@ def call(self, *seeder_classes): seeder_class(connection=self.connection).run() def run_database_seed(self): - database_seeder = pydoc.locate( - f"{self.seed_module}.database_seeder.DatabaseSeeder" - ) + database_seeder = pydoc.locate(f"{self.seed_module}.database_seeder.DatabaseSeeder") self.ran_seeds.append(database_seeder) diff --git a/src/masoniteorm/testing/BaseTestCaseSelectGrammar.py b/src/masoniteorm/testing/BaseTestCaseSelectGrammar.py index 0e296598..c5051c6c 100644 --- a/src/masoniteorm/testing/BaseTestCaseSelectGrammar.py +++ b/src/masoniteorm/testing/BaseTestCaseSelectGrammar.py @@ -25,61 +25,38 @@ def setUp(self): def test_can_compile_select(self): to_sql = self.builder.to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_order_by_and_first(self): to_sql = self.builder.order_by("id", "asc").first(query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_with_columns(self): to_sql = self.builder.select("username", "password").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_with_where(self): - to_sql = ( - self.builder.select("username", "password").where("id", 1).to_sql() - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.select("username", "password").where("id", 1).to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_or_where(self): to_sql = self.builder.where("name", 2).or_where("name", 3).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_grouped_where(self): - to_sql = self.builder.where( - lambda query: query.where("age", 2).where("name", "Joe") - ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.where(lambda query: query.where("age", 2).where("name", "Joe")).to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_with_several_where(self): - to_sql = ( - self.builder.select("username", "password") - .where("id", 1) - .where("username", "joe") - .to_sql() - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.select("username", "password").where("id", 1).where("username", "joe").to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_with_several_where_and_limit(self): @@ -90,163 +67,103 @@ def test_can_compile_with_several_where_and_limit(self): .limit(10) .to_sql() ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_with_sum(self): to_sql = self.builder.sum("age").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_with_max(self): to_sql = self.builder.max("age").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_with_max_and_columns(self): to_sql = self.builder.select("username").max("age").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_with_max_and_columns_different_order(self): to_sql = self.builder.max("age").select("username").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_with_order_by(self): - to_sql = ( - self.builder.select("username").order_by("age", "desc").to_sql() - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.select("username").order_by("age", "desc").to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_with_multiple_order_by(self): - to_sql = ( - self.builder.select("username") - .order_by("age", "desc") - .order_by("name") - .to_sql() - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.select("username").order_by("age", "desc").order_by("name").to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_with_group_by(self): to_sql = self.builder.select("username").group_by("age").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_where_in(self): - to_sql = ( - self.builder.select("username").where_in("age", [1, 2, 3]).to_sql() - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.select("username").where_in("age", [1, 2, 3]).to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_where_in_empty(self): to_sql = self.builder.where_in("age", []).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_where_not_in(self): - to_sql = ( - self.builder.select("username") - .where_not_in("age", [1, 2, 3]) - .to_sql() - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.select("username").where_not_in("age", [1, 2, 3]).to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_where_null(self): to_sql = self.builder.select("username").where_null("age").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_where_not_null(self): to_sql = self.builder.select("username").where_not_null("age").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_count(self): to_sql = self.builder.count("*").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_count_column(self): to_sql = self.builder.count("money").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_where_column(self): to_sql = self.builder.where_column("name", "email").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_sub_select(self): - to_sql = self.builder.where_in( - "name", self.builder.new().select("age") - ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.where_in("name", self.builder.new().select("age")).to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_complex_sub_select(self): to_sql = self.builder.where_in( "name", - ( - self.builder.new() - .select("age") - .where_in("email", self.builder.new().select("email")) - ), + (self.builder.new().select("age").where_in("email", self.builder.new().select("email"))), ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_sub_select_where(self): to_sql = self.builder.where_in( "age", - self.builder.new() - .select("age") - .where("age", 2) - .where("name", "Joe"), + self.builder.new().select("age").where("age", 2).where("name", "Joe"), ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_sub_select_from_lambda(self): @@ -254,9 +171,7 @@ def test_can_compile_sub_select_from_lambda(self): self.builder.new() .where_in( "age", - lambda q: ( - q.select("age").where("age", 2).where("name", "Joe") - ), + lambda q: q.select("age").where("age", 2).where("name", "Joe"), ) .to_sql() ) @@ -264,75 +179,46 @@ def test_can_compile_sub_select_from_lambda(self): self.assertEqual(to_sql, sql) def test_can_compile_sub_select_value(self): - to_sql = self.builder.where( - "name", self.builder.new().sum("age") - ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.where("name", self.builder.new().sum("age")).to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_exists(self): to_sql = ( self.builder.select("age") - .where_exists( - self.builder.new().select("username").where("age", 12) - ) + .where_exists(self.builder.new().select("username").where("age", 12)) .to_sql() ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_not_exists(self): to_sql = ( self.builder.select("age") - .where_not_exists( - self.builder.new().select("username").where("age", 12) - ) + .where_not_exists(self.builder.new().select("username").where("age", 12)) .to_sql() ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_having(self): to_sql = self.builder.sum("age").group_by("age").having("age").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_having_with_expression(self): - to_sql = ( - self.builder.sum("age").group_by("age").having("age", 10).to_sql() - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.sum("age").group_by("age").having("age", 10).to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_having_with_greater_than_expression(self): - to_sql = ( - self.builder.sum("age") - .group_by("age") - .having("age", ">", 10) - .to_sql() - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.sum("age").group_by("age").having("age", ">", 10).to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_join(self): - to_sql = self.builder.join( - "contacts", "users.id", "=", "contacts.user_id" - ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.join("contacts", "users.id", "=", "contacts.user_id").to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_join_clause(self): @@ -345,9 +231,7 @@ def test_can_compile_join_clause(self): ) to_sql = self.builder.join(clause).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_join_clause_with_value(self): @@ -358,9 +242,7 @@ def test_can_compile_join_clause_with_value(self): ) to_sql = self.builder.join(clause).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_join_clause_with_null(self): @@ -372,9 +254,7 @@ def test_can_compile_join_clause_with_null(self): ) to_sql = self.builder.join(clause).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_join_clause_with_not_null(self): @@ -386,57 +266,39 @@ def test_can_compile_join_clause_with_not_null(self): ) to_sql = self.builder.join(clause).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_join_clause_with_lambda(self): to_sql = self.builder.join( "report_groups as rg", - lambda clause: ( - clause.on("bgt.fund", "=", "rg.fund").on_null("bgt") - ), + lambda clause: clause.on("bgt.fund", "=", "rg.fund").on_null("bgt"), ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_left_join_clause_with_lambda(self): to_sql = self.builder.left_join( "report_groups as rg", - lambda clause: ( - clause.on("bgt.fund", "=", "rg.fund").or_on_null("bgt") - ), + lambda clause: clause.on("bgt.fund", "=", "rg.fund").or_on_null("bgt"), ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_right_join_clause_with_lambda(self): to_sql = self.builder.right_join( "report_groups as rg", - lambda clause: ( - clause.on("bgt.fund", "=", "rg.fund").or_on_null("bgt") - ), + lambda clause: clause.on("bgt.fund", "=", "rg.fund").or_on_null("bgt"), ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_left_join(self): - to_sql = self.builder.left_join( - "contacts", "users.id", "=", "contacts.user_id" - ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.left_join("contacts", "users.id", "=", "contacts.user_id").to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_multiple_join(self): @@ -445,149 +307,95 @@ def test_can_compile_multiple_join(self): .join("posts", "comments.post_id", "=", "posts.id") .to_sql() ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_limit_and_offset(self): to_sql = self.builder.limit(10).offset(10).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_between(self): to_sql = self.builder.between("age", 18, 21).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_not_between(self): to_sql = self.builder.not_between("age", 18, 21).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_user_where_raw_and_where(self): - to_sql = ( - self.builder.where_raw("age = '18'") - .where("name", "=", "James") - .to_sql() - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.where_raw("age = '18'").where("name", "=", "James").to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_where_raw_and_where_with_multiple_bindings(self): - query = self.builder.where_raw( - "`age` = ? AND `is_admin` = ?", [18, True] - ).where("email", "test@example.com") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + query = self.builder.where_raw("`age` = ? AND `is_admin` = ?", [18, True]).where( + "email", "test@example.com" + ) + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(query.to_qmark(), sql) self.assertEqual(query._bindings, [18, True, "test@example.com"]) def test_can_compile_first_or_fail(self): - to_sql = ( - self.builder.where("is_admin", "=", True) - .first_or_fail(query=True) - .to_sql() - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.where("is_admin", "=", True).first_or_fail(query=True).to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_where_like(self): to_sql = self.builder.where("age", "like", "%name%").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_where_regexp(self): to_sql = self.builder.where("age", "regexp", "Joe").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_where_exists_with_lambda(self): - to_sql = self.builder.where_exists( - lambda q: q.where("age", 1) - ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.where_exists(lambda q: q.where("age", 1)).to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() print(to_sql) self.assertEqual(to_sql, sql) def test_where_not_exists_with_lambda(self): - to_sql = self.builder.where_not_exists( - lambda q: q.where("age", 1) - ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.where_not_exists(lambda q: q.where("age", 1)).to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() print(to_sql) self.assertEqual(to_sql, sql) def test_where_not_regexp(self): to_sql = self.builder.where("age", "not regexp", "Joe").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_where_not_like(self): to_sql = self.builder.where("age", "not like", "%name%").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_shared_lock(self): to_sql = self.builder.where("votes", ">=", 100).shared_lock().to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_update_lock(self): - to_sql = ( - self.builder.where("votes", ">=", 100).lock_for_update().to_sql() - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.where("votes", ">=", 100).lock_for_update().to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_where_date(self): to_sql = self.builder.where_date("created_at", "2022-06-01").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_or_where_null(self): - to_sql = ( - self.builder.where_null("column1") - .or_where_null("column2") - .to_sql() - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.where_null("column1").or_where_null("column2").to_sql() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_select_distinct(self): to_sql = self.builder.select("group").distinct().to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) From d1739449b2c79e2b623d3414cda40c559c2eb6b8 Mon Sep 17 00:00:00 2001 From: Kieren Eaton <499977+circulon@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:04:58 +0800 Subject: [PATCH 06/10] format tests --- .../commands/MakeModelDocstringCommand.py | 4 +- src/masoniteorm/commands/MigrateCommand.py | 4 +- src/masoniteorm/factories/Factory.py | 4 +- .../query/grammars/MySQLGrammar.py | 8 +- .../relationships/BaseRelationship.py | 40 +-- src/masoniteorm/relationships/MorphTo.py | 8 +- src/masoniteorm/relationships/MorphToMany.py | 8 +- src/masoniteorm/scopes/TimeStampsScope.py | 12 +- src/masoniteorm/scopes/UUIDPrimaryKeyScope.py | 10 +- tests/collection/test_collection.py | 16 +- tests/commands/test_shell.py | 9 +- tests/eagers/test_eager.py | 20 +- tests/models/test_models.py | 58 +---- .../mssql/builder/test_mssql_query_builder.py | 40 +-- .../grammar/test_mssql_delete_grammar.py | 10 +- .../grammar/test_mssql_select_grammar.py | 40 ++- .../grammar/test_mssql_update_grammar.py | 15 +- .../mssql/schema/test_mssql_schema_builder.py | 8 +- .../schema/test_mssql_schema_builder_alter.py | 12 +- .../builder/test_mysql_builder_transaction.py | 4 +- tests/mysql/builder/test_query_builder.py | 220 ++++------------ .../builder/test_query_builder_scopes.py | 8 +- .../grammar/test_mysql_delete_grammar.py | 22 +- .../grammar/test_mysql_insert_grammar.py | 20 +- tests/mysql/grammar/test_mysql_qmark.py | 50 +--- .../grammar/test_mysql_select_grammar.py | 26 +- .../grammar/test_mysql_update_grammar.py | 35 +-- .../model/test_accessors_and_mutators.py | 4 +- tests/mysql/model/test_model.py | 52 +--- .../relationships/test_belongs_to_many.py | 16 +- .../relationships/test_has_many_through.py | 8 +- .../relationships/test_has_one_through.py | 8 +- .../mysql/relationships/test_relationships.py | 20 +- .../mysql/schema/test_mysql_schema_builder.py | 34 +-- .../schema/test_mysql_schema_builder_alter.py | 20 +- .../scopes/test_can_use_global_scopes.py | 4 +- tests/mysql/scopes/test_soft_delete.py | 4 +- .../builder/test_postgres_query_builder.py | 192 ++++---------- tests/postgres/grammar/test_delete_grammar.py | 18 +- tests/postgres/grammar/test_insert_grammar.py | 16 +- tests/postgres/grammar/test_select_grammar.py | 32 +-- tests/postgres/grammar/test_update_grammar.py | 37 +-- .../schema/test_postgres_schema_builder.py | 24 +- .../test_postgres_schema_builder_alter.py | 12 +- .../builder/test_sqlite_query_builder.py | 238 +++++------------- ...test_sqlite_query_builder_relationships.py | 4 +- .../grammar/test_sqlite_delete_grammar.py | 19 +- .../grammar/test_sqlite_insert_grammar.py | 20 +- .../grammar/test_sqlite_select_grammar.py | 26 +- .../grammar/test_sqlite_update_grammar.py | 33 +-- tests/sqlite/models/test_sqlite_model.py | 4 +- ...st_sqlite_has_many_through_relationship.py | 4 +- ...est_sqlite_has_one_through_relationship.py | 14 +- .../test_sqlite_relationships.py | 8 +- .../schema/test_sqlite_schema_builder.py | 20 +- .../test_sqlite_schema_builder_alter.py | 16 +- tests/sqlite/schema/test_table.py | 4 +- 57 files changed, 422 insertions(+), 1200 deletions(-) diff --git a/src/masoniteorm/commands/MakeModelDocstringCommand.py b/src/masoniteorm/commands/MakeModelDocstringCommand.py index 9f4c187a..84d612d6 100644 --- a/src/masoniteorm/commands/MakeModelDocstringCommand.py +++ b/src/masoniteorm/commands/MakeModelDocstringCommand.py @@ -19,9 +19,7 @@ def handle(self): schema = DB.get_schema_builder(self.option("connection")) if not schema.has_table(table): - return self.line_error( - f"There is no such table {table} for this connection." - ) + return self.line_error(f"There is no such table {table} for this connection.") self.info(f"Model Docstring for table: {table}") print('"""') diff --git a/src/masoniteorm/commands/MigrateCommand.py b/src/masoniteorm/commands/MigrateCommand.py index d86dea59..f9e352d4 100644 --- a/src/masoniteorm/commands/MigrateCommand.py +++ b/src/masoniteorm/commands/MigrateCommand.py @@ -22,9 +22,7 @@ def handle(self): if os.getenv("APP_ENV") == "production" and not self.option("force"): answer = "" while answer not in ["y", "n"]: - answer = input( - "Do you want to run migrations in PRODUCTION ? (y/n)\n" - ).lower() + answer = input("Do you want to run migrations in PRODUCTION ? (y/n)\n").lower() if answer != "y": self.info("Migrations cancelled") exit(0) diff --git a/src/masoniteorm/factories/Factory.py b/src/masoniteorm/factories/Factory.py index 83f24115..aed8fcdb 100644 --- a/src/masoniteorm/factories/Factory.py +++ b/src/masoniteorm/factories/Factory.py @@ -11,9 +11,7 @@ def faker(self): try: from faker import Faker except ImportError: - raise ImportError( - "Could not find the 'faker' library. Run 'pip install faker' to fix this." - ) + raise ImportError("Could not find the 'faker' library. Run 'pip install faker' to fix this.") if not Factory._faker: Factory._faker = Faker() diff --git a/src/masoniteorm/query/grammars/MySQLGrammar.py b/src/masoniteorm/query/grammars/MySQLGrammar.py index e8eb6a53..6eba6a47 100644 --- a/src/masoniteorm/query/grammars/MySQLGrammar.py +++ b/src/masoniteorm/query/grammars/MySQLGrammar.py @@ -120,14 +120,10 @@ def process_table(self, table): if not table: return "" if isinstance(table, str): - return ".".join( - self.table_string().format(table=t) for t in table.split(".") - ) + return ".".join(self.table_string().format(table=t) for t in table.split(".")) if table.raw: return table.name - return ".".join( - self.table_string().format(table=t) for t in table.name.split(".") - ) + return ".".join(self.table_string().format(table=t) for t in table.name.split(".")) def subquery_alias_string(self): return "AS {alias}" diff --git a/src/masoniteorm/relationships/BaseRelationship.py b/src/masoniteorm/relationships/BaseRelationship.py index cbf5bb24..f6ecd284 100644 --- a/src/masoniteorm/relationships/BaseRelationship.py +++ b/src/masoniteorm/relationships/BaseRelationship.py @@ -77,16 +77,12 @@ def apply_query(self, foreign, owner): dict -- A dictionary of data which will be hydrated. """ klass = self.__class__.__name__ - raise NotImplementedError( - f"{klass} relationship does not implement the 'apply_query' method" - ) + raise NotImplementedError(f"{klass} relationship does not implement the 'apply_query' method") def query_where_exists(self, builder, callback, method="where_exists"): """Adds a criteria clause to the query filter for existing related records""" klass = self.__class__.__name__ - raise NotImplementedError( - f"{klass} relationship does not implement the 'query_where_exists' method" - ) + raise NotImplementedError(f"{klass} relationship does not implement the 'query_where_exists' method") def joins(self, builder, clause=None): """Helper method for adding join clauses to a relationship""" @@ -110,52 +106,36 @@ def get_with_count_query(self, builder, callback): def attach(self, current_model, related_record): """Link a related model to the current model""" klass = self.__class__.__name__ - raise NotImplementedError( - f"{klass} relationship does not implement the 'attach' method" - ) + raise NotImplementedError(f"{klass} relationship does not implement the 'attach' method") def get_related(self, query, relation, eagers=None, callback=None): klass = self.__class__.__name__ - raise NotImplementedError( - f"{klass} relationship does not implement the 'get_related' method" - ) + raise NotImplementedError(f"{klass} relationship does not implement the 'get_related' method") def relate(self, related_record): klass = self.__class__.__name__ - raise NotImplementedError( - f"{klass} relationship does not implement the 'relate' method" - ) + raise NotImplementedError(f"{klass} relationship does not implement the 'relate' method") def detach(self, current_model, related_record): """Unlink a related model from the current model""" klass = self.__class__.__name__ - raise NotImplementedError( - f"{klass} relationship does not implement the 'detach' method" - ) + raise NotImplementedError(f"{klass} relationship does not implement the 'detach' method") def attach_related(self, current_model, related_record): """Unlink a related model from the current model""" klass = self.__class__.__name__ - raise NotImplementedError( - f"{klass} relationship does not implement the 'attach_related' method" - ) + raise NotImplementedError(f"{klass} relationship does not implement the 'attach_related' method") def detach_related(self, current_model, related_record): """Unlink a related model from the current model""" klass = self.__class__.__name__ - raise NotImplementedError( - f"{klass} relationship does not implement the 'detach_related' method" - ) + raise NotImplementedError(f"{klass} relationship does not implement the 'detach_related' method") def query_has(self, current_query_builder, method="where_exists"): """Adds a clause to the query to chek if a rwlarion exists""" klass = self.__class__.__name__ - raise NotImplementedError( - f"{klass} relationship does not implement the 'query_has' method" - ) + raise NotImplementedError(f"{klass} relationship does not implement the 'query_has' method") def map_related(self, related_result): klass = self.__class__.__name__ - raise NotImplementedError( - f"{klass} relationship does not implement the 'related_result' method" - ) + raise NotImplementedError(f"{klass} relationship does not implement the 'related_result' method") diff --git a/src/masoniteorm/relationships/MorphTo.py b/src/masoniteorm/relationships/MorphTo.py index 638c55cb..39864b45 100644 --- a/src/masoniteorm/relationships/MorphTo.py +++ b/src/masoniteorm/relationships/MorphTo.py @@ -84,9 +84,7 @@ def get_related(self, query, relation, eagers=None, callback=None): relations.merge( morphed_model.where_in( f"{morphed_model.get_table_name()}.{morphed_model.get_primary_key()}", - Collection(items) - .pluck(self.morph_id, keep_nulls=False) - .unique(), + Collection(items).pluck(self.morph_id, keep_nulls=False).unique(), ).get() ) return relations @@ -98,9 +96,7 @@ def get_related(self, query, relation, eagers=None, callback=None): def register_related(self, key, model, collection): morphed_model = self.morph_map().get(getattr(model, self.morph_key)) - related = collection.where( - morphed_model.get_primary_key(), getattr(model, self.morph_id) - ).first() + related = collection.where(morphed_model.get_primary_key(), getattr(model, self.morph_id)).first() model.add_relation({key: related}) diff --git a/src/masoniteorm/relationships/MorphToMany.py b/src/masoniteorm/relationships/MorphToMany.py index a5c46a61..53f90f11 100644 --- a/src/masoniteorm/relationships/MorphToMany.py +++ b/src/masoniteorm/relationships/MorphToMany.py @@ -84,9 +84,7 @@ def get_related(self, query, relation, eagers=None, callback=None): relations.merge( morphed_model.where_in( f"{morphed_model.get_table_name()}.{morphed_model.get_primary_key()}", - Collection(items) - .pluck(self.morph_id, keep_nulls=False) - .unique(), + Collection(items).pluck(self.morph_id, keep_nulls=False).unique(), ).get() ) return relations @@ -98,9 +96,7 @@ def get_related(self, query, relation, eagers=None, callback=None): def register_related(self, key, model, collection): morphed_model = self.morph_map().get(getattr(model, self.morph_key)) - related = collection.where( - morphed_model.get_primary_key(), getattr(model, self.morph_id) - ) + related = collection.where(morphed_model.get_primary_key(), getattr(model, self.morph_id)) model.add_relation({key: related}) diff --git a/src/masoniteorm/scopes/TimeStampsScope.py b/src/masoniteorm/scopes/TimeStampsScope.py index e9da387e..07d2b35d 100644 --- a/src/masoniteorm/scopes/TimeStampsScope.py +++ b/src/masoniteorm/scopes/TimeStampsScope.py @@ -6,13 +6,9 @@ class TimeStampsScope(BaseScope): """Global scope class to add soft deleting to models.""" def on_boot(self, builder): - builder.set_global_scope( - "_timestamps", self.set_timestamp_create, action="insert" - ) + builder.set_global_scope("_timestamps", self.set_timestamp_create, action="insert") - builder.set_global_scope( - "_timestamp_update", self.set_timestamp_update, action="update" - ) + builder.set_global_scope("_timestamp_update", self.set_timestamp_update, action="update") def on_remove(self, builder): pass @@ -40,8 +36,6 @@ def set_timestamp_update(self, builder): return builder._updates += ( UpdateQueryExpression( - { - builder._model.date_updated_at: builder._model.get_new_date().to_datetime_string() - } + {builder._model.date_updated_at: builder._model.get_new_date().to_datetime_string()} ), ) diff --git a/src/masoniteorm/scopes/UUIDPrimaryKeyScope.py b/src/masoniteorm/scopes/UUIDPrimaryKeyScope.py index 96e26b7c..30b4bee4 100644 --- a/src/masoniteorm/scopes/UUIDPrimaryKeyScope.py +++ b/src/masoniteorm/scopes/UUIDPrimaryKeyScope.py @@ -7,9 +7,7 @@ class UUIDPrimaryKeyScope(BaseScope): """Global scope class to use UUID4 as primary key.""" def on_boot(self, builder): - builder.set_global_scope( - "_UUID_primary_key", self.set_uuid_create, action="insert" - ) + builder.set_global_scope("_UUID_primary_key", self.set_uuid_create, action="insert") builder.set_global_scope( "_UUID_primary_key", self.set_bulk_uuid_create, @@ -34,11 +32,7 @@ def generate_uuid(self, builder, uuid_version, bytes=False): def build_uuid_pk(self, builder): uuid_version = getattr(builder._model, "__uuid_version__", 4) uuid_bytes = getattr(builder._model, "__uuid_bytes__", False) - return { - builder._model.__primary_key__: self.generate_uuid( - builder, uuid_version, uuid_bytes - ) - } + return {builder._model.__primary_key__: self.generate_uuid(builder, uuid_version, uuid_bytes)} def set_uuid_create(self, builder): # if there is already a primary key, no need to set a new one diff --git a/tests/collection/test_collection.py b/tests/collection/test_collection.py index 1687b37c..3722190c 100644 --- a/tests/collection/test_collection.py +++ b/tests/collection/test_collection.py @@ -228,9 +228,7 @@ def test_count(self): collection = Collection([1, 1, 2, 4]) self.assertEqual(collection.count(), 4) - collection = Collection( - [{"name": "Corentin All", "age": 1}, {"name": "Corentin All", "age": 2}] - ) + collection = Collection([{"name": "Corentin All", "age": 1}, {"name": "Corentin All", "age": 2}]) self.assertEqual(collection.count(), 2) def test_chunk(self): @@ -350,9 +348,7 @@ def test_reject(self): collection.reject(lambda x: x if x["age"] > 2 else None) self.assertEqual( - Collection( - [{"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}] - ), + Collection([{"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}]), collection.all(), ) @@ -493,9 +489,7 @@ def test_implode(self): result = collection.implode("-") self.assertEqual(result, "1-2-3-4") - collection = Collection( - [{"name": "Corentin"}, {"name": "Joe"}, {"name": "Marlysson"}] - ) + collection = Collection([{"name": "Corentin"}, {"name": "Joe"}, {"name": "Marlysson"}]) result = collection.implode(key="name") self.assertEqual(result, "Corentin,Joe,Marlysson") @@ -510,9 +504,7 @@ def __eq__(self, other): return self.code == other.code currencies = collection.map_into(Currency) - self.assertEqual( - currencies.all(), [Currency("USD"), Currency("EUR"), Currency("GBP")] - ) + self.assertEqual(currencies.all(), [Currency("USD"), Currency("EUR"), Currency("GBP")]) def test_map(self): collection = Collection([1, 2, 3, 4]) diff --git a/tests/commands/test_shell.py b/tests/commands/test_shell.py index 7de7b08f..fb1e9c95 100644 --- a/tests/commands/test_shell.py +++ b/tests/commands/test_shell.py @@ -63,10 +63,7 @@ def test_for_mssql(self): "full_details": {"driver": "mssql"}, } command, _ = self.command.get_command(config) - assert ( - command - == "sqlcmd -d orm -U root -P secretpostgres -S tcp:db.masonite.com,1234" - ) + assert command == "sqlcmd -d orm -U root -P secretpostgres -S tcp:db.masonite.com,1234" def test_running_command_with_sqlite(self): self.command_tester.execute("-c dev") @@ -85,6 +82,4 @@ def test_hiding_sensitive_options(self): } command, _ = self.command.get_command(config) cleaned_command = self.command.hide_sensitive_options(config, command) - assert ( - cleaned_command == "mysql orm --host localhost --user root --password ***" - ) + assert cleaned_command == "mysql orm --host localhost --user root --password ***" diff --git a/tests/eagers/test_eager.py b/tests/eagers/test_eager.py index b162c180..759f64bb 100644 --- a/tests/eagers/test_eager.py +++ b/tests/eagers/test_eager.py @@ -5,9 +5,7 @@ class TestEagerRelation(unittest.TestCase): def test_can_register_string_eager_load(self): - self.assertEqual( - EagerRelations().register("profile").get_eagers(), [["profile"]] - ) + self.assertEqual(EagerRelations().register("profile").get_eagers(), [["profile"]]) self.assertEqual(EagerRelations().register("profile").is_nested, False) self.assertEqual( EagerRelations().register("profile.user").get_eagers(), @@ -18,9 +16,7 @@ def test_can_register_string_eager_load(self): [{"profile": ["user", "logo"]}], ) self.assertEqual( - EagerRelations() - .register("profile.user", "profile.logo", "profile.bio") - .get_eagers(), + EagerRelations().register("profile.user", "profile.logo", "profile.bio").get_eagers(), [{"profile": ["user", "logo", "bio"]}], ) self.assertEqual( @@ -29,9 +25,7 @@ def test_can_register_string_eager_load(self): ) def test_can_register_tuple_eager_load(self): - self.assertEqual( - EagerRelations().register(("profile",)).get_eagers(), [["profile"]] - ) + self.assertEqual(EagerRelations().register(("profile",)).get_eagers(), [["profile"]]) self.assertEqual( EagerRelations().register(("profile", "user")).get_eagers(), [["profile", "user"]], @@ -42,9 +36,7 @@ def test_can_register_tuple_eager_load(self): ) def test_can_register_list_eager_load(self): - self.assertEqual( - EagerRelations().register(["profile"]).get_eagers(), [["profile"]] - ) + self.assertEqual(EagerRelations().register(["profile"]).get_eagers(), [["profile"]]) self.assertEqual( EagerRelations().register(["profile", "user"]).get_eagers(), [["profile", "user"]], @@ -62,8 +54,6 @@ def test_can_register_list_eager_load(self): [["logo"], {"profile": ["name"]}], ) self.assertEqual( - EagerRelations() - .register(["profile.name", "logo", "profile.user"]) - .get_eagers(), + EagerRelations().register(["profile.name", "logo", "profile.user"]).get_eagers(), [["logo"], {"profile": ["name", "user"]}], ) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 9263d01f..4036e456 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -61,22 +61,14 @@ def test_model_can_access_str_dates_as_pendulum_from_correct_datetimes( ): model = ModelTest() - self.assertEqual( - model.get_new_date(datetime.datetime(2021, 1, 1, 7, 10)).hour, 7 - ) + self.assertEqual(model.get_new_date(datetime.datetime(2021, 1, 1, 7, 10)).hour, 7) self.assertEqual(model.get_new_date(datetime.date(2021, 1, 1)).hour, 0) self.assertEqual(model.get_new_date(datetime.time(1, 1, 1)).hour, 1) self.assertEqual(model.get_new_date("2020-11-28 11:42:07").hour, 11) def test_model_can_access_str_dates_on_relationships(self): model = ModelTest.hydrate({"user": "joe", "due_date": "2020-11-28 11:42:07"}) - model.add_relation( - { - "profile": ModelTest.hydrate( - {"name": "bob", "due_date": "2020-11-28 11:42:07"} - ) - } - ) + model.add_relation({"profile": ModelTest.hydrate({"name": "bob", "due_date": "2020-11-28 11:42:07"})}) self.assertEqual(model.profile.name, "bob") self.assertTrue(model.profile.due_date.is_past()) @@ -85,9 +77,7 @@ def test_model_original_and_dirty_attributes(self): model = ModelTest.hydrate({"username": "joe", "admin": True}) self.assertEqual(model.username, "joe") - self.assertEqual( - model.__original_attributes__, {"username": "joe", "admin": True} - ) + self.assertEqual(model.__original_attributes__, {"username": "joe", "admin": True}) model.username = "bob" @@ -97,9 +87,7 @@ def test_model_original_and_dirty_attributes(self): self.assertEqual(model.__dirty_attributes__["username"], "bob") self.assertEqual(model.get_dirty_keys(), ["username"]) self.assertTrue(model.is_dirty() is True) - self.assertEqual( - model.__original_attributes__, {"username": "joe", "admin": True} - ) + self.assertEqual(model.__original_attributes__, {"username": "joe", "admin": True}) def test_model_creates_when_new(self): model = ModelTest.hydrate({"id": 1, "username": "joe", "admin": True}) @@ -173,21 +161,15 @@ def test_model_can_cast_dict_attributes(self): self.assertEqual(type(model.d), Decimal) def test_valid_json_cast(self): - model = ModelTest.hydrate( - {"payload": {"this": "dict", "is": "usable", "as": "json"}} - ) + model = ModelTest.hydrate({"payload": {"this": "dict", "is": "usable", "as": "json"}}) self.assertEqual(type(model.payload), dict) - model = ModelTest.hydrate( - {"payload": {"this": "dict", "is": "invalid", "as": "json"}} - ) + model = ModelTest.hydrate({"payload": {"this": "dict", "is": "invalid", "as": "json"}}) self.assertEqual(type(model.payload), dict) - model = ModelTest.hydrate( - {"payload": '{"this": "dict", "is": "usable", "as": "json"}'} - ) + model = ModelTest.hydrate({"payload": '{"this": "dict", "is": "usable", "as": "json"}'}) self.assertEqual(type(model.payload), dict) @@ -204,9 +186,7 @@ def test_valid_json_cast(self): model.save() def test_model_update_without_changes(self): - model = ModelTest.hydrate( - {"id": 1, "username": "joe", "name": "Joe", "admin": True} - ) + model = ModelTest.hydrate({"id": 1, "username": "joe", "name": "Joe", "admin": True}) model.username = "joe" model.name = "Bill" @@ -215,9 +195,7 @@ def test_model_update_without_changes(self): self.assertNotIn("username", sql) def test_force_update_on_model_class(self): - model = ModelTestForced.hydrate( - {"id": 1, "username": "joe", "name": "Joe", "admin": True} - ) + model = ModelTestForced.hydrate({"id": 1, "username": "joe", "name": "Joe", "admin": True}) model.username = "joe" model.name = "Bill" @@ -227,17 +205,13 @@ def test_force_update_on_model_class(self): self.assertIn("name", sql) def test_only_method(self): - model = ModelTestForced.hydrate( - {"id": 1, "username": "joe", "name": "Joe", "admin": True} - ) + model = ModelTestForced.hydrate({"id": 1, "username": "joe", "name": "Joe", "admin": True}) self.assertEqual({"username": "joe"}, model.only("username")) self.assertEqual({"username": "joe"}, model.only(["username"])) def test_model_update_without_changes_at_all(self): - model = ModelTest.hydrate( - {"id": 1, "username": "joe", "name": "Joe", "admin": True} - ) + model = ModelTest.hydrate({"id": 1, "username": "joe", "name": "Joe", "admin": True}) model.username = "joe" model.name = "Joe" @@ -258,11 +232,7 @@ def test_model_using_or_where_and_chaining_wheres(self): sql = ( model.where("name", "=", "joe") - .or_where( - lambda query: query.where("username", "Joseph").or_where( - "age", ">=", 18 - ) - ) + .or_where(lambda query: query.where("username", "Joseph").or_where("age", ">=", 18)) .to_sql() ) @@ -302,9 +272,7 @@ def test_model_can_provide_default_select(self): ) def test_model_can_override_to_default_select(self): - sql = ModelWithBaseModel.select( - ["products.name", "products.id", "store.name"] - ).to_sql() + sql = ModelWithBaseModel.select(["products.name", "products.id", "store.name"]).to_sql() self.assertEqual( sql, """SELECT `products`.`name`, `products`.`id`, `store`.`name` FROM `users`""", diff --git a/tests/mssql/builder/test_mssql_query_builder.py b/tests/mssql/builder/test_mssql_query_builder.py index 16cf68c4..f86ab19a 100644 --- a/tests/mssql/builder/test_mssql_query_builder.py +++ b/tests/mssql/builder/test_mssql_query_builder.py @@ -39,9 +39,7 @@ def test_sum(self): builder = self.get_builder() builder.sum("age") - self.assertEqual( - builder.to_sql(), "SELECT SUM([users].[age]) AS age FROM [users]" - ) + self.assertEqual(builder.to_sql(), "SELECT SUM([users].[age]) AS age FROM [users]") def test_where_like(self): builder = self.get_builder() @@ -65,25 +63,19 @@ def test_max(self): builder = self.get_builder() builder.max("age") - self.assertEqual( - builder.to_sql(), "SELECT MAX([users].[age]) AS age FROM [users]" - ) + self.assertEqual(builder.to_sql(), "SELECT MAX([users].[age]) AS age FROM [users]") def test_min(self): builder = self.get_builder() builder.min("age") - self.assertEqual( - builder.to_sql(), "SELECT MIN([users].[age]) AS age FROM [users]" - ) + self.assertEqual(builder.to_sql(), "SELECT MIN([users].[age]) AS age FROM [users]") def test_avg(self): builder = self.get_builder() builder.avg("age") - self.assertEqual( - builder.to_sql(), "SELECT AVG([users].[age]) AS age FROM [users]" - ) + self.assertEqual(builder.to_sql(), "SELECT AVG([users].[age]) AS age FROM [users]") def test_all(self): builder = self.get_builder() @@ -131,9 +123,7 @@ def test_select_raw(self): builder = self.get_builder() builder.select_raw("count(email) as email_count") - self.assertEqual( - builder.to_sql(), "SELECT count(email) as email_count FROM [users]" - ) + self.assertEqual(builder.to_sql(), "SELECT count(email) as email_count FROM [users]") def test_create(self): builder = self.get_builder().without_global_scopes() @@ -210,9 +200,7 @@ def test_right_join(self): ) def test_update(self): - builder = self.get_builder().update( - {"name": "Joe", "email": "joe@yopmail.com"}, dry=True - ) + builder = self.get_builder().update({"name": "Joe", "email": "joe@yopmail.com"}, dry=True) self.assertEqual( builder.to_sql(), "UPDATE [users] SET [users].[name] = 'Joe', [users].[email] = 'joe@yopmail.com'", @@ -235,9 +223,7 @@ def test_update(self): def test_count(self): builder = self.get_builder() builder.count("id") - self.assertEqual( - builder.to_sql(), "SELECT COUNT([users].[id]) AS id FROM [users]" - ) + self.assertEqual(builder.to_sql(), "SELECT COUNT([users].[id]) AS id FROM [users]") def test_order_by_asc(self): builder = self.get_builder() @@ -247,9 +233,7 @@ def test_order_by_asc(self): def test_order_by_desc(self): builder = self.get_builder() builder.order_by("email", "desc") - self.assertEqual( - builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] DESC" - ) + self.assertEqual(builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] DESC") def test_where_column(self): builder = self.get_builder() @@ -312,9 +296,7 @@ def test_where_not_null(self): def test_having(self): builder = self.get_builder(table="payments") - builder.select("user_id").avg("salary").group_by("user_id").having( - "salary", ">=", "1000" - ) + builder.select("user_id").avg("salary").group_by("user_id").having("salary", ">=", "1000") self.assertEqual( builder.to_sql(), @@ -425,9 +407,7 @@ def test_truncate_without_foreign_keys(self): def test_latest(self): builder = self.get_builder() builder.latest("email") - self.assertEqual( - builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] DESC" - ) + self.assertEqual(builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] DESC") def test_latest_multiple(self): builder = self.get_builder() diff --git a/tests/mssql/grammar/test_mssql_delete_grammar.py b/tests/mssql/grammar/test_mssql_delete_grammar.py index 1e5007a0..dbe78e38 100644 --- a/tests/mssql/grammar/test_mssql_delete_grammar.py +++ b/tests/mssql/grammar/test_mssql_delete_grammar.py @@ -16,14 +16,8 @@ def test_can_compile_delete(self): def test_can_compile_delete_with_where(self): to_sql = ( - self.builder.where("age", 20) - .where("profile", 1) - .set_action("delete") - .delete(query=True) - .to_sql() + self.builder.where("age", 20).where("profile", 1).set_action("delete").delete(query=True).to_sql() ) - sql = ( - "DELETE FROM [users] WHERE [users].[age] = '20' AND [users].[profile] = '1'" - ) + sql = "DELETE FROM [users] WHERE [users].[age] = '20' AND [users].[profile] = '1'" self.assertEqual(to_sql, sql) diff --git a/tests/mssql/grammar/test_mssql_select_grammar.py b/tests/mssql/grammar/test_mssql_select_grammar.py index a0595df8..200d451b 100644 --- a/tests/mssql/grammar/test_mssql_select_grammar.py +++ b/tests/mssql/grammar/test_mssql_select_grammar.py @@ -117,9 +117,9 @@ def can_compile_where_raw(self): return "SELECT * FROM [users] WHERE [users].[age] = '18'" def test_can_compile_where_raw_and_where_with_multiple_bindings(self): - query = self.builder.where_raw( - "[age] = ? AND [is_admin] = ?", [18, True] - ).where("email", "test@example.com") + query = self.builder.where_raw("[age] = ? AND [is_admin] = ?", [18, True]).where( + "email", "test@example.com" + ) self.assertEqual( query.to_qmark(), "SELECT * FROM [users] WHERE [age] = ? AND [is_admin] = ? AND [users].[email] = ?", @@ -175,9 +175,7 @@ def can_compile_or_where(self): """ self.builder.where('name', 2).or_where('name', 3).to_sql() """ - return ( - "SELECT * FROM [users] WHERE [users].[name] = '2' OR [users].[name] = '3'" - ) + return "SELECT * FROM [users] WHERE [users].[name] = '2' OR [users].[name] = '3'" def can_grouped_where(self): """ @@ -273,13 +271,17 @@ def can_compile_having_with_expression(self): """ builder.sum('age').group_by('age').having('age', 10).to_sql() """ - return "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age] = '10'" + return ( + "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age] = '10'" + ) def can_compile_having_with_greater_than_expression(self): """ builder.sum('age').group_by('age').having('age', '>', 10).to_sql() """ - return "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age] > '10'" + return ( + "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age] > '10'" + ) def can_compile_join(self): """ @@ -304,14 +306,8 @@ def test_can_compile_where_raw(self): self.assertEqual(to_sql, "SELECT * FROM [users] WHERE [age] = '18'") def test_can_compile_having_raw(self): - to_sql = ( - self.builder.select_raw("COUNT(*) as counts") - .having_raw("counts > 10") - .to_sql() - ) - self.assertEqual( - to_sql, "SELECT COUNT(*) as counts FROM [users] HAVING counts > 10" - ) + to_sql = self.builder.select_raw("COUNT(*) as counts").having_raw("counts > 10").to_sql() + self.assertEqual(to_sql, "SELECT COUNT(*) as counts FROM [users] HAVING counts > 10") def test_can_compile_having_raw_order(self): to_sql = ( @@ -327,16 +323,12 @@ def test_can_compile_having_raw_order(self): def test_can_compile_select_raw(self): to_sql = self.builder.select_raw("COUNT(*)").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_select_raw_with_select(self): to_sql = self.builder.select("id").select_raw("COUNT(*)").to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def can_compile_first_or_fail(self): @@ -491,9 +483,7 @@ def where_not_exists_with_lambda(self): return """SELECT * FROM [users] WHERE NOT EXISTS (SELECT * FROM [users] WHERE [users].[age] = '1')""" def where_date(self): - return ( - """SELECT * FROM [users] WHERE DATE([users].[created_at]) = '2022-06-01'""" - ) + return """SELECT * FROM [users] WHERE DATE([users].[created_at]) = '2022-06-01'""" def or_where_null(self): return """SELECT * FROM [users] WHERE [users].[column1] IS NULL OR [users].[column2] IS NULL""" diff --git a/tests/mssql/grammar/test_mssql_update_grammar.py b/tests/mssql/grammar/test_mssql_update_grammar.py index aa4c6f17..2fc0bb60 100644 --- a/tests/mssql/grammar/test_mssql_update_grammar.py +++ b/tests/mssql/grammar/test_mssql_update_grammar.py @@ -10,22 +10,17 @@ def setUp(self): self.builder = QueryBuilder(MSSQLGrammar, table="users") def test_can_compile_update(self): - to_sql = ( - self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql() - ) + to_sql = self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql() sql = "UPDATE [users] SET [users].[name] = 'Joe' WHERE [users].[name] = 'bob'" self.assertEqual(to_sql, sql) def test_can_compile_update_with_multiple_where(self): - to_sql = ( - self.builder.where("name", "bob") - .where("age", 20) - .update({"name": "Joe"}, dry=True) - .to_sql() - ) + to_sql = self.builder.where("name", "bob").where("age", 20).update({"name": "Joe"}, dry=True).to_sql() - sql = "UPDATE [users] SET [users].[name] = 'Joe' WHERE [users].[name] = 'bob' AND [users].[age] = '20'" + sql = ( + "UPDATE [users] SET [users].[name] = 'Joe' WHERE [users].[name] = 'bob' AND [users].[age] = '20'" + ) self.assertEqual(to_sql, sql) def test_raw_expression(self): diff --git a/tests/mssql/schema/test_mssql_schema_builder.py b/tests/mssql/schema/test_mssql_schema_builder.py index 71f07bb7..2c091bec 100644 --- a/tests/mssql/schema/test_mssql_schema_builder.py +++ b/tests/mssql/schema/test_mssql_schema_builder.py @@ -170,9 +170,7 @@ def test_can_advanced_table_creation2(self): blueprint.string("thumbnail").nullable() blueprint.integer("premium") blueprint.integer("author_id").unsigned().nullable() - blueprint.foreign("author_id").references("id").on("users").on_delete( - "CASCADE" - ) + blueprint.foreign("author_id").references("id").on("users").on_delete("CASCADE") blueprint.text("description") blueprint.timestamps() @@ -193,9 +191,7 @@ def test_can_advanced_table_creation2(self): def test_can_add_columns_with_foreign_key_constraint_name(self): with self.schema.create("users") as blueprint: blueprint.integer("profile_id") - blueprint.foreign("profile_id", name="profile_foreign").references("id").on( - "profiles" - ) + blueprint.foreign("profile_id", name="profile_foreign").references("id").on("profiles") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( diff --git a/tests/mssql/schema/test_mssql_schema_builder_alter.py b/tests/mssql/schema/test_mssql_schema_builder_alter.py index e23e6acf..7c0c76a3 100644 --- a/tests/mssql/schema/test_mssql_schema_builder_alter.py +++ b/tests/mssql/schema/test_mssql_schema_builder_alter.py @@ -27,9 +27,7 @@ def test_can_add_columns(self): self.assertEqual(len(blueprint.table.added_columns), 2) - sql = [ - "ALTER TABLE [users] ADD [name] VARCHAR(255) NOT NULL, [age] INT NOT NULL" - ] + sql = ["ALTER TABLE [users] ADD [name] VARCHAR(255) NOT NULL, [age] INT NOT NULL"] self.assertEqual(blueprint.to_sql(), sql) @@ -82,9 +80,7 @@ def test_alter_drop1(self): def test_alter_add_column_and_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.unsigned_integer("playlist_id").nullable() - blueprint.foreign("playlist_id").references("id").on("playlists").on_delete( - "cascade" - ) + blueprint.foreign("playlist_id").references("id").on("playlists").on_delete("cascade") sql = [ "ALTER TABLE [users] ADD [playlist_id] INT NULL", @@ -133,9 +129,7 @@ def test_alter_add_primary(self): with self.schema.table("users") as blueprint: blueprint.primary("playlist_id") - sql = [ - "ALTER TABLE [users] ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)" - ] + sql = ["ALTER TABLE [users] ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)"] self.assertEqual(blueprint.to_sql(), sql) diff --git a/tests/mysql/builder/test_mysql_builder_transaction.py b/tests/mysql/builder/test_mysql_builder_transaction.py index 54cebed8..ebf2b38c 100644 --- a/tests/mysql/builder/test_mysql_builder_transaction.py +++ b/tests/mysql/builder/test_mysql_builder_transaction.py @@ -19,9 +19,7 @@ class BaseTestQueryRelationships(unittest.TestCase): def get_builder(self, table="users"): connection = ConnectionFactory().make("mysql") - return QueryBuilder( - grammar=MySQLGrammar, connection=connection, table=table - ).on("mysql") + return QueryBuilder(grammar=MySQLGrammar, connection=connection, table=table).on("mysql") def test_transaction(self): builder = self.get_builder() diff --git a/tests/mysql/builder/test_query_builder.py b/tests/mysql/builder/test_query_builder.py index 0a33d096..15b5a8e3 100644 --- a/tests/mysql/builder/test_query_builder.py +++ b/tests/mysql/builder/test_query_builder.py @@ -42,94 +42,72 @@ def test_sum(self): builder = self.get_builder() builder.sum("age") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_sum_chained(self): builder = self.get_builder() builder.sum("age").max("salary") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_with_(self): builder = self.get_builder() builder.with_("articles").sum("age") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_like(self): builder = self.get_builder() builder.where("age", "like", "%name%") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_not_like(self): builder = self.get_builder() builder.where("age", "not like", "%name%") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_max(self): builder = self.get_builder() builder.max("age") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_min(self): builder = self.get_builder() builder.min("age") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_avg(self): builder = self.get_builder() builder.avg("age") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_all(self): builder = self.get_builder() builder.all() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_get(self): builder = self.get_builder() builder.get() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_first(self): builder = self.get_builder().first(query=True) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_find_with_model(self): @@ -173,41 +151,31 @@ def test_find_with_builder_without_column(self): def test_select(self): builder = self.get_builder() builder.select("name", "email") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_select_with_table(self): builder = self.get_builder() builder.select("users.*") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_select_with_table_raw(self): builder = self.get_builder() builder.select("users.*").from_raw("orders, customers") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_select_with_alias(self): builder = self.get_builder() builder.select("users.username as name") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_select_raw(self): builder = self.get_builder() builder.select_raw("count(email) as email_count") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_add_select(self): @@ -218,9 +186,7 @@ def test_add_select(self): .add_select("salary", lambda q: q.count("*").table("salary")) .to_sql() ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_add_select_no_table(self): @@ -236,9 +202,7 @@ def test_add_select_no_table(self): ) .to_sql() ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_create(self): @@ -247,82 +211,60 @@ def test_create(self): {"name": "Corentin All", "email": "corentin@yopmail.com"}, query=True, ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_delete(self): builder = self.get_builder() builder.delete("name", "Joe", query=True) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where(self): builder = self.get_builder() builder.where("name", "Joe") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_exists(self): builder = self.get_builder() builder.where_exists("name") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_limit(self): builder = self.get_builder() builder.limit(5) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_offset(self): builder = self.get_builder() builder.offset(5) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_join(self): builder = self.get_builder() builder.join("profiles", "users.id", "=", "profiles.user_id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_left_join(self): builder = self.get_builder() builder.left_join("profiles", "users.id", "=", "profiles.user_id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_right_join(self): builder = self.get_builder() builder.right_join("profiles", "users.id", "=", "profiles.user_id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_update(self): - builder = self.get_builder().update( - {"name": "Joe", "email": "joe@yopmail.com"}, dry=True - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + builder = self.get_builder().update({"name": "Joe", "email": "joe@yopmail.com"}, dry=True) + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) # def test_increment(self): @@ -344,104 +286,78 @@ def test_update(self): def test_count(self): builder = self.get_builder() builder.count("id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_order_by_asc(self): builder = self.get_builder() builder.order_by("email", "asc") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_order_by_desc(self): builder = self.get_builder() builder.order_by("email", "desc") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_column(self): builder = self.get_builder() builder.where_column("name", "username") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_not_in(self): builder = self.get_builder() builder.where_not_in("id", [1, 2, 3]) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_between(self): builder = self.get_builder() builder.between("id", 2, 5) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_not_between(self): builder = self.get_builder() builder.not_between("id", 2, 5) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_in(self): builder = self.get_builder() builder.where_in("id", [1, 2, 3]) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_null(self): builder = self.get_builder() builder.where_null("name") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_not_null(self): builder = self.get_builder() builder.where_not_null("name") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_having(self): builder = self.get_builder(table="payments") - builder.select("user_id").avg("salary").group_by("user_id").having( - "salary", ">=", "1000" - ) + builder.select("user_id").avg("salary").group_by("user_id").having("salary", ">=", "1000") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_group_by(self): builder = self.get_builder(table="payments") builder.select("user_id").min("salary").group_by("user_id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_builder_alone(self): @@ -468,49 +384,37 @@ def test_builder_alone(self): def test_where_lt(self): builder = self.get_builder() builder.where("age", "<", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_lte(self): builder = self.get_builder() builder.where("age", "<=", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_gt(self): builder = self.get_builder() builder.where("age", ">", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_gte(self): builder = self.get_builder() builder.where("age", ">=", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_ne(self): builder = self.get_builder() builder.where("age", "!=", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_or_where(self): builder = self.get_builder() builder.where("age", "20").or_where("age", "<", 20) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_like_as_operator(self): @@ -541,33 +445,25 @@ def test_can_call_with_multi_tables(self): def test_truncate(self): builder = self.get_builder(dry=True) sql = builder.truncate() - sql_ref = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql_ref = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(sql, sql_ref) def test_truncate_without_foreign_keys(self): builder = self.get_builder(dry=True) sql = builder.truncate(foreign_keys=True) - sql_ref = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql_ref = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(sql, sql_ref) def test_shared_lock(self): builder = self.get_builder(dry=True) sql = builder.where("votes", ">=", 100).shared_lock().to_sql() - sql_ref = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql_ref = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(sql, sql_ref) def test_update_lock(self): builder = self.get_builder(dry=True) sql = builder.where("votes", ">=", 100).lock_for_update().to_sql() - sql_ref = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql_ref = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(sql, sql_ref) @@ -880,9 +776,7 @@ def or_where(self): builder = self.get_builder() builder.where('age', '20').or_where('age','<', 20) """ - return ( - "SELECT * FROM `users` WHERE `users`.`age` = '20' OR `users`.`age` < '20'" - ) + return "SELECT * FROM `users` WHERE `users`.`age` = '20' OR `users`.`age` < '20'" def where_like(self): """ @@ -933,17 +827,13 @@ def update_lock(self): def test_latest(self): builder = self.get_builder() builder.latest("email") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_oldest(self): builder = self.get_builder() builder.oldest("email") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def latest(self): diff --git a/tests/mysql/builder/test_query_builder_scopes.py b/tests/mysql/builder/test_query_builder_scopes.py index 4686cf22..bc23010c 100644 --- a/tests/mysql/builder/test_query_builder_scopes.py +++ b/tests/mysql/builder/test_query_builder_scopes.py @@ -22,9 +22,7 @@ def get_builder(self, table="users"): ) def test_scopes(self): - builder = self.get_builder().set_scope( - "gender", lambda model, q: q.where("gender", "w") - ) + builder = self.get_builder().set_scope("gender", lambda model, q: q.where("gender", "w")) self.assertEqual( builder.gender().where("id", 1).to_sql(), @@ -53,9 +51,7 @@ def test_global_scope_from_class(self): def test_global_scope_remove_from_class(self): builder = ( - self.get_builder() - .set_global_scope(SoftDeleteScope()) - .remove_global_scope(SoftDeleteScope()) + self.get_builder().set_global_scope(SoftDeleteScope()).remove_global_scope(SoftDeleteScope()) ) self.assertEqual( diff --git a/tests/mysql/grammar/test_mysql_delete_grammar.py b/tests/mysql/grammar/test_mysql_delete_grammar.py index c17acee7..98a39e2f 100644 --- a/tests/mysql/grammar/test_mysql_delete_grammar.py +++ b/tests/mysql/grammar/test_mysql_delete_grammar.py @@ -12,31 +12,21 @@ def setUp(self): def test_can_compile_delete(self): to_sql = self.builder.delete("id", 1, query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_delete_in(self): to_sql = self.builder.delete("id", [1, 2, 3], query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_delete_with_where(self): to_sql = ( - self.builder.where("age", 20) - .where("profile", 1) - .set_action("delete") - .delete(query=True) - .to_sql() + self.builder.where("age", 20).where("profile", 1).set_action("delete").delete(query=True).to_sql() ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) @@ -74,6 +64,4 @@ def can_compile_delete_with_where(self): .to_sql() ) """ - return ( - "DELETE FROM `users` WHERE `users`.`age` = '20' AND `users`.`profile` = '1'" - ) + return "DELETE FROM `users` WHERE `users`.`age` = '20' AND `users`.`profile` = '1'" diff --git a/tests/mysql/grammar/test_mysql_insert_grammar.py b/tests/mysql/grammar/test_mysql_insert_grammar.py index 0089ba2c..e815de49 100644 --- a/tests/mysql/grammar/test_mysql_insert_grammar.py +++ b/tests/mysql/grammar/test_mysql_insert_grammar.py @@ -12,17 +12,13 @@ def setUp(self): def test_can_compile_insert(self): to_sql = self.builder.create({"name": "Joe"}, query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_insert_with_keywords(self): to_sql = self.builder.create(name="Joe", query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create(self): @@ -36,9 +32,7 @@ def test_can_compile_bulk_create(self): query=True, ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create_qmark(self): @@ -46,9 +40,7 @@ def test_can_compile_bulk_create_qmark(self): [{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True ).to_qmark() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create_multiple(self): @@ -61,9 +53,7 @@ def test_can_compile_bulk_create_multiple(self): query=True, ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) diff --git a/tests/mysql/grammar/test_mysql_qmark.py b/tests/mysql/grammar/test_mysql_qmark.py index f276101d..532a22cf 100644 --- a/tests/mysql/grammar/test_mysql_qmark.py +++ b/tests/mysql/grammar/test_mysql_qmark.py @@ -12,104 +12,78 @@ def setUp(self): def test_can_compile_select(self): mark = self.builder.select("username").where("name", "Joe") - sql, bindings = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql, bindings = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_delete(self): mark = self.builder.where("name", "Joe").delete(query=True) - sql, bindings = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql, bindings = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_update(self): mark = self.builder.update({"name": "Bob"}, dry=True).where("name", "Joe") - sql, bindings = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql, bindings = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_where_in(self): mark = self.builder.where_in("id", [1, 2, 3]) - sql, bindings = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql, bindings = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_where_not_null(self): mark = self.builder.where_not_null("id") - sql, bindings = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql, bindings = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, []) def test_can_compile_where_with_falsy_values(self): mark = self.builder.where("name", 0) - sql, bindings = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql, bindings = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_where_with_true_value(self): mark = self.builder.where("is_admin", True) - sql, bindings = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql, bindings = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_where_with_false_value(self): mark = self.builder.where("is_admin", False) - sql, bindings = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql, bindings = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_sub_group_bindings(self): mark = self.builder.where( - lambda query: ( - query.where("challenger", 1) - .or_where("proposer", 1) - .or_where("referee", 1) - ) + lambda query: query.where("challenger", 1).or_where("proposer", 1).or_where("referee", 1) ) - sql, bindings = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql, bindings = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_increment(self): builder = self.builder.increment("age", dry=True) - sql, bindings = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql, bindings = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_qmark(), sql) self.assertEqual(builder._bindings, bindings) def test_can_decrement(self): builder = self.builder.decrement("age", dry=True) - sql, bindings = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql, bindings = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_qmark(), sql) self.assertEqual(builder._bindings, bindings) diff --git a/tests/mysql/grammar/test_mysql_select_grammar.py b/tests/mysql/grammar/test_mysql_select_grammar.py index a87dd772..e2cc9e64 100644 --- a/tests/mysql/grammar/test_mysql_select_grammar.py +++ b/tests/mysql/grammar/test_mysql_select_grammar.py @@ -176,9 +176,7 @@ def can_compile_or_where(self): """ self.builder.where('name', 2).or_where('name', 3).to_sql() """ - return ( - "SELECT * FROM `users` WHERE `users`.`name` = '2' OR `users`.`name` = '3'" - ) + return "SELECT * FROM `users` WHERE `users`.`name` = '2' OR `users`.`name` = '3'" def can_grouped_where(self): """ @@ -256,13 +254,17 @@ def can_compile_having_with_expression(self): """ builder.sum('age').group_by('age').having('age', 10).to_sql() """ - return "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age` = '10'" + return ( + "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age` = '10'" + ) def can_compile_having_with_greater_than_expression(self): """ builder.sum('age').group_by('age').having('age', '>', 10).to_sql() """ - return "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age` > '10'" + return ( + "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age` > '10'" + ) def can_compile_join(self): """ @@ -299,14 +301,8 @@ def test_can_compile_where_raw(self): self.assertEqual(to_sql, "SELECT * FROM `users` WHERE `age` = '18'") def test_can_compile_having_raw(self): - to_sql = ( - self.builder.select_raw("COUNT(*) as counts") - .having_raw("counts > 10") - .to_sql() - ) - self.assertEqual( - to_sql, "SELECT COUNT(*) as counts FROM `users` HAVING counts > 10" - ) + to_sql = self.builder.select_raw("COUNT(*) as counts").having_raw("counts > 10").to_sql() + self.assertEqual(to_sql, "SELECT COUNT(*) as counts FROM `users` HAVING counts > 10") def test_can_compile_having_raw_order(self): to_sql = ( @@ -480,9 +476,7 @@ def where_not_exists_with_lambda(self): return """SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `users` WHERE `users`.`age` = '1')""" def where_date(self): - return ( - """SELECT * FROM `users` WHERE DATE(`users`.`created_at`) = '2022-06-01'""" - ) + return """SELECT * FROM `users` WHERE DATE(`users`.`created_at`) = '2022-06-01'""" def or_where_null(self): return """SELECT * FROM `users` WHERE `users`.`column1` IS NULL OR `users`.`column2` IS NULL""" diff --git a/tests/mysql/grammar/test_mysql_update_grammar.py b/tests/mysql/grammar/test_mysql_update_grammar.py index dde37bb3..56412dca 100644 --- a/tests/mysql/grammar/test_mysql_update_grammar.py +++ b/tests/mysql/grammar/test_mysql_update_grammar.py @@ -11,36 +11,21 @@ def setUp(self): self.builder = QueryBuilder(self.grammar, table="users") def test_can_compile_update(self): - to_sql = ( - self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql() - ) + to_sql = self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_multiple_update(self): - to_sql = self.builder.update( - {"name": "Joe", "email": "user@email.com"}, dry=True - ).to_sql() + to_sql = self.builder.update({"name": "Joe", "email": "user@email.com"}, dry=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_update_with_multiple_where(self): - to_sql = ( - self.builder.where("name", "bob") - .where("age", 20) - .update({"name": "Joe"}, dry=True) - .to_sql() - ) + to_sql = self.builder.where("name", "bob").where("age", 20).update({"name": "Joe"}, dry=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) # def test_can_compile_increment(self): @@ -62,9 +47,7 @@ def test_can_compile_update_with_multiple_where(self): def test_raw_expression(self): to_sql = self.builder.update({"name": Raw("`username`")}, dry=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) @@ -100,7 +83,9 @@ def can_compile_update_with_multiple_where(self): 'name': 'Joe' }).to_sql() """ - return "UPDATE `users` SET `users`.`name` = 'Joe' WHERE `users`.`name` = 'bob' AND `users`.`age` = '20'" + return ( + "UPDATE `users` SET `users`.`name` = 'Joe' WHERE `users`.`name` = 'bob' AND `users`.`age` = '20'" + ) def can_compile_increment(self): """ diff --git a/tests/mysql/model/test_accessors_and_mutators.py b/tests/mysql/model/test_accessors_and_mutators.py index 609785b2..c18f25a7 100644 --- a/tests/mysql/model/test_accessors_and_mutators.py +++ b/tests/mysql/model/test_accessors_and_mutators.py @@ -22,9 +22,7 @@ def set_name_attribute(self, attribute): class TestAccessor(unittest.TestCase): def test_can_get_accessor(self): - user = User.hydrate( - {"name": "joe", "email": "joe@masoniteproject.com", "is_admin": 1} - ) + user = User.hydrate({"name": "joe", "email": "joe@masoniteproject.com", "is_admin": 1}) self.assertEqual(user.email, "joe@masoniteproject.com") self.assertEqual(user.name, "Hello, joe") self.assertTrue(user.is_admin is True, f"{user.is_admin} is not True") diff --git a/tests/mysql/model/test_model.py b/tests/mysql/model/test_model.py index 0961c4d1..1a917033 100644 --- a/tests/mysql/model/test_model.py +++ b/tests/mysql/model/test_model.py @@ -78,18 +78,12 @@ class ProductNames(Model): class TestModel(unittest.TestCase): def test_create_can_use_fillable(self): - sql = ProfileFillable.create( - {"name": "Joe", "email": "user@example.com"}, query=True - ).to_sql() + sql = ProfileFillable.create({"name": "Joe", "email": "user@example.com"}, query=True).to_sql() - self.assertEqual( - sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')" - ) + self.assertEqual(sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')") def test_create_can_use_fillable_asterisk(self): - sql = ProfileFillAsterisk.create( - {"name": "Joe", "email": "user@example.com"}, query=True - ).to_sql() + sql = ProfileFillAsterisk.create({"name": "Joe", "email": "user@example.com"}, query=True).to_sql() self.assertEqual( sql, @@ -97,18 +91,12 @@ def test_create_can_use_fillable_asterisk(self): ) def test_create_can_use_guarded(self): - sql = ProfileGuarded.create( - {"name": "Joe", "email": "user@example.com"}, query=True - ).to_sql() + sql = ProfileGuarded.create({"name": "Joe", "email": "user@example.com"}, query=True).to_sql() - self.assertEqual( - sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')" - ) + self.assertEqual(sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')") def test_create_can_use_guarded_asterisk(self): - sql = ProfileGuardedAsterisk.create( - {"name": "Joe", "email": "user@example.com"}, query=True - ).to_sql() + sql = ProfileGuardedAsterisk.create({"name": "Joe", "email": "user@example.com"}, query=True).to_sql() # An asterisk guarded attribute excludes all fields from mass-assignment. # This would raise a DB error if there are any required fields. @@ -168,14 +156,10 @@ def test_bulk_create_can_use_guarded_asterisk(self): # An asterisk guarded attribute excludes all fields from mass-assignment. # This would obviously raise an invalid SQL syntax error. # TODO: Raise a clearer error? - self.assertEqual( - query_builder.to_sql(), "INSERT INTO `profiles` () VALUES (), ()" - ) + self.assertEqual(query_builder.to_sql(), "INSERT INTO `profiles` () VALUES (), ()") def test_update_can_use_fillable(self): - query_builder = ProfileFillable().update( - {"name": "Joe", "email": "user@example.com"}, dry=True - ) + query_builder = ProfileFillable().update({"name": "Joe", "email": "user@example.com"}, dry=True) self.assertEqual( query_builder.to_sql(), @@ -183,9 +167,7 @@ def test_update_can_use_fillable(self): ) def test_update_can_use_fillable_asterisk(self): - query_builder = ProfileFillAsterisk().update( - {"name": "Joe", "email": "user@example.com"}, dry=True - ) + query_builder = ProfileFillAsterisk().update({"name": "Joe", "email": "user@example.com"}, dry=True) self.assertEqual( query_builder.to_sql(), @@ -193,9 +175,7 @@ def test_update_can_use_fillable_asterisk(self): ) def test_update_can_use_guarded(self): - query_builder = ProfileGuarded().update( - {"name": "Joe", "email": "user@example.com"}, dry=True - ) + query_builder = ProfileGuarded().update({"name": "Joe", "email": "user@example.com"}, dry=True) self.assertEqual( query_builder.to_sql(), @@ -205,9 +185,7 @@ def test_update_can_use_guarded(self): def test_update_can_use_guarded_asterisk(self): profile = ProfileGuardedAsterisk() initial_sql = profile.get_builder().to_sql() - query_builder = profile.update( - {"name": "Joe", "email": "user@example.com"}, dry=True - ) + query_builder = profile.update({"name": "Joe", "email": "user@example.com"}, dry=True) # An asterisk guarded attribute excludes all fields from mass-assignment. # The query builder's sql should not have been altered in any way. @@ -234,9 +212,7 @@ def test_json(self): self.assertEqual(profile.to_json(), '{"name": "Joe", "id": 1}') def test_serialize_with_hidden(self): - profile = ProfileSerialize.hydrate( - {"name": "Joe", "id": 1, "password": "secret"} - ) + profile = ProfileSerialize.hydrate({"name": "Joe", "id": 1, "password": "secret"}) self.assertTrue(profile.serialize().get("name")) self.assertTrue(profile.serialize().get("id")) @@ -251,9 +227,7 @@ def test_serialize_with_visible(self): "email": "joe@masonite.com", } ) - self.assertTrue( - {"name": "Joe", "email": "joe@masonite.com"}, profile.serialize() - ) + self.assertTrue({"name": "Joe", "email": "joe@masonite.com"}, profile.serialize()) def test_serialize_with_visible_and_hidden_raise_error(self): profile = ProfileSerializeWithVisibleAndHidden.hydrate( diff --git a/tests/mysql/relationships/test_belongs_to_many.py b/tests/mysql/relationships/test_belongs_to_many.py index 69494f9b..92b4f6c1 100644 --- a/tests/mysql/relationships/test_belongs_to_many.py +++ b/tests/mysql/relationships/test_belongs_to_many.py @@ -47,9 +47,7 @@ class MySQLRelationships(unittest.TestCase): maxDiff = None def test_belongs_to_many(self): - sql = Permission.where_has( - "role", lambda query: query.where("slug", "users") - ).to_sql() + sql = Permission.where_has("role", lambda query: query.where("slug", "users")).to_sql() self.assertEqual( sql, @@ -95,9 +93,7 @@ def test_belongs_to_many_or_doesnt_have(self): def test_where_doesnt_have(self): sql = ( Role.where("name", "role_name") - .where_doesnt_have( - "permissions", lambda q: q.where("name", "Creates Users") - ) + .where_doesnt_have("permissions", lambda q: q.where("name", "Creates Users")) .to_sql() ) @@ -109,9 +105,7 @@ def test_where_doesnt_have(self): def test_or_where_doesnt_have(self): sql = ( Role.where("name", "role_name") - .or_where_doesnt_have( - "permissions", lambda q: q.where("name", "Creates Users") - ) + .or_where_doesnt_have("permissions", lambda q: q.where("name", "Creates Users")) .to_sql() ) @@ -121,9 +115,7 @@ def test_or_where_doesnt_have(self): ) def test_belongs_to_many_where_has(self): - sql = Role.where_has( - "permissions", lambda q: q.where("name", "Creates Users") - ).to_sql() + sql = Role.where_has("permissions", lambda q: q.where("name", "Creates Users")).to_sql() self.assertEqual( sql, diff --git a/tests/mysql/relationships/test_has_many_through.py b/tests/mysql/relationships/test_has_many_through.py index f341a52d..129b5799 100644 --- a/tests/mysql/relationships/test_has_many_through.py +++ b/tests/mysql/relationships/test_has_many_through.py @@ -44,9 +44,7 @@ def test_or_has(self): ) def test_where_has_query(self): - sql = InboundShipment.where_has( - "from_country", lambda query: query.where("name", "USA") - ).to_sql() + sql = InboundShipment.where_has("from_country", lambda query: query.where("name", "USA")).to_sql() self.assertEqual( sql, @@ -76,9 +74,7 @@ def test_doesnt_have(self): def test_or_where_doesnt_have(self): sql = ( InboundShipment.where("name", "Joe") - .or_where_doesnt_have( - "from_country", lambda query: query.where("name", "USA") - ) + .or_where_doesnt_have("from_country", lambda query: query.where("name", "USA")) .to_sql() ) diff --git a/tests/mysql/relationships/test_has_one_through.py b/tests/mysql/relationships/test_has_one_through.py index 0078b1dc..9bcb6d84 100644 --- a/tests/mysql/relationships/test_has_one_through.py +++ b/tests/mysql/relationships/test_has_one_through.py @@ -44,9 +44,7 @@ def test_or_has(self): ) def test_where_has_query(self): - sql = InboundShipment.where_has( - "from_country", lambda query: query.where("name", "USA") - ).to_sql() + sql = InboundShipment.where_has("from_country", lambda query: query.where("name", "USA")).to_sql() self.assertEqual( sql, @@ -76,9 +74,7 @@ def test_doesnt_have(self): def test_or_where_doesnt_have(self): sql = ( InboundShipment.where("name", "Joe") - .or_where_doesnt_have( - "from_country", lambda query: query.where("name", "USA") - ) + .or_where_doesnt_have("from_country", lambda query: query.where("name", "USA")) .to_sql() ) diff --git a/tests/mysql/relationships/test_relationships.py b/tests/mysql/relationships/test_relationships.py index a1eddcf0..f9d97ba8 100644 --- a/tests/mysql/relationships/test_relationships.py +++ b/tests/mysql/relationships/test_relationships.py @@ -60,11 +60,7 @@ def test_or_has_nested(self): ) def test_relationship_where_has(self): - sql = ( - User.where("name", "Joe") - .where_has("profile", lambda q: q.where("profile_id", 1)) - .to_sql() - ) + sql = User.where("name", "Joe").where_has("profile", lambda q: q.where("profile_id", 1)).to_sql() self.assertEqual( sql, @@ -87,11 +83,7 @@ def test_relationship_where_has_nested(self): ) def test_relationship_or_where_has(self): - sql = ( - User.where("name", "Joe") - .or_where_has("profile", lambda q: q.where("profile_id", 1)) - .to_sql() - ) + sql = User.where("name", "Joe").or_where_has("profile", lambda q: q.where("profile_id", 1)).to_sql() self.assertEqual( sql, @@ -130,9 +122,7 @@ def test_relationship_doesnt_have_nested(self): ) def test_relationship_where_doesnt_have(self): - sql = User.where_doesnt_have( - "profile", lambda q: q.where("profile_id", 1) - ).to_sql() + sql = User.where_doesnt_have("profile", lambda q: q.where("profile_id", 1)).to_sql() self.assertEqual( sql, @@ -150,9 +140,7 @@ def test_relationship_where_doesnt_have_nested(self): ) def test_relationship_or_where_doesnt_have(self): - sql = User.or_where_doesnt_have( - "profile", lambda q: q.where("profile_id", 1) - ).to_sql() + sql = User.or_where_doesnt_have("profile", lambda q: q.where("profile_id", 1)).to_sql() self.assertEqual( sql, diff --git a/tests/mysql/schema/test_mysql_schema_builder.py b/tests/mysql/schema/test_mysql_schema_builder.py index c83a3c84..6878dcc8 100644 --- a/tests/mysql/schema/test_mysql_schema_builder.py +++ b/tests/mysql/schema/test_mysql_schema_builder.py @@ -33,9 +33,7 @@ def test_can_add_columns1(self): self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), - [ - "CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL, `age` INT(11) NOT NULL)" - ], + ["CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL, `age` INT(11) NOT NULL)"], ) def test_can_add_tiny_text(self): @@ -66,9 +64,7 @@ def test_can_create_table_if_not_exists(self): self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), - [ - "CREATE TABLE IF NOT EXISTS `users` (`name` VARCHAR(255) NOT NULL, `age` INT(11) NOT NULL)" - ], + ["CREATE TABLE IF NOT EXISTS `users` (`name` VARCHAR(255) NOT NULL, `age` INT(11) NOT NULL)"], ) def test_can_add_columns_with_constaint(self): @@ -93,9 +89,7 @@ def test_add_column_comment(self): self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), - [ - "CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL COMMENT 'A users username')" - ], + ["CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL COMMENT 'A users username')"], ) def test_can_add_table_comment(self): @@ -106,9 +100,7 @@ def test_can_add_table_comment(self): self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), - [ - "CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL) COMMENT 'A users table'" - ], + ["CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL) COMMENT 'A users table'"], ) def test_can_add_columns_with_foreign_key_constaint(self): @@ -168,11 +160,7 @@ def test_can_add_primary_constraint_without_column_name(self): self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual(len(blueprint.table.added_constraints), 1) - self.assertTrue( - blueprint.to_sql()[0].startswith( - "CREATE TABLE `users` (`user_id` INT(11) NOT NULL" - ) - ) + self.assertTrue(blueprint.to_sql()[0].startswith("CREATE TABLE `users` (`user_id` INT(11) NOT NULL")) def test_can_advanced_table_creation2(self): with self.schema.create("users") as blueprint: @@ -187,9 +175,7 @@ def test_can_advanced_table_creation2(self): blueprint.string("thumbnail").nullable() blueprint.integer("premium") blueprint.integer("author_id").unsigned().nullable() - blueprint.foreign("author_id").references("id").on("users").on_delete( - "CASCADE" - ) + blueprint.foreign("author_id").references("id").on("users").on_delete("CASCADE") blueprint.text("description") blueprint.timestamps() @@ -208,9 +194,7 @@ def test_can_advanced_table_creation2(self): def test_can_add_columns_with_foreign_key_constraint_name(self): with self.schema.create("users") as blueprint: blueprint.integer("profile_id") - blueprint.foreign("profile_id", name="profile_foreign").references("id").on( - "profiles" - ) + blueprint.foreign("profile_id", name="profile_foreign").references("id").on("profiles") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( @@ -367,7 +351,5 @@ def test_can_add_enum(self): self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), - [ - "CREATE TABLE `users` (`status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active')" - ], + ["CREATE TABLE `users` (`status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active')"], ) diff --git a/tests/mysql/schema/test_mysql_schema_builder_alter.py b/tests/mysql/schema/test_mysql_schema_builder_alter.py index bba8007e..eddf03f1 100644 --- a/tests/mysql/schema/test_mysql_schema_builder_alter.py +++ b/tests/mysql/schema/test_mysql_schema_builder_alter.py @@ -27,9 +27,7 @@ def test_can_add_columns(self): self.assertEqual(len(blueprint.table.added_columns), 2) - sql = [ - "ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL, ADD `age` INT(11) NOT NULL" - ] + sql = ["ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL, ADD `age` INT(11) NOT NULL"] self.assertEqual(blueprint.to_sql(), sql) @@ -129,9 +127,7 @@ def test_alter_drop1(self): def test_alter_add_column_and_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.unsigned_integer("playlist_id").nullable() - blueprint.foreign("playlist_id").references("id").on("playlists").on_delete( - "cascade" - ) + blueprint.foreign("playlist_id").references("id").on("playlists").on_delete("cascade") sql = [ "ALTER TABLE `users` ADD `playlist_id` INT UNSIGNED NULL", @@ -184,9 +180,7 @@ def test_alter_add_primary(self): with self.schema.table("users") as blueprint: blueprint.primary("playlist_id") - sql = [ - "ALTER TABLE `users` ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)" - ] + sql = ["ALTER TABLE `users` ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)"] self.assertEqual(blueprint.to_sql(), sql) @@ -301,9 +295,7 @@ def test_can_add_column_enum(self): self.assertEqual(len(blueprint.table.added_columns), 1) - sql = [ - "ALTER TABLE `users` ADD `status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active'" - ] + sql = ["ALTER TABLE `users` ADD `status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active'"] self.assertEqual(blueprint.to_sql(), sql) @@ -313,8 +305,6 @@ def test_can_change_column_enum(self): self.assertEqual(len(blueprint.table.changed_columns), 1) - sql = [ - "ALTER TABLE `users` MODIFY `status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active'" - ] + sql = ["ALTER TABLE `users` MODIFY `status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active'"] self.assertEqual(blueprint.to_sql(), sql) diff --git a/tests/mysql/scopes/test_can_use_global_scopes.py b/tests/mysql/scopes/test_can_use_global_scopes.py index 89b94d19..0d881582 100644 --- a/tests/mysql/scopes/test_can_use_global_scopes.py +++ b/tests/mysql/scopes/test_can_use_global_scopes.py @@ -31,9 +31,7 @@ def test_can_use_global_scopes_on_select(self): def test_can_use_global_scopes_on_time(self): sql = "INSERT INTO `users` (`users`.`name`, `users`.`updated_at`, `users`.`created_at`) VALUES ('Joe'" - self.assertTrue( - User.create({"name": "Joe"}, query=True).to_sql().startswith(sql) - ) + self.assertTrue(User.create({"name": "Joe"}, query=True).to_sql().startswith(sql)) # def test_can_use_global_scopes_on_inherit(self): # sql = "SELECT * FROM `user_softs` WHERE `user_softs`.`deleted_at` IS NULL" diff --git a/tests/mysql/scopes/test_soft_delete.py b/tests/mysql/scopes/test_soft_delete.py index a1e9318d..88755073 100644 --- a/tests/mysql/scopes/test_soft_delete.py +++ b/tests/mysql/scopes/test_soft_delete.py @@ -49,9 +49,7 @@ def test_restore(self): def test_force_delete_with_wheres(self): sql = "DELETE FROM `users` WHERE `users`.`active` = '1'" - self.assertEqual( - sql, UserSoft.where("active", 1).force_delete(query=True).to_sql() - ) + self.assertEqual(sql, UserSoft.where("active", 1).force_delete(query=True).to_sql()) def test_that_trashed_users_are_not_returned_by_default(self): sql = "SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL" diff --git a/tests/postgres/builder/test_postgres_query_builder.py b/tests/postgres/builder/test_postgres_query_builder.py index 4860fa89..28fe6bc0 100644 --- a/tests/postgres/builder/test_postgres_query_builder.py +++ b/tests/postgres/builder/test_postgres_query_builder.py @@ -39,84 +39,64 @@ def test_sum(self): builder = self.get_builder() builder.sum("age") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_like(self): builder = self.get_builder() builder.where("age", "like", "%name%") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_not_like(self): builder = self.get_builder() builder.where("age", "not like", "%name%") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_max(self): builder = self.get_builder() builder.max("age") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_min(self): builder = self.get_builder() builder.min("age") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_avg(self): builder = self.get_builder() builder.avg("age") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_all(self): builder = self.get_builder() builder.all() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_get(self): builder = self.get_builder() builder.get() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_first(self): builder = self.get_builder().first(query=True) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_select(self): builder = self.get_builder() builder.select("name", "email") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_add_select_no_table(self): @@ -132,17 +112,13 @@ def test_add_select_no_table(self): ) .to_sql() ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_select_raw(self): builder = self.get_builder() builder.select_raw("count(email) as email_count") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_create(self): @@ -151,82 +127,60 @@ def test_create(self): {"name": "Corentin All", "email": "corentin@yopmail.com"}, query=True, ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_delete(self): builder = self.get_builder() builder.delete("name", "Joe", query=True) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where(self): builder = self.get_builder() builder.where("name", "Joe") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_exists(self): builder = self.get_builder() builder.where_exists("name") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_limit(self): builder = self.get_builder() builder.limit(5) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_offset(self): builder = self.get_builder() builder.offset(5) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_join(self): builder = self.get_builder() builder.join("profiles", "users.id", "=", "profiles.user_id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_left_join(self): builder = self.get_builder() builder.left_join("profiles", "users.id", "=", "profiles.user_id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_right_join(self): builder = self.get_builder() builder.right_join("profiles", "users.id", "=", "profiles.user_id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_update(self): - builder = self.get_builder().update( - {"name": "Joe", "email": "joe@yopmail.com"}, dry=True - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + builder = self.get_builder().update({"name": "Joe", "email": "joe@yopmail.com"}, dry=True) + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) # def test_increment(self): @@ -248,104 +202,78 @@ def test_update(self): def test_count(self): builder = self.get_builder() builder.count("id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_order_by_asc(self): builder = self.get_builder() builder.order_by("email", "asc") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_order_by_desc(self): builder = self.get_builder() builder.order_by("email", "desc") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_column(self): builder = self.get_builder() builder.where_column("name", "username") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_not_in(self): builder = self.get_builder() builder.where_not_in("id", [1, 2, 3]) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_between(self): builder = self.get_builder() builder.between("id", 2, 5) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_not_between(self): builder = self.get_builder() builder.not_between("id", 2, 5) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_in(self): builder = self.get_builder() builder.where_in("id", [1, 2, 3]) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_null(self): builder = self.get_builder() builder.where_null("name") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_not_null(self): builder = self.get_builder() builder.where_not_null("name") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_having(self): builder = self.get_builder(table="payments") - builder.select("user_id").avg("salary").group_by("user_id").having( - "salary", ">=", "1000" - ) + builder.select("user_id").avg("salary").group_by("user_id").having("salary", ">=", "1000") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_group_by(self): builder = self.get_builder(table="payments") builder.select("user_id").min("salary").group_by("user_id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_builder_alone(self): @@ -372,49 +300,37 @@ def test_builder_alone(self): def test_where_lt(self): builder = self.get_builder() builder.where("age", "<", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_lte(self): builder = self.get_builder() builder.where("age", "<=", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_gt(self): builder = self.get_builder() builder.where("age", ">", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_gte(self): builder = self.get_builder() builder.where("age", ">=", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_ne(self): builder = self.get_builder() builder.where("age", "!=", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_or_where(self): builder = self.get_builder() builder.where("age", "20").or_where("age", "<", 20) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_can_call_with_schema(self): @@ -433,33 +349,25 @@ def test_can_call_with_schema(self): def test_truncate(self): builder = self.get_builder(dry=True) sql = builder.truncate() - sql_ref = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql_ref = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(sql, sql_ref) def test_truncate_without_foreign_keys(self): builder = self.get_builder(dry=True) sql = builder.truncate(foreign_keys=True) - sql_ref = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql_ref = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(sql, sql_ref) def test_shared_lock(self): builder = self.get_builder(dry=True) sql = builder.where("votes", ">=", 100).shared_lock().to_sql() - sql_ref = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql_ref = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(sql, sql_ref) def test_update_lock(self): builder = self.get_builder(dry=True) sql = builder.where("votes", ">=", 100).lock_for_update().to_sql() - sql_ref = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql_ref = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(sql, sql_ref) @@ -777,17 +685,13 @@ def shared_lock(self): def test_latest(self): builder = self.get_builder() builder.latest("email") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_oldest(self): builder = self.get_builder() builder.oldest("email") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def oldest(self): diff --git a/tests/postgres/grammar/test_delete_grammar.py b/tests/postgres/grammar/test_delete_grammar.py index 690e7253..b3231b56 100644 --- a/tests/postgres/grammar/test_delete_grammar.py +++ b/tests/postgres/grammar/test_delete_grammar.py @@ -12,31 +12,21 @@ def setUp(self): def test_can_compile_delete(self): to_sql = self.builder.delete("id", 1, query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_delete_in(self): to_sql = self.builder.delete("id", [1, 2, 3], query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_delete_with_where(self): to_sql = ( - self.builder.where("age", 20) - .where("profile", 1) - .set_action("delete") - .delete(query=True) - .to_sql() + self.builder.where("age", 20).where("profile", 1).set_action("delete").delete(query=True).to_sql() ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) diff --git a/tests/postgres/grammar/test_insert_grammar.py b/tests/postgres/grammar/test_insert_grammar.py index 6404d2e3..6e4d1310 100644 --- a/tests/postgres/grammar/test_insert_grammar.py +++ b/tests/postgres/grammar/test_insert_grammar.py @@ -12,17 +12,13 @@ def setUp(self): def test_can_compile_insert(self): to_sql = self.builder.create({"name": "Joe"}, query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_insert_with_keywords(self): to_sql = self.builder.create(name="Joe", query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create(self): @@ -36,9 +32,7 @@ def test_can_compile_bulk_create(self): query=True, ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create_qmark(self): @@ -46,9 +40,7 @@ def test_can_compile_bulk_create_qmark(self): [{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True ).to_qmark() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) diff --git a/tests/postgres/grammar/test_select_grammar.py b/tests/postgres/grammar/test_select_grammar.py index 475d6eee..8d1255c8 100644 --- a/tests/postgres/grammar/test_select_grammar.py +++ b/tests/postgres/grammar/test_select_grammar.py @@ -71,9 +71,7 @@ def can_compile_with_multiple_order_by(self): """ self.builder.select('username').order_by('age', 'desc').order_by('name').to_sql() """ - return ( - """SELECT "users"."username" FROM "users" ORDER BY "age" DESC, "name" ASC""" - ) + return """SELECT "users"."username" FROM "users" ORDER BY "age" DESC, "name" ASC""" def can_compile_with_group_by(self): """ @@ -109,9 +107,7 @@ def can_compile_where_not_null(self): """ self.builder.select('username').where_not_null('age').to_sql() """ - return ( - """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NOT NULL""" - ) + return """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NOT NULL""" def can_compile_where_raw(self): """ @@ -210,7 +206,9 @@ def can_compile_sub_select_value(self): ).to_sql() """ - return """SELECT * FROM "users" WHERE "users"."name" = (SELECT SUM("users"."age") AS age FROM "users")""" + return ( + """SELECT * FROM "users" WHERE "users"."name" = (SELECT SUM("users"."age") AS age FROM "users")""" + ) def can_compile_complex_sub_select(self): """ @@ -243,7 +241,9 @@ def can_compile_having(self): """ builder.sum('age').group_by('age').having('age').to_sql() """ - return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age\"""" + return ( + """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age\"""" + ) def can_compile_having_order(self): """ @@ -304,11 +304,7 @@ def test_can_compile_where_raw(self): self.assertEqual(to_sql, """SELECT * FROM "users" WHERE "age" = '18'""") def test_can_compile_having_raw(self): - to_sql = ( - self.builder.select_raw("COUNT(*) as counts") - .having_raw("counts > 10") - .to_sql() - ) + to_sql = self.builder.select_raw("COUNT(*) as counts").having_raw("counts > 10").to_sql() self.assertEqual( to_sql, """SELECT COUNT(*) as counts FROM "users" HAVING counts > 10""", @@ -327,9 +323,9 @@ def test_can_compile_having_raw_order(self): ) def test_can_compile_where_raw_and_where_with_multiple_bindings(self): - query = self.builder.where_raw( - """ "age" = ? AND "is_admin" = ?""", [18, True] - ).where("email", "test@example.com") + query = self.builder.where_raw(""" "age" = ? AND "is_admin" = ?""", [18, True]).where( + "email", "test@example.com" + ) self.assertEqual( query.to_qmark(), """SELECT * FROM "users" WHERE "age" = ? AND "is_admin" = ? AND "users"."email" = ?""", @@ -496,9 +492,7 @@ def where_not_exists_with_lambda(self): return """SELECT * FROM "users" WHERE NOT EXISTS (SELECT * FROM "users" WHERE "users"."age" = '1')""" def where_date(self): - return ( - """SELECT * FROM "users" WHERE DATE("users"."created_at") = '2022-06-01'""" - ) + return """SELECT * FROM "users" WHERE DATE("users"."created_at") = '2022-06-01'""" def or_where_null(self): return """SELECT * FROM "users" WHERE "users"."column1" IS NULL OR "users"."column2" IS NULL""" diff --git a/tests/postgres/grammar/test_update_grammar.py b/tests/postgres/grammar/test_update_grammar.py index 88d458d3..fa8cb3c3 100644 --- a/tests/postgres/grammar/test_update_grammar.py +++ b/tests/postgres/grammar/test_update_grammar.py @@ -9,41 +9,24 @@ class BaseTestCaseUpdateGrammar: def setUp(self): - self.builder = QueryBuilder( - PostgresGrammar, connection_class=PostgresConnection, table="users" - ) + self.builder = QueryBuilder(PostgresGrammar, connection_class=PostgresConnection, table="users") def test_can_compile_update(self): - to_sql = ( - self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql() - ) + to_sql = self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_multiple_update(self): - to_sql = self.builder.update( - {"name": "Joe", "email": "user@email.com"}, dry=True - ).to_sql() + to_sql = self.builder.update({"name": "Joe", "email": "user@email.com"}, dry=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_update_with_multiple_where(self): - to_sql = ( - self.builder.where("name", "bob") - .where("age", 20) - .update({"name": "Joe"}, dry=True) - .to_sql() - ) - - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.where("name", "bob").where("age", 20).update({"name": "Joe"}, dry=True).to_sql() + + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) # def test_can_compile_increment(self): @@ -65,9 +48,7 @@ def test_can_compile_update_with_multiple_where(self): def test_raw_expression(self): to_sql = self.builder.update({"name": Raw('"username"')}, dry=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) diff --git a/tests/postgres/schema/test_postgres_schema_builder.py b/tests/postgres/schema/test_postgres_schema_builder.py index c2270b34..ebac316b 100644 --- a/tests/postgres/schema/test_postgres_schema_builder.py +++ b/tests/postgres/schema/test_postgres_schema_builder.py @@ -27,9 +27,7 @@ def test_can_add_columns(self): self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), - [ - 'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)' - ], + ['CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)'], ) def test_can_add_tiny_text(self): @@ -37,9 +35,7 @@ def test_can_add_tiny_text(self): blueprint.tiny_text("description") self.assertEqual(len(blueprint.table.added_columns), 1) - self.assertEqual( - blueprint.to_sql(), ['CREATE TABLE "users" ("description" TEXT NOT NULL)'] - ) + self.assertEqual(blueprint.to_sql(), ['CREATE TABLE "users" ("description" TEXT NOT NULL)']) def test_can_add_unsigned_decimal(self): with self.schema.create("users") as blueprint: @@ -59,9 +55,7 @@ def test_can_create_table_if_not_exists(self): self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), - [ - 'CREATE TABLE IF NOT EXISTS "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)' - ], + ['CREATE TABLE IF NOT EXISTS "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)'], ) def test_can_add_column_comment(self): @@ -138,9 +132,7 @@ def test_can_add_columns_with_long_text(self): blueprint.long_text("description") self.assertEqual(len(blueprint.table.added_columns), 1) - self.assertEqual( - blueprint.to_sql(), ['CREATE TABLE "users" ("description" TEXT NOT NULL)'] - ) + self.assertEqual(blueprint.to_sql(), ['CREATE TABLE "users" ("description" TEXT NOT NULL)']) def test_can_have_unsigned_columns(self): with self.schema.create("users") as blueprint: @@ -220,9 +212,7 @@ def test_can_advanced_table_creation2(self): blueprint.integer("premium") blueprint.double("amount").default(0.0) blueprint.integer("author_id").unsigned().nullable() - blueprint.foreign("author_id").references("id").on("authors").on_delete( - "CASCADE" - ) + blueprint.foreign("author_id").references("id").on("authors").on_delete("CASCADE") blueprint.text("description") blueprint.timestamps() @@ -262,9 +252,7 @@ def test_can_add_uuid_column(self): def test_can_add_columns_with_foreign_key_constraint_name(self): with self.schema.create("users") as blueprint: blueprint.integer("profile_id") - blueprint.foreign("profile_id", name="profile_foreign").references("id").on( - "profiles" - ) + blueprint.foreign("profile_id", name="profile_foreign").references("id").on("profiles") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( diff --git a/tests/postgres/schema/test_postgres_schema_builder_alter.py b/tests/postgres/schema/test_postgres_schema_builder_alter.py index 9d929dd2..e9510bc2 100644 --- a/tests/postgres/schema/test_postgres_schema_builder_alter.py +++ b/tests/postgres/schema/test_postgres_schema_builder_alter.py @@ -99,9 +99,7 @@ def test_alter_drop(self): def test_alter_add_column_and_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.unsigned_integer("playlist_id").nullable() - blueprint.foreign("playlist_id").references("id").on("playlists").on_delete( - "cascade" - ) + blueprint.foreign("playlist_id").references("id").on("playlists").on_delete("cascade") sql = [ 'ALTER TABLE "users" ADD COLUMN "playlist_id" INTEGER NULL', @@ -165,9 +163,7 @@ def test_alter_add_primary(self): with self.schema.table("users") as blueprint: blueprint.primary("playlist_id") - sql = [ - 'ALTER TABLE "users" ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)' - ] + sql = ['ALTER TABLE "users" ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)'] self.assertEqual(blueprint.to_sql(), sql) @@ -246,9 +242,7 @@ def test_change_string(self): blueprint.table.from_table = table - sql = [ - 'ALTER TABLE "users" ALTER COLUMN "name" TYPE VARCHAR(93), ALTER COLUMN "name" SET NOT NULL' - ] + sql = ['ALTER TABLE "users" ALTER COLUMN "name" TYPE VARCHAR(93), ALTER COLUMN "name" SET NOT NULL'] self.assertEqual(blueprint.to_sql(), sql) diff --git a/tests/sqlite/builder/test_sqlite_query_builder.py b/tests/sqlite/builder/test_sqlite_query_builder.py index 89234729..12236c2f 100644 --- a/tests/sqlite/builder/test_sqlite_query_builder.py +++ b/tests/sqlite/builder/test_sqlite_query_builder.py @@ -72,9 +72,7 @@ def test_sum(self): builder = self.get_builder() builder.sum("age") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_sum_aggregate(self): @@ -88,9 +86,7 @@ def test_sum_aggregate_with_alias(self): builder = self.get_builder() builder.aggregate("SUM", "age", alias="number") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_sum_aggregate_with_alias_in_column_name(self): @@ -104,67 +100,51 @@ def test_where_like(self): builder = self.get_builder() builder.where("age", "like", "%name%") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_not_like(self): builder = self.get_builder() builder.where("age", "not like", "%name%") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_max(self): builder = self.get_builder() builder.max("age") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_min(self): builder = self.get_builder() builder.min("age") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_avg(self): builder = self.get_builder() builder.avg("age") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_all(self): builder = self.get_builder() builder.all() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_get(self): builder = self.get_builder() builder.get() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_first(self): builder = self.get_builder().first(query=True) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_last(self): @@ -188,9 +168,7 @@ def test_find_or_404_exception(self): def test_select(self): builder = self.get_builder() builder.select("name", "email") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_select_multiple(self): @@ -207,9 +185,7 @@ def test_add_select(self): .add_select("salary", lambda q: q.count("*").table("salary")) .to_sql() ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_add_select_no_table(self): @@ -225,9 +201,7 @@ def test_add_select_no_table(self): ) .to_sql() ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_add_select_with_raw(self): @@ -237,24 +211,16 @@ def test_add_select_with_raw(self): .from_("some_table") .add_select( "other_test", - lambda query: ( - query.max("updated_at") - .from_("different_table") - .where("some_id", "=", "3") - ), + lambda query: query.max("updated_at").from_("different_table").where("some_id", "=", "3"), ) ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_select_raw(self): builder = self.get_builder() builder.select_raw("count(email) as email_count") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_create(self): @@ -263,25 +229,19 @@ def test_create(self): {"name": "Corentin All", "email": "corentin@yopmail.com"}, query=True, ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_delete(self): builder = self.get_builder() builder.delete("name", "Joe", query=True) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where(self): builder = self.get_builder() builder.where("name", "Joe") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_dictionary(self): @@ -293,152 +253,112 @@ def test_where_dictionary(self): def test_where_exists(self): builder = self.get_builder() builder.where_exists("name") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_limit(self): builder = self.get_builder() builder.limit(5) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_offset(self): builder = self.get_builder() builder.offset(5) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_offset_with_limit(self): builder = self.get_builder() builder.limit(2).offset(5) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_join(self): builder = self.get_builder() builder.join("profiles", "users.id", "=", "profiles.user_id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_left_join(self): builder = self.get_builder() builder.left_join("profiles", "users.id", "=", "profiles.user_id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_right_join(self): builder = self.get_builder() builder.right_join("profiles", "users.id", "=", "profiles.user_id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_update(self): - builder = self.get_builder().update( - {"name": "Joe", "email": "joe@yopmail.com"}, dry=True - ) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + builder = self.get_builder().update({"name": "Joe", "email": "joe@yopmail.com"}, dry=True) + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_increment(self): builder = self.get_builder().increment("age", 1, dry=True) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_decrement(self): builder = self.get_builder().decrement("age", 1, dry=True) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_count(self): builder = self.get_builder() builder.count("id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_order_by_asc(self): builder = self.get_builder() builder.order_by("email", "asc") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_order_by_multiple(self): builder = self.get_builder() builder.order_by("email, name, active") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_order_by_reference_direction(self): builder = self.get_builder() builder.order_by("email, name desc") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_order_by_raw(self): builder = self.get_builder() builder.order_by_raw("col asc") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_order_by_desc(self): builder = self.get_builder() builder.order_by("email", "desc") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_column(self): builder = self.get_builder() builder.where_column("name", "username") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_not_in(self): builder = self.get_builder() builder.where_not_in("id", [1, 2, 3]) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_between(self): builder = self.get_builder() builder.between("id", 2, 5) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_between_persisted(self): @@ -450,9 +370,7 @@ def test_between_persisted(self): def test_not_between(self): builder = self.get_builder() builder.not_between("id", 2, 5) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_not_between_persisted(self): @@ -465,65 +383,49 @@ def test_where_in(self): builder = self.get_builder() builder.where_in("id", [1, 2, 3]) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_null(self): builder = self.get_builder() builder.where_null("name") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_not_null(self): builder = self.get_builder() builder.where_not_null("name") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_having(self): builder = self.get_builder(table="payments") - builder.select("user_id").avg("salary").group_by("user_id").having( - "salary", ">=", "1000" - ) + builder.select("user_id").avg("salary").group_by("user_id").having("salary", ">=", "1000") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_group_by(self): builder = self.get_builder(table="payments") builder.select("user_id").min("salary").group_by("user_id") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_group_by_raw(self): builder = self.get_builder(table="payments") builder.select("user_id").min("salary").group_by_raw("count(*)") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_group_by_multiple(self): builder = self.get_builder(table="payments") builder.select("user_id").min("salary").group_by("user_id").group_by("salary") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_group_by_multiple_in_same_group_by(self): @@ -550,49 +452,37 @@ def test_builder_alone(self): def test_where_lt(self): builder = self.get_builder() builder.where("age", "<", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_lte(self): builder = self.get_builder() builder.where("age", "<=", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_gt(self): builder = self.get_builder() builder.where("age", ">", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_gte(self): builder = self.get_builder() builder.where("age", ">=", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_where_ne(self): builder = self.get_builder() builder.where("age", "!=", "20") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_or_where(self): builder = self.get_builder() builder.where("age", "20").or_where("age", "<", 20) - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_can_call_with_schema(self): @@ -616,17 +506,13 @@ def test_can_call_with_raw(self): def test_truncate(self): builder = self.get_builder() sql = builder.truncate(dry=True) - sql_ref = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql_ref = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(sql, sql_ref) def test_truncate_without_foreign_keys(self): builder = self.get_builder() sql = builder.truncate(foreign_keys=True) - sql_ref = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql_ref = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(sql, sql_ref) @@ -853,9 +739,7 @@ def order_by_multiple(self): """ builder.order_by('email', 'asc') """ - return ( - """SELECT * FROM "users" ORDER BY "email" ASC, "name" ASC, "active" ASC""" - ) + return """SELECT * FROM "users" ORDER BY "email" ASC, "name" ASC, "active" ASC""" def order_by_raw(self): """ @@ -1030,17 +914,13 @@ def truncate_without_foreign_keys(self): def test_latest(self): builder = self.get_builder() builder.latest("email") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def test_oldest(self): builder = self.get_builder() builder.oldest("email") - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(builder.to_sql(), sql) def oldest(self): diff --git a/tests/sqlite/builder/test_sqlite_query_builder_relationships.py b/tests/sqlite/builder/test_sqlite_query_builder_relationships.py index b364ecda..350acc56 100644 --- a/tests/sqlite/builder/test_sqlite_query_builder_relationships.py +++ b/tests/sqlite/builder/test_sqlite_query_builder_relationships.py @@ -76,9 +76,7 @@ def test_doesnt_have(self): def test_where_doesnt_have(self): builder = self.get_builder() - sql = builder.where_doesnt_have( - "articles", lambda q: q.where("title", "Eggs and Ham") - ).to_sql() + sql = builder.where_doesnt_have("articles", lambda q: q.where("title", "Eggs and Ham")).to_sql() self.assertEqual( sql, """SELECT * FROM "users" WHERE NOT EXISTS (""" diff --git a/tests/sqlite/grammar/test_sqlite_delete_grammar.py b/tests/sqlite/grammar/test_sqlite_delete_grammar.py index 3bd36a87..15fdd558 100644 --- a/tests/sqlite/grammar/test_sqlite_delete_grammar.py +++ b/tests/sqlite/grammar/test_sqlite_delete_grammar.py @@ -12,30 +12,19 @@ def setUp(self): def test_can_compile_delete(self): to_sql = self.builder.delete("id", 1, query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_delete_in(self): to_sql = self.builder.delete("id", [1, 2, 3], query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_delete_with_where(self): - to_sql = ( - self.builder.where("age", 20) - .where("profile", 1) - .delete(query=True) - .to_sql() - ) + to_sql = self.builder.where("age", 20).where("profile", 1).delete(query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) diff --git a/tests/sqlite/grammar/test_sqlite_insert_grammar.py b/tests/sqlite/grammar/test_sqlite_insert_grammar.py index 35ee7eb9..72ccb7e4 100644 --- a/tests/sqlite/grammar/test_sqlite_insert_grammar.py +++ b/tests/sqlite/grammar/test_sqlite_insert_grammar.py @@ -12,17 +12,13 @@ def setUp(self): def test_can_compile_insert(self): to_sql = self.builder.create({"name": "Joe"}, query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_insert_with_keywords(self): to_sql = self.builder.create(name="Joe", query=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create(self): @@ -36,9 +32,7 @@ def test_can_compile_bulk_create(self): query=True, ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create_qmark(self): @@ -46,9 +40,7 @@ def test_can_compile_bulk_create_qmark(self): [{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True ).to_qmark() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create_multiple(self): @@ -61,9 +53,7 @@ def test_can_compile_bulk_create_multiple(self): query=True, ).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) diff --git a/tests/sqlite/grammar/test_sqlite_select_grammar.py b/tests/sqlite/grammar/test_sqlite_select_grammar.py index 45575648..7b6d0d21 100644 --- a/tests/sqlite/grammar/test_sqlite_select_grammar.py +++ b/tests/sqlite/grammar/test_sqlite_select_grammar.py @@ -78,9 +78,7 @@ def can_compile_with_multiple_order_by(self): """ self.builder.select('username').order_by('age', 'desc').order_by('name').to_sql() """ - return ( - """SELECT "users"."username" FROM "users" ORDER BY "age" DESC, "name" ASC""" - ) + return """SELECT "users"."username" FROM "users" ORDER BY "age" DESC, "name" ASC""" def can_compile_with_group_by(self): """ @@ -116,9 +114,7 @@ def can_compile_where_not_null(self): """ self.builder.select('username').where_not_null('age').to_sql() """ - return ( - """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NOT NULL""" - ) + return """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NOT NULL""" def can_compile_where_raw(self): """ @@ -202,7 +198,9 @@ def can_compile_sub_select_value(self): ).to_sql() """ - return """SELECT * FROM "users" WHERE "users"."name" = (SELECT SUM("users"."age") AS age FROM "users")""" + return ( + """SELECT * FROM "users" WHERE "users"."name" = (SELECT SUM("users"."age") AS age FROM "users")""" + ) def can_compile_complex_sub_select(self): """ @@ -235,7 +233,9 @@ def can_compile_having(self): """ builder.sum('age').group_by('age').having('age').to_sql() """ - return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age\"""" + return ( + """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age\"""" + ) def can_compile_having_order(self): """ @@ -296,9 +296,9 @@ def test_can_compile_where_raw(self): self.assertEqual(to_sql, """SELECT * FROM "users" WHERE "age" = '18'""") def test_can_compile_where_raw_and_where_with_multiple_bindings(self): - query = self.builder.where_raw( - """ "age" = ? AND "is_admin" = ? """, [18, True] - ).where("email", "test@example.com") + query = self.builder.where_raw(""" "age" = ? AND "is_admin" = ? """, [18, True]).where( + "email", "test@example.com" + ) self.assertEqual( query.to_qmark(), """SELECT * FROM "users" WHERE "age" = ? AND "is_admin" = ? AND "users"."email" = ?""", @@ -465,9 +465,7 @@ def where_not_exists_with_lambda(self): return """SELECT * FROM "users" WHERE NOT EXISTS (SELECT * FROM "users" WHERE "users"."age" = '1')""" def where_date(self): - return ( - """SELECT * FROM "users" WHERE DATE("users"."created_at") = '2022-06-01'""" - ) + return """SELECT * FROM "users" WHERE DATE("users"."created_at") = '2022-06-01'""" def or_where_null(self): return """SELECT * FROM "users" WHERE "users"."column1" IS NULL OR "users"."column2" IS NULL""" diff --git a/tests/sqlite/grammar/test_sqlite_update_grammar.py b/tests/sqlite/grammar/test_sqlite_update_grammar.py index 0c8e06e0..df8a289e 100644 --- a/tests/sqlite/grammar/test_sqlite_update_grammar.py +++ b/tests/sqlite/grammar/test_sqlite_update_grammar.py @@ -11,36 +11,21 @@ def setUp(self): self.builder = QueryBuilder(SQLiteGrammar, table="users") def test_can_compile_update(self): - to_sql = ( - self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql() - ) + to_sql = self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_multiple_update(self): - to_sql = self.builder.update( - {"name": "Joe", "email": "user@email.com"}, dry=True - ).to_sql() + to_sql = self.builder.update({"name": "Joe", "email": "user@email.com"}, dry=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) def test_can_compile_update_with_multiple_where(self): - to_sql = ( - self.builder.where("name", "bob") - .where("age", 20) - .update({"name": "Joe"}, dry=True) - .to_sql() - ) - - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + to_sql = self.builder.where("name", "bob").where("age", 20).update({"name": "Joe"}, dry=True).to_sql() + + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) # def test_can_compile_increment(self): @@ -60,9 +45,7 @@ def test_can_compile_update_with_multiple_where(self): def test_raw_expression(self): to_sql = self.builder.update({"name": Raw('"username"')}, dry=True).to_sql() - sql = getattr( - self, inspect.currentframe().f_code.co_name.replace("test_", "") - )() + sql = getattr(self, inspect.currentframe().f_code.co_name.replace("test_", ""))() self.assertEqual(to_sql, sql) diff --git a/tests/sqlite/models/test_sqlite_model.py b/tests/sqlite/models/test_sqlite_model.py index 2f0ef9b8..3f118cdd 100644 --- a/tests/sqlite/models/test_sqlite_model.py +++ b/tests/sqlite/models/test_sqlite_model.py @@ -225,9 +225,7 @@ def test_should_return_relation_applying_hidden_attributes(self): blueprint.foreign("user_id").references("id").on("users_hidden") blueprint.timestamps() - UserHydrateHidden.create( - name="Name", password="pass_value", token="token_value" - ) + UserHydrateHidden.create(name="Name", password="pass_value", token="token_value") Group.create(name="Group") diff --git a/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py b/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py index 4b10780e..618da600 100644 --- a/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py +++ b/tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py @@ -26,9 +26,7 @@ class Course(Model): __connection__ = "dev" __fillable__ = ["course_id", "name"] - @has_many_through( - None, "in_course_id", "active_student_id", "course_id", "student_id" - ) + @has_many_through(None, "in_course_id", "active_student_id", "course_id", "student_id") def students(self): return [Student, Enrolment] diff --git a/tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py b/tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py index 73988429..153fe9d2 100644 --- a/tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py +++ b/tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py @@ -126,14 +126,8 @@ def test_has_one_through_can_eager_load(self): self.assertEqual(shipment2.from_country.country_id, 40) # check .first() and .get() produce the same result - single = ( - IncomingShipment.where("name", "Tractor Parts") - .with_("from_country") - .first() - ) - single_get = ( - IncomingShipment.where("name", "Tractor Parts").with_("from_country").get() - ) + single = IncomingShipment.where("name", "Tractor Parts").with_("from_country").first() + single_get = IncomingShipment.where("name", "Tractor Parts").with_("from_country").get() self.assertEqual(single.from_country.country_id, 10) self.assertEqual(single_get.count(), 1) self.assertEqual( @@ -158,7 +152,5 @@ def test_has_one_through_can_get_related(self): self.assertEqual(shipment.from_country.country_id, 10) def test_has_one_through_has_query(self): - shipments = IncomingShipment.where_has( - "from_country", lambda query: query.where("name", "USA") - ) + shipments = IncomingShipment.where_has("from_country", lambda query: query.where("name", "USA")) self.assertEqual(shipments.count(), 2) diff --git a/tests/sqlite/relationships/test_sqlite_relationships.py b/tests/sqlite/relationships/test_sqlite_relationships.py index 866a2a16..9259193c 100644 --- a/tests/sqlite/relationships/test_sqlite_relationships.py +++ b/tests/sqlite/relationships/test_sqlite_relationships.py @@ -159,12 +159,8 @@ def test_belongs_to_many(self): self.assertEqual(store.products.count(), 3) self.assertEqual(store.products.serialize()[0]["id"], 4) self.assertEqual(store.products.serialize()[0]["name"], "Handgun") - self.assertEqual( - store.products.serialize()[0]["updated_at"], "2020-01-01T00:00:00+00:00" - ) - self.assertEqual( - store.products.serialize()[0]["created_at"], "2020-01-01T00:00:00+00:00" - ) + self.assertEqual(store.products.serialize()[0]["updated_at"], "2020-01-01T00:00:00+00:00") + self.assertEqual(store.products.serialize()[0]["created_at"], "2020-01-01T00:00:00+00:00") def test_belongs_to_eager_many(self): store = Store.hydrate({"id": 2, "name": "Walmart"}) diff --git a/tests/sqlite/schema/test_sqlite_schema_builder.py b/tests/sqlite/schema/test_sqlite_schema_builder.py index b1e771c5..14965168 100644 --- a/tests/sqlite/schema/test_sqlite_schema_builder.py +++ b/tests/sqlite/schema/test_sqlite_schema_builder.py @@ -25,9 +25,7 @@ def test_can_add_columns(self): self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), - [ - 'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)' - ], + ['CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)'], ) def test_can_add_tiny_text(self): @@ -58,9 +56,7 @@ def test_can_create_table_if_not_exists(self): self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), - [ - 'CREATE TABLE IF NOT EXISTS "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)' - ], + ['CREATE TABLE IF NOT EXISTS "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)'], ) def test_can_add_columns_with_constraint(self): @@ -72,9 +68,7 @@ def test_can_add_columns_with_constraint(self): self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), - [ - 'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL, UNIQUE(name))' - ], + ['CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL, UNIQUE(name))'], ) def test_can_have_float_type(self): @@ -114,9 +108,7 @@ def test_can_add_columns_with_foreign_key_constraint_name(self): blueprint.string("name").unique() blueprint.integer("age") blueprint.integer("profile_id") - blueprint.foreign("profile_id", name="profile_foreign").references("id").on( - "profiles" - ) + blueprint.foreign("profile_id", name="profile_foreign").references("id").on("profiles") self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( @@ -260,9 +252,7 @@ def test_can_advanced_table_creation2(self): blueprint.string("thumbnail").nullable() blueprint.integer("premium") blueprint.integer("author_id").unsigned().nullable() - blueprint.foreign("author_id").references("id").on("users").on_delete( - "set null" - ) + blueprint.foreign("author_id").references("id").on("users").on_delete("set null") blueprint.text("description") blueprint.timestamps() diff --git a/tests/sqlite/schema/test_sqlite_schema_builder_alter.py b/tests/sqlite/schema/test_sqlite_schema_builder_alter.py index 82580aa0..efa64253 100644 --- a/tests/sqlite/schema/test_sqlite_schema_builder_alter.py +++ b/tests/sqlite/schema/test_sqlite_schema_builder_alter.py @@ -156,18 +156,16 @@ def test_alter_add_primary(self): with self.schema.table("users") as blueprint: blueprint.primary("playlist_id") - sql = [ - 'ALTER TABLE "users" ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)' - ] + sql = ['ALTER TABLE "users" ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)'] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_column_and_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.unsigned_integer("playlist_id").nullable() - blueprint.foreign("playlist_id").references("id").on("playlists").on_delete( - "cascade" - ).on_update("SET NULL") + blueprint.foreign("playlist_id").references("id").on("playlists").on_delete("cascade").on_update( + "SET NULL" + ) table = Table("users") table.add_column("age", "string") @@ -189,9 +187,9 @@ def test_alter_add_column_and_foreign_key(self): def test_alter_add_foreign_key_only(self): with self.schema.table("users") as blueprint: - blueprint.foreign("playlist_id").references("id").on("playlists").on_delete( - "cascade" - ).on_update("set null") + blueprint.foreign("playlist_id").references("id").on("playlists").on_delete("cascade").on_update( + "set null" + ) table = Table("users") table.add_column("age", "string") diff --git a/tests/sqlite/schema/test_table.py b/tests/sqlite/schema/test_table.py index 464f56d0..6b21345f 100644 --- a/tests/sqlite/schema/test_table.py +++ b/tests/sqlite/schema/test_table.py @@ -101,9 +101,7 @@ def test_create_sql_with_foreign_key_constraint(self): def test_can_build_table_from_connection_call(self): sql_details = DATABASES["dev"] table = self.platform.get_current_schema( - SQLiteConnection( - database=sql_details["database"], name="dev" - ).make_connection(), + SQLiteConnection(database=sql_details["database"], name="dev").make_connection(), "table_schema", ) From 1e2d65545b2c2a232ae6d340540640c849a46efb Mon Sep 17 00:00:00 2001 From: Kieren Eaton <499977+circulon@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:25:49 +0800 Subject: [PATCH 07/10] removed CI branch restrictions --- .github/workflows/pythonapp.yml | 55 ++++++++++++++++---------------- orm.sqlite3 | Bin 188416 -> 188416 bytes 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 411d37d1..d22f340f 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -2,7 +2,6 @@ name: CI on: push: - branches: [main, master] pull_request: concurrency: @@ -33,33 +32,33 @@ jobs: needs: lint services: - postgres: - image: postgres:16 - env: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - POSTGRES_DB: postgres - ports: - - 5432/tcp - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - - mysql: - image: mysql:8.0 - env: - MYSQL_ALLOW_EMPTY_PASSWORD: yes - MYSQL_DATABASE: orm - ports: - - 3306/tcp - options: >- - --health-cmd "mysqladmin ping" - --health-interval 10s - --health-timeout 5s - --health-retries 5 - +# postgres: +# image: postgres:16 +# env: +# POSTGRES_USER: postgres +# POSTGRES_PASSWORD: postgres +# POSTGRES_DB: postgres +# ports: +# - 5432/tcp +# options: >- +# --health-cmd pg_isready +# --health-interval 10s +# --health-timeout 5s +# --health-retries 5 +# +# mysql: +# image: mysql:8.0 +# env: +# MYSQL_ALLOW_EMPTY_PASSWORD: yes +# MYSQL_DATABASE: orm +# ports: +# - 3306/tcp +# options: >- +# --health-cmd "mysqladmin ping" +# --health-interval 10s +# --health-timeout 5s +# --health-retries 5 +# strategy: fail-fast: false matrix: diff --git a/orm.sqlite3 b/orm.sqlite3 index cdda25e6cbe72200a2528df38cc8037284a18d9b..6baaa386689671e2e43182a10670c80864f631b0 100644 GIT binary patch delta 401 zcmZoTz};|wdxDe@mlOj7gB}n=0ZYR~9U~?#sm6q@35-knnWMPkr?XFBT* Date: Tue, 17 Mar 2026 17:30:25 +0800 Subject: [PATCH 08/10] reverted workflow setup --- .github/workflows/pythonapp.yml | 54 ++++++++++++++++----------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index d22f340f..6270c4bb 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -32,33 +32,33 @@ jobs: needs: lint services: -# postgres: -# image: postgres:16 -# env: -# POSTGRES_USER: postgres -# POSTGRES_PASSWORD: postgres -# POSTGRES_DB: postgres -# ports: -# - 5432/tcp -# options: >- -# --health-cmd pg_isready -# --health-interval 10s -# --health-timeout 5s -# --health-retries 5 -# -# mysql: -# image: mysql:8.0 -# env: -# MYSQL_ALLOW_EMPTY_PASSWORD: yes -# MYSQL_DATABASE: orm -# ports: -# - 3306/tcp -# options: >- -# --health-cmd "mysqladmin ping" -# --health-interval 10s -# --health-timeout 5s -# --health-retries 5 -# + postgres: + image: postgres:16 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + ports: + - 5432/tcp + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + mysql: + image: mysql:8.0 + env: + MYSQL_ALLOW_EMPTY_PASSWORD: yes + MYSQL_DATABASE: orm + ports: + - 3306/tcp + options: >- + --health-cmd "mysqladmin ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + strategy: fail-fast: false matrix: From 75840f842c39024814ad2a00005a2e8e7015cb00 Mon Sep 17 00:00:00 2001 From: Kieren Eaton <499977+circulon@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:36:43 +0800 Subject: [PATCH 09/10] added faker back ro requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index b608f926..79edac6f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ ruff +faker pytest pytest-env coverage From 1440478cfddb8bc6970b5a8387844c6a2541907e Mon Sep 17 00:00:00 2001 From: Kieren Eaton <499977+circulon@users.noreply.github.com> Date: Wed, 18 Mar 2026 10:40:03 +0800 Subject: [PATCH 10/10] updated action versions --- .github/workflows/pythonapp.yml | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 6270c4bb..36dab03c 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -13,16 +13,13 @@ jobs: name: Lint (ruff) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 with: python-version: "3.12" cache: pip - - name: Install dev dependencies run: make init-ci - - name: Lint & format check run: make check @@ -30,7 +27,6 @@ jobs: name: Python ${{ matrix.python-version }} runs-on: ubuntu-latest needs: lint - services: postgres: image: postgres:16 @@ -45,7 +41,6 @@ jobs: --health-interval 10s --health-timeout 5s --health-retries 5 - mysql: image: mysql:8.0 env: @@ -58,23 +53,18 @@ jobs: --health-interval 10s --health-timeout 5s --health-retries 5 - strategy: fail-fast: false matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} cache: pip - - name: Install test dependencies run: make init-ci - - name: Run migrations env: POSTGRES_DATABASE_HOST: localhost @@ -90,7 +80,6 @@ jobs: run: | python orm migrate --connection postgres python orm migrate --connection mysql - - name: Run tests env: POSTGRES_DATABASE_HOST: localhost