diff --git a/.github/workflows/python-app-ci.yml b/.github/workflows/python-app-ci.yml index 3af05f30..1d7474be 100644 --- a/.github/workflows/python-app-ci.yml +++ b/.github/workflows/python-app-ci.yml @@ -23,12 +23,14 @@ jobs: uses: actions/setup-python@v3 with: python-version: "3.10" + - name: Install uv + uses: astral-sh/setup-uv@v3 + with: + version: "latest" - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install flake8 pytest - if ${{ -f requirements.txt }}; then pip install -r requirements.txt; fi - python -m pip install -e . + uv sync --all-extras + uv pip install -e . #- name: Lint with flake8 # run: | # # stop the build if there are Python syntax errors or undefined names @@ -39,7 +41,7 @@ jobs: env: MONGODB_PASSWORD: ${{secrets.MONGODB_PASSWORD}} run: | - pytest + uv run pytest lint-and-format-backend: continue-on-error: true @@ -56,6 +58,10 @@ jobs: uses: actions/setup-python@v3 with: python-version: "3.10" + - name: Install uv + uses: astral-sh/setup-uv@v3 + with: + version: "latest" - name: Check Black formatting uses: reviewdog/action-black@v3 @@ -76,8 +82,7 @@ jobs: REVIEWDOG_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | cd monggregate - pip3 install ruff - ruff check . \ + uv run --with ruff ruff check . \ | reviewdog -name=Ruff -reporter=local -reporter=github-pr-review -efm="%f:%l:%c: %m" -filter-mode=added -fail-on-error=true - name: Check Mypy linting @@ -87,7 +92,7 @@ jobs: github_token: ${{ secrets.GITHUB_TOKEN }} reporter: github-pr-review filter_mode: added - setup_command: pip3 install --no-cache-dir --upgrade -r ../requirements/all.txt + setup_command: uv sync --all-extras setup_method: install level: warning workdir: mongreggate diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index e6a79064..7b851b61 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -23,17 +23,19 @@ jobs: uses: actions/setup-python@v3 with: python-version: "3.10" + - name: Install uv + uses: astral-sh/setup-uv@v3 + with: + version: "latest" - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install . - pip install -r requirements/core.txt - pip install -r requirements/testing.txt + uv sync --all-extras + uv pip install -e . - name: Test with pytest env: MONGODB_PASSWORD: ${{secrets.MONGODB_PASSWORD}} run: | - pytest + uv run pytest lint-and-format-backend: continue-on-error: true @@ -50,6 +52,10 @@ jobs: uses: actions/setup-python@v3 with: python-version: "3.10" + - name: Install uv + uses: astral-sh/setup-uv@v3 + with: + version: "latest" - name: Check Black formatting uses: reviewdog/action-black@v3 @@ -68,8 +74,7 @@ jobs: env: REVIEWDOG_GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | - pip3 install ruff - ruff check . \ + uv run --with ruff ruff check . \ | reviewdog -name=Ruff -reporter=local -reporter=github-pr-review -efm="%f:%l:%c: %m" -filter-mode=added -fail-on-error=false - name: Check Mypy linting @@ -79,7 +84,7 @@ jobs: github_token: ${{ secrets.GITHUB_TOKEN }} reporter: github-pr-review filter_mode: added - setup_command: pip3 install --no-cache-dir --upgrade -r requirements/all.txt + setup_command: uv sync --all-extras setup_method: install level: warning mypy_flags: --python-version 3.10 diff --git a/.gitignore b/.gitignore index 5a4d128b..8e0855f9 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,5 @@ dist test.py site/ .pypirc -clean_pycache.ps1 \ No newline at end of file +clean_pycache.ps1 +test_deployments \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3494bb61..b65e355f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ [project] name = "monggregate" -version = "0.21.0" +version = "0.22.0" description = "MongoDB aggregation pipelines made easy. Joins, grouping, counting and much more..." readme = "README.md" authors = [{ name = "Vianney Mixtur", email = "vianney.mixtur@outlook.fr" }] @@ -22,12 +22,18 @@ dependencies = [ "typing-extensions>=4.0", ] +[project.optional-dependencies] +mongodb = [ + "pymongo>=3.0.0", + "motor>=3.0.0", +] + [project.urls] Homepage = "https://github.com/VianneyMI/monggregate" documentation = "https://vianneymi.github.io/monggregate/" [tool.bumpver] -current_version = "0.21.0" +current_version = "0.22.0" version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]" commit_message = "bump version {old_version} -> {new_version}" commit = true @@ -41,11 +47,37 @@ push = false ] "src/monggregate/__init__.py" = ['__version__ = "{version}"'] +[tool.setuptools] +packages = ["monggregate", "tests"] + +[tool.setuptools.package-dir] +monggregate = "src/monggregate" +tests = "tests" + [dependency-groups] -dev = [ +test = [ "pytest>=8.3.5", + "pytest-html>=3.2.0", + "python-dotenv>=1.0.0", + "certifi", +] +lint = [ + "mypy>=1.6.1", + "ruff", + "black", ] doc = [ "mkdocs==1.5.2", "mkdocs-material==9.2.7", ] +dev = [ + "pytest>=8.3.5", + "pytest-html>=3.2.0", + "python-dotenv>=1.0.0", + "certifi", + "mypy>=1.6.1", + "ruff", + "black", + "mkdocs==1.5.2", + "mkdocs-material==9.2.7", +] diff --git a/pytest.ini b/pytest.ini index 359d7ff1..2ef11ffa 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,6 @@ [pytest] # Be careful when registering the mark the semi-colon (:) should be right after the mark without any space before it. +pythonpath = . tests markers = unit: unit tests functional: functional tests diff --git a/src/monggregate/__init__.py b/src/monggregate/__init__.py index a8fa57f9..fa94b1d9 100644 --- a/src/monggregate/__init__.py +++ b/src/monggregate/__init__.py @@ -5,7 +5,7 @@ __all__ = ["Pipeline", "S", "SS"] -__version__ = "0.21.0" +__version__ = "0.22.0" __author__ = "Vianney Mixtur" __contact__ = "prenom.nom@outlook.fr" __copyright__ = "Copyright © 2022-2024 Vianney Mixtur" diff --git a/src/monggregate/base.py b/src/monggregate/base.py index 4e4cfa6c..4cc41143 100644 --- a/src/monggregate/base.py +++ b/src/monggregate/base.py @@ -72,6 +72,21 @@ class Config(pyd.BaseConfig): alias_generator = camelize +class ExpressionWrapper(BaseModel): + """Wrapper for an expression. + + To be used for Stage, Operator or other MongoDB object that hasn't been interfaced yet in `monggregate`. + """ + + _expression: Expression + + @property + def expression(self) -> Expression: + """Expression property""" + + return self.expression + + def isbasemodel(instance: Any) -> TypeGuard[BaseModel]: """Returns true if instance is an instance of BaseModel""" @@ -82,24 +97,20 @@ def express(obj: Any) -> dict | list[dict]: """Resolves an expression encapsulated in an object from a class inheriting from BaseModel""" if isbasemodel(obj): + # If it's a BaseModel instance, get its expression output: dict | list = obj.expression - elif isinstance(obj, list) and any(map(isbasemodel, obj)): + elif isinstance(obj, list): + # Always process lists recursively - they might contain nested BaseModel instances output = [] for element in obj: - if isinstance(element, BaseModel): - output.append( - element.expression - ) # probably should call express(element) - else: - output.append(element) + output.append(express(element)) elif isinstance(obj, dict): + # Always process dictionaries recursively - they might contain nested BaseModel instances output = {} for key, value in obj.items(): - if isinstance(value, BaseModel): - output[key] = value.expression # probably should call express(value) - else: - output[key] = express(value) + output[key] = express(value) else: + # For primitive types (int, str, bool, None, etc.), return as-is output = obj return output diff --git a/src/monggregate/operators/operator.py b/src/monggregate/operators/operator.py index 96b1493d..c7e7fdff 100644 --- a/src/monggregate/operators/operator.py +++ b/src/monggregate/operators/operator.py @@ -1,7 +1,7 @@ """Operator Module""" # Standard Library imports -#---------------------------- +# ---------------------------- from abc import ABC # Package imports @@ -9,13 +9,14 @@ from monggregate.base import BaseModel from monggregate.utils import StrEnum + class Operator(BaseModel, ABC): """MongoDB operator abstract base class""" + class OperatorEnum(StrEnum): """Enumeration of available operators""" - ABS = "$abs" ACCUMULATOR = "$accumulator" ACOS = "$acos" @@ -34,7 +35,7 @@ class OperatorEnum(StrEnum): ATANH = "$atanh" AVG = "$avg" BINARY_SIZE = "$binarySize" - BSON_SIZE ="$bsonSize" + BSON_SIZE = "$bsonSize" CEIL = "$ceil" CMP = "$cmp" CONCAT = "$concat" @@ -47,7 +48,7 @@ class OperatorEnum(StrEnum): DATE_FROM_STRING = "$dateFromString" DATE_TO_PARTS = "$dateToParts" DATE_TO_STRING = "$dateToString" - DAY_OF_MONTH ="$dayOfMonth" + DAY_OF_MONTH = "$dayOfMonth" DAY_OF_WEEK = "$dayOfWeek" DAY_OF_YEAR = "$dayOfYear" DEGREES_TO_RADIANS = "$degreesToRadians" @@ -55,7 +56,7 @@ class OperatorEnum(StrEnum): EQ = "$eq" EXP = "$exp" FILTER = "$filter" - FIRST = "$first" # two operators one for array one for accumulator + FIRST = "$first" # two operators one for array one for accumulator FLOOR = "$floor" FUNCTION = "$function" GET_FIELD = "$getField" @@ -71,9 +72,9 @@ class OperatorEnum(StrEnum): IS_NUMBER = "$isNumber" ISO_DAY_OF_WEEK = "$isoDayOfWeek" ISO_WEEK = "$isoWeek" - ISO_WEEK_YEAR ="$isoWeekYear" + ISO_WEEK_YEAR = "$isoWeekYear" LAST = "$last" # two operators one for array one for accumulator - LET ="$let" + LET = "$let" LITERAL = "$literal" LN = "$ln" LOG = "$log" @@ -85,15 +86,15 @@ class OperatorEnum(StrEnum): MAX = "$max" MERGE_OBJECTS = "$mergeObjects" META = "$meta" - MILLI_SECOND = "$millisecond" + MILLISECOND = "$millisecond" MIN = "$min" - MINUTE ="$minute" - MOD ="$mod" + MINUTE = "$minute" + MOD = "$mod" MONTH = "$month" - MULTIPLY ="$multiply" - NE ="$ne" - NOT ="$not" - OBJECT_TO_ARRAY ="$objectToArray" + MULTIPLY = "$multiply" + NE = "$ne" + NOT = "$not" + OBJECT_TO_ARRAY = "$objectToArray" OR = "$or" POW = "$pow" PUSH = "$push" @@ -101,10 +102,10 @@ class OperatorEnum(StrEnum): RAND = "$rand" RANGE = "$range" REDUCE = "$reduce" - REGEX_FIND ="$regexFind" + REGEX_FIND = "$regexFind" REGEX_FIND_ALL = "$regexFindAll" REGEX_MATCH = "$regexMatch" - REPLACE_ONE ="$replaceOne" + REPLACE_ONE = "$replaceOne" REPLACE_ALL = "$replaceAll" REVERSE_ARRAY = "$reverseArray" ROUND = "$round" @@ -126,16 +127,16 @@ class OperatorEnum(StrEnum): STD_DEV_SAMP = "$stdDevSamp" STR_LEN_BYTES = "$strLenBytes" STR_LEN_CP = "$strLenCP" - STR_CASE_CMP = "$strcasecmp" + STRCASECMP = "$strcasecmp" SUBSTR = "$substr" SUBSTR_BYTES = "$substrBytes" SUBSTR_CP = "$substrCP" - SUBSTRACT = "$subtract" + SUBTRACT = "$subtract" SUM = "$sum" SWITCH = "$switch" TAN = "$tan" TANH = "$tanh" - TO_BOOL ="$toBool" + TO_BOOL = "$toBool" TO_DATE = "$toDate" TO_DECIMAL = "$toDecimal" TO_DOUBLE = "$toDouble" @@ -150,4 +151,4 @@ class OperatorEnum(StrEnum): TYPE = "$type" WEEK = "$week" YEAR = "$year" - ZIP = "$zip" \ No newline at end of file + ZIP = "$zip" diff --git a/src/monggregate/pipeline.py b/src/monggregate/pipeline.py index ea42a806..d7c874cf 100644 --- a/src/monggregate/pipeline.py +++ b/src/monggregate/pipeline.py @@ -6,7 +6,7 @@ from typing_extensions import Self -from monggregate.base import BaseModel, Expression +from monggregate.base import BaseModel, Expression, express from monggregate.stages import ( AnyStage, BucketAuto, @@ -21,7 +21,6 @@ Project, ReplaceRoot, Sample, - Stage, Search, SearchMeta, SearchStageMap, @@ -43,8 +42,7 @@ from monggregate.dollar import ROOT - -class Pipeline(BaseModel): # pylint: disable=too-many-public-methods +class Pipeline(BaseModel): # pylint: disable=too-many-public-methods """ MongoDB aggregation pipeline abstraction. @@ -79,34 +77,18 @@ class Pipeline(BaseModel): # pylint: disable=too-many-public-methods """ - - stages : list[AnyStage|Expression] = [] - - + stages: list[AnyStage | Expression] = [] @property - def expression(self)->list[Expression]: + def expression(self) -> list[Expression]: """Returns the pipeline statement""" - # TODO : Add test on this case - # https://github.com/VianneyMI/monggregate/issues/106 - stages_expressions = [] - - for stage in self.stages: - if isinstance(stage, Stage): - stages_expressions.append(stage.expression) - else: - stages_expressions.append(stage) - - return stages_expressions - - - + return express(self.stages) # ------------------------------------------------ # Pipeline Internal Methods - #------------------------------------------------- - def export(self)->list[dict]: + # ------------------------------------------------- + def export(self) -> list[dict]: """ Exports current pipeline to pymongo format. @@ -117,55 +99,53 @@ def export(self)->list[dict]: return self.expression - # -------------------------------------------------- # Pipeline List Methods - #--------------------------------------------------- - def __add__(self, other:Self)->Self: + # --------------------------------------------------- + def __add__(self, other: Self) -> Self: """Concatenates two pipelines together""" if not isinstance(other, Pipeline): - raise TypeError(f"unsupported operand type(s) for +: 'Pipeline' and '{type(other)}'") - - return Pipeline( - stages=self.stages + other.stages + raise TypeError( + f"unsupported operand type(s) for +: 'Pipeline' and '{type(other)}'" ) - - def __getitem__(self, index:int)->AnyStage: + + return Pipeline(stages=self.stages + other.stages) + + def __getitem__(self, index: int) -> AnyStage: """Returns a stage from the pipeline""" # https://realpython.com/inherit-python-list/ return self.stages[index] - def __setitem__(self, index:int, stage:AnyStage)->None: + def __setitem__(self, index: int, stage: AnyStage) -> None: """Sets a stage in the pipeline""" self.stages[index] = stage - def __delitem__(self, index:int)->None: + def __delitem__(self, index: int) -> None: """Deletes a stage from the pipeline""" del self.stages[index] - def __len__(self)->int: + def __len__(self) -> int: """Returns the length of the pipeline""" return len(self.stages) - - def append(self, stage:AnyStage)->None: + + def append(self, stage: AnyStage) -> None: """Appends a stage to the pipeline""" self.stages.append(stage) - def insert(self, index:int, stage:AnyStage)->None: + def insert(self, index: int, stage: AnyStage) -> None: """Inserts a stage in the pipeline""" self.stages.insert(index, stage) - def extend(self, stages:list[AnyStage])->None: + def extend(self, stages: list[AnyStage]) -> None: """Extends the pipeline with a list of stages""" self.stages.extend(stages) - - #--------------------------------------------------- + # --------------------------------------------------- # Stages - #--------------------------------------------------- + # --------------------------------------------------- # The below methods wrap the constructors of the classes of the same name - def add_fields(self, document:dict={}, **kwargs:Any)->Self: + def add_fields(self, document: dict = {}, **kwargs: Any) -> Self: """ Adds an add_fields stage to the current pipeline. @@ -177,17 +157,23 @@ def add_fields(self, document:dict={}, **kwargs:Any)->Self: Online MongoDB documentation: ----------------------------- Adds new fields to documents. set outputs documents that contain all existing fields from the inputs documents and newly added fields. Both stages are equivalent to a project stage that explicitly specifies all existing fields in the inputs documents and adds the new fields. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/set/#mongodb-pipeline-pipe.-set """ document = document | kwargs - self.stages.append( - Set(document=document) - ) + self.stages.append(Set(document=document)) return self - def bucket(self, *, boundaries:list, by:Any=None, group_by:Any=None, default:Any=None, output:dict|None=None)->Self: + def bucket( + self, + *, + boundaries: list, + by: Any = None, + group_by: Any = None, + default: Any = None, + output: dict | None = None, + ) -> Self: """ Adds a bucket stage to the current pipeline. This stage aggregates documents into buckets specified by the boundaries argument. @@ -230,22 +216,27 @@ def bucket(self, *, boundaries:list, by:Any=None, group_by:Any=None, default:Any Categorizes incoming documents into groups, called buckets, based on a specified expression and bucket boundaries and outputs a document per each bucket. Each output document contains an _id field whose value specifies the inclusive lower bound of the bucket. The output option specifies the fields included in each output document. - $bucket only produces output documents for buckets that contain at least one input document. - + $bucket only produces output documents for buckets that contain at least one input document. + Source : https://www.mongodb.com/docs/manual/meta/aggregation-quick-reference/ """ - + self.stages.append( Bucket( - by = by or group_by, - boundaries = boundaries, - default = default, - output = output + by=by or group_by, boundaries=boundaries, default=default, output=output ) ) return self - def bucket_auto(self, *, by:Any=None, group_by:Any=None, buckets:int, output:dict|None=None, granularity:GranularityEnum|None=None)->Self: + def bucket_auto( + self, + *, + by: Any = None, + group_by: Any = None, + buckets: int, + output: dict | None = None, + granularity: GranularityEnum | None = None, + ) -> Self: """ Adds a bucket_auto stage to the current pipeline. This stage aggregates documents into buckets automatically computed to statisfy the number of buckets desired @@ -287,22 +278,21 @@ def bucket_auto(self, *, by:Any=None, group_by:Any=None, buckets:int, output:dic * The _id.max field specifies the upper bound for the bucket. This bound is exclusive for all buckets except the final bucket in the series, where it is inclusive. * A count field that contains the number of documents in the bucket. The count field is included by default when the output document is not specified. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/bucketAuto/ """ self.stages.append( BucketAuto( - by = by or group_by, - buckets = buckets, - output = output, - granularity = granularity + by=by or group_by, + buckets=buckets, + output=output, + granularity=granularity, ) ) return self - - def count(self, name:str)->Self: + def count(self, name: str) -> Self: """ Adds a count stage to the current pipeline. Passes a document to the next stage that contains a count of the number of documents input to the stage. @@ -322,22 +312,22 @@ def count(self, name:str)->Self: is the name of the output field which has the count as its value. must be a non-empty string, must not start with $ and must not contain the . character. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/count/#mongodb-pipeline-pipe.-count """ - self.stages.append( - Count(name=name) - ) + self.stages.append(Count(name=name)) return self - def explode(self, \ - path_to_array:str|None=None, - path:str|None=None, - *, - include_array_index:str|None=None, - always:bool=False, - preserve_null_and_empty_arrays:bool=False)->Self: + def explode( + self, + path_to_array: str | None = None, + path: str | None = None, + *, + include_array_index: str | None = None, + always: bool = False, + preserve_null_and_empty_arrays: bool = False, + ) -> Self: """ Adds a unwind stage to the current pipeline. @@ -353,25 +343,27 @@ def explode(self, \ ----------------------------- Deconstructs an array field from the input documents to output a document for each element. Each output document is the input document with the value of the array field replaced by the element. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/unwind/#mongodb-pipeline-pipe.-unwind """ self.stages.append( - Unwind( - path_to_array=path_to_array or path, - include_array_index=include_array_index, - always=always or preserve_null_and_empty_arrays - ) + Unwind( + path_to_array=path_to_array or path, + include_array_index=include_array_index, + always=always or preserve_null_and_empty_arrays, ) + ) return self - def group(self, *, by:Any|None=None, _id:Any|None=None, query:dict={})->Self: + def group( + self, *, by: Any | None = None, _id: Any | None = None, query: dict = {} + ) -> Self: """ Adds a group stage to the current pipeline. The group stage separates documents into groups according to a "group key". The output is one document for each unique group key. The output documents can also contain additional fields that are set using accumulator expressions. - + Arguments: ------------------------ - by (_id), str | list[str] | set[str] | dict | None : field or group of fields to group by @@ -394,15 +386,10 @@ def group(self, *, by:Any|None=None, _id:Any|None=None, query:dict={})->Self: Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/group/#mongodb-pipeline-pipe.-group """ - self.stages.append( - Group( - by=by or _id, - query=query - ) - ) + self.stages.append(Group(by=by or _id, query=query)) return self - def limit(self, value:int)->Self: + def limit(self, value: int) -> Self: """ Adds a limit stage to the current pipeline. Limits the number of documents passed to the next stage in the pipeline. @@ -416,7 +403,7 @@ def limit(self, value:int)->Self: Online MongoDB documentation: ----------------------------- Limits the number of documents passed to the next stage in the pipeline. - + $limit takes a positive integer that specifies the maximum number of documents to pass along. NOTE : Starting in MongoDB 5.0, the $limit pipeline aggregation has a 64-bit integer limit. Values @@ -425,25 +412,26 @@ def limit(self, value:int)->Self: Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/group/#mongodb-pipeline-pipe.-group """ - self.stages.append( - Limit(value=value) - ) + self.stages.append(Limit(value=value)) return self - def lookup(self, *, \ - name:str, - right:str|None=None, - on:str|None=None, - left_on:str|None=None, - local_field:str|None=None, - right_on:str|None=None, - foreign_field:str|None=None)->Self: + def lookup( + self, + *, + name: str, + right: str | None = None, + on: str | None = None, + left_on: str | None = None, + local_field: str | None = None, + right_on: str | None = None, + foreign_field: str | None = None, + ) -> Self: """ Adds a lookup stage to the current pipeline. Performs a left outer join to a collection in the same database to filter in documents from the "joined" collection for processing. The lookup stage adds a new array field to each input document. The new array field contains the matching documents from the "joined" collection. The lookup stage passes these reshaped documents to the next stage. - + Arguments: ---------------------------- - right / from (official MongoDB name), str : foreign collection @@ -487,24 +475,26 @@ def lookup(self, *, \ self.stages.append( Lookup( - right = right, - on = on, - left_on = left_on or local_field, - right_on = right_on or foreign_field, - name = name + right=right, + on=on, + left_on=left_on or local_field, + right_on=right_on or foreign_field, + name=name, ) ) return self def join( - self, - *, - other:str, - how:Literal["left", "right", "inner"]="left", # TODO : Implement outer and cross joins - on:str|None=None, - left_on:str|None=None, - right_on:str|None=None - )->Self: + self, + *, + other: str, + how: Literal[ + "left", "right", "inner" + ] = "left", # TODO : Implement outer and cross joins + on: str | None = None, + left_on: str | None = None, + right_on: str | None = None, + ) -> Self: """ Adds a combination of stages, that together reproduce SQL joins. This is a virtual and unofficial stage. It is not documented on MongoDB aggregation pipeline reference page. @@ -520,12 +510,12 @@ def join( the left collection that match documents from the right collection - - on, str|None=None: key to use to perform the join, + - on, str|None=None: key to use to perform the join, if the key name is the same in both collections - left_on, str|None=None: key to use on the left collection to perform the join. Must be use with right_on. - right_on, str|None=None: key to use on the right collection to perform the join - Must be use with left_on. + Must be use with left_on. """ # NOTE : Currently chose to implement a real SQL join, that is we chose to promote the matches in the local collection, the matches of the foreign collection @@ -544,57 +534,51 @@ def join( self.__inner_join(right=other, on=on, left_on=left_on, right_on=right_on) return self - - def __join_common(self, right:str, on:str|None, left_on:str|None, right_on:str|None)->str: - """Common parts between various join types""" + def __join_common( + self, right: str, on: str | None, left_on: str | None, right_on: str | None + ) -> str: + """Common parts between various join types""" _prefix = right.lower() join_field = "__" + _prefix + "__" self.stages.append( Lookup( - right = right, - on = on, - left_on = left_on, - right_on = right_on, - name = join_field + right=right, on=on, left_on=left_on, right_on=right_on, name=join_field ) ) - self.stages.append( - Unwind(path_to_array=join_field) - ) + self.stages.append(Unwind(path_to_array=join_field)) self.stages.append( ReplaceRoot( - document=MergeObjects( - operand=[ROOT, "$"+join_field] - ).expression + document=MergeObjects(operand=[ROOT, "$" + join_field]).expression ) ) - self.stages.append( - Project(exclude=join_field) - ) + self.stages.append(Project(exclude=join_field)) return join_field - def __left_join(self, right:str, on:str|None, left_on:str|None, right_on:str|None) -> None: + def __left_join( + self, right: str, on: str | None, left_on: str | None, right_on: str | None + ) -> None: """Implements SQL left join""" self.__join_common(right=right, on=on, left_on=left_on, right_on=right_on) - - - def __inner_join(self, right:str, on:str|None, left_on:str|None, right_on:str|None) -> None: + + def __inner_join( + self, right: str, on: str | None, left_on: str | None, right_on: str | None + ) -> None: """Implements SQL inner join""" - join_field = self.__join_common(right=right, on=on, left_on=left_on, right_on=right_on) - + join_field = self.__join_common( + right=right, on=on, left_on=left_on, right_on=right_on + ) + filter_no_match = Match( - query = { - join_field : [] - } - ) # used to filter out documents in the left collection, that has no match in the right collection + query={join_field: []} + ) # used to filter out documents in the left collection, that has no match in the right collection self.stages.insert(-3, filter_no_match) - def match(self, query:dict={}, expr:Expression=None, **kwargs:Any)->Self: + def match(self, query: dict = {}, expr: Expression = None, **kwargs: Any) -> Self: """ Adds a match stage to the current pipeline. Filters the documents to pass only the documents that match the specified condition(s) to the next pipeline stage. @@ -604,13 +588,13 @@ def match(self, query:dict={}, expr:Expression=None, **kwargs:Any)->Self: - query, dict : a simple MQL query use to filter the documents. - operand, Any:an aggregation expression used to filter the documents - + NOTE : Use query if you're using a MQL query and expression if you're using aggregation expressions. Online MongoDB documentation: ----------------------------- Filters the documents to pass only the documents that match the specified condition(s) to the next pipeline stage. - + $match takes a document that specifies the query conditions. The query syntax is identical to the read operation query syntax; i.e. $match does not accept raw aggregation expressions. Instead, use a $expr query expression to include aggregation expression in $match @@ -620,12 +604,16 @@ def match(self, query:dict={}, expr:Expression=None, **kwargs:Any)->Self: """ query = query | kwargs - self.stages.append( - Match(query=query, expr=expr) - ) + self.stages.append(Match(query=query, expr=expr)) return self - def out(self, collection:str|None=None, coll:str|None=None, *, db:str|None=None)->Self: + def out( + self, + collection: str | None = None, + coll: str | None = None, + *, + db: str | None = None, + ) -> Self: """ Adds an out stage to the current pipeline. Takes the documents returned by the aggregation pipeline and writes them to a specified collection. Starting in MongoDB 4.4, you can specify the output database. @@ -642,24 +630,22 @@ def out(self, collection:str|None=None, coll:str|None=None, *, db:str|None=None) WARNING : out replaces the specified collection if it exists. See [Replace Existing Collection](https://www.mongodb.com/docs/manual/reference/operator/aggregation/out/#std-label-replace-existing-collection) for details. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/out/#mongodb-pipeline-pipe.-out """ - self.stages.append( - Out( - collection=collection or coll, - db = db - ) - ) + self.stages.append(Out(collection=collection or coll, db=db)) return self - def project(self, *,\ - include : str|set[str]|list[str]|dict|bool|None = None, - exclude : str|set[str]|list[str]|dict|bool|None = None, - fields : str|set[str]|list[str]|None = None, - projection : dict = {}, - **kwargs:Any)->Self: + def project( + self, + *, + include: str | set[str] | list[str] | dict | bool | None = None, + exclude: str | set[str] | list[str] | dict | bool | None = None, + fields: str | set[str] | list[str] | None = None, + projection: dict = {}, + **kwargs: Any, + ) -> Self: """ Adds a project stage to the current pipeline. Passes along the documents with the requested fields to the next stage in the pipeline. The specified fields can be existing fields from the input documents or newly computed fields. @@ -679,28 +665,31 @@ def project(self, *,\ the suppression of the _id field, the addition of new fields, and the resetting of the values of existing fields. Alternatively, you may specify the exclusion of fields. The $project specifications have the following forms: - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/project/#mongodb-pipeline-pipe.-project """ projection = projection | kwargs self.stages.append( - Project( - include = include, - exclude = exclude, - fields = fields, - projection = projection - ) + Project( + include=include, exclude=exclude, fields=fields, projection=projection ) + ) return self - def replace_root(self, path:str|None=None, path_to_new_root:str|None=None, *,document:dict|None=None)->Self: + def replace_root( + self, + path: str | None = None, + path_to_new_root: str | None = None, + *, + document: dict | None = None, + ) -> Self: """ Adds a replace_root stage to the current pipeline. Replaces the input document with the specified document. The operation replaces all existing fields in the input document, including the _id field. You can promote an existing embedded document to the top level, or create a new document for promotion - + Arguments: ------------------------------------- @@ -717,19 +706,22 @@ def replace_root(self, path:str|None=None, path_to_new_root:str|None=None, *,doc The replacement document can be any valid expression that resolves to a document. The stage errors and fails if is not a document. For more information on expressions, see Expressions. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/replaceRoot/#mongodb-pipeline-pipe.-replaceRoot """ self.stages.append( - ReplaceRoot( - path=path or path_to_new_root, - document=document - ) - ) + ReplaceRoot(path=path or path_to_new_root, document=document) + ) return self - def replace_with(self, path:str|None=None, path_to_new_root:str|None=None, *,document:dict|None=None)->Self: + def replace_with( + self, + path: str | None = None, + path_to_new_root: str | None = None, + *, + document: dict | None = None, + ) -> Self: """ Adds a replace_with stage to the current pipeline. @@ -749,19 +741,16 @@ def replace_with(self, path:str|None=None, path_to_new_root:str|None=None, *,doc The replacement document can be any valid expression that resolves to a document. The stage errors and fails if is not a document. For more information on expressions, see Expressions. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/replaceRoot/#mongodb-pipeline-pipe.-replaceRoot """ self.stages.append( - ReplaceRoot( - path=path or path_to_new_root, - document=document - ) - ) + ReplaceRoot(path=path or path_to_new_root, document=document) + ) return self - def sample(self, value:int)->Self: + def sample(self, value: int) -> Self: """ Adds a sample stage to the current pipeline. Randomly selects the specified number of documents from the input documents. @@ -774,36 +763,34 @@ def sample(self, value:int)->Self: Online MongoDB documentation: ----------------------------- Randomly selects the specified number of documents from the input documents. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/sample/#mongodb-pipeline-pipe.-sample """ - self.stages.append( - Sample(value=value) - ) - + self.stages.append(Sample(value=value)) + return self - + # TODO : Check that clause_type and facet_type parameters don't break anything def search( - self, - path:str|list[str]|None=None, - query:str|list[str]|None=None, - *, - operator_name:OperatorLiteral|None=None, - collector_name:Literal["facet"]|None=None, - # Including the below parameters to give them visibility - #--------------------------------------------------- - clause_type:ClauseType|None=None, - facet_type:FacetType|None=None, - #--------------------------------------------------- - index:str="default", - count:CountOptions|None=None, - highlight:HighlightOptions|None=None, - return_stored_source:bool=False, - score_details:bool=False, - **kwargs:Any - )->Self: + self, + path: str | list[str] | None = None, + query: str | list[str] | None = None, + *, + operator_name: OperatorLiteral | None = None, + collector_name: Literal["facet"] | None = None, + # Including the below parameters to give them visibility + # --------------------------------------------------- + clause_type: ClauseType | None = None, + facet_type: FacetType | None = None, + # --------------------------------------------------- + index: str = "default", + count: CountOptions | None = None, + highlight: HighlightOptions | None = None, + return_stored_source: bool = False, + score_details: bool = False, + **kwargs: Any, + ) -> Self: """ Adds a search stage to the current pipeline. The search stage performs a full-text search on the specified field or fields which must be covered by an Atlas Search index. @@ -817,7 +804,7 @@ def search( - index, str : name of the index to use for the search. Defaults to defaut - count, CountOptions|None : document that specifies the count options for retrieving a count of the results - - highlight, HighlightOptions|None : document that specifies the highlight options for + - highlight, HighlightOptions|None : document that specifies the highlight options for displaying search terms in their original context - return_stored_source, bool : Indicates whether to use the copy of the documents in the Atlas Search index (with just a subset of the fields) @@ -827,9 +814,9 @@ def search( False => Do a lookup and return the original documents. - score_details, bool : Indicates whether to retrieve the detailed breakdown of the score for the documents in the results. Defaults to False. - To view the details, you must use the $meta expression in the + To view the details, you must use the $meta expression in the $project stage. - - operator_name, str : Name of the operator to search with. Use the compound operator to run a + - operator_name, str : Name of the operator to search with. Use the compound operator to run a compound (i.e query with multiple operators). - kwargs, Any : Operators specific options. Includes (non-exhaustive): @@ -839,10 +826,10 @@ def search( - allow_analyzed_field, bool (controls index scanning) - synonyms - like, dict|list[dict] (allow looking for similar documents) - + Online MongoDB documentation: ----------------------------- - The search stage performs a full-text search on the specified field or fields + The search stage performs a full-text search on the specified field or fields which must be covered by an Atlas Search index. Source : https://www.mongodb.com/docs/atlas/atlas-search/query-syntax/#mongodb-pipeline-pipe.-search @@ -851,7 +838,6 @@ def search( if not collector_name and not operator_name: operator_name = "text" - # If pipeline is empty, adds a search stage if len(self) == 0: # if facet_type is not None: @@ -870,18 +856,20 @@ def search( highlight=highlight, return_stored_source=return_stored_source, score_details=score_details, - **kwargs + **kwargs, ) - + # If pipeline is not empty then the first stage must be Search stage. # If so, adds the operator to the existing stage using Compound. elif len(self) >= 1 and isinstance(self.stages[0], Search): - kwargs.update({ - # "collector_name":collector_name, - "operator_name":operator_name, - "path":path, - "query":query, - }) + kwargs.update( + { + # "collector_name":collector_name, + "operator_name": operator_name, + "path": path, + "query": query, + } + ) has_facet_arg = self.__has_facet_arg(**kwargs) if has_facet_arg: self._append_facet(facet_type, **kwargs) @@ -890,33 +878,32 @@ def search( else: raise TypeError("search stage has to be the first stage of the pipeline") - + return self - def search_meta( - self, - path:str|list[str]|None=None, - query:str|list[str]|None=None, - *, - operator_name:OperatorLiteral|None=None, - collector_name:Literal["facet"]|None=None, - # Including the below parameters to give them visibility - #--------------------------------------------------- - clause_type:ClauseType|None=None, - facet_type:FacetType|None=None, - #--------------------------------------------------- - index:str="default", - count:CountOptions|None=None, - highlight:HighlightOptions|None=None, - return_stored_source:bool=False, - score_details:bool=False, - **kwargs:Any - )->Self: + self, + path: str | list[str] | None = None, + query: str | list[str] | None = None, + *, + operator_name: OperatorLiteral | None = None, + collector_name: Literal["facet"] | None = None, + # Including the below parameters to give them visibility + # --------------------------------------------------- + clause_type: ClauseType | None = None, + facet_type: FacetType | None = None, + # --------------------------------------------------- + index: str = "default", + count: CountOptions | None = None, + highlight: HighlightOptions | None = None, + return_stored_source: bool = False, + score_details: bool = False, + **kwargs: Any, + ) -> Self: """ Adds a searchMeta stage to the current pipeline. The searchMeta stage returns different types of metadata result documents. - + NOTE : if used, search has to be the first stage of the pipeline Arguments: @@ -926,7 +913,7 @@ def search_meta( - index, str : name of the index to use for the search. Defaults to defaut - count, dict|None : document that specifies the count options for retrieving a count of the results - - highlight, dict|None : document that specifies the highlight options for + - highlight, dict|None : document that specifies the highlight options for displaying search terms in their original context - return_stored_source, bool : Indicates whether to use the copy of the documents in the Atlas Search index (with just a subset of the fields) @@ -936,9 +923,9 @@ def search_meta( False => Do a lookup and return the original documents. - score_details, bool : Indicates whether to retrieve the detailed breakdown of the score for the documents in the results. Defaults to False. - To view the details, you must use the $meta expression in the + To view the details, you must use the $meta expression in the $project stage. - - operator_name, str : Name of the operator to search with. Use the compound operator to run a + - operator_name, str : Name of the operator to search with. Use the compound operator to run a compound (i.e query with multiple operators). - kwargs, Any : Operators specific options. Includes (non-exhaustive): @@ -949,11 +936,10 @@ def search_meta( - synonyms - like, dict|list[dict] (allow looking for similar documents) """ - + if not collector_name and not operator_name: operator_name = "text" - # If pipeline is empty, adds a search stage if len(self) == 0: self._init_search( @@ -967,18 +953,20 @@ def search_meta( highlight=highlight, return_stored_source=return_stored_source, score_details=score_details, - **kwargs + **kwargs, ) - + # If pipeline is not empty then the first stage must be Search stage. # If so, adds the operator to the existing stage using Compound. elif len(self) >= 1 and isinstance(self.stages[0], SearchMeta): - kwargs.update({ - # "collector_name":collector_name, - "operator_name":operator_name, - "path":path, - "query":query, - }) + kwargs.update( + { + # "collector_name":collector_name, + "operator_name": operator_name, + "path": path, + "query": query, + } + ) has_facet_arg = self.__has_facet_arg(facet_type=facet_type, **kwargs) if has_facet_arg: self._append_facet(facet_type, **kwargs) @@ -990,21 +978,21 @@ def search_meta( return self - def _init_search( - self, - search_class:Literal["search", "searchMeta"], - path:str|list[str]|None=None, - query:str|list[str]|None=None, - *, - operator_name:OperatorLiteral|None=None, - collector_name:Literal["facet"]|None=None, - index:str="default", - count:CountOptions|None=None, - highlight:HighlightOptions|None=None, - return_stored_source:bool=False, - score_details:bool=False, - **kwargs:Any)->None: + self, + search_class: Literal["search", "searchMeta"], + path: str | list[str] | None = None, + query: str | list[str] | None = None, + *, + operator_name: OperatorLiteral | None = None, + collector_name: Literal["facet"] | None = None, + index: str = "default", + count: CountOptions | None = None, + highlight: HighlightOptions | None = None, + return_stored_source: bool = False, + score_details: bool = False, + **kwargs: Any, + ) -> None: """Adds a search stage to the pipeline.""" if not collector_name and operator_name: @@ -1017,7 +1005,7 @@ def _init_search( highlight=highlight, return_stored_source=return_stored_source, score_details=score_details, - **kwargs + **kwargs, ) else: search_stage = SearchStageMap[search_class].init_facet( @@ -1030,24 +1018,22 @@ def _init_search( return_stored_source=return_stored_source, score_details=score_details, collector_name=collector_name, - **kwargs + **kwargs, ) - self.stages.append( - search_stage - ) + self.stages.append(search_stage) return None - def _append_clause( - self, - clause_type:ClauseType|None=None, - *, - operator_name:OperatorLiteral|None=None, - path:str|list[str]|None=None, - query:str|list[str]|None=None, - **kwargs:Any)->None: + self, + clause_type: ClauseType | None = None, + *, + operator_name: OperatorLiteral | None = None, + path: str | list[str] | None = None, + query: str | list[str] | None = None, + **kwargs: Any, + ) -> None: """Adds a clause to the search stage of the pipeline.""" first_stage = self.stages[0] @@ -1059,35 +1045,47 @@ def _append_clause( else: default_minimum_should_match = 0 - minimum_should_match = kwargs.pop("minimum_should_match", default_minimum_should_match) + minimum_should_match = kwargs.pop( + "minimum_should_match", default_minimum_should_match + ) - kwargs.update({ - "path":path, - "query":query - }) + kwargs.update({"path": path, "query": query}) if isinstance(first_stage.collector, Facet): if isinstance(first_stage.collector.operator, Compound): # Add clause to existing compound - first_stage.__get_operators_map__(operator_name=operator_name)(clause_type, **kwargs) + first_stage.__get_operators_map__(operator_name=operator_name)( + clause_type, **kwargs + ) elif first_stage.collector.operator is None: # Create a compound operator with the to-be operator as a clause new_operator = Compound(minimum_should_match=minimum_should_match) - new_operator.__get_operators_map__(operator_name=operator_name)(clause_type, **kwargs) - first_stage.operator = new_operator + new_operator.__get_operators_map__(operator_name=operator_name)( + clause_type, **kwargs + ) + first_stage.operator = new_operator else: # Retrieve current operator and create a compound operator # and add the current operator as a clause - new_operator = Compound(should=[first_stage.collector.operator], minimum_should_match=minimum_should_match) - new_operator.__get_operators_map__(operator_name=operator_name)(clause_type, **kwargs) + new_operator = Compound( + should=[first_stage.collector.operator], + minimum_should_match=minimum_should_match, + ) + new_operator.__get_operators_map__(operator_name=operator_name)( + clause_type, **kwargs + ) first_stage.operator = new_operator elif isinstance(first_stage.operator, Compound): # Add clause to existing compound - first_stage.__get_operators_map__(operator_name=operator_name)(clause_type, **kwargs) + first_stage.__get_operators_map__(operator_name=operator_name)( + clause_type, **kwargs + ) elif first_stage.operator is not None: # Create a compound operator with the to-be operator as a clause new_operator = Compound(minimum_should_match=minimum_should_match) - new_operator.__get_operators_map__(operator_name=operator_name)(clause_type, **kwargs) + new_operator.__get_operators_map__(operator_name=operator_name)( + clause_type, **kwargs + ) first_stage.operator = new_operator else: @@ -1096,8 +1094,7 @@ def _append_clause( return None - - def _append_facet(self, facet_type:FacetType|None=None, **kwargs:Any)->None: + def _append_facet(self, facet_type: FacetType | None = None, **kwargs: Any) -> None: """Adds a facet to the search stage of the pipeline.""" if not facet_type: @@ -1114,12 +1111,10 @@ def _append_facet(self, facet_type:FacetType|None=None, **kwargs:Any)->None: first_stage.collector.facet(type=facet_type, **kwargs) - return None - @classmethod - def __has_facet_arg(cls, **kwargs:Any)->bool: + def __has_facet_arg(cls, **kwargs: Any) -> bool: """Checks if the kwargs contains a facet argument""" facet_args = ["facet_type", "num_buckets", "boundaries", "default"] @@ -1131,9 +1126,8 @@ def __has_facet_arg(cls, **kwargs:Any)->bool: break return has_facet_arg - - - def set(self, document:dict={}, **kwargs:Any)->Self: + + def set(self, document: dict = {}, **kwargs: Any) -> Self: """ Adds a set stage to the current pipeline. Adds new fields to documents. $set outputs documents that conain all existing fields from the inputs documents @@ -1148,17 +1142,15 @@ def set(self, document:dict={}, **kwargs:Any)->Self: Online MongoDB documentation: ----------------------------- Adds new fields to documents. set outputs documents that contain all existing fields from the inputs documents and newly added fields. Both stages are equivalent to a project stage that explicitly specifies all existing fields in the inputs documents and adds the new fields. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/set/#mongodb-pipeline-pipe.-set """ document = document | kwargs - self.stages.append( - Set(document=document) - ) + self.stages.append(Set(document=document)) return self - def skip(self, value:int)->Self: + def skip(self, value: int) -> Self: """ Adds a skip stage to the current pipeline. Skips over the specified number of documents that pass into the stage and passes the remaining documents to the next stage in the pipeline. @@ -1171,30 +1163,31 @@ def skip(self, value:int)->Self: Online MongoDB documentation: ----------------------------- Skips over the specified number of documents that pass into the stage and passes the remaining documents to the next stage in the pipeline. - + $skip takes a positive integer that specifies the maximum number of documents to skip. NOTE : Starting in MongoDB 5.0, the $skip pipeline aggregation has a 64-bit integer limit. Values passed to the pipeline which exceed this limit will return an invalid argument error. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/skip/#mongodb-pipeline-pipe.-skip """ - self.stages.append( - Skip(value=value) - ) + self.stages.append(Skip(value=value)) return self - def sort(self, *,\ - descending : str|list[str]|dict|bool|None = None, - ascending : str|list[str]|dict|bool|None = None, - by : list[str]|None = None, - query : dict[str, Literal[1, -1]] = {}, - **kwargs:Any)->Self: + def sort( + self, + *, + descending: str | list[str] | dict | bool | None = None, + ascending: str | list[str] | dict | bool | None = None, + by: list[str] | None = None, + query: dict[str, Literal[1, -1]] = {}, + **kwargs: Any, + ) -> Self: """ Adds a sort stage to the current pipeline. Sorts all input documents and returns them to the pipeline in sorted order. - + Arguments: ----------------------- - statement, dict : the statement generated after instantiation @@ -1220,22 +1213,17 @@ def sort(self, *,\ Online MongoDB documentation: ----------------------------- Sorts all input documents and returns them to the pipeline in sorted order. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/sort/#mongodb-pipeline-pipe.-sort """ query = query | kwargs self.stages.append( - Sort( - descending = descending, - ascending = ascending, - by = by, - query = query - ) - ) + Sort(descending=descending, ascending=ascending, by=by, query=query) + ) return self - def sort_by_count(self, by:str)->Self: + def sort_by_count(self, by: str) -> Self: """ Adds a sort_by_count stage to the current pipeline. Groups incoming documents based on the value of a specified expression, then computes the count of documents in each distinct group. @@ -1262,12 +1250,12 @@ def sort_by_count(self, by:str)->Self: Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/sortByCount/#mongodb-pipeline-pipe.-sortByCount """ - self.stages.append( - SortByCount(by=by) - ) + self.stages.append(SortByCount(by=by)) return self - - def union_with(self, collection:str, coll:str, pipeline:list[dict]|None=None)->Self: + + def union_with( + self, collection: str, coll: str, pipeline: list[dict] | None = None + ) -> Self: """ Adds a union_with stage to the current pipeline. Performs a union of two collections. unionWith combines pipeline results from two collections into a single result set. The stage outputs the combined result set (including duplicates) to the next stage. @@ -1275,7 +1263,7 @@ def union_with(self, collection:str, coll:str, pipeline:list[dict]|None=None)->S The order in which the combined result set documents are output is unspecified. Arguments: --------------------------------- - + - collection / coll, str : The collection or view whose pipeline results you wish to include in the result set - pipeline, list[dict] | Pipeline | None : An aggregation pipeline to apply to the specified coll. @@ -1285,24 +1273,22 @@ def union_with(self, collection:str, coll:str, pipeline:list[dict]|None=None)->S unionWith combines pipeline results from two collections into a single result set. The stage outputs the combined result set (including duplicates) to the next stage. The order in which the combined result set documents are output is unspecified. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/unionWith/#mongodb-pipeline-pipe.-unionWith """ - self.stages.append( - UnionWith( - collection=collection or coll, - pipeline=pipeline) - ) + self.stages.append(UnionWith(collection=collection or coll, pipeline=pipeline)) return self - def unwind(self, \ - path:str|None=None, - path_to_array:str|None=None, - include_array_index:str|None=None, - always:bool=False, - preserve_null_and_empty_arrays:bool=False)->Self: + def unwind( + self, + path: str | None = None, + path_to_array: str | None = None, + include_array_index: str | None = None, + always: bool = False, + preserve_null_and_empty_arrays: bool = False, + ) -> Self: """ Adds a unwind stage to the current pipeline. @@ -1318,54 +1304,50 @@ def unwind(self, \ ----------------------------- Deconstructs an array field from the input documents to output a document for each element. Each output document is the input document with the value of the array field replaced by the element. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/unwind/#mongodb-pipeline-pipe.-unwind """ self.stages.append( - Unwind( - path = path or path_to_array, - include_array_index = include_array_index, - always = always or preserve_null_and_empty_arrays, - ) + Unwind( + path=path or path_to_array, + include_array_index=include_array_index, + always=always or preserve_null_and_empty_arrays, ) + ) return self - - def unset(self, field:str|None=None, fields:list[str]|None=None)->Self: + def unset(self, field: str | None = None, fields: list[str] | None = None) -> Self: """ Adds an unset stage to the current pipeline. Removes/excludes fields from documents. - + Arguments: ------------------------------- - field, str|None: field to be removed - fields, list[str]|None, list of fields to be removed - + Online MongoDB documentation: ----------------------------- Removes/excludes fields from documents. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/unset/#definition """ - self.stages.append( - Unset(field=field, fields=fields) - ) + self.stages.append(Unset(field=field, fields=fields)) return self - def vector_search( - self, - index:str, - path:str, - query_vector:list[float], - num_candidates:int, - limit:int, - filter:dict|None=None, - )->Self: + self, + index: str, + path: str, + query_vector: list[float], + num_candidates: int, + limit: int, + filter: dict | None = None, + ) -> Self: """ Adds a vector_search stage to the current pipeline. @@ -1378,7 +1360,7 @@ def vector_search( - num_candidates, int : number of nearest neighbors to use during the search - limit, int : number of documents to return in the results - filter, dict|None : any MQL match expression that compares an indexed field with a boolean, number (not decimals), or string to use as a prefilter - + """ self.stages.append( @@ -1388,8 +1370,7 @@ def vector_search( query_vector=query_vector, num_candidates=num_candidates, limit=limit, - filter=filter + filter=filter, ) - ) return self diff --git a/src/monggregate/search/operators/__init__.py b/src/monggregate/search/operators/__init__.py index 63750285..fe00bbd7 100644 --- a/src/monggregate/search/operators/__init__.py +++ b/src/monggregate/search/operators/__init__.py @@ -24,7 +24,6 @@ """ - from monggregate.search.operators.autocomplete import Autocomplete from monggregate.search.operators.compound import Compound from monggregate.search.operators.equals import Equals @@ -35,13 +34,23 @@ from monggregate.search.operators.text import Text from monggregate.search.operators.wildcard import Wildcard -AnyOperator = Autocomplete | Compound | Equals | Exists | MoreLikeThis | Range | Regex | Text | Wildcard +AnyOperator = ( + Autocomplete + | Compound + | Equals + | Exists + | MoreLikeThis + | Range + | Regex + | Text + | Wildcard +) OperatorMap = { "autocomplete": Autocomplete, "compound": Compound, "equals": Equals, "exists": Exists, - "moreLikeThis": MoreLikeThis, + "more_like_this": MoreLikeThis, "range": Range, "regex": Regex, "text": Text, diff --git a/src/monggregate/stages/lookup.py b/src/monggregate/stages/lookup.py index 712b281a..72121fdd 100644 --- a/src/monggregate/stages/lookup.py +++ b/src/monggregate/stages/lookup.py @@ -355,7 +355,7 @@ def set_type(cls, value: str, values: dict) -> str: if right and left_on and right_on and not (let or pipeline): type_ = "simple" - elif let and left_on and right_on and pipeline is not None: + elif let and pipeline is not None: type_ = "correlated" elif not let and pipeline is not None: @@ -376,36 +376,15 @@ def set_type(cls, value: str, values: dict) -> str: def expression(self) -> Expression: """Generates statement from attributes""" - # Generate statement: - # ----------------------------------------------- - if self.type_ == "simple": - statement = { - "$lookup": { - "from": self.right, - "localField": self.left_on, - "foreignField": self.right_on, - "as": self.name, - } - } - elif self.type_ == "uncorrelated": - statement = { - "$lookup": { - "from": self.right, - "let": self.let, - "pipeline": self.pipeline, - "as": self.name, - } - } - else: # should be correlated case - statement = { - "$lookup": { - "from": self.right, - "localField": self.right_on, - "foreignField": self.right_on, - "let": self.let, - "pipeline": self.pipeline, - "as": self.name, - } + statement = { + "$lookup": { + "from": self.right, + "localField": self.left_on, + "foreignField": self.right_on, + "let": self.let, + "pipeline": self.pipeline, + "as": self.name, } + } return self.express(statement) diff --git a/src/monggregate/stages/match.py b/src/monggregate/stages/match.py index 48f5bbe3..9e141c70 100644 --- a/src/monggregate/stages/match.py +++ b/src/monggregate/stages/match.py @@ -52,7 +52,8 @@ from monggregate.base import pyd, Expression from monggregate.stages.stage import Stage from monggregate.operators.operator import Operator -#from monggregate.expressions import Expression +# from monggregate.expressions import Expression + class Match(Stage): """ @@ -63,14 +64,14 @@ class Match(Stage): - query, dict : a simple MQL query use to filter the documents. - operand, Any:an aggregation expression used to filter the documents - + NOTE : Use query if you're using a MQL query and expression if you're using aggregation expressions. - - + + Online MongoDB documentation: ----------------------------- Filters the documents to pass only the documents that match the specified condition(s) to the next pipeline stage. - + $match takes a document that specifies the query conditions. The query syntax is identical to the read operation query syntax; i.e. $match does not accept raw aggregation expressions. Instead, use a $expr query expression to include aggregation expression in $match @@ -78,30 +79,30 @@ class Match(Stage): Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/match/#mongodb-pipeline-pipe.-match """ - query : dict = {} #| None - expr : Expression | None = None + query: dict = {} # | None + expr: Expression | None = None @pyd.validator("expr", pre=True, always=True) - def validate_operand(cls, expr:Any)-> Any: - - c1 = isinstance(expr, dict) # expression is "expressed/resolved" already - c2 = isinstance(expr, Operator) # expression is an operator object - - if expr and not (c1 or c2 ): - raise ValueError("The expression argument must be a valid expression, operator or a dict.") - + def validate_operand(cls, expr: Any) -> Any: + c1 = isinstance(expr, dict) # expression is "expressed/resolved" already + c2 = isinstance(expr, Operator) # expression is an operator object + + if expr and not (c1 or c2): + raise ValueError( + "The expression argument must be a valid expression, operator or a dict." + ) + if isinstance(expr, dict) and "$expr" not in expr: - expr = {"$expr":expr} - + expr = {"$expr": expr} + return expr @property def expression(self) -> Expression: - if self.expr: - _statement = self.express({"$match":{"$expr":self.expr}}) - + _statement = self.express({"$match": self.expr}) + else: - _statement = self.express({"$match":self.query}) + _statement = self.express({"$match": self.query}) return _statement diff --git a/src/monggregate/stages/project.py b/src/monggregate/stages/project.py index 6d15f8a1..9a516f5a 100644 --- a/src/monggregate/stages/project.py +++ b/src/monggregate/stages/project.py @@ -131,6 +131,7 @@ ProjectionArgs = str | list[str] | set[str] + class Project(Stage): """ Abstraction of MongoDB $project statement which Passes along the documents with the requested fields to the next stage in the pipeline. @@ -141,7 +142,7 @@ class Project(Stage): - fields, ProjectionArgs | None : fields to be kept or excluded (depending on include/exclude parameters when those are booleans) - include, ProjectionArgs| dict | bool | None : fields to be kept - exclude, ProjectionArgs | dict | bool | None : fields to be excluded - + Online MongoDB documentation: ----------------------------- Passes along the documents with the requested fields to the next stage in the pipeline. The specified fields can be existing fields from the input documents or newly computed fields. @@ -150,55 +151,87 @@ class Project(Stage): the suppression of the _id field, the addition of new fields, and the resetting of the values of existing fields. Alternatively, you may specify the exclusion of fields. The $project specifications have the following forms: - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/project/#mongodb-pipeline-pipe.-project """ - include : list[str] | dict | bool | None = None - exclude : list[str] | dict | bool | None = None - fields : list[str] | None = None - projection : dict = {} + include: list[str] | dict | bool | None = None + exclude: list[str] | dict | bool | None = None + fields: list[str] | None = None + projection: dict = {} @pyd.validator("include", "exclude", pre=True, always=True) @classmethod - def parse_include_exclude(cls, value:ProjectionArgs|dict|bool|None)->list[str]|dict|bool|None: + def parse_include_exclude( + cls, value: ProjectionArgs | dict | bool | None + ) -> list[str] | dict | bool | None: """Parses include and exclude arguments""" return to_unique_list(value) @pyd.validator("exclude") @classmethod - def validates_booleans(cls, exclude:ProjectionArgs|dict|bool|None, values:dict[str, ProjectionArgs|dict|bool|None]) -> list[str]|bool|None: + def validates_booleans( + cls, + exclude: ProjectionArgs | dict | bool | None, + values: dict[str, ProjectionArgs | dict | bool | None], + ) -> list[str] | bool | None: """Validates combination of include and exclude""" include = values.get("include") if isinstance(include, bool) and isinstance(exclude, bool): - raise ValueError("Cannot both include and exclude fields when using include and exlude as booleans") + raise ValueError( + "Cannot both include and exclude fields when using include and exlude as booleans" + ) return exclude # TODO : When using fields, consider include = True as default @pyd.validator("fields", pre=True) @classmethod - def validates_fields(cls, value:ProjectionArgs|None, values:dict[str, list[str]|dict|bool|None])-> list[str]|None: + def validates_fields( + cls, + value: ProjectionArgs | None, + values: dict[str, list[str] | dict | bool | None], + ) -> list[str] | None: """Validates fields""" include = values.get("include") exclude = values.get("exclude") - if value and not (isinstance(include, bool) or isinstance(exclude, bool)): - raise ValueError("Either include or exclude must be set and be a boolean when using fields") + if value: + # When fields is provided, include or exclude must be boolean + if not (isinstance(include, bool) or isinstance(exclude, bool)): + raise ValueError( + "Either include or exclude must be set and be a boolean when using fields" + ) + + # Forbid mixing boolean and list/dict approaches + if ( + isinstance(include, bool) + and (isinstance(exclude, list) or isinstance(exclude, dict)) + ) or ( + isinstance(exclude, bool) + and (isinstance(include, list) or isinstance(include, dict)) + ): + raise ValueError( + "Cannot mix boolean include/exclude with list/dict include/exclude. " + "Use include=['field1', 'field2'], exclude=['field3', 'field4'] instead." + ) fields = to_unique_list(value) return fields - @pyd.validator("projection", pre=True, always=True) @classmethod - def generates_projection(cls, projection:dict, values:dict[str, list[str] | dict | bool | None])->dict: + def generates_projection( + cls, projection: dict, values: dict[str, list[str] | dict | bool | None] + ) -> dict: """Validates and if necessary generates projection""" - def _to_projection(projection:dict, projection_args:list[str]|dict, include:bool)->None: + def _to_projection( + projection: dict, projection_args: list[str] | dict, include: bool + ) -> None: """ Inserts fields in include or exlude arguments inside a projection Ex: @@ -214,12 +247,18 @@ def _to_projection(projection:dict, projection_args:list[str]|dict, include:bool - include, bool : whether the fields are to be included or excluded """ - if isinstance(projection_args, list): for field in projection_args: projection[field] = include else: - projection.update(projection_args) + # For dict, we need to respect the include parameter + # If include=False (exclude case), set all fields to 0 + # If include=True (include case), use the dict values as-is + if include: + projection.update(projection_args) + else: + for field in projection_args: + projection[field] = 0 # Retrieving validated fields # ----------------------------- @@ -227,18 +266,18 @@ def _to_projection(projection:dict, projection_args:list[str]|dict, include:bool exclude = values.get("exclude") fields = values.get("fields") - # Initizaling projection if not provided # -------------------------------------- if not projection: - # Case #1 : fields is provided # ------------------------------ if fields: # validates_fields ensures that include and exclude are either None or booleans when fields is provided - # None or boolean_value = boolean_value # valdiates_booleans ensures that include or exclude are not both, booleans at the same time - _to_projection(projection, fields, include or exclude) + if isinstance(include, bool): + _to_projection(projection, fields, include) + elif isinstance(exclude, bool): + _to_projection(projection, fields, not exclude) # Case #2 : fields is not provided # ------------------------------- @@ -251,13 +290,14 @@ def _to_projection(projection:dict, projection_args:list[str]|dict, include:bool # TODO : Validate final projection if not projection: - raise ValueError(f"Invalid combination of arguments with include={include}, exclude={exclude} and fields={fields}.") + raise ValueError( + f"Invalid combination of arguments with include={include}, exclude={exclude} and fields={fields}." + ) return projection - @property - def expression(self)->Expression: + def expression(self) -> Expression: """Generates statement from other attributes""" - return self.express({"$project":self.projection}) + return self.express({"$project": self.projection}) diff --git a/src/monggregate/stages/search/base.py b/src/monggregate/stages/search/base.py index 3fcf218b..361e5cb0 100644 --- a/src/monggregate/stages/search/base.py +++ b/src/monggregate/stages/search/base.py @@ -118,13 +118,13 @@ def init(cls, values: dict) -> dict: if operator_name and not operator: operator = OperatorMap[operator_name](**values) - # values.pop("collector", None) - # values["operator"] = operator + values.pop("collector", None) + values["operator"] = operator if collector_name and not collector: collector = Facet(operator=operator, **values) values.pop("operator", None) - # values["collector"] = collector + values["collector"] = collector if not collector and not operator: values["operator"] = Compound() diff --git a/src/monggregate/stages/sort.py b/src/monggregate/stages/sort.py index e05d1d4b..30e798ee 100644 --- a/src/monggregate/stages/sort.py +++ b/src/monggregate/stages/sort.py @@ -101,6 +101,7 @@ SortArgs = str | list[str] | set[str] + class Sort(Stage): """ Abstration of MongoDB $sort statement that sorts all input documents and returns them to the pipeline in sorted order. @@ -130,88 +131,93 @@ class Sort(Stage): Online MongoDB documentation: ----------------------------- Sorts all input documents and returns them to the pipeline in sorted order. - + Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/sort/#mongodb-pipeline-pipe.-sort """ - descending : list[str] | dict | bool | None = None - ascending : list[str] | dict | bool | None = None - by : list[str] | None = None - query : dict[str, Literal[1, -1]] = {} + descending: list[str] | dict | bool | None = None + ascending: list[str] | dict | bool | None = None + by: list[str] | None = None + query: dict[str, Literal[1, -1]] = {} # NOTE : The below are pyd.validators are very close to what is used for project => CONSIDER factorizing @pyd.validator("ascending", "descending", pre=True, always=True) @classmethod - def parse_ascending_descending(cls, value:SortArgs|dict|bool|None)->list[str]|dict|bool|None: + def parse_ascending_descending( + cls, value: SortArgs | dict | bool | None + ) -> list[str] | dict | bool | None: """Parses ascending and descending""" return to_unique_list(value) - @pyd.validator("ascending") + @pyd.root_validator(pre=True) @classmethod - def validates_booleans(cls, ascending:list[str]|dict|bool|None, values:dict)->list[str]|bool|None: + def validates_booleans(cls, values: dict) -> dict: """Validates combination of ascending and descending""" + ascending = values.get("ascending") descending = values.get("descending") # Preventing to use both ascending and descending as booleans at the same time # to avoid conflicting behaviors if isinstance(descending, bool) and isinstance(ascending, bool): - raise ValueError("Cannot use both ascending and descending as booleans at the same time") + raise ValueError( + "Cannot use both ascending and descending as booleans at the same time" + ) # If neither is set, in case by is set, we set ascending as True by default # so that the documents will be sorted by the fields provided in by in ascending order elif ascending is None and descending is None: - ascending = True + values["ascending"] = True # If descending is provided as a bool, we symetrically compute ascending, so that we only need one of argument # in validates_by below elif ascending is None and isinstance(descending, bool): - ascending = not descending + values["ascending"] = not descending # and reciprocally, if ascending is provided as a bool, we symetrically compute descending. # (WARNING: removing this branch breaks the pyd.validator on a functional stand point) elif descending is None and isinstance(ascending, bool): - descending = not ascending + values["descending"] = not ascending - elif isinstance(ascending, list) or isinstance(descending, list): + elif isinstance(ascending, (list, set)) or isinstance(descending, (list, set)): pass - elif isinstance(ascending, dict) or isinstance(descending, dict): pass - # if we are in none of the cases above, we raise an error. Hopefully we don't have false positives ! else: raise TypeError( f"Wrong combination of arguments.\ Cannot have ascending with type {type(ascending)} and descending with type {type(descending)} at the same time" - ) - - return ascending + ) + return values @pyd.validator("by", pre=True) @classmethod - def validates_by(cls, value:SortArgs|None, values:dict)->list[str]|None: + def validates_by(cls, value: SortArgs | None, values: dict) -> list[str] | None: """Validates by""" ascending = values.get("ascending") descending = values.get("descending") if value and not (isinstance(ascending, bool) or isinstance(descending, bool)): - raise ValueError("Either ascending or descending must be set and be a boolean when using fields") + raise ValueError( + "Either ascending or descending must be set and be a boolean when using fields" + ) return to_unique_list(value) - @pyd.validator("query", pre=True, always=True) @classmethod - def generates_query(cls, query:dict, values:dict)->dict: + def generates_query(cls, query: dict, values: dict) -> dict: """Generates query if not provided""" - def _to_query(query:dict, sort_args:list[str]|dict, direction:bool)->None: + def _to_query( + query: dict, sort_args: list[str] | dict, direction: bool + ) -> None: """ Inserts fields in ascending and descending arguments inside a query Ex: @@ -224,8 +230,8 @@ def _to_query(query:dict, sort_args:list[str]|dict, direction:bool)->None: """ _sort_order_map = { - True:1, # ascending - False:-1 # descending + True: 1, # ascending + False: -1, # descending } if isinstance(sort_args, list): @@ -234,25 +240,27 @@ def _to_query(query:dict, sort_args:list[str]|dict, direction:bool)->None: else: query.update(sort_args) - # Retrieving validated fields # ----------------------------- - ascending:list[str]|dict|bool|None = values.get("ascending") - descending:list[str]|dict|bool|None = values.get("descending") - by:list[str]|None = values.get("by") - + ascending: list[str] | dict | bool | None = values.get("ascending") + descending: list[str] | dict | bool | None = values.get("descending") + by: list[str] | None = values.get("by") # Initizaling projection if not provided # -------------------------------------- if not query: - # Case #1 : By is provided # ------------------------------ if by: # validates_by ensures that ascending and descending are either None or booleans when by is provided # None or boolean_value = boolean_value # valdiates_booleans ensures that ascending or descending are not both, booleans at the same time - _to_query(query, by, ascending) + # If ascending is None due to validation failure, we should not continue + if ascending is None: + raise ValueError( + "Invalid configuration: ascending cannot be None when by is provided" + ) + _to_query(query, by, ascending) # Case #2 : by is not provided # ------------------------------- @@ -266,7 +274,7 @@ def _to_query(query:dict, sort_args:list[str]|dict, direction:bool)->None: return query @property - def expression(self)->Expression: + def expression(self) -> Expression: """Generates statement from other attributes""" - return self.express({"$sort":self.query}) + return self.express({"$sort": self.query}) diff --git a/src/monggregate/stages/sort_by_count.py b/src/monggregate/stages/sort_by_count.py index 40bd5973..577124f1 100644 --- a/src/monggregate/stages/sort_by_count.py +++ b/src/monggregate/stages/sort_by_count.py @@ -74,14 +74,32 @@ class SortByCount(Stage): Source : https://www.mongodb.com/docs/manual/reference/operator/aggregation/sortByCount/#mongodb-pipeline-pipe.-sortByCount """ - by: str # TODO : Allow more types - # Should be a field path or a valid expression + by: str | list[str] | set[str] + # Validators # ------------------------ _validates_path_to_array = pyd.validator( "by", allow_reuse=True, pre=True, always=True )(validate_field_path) + @pyd.validator("by", pre=True) + @classmethod + def validates_by( + cls, value: str | list[str] | set[str], values: dict + ) -> str | set[str]: + """Validates by""" + + output: str | set[str] + + if isinstance(value, list) or isinstance(value, set): + output = set([validate_field_path(path) for path in value]) + elif isinstance(value, str): + output = validate_field_path(value) + else: + raise ValueError("by must be a string, list of strings, or set of strings") + + return output + @property def expression(self) -> Expression: """Generates sort_by_count stage statement from SortByCount class keywords arguments""" diff --git a/tests/test_coverage.py b/tests/test_coverage.py index b9b4293a..021b0684 100644 --- a/tests/test_coverage.py +++ b/tests/test_coverage.py @@ -162,7 +162,6 @@ def generate_error_message( return error_msg -# @pytest.mark.xfail(reason="We first need to catch up the existing code base.") def test_all_modules_have_tests() -> None: """Test that every Python module in src/monggregate has a corresponding test file in tests/tests_monggregate with the appropriate naming convention. diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..6625a62e --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,65 @@ +"""Tests for the `tests.utils` module.""" + +from tests.utils import generate_enum_member_name + + +class TestGenerateEnumMemberName: + """Tests for the `generate_enum_member_name` function.""" + + def test_with_real_operator_names(self) -> None: + """Test that the `generate_enum_member_name` function generates the correct enum member name.""" + + assert generate_enum_member_name("$addToSet") == "ADD_TO_SET" + assert generate_enum_member_name("bottomN") == "BOTTOM_N" + assert generate_enum_member_name("topN") == "TOP_N" + assert generate_enum_member_name("bottomN") == "BOTTOM_N" + + def test_with_stages_names(self) -> None: + """Test that the `generate_enum_member_name` function generates the correct enum member name.""" + + assert generate_enum_member_name("$lookup") == "LOOKUP" + assert generate_enum_member_name("$project") == "PROJECT" + assert generate_enum_member_name("$match") == "MATCH" + assert generate_enum_member_name("$unwind") == "UNWIND" + assert generate_enum_member_name("$group") == "GROUP" + assert generate_enum_member_name("$sort") == "SORT" + assert generate_enum_member_name("$limit") == "LIMIT" + assert generate_enum_member_name("$skip") == "SKIP" + assert generate_enum_member_name("$sample") == "SAMPLE" + + def test_on_random_names(self) -> None: + """Test that the `generate_enum_member_name` function generates the correct enum member name.""" + + assert generate_enum_member_name("random_name") == "RANDOM_NAME" + assert generate_enum_member_name("addToSet") == "ADD_TO_SET" + + def test_consecutive_uppercase_letters(self) -> None: + """Test that consecutive uppercase letters are kept together.""" + + # These are the problematic cases that currently fail + assert generate_enum_member_name("$indexOfCP") == "INDEX_OF_CP" + assert generate_enum_member_name("$indexCP") == "INDEX_CP" + assert generate_enum_member_name("$strLenCP") == "STR_LEN_CP" + assert generate_enum_member_name("$indexOfBytes") == "INDEX_OF_BYTES" + assert generate_enum_member_name("$substrCP") == "SUBSTR_CP" + assert generate_enum_member_name("$strLenBytes") == "STR_LEN_BYTES" + assert generate_enum_member_name("$substrBytes") == "SUBSTR_BYTES" + + def test_edge_cases(self) -> None: + """Test that the `generate_enum_member_name` function generates the correct enum member name on edge cases. + + - Empty string + - Single letter + - Single letter with $ prefix + - Single letter without $ prefix + - Hybrid case + """ + + assert generate_enum_member_name("") == "" + assert generate_enum_member_name("a") == "A" + assert generate_enum_member_name("$a") == "A" + assert generate_enum_member_name("aB") == "A_B" + assert generate_enum_member_name("aBc") == "A_BC" + assert generate_enum_member_name("XMLHttpRequest") == "XML_HTTP_REQUEST" + assert generate_enum_member_name("HTMLParser") == "HTML_PARSER" + assert generate_enum_member_name("JSONData") == "JSON_DATA" diff --git a/tests/tests_legacy/test_join_alias.py b/tests/tests_legacy/test_join_alias.py index a761dac8..8f881b12 100644 --- a/tests/tests_legacy/test_join_alias.py +++ b/tests/tests_legacy/test_join_alias.py @@ -7,36 +7,32 @@ from monggregate import Pipeline from monggregate.stages import Lookup, Match -def test_left_join()->None: + +def test_left_join() -> None: """Tests left join in pipeline class""" pipeline = Pipeline(collection="left") - pipeline.join( - other = "right", - how = "left", - on = "zipcode" - ) + pipeline.join(other="right", how="left", on="zipcode") assert len(pipeline.stages) == 4, pipeline assert isinstance(pipeline[0], Lookup) assert pipeline[0]() == { - "$lookup":{ - "from" : "right", # from references the right collection - "localField" : "zipcode", - "foreignField" : "zipcode", - "as" : "__right__" - }} - - -def test_inner_join()->None: + "$lookup": { + "from": "right", # from references the right collection + "localField": "zipcode", + "foreignField": "zipcode", + "as": "__right__", + "let": None, + "pipeline": None, + } + } + + +def test_inner_join() -> None: """Tests left join in pipeline class""" pipeline = Pipeline(collection="left") - pipeline.join( - other = "right", - how = "inner", - on = "zipcode" - ) + pipeline.join(other="right", how="inner", on="zipcode") assert len(pipeline.stages) == 5, pipeline assert isinstance(pipeline[0], Lookup) @@ -46,12 +42,16 @@ def test_inner_join()->None: assert isinstance(pipeline[1], Match) assert pipeline[0]() == { - "$lookup":{ - "from" : "right", # from references the right collection - "localField" : "zipcode", - "foreignField" : "zipcode", - "as" : "__right__" - }} + "$lookup": { + "from": "right", # from references the right collection + "localField": "zipcode", + "foreignField": "zipcode", + "as": "__right__", + "let": None, + "pipeline": None, + } + } + if __name__ == "__main__": test_left_join() diff --git a/tests/tests_legacy/test_stages.py b/tests/tests_legacy/test_stages.py index 7b04c879..97ab0dc6 100644 --- a/tests/tests_legacy/test_stages.py +++ b/tests/tests_legacy/test_stages.py @@ -9,7 +9,7 @@ from monggregate.base import pyd from monggregate import Pipeline -from monggregate.stages import( # pylint: disable=import-error +from monggregate.stages import ( # pylint: disable=import-error Stage, BucketAuto, Bucket, @@ -29,11 +29,12 @@ Sort, UnionWith, Unset, - Unwind + Unwind, ) State = dict[str, Stage] + @pytest.mark.stages @pytest.mark.unit @pytest.mark.functional @@ -41,7 +42,7 @@ class TestStages: """This class only aims at reusing the markers""" @pytest.fixture(autouse=True, scope="class") - def state(self)->State: + def state(self) -> State: """ Creates a memory for the tests. @@ -56,33 +57,28 @@ def state(self)->State: """ return {} - def test_stage(self)->None: + def test_stage(self) -> None: """Tests the stage parent class""" with pytest.raises(TypeError): # Checking that Stage cannot be instantiated stage = Stage(expression={}) # pylint: disable=abstract-class-instantiated - - def test_bucket_auto(self, state:State)->None: + def test_bucket_auto(self, state: State) -> None: """Tests $bucketAuto stage class and mirror function""" - bucket_auto = BucketAuto( - group_by="test", - buckets=10 - ) + bucket_auto = BucketAuto(group_by="test", buckets=10) assert bucket_auto state["bucket_auto"] = bucket_auto assert BucketAuto( group_by="test", - buckets = 4, - output = {"new_var":{"$sum":"my_expression"}}, - granularity="E12" + buckets=4, + output={"new_var": {"$sum": "my_expression"}}, + granularity="E12", ) - - def test_bucket(self, state:State)->None: + def test_bucket(self, state: State) -> None: """Tests the $bucket stage class and mirror function""" bucket = Bucket( @@ -93,171 +89,114 @@ def test_bucket(self, state:State)->None: state["bucket"] = bucket assert Bucket( - group_by="income", - boundaries=[25000, 40000, 60000, 100000], - default="other" + group_by="income", boundaries=[25000, 40000, 60000, 100000], default="other" ) - assert Bucket( group_by="income", boundaries=[25000, 40000, 60000, 100000], default="other", - output={"output":{"$sum":1}} + output={"output": {"$sum": 1}}, ) - - - def test_count(self, state:State)->None: + def test_count(self, state: State) -> None: """Tests the count stage""" count = Count(name="count") assert count state["count"] = count - - def test_group(self, state:State)->None: + def test_group(self, state: State) -> None: """Tests the group stage""" # Testing mandatory fields # ------------------------ - group = Group( - query = {"_id":"count"} - ) + group = Group(query={"_id": "count"}) assert group state["group"] = group # Tests aliases # ----------------------- - assert Group( - _id="count" - ) + assert Group(_id="count") # Test optional arguments # ------------------------ - assert Group( - by="count", - query = { - "output":{"$sum":"income"} - } - ) + assert Group(by="count", query={"output": {"$sum": "income"}}) # Test by as list # ------------------------ - assert Group( - by=["name", "age"], - query = { - "output":{"$sum":"income"} - } - ) + assert Group(by=["name", "age"], query={"output": {"$sum": "income"}}) # Test by as set # ------------------------ - assert Group( - by=set(["name", "age"]), - query = { - "output":{"$sum":"income"} - } - ) + assert Group(by=set(["name", "age"]), query={"output": {"$sum": "income"}}) # Test by as constant # ------------------------ - assert Group( - by=1, - query = { - "output":{"$sum":"income"} - } - ) + assert Group(by=1, query={"output": {"$sum": "income"}}) # Test by as dict # ------------------------ - assert Group( - by={"name":"$name"}, - query = { - "output":{"$sum":"income"} - } - ) + assert Group(by={"name": "$name"}, query={"output": {"$sum": "income"}}) # Test by as None # ------------------------ - assert Group( - by=None, - query = { - "output":{"$sum":"income"} - } - ) - - + assert Group(by=None, query={"output": {"$sum": "income"}}) - def test_limit(self, state:State)->None: + def test_limit(self, state: State) -> None: """Tests the limit stage""" limit = Limit(value=10) assert limit state["limit"] = limit - - def test_lookup(self, state:State)->None: + def test_lookup(self, state: State) -> None: """Tests the lookup stage""" # Testing mandatory attributes # ----------------------------- lookup = Lookup( - right = "other_collection", - left_on = "_id", - right_on = "foreign_key", - name = "matches" + right="other_collection", + left_on="_id", + right_on="foreign_key", + name="matches", ) assert lookup state["lookup"] = lookup - assert Lookup( - right = "restaurants", - left_on = "restaurant_name", - right_on = "name", - let = {"orders_drink":"$drink"}, - pipeline = [ - { - "$match" : { - "$expr" : { - "$in" : ["$$orders_drink", "$beverages"] - } - } - } - ], - name = "matches" + assert Lookup( + right="restaurants", + left_on="restaurant_name", + right_on="name", + let={"orders_drink": "$drink"}, + pipeline=[{"$match": {"$expr": {"$in": ["$$orders_drink", "$beverages"]}}}], + name="matches", ) - # Testing aliases # ----------------------------- params = { - "from" : "other_collection", - "local_field" : "_id", - "foreign_field" : "foreign_key", - "as" :"matches" + "from": "other_collection", + "local_field": "_id", + "foreign_field": "foreign_key", + "as": "matches", } simple = Lookup(**params) assert simple # Testing optional attributes # ----------------------------- - uncorrelated = Lookup( - right = "other_collection", - pipeline = [], - name = "new_fields" - ) + uncorrelated = Lookup(right="other_collection", pipeline=[], name="new_fields") assert uncorrelated - - def test_match(self, state:State)->None: + def test_match(self, state: State) -> None: """Tests the match stage""" - match = Match(query={"_id":"12345"}) + match = Match(query={"_id": "12345"}) assert match state["match"] = match - def test_out(self, state:State)->None: + def test_out(self, state: State) -> None: """Tests the out stage""" out = Out(coll="my_collection") @@ -272,13 +211,12 @@ def test_out(self, state:State)->None: # ---------------- assert Out(db="other_db", coll="new_collection") - - def test_project(self, state:State)->None: + def test_project(self, state: State) -> None: """Tests the project stage""" # Testing mandatory attributes # ----------------------------- - project = Project(projection={"_id":0}) #projection is not really mandatory + project = Project(projection={"_id": 0}) # projection is not really mandatory assert project state["project"] = project @@ -290,35 +228,22 @@ def test_project(self, state:State)->None: # Testing aliases # ----------------------------- - assert Project( - fields = "email", - include = True - ) - + assert Project(fields="email", include=True) # Testing optional attributes # ----------------------------- - # Include/Exclude as set - assert Project( - include=set(["name"]) - ) - - assert Project( - exclude=set(["password"]) - ) - # Include/Exclude as str or list of strings - assert Project( - include=set(["name"]) - ) + # Include/Exclude as set + assert Project(include=set(["name"])) - assert Project( - exclude=["name", "_id"] - ) + assert Project(exclude=set(["password"])) + # Include/Exclude as str or list of strings + assert Project(include=set(["name"])) - #TODO: test with string parameters and list + assert Project(exclude=["name", "_id"]) + # TODO: test with string parameters and list - def test_replace_root(self, state:State)->None: + def test_replace_root(self, state: State) -> None: """Tests the replace_root stage""" # Testing mandatory attributes @@ -336,12 +261,12 @@ def test_replace_root(self, state:State)->None: # ----------------------------- # N/A - def test_sample(self, state:State)->None: + def test_sample(self, state: State) -> None: """Tests the sample stage""" # Testing mandatory attributes # ----------------------------- - sample = Sample(value=3) # value is not really mandatory + sample = Sample(value=3) # value is not really mandatory assert sample state["sample"] = sample @@ -351,27 +276,19 @@ def test_sample(self, state:State)->None: # Testing optional attributes # ----------------------------- - sample = Sample() # sample has no mandatory field - # however it has a default + sample = Sample() # sample has no mandatory field + # however it has a default assert sample assert sample.value == 10 - - def test_search(self, state:State)->None: + def test_search(self, state: State) -> None: """Tests the search stage""" - search = Search( - operator={ - "query":"test", - "path":"description" - } - - ) + search = Search(operator={"query": "test", "path": "description"}) state["search"] = search assert search - search = Search.from_operator(operator_name="more_like_this", like={}) assert search @@ -379,16 +296,12 @@ def test_search(self, state:State)->None: # with pytest.raises(pyd.ValidationError): # Search() - - def test_set(self, state:State)->None: + def test_set(self, state: State) -> None: """Tests the set stage""" # Testing mandatory attributes # ----------------------------- - set = Set(document={ - "field1":"value1", - "fieldN":"valueN" - }) + set = Set(document={"field1": "value1", "fieldN": "valueN"}) assert set state["set"] = set @@ -400,8 +313,7 @@ def test_set(self, state:State)->None: # ----------------------------- # N/A - - def test_skip(self, state:State)->None: + def test_skip(self, state: State) -> None: """Tests the skip stage""" # Testing mandatory attributes @@ -418,15 +330,12 @@ def test_skip(self, state:State)->None: # ----------------------------- # N/A - - def test_sort_by_count(self, state:State)->None: + def test_sort_by_count(self, state: State) -> None: """Tests the sort_by_count stage""" # Testing mandatory attributes # ----------------------------- - sort_by_count = SortByCount( - by = "name" - ) + sort_by_count = SortByCount(by="name") assert sort_by_count state["sort_by_count"] = sort_by_count @@ -438,56 +347,46 @@ def test_sort_by_count(self, state:State)->None: # ----------------------------- # N/A - - def test_sort(self, state:State)->None: + def test_sort(self, state: State) -> None: """Tests the sort stage""" # Testing mandatory attributes # ----------------------------- with pytest.raises(pyd.ValidationError): - sort = Sort(query={"field1":1, "fieldN":0}) + sort = Sort(query={"field1": 1, "fieldN": 0}) with pytest.raises(pyd.ValidationError): sort = Sort(query={}) - sort = Sort(query={"field1":1, "fieldN":-1}) + sort = Sort(query={"field1": 1, "fieldN": -1}) assert sort state["sort"] = sort # Testing aliases # ----------------------------- - assert Sort( - by = "count" - ) + assert Sort(by="count") # Testing optional attributes # ----------------------------- assert Sort(ascending=set(["field1", "field2"])) assert Sort(descending=set(["field1"])) assert Sort( - ascending=set(["field1", "field2"]), - descending=set(["field3", "fieldN"]) - ) + ascending=set(["field1", "field2"]), descending=set(["field3", "fieldN"]) + ) # Further testing ascending and descending complex logic # ---------------------------- - assert Sort( - by = "year", - ascending = True - ) + assert Sort(by="year", ascending=True) - assert Sort( - ascending = {"year":1} - ) - - def test_union_with(self, state:State)->None: + assert Sort(ascending={"year": 1}) + + def test_union_with(self, state: State) -> None: """Tests the $unionWith stage""" # Testing mandatory attributes # ----------------------------- union_with = UnionWith( - collection="other_collection", - pipeline = Pipeline().match({"name":"test"}) + collection="other_collection", pipeline=Pipeline().match({"name": "test"}) ) assert union_with state["union_with"] = union_with @@ -500,8 +399,7 @@ def test_union_with(self, state:State)->None: # ----------------------------- # N/A - - def test_unset(self, state:State)->None: + def test_unset(self, state: State) -> None: """Tests the $unset stage""" # Testing mandatory attributes @@ -518,233 +416,219 @@ def test_unset(self, state:State)->None: # ----------------------------- # N/A - - def test_unwind(self, state:State)->None: + def test_unwind(self, state: State) -> None: """Tests the $unwind stage""" # Testing mandatory attributes # ----------------------------- unwind = Unwind( - path_to_array = "xyz", + path_to_array="xyz", ) assert unwind state["unwind"] = unwind # Testing aliases # ----------------------------- - assert Unwind( - path = "xyz" - ) + assert Unwind(path="xyz") # Testing optional attributes # ----------------------------- - assert Unwind( - path = "xyz", - include_array_index = "index", - always = True - ) + assert Unwind(path="xyz", include_array_index="index", always=True) @pytest.mark.latest class TestStagesFunctional(TestStages): """This class gather the functional tests of the stages""" - - def test_bucket_auto_statement(self, state:State)->None: + def test_bucket_auto_statement(self, state: State) -> None: """Tests the BucketAuto class statement and its mirror function""" - assert state["bucket_auto"].expression == Pipeline().bucket_auto( - by = "test", - buckets = 10 - )[-1].expression == { - "$bucketAuto" : { - "groupBy" : "$test", - "buckets" : 10, - "output" : None, - "granularity" : None + assert ( + state["bucket_auto"].expression + == Pipeline().bucket_auto(by="test", buckets=10)[-1].expression + == { + "$bucketAuto": { + "groupBy": "$test", + "buckets": 10, + "output": None, + "granularity": None, } } + ) - def test_bucket_statement(self, state:State)->None: + def test_bucket_statement(self, state: State) -> None: """ Tests the Bucket class statement and its mirror function """ bucket = state["bucket"] - assert bucket.expression == Pipeline().bucket( - by = "income", - boundaries = [25000, 40000, 60000, 100000])[-1].expression == { - "$bucket" : { - "groupBy" : "$income", - "boundaries" : [25000, 40000, 60000, 100000], - "default" : None, - "output" : None + assert ( + bucket.expression + == Pipeline() + .bucket(by="income", boundaries=[25000, 40000, 60000, 100000])[-1] + .expression + == { + "$bucket": { + "groupBy": "$income", + "boundaries": [25000, 40000, 60000, 100000], + "default": None, + "output": None, + } } - } + ) - def test_count_statement(self, state:State)->None: + def test_count_statement(self, state: State) -> None: """Tests the Count class statement and its mirror function""" count = state["count"] - assert count.expression == Pipeline().count(name="count")[0].expression == { - "$count" : "count" - } + assert ( + count.expression + == Pipeline().count(name="count")[0].expression + == {"$count": "count"} + ) - def test_group_statement(self, state:State)->None: + def test_group_statement(self, state: State) -> None: """Tests the Group class statement and its mirror function""" group = state["group"] - assert group.expression == Pipeline().group( - query = { - "_id" :"count" - } - )[0].expression == { - "$group" : { - "_id" :"count" - } - } + assert ( + group.expression + == Pipeline().group(query={"_id": "count"})[0].expression + == {"$group": {"_id": "count"}} + ) - def test_limit_statement(self, state:State)->None: + def test_limit_statement(self, state: State) -> None: """Tests the Limit class statement and its mirror function""" limit = state["limit"] - assert limit.expression == Pipeline().limit(10)[0].expression == { - "$limit" : 10 - } + assert limit.expression == Pipeline().limit(10)[0].expression == {"$limit": 10} - def test_lookup_statement(self, state:State)->None: + def test_lookup_statement(self, state: State) -> None: """Tests the Limit class statement and its mirror function""" lookup = state["lookup"] - assert lookup.expression == Pipeline().lookup( - right = "other_collection", - left_on = "_id", - right_on = "foreign_key", - name = "matches" - )[0].expression == { - "$lookup" :{ - "from" : "other_collection", - "localField" : "_id", - "foreignField" : "foreign_key", - "as" : "matches" + assert ( + lookup.expression + == Pipeline() + .lookup( + right="other_collection", + left_on="_id", + right_on="foreign_key", + name="matches", + )[0] + .expression + == { + "$lookup": { + "from": "other_collection", + "localField": "_id", + "foreignField": "foreign_key", + "as": "matches", + "let": None, + "pipeline": None, + } } - } + ) - def test_match_statement(self, state:State)->None: + def test_match_statement(self, state: State) -> None: """Tests the Match class and its mirror function""" match = state["match"] - assert match.expression == Pipeline().match( - query = { - "_id":"12345" - } - )[0].expression == { - "$match" : { - "_id" : "12345" - } - } + assert ( + match.expression + == Pipeline().match(query={"_id": "12345"})[0].expression + == {"$match": {"_id": "12345"}} + ) - def test_out_satement(self, state:State)->None: + def test_out_satement(self, state: State) -> None: """Tests the Out class and its mirror function""" out = state["out"] - assert out.expression == Pipeline().out("my_collection")[0].expression == { - "$out" : "my_collection" - } + assert ( + out.expression + == Pipeline().out("my_collection")[0].expression + == {"$out": "my_collection"} + ) - def test_project_statement(self, state:State)->None: + def test_project_statement(self, state: State) -> None: """Tests the Project class and its mirror function""" project = state["project"] - assert project.expression == Pipeline().project( - exclude = "_id" - )[0].expression == { - "$project" : { - "_id" : 0 - } - } + assert ( + project.expression + == Pipeline().project(exclude="_id")[0].expression + == {"$project": {"_id": 0}} + ) - def test_replace_root_statement(self, state:State)->None: + def test_replace_root_statement(self, state: State) -> None: """Tests the ReplaceRoot class and its mirror function""" replace_root = state["replace_root"] - assert replace_root.expression == Pipeline().replace_root( - "myarray.mydocument" - )[0].expression == { - "$replaceRoot" : { - "newRoot" : "$myarray.mydocument" - } - } + assert ( + replace_root.expression + == Pipeline().replace_root("myarray.mydocument")[0].expression + == {"$replaceRoot": {"newRoot": "$myarray.mydocument"}} + ) - def test_sample_statement(self, state:State) -> None: + def test_sample_statement(self, state: State) -> None: """Tests the Sample class and its mirror function""" sample = state["sample"] - assert sample.expression == Pipeline().sample(3)[0].expression == { - "$sample" : { - "size" : 3 - } - } + assert ( + sample.expression + == Pipeline().sample(3)[0].expression + == {"$sample": {"size": 3}} + ) - def test_set_statement(self, state:State) -> None: + def test_set_statement(self, state: State) -> None: """Tests the Set class and its mirror function""" set = state["set"] - assert set.expression == Pipeline().set( - { - "field1":"value1", - "fieldN":"valueN" - } - )[0].expression == { - "$set" : { - "field1":"value1", - "fieldN":"valueN" - } - } + assert ( + set.expression + == Pipeline().set({"field1": "value1", "fieldN": "valueN"})[0].expression + == {"$set": {"field1": "value1", "fieldN": "valueN"}} + ) - def test_skip_statement(self, state:State)->None: + def test_skip_statement(self, state: State) -> None: """Tests the Skip class and its mirror function""" skip = state["skip"] - assert skip.expression == Pipeline().skip(10)[0].expression == { - "$skip" : 10 - } - + assert skip.expression == Pipeline().skip(10)[0].expression == {"$skip": 10} - def test_sort_by_count_statement(self, state:State)->None: + def test_sort_by_count_statement(self, state: State) -> None: """Tests the SortByCount class and its mirror function""" sort_by_count = state["sort_by_count"] - assert sort_by_count.expression == Pipeline().sort_by_count( - by = "name" - )[0].expression == { - "$sortByCount" : "$name" - } + assert ( + sort_by_count.expression + == Pipeline().sort_by_count(by="name")[0].expression + == {"$sortByCount": "$name"} + ) - def test_sort_statement(self, state:State)->None: + def test_sort_statement(self, state: State) -> None: """Tests the Sort class and its mirror function""" sort = state["sort"] - assert sort.expression == Pipeline().sort(field1=1, fieldN=-1)[0].expression == { - "$sort" : { - "field1" : 1, - "fieldN" : -1 - } - } + assert ( + sort.expression + == Pipeline().sort(field1=1, fieldN=-1)[0].expression + == {"$sort": {"field1": 1, "fieldN": -1}} + ) - def test_unwind_statement(self, state:State)->None: + def test_unwind_statement(self, state: State) -> None: """Tests the Unwind class and its mirror function""" unwind = state["unwind"] - assert unwind.expression == Pipeline().unwind("xyz")[0].expression == { - "$unwind" : { - "path" : "$xyz" - } - } + assert ( + unwind.expression + == Pipeline().unwind("xyz")[0].expression + == {"$unwind": {"path": "$xyz"}} + ) + # ------------------------ # Debugging: -#------------------------- +# ------------------------- if __name__ == "__main__": # TestStages().test_stage() # TestStages().test_bucket_auto({}) diff --git a/tests/tests_monggregate/test_base.py b/tests/tests_monggregate/test_base.py index 11eff068..6db317ee 100644 --- a/tests/tests_monggregate/test_base.py +++ b/tests/tests_monggregate/test_base.py @@ -1,25 +1,24 @@ -import pytest from pydantic import BaseModel as PydanticBaseModel from monggregate.base import BaseModel, Expression, Singleton, express, isbasemodel # Create a simple subclass of BaseModel for testing -class TestModel(BaseModel): +class DummyModel(BaseModel): field1: str = "default value" field2: int = 0 @property def expression(self) -> Expression: - return {"$add": [self.field1, self.field2]} + return {"$dummy": [self.field1, self.field2]} # Create a simple subclass of PydanticBaseModel for testing -class TestPydanticModel(PydanticBaseModel): +class DummyPydanticModel(PydanticBaseModel): field1: str = "default value" field2: int = 0 -def test_singleton_instantiation(): +def test_singleton_instantiation() -> None: """Test that Singleton class can be instantiated correctly.""" # Create two instances of the Singleton class @@ -30,28 +29,28 @@ def test_singleton_instantiation(): assert instance1 is instance2 -def test_base_model_instantiation(): +def test_base_model_instantiation() -> None: """Test that BaseModel class can be instantiated correctly.""" # Instantiate the model - model = TestModel() + model = DummyModel() # Check the default values assert model.field1 == "default value" assert model.field2 == 0 # Test with custom values - custom_model = TestModel(field1="custom", field2=42) + custom_model = DummyModel(field1="custom", field2=42) assert custom_model.field1 == "custom" assert custom_model.field2 == 42 -def test_isbasemodel(): +def test_isbasemodel() -> None: """Test that isbasemodel function works correctly.""" # Instantiate the model - test_model = TestModel() - test_pydantic_model = TestPydanticModel() + test_model = DummyModel() + test_pydantic_model = DummyPydanticModel() # Check that the model is a BaseModel assert isbasemodel(test_model) @@ -66,69 +65,205 @@ def test_isbasemodel(): class TestExpress: """Test that express function works correctly.""" - def test_with_basemodel_instance(self): + def test_with_basemodel_instance(self) -> None: """Test that express function works correctly for BaseModel objects.""" # Instantiate the model - test_model = TestModel() + test_model = DummyModel() # Check that the expression is correct - assert express(test_model) == {"$add": ["default value", 0]} + assert express(test_model) == {"$dummy": ["default value", 0]} - def test_with_list_of_basemodel_instances(self): + def test_with_list_of_basemodel_instances(self) -> None: """Test that express function works correctly for a list of BaseModel objects.""" # Instantiate the model - test_model_1 = TestModel() - test_model_2 = TestModel() + test_model_1 = DummyModel() + test_model_2 = DummyModel() # Create a list of the models test_model_list = [test_model_1, test_model_2] # Check that the expression is correct assert express(test_model_list) == [ - {"$add": ["default value", 0]}, - {"$add": ["default value", 0]}, + {"$dummy": ["default value", 0]}, + {"$dummy": ["default value", 0]}, ] - def test_with_dict_of_basemodel_instances(self): + def test_with_dict_of_basemodel_instances(self) -> None: """Test that express function works correctly for a dictionary of BaseModel objects.""" # Instantiate the model - test_model = TestModel() + test_model = DummyModel() unresolved_expression = {"$add": [test_model, 0]} # Check that the expression is correct # fmt: off assert express(unresolved_expression) == { "$add": [ - {"$add": ["default value", 0]}, + {"$dummy": ["default value", 0]}, 0], } # fmt: on - @pytest.mark.xfail( - reason="This comes from an issue in the recursion of the express function." - ) - def test_with_nested_basemodel_instances(self): + def test_with_nested_basemodel_instances(self) -> None: """Test that express function works correctly for a nested BaseModel object.""" # Instantiate the model - test_model = TestModel() - unresolved_expression_layer_1 = {"$add": [test_model, 0]} - unresolved_expression_layer_2 = {"$add": [unresolved_expression_layer_1, 0]} + test_model = DummyModel() + unresolved_expression_nested_layer = {"$add": [test_model, 0]} + unresolved_expression_top_layer = { + "$add": [unresolved_expression_nested_layer, 1] + } # Check that the expression is correct # fmt: off - assert express(unresolved_expression_layer_2) == { + assert express(unresolved_expression_top_layer) == { "$add": [ { "$add": [ - {"$add": ["default value", 0]}, + {"$dummy": ["default value", 0]}, 0, ] }, - 0, + 1, ], - }, express(unresolved_expression_layer_1) + }, express(unresolved_expression_top_layer) # fmt: on + + def test_with_primitive_types(self) -> None: + """Test that express function works correctly with primitive types.""" + + # Test various primitive types + assert express(42) == 42 + assert express("hello") == "hello" + assert express(True) is True + assert express(False) is False + assert express(None) is None + assert express(3.14) == 3.14 + + def test_with_empty_containers(self) -> None: + """Test that express function works correctly with empty containers.""" + + # Test empty list and dict + assert express([]) == [] + assert express({}) == {} + + def test_with_mixed_list_types(self) -> None: + """Test that express function works correctly with lists containing mixed types.""" + + test_model = DummyModel() + mixed_list = [test_model, 42, "hello", None, {"key": "value"}] + + expected = [ + {"$dummy": ["default value", 0]}, + 42, + "hello", + None, + {"key": "value"}, + ] + + assert express(mixed_list) == expected + + def test_with_deeply_nested_structures(self) -> None: + """Test that express function works correctly with deeply nested structures.""" + + test_model = DummyModel() + deeply_nested = { + "level1": {"level2": {"level3": [test_model, {"level4": test_model}]}} + } + + expected = { + "level1": { + "level2": { + "level3": [ + {"$dummy": ["default value", 0]}, + {"level4": {"$dummy": ["default value", 0]}}, + ] + } + } + } + + assert express(deeply_nested) == expected + + def test_with_list_of_dicts_containing_basemodels(self) -> None: + """Test that express function works correctly with lists of dictionaries containing BaseModels.""" + + test_model1 = DummyModel(field1="first", field2=1) + test_model2 = DummyModel(field1="second", field2=2) + + list_of_dicts = [ + {"operation": "$add", "operands": [test_model1, 10]}, + {"operation": "$multiply", "operands": [test_model2, 5]}, + ] + + expected = [ + {"operation": "$add", "operands": [{"$dummy": ["first", 1]}, 10]}, + {"operation": "$multiply", "operands": [{"$dummy": ["second", 2]}, 5]}, + ] + + assert express(list_of_dicts) == expected + + def test_with_basemodel_containing_complex_expression(self) -> None: + """Test that express function works correctly when BaseModel expression is complex.""" + + class ComplexModel(BaseModel): + name: str = "complex" + + @property + def expression(self) -> Expression: + return { + "$complex": { + "name": self.name, + "nested": {"array": [1, 2, 3], "object": {"key": "value"}}, + } + } + + complex_model = ComplexModel() + nested_structure = { + "$pipeline": [{"$match": complex_model}, {"$group": {"_id": complex_model}}] + } + + expected = { + "$pipeline": [ + { + "$match": { + "$complex": { + "name": "complex", + "nested": {"array": [1, 2, 3], "object": {"key": "value"}}, + } + } + }, + { + "$group": { + "_id": { + "$complex": { + "name": "complex", + "nested": { + "array": [1, 2, 3], + "object": {"key": "value"}, + }, + } + } + } + }, + ] + } + + assert express(nested_structure) == expected + + def test_with_tuple_and_other_sequences(self) -> None: + """Test that express function works correctly with tuples and other sequence types.""" + + test_model = DummyModel() + + # Test with tuple containing BaseModel + test_tuple = (test_model, 42, "hello") + # Tuples should be treated as regular objects since they're immutable + # The function should return the tuple as-is since it's not a list or dict + assert express(test_tuple) == (test_model, 42, "hello") + + # Test with set containing BaseModel (though sets are not JSON serializable) + # Sets should also be returned as-is + test_set = {42, "hello"} # Can't put BaseModel in set due to unhashable type + assert express(test_set) == test_set diff --git a/tests/tests_monggregate/test_fields.py b/tests/tests_monggregate/test_fields.py index fae7d61d..609207f8 100644 --- a/tests/tests_monggregate/test_fields.py +++ b/tests/tests_monggregate/test_fields.py @@ -5,8 +5,9 @@ class TestFieldName: """Test the FieldName class.""" - def test_validate_valid_field_name(self): + def test_validate_valid_field_name(self) -> None: """Test that a valid field name passes validation.""" + # Setup field_name = "validField" @@ -16,8 +17,9 @@ def test_validate_valid_field_name(self): # Assert assert result == field_name - def test_validate_invalid_field_name_with_dollar(self): + def test_validate_invalid_field_name_with_dollar(self) -> None: """Test that a field name starting with $ fails validation.""" + # Setup field_name = "$invalidField" @@ -25,8 +27,9 @@ def test_validate_invalid_field_name_with_dollar(self): with pytest.raises(ValueError): FieldName.validate(field_name) - def test_validate_invalid_field_name_with_dot(self): + def test_validate_invalid_field_name_with_dot(self) -> None: """Test that a field name containing a dot fails validation.""" + # Setup field_name = "invalid.field" @@ -34,8 +37,9 @@ def test_validate_invalid_field_name_with_dot(self): with pytest.raises(ValueError): FieldName.validate(field_name) - def test_validate_edge_case_empty_string(self): + def test_validate_edge_case_empty_string(self) -> None: """Test that an empty string fails validation.""" + # Setup field_name = "" @@ -47,8 +51,9 @@ def test_validate_edge_case_empty_string(self): class TestFieldPath: """Test the FieldPath class.""" - def test_validate_valid_field_path(self): + def test_validate_valid_field_path(self) -> None: """Test that a valid field path passes validation.""" + # Setup field_path = "$validPath" @@ -58,46 +63,23 @@ def test_validate_valid_field_path(self): # Assert assert result == field_path - def test_validate_invalid_field_path_without_dollar(self): + def test_validate_invalid_field_path(self) -> None: """Test that a field path without $ fails validation.""" - # Setup - field_path = "invalidPath" - # Act & Assert - with pytest.raises(ValueError): - FieldPath.validate(field_path) - - @pytest.mark.xfail(reason="Should be fixed in the code.") - def test_validate_invalid_field_path_with_double_dollar(self): - """Test that a field path with $$ fails validation.""" # Setup - field_path = "$$invalidPath" + field_path = "invalidPath" # Act & Assert with pytest.raises(ValueError): FieldPath.validate(field_path) - pytest.mark.xfail( - reason="This passes but should fail. Need to be fixed in the code" - ) - - def test_validate_edge_case_single_dollar(self): - """Test that just a single $ passes validation.""" - # Setup - field_path = "$" - - # Act - result = FieldPath.validate(field_path) - - # Assert - assert result == field_path - class TestVariable: """Test the Variable class.""" - def test_validate_valid_variable(self): + def test_validate_valid_variable(self) -> None: """Test that a valid variable passes validation.""" + # Setup variable = "$$validVariable" @@ -107,8 +89,9 @@ def test_validate_valid_variable(self): # Assert assert result == variable - def test_validate_invalid_variable_without_dollars(self): + def test_validate_invalid_variable_without_dollars(self) -> None: """Test that a variable without $$ fails validation.""" + # Setup variable = "invalidVariable" @@ -116,8 +99,9 @@ def test_validate_invalid_variable_without_dollars(self): with pytest.raises(ValueError): Variable.validate(variable) - def test_validate_invalid_variable_with_single_dollar(self): + def test_validate_invalid_variable_with_single_dollar(self) -> None: """Test that a variable with single $ fails validation.""" + # Setup variable = "$invalidVariable" @@ -125,8 +109,9 @@ def test_validate_invalid_variable_with_single_dollar(self): with pytest.raises(ValueError): Variable.validate(variable) - def test_validate_edge_case_system_variable(self): + def test_validate_edge_case_system_variable(self) -> None: """Test that a system variable passes validation.""" + # Setup variable = "$$NOW" @@ -137,8 +122,9 @@ def test_validate_edge_case_system_variable(self): assert result == variable -def test_field_types_validation(): +def test_field_types_validation() -> None: """Test that field types validate input correctly.""" + # Valid field name (no $ at start, no dots) assert FieldName.validate("validField") == "validField" diff --git a/tests/tests_monggregate/test_pipeline.py b/tests/tests_monggregate/test_pipeline.py index 2ac72ec5..ea2e34cf 100644 --- a/tests/tests_monggregate/test_pipeline.py +++ b/tests/tests_monggregate/test_pipeline.py @@ -1,11 +1,33 @@ import pytest -from monggregate.pipeline import Pipeline, Match, Project +from monggregate.pipeline import Pipeline +from monggregate.stages import ( + AddFields, + Bucket, + BucketAuto, + Count, + Group, + Limit, + Lookup, + Match, + Out, + Project, + ReplaceRoot, + Sample, + Set, + Skip, + SortByCount, + Sort, + UnionWith, + Unset, + Unwind, + VectorSearch, +) class TestPipeline: """Test the Pipeline class.""" - def test_instantiation(self): + def test_instantiation(self) -> None: """Test that Pipeline class can be instantiated correctly.""" pipeline = Pipeline() @@ -18,7 +40,7 @@ def test_instantiation(self): # Check that the pipeline's expression property returns an empty list assert pipeline.expression == [] - def test___add__(self): + def test___add__(self) -> None: """Test the __add__ method of the Pipeline class.""" pipeline1 = Pipeline() @@ -42,7 +64,7 @@ def test___add__(self): "$project": {"name": 1, "age": 1} } - def test___add_order_should_matter(self): + def test___add_order_should_matter(self) -> None: """Test that the order of addition matters.""" pipeline1 = Pipeline() pipeline2 = Pipeline() @@ -55,7 +77,7 @@ def test___add_order_should_matter(self): assert combined_pipeline.export() != reversed_combined_pipeline.export() - def test___add__with_non_pipeline_object(self): + def test___add__with_non_pipeline_object(self) -> None: """Test the __add__ method of the Pipeline class with a non-Pipeline object.""" pipeline = Pipeline() pipeline.match(query={"name": "John"}) @@ -63,7 +85,7 @@ def test___add__with_non_pipeline_object(self): with pytest.raises(TypeError): pipeline + Project(fields=["name", "age"], include=True) - def test___getitem__(self): + def test___getitem__(self) -> None: """Test the __getitem__ method of the Pipeline class.""" pipeline = Pipeline() pipeline.match(query={"name": "John"}) @@ -71,7 +93,7 @@ def test___getitem__(self): assert pipeline[0].expression == {"$match": {"name": "John"}} - def test__getitem__index_out_of_range(self): + def test__getitem__index_out_of_range(self) -> None: """Test the __getitem__ method of the Pipeline class with edge cases.""" pipeline = Pipeline() @@ -82,7 +104,7 @@ def test__getitem__index_out_of_range(self): with pytest.raises(IndexError): pipeline[2] - def test__setitem__(self): + def test__setitem__(self) -> None: """Test the __setitem__ method of the Pipeline class.""" index = 0 @@ -95,7 +117,7 @@ def test__setitem__(self): assert isinstance(pipeline[index], Project) - def test__setitem__index_out_of_range(self): + def test__setitem__index_out_of_range(self) -> None: """Test the __setitem__ method of the Pipeline class with edge cases.""" index = 2 @@ -106,7 +128,7 @@ def test__setitem__index_out_of_range(self): with pytest.raises(IndexError): pipeline[index] = Project(fields=["name", "age"], include=True) - def test__delitem__(self): + def test__delitem__(self) -> None: """Test the __delitem__ method of the Pipeline class.""" pipeline = Pipeline() pipeline.unwind(path="name") @@ -115,7 +137,7 @@ def test__delitem__(self): del pipeline[0] assert pipeline.export() == [{"$match": {"name": "John"}}] - def test__len__(self): + def test__len__(self) -> None: """Test the __len__ method of the Pipeline class.""" pipeline = Pipeline() @@ -124,7 +146,7 @@ def test__len__(self): assert len(pipeline) == 2 - def test_append(self): + def test_append(self) -> None: """Test the append method of the Pipeline class.""" pipeline = Pipeline() @@ -136,7 +158,7 @@ def test_append(self): assert len(pipeline) == 3 assert isinstance(pipeline[2], Project) - def test_insert(self): + def test_insert(self) -> None: """Test the insert method of the Pipeline class.""" pipeline = Pipeline() @@ -147,7 +169,7 @@ def test_insert(self): assert len(pipeline) == 3 - def test_extend(self): + def test_extend(self) -> None: """Test the extend method of the Pipeline class.""" pipeline = Pipeline() @@ -160,6 +182,1087 @@ def test_extend(self): # Stages # --------------------------------------------------- - # Add tests for stages methods below - # - # ...... + @pytest.mark.xfail( + reason="""AddFields is implemented as a simple alias for Set stage, + which is correct, but it should be done in a different way here. + Indeed right now, the symbol is set to $set rather than $addFields. + + """ + ) + class TestAddFields: + """Test the `add_fields` method of the Pipeline class.""" + + def test_with_document(self) -> None: + """Test the `add_fields` method of the Pipeline class with a document.""" + + expected_first_stage = AddFields(document={"name": "John", "age": 30}) + + pipeline = Pipeline() + pipeline.add_fields(document={"name": "John", "age": 30}) + + assert pipeline[0] == expected_first_stage + assert pipeline.export() == [{"$addFields": {"name": "John", "age": 30}}] + + def test_with_kwargs(self) -> None: + """Test the `add_fields` method of the Pipeline class with kwargs.""" + + pipeline = Pipeline() + pipeline.add_fields(name="John", age=30) + + assert pipeline.export() == [{"$addFields": {"name": "John", "age": 30}}] + + class TestBucket: + """Test the `bucket` method of the Pipeline class.""" + + def test_with_required_params(self) -> None: + """Test the `bucket` method with required parameters.""" + + pipeline = Pipeline() + boundaries = [0, 10, 20, 50, 100] + pipeline.bucket(by="price", boundaries=boundaries) + + expected_stage = Bucket(by="price", boundaries=boundaries) + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Bucket) + assert len(pipeline) == 1 + + def test_with_group_by_param(self) -> None: + """Test the `bucket` method with group_by parameter.""" + + pipeline = Pipeline() + boundaries = [0, 10, 20, 50, 100] + pipeline.bucket(group_by="price", boundaries=boundaries) + + expected_stage = Bucket(by="price", boundaries=boundaries) + assert pipeline[0] == expected_stage + + def test_with_default_param(self) -> None: + """Test the `bucket` method with default parameter.""" + + pipeline = Pipeline() + boundaries = [0, 10, 20, 50, 100] + default_value = "Other" + pipeline.bucket(by="price", boundaries=boundaries, default=default_value) + + expected_stage = Bucket( + by="price", boundaries=boundaries, default=default_value + ) + assert pipeline[0] == expected_stage + + def test_with_output_param(self) -> None: + """Test the `bucket` method with output parameter.""" + + pipeline = Pipeline() + boundaries = [0, 10, 20, 50, 100] + output = {"count": {"$sum": 1}, "total": {"$sum": "$price"}} + pipeline.bucket(by="price", boundaries=boundaries, output=output) + + expected_stage = Bucket(by="price", boundaries=boundaries, output=output) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that bucket method returns self for chaining.""" + + pipeline = Pipeline() + boundaries = [0, 10, 20, 50, 100] + result = pipeline.bucket(by="price", boundaries=boundaries) + + assert result is pipeline + assert len(pipeline) == 1 + + class TestBucketAuto: + """Test the `bucket_auto` method of the Pipeline class.""" + + def test_with_required_params(self) -> None: + """Test the `bucket_auto` method with required parameters.""" + + pipeline = Pipeline() + pipeline.bucket_auto(by="price", buckets=4) + + expected_stage = BucketAuto(by="price", buckets=4) + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], BucketAuto) + assert len(pipeline) == 1 + + def test_with_group_by_param(self) -> None: + """Test the `bucket_auto` method with group_by parameter.""" + + pipeline = Pipeline() + pipeline.bucket_auto(group_by="price", buckets=4) + + expected_stage = BucketAuto(by="price", buckets=4) + assert pipeline[0] == expected_stage + + def test_with_output_param(self) -> None: + """Test the `bucket_auto` method with output parameter.""" + + pipeline = Pipeline() + output = {"count": {"$sum": 1}, "total": {"$sum": "$price"}} + pipeline.bucket_auto(by="price", buckets=4, output=output) + + expected_stage = BucketAuto(by="price", buckets=4, output=output) + assert pipeline[0] == expected_stage + + def test_with_granularity_param(self) -> None: + """Test the `bucket_auto` method with granularity parameter.""" + + pipeline = Pipeline() + from monggregate.stages import GranularityEnum + + pipeline.bucket_auto(by="price", buckets=4, granularity=GranularityEnum.R5) + + expected_stage = BucketAuto( + by="price", buckets=4, granularity=GranularityEnum.R5 + ) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that bucket_auto method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.bucket_auto(by="price", buckets=4) + + assert result is pipeline + assert len(pipeline) == 1 + + class TestCount: + """Test the `count` method of the Pipeline class.""" + + def test_with_name(self) -> None: + """Test the `count` method of the Pipeline class with a name.""" + + expected_first_stage = Count(name="a_field") + + pipeline = Pipeline() + pipeline.count(name="a_field") + + assert pipeline[0] == expected_first_stage + assert pipeline.export() == [{"$count": "a_field"}] + + class TestExplode: + """Test the `explode` method of the Pipeline class.""" + + def test_with_path_to_array(self) -> None: + """Test the `explode` method with path_to_array parameter.""" + + pipeline = Pipeline() + pipeline.explode(path_to_array="tags") + + expected_stage = Unwind(path_to_array="tags") + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Unwind) + + def test_with_path(self) -> None: + """Test the `explode` method with path parameter.""" + + pipeline = Pipeline() + pipeline.explode(path="tags") + + expected_stage = Unwind(path_to_array="tags") + assert pipeline[0] == expected_stage + + def test_with_include_array_index(self) -> None: + """Test the `explode` method with include_array_index parameter.""" + + pipeline = Pipeline() + pipeline.explode(path="tags", include_array_index="tag_index") + + expected_stage = Unwind( + path_to_array="tags", include_array_index="tag_index" + ) + assert pipeline[0] == expected_stage + + def test_with_always_parameter(self) -> None: + """Test the `explode` method with always parameter.""" + + pipeline = Pipeline() + pipeline.explode(path="tags", always=True) + + expected_stage = Unwind(path_to_array="tags", always=True) + assert pipeline[0] == expected_stage + + def test_with_preserve_null_and_empty_arrays(self) -> None: + """Test the `explode` method with preserve_null_and_empty_arrays parameter.""" + + pipeline = Pipeline() + pipeline.explode(path="tags", preserve_null_and_empty_arrays=True) + + expected_stage = Unwind(path_to_array="tags", always=True) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that explode method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.explode(path="tags") + + assert result is pipeline + assert len(pipeline) == 1 + + class TestGroup: + """Test the `group` method of the Pipeline class.""" + + def test_with_by_parameter(self) -> None: + """Test the `group` method with by parameter.""" + + pipeline = Pipeline() + pipeline.group(by="category", query={"count": {"$sum": 1}}) + + expected_stage = Group(by="category", query={"count": {"$sum": 1}}) + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Group) + + def test_with_id_parameter(self) -> None: + """Test the `group` method with _id parameter.""" + + pipeline = Pipeline() + pipeline.group(_id="category", query={"count": {"$sum": 1}}) + + expected_stage = Group(by="category", query={"count": {"$sum": 1}}) + assert pipeline[0] == expected_stage + + def test_with_multiple_fields(self) -> None: + """Test the `group` method with multiple grouping fields.""" + + pipeline = Pipeline() + pipeline.group(by=["category", "status"], query={"count": {"$sum": 1}}) + + expected_stage = Group( + by=["category", "status"], query={"count": {"$sum": 1}} + ) + assert pipeline[0] == expected_stage + + def test_with_dict_grouping(self) -> None: + """Test the `group` method with dictionary grouping.""" + + pipeline = Pipeline() + group_expr = {"category": "$category", "year": {"$year": "$date"}} + pipeline.group(by=group_expr, query={"count": {"$sum": 1}}) + + expected_stage = Group(by=group_expr, query={"count": {"$sum": 1}}) + assert pipeline[0] == expected_stage + + def test_with_no_grouping(self) -> None: + """Test the `group` method with no grouping (aggregate all documents).""" + + pipeline = Pipeline() + pipeline.group(by=None, query={"total": {"$sum": "$amount"}}) + + expected_stage = Group(by=None, query={"total": {"$sum": "$amount"}}) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that group method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.group(by="category", query={"count": {"$sum": 1}}) + + assert result is pipeline + assert len(pipeline) == 1 + + class TestLimit: + """Test the `limit` method of the Pipeline class.""" + + def test_with_value(self) -> None: + """Test the `limit` method with a value.""" + + pipeline = Pipeline() + pipeline.limit(value=10) + + expected_stage = Limit(value=10) + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Limit) + assert pipeline.export() == [{"$limit": 10}] + + def test_with_large_value(self) -> None: + """Test the `limit` method with a large value.""" + + pipeline = Pipeline() + pipeline.limit(value=1000000) + + expected_stage = Limit(value=1000000) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that limit method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.limit(value=5) + + assert result is pipeline + assert len(pipeline) == 1 + + class TestLookup: + """Test the `lookup` method of the Pipeline class.""" + + def test_with_basic_params(self) -> None: + """Test the `lookup` method with basic parameters.""" + + pipeline = Pipeline() + pipeline.lookup( + name="orders", right="orders", left_on="customer_id", right_on="_id" + ) + + expected_stage = Lookup( + name="orders", right="orders", left_on="customer_id", right_on="_id" + ) + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Lookup) + + def test_with_on_parameter(self) -> None: + """Test the `lookup` method with on parameter for equal field names.""" + + pipeline = Pipeline() + pipeline.lookup(name="user_data", right="users", on="user_id") + + expected_stage = Lookup(name="user_data", right="users", on="user_id") + assert pipeline[0] == expected_stage + + def test_with_official_mongodb_names(self) -> None: + """Test the `lookup` method with official MongoDB parameter names.""" + + pipeline = Pipeline() + pipeline.lookup( + name="orders", + right="orders", + local_field="customer_id", + foreign_field="_id", + ) + + expected_stage = Lookup( + name="orders", right="orders", left_on="customer_id", right_on="_id" + ) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that lookup method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.lookup(name="orders", right="orders", on="customer_id") + + assert result is pipeline + assert len(pipeline) == 1 + + class TestMatch: + """Test the `match` method of the Pipeline class.""" + + def test_with_query_dict(self) -> None: + """Test the `match` method with query dictionary.""" + + pipeline = Pipeline() + query = {"status": "active", "age": {"$gte": 18}} + pipeline.match(query=query) + + expected_stage = Match(query=query) + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Match) + assert pipeline.export() == [{"$match": query}] + + def test_with_kwargs(self) -> None: + """Test the `match` method with keyword arguments.""" + + pipeline = Pipeline() + pipeline.match(status="active", age=25) + + expected_query = {"status": "active", "age": 25} + expected_stage = Match(query=expected_query) + assert pipeline[0] == expected_stage + + def test_with_query_and_kwargs_combined(self) -> None: + """Test the `match` method with both query dict and kwargs.""" + + pipeline = Pipeline() + pipeline.match(query={"status": "active"}, age=25, category="premium") + + expected_query = {"status": "active", "age": 25, "category": "premium"} + expected_stage = Match(query=expected_query) + assert pipeline[0] == expected_stage + + def test_with_empty_query(self) -> None: + """Test the `match` method with empty query.""" + + pipeline = Pipeline() + pipeline.match() + + expected_stage = Match(query={}) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that match method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.match(status="active") + + assert result is pipeline + assert len(pipeline) == 1 + + class TestOut: + """Test the `out` method of the Pipeline class.""" + + def test_with_collection_name(self) -> None: + """Test the `out` method with collection name.""" + + pipeline = Pipeline() + pipeline.out(collection="result_collection") + + expected_stage = Out(collection="result_collection") + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Out) + + def test_with_coll_parameter(self) -> None: + """Test the `out` method with coll parameter.""" + + pipeline = Pipeline() + pipeline.out(coll="result_collection") + + expected_stage = Out(collection="result_collection") + assert pipeline[0] == expected_stage + + def test_with_database_name(self) -> None: + """Test the `out` method with database name.""" + + pipeline = Pipeline() + pipeline.out(collection="result_collection", db="analytics_db") + + expected_stage = Out(collection="result_collection", db="analytics_db") + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that out method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.out(collection="results") + + assert result is pipeline + assert len(pipeline) == 1 + + class TestProject: + """Test the `project` method of the Pipeline class.""" + + def test_with_include_fields(self) -> None: + """Test the `project` method with include fields.""" + + pipeline = Pipeline() + pipeline.project(fields=["name", "age"], include=True) + + expected_stage = Project(fields=["name", "age"], include=True) + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Project) + + def test_with_exclude_fields(self) -> None: + """Test the `project` method with exclude fields.""" + + pipeline = Pipeline() + pipeline.project(fields=["password", "internal_id"], exclude=True) + + expected_stage = Project(fields=["password", "internal_id"], exclude=True) + assert pipeline[0] == expected_stage + + def test_with_projection_dict(self) -> None: + """Test the `project` method with projection dictionary.""" + + pipeline = Pipeline() + projection = { + "name": 1, + "age": 1, + "full_name": {"$concat": ["$first_name", " ", "$last_name"]}, + } + pipeline.project(projection=projection) + + expected_stage = Project(projection=projection) + assert pipeline[0] == expected_stage + + def test_with_kwargs(self) -> None: + """Test the `project` method with keyword arguments.""" + + pipeline = Pipeline() + pipeline.project(name=1, age=1, status=0) + + expected_projection = {"name": 1, "age": 1, "status": 0} + expected_stage = Project(projection=expected_projection) + assert pipeline[0] == expected_stage + + def test_with_include_parameter_as_dict(self) -> None: + """Test the `project` method with include as dictionary.""" + + pipeline = Pipeline() + include_dict = {"name": 1, "age": 1} + pipeline.project(include=include_dict) + + expected_stage = Project(include=include_dict) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that project method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.project(fields=["name", "age"], include=True) + + assert result is pipeline + assert len(pipeline) == 1 + + class TestReplaceRoot: + """Test the `replace_root` method of the Pipeline class.""" + + def test_with_path(self) -> None: + """Test the `replace_root` method with path parameter.""" + + pipeline = Pipeline() + pipeline.replace_root(path="user_info") + + expected_stage = ReplaceRoot(path="user_info") + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], ReplaceRoot) + + def test_with_path_to_new_root(self) -> None: + """Test the `replace_root` method with path_to_new_root parameter.""" + + pipeline = Pipeline() + pipeline.replace_root(path_to_new_root="user_info") + + expected_stage = ReplaceRoot(path="user_info") + assert pipeline[0] == expected_stage + + def test_with_document(self) -> None: + """Test the `replace_root` method with document parameter.""" + + pipeline = Pipeline() + doc = {"name": "$user.name", "email": "$user.email", "full_info": "$$ROOT"} + pipeline.replace_root(document=doc) + + expected_stage = ReplaceRoot(document=doc) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that replace_root method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.replace_root(path="user_info") + + assert result is pipeline + assert len(pipeline) == 1 + + class TestReplaceWith: + """Test the `replace_with` method of the Pipeline class.""" + + def test_with_path(self) -> None: + """Test the `replace_with` method with path parameter.""" + + pipeline = Pipeline() + pipeline.replace_with(path="user_info") + + expected_stage = ReplaceRoot(path="user_info") + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], ReplaceRoot) + + def test_with_path_to_new_root(self) -> None: + """Test the `replace_with` method with path_to_new_root parameter.""" + + pipeline = Pipeline() + pipeline.replace_with(path_to_new_root="user_info") + + expected_stage = ReplaceRoot(path="user_info") + assert pipeline[0] == expected_stage + + def test_with_document(self) -> None: + """Test the `replace_with` method with document parameter.""" + + pipeline = Pipeline() + doc = {"name": "$user.name", "email": "$user.email"} + pipeline.replace_with(document=doc) + + expected_stage = ReplaceRoot(document=doc) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that replace_with method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.replace_with(path="user_info") + + assert result is pipeline + assert len(pipeline) == 1 + + class TestSample: + """Test the `sample` method of the Pipeline class.""" + + def test_with_value(self) -> None: + """Test the `sample` method with a value.""" + + pipeline = Pipeline() + pipeline.sample(value=5) + + expected_stage = Sample(value=5) + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Sample) + + def test_with_large_value(self) -> None: + """Test the `sample` method with a large value.""" + + pipeline = Pipeline() + pipeline.sample(value=1000) + + expected_stage = Sample(value=1000) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that sample method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.sample(value=10) + + assert result is pipeline + assert len(pipeline) == 1 + + class TestSet: + """Test the `set` method of the Pipeline class.""" + + def test_with_document(self) -> None: + """Test the `set` method with document parameter.""" + + pipeline = Pipeline() + document = { + "full_name": {"$concat": ["$first_name", " ", "$last_name"]}, + "is_adult": {"$gte": ["$age", 18]}, + } + pipeline.set(document=document) + + expected_stage = Set(document=document) + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Set) + + def test_with_kwargs(self) -> None: + """Test the `set` method with keyword arguments.""" + + pipeline = Pipeline() + pipeline.set(status="active", last_updated="$NOW") + + expected_document = {"status": "active", "last_updated": "$NOW"} + expected_stage = Set(document=expected_document) + assert pipeline[0] == expected_stage + + def test_with_document_and_kwargs(self) -> None: + """Test the `set` method with both document and kwargs.""" + + pipeline = Pipeline() + document = {"computed_field": {"$add": ["$a", "$b"]}} + pipeline.set(document=document, status="active", flag=True) + + expected_document = { + "computed_field": {"$add": ["$a", "$b"]}, + "status": "active", + "flag": True, + } + expected_stage = Set(document=expected_document) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that set method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.set(status="active") + + assert result is pipeline + assert len(pipeline) == 1 + + class TestSkip: + """Test the `skip` method of the Pipeline class.""" + + def test_with_value(self) -> None: + """Test the `skip` method with a value.""" + + pipeline = Pipeline() + pipeline.skip(value=10) + + expected_stage = Skip(value=10) + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Skip) + assert pipeline.export() == [{"$skip": 10}] + + def test_with_large_value(self) -> None: + """Test the `skip` method with a large value.""" + + pipeline = Pipeline() + pipeline.skip(value=1000000) + + expected_stage = Skip(value=1000000) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that skip method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.skip(value=5) + + assert result is pipeline + assert len(pipeline) == 1 + + class TestSort: + """Test the `sort` method of the Pipeline class.""" + + def test_with_query_dict(self) -> None: + """Test the `sort` method with query dictionary.""" + + pipeline = Pipeline() + query = {"name": 1, "age": -1} + pipeline.sort(query=query) + + expected_stage = Sort(query=query) + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Sort) + + def test_with_ascending_fields(self) -> None: + """Test the `sort` method with ascending fields.""" + + pipeline = Pipeline() + pipeline.sort(ascending=["name", "email"]) + + expected_stage = Sort(ascending=["name", "email"]) + assert pipeline[0] == expected_stage + + def test_with_descending_fields(self) -> None: + """Test the `sort` method with descending fields.""" + + pipeline = Pipeline() + pipeline.sort(descending=["created_date", "score"]) + + expected_stage = Sort(descending=["created_date", "score"]) + assert pipeline[0] == expected_stage + + def test_with_by_parameter(self) -> None: + """Test the `sort` method with by parameter.""" + + pipeline = Pipeline() + pipeline.sort(by=["name", "age"]) + + expected_stage = Sort(by=["name", "age"]) + assert pipeline[0] == expected_stage + + def test_with_kwargs(self) -> None: + """Test the `sort` method with keyword arguments.""" + + pipeline = Pipeline() + pipeline.sort(name=1, age=-1, score=1) + + expected_query = {"name": 1, "age": -1, "score": 1} + expected_stage = Sort(query=expected_query) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that sort method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.sort(ascending=["name"]) + + assert result is pipeline + assert len(pipeline) == 1 + + class TestSortByCount: + """Test the `sort_by_count` method of the Pipeline class.""" + + def test_with_field(self) -> None: + """Test the `sort_by_count` method with a field.""" + + pipeline = Pipeline() + pipeline.sort_by_count(by="category") + + expected_stage = SortByCount(by="category") + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], SortByCount) + + def test_with_expression(self) -> None: + """Test the `sort_by_count` method with an expression.""" + + pipeline = Pipeline() + pipeline.sort_by_count(by="$status") + + expected_stage = SortByCount(by="$status") + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that sort_by_count method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.sort_by_count(by="category") + + assert result is pipeline + assert len(pipeline) == 1 + + class TestUnionWith: + """Test the `union_with` method of the Pipeline class.""" + + def test_with_collection_name(self) -> None: + """Test the `union_with` method with collection name.""" + + pipeline = Pipeline() + pipeline.union_with(collection="other_collection", coll=None) + + expected_stage = UnionWith(collection="other_collection") + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], UnionWith) + + def test_with_coll_parameter(self) -> None: + """Test the `union_with` method with coll parameter.""" + + pipeline = Pipeline() + pipeline.union_with(collection=None, coll="other_collection") + + expected_stage = UnionWith(collection="other_collection") + assert pipeline[0] == expected_stage + + def test_with_pipeline(self) -> None: + """Test the `union_with` method with pipeline parameter.""" + + pipeline = Pipeline() + union_pipeline = [ + {"$match": {"status": "active"}}, + {"$project": {"name": 1}}, + ] + pipeline.union_with( + collection="other_collection", coll=None, pipeline=union_pipeline + ) + + expected_stage = UnionWith( + collection="other_collection", pipeline=union_pipeline + ) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that union_with method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.union_with(collection="other_collection", coll=None) + + assert result is pipeline + assert len(pipeline) == 1 + + class TestUnwind: + """Test the `unwind` method of the Pipeline class.""" + + def test_with_path(self) -> None: + """Test the `unwind` method with path parameter.""" + + pipeline = Pipeline() + pipeline.unwind(path="tags") + + expected_stage = Unwind(path="tags") + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Unwind) + + def test_with_path_to_array(self) -> None: + """Test the `unwind` method with path_to_array parameter.""" + + pipeline = Pipeline() + pipeline.unwind(path_to_array="tags") + + expected_stage = Unwind(path="tags") + assert pipeline[0] == expected_stage + + def test_with_include_array_index(self) -> None: + """Test the `unwind` method with include_array_index parameter.""" + + pipeline = Pipeline() + pipeline.unwind(path="tags", include_array_index="tag_index") + + expected_stage = Unwind(path="tags", include_array_index="tag_index") + assert pipeline[0] == expected_stage + + def test_with_always_parameter(self) -> None: + """Test the `unwind` method with always parameter.""" + + pipeline = Pipeline() + pipeline.unwind(path="tags", always=True) + + expected_stage = Unwind(path="tags", always=True) + assert pipeline[0] == expected_stage + + def test_with_preserve_null_and_empty_arrays(self) -> None: + """Test the `unwind` method with preserve_null_and_empty_arrays parameter.""" + + pipeline = Pipeline() + pipeline.unwind(path="tags", preserve_null_and_empty_arrays=True) + + expected_stage = Unwind(path="tags", always=True) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that unwind method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.unwind(path="tags") + + assert result is pipeline + assert len(pipeline) == 1 + + class TestUnset: + """Test the `unset` method of the Pipeline class.""" + + def test_with_single_field(self) -> None: + """Test the `unset` method with a single field.""" + + pipeline = Pipeline() + pipeline.unset(field="password") + + expected_stage = Unset(field="password") + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], Unset) + + def test_with_multiple_fields(self) -> None: + """Test the `unset` method with multiple fields.""" + + pipeline = Pipeline() + pipeline.unset(fields=["password", "internal_id", "temp_data"]) + + expected_stage = Unset(fields=["password", "internal_id", "temp_data"]) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that unset method returns self for chaining.""" + + pipeline = Pipeline() + result = pipeline.unset(field="password") + + assert result is pipeline + assert len(pipeline) == 1 + + class TestVectorSearch: + """Test the `vector_search` method of the Pipeline class.""" + + def test_with_required_params(self) -> None: + """Test the `vector_search` method with required parameters.""" + + pipeline = Pipeline() + query_vector = [0.1, 0.2, 0.3, 0.4, 0.5] + pipeline.vector_search( + index="vector_index", + path="embedding", + query_vector=query_vector, + num_candidates=100, + limit=10, + ) + + expected_stage = VectorSearch( + index="vector_index", + path="embedding", + query_vector=query_vector, + num_candidates=100, + limit=10, + ) + assert pipeline[0] == expected_stage + assert isinstance(pipeline[0], VectorSearch) + + def test_with_filter(self) -> None: + """Test the `vector_search` method with filter parameter.""" + + pipeline = Pipeline() + query_vector = [0.1, 0.2, 0.3, 0.4, 0.5] + filter_dict = {"status": "active", "category": "premium"} + pipeline.vector_search( + index="vector_index", + path="embedding", + query_vector=query_vector, + num_candidates=100, + limit=10, + filter=filter_dict, + ) + + expected_stage = VectorSearch( + index="vector_index", + path="embedding", + query_vector=query_vector, + num_candidates=100, + limit=10, + filter=filter_dict, + ) + assert pipeline[0] == expected_stage + + def test_chaining(self) -> None: + """Test that vector_search method returns self for chaining.""" + + pipeline = Pipeline() + query_vector = [0.1, 0.2, 0.3, 0.4, 0.5] + result = pipeline.vector_search( + index="vector_index", + path="embedding", + query_vector=query_vector, + num_candidates=100, + limit=10, + ) + + assert result is pipeline + assert len(pipeline) == 1 + + class TestMethodChaining: + """Test method chaining across different stage methods.""" + + def test_multiple_stages_chaining(self) -> None: + """Test chaining multiple different stage methods.""" + + pipeline = Pipeline() + result = ( + pipeline.match(status="active") + .project(fields=["name", "age"], include=True) + .sort(ascending=["name"]) + .limit(value=10) + .skip(value=5) + ) + + assert result is pipeline + assert len(pipeline) == 5 + assert isinstance(pipeline[0], Match) + assert isinstance(pipeline[1], Project) + assert isinstance(pipeline[2], Sort) + assert isinstance(pipeline[3], Limit) + assert isinstance(pipeline[4], Skip) + + def test_complex_pipeline_chaining(self) -> None: + """Test a complex pipeline with various stages.""" + + pipeline = Pipeline() + result = ( + pipeline.match(category="electronics") + .lookup(name="user_info", right="users", on="user_id") + .unwind(path="user_info") + .group( + by="user_info.region", query={"total_sales": {"$sum": "$amount"}} + ) + .sort(descending=["total_sales"]) + .limit(value=5) + ) + + assert result is pipeline + assert len(pipeline) == 6 + + +def test_pipeline_with_stages_and_raw_expressions() -> None: + """Test that the Pipeline class can be instantiated with stages and raw expressions.""" + + pipeline = Pipeline() + pipeline.match(query={"name": "John"}) + + assert pipeline.export() == [{"$match": {"name": "John"}}] + + pipeline.append( + { + "$redact": { + "$cond": { + "if": {"$eq": ["$name", "John"]}, + "then": "$$DESCEND", + "else": "$$PRUNE", + } + } + }, + ) + + assert pipeline.export() == [ + {"$match": {"name": "John"}}, + { + "$redact": { + "$cond": { + "if": {"$eq": ["$name", "John"]}, + "then": "$$DESCEND", + "else": "$$PRUNE", + } + } + }, + ] diff --git a/tests/tests_monggregate/test_utils.py b/tests/tests_monggregate/test_utils.py index a4da7f85..902642a8 100644 --- a/tests/tests_monggregate/test_utils.py +++ b/tests/tests_monggregate/test_utils.py @@ -1,4 +1,5 @@ -import pytest +"""Tests for the `monggregate.utils` module.""" + from monggregate.utils import ( to_unique_list, validate_field_path, @@ -7,7 +8,7 @@ ) -def test_str_enum(): +def test_str_enum() -> None: """Test that StrEnum returns the correct value.""" class TestEnum(StrEnum): @@ -17,7 +18,7 @@ class TestEnum(StrEnum): assert str(TestEnum.VALUE) == "value" -def test_to_unique_list(): +def test_to_unique_list() -> None: """Test that to_unique_list converts inputs to a list of unique values.""" # Test with a string assert to_unique_list("field") == ["field"] @@ -36,7 +37,7 @@ def test_to_unique_list(): assert to_unique_list(non_convertible) is non_convertible -def test_validate_field_path(): +def test_validate_field_path() -> None: """Test that validate_field_path adds $ prefix to paths when needed.""" # Path without $ prefix assert validate_field_path("field") == "$field" @@ -48,7 +49,7 @@ def test_validate_field_path(): assert validate_field_path(None) is None -def test_validate_field_paths(): +def test_validate_field_paths() -> None: """Test that validate_field_paths converts inputs to a list of unique values.""" # Test with a list assert validate_field_paths(["field1", "field2", "field1"]) == [ diff --git a/tests/tests_monggregate/tests_operators/test_operator.py b/tests/tests_monggregate/tests_operators/test_operator.py index 67503efa..f23692aa 100644 --- a/tests/tests_monggregate/tests_operators/test_operator.py +++ b/tests/tests_monggregate/tests_operators/test_operator.py @@ -3,41 +3,7 @@ import pytest from monggregate.operators.operator import Operator, OperatorEnum - - -def generate_enum_member_name(value: str) -> str: - """Generate the expected enum member name from a value. - - Args: - value: The enum value (with or without $ prefix) - - Returns: - The expected member name in UPPER_SNAKE_CASE format - - Example: - >>> generate_enum_member_name("$addToSet") - 'ADD_TO_SET' - >>> generate_enum_member_name("bottomN") - 'BOTTOM_N' - """ - # Remove the $ prefix if present - value_name = value[1:] if value.startswith("$") else value - - # Split camelCase into words - words = [] - current_word = value_name[0] - - for char in value_name[1:]: - if char.isupper(): - words.append(current_word) - current_word = char - else: - current_word += char - - words.append(current_word) - - # Convert to UPPER_SNAKE_CASE - return "_".join(word.upper() for word in words) +from tests.utils import generate_enum_member_name class TestOperator: @@ -52,13 +18,6 @@ def test_is_abstract(self) -> None: class TestOperatorEnum: """Tests for the `OperatorEnum` class.""" - @pytest.mark.xfail( - reason="""Some operators are not following the naming convention. - Ex: INDEX_OF_CP - - Need to review the generate_enum_member_name function. - """ - ) def test_naming_convention(self) -> None: """Test that the naming convention is correct.""" mismatches = [] diff --git a/tests/tests_monggregate/tests_operators/tests_accumulators/test_accumulator.py b/tests/tests_monggregate/tests_operators/tests_accumulators/test_accumulator.py index 352f8e05..8705e93f 100644 --- a/tests/tests_monggregate/tests_operators/tests_accumulators/test_accumulator.py +++ b/tests/tests_monggregate/tests_operators/tests_accumulators/test_accumulator.py @@ -3,41 +3,7 @@ import pytest from monggregate.operators.accumulators.accumulator import Accumulator, AccumulatorEnum - - -def generate_enum_member_name(value: str) -> str: - """Generate the expected enum member name from a value. - - Args: - value: The enum value (with or without $ prefix) - - Returns: - The expected member name in UPPER_SNAKE_CASE format - - Example: - >>> generate_enum_member_name("$addToSet") - 'ADD_TO_SET' - >>> generate_enum_member_name("bottomN") - 'BOTTOM_N' - """ - # Remove the $ prefix if present - value_name = value[1:] if value.startswith("$") else value - - # Split camelCase into words - words = [] - current_word = value_name[0] - - for char in value_name[1:]: - if char.isupper(): - words.append(current_word) - current_word = char - else: - current_word += char - - words.append(current_word) - - # Convert to UPPER_SNAKE_CASE - return "_".join(word.upper() for word in words) +from tests.utils import generate_enum_member_name class TestAccumulator: diff --git a/tests/tests_monggregate/tests_operators/tests_arithmetic/test_arithmetic.py b/tests/tests_monggregate/tests_operators/tests_arithmetic/test_arithmetic.py index 4d271873..2a3ef7a0 100644 --- a/tests/tests_monggregate/tests_operators/tests_arithmetic/test_arithmetic.py +++ b/tests/tests_monggregate/tests_operators/tests_arithmetic/test_arithmetic.py @@ -6,41 +6,7 @@ ArithmeticOperator, ArithmeticOperatorEnum, ) - - -def generate_enum_member_name(value: str) -> str: - """Generate the expected enum member name from a value. - - Args: - value: The enum value (with or without $ prefix) - - Returns: - The expected member name in UPPER_SNAKE_CASE format - - Example: - >>> generate_enum_member_name("$addToSet") - 'ADD_TO_SET' - >>> generate_enum_member_name("bottomN") - 'BOTTOM_N' - """ - # Remove the $ prefix if present - value_name = value[1:] if value.startswith("$") else value - - # Split camelCase into words - words = [] - current_word = value_name[0] - - for char in value_name[1:]: - if char.isupper(): - words.append(current_word) - current_word = char - else: - current_word += char - - words.append(current_word) - - # Convert to UPPER_SNAKE_CASE - return "_".join(word.upper() for word in words) +from tests.utils import generate_enum_member_name class TestArithmeticOperator: diff --git a/tests/tests_monggregate/tests_operators/tests_array/test_array.py b/tests/tests_monggregate/tests_operators/tests_array/test_array.py index 27ac2270..a5c0ef65 100644 --- a/tests/tests_monggregate/tests_operators/tests_array/test_array.py +++ b/tests/tests_monggregate/tests_operators/tests_array/test_array.py @@ -3,41 +3,7 @@ import pytest from monggregate.operators.array.array import ArrayOperator, ArrayOperatorEnum - - -def generate_enum_member_name(value: str) -> str: - """Generate the expected enum member name from a value. - - Args: - value: The enum value (with or without $ prefix) - - Returns: - The expected member name in UPPER_SNAKE_CASE format - - Example: - >>> generate_enum_member_name("$addToSet") - 'ADD_TO_SET' - >>> generate_enum_member_name("bottomN") - 'BOTTOM_N' - """ - # Remove the $ prefix if present - value_name = value[1:] if value.startswith("$") else value - - # Split camelCase into words - words = [] - current_word = value_name[0] - - for char in value_name[1:]: - if char.isupper(): - words.append(current_word) - current_word = char - else: - current_word += char - - words.append(current_word) - - # Convert to UPPER_SNAKE_CASE - return "_".join(word.upper() for word in words) +from tests.utils import generate_enum_member_name class TestArrayOperator: diff --git a/tests/tests_monggregate/tests_operators/tests_boolean/test_boolean.py b/tests/tests_monggregate/tests_operators/tests_boolean/test_boolean.py index 8762bf6e..dba40449 100644 --- a/tests/tests_monggregate/tests_operators/tests_boolean/test_boolean.py +++ b/tests/tests_monggregate/tests_operators/tests_boolean/test_boolean.py @@ -3,41 +3,7 @@ import pytest from monggregate.operators.boolean.boolean import BooleanOperator, BooleanOperatorEnum - - -def generate_enum_member_name(value: str) -> str: - """Generate the expected enum member name from a value. - - Args: - value: The enum value (with or without $ prefix) - - Returns: - The expected member name in UPPER_SNAKE_CASE format - - Example: - >>> generate_enum_member_name("$addToSet") - 'ADD_TO_SET' - >>> generate_enum_member_name("bottomN") - 'BOTTOM_N' - """ - # Remove the $ prefix if present - value_name = value[1:] if value.startswith("$") else value - - # Split camelCase into words - words = [] - current_word = value_name[0] - - for char in value_name[1:]: - if char.isupper(): - words.append(current_word) - current_word = char - else: - current_word += char - - words.append(current_word) - - # Convert to UPPER_SNAKE_CASE - return "_".join(word.upper() for word in words) +from tests.utils import generate_enum_member_name class TestBooleanOperator: diff --git a/tests/tests_monggregate/tests_operators/tests_comparison/test_comparator.py b/tests/tests_monggregate/tests_operators/tests_comparison/test_comparator.py index dad0a5d8..41fb9461 100644 --- a/tests/tests_monggregate/tests_operators/tests_comparison/test_comparator.py +++ b/tests/tests_monggregate/tests_operators/tests_comparison/test_comparator.py @@ -3,41 +3,7 @@ import pytest from monggregate.operators.comparison.comparator import Comparator, ComparatorEnum - - -def generate_enum_member_name(value: str) -> str: - """Generate the expected enum member name from a value. - - Args: - value: The enum value (with or without $ prefix) - - Returns: - The expected member name in UPPER_SNAKE_CASE format - - Example: - >>> generate_enum_member_name("$addToSet") - 'ADD_TO_SET' - >>> generate_enum_member_name("bottomN") - 'BOTTOM_N' - """ - # Remove the $ prefix if present - value_name = value[1:] if value.startswith("$") else value - - # Split camelCase into words - words = [] - current_word = value_name[0] - - for char in value_name[1:]: - if char.isupper(): - words.append(current_word) - current_word = char - else: - current_word += char - - words.append(current_word) - - # Convert to UPPER_SNAKE_CASE - return "_".join(word.upper() for word in words) +from tests.utils import generate_enum_member_name class TestComparator: diff --git a/tests/tests_monggregate/tests_operators/tests_conditional/test_conditional.py b/tests/tests_monggregate/tests_operators/tests_conditional/test_conditional.py index af24cb21..27c74b07 100644 --- a/tests/tests_monggregate/tests_operators/tests_conditional/test_conditional.py +++ b/tests/tests_monggregate/tests_operators/tests_conditional/test_conditional.py @@ -6,41 +6,7 @@ ConditionalOperator, ConditionalOperatorEnum, ) - - -def generate_enum_member_name(value: str) -> str: - """Generate the expected enum member name from a value. - - Args: - value: The enum value (with or without $ prefix) - - Returns: - The expected member name in UPPER_SNAKE_CASE format - - Example: - >>> generate_enum_member_name("$addToSet") - 'ADD_TO_SET' - >>> generate_enum_member_name("bottomN") - 'BOTTOM_N' - """ - # Remove the $ prefix if present - value_name = value[1:] if value.startswith("$") else value - - # Split camelCase into words - words = [] - current_word = value_name[0] - - for char in value_name[1:]: - if char.isupper(): - words.append(current_word) - current_word = char - else: - current_word += char - - words.append(current_word) - - # Convert to UPPER_SNAKE_CASE - return "_".join(word.upper() for word in words) +from tests.utils import generate_enum_member_name class TestConditionalOperator: diff --git a/tests/tests_monggregate/tests_operators/tests_custom/test_custom.py b/tests/tests_monggregate/tests_operators/tests_custom/test_custom.py index f29c953a..2b95b4ff 100644 --- a/tests/tests_monggregate/tests_operators/tests_custom/test_custom.py +++ b/tests/tests_monggregate/tests_operators/tests_custom/test_custom.py @@ -3,41 +3,7 @@ import pytest from monggregate.operators.custom.custom import CustomOperator, CustomOperatorEnum - - -def generate_enum_member_name(value: str) -> str: - """Generate the expected enum member name from a value. - - Args: - value: The enum value (with or without $ prefix) - - Returns: - The expected member name in UPPER_SNAKE_CASE format - - Example: - >>> generate_enum_member_name("$addToSet") - 'ADD_TO_SET' - >>> generate_enum_member_name("bottomN") - 'BOTTOM_N' - """ - # Remove the $ prefix if present - value_name = value[1:] if value.startswith("$") else value - - # Split camelCase into words - words = [] - current_word = value_name[0] - - for char in value_name[1:]: - if char.isupper(): - words.append(current_word) - current_word = char - else: - current_word += char - - words.append(current_word) - - # Convert to UPPER_SNAKE_CASE - return "_".join(word.upper() for word in words) +from tests.utils import generate_enum_member_name class TestCustomOperator: diff --git a/tests/tests_monggregate/tests_operators/tests_data_size/test_data_size.py b/tests/tests_monggregate/tests_operators/tests_data_size/test_data_size.py index b5543299..16b0b7a8 100644 --- a/tests/tests_monggregate/tests_operators/tests_data_size/test_data_size.py +++ b/tests/tests_monggregate/tests_operators/tests_data_size/test_data_size.py @@ -6,41 +6,7 @@ DataSizeOperator, DataSizeOperatorEnum, ) - - -def generate_enum_member_name(value: str) -> str: - """Generate the expected enum member name from a value. - - Args: - value: The enum value (with or without $ prefix) - - Returns: - The expected member name in UPPER_SNAKE_CASE format - - Example: - >>> generate_enum_member_name("$addToSet") - 'ADD_TO_SET' - >>> generate_enum_member_name("bottomN") - 'BOTTOM_N' - """ - # Remove the $ prefix if present - value_name = value[1:] if value.startswith("$") else value - - # Split camelCase into words - words = [] - current_word = value_name[0] - - for char in value_name[1:]: - if char.isupper(): - words.append(current_word) - current_word = char - else: - current_word += char - - words.append(current_word) - - # Convert to UPPER_SNAKE_CASE - return "_".join(word.upper() for word in words) +from tests.utils import generate_enum_member_name class TestDataSizeOperator: diff --git a/tests/tests_monggregate/tests_operators/tests_date/test_date.py b/tests/tests_monggregate/tests_operators/tests_date/test_date.py index 2e47c02e..834f58dc 100644 --- a/tests/tests_monggregate/tests_operators/tests_date/test_date.py +++ b/tests/tests_monggregate/tests_operators/tests_date/test_date.py @@ -3,41 +3,7 @@ import pytest from monggregate.operators.date.date import DateOperator, DateOperatorEnum - - -def generate_enum_member_name(value: str) -> str: - """Generate the expected enum member name from a value. - - Args: - value: The enum value (with or without $ prefix) - - Returns: - The expected member name in UPPER_SNAKE_CASE format - - Example: - >>> generate_enum_member_name("$addToSet") - 'ADD_TO_SET' - >>> generate_enum_member_name("bottomN") - 'BOTTOM_N' - """ - # Remove the $ prefix if present - value_name = value[1:] if value.startswith("$") else value - - # Split camelCase into words - words = [] - current_word = value_name[0] - - for char in value_name[1:]: - if char.isupper(): - words.append(current_word) - current_word = char - else: - current_word += char - - words.append(current_word) - - # Convert to UPPER_SNAKE_CASE - return "_".join(word.upper() for word in words) +from tests.utils import generate_enum_member_name class TestDateOperator: diff --git a/tests/tests_monggregate/tests_operators/tests_objects/test_object_.py b/tests/tests_monggregate/tests_operators/tests_objects/test_object_.py index 68194c01..ec4430eb 100644 --- a/tests/tests_monggregate/tests_operators/tests_objects/test_object_.py +++ b/tests/tests_monggregate/tests_operators/tests_objects/test_object_.py @@ -3,41 +3,7 @@ import pytest from monggregate.operators.objects.object_ import ObjectOperator, ObjectOperatorEnum - - -def generate_enum_member_name(value: str) -> str: - """Generate the expected enum member name from a value. - - Args: - value: The enum value (with or without $ prefix) - - Returns: - The expected member name in UPPER_SNAKE_CASE format - - Example: - >>> generate_enum_member_name("$addToSet") - 'ADD_TO_SET' - >>> generate_enum_member_name("bottomN") - 'BOTTOM_N' - """ - # Remove the $ prefix if present - value_name = value[1:] if value.startswith("$") else value - - # Split camelCase into words - words = [] - current_word = value_name[0] - - for char in value_name[1:]: - if char.isupper(): - words.append(current_word) - current_word = char - else: - current_word += char - - words.append(current_word) - - # Convert to UPPER_SNAKE_CASE - return "_".join(word.upper() for word in words) +from tests.utils import generate_enum_member_name class TestObjectOperator: diff --git a/tests/tests_monggregate/tests_operators/tests_strings/test_string.py b/tests/tests_monggregate/tests_operators/tests_strings/test_string.py index f40f3811..0ee4805f 100644 --- a/tests/tests_monggregate/tests_operators/tests_strings/test_string.py +++ b/tests/tests_monggregate/tests_operators/tests_strings/test_string.py @@ -3,41 +3,7 @@ import pytest from monggregate.operators.strings.string import StringOperator, StringOperatorEnum - - -def generate_enum_member_name(value: str) -> str: - """Generate the expected enum member name from a value. - - Args: - value: The enum value (with or without $ prefix) - - Returns: - The expected member name in UPPER_SNAKE_CASE format - - Example: - >>> generate_enum_member_name("$addToSet") - 'ADD_TO_SET' - >>> generate_enum_member_name("bottomN") - 'BOTTOM_N' - """ - # Remove the $ prefix if present - value_name = value[1:] if value.startswith("$") else value - - # Split camelCase into words - words = [] - current_word = value_name[0] - - for char in value_name[1:]: - if char.isupper(): - words.append(current_word) - current_word = char - else: - current_word += char - - words.append(current_word) - - # Convert to UPPER_SNAKE_CASE - return "_".join(word.upper() for word in words) +from tests.utils import generate_enum_member_name class TestStringOperator: @@ -52,13 +18,6 @@ def test_is_abstract(self) -> None: class TestStringOperatorEnum: """Tests for the `StringOperatorEnum` class.""" - @pytest.mark.xfail( - reason="""Some operators are not following the naming convention. - Ex: CONCAT_WS - - Need to review the generate_enum_member_name function. - """ - ) def test_naming_convention(self) -> None: """Test that the naming convention is correct.""" mismatches = [] diff --git a/tests/tests_monggregate/tests_stages/test_lookup.py b/tests/tests_monggregate/tests_stages/test_lookup.py index cd04bb12..f1c933f1 100644 --- a/tests/tests_monggregate/tests_stages/test_lookup.py +++ b/tests/tests_monggregate/tests_stages/test_lookup.py @@ -33,13 +33,15 @@ def test_expression(self) -> None: "localField": "foreign_field_id", "foreignField": "_id", "as": "foreign_documents", + "let": None, + "pipeline": None, } } # fmt: on # NOTE: The bug is that left_on and right_on are required in the code while they should # be optional in that case. - @pytest.mark.xfail(reason="This should be valid. Bug in the code.") + def test_expression_with_correlated_subquery(self) -> None: """Test that the expression method returns the correct expression.""" @@ -59,6 +61,8 @@ def test_expression_with_correlated_subquery(self) -> None: assert lookup.expression == { "$lookup": { "from": "right_collection", + "localField": None, + "foreignField": None, "let": {"variable": "$local_variable"}, "pipeline": [{"$match": {"$expr": {"$gte": ["$$variable", "$foreign_field_quantity"]}}}], "as": "foreign_documents", @@ -66,9 +70,6 @@ def test_expression_with_correlated_subquery(self) -> None: } # fmt: on - # NOTE: The bug is that left_on and right_on are required in the code while they should - # be optional in that case. - @pytest.mark.xfail(reason="This should be valid. Bug in the code.") def test_expression_with_uncorrelated_subquery(self) -> None: """Test that the expression method returns the correct expression.""" @@ -85,6 +86,9 @@ def test_expression_with_uncorrelated_subquery(self) -> None: assert lookup.expression == { "$lookup": { "from": "holidays", + "let": None, + "localField": None, + "foreignField": None, "pipeline": [{"$match": {"year": 2018}}, {"$project": {"name": 1, "date": 1, "_id": 0}}], "as": "holidaysIn2018", } diff --git a/tests/tests_monggregate/tests_stages/test_match.py b/tests/tests_monggregate/tests_stages/test_match.py index 0bab8bc1..4b86629f 100644 --- a/tests/tests_monggregate/tests_stages/test_match.py +++ b/tests/tests_monggregate/tests_stages/test_match.py @@ -16,9 +16,6 @@ def test_expression(self) -> None: match = Match(query={"status": "active"}) assert match.expression == {"$match": {"status": "active"}} - @pytest.mark.xfail(reason="This should be valid. Bug in the code.") - # NOTE: The bug is that the $exp is inserted twice. - # once in validate_operand and once in the expression method. def test_expression_with_expr(self) -> None: """Test that the expression method returns the correct expression.""" diff --git a/tests/tests_monggregate/tests_stages/test_project.py b/tests/tests_monggregate/tests_stages/test_project.py index 4e86607f..78924f2a 100644 --- a/tests/tests_monggregate/tests_stages/test_project.py +++ b/tests/tests_monggregate/tests_stages/test_project.py @@ -49,10 +49,6 @@ def test_expression_with_exclude_as_dict(self) -> None: project = Project(exclude={"field1": 0, "field2": 0}) assert project.expression == {"$project": {"field1": 0, "field2": 0}} - @pytest.mark.xfail(reason="Bug in the code.") - # NOTE: The issue is that when using booleans, only include is used. - # We should find a mechanism so that include = !exclude and vice versa. - # Or review the logic of the code. def test_expression_with_exclude_as_bool(self) -> None: """Test that the expression method returns the correct expression with exclude.""" @@ -67,10 +63,6 @@ def test_expression_with_include_and_exclude_both_as_list_of_strings(self) -> No "$project": {"field1": 1, "field2": 1, "field3": 0, "field4": 0} } - @pytest.mark.xfail(reason="Bug in the code.") - # NOTE: The issue is that when using booleans, only include is used. - # We should find a mechanism so that include = !exclude and vice versa. - # Or review the logic of the code. def test_expression_with_include_and_exclude_both_as_dict(self) -> None: """Test that the expression method returns the correct expression with include and exclude.""" @@ -88,15 +80,13 @@ def test_expression_with_include_and_exclude_both_as_bool(self) -> None: with pytest.raises(ValueError): project = Project(include=True, exclude=True, fields=["field1", "field2"]) - @pytest.mark.xfail(reason="This fails but we might want to forbid this case.") - def test_expression_with_include_and_exclude_both_as_bool_and_list_of_strings( - self, - ) -> None: - """Test that the expression method returns the correct expression with include and exclude.""" + def test_mixed_boolean_and_list_parameters_raises_error(self) -> None: + """Test that mixing boolean and list parameters raises an appropriate error.""" - project = Project( - include=True, exclude=["field3", "field4"], fields=["field1", "field2"] - ) - assert project.expression == { - "$project": {"field1": 1, "field2": 1, "field3": 0, "field4": 0} - } + with pytest.raises( + ValueError, + match="Cannot mix boolean include/exclude with list/dict include/exclude", + ): + Project( + include=True, exclude=["field3", "field4"], fields=["field1", "field2"] + ) diff --git a/tests/tests_monggregate/tests_stages/test_sort.py b/tests/tests_monggregate/tests_stages/test_sort.py index 43751f48..6abac361 100644 --- a/tests/tests_monggregate/tests_stages/test_sort.py +++ b/tests/tests_monggregate/tests_stages/test_sort.py @@ -58,9 +58,6 @@ def test_expression_with_ascending_descending_as_list_of_strings(self) -> None: sort = Sort(ascending=["field1"], descending=["field2"]) assert sort.expression == {"$sort": {"field1": 1, "field2": -1}} - @pytest.mark.xfail( - reason="Should raise a ValueError/ValidationError but raises a KeyError. " - ) def test_expression_with_ascending_descending_as_bool(self) -> None: """Test that the expression method returns the correct expression with ascending and descending.""" diff --git a/tests/tests_monggregate/tests_stages/test_sort_by_count.py b/tests/tests_monggregate/tests_stages/test_sort_by_count.py index fae40fd7..c203b879 100644 --- a/tests/tests_monggregate/tests_stages/test_sort_by_count.py +++ b/tests/tests_monggregate/tests_stages/test_sort_by_count.py @@ -19,9 +19,8 @@ def test_expression(self) -> None: sort_by_count = SortByCount(by="field1") assert sort_by_count.expression == {"$sortByCount": "$field1"} - @pytest.mark.xfail(reason="Bug in the code.") - def test_exppression_with_other_types(self) -> None: + def test_expression_with_other_types(self) -> None: """Test that the expression method returns the correct expression with other types.""" sort_by_count = SortByCount(by=["field1", "field2"]) - assert sort_by_count.expression == {"$sortByCount": ["$field1", "$field2"]} + assert sort_by_count.expression == {"$sortByCount": {"$field1", "$field2"}} diff --git a/tests/tests_monggregate/tests_stages/tests_search/test_base.py b/tests/tests_monggregate/tests_stages/tests_search/test_base.py index 4e2d8790..65f3564e 100644 --- a/tests/tests_monggregate/tests_stages/tests_search/test_base.py +++ b/tests/tests_monggregate/tests_stages/tests_search/test_base.py @@ -52,6 +52,7 @@ class TestSearchBase: "value": "value", "gte": 1, "lte": 2, + "like": {"title": "test"}, } def test_instantiation(self) -> None: @@ -61,10 +62,6 @@ def test_instantiation(self) -> None: assert isinstance(base, SearchBase) assert isinstance(base, BaseModel) - @pytest.mark.xfail( - reason="""Broken because operator is not correctly set. - Uncomment lines in init to fix.""" - ) @pytest.mark.parametrize("operator_name", OperatorLiteral.__args__) def test_init_with_operator_name(self, operator_name: OperatorLiteral) -> None: """Tests the init method of the SearchBase class.""" @@ -72,10 +69,6 @@ def test_init_with_operator_name(self, operator_name: OperatorLiteral) -> None: search_base = SearchBase(operator_name=operator_name, **self.default_args) assert isinstance(search_base.operator, OperatorMap[operator_name]) - @pytest.mark.xfail( - reason="""Broken because collector is not correctly set. - Uncomment lines in init to fix.""" - ) def test_init_with_collector_name(self, collector_name: str = "facet") -> None: """Tests the init method of the SearchBase class.""" diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..8ce17b0e --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,94 @@ +"""Test utilities for monggregate tests.""" + + +def _remove_dollar_prefix(value: str) -> str: + """Remove the $ prefix from a value if present. + + Args: + value: The enum value that may have a $ prefix + + Returns: + The value without the $ prefix + """ + return value[1:] if value.startswith("$") else value + + +def _should_start_new_word_after_uppercase( + current_word: str, current_index: int, value_name: str +) -> bool: + """Determine if we should start a new word after encountering an uppercase letter. + + Args: + current_word: The current word being built + current_index: Current position in the string + value_name: The full string being processed + + Returns: + True if we should start a new word, False otherwise + """ + if not current_word: + return False + + # If current word ends with lowercase, start new word + if current_word[-1].islower(): + return True + + # Check if this uppercase letter starts a new word in an acronym sequence + # (like 'C' in 'CP' when followed by lowercase) + next_char_is_lowercase = ( + current_index + 1 < len(value_name) and value_name[current_index + 1].islower() + ) + current_word_ends_with_uppercase = current_word[-1].isupper() + + return next_char_is_lowercase and current_word_ends_with_uppercase + + +def _split_camel_case_to_words(value_name: str) -> list[str]: + """Split a camelCase string into individual words. + + Args: + value_name: The camelCase string to split + + Returns: + List of words extracted from the camelCase string + """ + words = [] + current_word = "" + + for i, char in enumerate(value_name): + if char.isupper() and _should_start_new_word_after_uppercase( + current_word, i, value_name + ): + words.append(current_word) + current_word = char + else: + current_word += char + + if current_word: + words.append(current_word) + + return words + + +def generate_enum_member_name(value: str) -> str: + """Generate the expected enum member name from a value. + + Args: + value: The enum value (with or without $ prefix) + + Returns: + The expected member name in UPPER_SNAKE_CASE format + + Example: + >>> generate_enum_member_name("$addToSet") + 'ADD_TO_SET' + >>> generate_enum_member_name("bottomN") + 'BOTTOM_N' + """ + cleaned_value = _remove_dollar_prefix(value) + + # Handle empty string case by returning empty result + words = _split_camel_case_to_words(cleaned_value) if cleaned_value else [] + + enum_member_name = "_".join(word.upper() for word in words) + return enum_member_name