-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_data.py
More file actions
60 lines (48 loc) · 2.18 KB
/
Copy pathgenerate_data.py
File metadata and controls
60 lines (48 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import csv
import random
import math
def generate_data(filename, count, include_result=True):
"""
生成数据文件
filename: 保存的文件名
count: 生成的行数
include_result: 是否包含结果列(推理数据通常也可以包含,方便人工比对,C++代码会忽略它)
"""
print(f"正在生成 {count} 条数据到 {filename} ...")
with open(filename, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
for _ in range(count):
# 0:加, 1:减, 2:乘, 3:除
op = random.randint(0, 3)
# 为了让简单网络更容易收敛,我们将数值限制在较小范围内 (-2.0 到 2.0)
# 这样可以避免 Sigmoid 激活函数饱和
a = round(random.uniform(-2.0, 2.0), 4)
b = round(random.uniform(-2.0, 2.0), 4)
result = 0.0
if op == 0: # +
result = a + b
elif op == 1: # -
result = a - b
elif op == 2: # *
result = a * b
elif op == 3: # /
# 避免除以 0,如果 b 太小,重新随机一个 b
while abs(b) < 0.1:
b = round(random.uniform(-2.0, 2.0), 4)
result = a / b
# 格式化结果,保留4位小数
result = round(result, 4)
# 写入 CSV
# 即使是推理数据,我们也把正确答案写进去作为第4列
# 之前的 C++ 代码在 infer 模式下会自动忽略第4列,但你可以打开文件人工核对
if include_result:
writer.writerow([op, a, b, result])
else:
writer.writerow([op, a, b])
print(f"完成: {filename}")
if __name__ == "__main__":
# 1. 生成 10,000 条训练数据
generate_data("train_data.csv", 10000, include_result=True)
# 2. 生成 50 条推理数据 (数量少一点方便在命令行查看)
# 我们也保留结果列,方便你打开 CSV 文件对比 C++ 跑出来的结果
generate_data("infer_data.csv", 50, include_result=True)