@@ -88,18 +88,19 @@ def export(config, output_path):
8888 first = tf .reshape (first , (- 1 , first .shape [- 1 ]))
8989 stages [0 ][0 ]["weights" ] = first .numpy ().tolist ()
9090
91+ network = {
92+ "geometry" : {
93+ "intersections" : config ["projection" ]["config" ]["geometry" ]["intersections" ],
94+ "radius" : config ["projection" ]["config" ]["geometry" ]["radius" ],
95+ "shape" : config ["projection" ]["config" ]["geometry" ]["shape" ],
96+ },
97+ "mesh" : config ["projection" ]["config" ]["mesh" ]["model" ],
98+ "network" : stages ,
99+ }
100+
101+ # Add classification meta data
102+ if config ["label" ]["type" ] == "Classification" :
103+ network ["class_map" ] = {c ["name" ]: i for i , c in enumerate (config ["label" ]["config" ]["classes" ])}
104+
91105 with open (os .path .join (output_path , "model.yaml" ), "w" ) as out :
92- yaml .dump (
93- {
94- "mesh" : config ["projection" ]["config" ]["mesh" ]["model" ],
95- "geometry" : {
96- "shape" : config ["projection" ]["config" ]["geometry" ]["shape" ],
97- "radius" : config ["projection" ]["config" ]["geometry" ]["radius" ],
98- "intersections" : config ["projection" ]["config" ]["geometry" ]["intersections" ],
99- },
100- "network" : stages ,
101- },
102- out ,
103- default_flow_style = None ,
104- width = float ("inf" ),
105- )
106+ yaml .dump (network , out , default_flow_style = None , width = float ("inf" ))
0 commit comments