diff --git a/faster_sam/cloudformation.py b/faster_sam/cloudformation.py index deccf6b4..f6efb3bf 100644 --- a/faster_sam/cloudformation.py +++ b/faster_sam/cloudformation.py @@ -344,6 +344,7 @@ def __init__( self.template = self.load(template_path) self.include_files() + self.template = self.evaluate_and_replace(self.template) self.set_parameters(parameters) @property @@ -545,17 +546,36 @@ def find_environment(self) -> Dict[str, Any]: for function in self.functions.values(): variables.update(function.environment) - environment = {} + for key, value in variables.items(): + variables[key] = str(value) - for key, val in variables.items(): - if isinstance(val, (str, int, float)): - environment[key] = str(val) - else: - value = IntrinsicFunctions.eval(val, self.template) - if value is not None: - environment[key] = str(value) + return variables - return environment + def evaluate_and_replace(self, obj: Dict[str, Any], parent_key=None) -> Dict[str, Any]: + functions = [ + "Fn::Base64", + "Fn::FindInMap", + "Fn::GetAtt", + "Fn::Join", + "Fn::Select", + "Fn::Split", + "Fn::Sub", + "Ref", + ] + + if isinstance(obj, dict): + for key in obj.keys(): + value = obj[key] + + if key in functions: + result = IntrinsicFunctions.eval(key, self.template) + + obj[parent_key] = result + + elif isinstance(value, dict): + obj[key] = self.evaluate_and_replace(value, parent_key=key) + + return obj class IntrinsicFunctions: @@ -586,50 +606,27 @@ def eval(function: Dict[str, Any], template: Dict[str, Any]) -> Any: NotImplementedError If the intrinsic function is not implemented. """ - implemented = ( - "Fn::Base64", - "Fn::FindInMap", - "Fn::GetAtt", - "Fn::Join", - "Fn::Select", - "Fn::Split", - "Fn::Sub", - "Ref", - ) + functions = { + "Fn::Base64": IntrinsicFunctions.base64, + "Fn::FindInMap": IntrinsicFunctions.find_in_map, + "Fn::GetAtt": IntrinsicFunctions.get_att, + "Fn::Join": IntrinsicFunctions.join, + "Fn::Select": IntrinsicFunctions.select, + "Fn::Split": IntrinsicFunctions.split, + "Fn::Sub": IntrinsicFunctions.sub, + "Ref": IntrinsicFunctions.ref, + } fun, val = list(function.items())[0] - if fun not in implemented: + if fun in functions.keys(): + return functions[fun](val, template) + else: logging.warning(f"{fun} intrinsic function not implemented") - - if "Fn::Base64" == fun: - return IntrinsicFunctions.base64(val) - - if "Fn::FindInMap" == fun: - return IntrinsicFunctions.find_in_map(val, template) - - if "Fn::GetAtt" == fun: - return IntrinsicFunctions.get_att(val, template) - - if "Fn::Join" == fun: - return IntrinsicFunctions.join(val, template) - - if "Fn::Select" == fun: - return IntrinsicFunctions.select(val, template) - - if "Fn::Split" == fun: - return IntrinsicFunctions.split(val, template) - - if "Fn::Sub" == fun: - return IntrinsicFunctions.sub(val, template) - - if "Ref" == fun: - return IntrinsicFunctions.ref(val, template) - - return None + return None @staticmethod - def base64(value: str) -> str: + def base64(value: str, template: Dict[str, Any]) -> str: """ Encode a string to base64. diff --git a/tests/fixtures/templates/example8.yml b/tests/fixtures/templates/example8.yml new file mode 100644 index 00000000..c1d9c8bf --- /dev/null +++ b/tests/fixtures/templates/example8.yml @@ -0,0 +1,21 @@ +Resources: + ApiGateway: + Type: AWS::Serverless::Api + Properties: + Name: sam-api + StageName: v1 + + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + Events: + HelloWorld: + Properties: + Path: /hello + Method: get + RestApiId: + Fn::Select: + - 1 + - Fn::Split: + - "|" + - "dev@gmail|dotz@gmail.com" diff --git a/tests/test_cloudformation.py b/tests/test_cloudformation.py index f9ac0afa..40b5aa62 100644 --- a/tests/test_cloudformation.py +++ b/tests/test_cloudformation.py @@ -269,6 +269,11 @@ def test_lambda_handler(self): handler_path = cloudformation.functions[key].handler self.assertEqual(handler_path, "tests.fixtures.handlers.lambda_handler.handler") + def test(self): + cloudformation = CloudformationTemplate("tests/fixtures/templates/example8.yml") + + print(cloudformation.template) + class TestIntrinsicFunctions(unittest.TestCase): def test_getatt_function(self):