-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathvis.py
More file actions
30 lines (23 loc) · 752 Bytes
/
vis.py
File metadata and controls
30 lines (23 loc) · 752 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import fire
import csv
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
path = "230616test_results/wf_prediction_epochs2/evaluation_tf.csv"
with open(path, "r") as fh:
golds, preds=[], []
for line in csv.DictReader(fh):
gold = line["true_response"].strip()
pred = line["response_1"].strip()
golds.append(gold)
preds.append(pred)
labels = list(set(golds))
label_dict = {v:i for i,v in enumerate(labels)}
print(label_dict)
# print(len(labels))
# exit()
cm = confusion_matrix(golds, preds, labels=labels)
print(cm)
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=labels)
disp.plot()
plt.savefig("./fig.png")