diff --git a/README.md b/README.md index 3e1485b21f..243dc85147 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ -[![Flake8](https://img.shields.io/badge/Flake8-passed-brightgreen)](https://github.com/RuminantFarmSystems/MASM/actions/workflows/combined_format_lint_test_mypy.yml) -[![Pytest](https://img.shields.io/badge/Pytest-passed-brightgreen)](https://github.com/RuminantFarmSystems/MASM/actions/workflows/combined_format_lint_test_mypy.yml) -[![Coverage](https://img.shields.io/badge/Coverage-99%25-brightgreen)](https://github.com/RuminantFarmSystems/MASM/actions/workflows/combined_format_lint_test_mypy.yml) -[![Mypy](https://img.shields.io/badge/Mypy-1191%20errors-red)](https://github.com/RuminantFarmSystems/MASM/actions/workflows/combined_format_lint_test_mypy.yml) +[![Flake8](https://img.shields.io/badge/Flake8-failed-red)](https://github.com/RuminantFarmSystems/MASM/actions/workflows/combined_format_lint_test_mypy.yml) +[![Pytest](https://img.shields.io/badge/Pytest-failed-red)](https://github.com/RuminantFarmSystems/MASM/actions/workflows/combined_format_lint_test_mypy.yml) +[![Coverage](https://img.shields.io/badge/Coverage-%25-red)](https://github.com/RuminantFarmSystems/MASM/actions/workflows/combined_format_lint_test_mypy.yml) +[![Mypy](https://img.shields.io/badge/Mypy-1194%20errors-red)](https://github.com/RuminantFarmSystems/MASM/actions/workflows/combined_format_lint_test_mypy.yml) # RuFaS: Ruminant Farm Systems diff --git a/RUFAS/biophysical/animal/reproduction/reproduction.py b/RUFAS/biophysical/animal/reproduction/reproduction.py index c7f3d1fa5a..898f7e1e8d 100644 --- a/RUFAS/biophysical/animal/reproduction/reproduction.py +++ b/RUFAS/biophysical/animal/reproduction/reproduction.py @@ -1313,74 +1313,93 @@ def _calculate_conception_rate_on_ai_day(self) -> None: self.conception_rate = max(0.0, self.conception_rate) - def execute_cow_ed_protocol( # noqa + def execute_cow_ed_protocol( self, reproduction_data_stream: ReproductionDataStream, simulation_day: int ) -> ReproductionDataStream: """Execute the estrus detection (ED) protocol for cows.""" + self._update_ed_days(reproduction_data_stream) + if 1 <= reproduction_data_stream.days_in_milk <= AnimalConfig.voluntary_waiting_period: + reproduction_data_stream = self._repeat_estrus_simulation_before_vwp( + reproduction_data_stream, simulation_day + ) + elif reproduction_data_stream.days_in_milk > AnimalConfig.voluntary_waiting_period: + reproduction_data_stream = self._handle_ed_after_vwp(reproduction_data_stream, simulation_day) + return reproduction_data_stream + + def _update_ed_days(self, reproduction_data_stream: ReproductionDataStream) -> None: + """Update ED days statistic based on pregnancy status.""" if not reproduction_data_stream.is_pregnant: self.reproduction_statistics.ED_days += 1 else: self.reproduction_statistics.ED_days = 0 - if 1 <= reproduction_data_stream.days_in_milk <= AnimalConfig.voluntary_waiting_period: - reproduction_data_stream = self._repeat_estrus_simulation_before_vwp( - reproduction_data_stream, simulation_day + + def _handle_ed_after_vwp( + self, reproduction_data_stream: ReproductionDataStream, simulation_day: int + ) -> ReproductionDataStream: + """Handle ED protocol logic after the voluntary waiting period.""" + if ( + self.repro_state_manager.is_in(ReproStateEnum.ENTER_HERD_FROM_INIT) + and reproduction_data_stream.days_born > self.estrus_day + ): + reproduction_data_stream = self._simulate_estrus( + reproduction_data_stream, + reproduction_data_stream.days_born, + simulation_day, + animal_constants.ESTRUS_DAY_SCHEDULED_NOTE, + AnimalConfig.average_estrus_cycle_cow, + AnimalConfig.std_estrus_cycle_cow, ) - elif reproduction_data_stream.days_in_milk > AnimalConfig.voluntary_waiting_period: - if ( - self.repro_state_manager.is_in(ReproStateEnum.ENTER_HERD_FROM_INIT) - and reproduction_data_stream.days_born > self.estrus_day - ): - reproduction_data_stream = self._simulate_estrus( - reproduction_data_stream, - reproduction_data_stream.days_born, - simulation_day, - animal_constants.ESTRUS_DAY_SCHEDULED_NOTE, - AnimalConfig.average_estrus_cycle_cow, - AnimalConfig.std_estrus_cycle_cow, - ) + if self.repro_state_manager.is_in_any({ReproStateEnum.FRESH, ReproStateEnum.ENTER_HERD_FROM_INIT}): + self.repro_state_manager.enter(ReproStateEnum.WAITING_FULL_ED_CYCLE) + reproduction_data_stream.events.add_event( + reproduction_data_stream.days_born, + simulation_day, + f"Current repro state(s): {self.repro_state_manager}", + ) - if self.repro_state_manager.is_in_any({ReproStateEnum.FRESH, ReproStateEnum.ENTER_HERD_FROM_INIT}): - self.repro_state_manager.enter(ReproStateEnum.WAITING_FULL_ED_CYCLE) + if reproduction_data_stream.days_born == self.estrus_day: + reproduction_data_stream = self._handle_estrus_day_states(reproduction_data_stream, simulation_day) + + return reproduction_data_stream + + def _handle_estrus_day_states( + self, reproduction_data_stream: ReproductionDataStream, simulation_day: int + ) -> ReproductionDataStream: + """Handle estrus detection state transitions on the scheduled estrus day.""" + if self.repro_state_manager.is_in(ReproStateEnum.WAITING_SHORT_ED_CYCLE): + self.repro_state_manager.exit(ReproStateEnum.WAITING_SHORT_ED_CYCLE) + reproduction_data_stream = self._handle_estrus_detection( + reproduction_data_stream, + simulation_day, + on_estrus_detected=self._setup_ai_day_after_estrus_detected, + on_estrus_not_detected=self._enter_ovsynch_repro_state, + ) + if self.repro_state_manager.is_in(ReproStateEnum.IN_OVSYNCH): reproduction_data_stream.events.add_event( reproduction_data_stream.days_born, simulation_day, f"Current repro state(s): {self.repro_state_manager}", ) - if reproduction_data_stream.days_born == self.estrus_day: - if self.repro_state_manager.is_in(ReproStateEnum.WAITING_SHORT_ED_CYCLE): - self.repro_state_manager.exit(ReproStateEnum.WAITING_SHORT_ED_CYCLE) - reproduction_data_stream = self._handle_estrus_detection( - reproduction_data_stream, - simulation_day, - on_estrus_detected=self._setup_ai_day_after_estrus_detected, - on_estrus_not_detected=self._enter_ovsynch_repro_state, - ) - if self.repro_state_manager.is_in(ReproStateEnum.IN_OVSYNCH): - reproduction_data_stream.events.add_event( - reproduction_data_stream.days_born, - simulation_day, - f"Current repro state(s): {self.repro_state_manager}", - ) + elif self.repro_state_manager.is_in(ReproStateEnum.WAITING_FULL_ED_CYCLE): + self.repro_state_manager.exit(ReproStateEnum.WAITING_FULL_ED_CYCLE) + reproduction_data_stream = self._handle_estrus_detection( + reproduction_data_stream, + simulation_day, + on_estrus_detected=self._setup_ai_day_after_estrus_detected, + on_estrus_not_detected=self._simulate_full_estrus_cycle, + ) - elif self.repro_state_manager.is_in(ReproStateEnum.WAITING_FULL_ED_CYCLE): - self.repro_state_manager.exit(ReproStateEnum.WAITING_FULL_ED_CYCLE) - reproduction_data_stream = self._handle_estrus_detection( - reproduction_data_stream, - simulation_day, - on_estrus_detected=self._setup_ai_day_after_estrus_detected, - on_estrus_not_detected=self._simulate_full_estrus_cycle, - ) + elif self.repro_state_manager.is_in(ReproStateEnum.WAITING_FULL_ED_CYCLE_BEFORE_OVSYNCH): + self.repro_state_manager.exit(ReproStateEnum.WAITING_FULL_ED_CYCLE_BEFORE_OVSYNCH) + reproduction_data_stream = self._handle_estrus_detection( + reproduction_data_stream, + simulation_day, + on_estrus_detected=self._setup_ai_day_after_estrus_detected, + on_estrus_not_detected=self._simulate_full_estrus_cycle_before_ovsynch, + ) - elif self.repro_state_manager.is_in(ReproStateEnum.WAITING_FULL_ED_CYCLE_BEFORE_OVSYNCH): - self.repro_state_manager.exit(ReproStateEnum.WAITING_FULL_ED_CYCLE_BEFORE_OVSYNCH) - reproduction_data_stream = self._handle_estrus_detection( - reproduction_data_stream, - simulation_day, - on_estrus_detected=self._setup_ai_day_after_estrus_detected, - on_estrus_not_detected=self._simulate_full_estrus_cycle_before_ovsynch, - ) return reproduction_data_stream def _enter_ovsynch_repro_state( diff --git a/RUFAS/data_validator.py b/RUFAS/data_validator.py index b1716fb772..333688d683 100644 --- a/RUFAS/data_validator.py +++ b/RUFAS/data_validator.py @@ -1888,7 +1888,15 @@ def _evaluate_expression( Parameters ---------- expression_block : dict[str, Any] - A dictionary containing the expression block to be evaluated. + A dictionary containing the expression block to be evaluated. Supports two mutually + exclusive sub-block forms plus an optional ``save_as`` key: + + Aggregation form: + ``{"aggregation": {"function": "...", "operands": [...], "mode": "..."}, "save_as": "..."}`` + + Array-of-dicts form: + ``{"for_each": {"in": "...", "field": "...", "compare_value": "..." | "compare_field": "...", + "operator": "...", "mode": "filter" | "enforce"}, "save_as": "..."}`` eager_termination : bool Whether to raise an error if the expression is not successfully evaluated. relationship : str @@ -1904,51 +1912,89 @@ def _evaluate_expression( ------ ValueError Raises the error when the expression block contains unknown operation or missing ordered variables. + """ + if "for_each" in expression_block: + result, evaluated = self._evaluate_iterate_array_of_dicts( + expression_block["for_each"], eager_termination, relationship + ) - Notes - ----- - Expression block: - >>> { - ... "operation": "sum | difference | average | product | no_op", # optional, defaults to "no_op" - ... "apply_to": "individual | group", # optional - ... "ordered_variables": ["alias_0", "alias_1"], - ... "save_as": "alias_2" # optional - ... } + elif "aggregation" in expression_block: + result, evaluated = self._evaluate_aggregation_block( + expression_block["aggregation"], eager_termination, relationship + ) + else: + raise ValueError(f"Cross-validation error: Unknown expression block: {expression_block}") + + if evaluated and "save_as" in expression_block: + save_as_alise_name: str = expression_block["save_as"] + self._save_to_alias_pool(alias_name=save_as_alise_name, value=result) + return result, evaluated + + def _evaluate_aggregation_block( + self, aggregation_block: dict[str, Any], eager_termination: bool, relationship: str + ) -> tuple[Any, bool]: """ - operation = expression_block.get("operation", "no_op") + Evaluates an aggregation block, resolving ordered variables from the alias pool and + applying the specified aggregation operation. + + Parameters + ---------- + aggregation_block : dict[str, Any] + A dictionary containing the aggregation block to be evaluated. Expected keys: + + - ``function``: aggregation function name (e.g. ``"sum"``, ``"no_op"``). Defaults to + ``"no_op"`` when absent. + - ``operands``: list of alias names to resolve and aggregate. + - ``mode``: ``"element_wise"`` or ``"aggregate"`` — required when any resolved variable + is a list or dict. Defaults to ``"aggregate"`` for scalar variables. + eager_termination : bool + Whether to raise an error if evaluation fails. + relationship : str + The relationship being evaluated, forwarded to alias-pool lookups. + + Returns + ------- + tuple[Any, bool] + The aggregated result and ``True`` on success, or ``(None, False)`` on error. + + Raises + ------ + ValueError + If ``eager_termination`` is ``True`` and the operation is unknown or + ``ordered_variables`` is empty or missing. + """ + operation = aggregation_block.get("function", "no_op") aggregator = AGGREGATION_FUNCTIONS.get(operation) if operation not in AGGREGATION_FUNCTIONS or aggregator is None: self._event_logs.append( { - "error": "Unknown Operation", - "message": f"Unknown operation {operation} in cross validation rule. Expected one of " + "error": "Unknown Aggregation Function", + "message": f"Unknown function '{operation}' in aggregation block. Expected one of " f"{list(AGGREGATION_FUNCTIONS.keys())}.", "info_map": { "class": self.__class__.__name__, - "function": self._evaluate_expression.__name__, + "function": self._evaluate_aggregation_block.__name__, }, } ) if eager_termination: - raise ValueError(f"Cross-validation error: Unknown operation in expression block: {operation}") + raise ValueError(f"Cross-validation error: Unknown function in aggregation block: {operation}") else: return None, False - if not (ordered_variable_alias := expression_block.get("ordered_variables", [])): + if not (ordered_variable_alias := aggregation_block.get("operands", [])): self._event_logs.append( { - "error": "Missing Ordered Variables", - "message": "Ordered variables list is empty or missing in cross validation rule.", + "error": "Missing Operands", + "message": "Operands list is empty or missing in aggregation block.", "info_map": { "class": self.__class__.__name__, - "function": self._evaluate_expression.__name__, + "function": self._evaluate_aggregation_block.__name__, }, } ) if eager_termination: - raise ValueError( - "Cross-validation error: " "Ordered variables list is empty or missing in cross validation rule." - ) + raise ValueError("Cross-validation error: Operands list is empty or missing in aggregation block.") else: return None, False ordered_values: list[Any] = [] @@ -1958,22 +2004,117 @@ def _evaluate_expression( if any(isinstance(value, (list, dict)) for value in ordered_values): if not self._validate_expression_block_with_complex_variable_values( - expression_block, ordered_values, eager_termination + aggregation_block, ordered_values, eager_termination ): return None, False ordered_values = ( ordered_values[0] if isinstance(ordered_values[0], list) else list(ordered_values[0].values()) ) - apply_to = expression_block.get("apply_to", "group") - result = ordered_values if apply_to == "individual" else [aggregator(ordered_values)] + mode = aggregation_block.get("mode", "aggregate") + result = ordered_values if mode == "element_wise" else [aggregator(ordered_values)] else: result = ordered_values if operation == "no_op" else [aggregator(ordered_values)] - - if "save_as" in expression_block: - save_as_alise_name: str = expression_block["save_as"] - self._save_to_alias_pool(alias_name=save_as_alise_name, value=result) return result, True + def _evaluate_iterate_array_of_dicts( + self, iter_block: dict[str, Any], eager_termination: bool, outer_relationship: str + ) -> tuple[Any, bool]: + """ + Evaluates a ``for_each`` block against a list of dicts in the alias pool. + + Parameters + ---------- + iter_block : dict[str, Any] + The block describing how to iterate the array. Expected keys: + + - ``in``: alias for the ``list[dict]`` value in the alias pool. + - ``field``: key within each dict entry to evaluate. + - ``compare_value``: alias for the scalar value to compare against each entry's + ``field``. Mutually exclusive with ``compare_field``. + - ``compare_field``: key within each dict entry whose value is used as the right-hand + comparison target. Mutually exclusive with ``compare_value``. + - ``operator``: one of the supported operator strings (e.g. ``"equal"``). + - ``mode``: ``"filter"`` to return the subset of entries that satisfy the condition; + ``"enforce"`` to return ``[True]`` when all entries satisfy, otherwise ``[False]``. + eager_termination : bool + Whether to raise on error. + outer_relationship : str + The operator of the enclosing condition clause, used for alias-pool error handling. + + Returns + ------- + tuple[Any, bool] + ``(result, True)`` on success or ``(None, False)`` on error. + """ + source_alias: str = iter_block.get("in", "") + field: str = iter_block.get("field", "") + compare_value_alias: str | None = iter_block.get("compare_value") + compare_field: str | None = iter_block.get("compare_field") + operator: str = iter_block.get("operator", "") + mode: str = iter_block.get("mode", "enforce") + + compare_fn = self.relation_mapping.get(operator) + if compare_fn is None: + self._event_logs.append( + { + "error": "Invalid operator in for_each block", + "message": f"Unknown operator '{operator}' in for_each block. " + f"Expected one of {list(self.relation_mapping.keys())}.", + "info_map": { + "class": self.__class__.__name__, + "function": self._evaluate_iterate_array_of_dicts.__name__, + }, + } + ) + if eager_termination: + raise ValueError(f"Cross-validation error: Unknown operator in for_each block: {operator}") + return None, False + + has_compare_value = compare_value_alias is not None + has_compare_field = compare_field is not None + if has_compare_value == has_compare_field: + self._event_logs.append( + { + "error": "Invalid for_each block configuration", + "message": "Exactly one of 'compare_value' or 'compare_field' must be specified in a for_each block.", + "info_map": { + "class": self.__class__.__name__, + "function": self._evaluate_iterate_array_of_dicts.__name__, + }, + } + ) + if eager_termination: + raise ValueError( + "Cross-validation error: Exactly one of 'compare_value' or 'compare_field' must be specified in a for_each block." + ) + return None, False + + array_of_dicts = self._get_alias_value(source_alias, eager_termination, outer_relationship) + if array_of_dicts is None: + return None, False + + if has_compare_value: + comparison_value = self._get_alias_value(compare_value_alias, eager_termination, outer_relationship) + if comparison_value is None: + return None, False + if not isinstance(comparison_value, list): + comparison_value = [comparison_value] + get_right = lambda _entry: comparison_value + else: + get_right = lambda entry: [entry.get(compare_field)] + + if mode == "filter": + result = [ + entry for entry in array_of_dicts if compare_fn([entry.get(field)], get_right(entry), eager_termination) + ] + print(result) + return result, True + else: + all_satisfy = all( + compare_fn([entry.get(field)], get_right(entry), eager_termination) for entry in array_of_dicts + ) + return [all_satisfy], True + def _validate_expression_block_with_complex_variable_values( self, expression_block: dict[str, Any], ordered_values: list[Any], eager_termination: bool ) -> bool: @@ -1999,8 +2140,8 @@ def _validate_expression_block_with_complex_variable_values( ------ ValueError -If multiple complex variables are selected for cross-validation in a single expression block. - -If the 'apply_to' key is missing in the expression block when a complex variable is selected. - -If the 'apply_to' value is not one of the expected options ('individual' or 'group'). + -If the 'mode' key is missing in the expression block when a complex variable is selected. + -If the 'mode' value is not one of the expected options ('element_wise' or 'aggregate'). Returns ------- @@ -2028,11 +2169,11 @@ def _validate_expression_block_with_complex_variable_values( else: return False - if "apply_to" not in expression_block: + if "mode" not in expression_block: self._event_logs.append( { - "error": "Missing `apply_to` key", - "message": "The 'apply_to' key is required in expression block " + "error": "Missing `mode` key", + "message": "The 'mode' key is required in aggregation block " "when a complex data structure is selected.", "info_map": { "class": self.__class__.__name__, @@ -2042,16 +2183,16 @@ def _validate_expression_block_with_complex_variable_values( ) if eager_termination: raise ValueError( - "Cross-validation error: Missing 'apply_to' key in expression block for " + "Cross-validation error: Missing 'mode' key in aggregation block for " "selected complex data structure." ) else: return False - if apply_to := expression_block["apply_to"] not in ["individual", "group"]: + if mode := expression_block["mode"] not in ["element_wise", "aggregate"]: self._event_logs.append( { - "error": "Unknown apply_to value", - "message": f"Unknown apply_to value {apply_to} in expression block.", + "error": "Unknown mode value", + "message": f"Unknown mode value '{mode}' in aggregation block.", "info_map": { "class": self.__class__.__name__, "function": self._validate_expression_block_with_complex_variable_values.__name__, @@ -2059,7 +2200,7 @@ def _validate_expression_block_with_complex_variable_values( } ) if eager_termination: - raise ValueError(f"Cross-validation error: Unknown apply_to value in expression block: {apply_to}") + raise ValueError(f"Cross-validation error: Unknown mode value in aggregation block: {mode}") else: return False return True @@ -2084,7 +2225,7 @@ def _evaluate_condition(self, condition_clause: dict[str, Any], eager_terminatio """ if not self._validate_condition_clause(condition_clause, eager_termination): return False - relationship = condition_clause.get("relationship", "") + relationship = condition_clause.get("operator", "") left_hand, left_evaluated = self._evaluate_expression( condition_clause["left_hand"], eager_termination, relationship ) @@ -2095,18 +2236,18 @@ def _evaluate_condition(self, condition_clause: dict[str, Any], eager_terminatio if not (left_evaluated and right_evaluated): return False - evaluation_function = self.relation_mapping[condition_clause["relationship"]] + evaluation_function = self.relation_mapping[condition_clause["operator"]] return evaluation_function(left_hand, right_hand, eager_termination) def _validate_condition_clause(self, condition_clause: dict[str, Any], eager_termination: bool) -> bool: """Validate the whole condition block.""" left_expression = condition_clause.get("left_hand", False) right_expression = condition_clause.get("right_hand", False) - relationship = condition_clause.get("relationship", False) + relationship = condition_clause.get("operator", False) fields = { "left hand": left_expression, "right hand": right_expression, - "relationship": relationship, + "operator": relationship, } valid = True if self._validate_relationship(relationship, eager_termination): diff --git a/RUFAS/input_manager.py b/RUFAS/input_manager.py index 361f5e0466..90938d0c14 100644 --- a/RUFAS/input_manager.py +++ b/RUFAS/input_manager.py @@ -184,9 +184,9 @@ def _cross_validate_data(self, cross_validation_file_paths: list[str] | None, ea cross_validation_rules = self._load_cross_validation(cross_validation_file_paths) if cross_validation_rules is not None and len(cross_validation_rules) > 0: for cross_validation_ruleset in cross_validation_rules: - cross_validation_blocks = cross_validation_ruleset.get("cross-validation", []) + cross_validation_blocks = cross_validation_ruleset.get("cross_validation", []) for block in cross_validation_blocks: - target_and_save_block = block.get("target_and_save", {}) + target_and_save_block = block.get("aliases", {}) target_and_save_result = self._extract_target_and_save_block( target_and_save_block, eager_termination ) diff --git a/changelog.md b/changelog.md index b1500f223d..10ab98d822 100644 --- a/changelog.md +++ b/changelog.md @@ -62,6 +62,7 @@ v1.0.0 - [2929](https://github.com/RuminantFarmSystems/RuFaS/pull/2929) - [minor change] [GraphGenerator] [NoInputChange] [NoOutputChange] Sanitizes non-numerical data sent to graph generator to allow graphing to occur despite. - [2925](https://github.com/RuminantFarmSystems/RuFaS/pull/2925) - [minor change] [NoInputChange] [NoOutputChange] Fix the `graph_and_report` option in report_generation.py. - [2907](https://github.com/RuminantFarmSystems/RuFaS/pull/2907) - [minor change] [NoInputChange] [OutputChange] Fix the FarmGrownFeed Emissions unit issue. The mirror issue of [Fix FarmGrownFeed Emissions Unit on test #2908](https://github.com/RuminantFarmSystems/MASM/pull/2908) to update `dev`. +- [2934](https://github.com/RuminantFarmSystems/RuFaS/pull/2934) - [minor change] [NoInputChange] [NoOutputChange] [Animal][Reproduction] Refactor `Reproduction.execute_cow_ed_protocol()`. ### v1.0.0 diff --git a/input/metadata/cross_validation/example_cross_validation.json b/input/metadata/cross_validation/example_cross_validation.json index c6fefe3516..ccaab38a65 100644 --- a/input/metadata/cross_validation/example_cross_validation.json +++ b/input/metadata/cross_validation/example_cross_validation.json @@ -1,8 +1,8 @@ { - "cross-validation": [ + "cross_validation": [ { "description": "Number of stalls in calf pen", - "target_and_save": { + "aliases": { "variables": { "number_of_calves": "animal.herd_information.calf_num", "pen_animal_type": "animal.pen_information.0.animal_combination", @@ -16,47 +16,55 @@ "apply_when": [ { "left_hand": { - "operation": "no_op", - "apply_to": "group", - "ordered_variables": [ - "pen_animal_type" - ] + "aggregation": { + "function": "no_op", + "mode": "aggregate", + "operands": [ + "pen_animal_type" + ] + } }, "right_hand": { - "operation": "no_op", - "apply_to": "individual", - "ordered_variables": [ - "calf_pen_type" - ] + "aggregation": { + "function": "no_op", + "mode": "element_wise", + "operands": [ + "calf_pen_type" + ] + } }, - "relationship": "equal" + "operator": "equal" } ], "rules": [ { "left_hand": { - "operation": "sum", - "apply_to": "group", - "ordered_variables": [ - "pen_stall_num" - ] + "aggregation": { + "function": "sum", + "mode": "aggregate", + "operands": [ + "pen_stall_num" + ] + } }, "right_hand": { - "operation": "division", - "apply_to": "individual", - "ordered_variables": [ - "number_of_calves", - "pen_stocking_density" - ], + "aggregation": { + "function": "division", + "mode": "element_wise", + "operands": [ + "number_of_calves", + "pen_stocking_density" + ] + }, "save_as": "min_num_stalls" }, - "relationship": "greater_or_equal_to" + "operator": "greater_or_equal_to" } ] }, { "description": "Number of stalls in growing pens", - "target_and_save": { + "aliases": { "variables": { "number_of_heifer1s": "animal.herd_information.heiferI_num", "number_of_heifer2s": "animal.herd_information.heiferII_num", @@ -71,66 +79,78 @@ "apply_when": [ { "left_hand": { - "operation": "no_op", - "apply_to": "group", - "ordered_variables": [ - "pen_animal_type" - ] + "aggregation": { + "function": "no_op", + "mode": "aggregate", + "operands": [ + "pen_animal_type" + ] + } }, "right_hand": { - "operation": "no_op", - "apply_to": "individual", - "ordered_variables": [ - "growing_pen_type" - ] + "aggregation": { + "function": "no_op", + "mode": "element_wise", + "operands": [ + "growing_pen_type" + ] + } }, - "relationship": "equal" + "operator": "equal" } ], "rules": [ { "left_hand": { - "operation": "sum", - "apply_to": "group", - "ordered_variables": [ - "pen_stall_num" - ] + "aggregation": { + "function": "sum", + "mode": "aggregate", + "operands": [ + "pen_stall_num" + ] + } }, "right_hand": { - "operation": "division", - "apply_to": "individual", - "ordered_variables": [ - "number_of_heifer1s", - "pen_stocking_density" - ], + "aggregation": { + "function": "division", + "mode": "element_wise", + "operands": [ + "number_of_heifer1s", + "pen_stocking_density" + ] + }, "save_as": "min_num_stalls" }, - "relationship": "greater_or_equal_to" + "operator": "greater_or_equal_to" }, { "left_hand": { - "operation": "sum", - "apply_to": "group", - "ordered_variables": [ - "pen_stall_num" - ] + "aggregation": { + "function": "sum", + "mode": "aggregate", + "operands": [ + "pen_stall_num" + ] + } }, "right_hand": { - "operation": "division", - "apply_to": "individual", - "ordered_variables": [ - "number_of_heifer2s", - "pen_stocking_density" - ], + "aggregation": { + "function": "division", + "mode": "element_wise", + "operands": [ + "number_of_heifer2s", + "pen_stocking_density" + ] + }, "save_as": "max_num_stalls" }, - "relationship": "greater_or_equal_to" + "operator": "greater_or_equal_to" } ] }, { "description": "Sum of three parity fractions must equal 1.0", - "target_and_save": { + "aliases": { "variables": { "p1": "animal.herd_information.parity_fractions.1", "p2": "animal.herd_information.parity_fractions.2", @@ -145,24 +165,98 @@ "rules": [ { "left_hand": { - "operation": "sum", - "ordered_variables": [ - "p1", - "p2", - "p3", - "p4", - "p5" - ], + "aggregation": { + "function": "sum", + "operands": [ + "p1", + "p2", + "p3", + "p4", + "p5" + ] + }, "save_as": "sum" }, "right_hand": { - "ordered_variables": [ - "summed_parity_fractions" - ] + "aggregation": { + "function": "no_op", + "operands": [ + "summed_parity_fractions" + ] + } }, - "relationship": "equal" + "operator": "equal" + } + ] + }, + { + "description": "Example: for_each usage", + "aliases": { + "variables": { + "crop_configurations": "crop_configurations.crop_configurations" + }, + "constants": { + "true_constant": true + } + }, + "rules": [ + { + "left_hand": { + "for_each": { + "mode": "enforce", + "in": "crop_configurations", + "field": "optimal_temperature", + "compare_field": "minimum_temperature", + "operator": "greater_or_equal_to" + } + }, + "right_hand": { + "aggregation": { + "function": "no_op", + "mode": "element_wise", + "operands": [ + "true_constant" + ] + } + }, + "operator": "equal" + } + ] + }, + { + "description": "LAC pens", + "aliases": { + "variables": { + "pen_information": "animal.pen_information" + }, + "constants": { + "LAC_COW": "LAC_COW", + "true_constant": true + } + }, + "rules": [ + { + "left_hand": { + "for_each": { + "mode": "filter", + "in": "pen_information", + "field": "animal_combination", + "compare_value": "LAC_COW", + "operator": "equal" + } + }, + "right_hand": { + "aggregation": { + "function": "no_op", + "mode": "element_wise", + "operands": [ + "true_constant" + ] + } + }, + "operator": "not_equal" } ] } ] -} \ No newline at end of file +} diff --git a/tests/test_biophysical/test_animal/test_reproduction/test_reproduction.py b/tests/test_biophysical/test_animal/test_reproduction/test_reproduction.py index 3beaa93694..6246f89ddf 100644 --- a/tests/test_biophysical/test_animal/test_reproduction/test_reproduction.py +++ b/tests/test_biophysical/test_animal/test_reproduction/test_reproduction.py @@ -1324,71 +1324,113 @@ def test_execute_heifer_ed_protocol( @pytest.mark.parametrize( - "days_in_milk, repro_state, days_born, estrus_day, " - "expected_simulate_estrus, expected_repro_state_entered, expected_handle_called," - "expected_repeat_estrus_simulation", + "days_in_milk, expected_call_repeat, expected_call_handle_ed", [ - (1, ReproStateEnum.ENTER_HERD_FROM_INIT, 350, 400, False, False, False, True), - (10, ReproStateEnum.FRESH, 350, 500, False, False, False, True), - (100, ReproStateEnum.ENTER_HERD_FROM_INIT, 450, 400, True, True, False, False), - (100, ReproStateEnum.ENTER_HERD_FROM_INIT, 350, 400, False, True, False, False), - (100, ReproStateEnum.FRESH, 450, 400, False, False, False, False), - (100, ReproStateEnum.ENTER_HERD_FROM_INIT, 400, 400, False, False, False, False), - (100, ReproStateEnum.FRESH, 400, 400, False, False, False, False), - (100, ReproStateEnum.WAITING_SHORT_ED_CYCLE, 400, 400, False, False, True, False), - (100, ReproStateEnum.WAITING_FULL_ED_CYCLE, 400, 400, False, False, True, False), - (100, ReproStateEnum.WAITING_FULL_ED_CYCLE_BEFORE_OVSYNCH, 400, 400, False, False, True, False), + (0, False, False), + (1, True, False), + (AnimalConfig.voluntary_waiting_period, True, False), + (AnimalConfig.voluntary_waiting_period + 1, False, True), ], ) def test_execute_cow_ed_protocol( days_in_milk: int, + expected_call_repeat: bool, + expected_call_handle_ed: bool, + mocker: MockerFixture, +) -> None: + reproduction = Reproduction() + simulation_day = 100 + mock_outputs = mock_reproduction_data_stream(animal_type=AnimalType.LAC_COW, days_in_milk=days_in_milk) + + mock_update_ed_days = mocker.patch.object(reproduction, "_update_ed_days") + mock_repeat = mocker.patch.object(reproduction, "_repeat_estrus_simulation_before_vwp", return_value=mock_outputs) + mock_handle_ed = mocker.patch.object(reproduction, "_handle_ed_after_vwp", return_value=mock_outputs) + + result = reproduction.execute_cow_ed_protocol(mock_outputs, simulation_day) + + mock_update_ed_days.assert_called_once_with(mock_outputs) + + if expected_call_repeat: + mock_repeat.assert_called_once_with(mock_outputs, simulation_day) + else: + mock_repeat.assert_not_called() + + if expected_call_handle_ed: + mock_handle_ed.assert_called_once_with(mock_outputs, simulation_day) + else: + mock_handle_ed.assert_not_called() + + assert result == mock_outputs + + +@pytest.mark.parametrize( + "is_pregnant, initial_ed_days, expected_ed_days", + [ + (False, 3, 4), + (True, 5, 0), + ], +) +def test_update_ed_days( + is_pregnant: bool, + initial_ed_days: int, + expected_ed_days: int, +) -> None: + reproduction = Reproduction() + reproduction.reproduction_statistics.ED_days = initial_ed_days + data_stream = mock_reproduction_data_stream( + animal_type=AnimalType.LAC_COW, + days_in_pregnancy=100 if is_pregnant else 0, + ) + + reproduction._update_ed_days(data_stream) + + assert reproduction.reproduction_statistics.ED_days == expected_ed_days + + +@pytest.mark.parametrize( + "repro_state, days_born, estrus_day, expected_simulate_estrus, expected_enter_waiting_full," + "expected_handle_estrus_day", + [ + (ReproStateEnum.ENTER_HERD_FROM_INIT, 450, 400, True, True, False), + (ReproStateEnum.ENTER_HERD_FROM_INIT, 350, 400, False, True, False), + (ReproStateEnum.ENTER_HERD_FROM_INIT, 400, 400, False, True, True), + (ReproStateEnum.FRESH, 450, 400, False, True, False), + (ReproStateEnum.FRESH, 400, 400, False, True, True), + (ReproStateEnum.WAITING_SHORT_ED_CYCLE, 400, 400, False, False, True), + (ReproStateEnum.WAITING_FULL_ED_CYCLE, 400, 400, False, False, True), + (ReproStateEnum.WAITING_FULL_ED_CYCLE, 450, 400, False, False, False), + ], +) +def test_handle_ed_after_vwp( repro_state: ReproStateEnum, days_born: int, estrus_day: int, expected_simulate_estrus: bool, - expected_repro_state_entered: bool, - expected_handle_called: bool, - expected_repeat_estrus_simulation: bool, + expected_enter_waiting_full: bool, + expected_handle_estrus_day: bool, mocker: MockerFixture, ) -> None: reproduction = Reproduction() reproduction.estrus_day = estrus_day reproduction.repro_state_manager = ReproStateManager() reproduction.repro_state_manager.enter(repro_state) - if repro_state == ReproStateEnum.WAITING_SHORT_ED_CYCLE: - reproduction.repro_state_manager.enter(ReproStateEnum.IN_OVSYNCH, keep_existing=True) - mock_enter_repro_state = mocker.patch.object(reproduction.repro_state_manager, "enter") - mock_exit_repro_state = mocker.patch.object(reproduction.repro_state_manager, "exit") - - mock_time = MagicMock(spec=RufasTime) - mock_time.simulation_day = 100 - + mock_enter = mocker.patch.object(reproduction.repro_state_manager, "enter") mock_events = MagicMock(spec=AnimalEvents) - mock_outputs = mock_reproduction_data_stream( - animal_type=AnimalType.LAC_COW, days_in_milk=days_in_milk, days_born=days_born, events=mock_events + animal_type=AnimalType.LAC_COW, days_born=days_born, events=mock_events ) + simulation_day = 100 - mock_repeat_estrus_simulation = mocker.patch.object( - reproduction, "_repeat_estrus_simulation_before_vwp", return_value=mock_outputs - ) mock_simulate_estrus = mocker.patch.object(reproduction, "_simulate_estrus", return_value=mock_outputs) - mock_handle_estrus_detection = mocker.patch.object( - reproduction, "_handle_estrus_detection", return_value=mock_outputs - ) - - result = reproduction.execute_cow_ed_protocol(mock_outputs, mock_time.simulation_day) + mock_handle_estrus_day = mocker.patch.object(reproduction, "_handle_estrus_day_states", return_value=mock_outputs) - if expected_repeat_estrus_simulation: - mock_repeat_estrus_simulation.assert_called_once_with(mock_outputs, mock_time.simulation_day) - else: - mock_repeat_estrus_simulation.assert_not_called() + result = reproduction._handle_ed_after_vwp(mock_outputs, simulation_day) if expected_simulate_estrus: mock_simulate_estrus.assert_called_once_with( mock_outputs, days_born, - mock_time.simulation_day, + simulation_day, animal_constants.ESTRUS_DAY_SCHEDULED_NOTE, AnimalConfig.average_estrus_cycle_cow, AnimalConfig.std_estrus_cycle_cow, @@ -1396,63 +1438,87 @@ def test_execute_cow_ed_protocol( else: mock_simulate_estrus.assert_not_called() - if expected_repro_state_entered: - mock_enter_repro_state.assert_called_once_with(ReproStateEnum.WAITING_FULL_ED_CYCLE) - - if expected_handle_called: - mock_handle_estrus_detection.assert_called_once() + if expected_enter_waiting_full: + mock_enter.assert_called_once_with(ReproStateEnum.WAITING_FULL_ED_CYCLE) + mock_events.add_event.assert_called_once() else: - mock_handle_estrus_detection.assert_not_called() + mock_enter.assert_not_called() - if ( - repro_state == ReproStateEnum.WAITING_SHORT_ED_CYCLE - or repro_state == ReproStateEnum.WAITING_FULL_ED_CYCLE - or repro_state == ReproStateEnum.WAITING_FULL_ED_CYCLE_BEFORE_OVSYNCH - ): - mock_exit_repro_state.assert_called_once() + if expected_handle_estrus_day: + mock_handle_estrus_day.assert_called_once_with(mock_outputs, simulation_day) + else: + mock_handle_estrus_day.assert_not_called() assert result == mock_outputs -def test_execute_cow_ed_protocol_resets_ed_days_when_pregnant(mocker: MockerFixture) -> None: - """If the cow is pregnant, ED_days should be reset to 0.""" +@pytest.mark.parametrize( + "repro_state, expected_state_exited, expected_not_detected_method", + [ + (ReproStateEnum.WAITING_SHORT_ED_CYCLE, ReproStateEnum.WAITING_SHORT_ED_CYCLE, "_enter_ovsynch_repro_state"), + (ReproStateEnum.WAITING_FULL_ED_CYCLE, ReproStateEnum.WAITING_FULL_ED_CYCLE, "_simulate_full_estrus_cycle"), + ( + ReproStateEnum.WAITING_FULL_ED_CYCLE_BEFORE_OVSYNCH, + ReproStateEnum.WAITING_FULL_ED_CYCLE_BEFORE_OVSYNCH, + "_simulate_full_estrus_cycle_before_ovsynch", + ), + ], +) +def test_handle_estrus_day_states( + repro_state: ReproStateEnum, + expected_state_exited: ReproStateEnum, + expected_not_detected_method: str, + mocker: MockerFixture, +) -> None: reproduction = Reproduction() - reproduction.reproduction_statistics.ED_days = 5 - mock_events = MagicMock(spec=AnimalEvents) - data_stream = mock_reproduction_data_stream( - animal_type=AnimalType.LAC_COW, - days_in_milk=0, - events=mock_events, - ) - data_stream.days_in_pregnancy = 100 - + reproduction.repro_state_manager = ReproStateManager() + reproduction.repro_state_manager.enter(repro_state) + mock_exit = mocker.patch.object(reproduction.repro_state_manager, "exit") + mock_outputs = mock_reproduction_data_stream(animal_type=AnimalType.LAC_COW) simulation_day = 100 - mock_repeat_estrus = mocker.patch.object( - reproduction, - "_repeat_estrus_simulation_before_vwp", - return_value=data_stream, - ) - mock_simulate_estrus = mocker.patch.object( - reproduction, - "_simulate_estrus", - return_value=data_stream, - ) mock_handle_estrus_detection = mocker.patch.object( - reproduction, - "_handle_estrus_detection", - return_value=data_stream, + reproduction, "_handle_estrus_detection", return_value=mock_outputs ) - result = reproduction.execute_cow_ed_protocol(data_stream, simulation_day) + result = reproduction._handle_estrus_day_states(mock_outputs, simulation_day) - assert reproduction.reproduction_statistics.ED_days == 0 + mock_exit.assert_called_once_with(expected_state_exited) + mock_handle_estrus_detection.assert_called_once_with( + mock_outputs, + simulation_day, + on_estrus_detected=reproduction._setup_ai_day_after_estrus_detected, + on_estrus_not_detected=getattr(reproduction, expected_not_detected_method), + ) + assert result == mock_outputs - mock_repeat_estrus.assert_not_called() - mock_simulate_estrus.assert_not_called() - mock_handle_estrus_detection.assert_not_called() - assert result is data_stream +def test_handle_estrus_day_states_no_matching_state() -> None: + reproduction = Reproduction() + reproduction.repro_state_manager = ReproStateManager() + reproduction.repro_state_manager.enter(ReproStateEnum.IN_OVSYNCH) + mock_outputs = mock_reproduction_data_stream(animal_type=AnimalType.LAC_COW) + + result = reproduction._handle_estrus_day_states(mock_outputs, simulation_day=100) + + assert result is mock_outputs + + +def test_handle_estrus_day_states_logs_event_when_in_ovsynch_after_short_ed_cycle(mocker: MockerFixture) -> None: + reproduction = Reproduction() + reproduction.repro_state_manager = ReproStateManager() + reproduction.repro_state_manager.enter(ReproStateEnum.WAITING_SHORT_ED_CYCLE) + reproduction.repro_state_manager.enter(ReproStateEnum.IN_OVSYNCH, keep_existing=True) + mock_events = MagicMock(spec=AnimalEvents) + mock_outputs = mock_reproduction_data_stream(animal_type=AnimalType.LAC_COW, events=mock_events) + simulation_day = 100 + + mocker.patch.object(reproduction.repro_state_manager, "exit") + mocker.patch.object(reproduction, "_handle_estrus_detection", return_value=mock_outputs) + + reproduction._handle_estrus_day_states(mock_outputs, simulation_day) + + mock_events.add_event.assert_called_once() @pytest.mark.parametrize( diff --git a/tests/test_data_validator.py b/tests/test_data_validator.py index 8b8b7a2ebc..7a9f18c0f9 100644 --- a/tests/test_data_validator.py +++ b/tests/test_data_validator.py @@ -2709,12 +2709,42 @@ def test_check_target_and_save_block_message_contains_all_invalid_keys_eager_ter @pytest.mark.parametrize( "expression_block, eager_termination", [ - ({"operation": "add", "ordered_variables": ["alias_0", "alias_1"], "save_as": "alias_2"}, True), - ({"operation": "subtract", "ordered_variables": ["alias_0", "alias_1"], "save_as": "alias_2"}, True), - ({"operation": "multiply", "ordered_variables": ["alias_0", "alias_1"], "save_as": "alias_2"}, True), - ({"operation": "add", "ordered_variables": ["alias_0", "alias_1"], "save_as": "alias_2"}, False), - ({"operation": "subtract", "ordered_variables": ["alias_0", "alias_1"], "save_as": "alias_2"}, False), - ({"operation": "multiply", "ordered_variables": ["alias_0", "alias_1"], "save_as": "alias_2"}, False), + ( + {"aggregation": {"function": "add", "operands": ["alias_0", "alias_1"]}, "save_as": "alias_2"}, + True, + ), + ( + { + "aggregation": {"function": "subtract", "operands": ["alias_0", "alias_1"]}, + "save_as": "alias_2", + }, + True, + ), + ( + { + "aggregation": {"function": "multiply", "operands": ["alias_0", "alias_1"]}, + "save_as": "alias_2", + }, + True, + ), + ( + {"aggregation": {"function": "add", "operands": ["alias_0", "alias_1"]}, "save_as": "alias_2"}, + False, + ), + ( + { + "aggregation": {"function": "subtract", "operands": ["alias_0", "alias_1"]}, + "save_as": "alias_2", + }, + False, + ), + ( + { + "aggregation": {"function": "multiply", "operands": ["alias_0", "alias_1"]}, + "save_as": "alias_2", + }, + False, + ), ], ) def test_evaluate_expression_unknown_operation( @@ -2739,18 +2769,18 @@ def test_evaluate_expression_unknown_operation( @pytest.mark.parametrize( "expression_block, eager_termination", [ - ({"operation": "add", "save_as": "alias_2"}, True), - ({"operation": "subtract", "save_as": "alias_2"}, True), - ({"operation": "multiply", "ordered_variables": [], "save_as": "alias_2"}, True), - ({"operation": "add", "save_as": "alias_2"}, False), - ({"operation": "subtract", "save_as": "alias_2"}, False), - ({"operation": "multiply", "save_as": "alias_2"}, False), + ({"aggregation": {"function": "add"}, "save_as": "alias_2"}, True), + ({"aggregation": {"function": "subtract"}, "save_as": "alias_2"}, True), + ({"aggregation": {"function": "multiply", "operands": []}, "save_as": "alias_2"}, True), + ({"aggregation": {"function": "add"}, "save_as": "alias_2"}, False), + ({"aggregation": {"function": "subtract"}, "save_as": "alias_2"}, False), + ({"aggregation": {"function": "multiply"}, "save_as": "alias_2"}, False), ], ) def test_evaluate_expression_no_ordered_variables( expression_block: dict[str, Any], eager_termination: bool, mocker: MockerFixture ) -> None: - """Test the behavior of _evaluate_expression when ordered_variables is missing or empty.""" + """Test the behavior of _evaluate_expression when operands is missing or empty.""" cross_validator = CrossValidator() mock_get_alias_value = mocker.patch.object(cross_validator, "_get_alias_value") mock_save_to_alias_pool = mocker.patch.object(cross_validator, "_save_to_alias_pool") @@ -2769,24 +2799,24 @@ def test_evaluate_expression_no_ordered_variables( @pytest.mark.parametrize( "expression_block, selected_variables, eager_termination", [ - ({"operation": "sum", "ordered_variables": ["alias_0", "alias_1"], "apply_to": "group"}, [[], []], True), - ({"operation": "difference", "ordered_variables": ["alias_0", "alias_1"], "apply_to": "group"}, [{}, {}], True), + ({"function": "sum", "operands": ["alias_0", "alias_1"], "mode": "aggregate"}, [[], []], True), + ({"function": "difference", "operands": ["alias_0", "alias_1"], "mode": "aggregate"}, [{}, {}], True), ( - {"operation": "average", "ordered_variables": ["alias_0", "alias_1"], "apply_to": "group"}, + {"function": "average", "operands": ["alias_0", "alias_1"], "mode": "aggregate"}, [[1, 2, 3], {"a": 1, "b": 2}], True, ), ( - {"operation": "product", "ordered_variables": ["alias_0", "alias_1"], "apply_to": "group"}, + {"function": "product", "operands": ["alias_0", "alias_1"], "mode": "aggregate"}, [[4, 5, 6], ["a", "b", "c"]], False, ), ( - {"operation": "division", "ordered_variables": ["alias_0", "alias_1"], "apply_to": "group"}, + {"function": "division", "operands": ["alias_0", "alias_1"], "mode": "aggregate"}, [{"a": [], "b": []}, [{}, {}]], False, ), - ({"operation": "no_op", "ordered_variables": ["alias_0", "alias_1"], "apply_to": "group"}, [[{}], []], False), + ({"function": "no_op", "operands": ["alias_0", "alias_1"], "mode": "aggregate"}, [[{}], []], False), ], ) def test_validate_expression_block_with_complex_variable_values_multiple_complex_variable( @@ -2813,12 +2843,12 @@ def test_validate_expression_block_with_complex_variable_values_multiple_complex @pytest.mark.parametrize( "expression_block, selected_variables, eager_termination", [ - ({"operation": "sum", "ordered_variables": ["alias_0"]}, [[]], True), - ({"operation": "difference", "ordered_variables": ["alias_0"]}, [{}], True), - ({"operation": "average", "ordered_variables": ["alias_0"]}, [[1, 2, 3]], True), - ({"operation": "product", "ordered_variables": ["alias_0"]}, [{"a": 1, "b": 2, "c": 3}], False), - ({"operation": "division", "ordered_variables": ["alias_0"]}, [{"a": [], "b": []}], False), - ({"operation": "no_op", "ordered_variables": ["alias_0"]}, [[{}]], False), + ({"function": "sum", "operands": ["alias_0"]}, [[]], True), + ({"function": "difference", "operands": ["alias_0"]}, [{}], True), + ({"function": "average", "operands": ["alias_0"]}, [[1, 2, 3]], True), + ({"function": "product", "operands": ["alias_0"]}, [{"a": 1, "b": 2, "c": 3}], False), + ({"function": "division", "operands": ["alias_0"]}, [{"a": [], "b": []}], False), + ({"function": "no_op", "operands": ["alias_0"]}, [[{}]], False), ], ) def test_validate_expression_block_with_complex_variable_values_no_apply_to( @@ -2826,7 +2856,7 @@ def test_validate_expression_block_with_complex_variable_values_no_apply_to( ) -> None: """ Unit tests for _validate_expression_block_with_complex_variable_values() in CrossValidator - when a complex variable is selected and the `apply_to` key is missing. + when a complex variable is selected and the `mode` key is missing. """ cross_validator = CrossValidator() @@ -2845,8 +2875,8 @@ def test_validate_expression_block_with_complex_variable_values_no_apply_to( @pytest.mark.parametrize( "expression_block, selected_variables, eager_termination", [ - ({"operation": "sum", "ordered_variables": ["alias_0"], "apply_to": "unknown"}, [[]], True), - ({"operation": "sum", "ordered_variables": ["alias_0"], "apply_to": "unknown"}, [[]], False), + ({"function": "sum", "operands": ["alias_0"], "mode": "unknown"}, [[]], True), + ({"function": "sum", "operands": ["alias_0"], "mode": "unknown"}, [[]], False), ], ) def test_validate_expression_block_with_complex_variable_values_unknown_apply_to_value( @@ -2854,7 +2884,7 @@ def test_validate_expression_block_with_complex_variable_values_unknown_apply_to ) -> None: """ Unit tests for _validate_expression_block_with_complex_variable_values() in CrossValidator - when a complex variable is selected and the `apply_to` key is set to an unknown value. + when a complex variable is selected and the `mode` key is set to an unknown value. """ cross_validator = CrossValidator() @@ -2873,21 +2903,31 @@ def test_validate_expression_block_with_complex_variable_values_unknown_apply_to @pytest.mark.parametrize( "expression_block, selected_variables, expected_result", [ - ({"operation": "no_op", "ordered_variables": ["alias_0"], "apply_to": "individual"}, [[1, 2, 3]], [1, 2, 3]), - ({"operation": "no_op", "ordered_variables": ["alias_0"], "apply_to": "individual"}, [[]], []), ( - {"operation": "no_op", "ordered_variables": ["alias_0"], "apply_to": "individual", "save_as": "abc"}, + {"aggregation": {"function": "no_op", "operands": ["alias_0"], "mode": "element_wise"}}, + [[1, 2, 3]], + [1, 2, 3], + ), + ({"aggregation": {"function": "no_op", "operands": ["alias_0"], "mode": "element_wise"}}, [[]], []), + ( + { + "aggregation": {"function": "no_op", "operands": ["alias_0"], "mode": "element_wise"}, + "save_as": "abc", + }, [{"a": 1, "b": 2, "c": 3}], [1, 2, 3], ), - ({"operation": "no_op", "ordered_variables": ["alias_0"], "apply_to": "individual"}, [{}], []), + ({"aggregation": {"function": "no_op", "operands": ["alias_0"], "mode": "element_wise"}}, [{}], []), ( - {"operation": "no_op", "ordered_variables": ["alias_0"], "apply_to": "individual", "save_as": "def"}, + { + "aggregation": {"function": "no_op", "operands": ["alias_0"], "mode": "element_wise"}, + "save_as": "def", + }, [{"a": [], "b": []}], [[], []], ), ( - {"operation": "no_op", "ordered_variables": ["alias_0"], "apply_to": "individual"}, + {"aggregation": {"function": "no_op", "operands": ["alias_0"], "mode": "element_wise"}}, [[{}, {}, {}]], [{}, {}, {}], ), @@ -2898,7 +2938,7 @@ def test_evaluate_expression_apply_to_individual( ) -> None: """ Unit tests for _evaluate_expression() in CrossValidator when a complex variable is selected - and `apply_to` is set to `individual` + and `mode` is set to `element_wise` """ cross_validator = CrossValidator() mock_get_alias_value = mocker.patch.object(cross_validator, "_get_alias_value", side_effect=selected_variables) @@ -2917,17 +2957,36 @@ def test_evaluate_expression_apply_to_individual( @pytest.mark.parametrize( "expression_block, selected_variables, expected_result", [ - ({"operation": "sum", "ordered_variables": ["alias_0"], "apply_to": "group"}, [[1, 2, 3]], [6]), - ({"operation": "difference", "ordered_variables": ["alias_0"], "apply_to": "group"}, [[]], [None]), ( - {"operation": "product", "ordered_variables": ["alias_0"], "apply_to": "group", "save_as": "abc"}, + {"aggregation": {"function": "sum", "operands": ["alias_0"], "mode": "aggregate"}}, + [[1, 2, 3]], + [6], + ), + ( + {"aggregation": {"function": "difference", "operands": ["alias_0"], "mode": "aggregate"}}, + [[]], + [None], + ), + ( + { + "aggregation": {"function": "product", "operands": ["alias_0"], "mode": "aggregate"}, + "save_as": "abc", + }, [{"a": 1, "b": 2, "c": 3}], [6], ), - ({"operation": "division", "ordered_variables": ["alias_0"], "apply_to": "group"}, [{}], [None]), - ({"operation": "no_op", "ordered_variables": ["a", "b", "c"], "save_as": "def"}, [2, 5, 8], [2, 5, 8]), ( - {"operation": "average", "ordered_variables": ["a", "b", "c", "d", "e", "f", "g", "h"]}, + {"aggregation": {"function": "division", "operands": ["alias_0"], "mode": "aggregate"}}, + [{}], + [None], + ), + ( + {"aggregation": {"function": "no_op", "operands": ["a", "b", "c"]}, "save_as": "def"}, + [2, 5, 8], + [2, 5, 8], + ), + ( + {"aggregation": {"function": "average", "operands": ["a", "b", "c", "d", "e", "f", "g", "h"]}}, [8, 7, 6, 5, 4, 3, 2, 1], [4.5], ), @@ -3132,7 +3191,7 @@ def test_evaluate_condition_short_circuits_when_validation_fails( mocker.patch.object(cv, "_validate_condition_clause", return_value=False) mock_eval = mocker.patch.object(cv, "_evaluate_expression") - valid = cv._evaluate_condition({"relationship": "equal"}, eager_termination) + valid = cv._evaluate_condition({"operator": "equal"}, eager_termination) assert not valid mock_eval.assert_not_called() @@ -3148,7 +3207,7 @@ def test_evaluate_condition_returns_false_when_side_not_evaluated( # Left evaluated False; right True mocker.patch.object(cv, "_evaluate_expression", side_effect=[("L", False), ("R", True)]) - valid = cv._evaluate_condition({"relationship": "equal", "left_hand": {}, "right_hand": {}}, eager_termination) + valid = cv._evaluate_condition({"operator": "equal", "left_hand": {}, "right_hand": {}}, eager_termination) assert not valid @@ -3161,7 +3220,7 @@ def test_evaluate_condition_equal_path(mocker: MockerFixture, eager_termination: mocker.patch.object(cv, "_evaluate_expression", side_effect=[("A", True), ("B", True)]) mock_eq = mocker.patch.object(cv, "_evaluate_equal_condition", return_value=True) - valid = cv._evaluate_condition({"relationship": "equal", "left_hand": {}, "right_hand": {}}, eager_termination) + valid = cv._evaluate_condition({"operator": "equal", "left_hand": {}, "right_hand": {}}, eager_termination) assert valid mock_eq.assert_called_once_with("A", "B") @@ -3177,7 +3236,7 @@ def test_evaluate_condition_greater_or_equal_short_circuit(mocker: MockerFixture mock_eq = mocker.patch.object(cv, "_evaluate_equal_condition", return_value=False) valid = cv._evaluate_condition( - {"relationship": "greater_or_equal_to", "left_hand": {}, "right_hand": {}}, eager_termination + {"operator": "greater_or_equal_to", "left_hand": {}, "right_hand": {}}, eager_termination ) assert valid @@ -3197,7 +3256,7 @@ def test_evaluate_condition_greater_or_equal_falls_back_to_equal( mock_eq = mocker.patch.object(cv, "_evaluate_equal_condition", return_value=True) valid = cv._evaluate_condition( - {"relationship": "greater_or_equal_to", "left_hand": {}, "right_hand": {}}, eager_termination + {"operator": "greater_or_equal_to", "left_hand": {}, "right_hand": {}}, eager_termination ) assert valid @@ -3213,7 +3272,7 @@ def test_evaluate_condition_not_equal_inverts_equality(mocker: MockerFixture, ea mocker.patch.object(cv, "_evaluate_expression", side_effect=[("foo", True), ("bar", True)]) mock_eq = mocker.patch.object(cv, "_evaluate_equal_condition", return_value=False) - valid = cv._evaluate_condition({"relationship": "not_equal", "left_hand": {}, "right_hand": {}}, eager_termination) + valid = cv._evaluate_condition({"operator": "not_equal", "left_hand": {}, "right_hand": {}}, eager_termination) assert valid mock_eq.assert_called_once_with("foo", "bar") @@ -3227,7 +3286,7 @@ def test_evaluate_condition_is_of_type_passes_eager(mocker: MockerFixture, eager mocker.patch.object(cv, "_evaluate_expression", side_effect=[("text", True), ("string", True)]) mock_is_type = mocker.patch.object(cv, "_evaluate_is_type", return_value=True) - valid = cv._evaluate_condition({"relationship": "is_of_type", "left_hand": {}, "right_hand": {}}, eager_termination) + valid = cv._evaluate_condition({"operator": "is_of_type", "left_hand": {}, "right_hand": {}}, eager_termination) assert valid mock_is_type.assert_called_once_with("text", "string", eager_termination) @@ -3242,7 +3301,7 @@ def test_evaluate_condition_is_null_branch(mocker: MockerFixture, eager_terminat mocker.patch.object(cv, "_evaluate_expression", side_effect=[(None, True), ("ignored", True)]) mock_is_null = mocker.patch.object(cv, "_evaluate_is_null", return_value=True) - valid = cv._evaluate_condition({"relationship": "is_null", "left_hand": {}, "right_hand": {}}, eager_termination) + valid = cv._evaluate_condition({"operator": "is_null", "left_hand": {}, "right_hand": {}}, eager_termination) assert valid mock_is_null.assert_called_once_with(None) @@ -3256,7 +3315,7 @@ def test_evaluate_condition_regex_branch(mocker: MockerFixture, eager_terminatio mocker.patch.object(cv, "_evaluate_expression", side_effect=[("abc", True), (r"a.c", True)]) mock_regex = mocker.patch.object(cv, "_evaluate_regex", return_value=True) - ok = cv._evaluate_condition({"relationship": "regex", "left_hand": {}, "right_hand": {}}, eager_termination) + ok = cv._evaluate_condition({"operator": "regex", "left_hand": {}, "right_hand": {}}, eager_termination) assert ok is True mock_regex.assert_called_once_with("abc", r"a.c") @@ -3279,7 +3338,7 @@ def test_validate_condition_clause_ok(mocker: MockerFixture) -> None: v = CrossValidator() mocker.patch.object(v, "_validate_relationship", return_value=True) log = mocker.patch.object(v, "_log_missing_condition_clause_field") - clause = {"left_hand": 1, "right_hand": 2, "relationship": "equal"} + clause = {"left_hand": 1, "right_hand": 2, "operator": "equal"} result = v._validate_condition_clause(clause, eager_termination=False) assert result is True log.assert_not_called() @@ -3290,7 +3349,7 @@ def test_validate_condition_clause_missing_both_no_eager(mocker: MockerFixture) v = CrossValidator() mocker.patch.object(v, "_validate_relationship", return_value=True) log = mocker.patch.object(v, "_log_missing_condition_clause_field") - clause = {"relationship": "equal"} + clause = {"operator": "equal"} result = v._validate_condition_clause(clause, eager_termination=False) assert result is False assert log.call_args_list == [call("left hand"), call("right hand")] @@ -3301,7 +3360,7 @@ def test_validate_condition_clause_missing_left_no_eager(mocker: MockerFixture) v = CrossValidator() mocker.patch.object(v, "_validate_relationship", return_value=True) log = mocker.patch.object(v, "_log_missing_condition_clause_field") - clause = {"right_hand": 2, "relationship": "equal"} + clause = {"right_hand": 2, "operator": "equal"} result = v._validate_condition_clause(clause, eager_termination=False) assert result is False assert log.call_args_list == [call("left hand")] @@ -3312,7 +3371,7 @@ def test_validate_condition_clause_missing_right_no_eager(mocker: MockerFixture) v = CrossValidator() mocker.patch.object(v, "_validate_relationship", return_value=True) log = mocker.patch.object(v, "_log_missing_condition_clause_field") - clause = {"left_hand": 1, "relationship": "equal"} + clause = {"left_hand": 1, "operator": "equal"} result = v._validate_condition_clause(clause, eager_termination=False) assert result is False assert log.call_args_list == [call("right hand")] @@ -3323,7 +3382,7 @@ def test_validate_condition_clause_missing_both_eager_raises(mocker: MockerFixtu v = CrossValidator() mocker.patch.object(v, "_validate_relationship", return_value=True) log = mocker.patch.object(v, "_log_missing_condition_clause_field") - clause = {"relationship": "equal"} + clause = {"operator": "equal"} with pytest.raises(KeyError): v._validate_condition_clause(clause, eager_termination=True) assert log.call_args_list == [call("left hand"), call("right hand")] @@ -3334,7 +3393,7 @@ def test_validate_condition_clause_invalid_relationship(mocker: MockerFixture) - v = CrossValidator() mocker.patch.object(v, "_validate_relationship", return_value=False) log = mocker.patch.object(v, "_log_missing_condition_clause_field") - clause = {"left_hand": 1, "right_hand": 2, "relationship": "bogus"} + clause = {"left_hand": 1, "right_hand": 2, "operator": "bogus"} result = v._validate_condition_clause(clause, eager_termination=False) assert result is False log.assert_not_called() @@ -3349,3 +3408,142 @@ def test_log_missing_condition_clause_field_only() -> None: e = v._event_logs[0] assert e["error"] == "Missing required condition clause field" assert e["message"] == "Missing the left hand field in condition clause." + + +# --------------------------------------------------------------------------- +# Tests for _evaluate_iterate_array_of_dicts +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "iter_block, alias_values, expected_result", + [ + # mode="filter": keep entries where field == compare_value + ( + { + "mode": "filter", + "in": "items", + "field": "type", + "compare_value": "target_type", + "operator": "equal", + }, + [ + [{"type": "A", "val": 1}, {"type": "B", "val": 2}, {"type": "A", "val": 3}], + "A", + ], + [{"type": "A", "val": 1}, {"type": "A", "val": 3}], + ), + # mode="filter": no entries match + ( + { + "mode": "filter", + "in": "items", + "field": "type", + "compare_value": "target_type", + "operator": "equal", + }, + [ + [{"type": "X"}, {"type": "Y"}], + "Z", + ], + [], + ), + # mode="enforce": all entries satisfy → [True] + ( + { + "mode": "enforce", + "in": "items", + "field": "status", + "compare_value": "expected_status", + "operator": "equal", + }, + [ + [{"status": "ok"}, {"status": "ok"}], + "ok", + ], + [True], + ), + # mode="enforce": not all entries satisfy → [False] + ( + { + "mode": "enforce", + "in": "items", + "field": "status", + "compare_value": "expected_status", + "operator": "equal", + }, + [ + [{"status": "ok"}, {"status": "fail"}], + "ok", + ], + [False], + ), + ], +) +def test_evaluate_iterate_array_of_dicts_success( + iter_block: dict[str, Any], alias_values: list[Any], expected_result: Any, mocker: MockerFixture +) -> None: + """_evaluate_iterate_array_of_dicts returns correct filtered/enforced result.""" + cv = CrossValidator() + mocker.patch.object(cv, "_get_alias_value", side_effect=alias_values) + + result, status = cv._evaluate_iterate_array_of_dicts( + iter_block, eager_termination=False, outer_relationship="equal" + ) + assert status is True + assert result == expected_result + + +def test_evaluate_iterate_array_of_dicts_unknown_relationship_no_eager() -> None: + """Unknown relationship returns (None, False) when not eager.""" + cv = CrossValidator() + iter_block = { + "mode": "filter", + "in": "items", + "field": "x", + "compare_value": "cmp", + "operator": "unknown_rel", + } + result, status = cv._evaluate_iterate_array_of_dicts( + iter_block, eager_termination=False, outer_relationship="equal" + ) + assert result is None + assert status is False + assert len(cv._event_logs) == 1 + + +def test_evaluate_iterate_array_of_dicts_unknown_relationship_eager() -> None: + """Unknown relationship raises ValueError when eager.""" + cv = CrossValidator() + iter_block = { + "mode": "filter", + "in": "items", + "field": "x", + "compare_value": "cmp", + "operator": "unknown_rel", + } + with pytest.raises(ValueError, match="Unknown operator"): + cv._evaluate_iterate_array_of_dicts(iter_block, eager_termination=True, outer_relationship="equal") + + +def test_evaluate_expression_with_iterate_array_of_dicts_and_save_as(mocker: MockerFixture) -> None: + """save_as at the expression_block level is applied when using iterate_array_of_dicts.""" + cv = CrossValidator() + expression_block = { + "for_each": { + "mode": "filter", + "in": "items", + "field": "type", + "compare_value": "cmp", + "operator": "equal", + }, + "save_as": "filtered_items", + } + array_data = [{"type": "A"}, {"type": "B"}] + mocker.patch.object(cv, "_get_alias_value", side_effect=[array_data, "A"]) + mock_save = mocker.patch.object(cv, "_save_to_alias_pool") + + result, status = cv._evaluate_expression(expression_block, eager_termination=False, relationship="equal") + assert status is True + assert result == [{"type": "A"}] + mock_save.assert_called_once_with(alias_name="filtered_items", value=[{"type": "A"}])