Skip to content

Commit cd9d08e

Browse files
Make the exporter export the list of classes into the output for classification networks
1 parent 426fa8e commit cd9d08e

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

training/export.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,19 @@ def export(config, output_path):
8888
first = tf.reshape(first, (-1, first.shape[-1]))
8989
stages[0][0]["weights"] = first.numpy().tolist()
9090

91+
network = {
92+
"geometry": {
93+
"intersections": config["projection"]["config"]["geometry"]["intersections"],
94+
"radius": config["projection"]["config"]["geometry"]["radius"],
95+
"shape": config["projection"]["config"]["geometry"]["shape"],
96+
},
97+
"mesh": config["projection"]["config"]["mesh"]["model"],
98+
"network": stages,
99+
}
100+
101+
# Add classification meta data
102+
if config["label"]["type"] == "Classification":
103+
network["class_map"] = {c["name"]: i for i, c in enumerate(config["label"]["config"]["classes"])}
104+
91105
with open(os.path.join(output_path, "model.yaml"), "w") as out:
92-
yaml.dump(
93-
{
94-
"mesh": config["projection"]["config"]["mesh"]["model"],
95-
"geometry": {
96-
"shape": config["projection"]["config"]["geometry"]["shape"],
97-
"radius": config["projection"]["config"]["geometry"]["radius"],
98-
"intersections": config["projection"]["config"]["geometry"]["intersections"],
99-
},
100-
"network": stages,
101-
},
102-
out,
103-
default_flow_style=None,
104-
width=float("inf"),
105-
)
106+
yaml.dump(network, out, default_flow_style=None, width=float("inf"))

0 commit comments

Comments
 (0)