-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path4.2_CP3.py
More file actions
27 lines (21 loc) · 929 Bytes
/
Copy path4.2_CP3.py
File metadata and controls
27 lines (21 loc) · 929 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Consider the world population problem of Computer Problem 3.1.1. Find the best exponential fit of the data points
# by using linearization. Estimate the 1980 population, and find the estimation error.
import numpy as np
def main():
A = np.array([[1, 0], [1, 10], [1, 30], [1, 40]])
b = np.array([[np.log(3039585530)], [np.log(3707475887)], [np.log(5281653820)], [np.log(6079603571)]])
b2 = np.array([[3039585530], [3707475887], [5281653820], [6079603571]])
lstsq = np.linalg.lstsq(A, b, rcond=None)
lhs = A.T@A
rhs = A.T@b
ans = np.linalg.solve(lhs,rhs)
print(ans)
r = b - A@ans
leastSquares = lstsq[0]
function = lambda t: leastSquares[0] * np.exp(leastSquares[1]*t)
leastSquares[0] = np.exp(leastSquares[0])
RMSE = np.sqrt(sum((b2 - r) ** 2) / len(b2)) # wrong
print("ESTIMATE 1980:", function(20))
print("RMSE: ", RMSE)
if __name__ == "__main__":
main()