diff --git a/env/src/game_types.py b/env/src/game_types.py index 4246cf5ad..c18d3df3b 100644 --- a/env/src/game_types.py +++ b/env/src/game_types.py @@ -13,10 +13,11 @@ class ResourceName(enum.Enum): CrudeOil = "crude-oil" UraniumOre = "uranium-ore" + class PrototypeMetaclass(enum.EnumMeta): def __getattr__(cls, name): try: - attr = super().__getattr__(name) + attr = super().__getattr__(name) return attr except AttributeError: # Get all valid prototype names @@ -29,24 +30,38 @@ def __getattr__(cls, name): if matches: suggestion_msg = f". Did you mean: {', '.join(matches)}?" - raise AttributeError(f"'{cls.__name__}' has no attribute '{name}'{suggestion_msg}") + raise AttributeError( + f"'{cls.__name__}' has no attribute '{name}'{suggestion_msg}" + ) + class RecipeName(enum.Enum): """ Recipe names that can be used in the game for fluids """ + NuclearFuelReprocessing = "nuclear-fuel-reprocessing" UraniumProcessing = "uranium-processing" - SulfuricAcid = "sulfuric-acid" # Recipe for producing sulfuric acid with a chemical plant - BasicOilProcessing = "basic-oil-processing" # Recipe for producing petroleum gas with a oil refinery - AdvancedOilProcessing = "advanced-oil-processing" # Recipe for producing petroleum gas, heavy oil and light oil with a oil refinery - CoalLiquefaction = "coal-liquefaction" # Recipe for producing petroleum gas in a oil refinery - HeavyOilCracking = "heavy-oil-cracking" # Recipe for producing light oil in a chemical plant - LightOilCracking = "light-oil-cracking" # Recipe for producing petroleum gas in a chemical plant - - SolidFuelFromHeavyOil = "solid-fuel-from-heavy-oil" # Recipe for producing solid fuel in a chemical plant - SolidFuelFromLightOil = "solid-fuel-from-light-oil" # Recipe for producing solid fuel in a chemical plant - SolidFuelFromPetroleumGas = "solid-fuel-from-petroleum-gas" # Recipe for producing solid fuel in a chemical plant + SulfuricAcid = ( + "sulfuric-acid" # Recipe for producing sulfuric acid with a chemical plant + ) + BasicOilProcessing = ( + "basic-oil-processing" # Recipe for producing petroleum gas with a oil refinery + ) + AdvancedOilProcessing = "advanced-oil-processing" # Recipe for producing petroleum gas, heavy oil and light oil with a oil refinery + CoalLiquefaction = ( + "coal-liquefaction" # Recipe for producing petroleum gas in a oil refinery + ) + HeavyOilCracking = ( + "heavy-oil-cracking" # Recipe for producing light oil in a chemical plant + ) + LightOilCracking = ( + "light-oil-cracking" # Recipe for producing petroleum gas in a chemical plant + ) + + SolidFuelFromHeavyOil = "solid-fuel-from-heavy-oil" # Recipe for producing solid fuel in a chemical plant + SolidFuelFromLightOil = "solid-fuel-from-light-oil" # Recipe for producing solid fuel in a chemical plant + SolidFuelFromPetroleumGas = "solid-fuel-from-petroleum-gas" # Recipe for producing solid fuel in a chemical plant FillCrudeOilBarrel = "fill-crude-oil-barrel" FillHeavyOilBarrel = "fill-heavy-oil-barrel" @@ -66,7 +81,6 @@ class RecipeName(enum.Enum): class Prototype(enum.Enum, metaclass=PrototypeMetaclass): - AssemblingMachine1 = "assembling-machine-1", AssemblingMachine AssemblingMachine2 = "assembling-machine-2", AdvancedAssemblingMachine AssemblingMachine3 = "assembling-machine-3", AdvancedAssemblingMachine @@ -83,7 +97,6 @@ class Prototype(enum.Enum, metaclass=PrototypeMetaclass): Inserter = "inserter", Inserter - BurnerMiningDrill = "burner-mining-drill", BurnerMiningDrill ElectricMiningDrill = "electric-mining-drill", ElectricMiningDrill @@ -114,7 +127,7 @@ class Prototype(enum.Enum, metaclass=PrototypeMetaclass): SolarPanel = "solar-panel", SolarPanel UndergroundPipe = "pipe-to-ground", Pipe - HeatPipe = 'heat-pipe', Pipe + HeatPipe = "heat-pipe", Pipe Pipe = "pipe", Pipe SteelChest = "steel-chest", Chest @@ -140,7 +153,7 @@ class Prototype(enum.Enum, metaclass=PrototypeMetaclass): IronStick = "iron-stick", None SteelPlate = "steel-plate", None # Crafting requires smelting 5 iron plates CopperPlate = "copper-plate", None # Crafting requires smelting 1 copper ore - StoneBrick = "stone-brick", None # Crafting requires smelting 2 stone + StoneBrick = "stone-brick", None # Crafting requires smelting 2 stone CopperCable = "copper-cable", None PlasticBar = "plastic-bar", None EmptyBarrel = "empty-barrel", None @@ -151,9 +164,12 @@ class Prototype(enum.Enum, metaclass=PrototypeMetaclass): Lubricant = "lubricant", None PetroleumGas = "petroleum-gas", None - AdvancedOilProcessing = "advanced-oil-processing", None # These are recipes, not prototypes. - CoalLiquifaction = "coal-liquifaction", None # These are recipes, not prototypes. - SolidFuel = "solid-fuel", None # These are recipes, not prototypes. + AdvancedOilProcessing = ( + "advanced-oil-processing", + None, + ) # These are recipes, not prototypes. + CoalLiquifaction = "coal-liquifaction", None # These are recipes, not prototypes. + SolidFuel = "solid-fuel", None # These are recipes, not prototypes. LightOil = "light-oil", None HeavyOil = "heavy-oil", None @@ -178,7 +194,7 @@ class Prototype(enum.Enum, metaclass=PrototypeMetaclass): NuclearReactor = "nuclear-reactor", Reactor UraniumFuelCell = "uranium-fuel-cell", None - HeatExchanger = 'heat-exchanger', HeatExchanger + HeatExchanger = "heat-exchanger", HeatExchanger AutomationSciencePack = "automation-science-pack", None MilitarySciencePack = "military-science-pack", None @@ -186,7 +202,7 @@ class Prototype(enum.Enum, metaclass=PrototypeMetaclass): ProductionSciencePack = "production-science-pack", None UtilitySciencePack = "utility-science-pack", None ChemicalSciencePack = "chemical-science-pack", None - + ProductivityModule = "productivity-module", None ProductivityModule2 = "productivity-module-2", None ProductivityModule3 = "productivity-module-3", None @@ -213,11 +229,12 @@ def __init__(self, prototype_name, entity_class_name): @property def WIDTH(self): return self.entity_class._width.default # Access the class attribute directly - + @property def HEIGHT(self): return self.entity_class._height.default + prototype_by_name = {prototype.value[0]: prototype for prototype in Prototype} prototype_by_title = {str(prototype): prototype for prototype in Prototype} @@ -245,11 +262,13 @@ class Technology(enum.Enum): ElectricEnergy = "electric-energy-distribution-1" ElectricEnergy2 = "electric-energy-distribution-2" SolarEnergy = "solar-energy" + Engine = "engine" ElectricEngineering = "electric-engine" BatteryTechnology = "battery" # AdvancedBattery = "battery-mk2-equipment" NuclearPower = "nuclear-power" + # Optics Optics = "optics" # Mining technologies @@ -290,7 +309,7 @@ class Technology(enum.Enum): Lubricant = "lubricant" # Modules - # Modules = "modules" + Modules = "modules" # SpeedModule = "speed-module" # SpeedModule2 = "speed-module-2" # SpeedModule3 = "speed-module-3" @@ -315,7 +334,7 @@ class Technology(enum.Enum): ChemicalSciencePack = "chemical-science-pack" ProductionSciencePack = "production-science-pack" # UtilitySciencePack = "utility-science-pack" - #SpaceSciencePack = "space-science-pack" + # SpaceSciencePack = "space-science-pack" # Inserter technologies FastInserter = "fast-inserter" @@ -361,6 +380,7 @@ class Technology(enum.Enum): # Helper dictionary to look up technology by name string technology_by_name = {tech.value: tech for tech in Technology} + class Resource: Coal = "coal", ResourcePatch IronOre = "iron-ore", ResourcePatch @@ -369,4 +389,4 @@ class Resource: Water = "water", ResourcePatch CrudeOil = "crude-oil", ResourcePatch UraniumOre = "uranium-ore", ResourcePatch - Wood = "wood", ResourcePatch \ No newline at end of file + Wood = "wood", ResourcePatch diff --git a/env/src/lib/serialize.lua b/env/src/lib/serialize.lua index 6496f12f0..73a9f0229 100644 --- a/env/src/lib/serialize.lua +++ b/env/src/lib/serialize.lua @@ -1248,11 +1248,11 @@ global.utils.serialize_entity = function(entity) -- Add the current research to the lab if entity.name == "lab" then - if game.players[1].force.current_research ~= nil then - serialized.research = game.players[1].force.current_research.name - else - serialized.research = nil - end + -- if game.players[1].force.current_research ~= nil then + -- serialized.research = game.players[1].force.current_research.name + -- else + -- serialized.research = nil + -- end end -- Add input and output locations if the entity is a offshore pump diff --git a/freeplay/trajectory_runner.py b/freeplay/trajectory_runner.py index 2e7540d03..acc572915 100644 --- a/freeplay/trajectory_runner.py +++ b/freeplay/trajectory_runner.py @@ -301,22 +301,18 @@ async def run(self): ) continue - # current_entities = f"{instance.namespace.get_entities()}" - # current_inventory = format_inventory(instance.namespace.inspect_inventory()) - - # (previous_iteration_summary,) = await self.agent.report_summary( - # iteration=iteration, - # current_inventory=current_inventory, - # current_entities=current_entities, - # current_conversation=current_conversation, - # ) - - # if iteration_row_number: - # update_spreadsheet_cell( - # os.getenv("SPREADSHEET_ID"), - # f"Iterations!G{iteration_row_number}", - # previous_iteration_summary, - # ) + (previous_iteration_summary,) = await self.agent.report_summary( + step, + game_state, + execution_history, + ) + + if iteration_row_number: + update_spreadsheet_cell( + os.getenv("SPREADSHEET_ID"), + f"Iterations!G{iteration_row_number}", + previous_iteration_summary, + ) elapsed = time.time() - self.start_time elapsed_str = f"{int(elapsed // 3600):02d}:{int((elapsed % 3600) // 60):02d}:{int(elapsed % 60):02d}" diff --git a/trainer/agent.py b/trainer/agent.py index f5d100694..a68421ae5 100644 --- a/trainer/agent.py +++ b/trainer/agent.py @@ -358,11 +358,11 @@ def _format_messages( iteration_messages += [ Message( role="assistant", - content=execution.agent_output.code, + content=execution.passed_code(), ), Message( role="user", - content=execution.evaluation.response, + content=execution.evaluation.formatted(), ), ] @@ -401,6 +401,11 @@ def _format_messages( {game_state.inventory()} +### Research Status +Current research status. Note that research makes progress by putting required science packs into labs. + +{game_state.research_status()} + ## Important Notes - Always inspect game state before making changes - Consider long-term implications of actions @@ -446,20 +451,26 @@ async def _get_policy(self, messages: List[Message]) -> AgentOutput: async def report_summary( self, - iteration: int, - current_inventory: str, - current_entities: str, - current_conversation: Conversation, + step: Step, + game_state: ParsedGameState, + execution_history: List[Execution], ): instruction = "" - iteration_messages = [] - for message in current_conversation.messages: - if message.metadata.get("iteration") == iteration: - iteration_messages.append(message) - instruction = message.metadata.get("instruction") + history = "" + for execution in execution_history: + if execution.step.iteration_number != step.iteration_number: + continue + + instruction = execution.step.instruction + history += ( + f"[{execution.step.in_iteration_number}]\n" + + f"{execution.agent_output.code}\n" + + "=" * 50 + + f"{execution.evaluation.response}\n" + ) iteration_summary = "" - if iteration_messages: + if len(history) > 0: try: iteration_summary_response = await self.llm_factory.acall( messages=[ @@ -467,14 +478,9 @@ async def report_summary( "role": "user", "content": iteration_summary_prompt( instruction, - current_entities, - current_inventory, - "\n".join( - [ - f"role: {m.role}\ncontent: {m.content}\n" - for m in iteration_messages - ] - ), + game_state.entities, + game_state.inventory(), + history, ), } ], diff --git a/trainer/db.py b/trainer/db.py index 17c6edaac..3eae749e4 100644 --- a/trainer/db.py +++ b/trainer/db.py @@ -46,7 +46,7 @@ def get_connection(self): conn.close() async def get_resume_state( - self, collection_id + self, collection_id, step_number=None ) -> tuple[ Optional[Step], Optional[ParsedGameState], @@ -55,16 +55,31 @@ async def get_resume_state( """Get the state to resume from""" try: # Get most recent successful program to resume from - query = """ - SELECT * FROM data_points - WHERE collection_id = ? - ORDER BY step_number DESC - LIMIT 1 - """ - with self.get_connection() as conn: cur = conn.cursor() - cur.execute(query, (collection_id,)) + + if step_number is not None: + cur.execute( + """ + SELECT * FROM data_points + WHERE collection_id = ? + AND step_number = ? + ORDER BY step_number DESC + LIMIT 1 + """, + (collection_id, step_number), + ) + else: + cur.execute( + """ + SELECT * FROM data_points + WHERE collection_id = ? + ORDER BY step_number DESC + LIMIT 1 + """, + (collection_id,), + ) + results = cur.fetchall() if not results: diff --git a/trainer/definitions.py b/trainer/definitions.py index 985d1dc33..f0ce6942b 100644 --- a/trainer/definitions.py +++ b/trainer/definitions.py @@ -3,6 +3,7 @@ from env.src.models.game_state import GameState from typing import List, Dict from models.achievements import ProductionFlows +import re @dataclass @@ -49,6 +50,7 @@ def to_dict(self): "raw": self.raw.to_raw(), "entities": self.entities, "inventory": self.inventory(), + "research_status": self.research_status(), } @classmethod @@ -61,6 +63,31 @@ def from_dict(cls, data): def inventory(self) -> str: return format_inventory(self.raw.inventory) + def research_status(self) -> str: + completed = [] + for name, tech in self.raw.research.technologies.items(): + if tech.researched: + completed.append(name) + + current_research = "None" + if self.raw.research.current_research: + current_research = ( + f"{self.raw.research.current_research}" + + " (" + + "%.1f" % (self.raw.research.research_progress * 100) + + "%)" + ) + + completed_researchs = f"{completed}" + + return ( + "Current Research: " + + current_research + + "\n" + + "Completed: " + + completed_researchs + ) + @dataclass class Message: @@ -133,6 +160,20 @@ def from_dict(cls, data): ticks=data["ticks"], ) + def formatted(self) -> str: + return f"""Code Execution: +{self.evaluation.response} + +Achievements: +{self.evaluation.achievements} + +Flows: +{self.evaluation.flows} +""" + + +line_number_pattern = re.compile(r"^(\d+): \(") + @dataclass class Execution: @@ -155,6 +196,25 @@ def from_dict(cls, data): evaluation=Evaluation.from_dict(data["evaluation"]), ) + def passed_code(self) -> str: + lines = self.agent_output.code.split("\n") + + eval_lines = self.evaluation.response.split("\n") + + error_line_number = None + for line in eval_lines: + if ": ('Error occurred:\\n" in line or ": ('\\nException:" in line: + line_number = line_number_pattern.match(line) + if line_number: + error_line_number = int(line_number.group(1)) + break + + pass_line_number = ( + len(lines) if error_line_number is None else error_line_number + ) + + return "\n".join(lines[:pass_line_number]) + class Agent(ABC): @abstractmethod